How to design a RAG system? This is a question that I have been thinking about for a long time. In this article, I will share some of my thoughts on this topic. To design a system, we need to first clarify the scenario and the requirements.
ToB or ToC: If it is ToB, then we need to consider the scalability and the cost of the system. If it is ToC, then we need to consider the user experience and the latency of the system.
The requirements of the system: What are the requirements of the system? The accuracy, the recall, the latency, the cost, etc. We need to clarify the requirements of the system before we can design the system.
The data: What kind of data do we have? How much data do we have? How is the data distributed? We need to understand the data before we can design the system. how it she quality of the data?
Without high quality data, we need to have a cold start strategy and the accuracy of the system may be low at the beginning. We need to have a strategy to improve the quality of the data and to increase the amount of data.
Higer quality data means the data that comes from the real task secenario, the more it is related to the real task, the higher quality it is.
Empirically, for pre-training, we need at leas 10 tokens for 1 parameter, for example, for 40B model, we need at least 400B tokens to do the fine-tuning. 400B token is around 300B words, which is around 1TB of text data. For instruction tuning, we need less data, for example, for 40B model, we may need around 100B tokens to do the instruction tuning. For Rag system, we often need to just fine tuning a small reranking model, a thousand of query-document pairs may be enough to do the fine-tuning.
How to iterate: We need to design the system in a way that we can iterate quickly. We need to be able to test the system and get feedback quickly so that we can improve the system quickly.
Specify the project boundary: We need to specify the project boundary, which means we need to clarify what is in the scope of the project and what is out of the scope of the project. It is an expert in certain field or it is a chatbot for entertainment? We need to chose the model based on the project boundary.
Compute resources:
There are three main terms: cards, data, model. From experience, for a A100 80GB card, the largest model that we can use is FP16-40B, FP16 is 2 bytes and for 40B parameters we have $$40 \times 10^9 \times 2 = 80GB $$ This is only the model parameters, in practice, we also need memory for KV cache, activations, framework overhead, etc.For full training, we need to have still more resources, for example, for 40B model, we need at least 8 A100 80GB cards to do the training:
Item Bytes weights 2 gradients 2 master weights 4 Adam m 4 Adam v 4 Total 16 bytes we need $$40 \times 10^9 \times 16 = 640 GB$$ Inference time: Whether it is a online system or an offline system, the inference time is also an important factor to consider. For online system, we need to consider the latency of the system, for offline system, we need to consider the throughput of the system.
System Rules: It is essential for explanability and badcase handling. For example, for a meeting resume, some key ideas and decision should follow some rules such as “the meeting resume should include the key decisions and the key action items”, “the meeting resume should not include the irrelevant information”.
Technical choices
There are generally two types of capacities:
Knowledge capacity: The knowledge of a certain domain that a model has. We often reinforce the knowledge capacity of a model by Rag system.
Reasoning capacity: The ability of a model to do reasoning, which is often related to the architecture of the model and the training strategy. We often reinforce the reasoning capacity of a model by instruction tuning and RLHF and prompt engineering, CoT, ToT, GoT.
Cot: Chain of Thought, which is a prompting technique that encourages the model to generate intermediate reasoning steps before giving the final answer. This can help the model to solve complex problems that require multiple steps of reasoning.
ToT: Tree of Thoughts, which is a prompting technique that encourages the model to generate multiple reasoning paths and then select the best one. This can help the model to explore different possibilities and avoid getting stuck in a local optimum. The key idea of ToT is to allow the model to backtrack and explore different reasoning paths, which can help it to find the optimal solution. It is like recurisive CoT.
GoT: Graph of Thoughts, which is a prompting technique that encourages the model to generate a graph of reasoning steps and then traverse the graph to find the best answer. This can help the model to capture the dependencies between different reasoning steps and to find the optimal solution.
Fine-tuning methods
Parameter-Efficient Fine-Tuning (PEFT) is a method that allows us to fine-tune a large language model (LLM) with a small amount of data and computational resources. There are several PEFT methods, such as LoRA, prefix-tuning, adapter-tuning, Lora, P-tuning, Freeze. As the time of writing, LoRA is the most popular PEFT method, which adds low-rank matrices to the original model parameters and only fine-tunes these low-rank matrices during training. This method can significantly reduce the number of trainable parameters and the computational cost of fine-tuning, while still achieving good performance on downstream tasks. Also the trained LoRA weights can be easily shared and applied to different models, which makes it a convenient method for fine-tuning LLMs.
Empirically, if we use Lora, the more specific the field is, the bigger our lora dimension should be.
LoRA (Low-Rank Adaptation)
Instead of updating the full weight matrix W, LoRA freezes the pretrained weights and learns a low-rank update.
Original Layer
$$ y = Wx $$
LoRA Modification
$$ y = (W + \Delta W)x $$
Low-Rank Decomposition
$$ \Delta W = BA $$
Where: $$A \in \mathbb{R}^{r \times d}$$ $$B \in \mathbb{R}^{d \times r}$$ $$r \ll d$$
Parameter Reduction
Instead of training:
$$ d^2 $$
parameters (full matrix),
LoRA trains only:
$$ 2dr $$
parameters.
This significantly reduces memory and training cost while preserving model performance.
Here is an example of of fine-tuning a LLM with LoRA on SageMaker:
import sagemaker
from sagemaker.huggingface import HuggingFace
from sagemaker.session import Session
role = sagemaker.get_execution_role()
hyperparameters = {
"model_name": "meta-llama/Llama-2-7b-hf",
"epochs": 3,
"per_device_train_batch_size": 4,
"learning_rate": 2e-4,
"lora_rank": 8
}
huggingface_estimator = HuggingFace(
entry_point="train.py",
source_dir="./",
instance_type="ml.g5.2xlarge",
instance_count=1,
role=role,
transformers_version="4.36",
pytorch_version="2.1",
py_version="py310",
hyperparameters=hyperparameters,
)
huggingface_estimator.fit(
{
"training": "s3://my-bucket/dataset/"
}
)
When fine-tuning, we keep the field data to be 20% of the total data, which can empirically achieve good performance and avoid the forgetting of the general knowledge of the model.
When we do fine-tuning, we chose to use base model when we have large high quality data. If we have small data like 5k, we can use chat model which is already instruction tuned.