This guide explains the basics of performing supervised fine-tuning of Google’s Gemma model, using a dataset based on the Datastacks Dolly 15K Q&A pairs,
Gemma is a lightweight, state-of-the-art open model from Google. Read more here.
Dolly is an open source dataset of instruction-following records generated by thousands of Databricks employees, and covers brainstorming, classification, closed QA, generation, information extraction, open QA, and summarizing. Read more about the dataset here.
Setup:
Launch a VSCode (or Jupyter) notebook server, and specify 4 NVIDIA GPUs:
Use the following commands to setup an environment:
mkdir my-project
cd my-project
pip install --upgrade uv
uv init --python 3.12
uv add ipykernel
uv sync
source .venv/bin/activate
uv run python -m ipykernel install --user --name=my-project-kernel
uv add torch tensorboard
uv add \
"transformers" \
"datasets" \
"accelerate" \
"evaluate" \
"bitsandbytes" \
"trl" \
"peft"
uv add ninja packaging
conda update -y -n base -c conda-forge conda
conda install -y -c nvidia cuda-compiler
uv add flash-attn --no-build-isolation
Training Code:
gemma-sft-chatml.py
import os
from transformers import TrainingArguments
from trl import SFTTrainer
import torch
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from accelerate import Accelerator
# Login to the Hugging Face Hub. You can comment out these lines if your YOUR_HUGGINGFACE_TOKEN is already set.
# from huggingface_hub import login
# login(token="YOUR_HUGGINGFACE_TOKEN", add_to_git_credential=True)
os.environ["MLFLOW_TRACKING_URI"] = "http://mlflow-server.mlflow.svc.cluster.local:80"
dataset = load_dataset("philschmid/dolly-15k-oai-style", split="train")
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2b-it",
device_map={"": Accelerator().local_process_index},
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
quantization_config=(
BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
),
)
tokenizer = AutoTokenizer.from_pretrained("philschmid/gemma-tokenizer-chatml")
tokenizer.padding_side = "right" # to prevent warnings
args = TrainingArguments(
output_dir="gemma-2b-dolly-chatml",
num_train_epochs=3,
per_device_train_batch_size=2,
gradient_checkpointing=False,
optim="adamw_torch_fused",
logging_steps=10,
save_strategy="epoch",
bf16=True,
tf32=True,
learning_rate=2e-4,
max_grad_norm=0.3,
warmup_ratio=0.03,
lr_scheduler_type="constant",
push_to_hub=False,
report_to=["tensorboard", "mlflow"],
)
max_seq_length = 512
trainer = SFTTrainer(
model=model,
args=args,
train_dataset=dataset,
peft_config=(
LoraConfig(
lora_alpha=8,
lora_dropout=0.05,
r=6,
bias="none",
target_modules="all-linear",
task_type="CAUSAL_LM",
)
),
max_seq_length=max_seq_length,
tokenizer=tokenizer,
packing=True,
dataset_kwargs={
"add_special_tokens": False, # We template with special tokens
"append_concat_token": False, # No need to add additional separator token
},
)
trainer.train()
trainer.save_model()
Deepspeed Config:
deepspeed_zero3.yaml
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
gradient_accumulation_steps: 1
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: 'bf16'
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
Training Script:
train.sh
accelerate launch \\
--config_file deepspeed_zero3.yaml \\
--multi_gpu \\
--num_machines 1 \\
--num_processes 4 \\
gemma-sft-chatml.py