Axolotl is a library that makes fine-tuning LLMs an easy task by using a configuration-first approach. With a single YAML file we can customize the training by defining parameters for:

  • Huggingface models: llama, pythia, falcon, mpt
  • Huggingface datasets and formats
  • Single vs Multiple GPUs training via FSDP or Deepspeed
  • Fine-tuning technique: fullfinetune, lora, qlora, relora, and gptq
  • WandB configuration to log metrics, results, checkpoints

This tip, will show you how to fine-tune Llama 2 using axolotl. We will fine-tune Llama on the knowrohit07/know_sql dataset which is a text to sql dataset.

Below are few examples from this datasets. It contains a context which represents the definition of one or multiple SQL tables, a question about those tables, and the ground truth SQL answer.

answer question context
SELECT COUNT(district) FROM table_1341586_19 WHERE incumbent = “Lindy Boggs” how many district with incumbent being lindy boggs CREATE TABLE table_1341586_19 (district VARCHAR, incumbent VARCHAR)
SELECT result FROM table_1341586_19 WHERE candidates = “Billy Tauzin (D) Unopposed” what’s the result with candidates being billy tauzin (d) unopposed CREATE TABLE table_1341586_19 (result VARCHAR, candidates VARCHAR)”

Note: you can choose any other dataset from Hugging Face.

Before starting the fine-tuning, first check that a GPU is available and that bf16 mode is supported

import torch

print('GPU available?', torch.cuda.is_available())
print('BF16 is supported?', torch.cuda.is_bf16_supported())

Second, install dependencies including axolotl and peft

git clone -b main --depth 1 https://github.com/OpenAccess-AI-Collective/axolotl
cd axolotl

pip install packaging
pip install -e '.[flash-attn,deepspeed]'
pip install -U git+https://github.com/huggingface/peft.git

axolotl uses a YAML file to configure the fine-tuning. You can see how such files looks like by visiting the examples folder.

In our case, we will use the llama-2/qlora.yml example and apply the following git patch that sets the base model to meta-llama/Llama-2-7b-hf and the dataset path and type.

diff --git a/content/axolotl/examples/llama-2/qlora.yml b/sql.yml
index 5425532..eb395cb 100644
--- a/content/axolotl/examples/llama-2/qlora.yml
+++ b/sql.yml
@@ -1,2 +1,2 @@
-base_model: NousResearch/Llama-2-7b-hf
-base_model_config: NousResearch/Llama-2-7b-hf
+base_model: meta-llama/Llama-2-7b-hf
+base_model_config: meta-llama/Llama-2-7b-hf
@@ -12,2 +12,2 @@ datasets:
-  - path: mhenrichsen/alpaca_2k_test
-    type: alpaca
+  - path: knowrohit07/know_sql
+    type: context_qa2

Now we can start the fine-tuning (depending on your system this may take around 1h)

accelerate launch -m axolotl.cli.train sql.yml

Once, axolotl is done the qlora weights of the model will be available in the qlora-out folder. We need to apply those weights to the original llama-2

from peft import PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

qlora_model = './qlora-out'
base_model = 'meta-llama/Llama-2-7b-hf'
tokr = AutoTokenizer.from_pretrained(base_model)
model = AutoModelForCausalLM.from_pretrained(base_model, torch_dtype=torch.bfloat16, device_map=0)
model = PeftModel.from_pretrained(model, qlora_model)
model = model.merge_and_unload()

Then save the new weights so we can reuse them in the future.


Now we are ready to test our fine-tuned text-to-sql model. Let’s define a helper function that format the prompt we will pass to our model with the same format used in the fine-tuning dataset

fmt = """SYSTEM: Use the following contextual information to concisely answer the question.

USER: {}

def sql_prompt(context, question): return fmt.format(context, question)

Let’s verify that our helper prompt function works as expecting with the following example

context = 'CREATE TABLE farm_competition (Hosts VARCHAR, Theme VARCHAR)'
question = 'Get the count of competition hosts by theme.'
print(sql_prompt(context, question))

The above snippet will generate the following output

SYSTEM: Use the following contextual information to concisely answer the question.

USER: CREATE TABLE farm_competition (Hosts VARCHAR, Theme VARCHAR)
Get the count of competition hosts by theme.

Finally, we can pass this prompt to our fine-tuned model

toks = tokr(sql_prompt(context, question), return_tensors='pt')
res = model.generate(**toks.to('cuda'), max_new_tokens=250).to('cpu')

This will output something like this

<s> SYSTEM: Use the following contextual information to concisely answer the question.

USER: CREATE TABLE farm_competition (Hosts VARCHAR, Theme VARCHAR)
Get the count of competition hosts by theme.
ASSISTANT: SELECT COUNT(Hosts), Theme FROM farm_competition GROUP BY Theme</s>

Note: <s> and </s> are extra tokens added by the tokenizer that denotes the start and end of text respectively.