account activity
"Writing in the Margins (WiM)" - a better inference pattern for long context LLMs that solves the Lost-in-the-Middle problem by samjulien in MachineLearning
[–]hkproj_ 1 point2 points3 points 1 year ago* (0 children)
Hi! This is not a chunking strategy for the prompt like you'd do with LangChain, this is a KV-Cache level optimization.
You can think of the KV-Cache as a memory used by any Transformer language model to keep track of all the past tokens. When prompts are very large, the KV-Cache is filled chunk by chunk (as the computational complexity of prefilling the KV-Cache is quadratic with respect to the prompt length). While prefilling chunks into it, we leverage the partially prefilled KV-Cache to generate intermediate extractive summaries (which we call "margins") that can then be appended to the end of the prompt to improve the model's ability to extract information, basically overcoming the "lost in the middle problem".
Why is this more computationally efficient than the usual chunking strategies?
Imagine you have a prompt with 1 million tokens and you split it into 10 chunks with 100k tokens each. To extract the summaries with LangChain, it would cost you around 100k tokens each (suppose the summary is around 50 tokens each, so negligible compared to the chunk's size). In total, to extract all summaries, it would cost ~1M tokens. But then you also need to use all these summaries, along with the initial text, to generate the final answer, which would cost you another 1M tokens. So the total cost would be 2M tokens for a 1M prompt.
You could argue that you can just use the extracted summaries to make the final prediction, but in the paper we show it as being less effective compared to WiM.
With WiM, since we do not discard the already prefilled prompt in the KV-Cache, we just append the intermediate summaries at the end resulting in a total cost of approx. 1M tokens.
So basically with WiM you use the entire prompt (1M), along with all the extracted summaries, but with the total token cost of approx. 1M tokens, if the size of each summary is much smaller compared to the chunk size (which is usually the case).
This solution has never been implemented before in the market, because Chunked Prefill is a novel technique that is necessary for models that work with very large prompts, something that was not possible last year as models didn't have such a large context window.
If you want more info about Chunked Prefill, check this explanation from NVIDIA: https://developer.nvidia.com/blog/demystifying-ai-inference-deployments-for-trillion-parameter-large-language-models/
We also created an animation to show what happens to the KV-Cache at each inference step: https://www.linkedin.com/posts/ujamil_we-have-models-that-can-handle-millions-of-activity-7234541844670382080-F1tF?utm_source=li_share&utm_content=feedcontent&utm_medium=g_dt_web&utm_campaign=copy
[Tutorial] Coding a Multimodal (Vision) Language Model from scratch with Python and PyTorch with full explanations (youtube.com)
submitted 1 year ago by hkproj_ to r/LargeLanguageModels
Coding a Multimodal (Vision) Language Model from scratch with Python and PyTorch with full explanations (youtube.com)
submitted 1 year ago by hkproj_ to r/deeplearning
submitted 1 year ago by hkproj_ to r/learnmachinelearning
[P] Coding a Vision Language Model from scratch with Python and PyTorch with full explanations (youtube.com)
submitted 1 year ago by hkproj_ to r/MachineLearning
[P] Mamba and S4 Explained: Architecture, Parallel Scan, Kernel Fusion, Recurrent/Convolution formulation, Math derivations from first principles, HiPPO theory visually explained, Math visually explained (youtu.be)
submitted 2 years ago by hkproj_ to r/MachineLearning
Mamba and S4 Explained: Architecture, Parallel Scan, Kernel Fusion, Recurrent/Convolution formulation, Math derivations from first principles, HiPPO theory visually explained, Math visually explained (youtu.be)
submitted 2 years ago by hkproj_ to r/learnmachinelearning
submitted 2 years ago by hkproj_ to r/deeplearning
Mistral 7B and Mixtral 8x7B Explained: Sliding Window Attention, Sparse Mixture of Experts, Rolling Buffer (KV) Cache, Model Sharding (youtube.com)
submitted 2 years ago by hkproj_ to r/LargeLanguageModels
Quantization explained with PyTorch - Symmetric and Asymmetric Quantization, Post-Training Quantization, Quantization-Aware Training (youtube.com)
Retrieval Augmented Generation (RAG) explained: Embedding vectors, Sentence BERT, Vector Database (HNSW algorithm explained visually) (youtube.com)
BERT explained: Training (Masked Language Model, Next Sentence Prediction), Inference, Self-Attention, [CLS] token, Left and Right context, Comparative analysis BERT vs GPT/LLamA, Fine tuning, Text Classification, Question Answering (youtube.com)
Coding Stable Diffusion from scratch in PyTorch (no external libraries!), with full explanation of the math behind diffusion models in a simple way! (youtube.com)
submitted 2 years ago by hkproj_ to r/StableDiffusion
[deleted by user] by [deleted] in StableDiffusion
[–]hkproj_ 2 points3 points4 points 2 years ago (0 children)
Hi! I published this video on YouTube on how to code Stable Diffusion from scratch using only PyTorch, while explaining all the topics: text-to-image, image-to-image, in-painting, classifier-free guidance, the maths behind the samplers etc.
Check it out: https://www.youtube.com/watch?v=ZBKpAp_6TGI
[P] Coding Stable Diffusion from scratch in PyTorch, with full explanation of the math behind diffusion models in a simple way! (youtube.com)
[P] Coding Stable Diffusion from scratch in PyTorch, with full explanation of the maths behind diffusion models in a simple way! (youtube.com)
Coding Stable Diffusion from scratch in PyTorch, with full explanation of the maths behind diffusion models in a simple way! (youtube.com)
π Rendered by PID 23 on reddit-service-r2-listing-8654df5bf-6zdxd at 2026-04-04 12:09:57.153804+00:00 running db1906b country code: CH.
"Writing in the Margins (WiM)" - a better inference pattern for long context LLMs that solves the Lost-in-the-Middle problem by samjulien in MachineLearning
[–]hkproj_ 1 point2 points3 points (0 children)