Supervised Fine Tuning of Gemma2b-it with Mlflow and deepspeed_zero3

Prev Next

This guide explains the basics of performing supervised fine-tuning of Google’s Gemma model, using a dataset based on the Datastacks Dolly 15K Q&A pairs,

Gemma is a lightweight, state-of-the-art open model from Google. Read more here.

Dolly is an open source dataset of instruction-following records generated by thousands of Databricks employees, and covers brainstorming, classification, closed QA, generation, information extraction, open QA, and summarizing. Read more about the dataset here.

Setup:

Launch a VSCode (or Jupyter) notebook server, and specify 4 NVIDIA GPUs:

Use the following commands to setup an environment:

mkdir my-project
cd my-project
pip install --upgrade uv

uv init --python 3.12
uv add ipykernel
uv sync
source  .venv/bin/activate
uv run python -m ipykernel install --user --name=my-project-kernel

uv add torch tensorboard

uv add \
    "transformers" \
    "datasets" \
    "accelerate" \
    "evaluate" \
    "bitsandbytes" \
    "trl" \
    "peft"

uv add ninja packaging

conda update -y -n base -c conda-forge conda
conda install -y -c nvidia cuda-compiler

uv add flash-attn --no-build-isolation

Training Code:

gemma-sft-chatml.py

import os
from transformers import TrainingArguments
from trl import SFTTrainer
import torch
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from accelerate import Accelerator
# Login to the Hugging Face Hub. You can comment out these lines if your YOUR_HUGGINGFACE_TOKEN is already set.
# from huggingface_hub import login
# login(token="YOUR_HUGGINGFACE_TOKEN", add_to_git_credential=True)
os.environ["MLFLOW_TRACKING_URI"] = "http://mlflow-server.mlflow.svc.cluster.local:80"
dataset = load_dataset("philschmid/dolly-15k-oai-style", split="train")
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2b-it",
    device_map={"": Accelerator().local_process_index},
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    quantization_config=(
        BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16,
        )
    ),
)
tokenizer = AutoTokenizer.from_pretrained("philschmid/gemma-tokenizer-chatml")
tokenizer.padding_side = "right"  # to prevent warnings
args = TrainingArguments(
    output_dir="gemma-2b-dolly-chatml",
    num_train_epochs=3,
    per_device_train_batch_size=2,
    gradient_checkpointing=False,
    optim="adamw_torch_fused",
    logging_steps=10,
    save_strategy="epoch",
    bf16=True,
    tf32=True,
    learning_rate=2e-4,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="constant",
    push_to_hub=False,
    report_to=["tensorboard", "mlflow"],
)
max_seq_length = 512
trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset,
    peft_config=(
        LoraConfig(
            lora_alpha=8,
            lora_dropout=0.05,
            r=6,
            bias="none",
            target_modules="all-linear",
            task_type="CAUSAL_LM",
        )
    ),
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    packing=True,
    dataset_kwargs={
        "add_special_tokens": False,  # We template with special tokens
        "append_concat_token": False,  # No need to add additional separator token
    },
)
trainer.train()
trainer.save_model()

Deepspeed Config:

deepspeed_zero3.yaml

compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
  deepspeed_multinode_launcher: standard
  gradient_accumulation_steps: 1
  offload_optimizer_device: none
  offload_param_device: none
  zero3_init_flag: true
  zero3_save_16bit_model: true
  zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: 'bf16'
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Training Script:

train.sh

accelerate launch \\
--config_file deepspeed_zero3.yaml \\
--multi_gpu \\
--num_machines 1 \\
--num_processes 4 \\
gemma-sft-chatml.py