Skip to main content

Supervised Fine-Tuning (SFT)

Supervised Fine-Tuning (SFT) is typically the first stage of LLM post-training. The model learns from curated instruction-response pairs using standard cross-entropy loss. AlphaApollo's SFT trainer is built on PyTorch FSDP and supports sequence parallelism, LoRA, and multi-turn conversation data.

Overview

The SFT pipeline:

  1. Loads a pretrained model and instruction-response dataset
  2. Fine-tunes the model using standard cross-entropy loss
  3. Supports both single-turn and multi-turn conversation formats
  4. Outputs a checkpoint compatible with the RL training pipeline

Entry point:

torchrun --standalone --nnodes=1 --nproc_per_node=<N_GPUS> \
-m verl.trainer.fsdp_sft_trainer \
model.partial_pretrain=<MODEL_PATH> \
data.train_files=<TRAIN_DATA> \
...

Config Reference

The SFT config is defined in verl/trainer/config/sft_trainer.yaml.

Data

data:
train_batch_size: 256
micro_batch_size_per_gpu: 4
train_files: ~/data/gsm8k/train.parquet
val_files: ~/data/gsm8k/test.parquet
# Single-turn settings
prompt_key: question
response_key: answer
prompt_dict_keys: ['question']
response_dict_keys: ['answer']
# Multi-turn settings
multiturn:
enable: false
messages_key: messages
max_length: 1024
truncation: error
balance_dp_token: False
chat_template: null
custom_cls:
path: null
name: null
ParameterTypeDefaultDescription
train_batch_sizeint256Global training batch size.
micro_batch_size_per_gpuint4Per-GPU batch size for forward/backward pass (gradient accumulation).
train_filesstrPath to training data (parquet format).
val_filesstrPath to validation data (parquet format).
prompt_keystrquestionColumn name for prompts.
response_keystranswerColumn name for responses.
prompt_dict_keyslist['question']Keys to extract from prompt dict if the column contains dicts.
response_dict_keyslist['answer']Keys to extract from response dict.
max_lengthint1024Maximum sequence length (prompt + response).
truncationstrerrorTruncation strategy: error, left, right.
balance_dp_tokenboolFalseBalance tokens across data-parallel ranks.
chat_templatestrnullCustom chat template. null uses the model's default.
multiturn.enableboolfalseEnable multi-turn conversation format.
multiturn.messages_keystrmessagesColumn name for multi-turn messages list.

Model

model:
partial_pretrain: ~/models/gemma-1.1-7b-it
strategy: fsdp2
fsdp_config:
wrap_policy:
min_num_params: 0
cpu_offload: False
offload_params: False
external_lib: null
enable_gradient_checkpointing: False
trust_remote_code: False
lora_rank: 0
lora_alpha: 16
target_modules: all-linear
use_liger: False
ParameterTypeDefaultDescription
partial_pretrainstrPath to pretrained model (HuggingFace format).
strategystrfsdp2FSDP strategy: fsdp or fsdp2.
fsdp_config.cpu_offloadboolFalseEnable CPU offloading for FSDP.
enable_gradient_checkpointingboolFalseEnable gradient checkpointing to reduce memory.
trust_remote_codeboolFalseAllow loading remote code models.
lora_rankint0LoRA rank. Set > 0 to enable LoRA fine-tuning.
lora_alphaint16LoRA scaling factor.
target_modulesstr / listall-linearLoRA target modules.
use_ligerboolFalseUse Liger kernel for memory-efficient computation.

Optimizer

optim:
lr: 1e-5
betas: [0.9, 0.95]
weight_decay: 0.01
warmup_steps_ratio: 0.1
clip_grad: 1.0
lr_scheduler: cosine # cosine or wsd
ParameterTypeDefaultDescription
lrfloat1e-5Learning rate.
betaslist[0.9, 0.95]Adam optimizer beta parameters.
weight_decayfloat0.01Weight decay coefficient.
warmup_steps_ratiofloat0.1Fraction of total steps used for learning rate warmup.
clip_gradfloat1.0Gradient clipping norm.
lr_schedulerstrcosineLR scheduler: cosine (cosine annealing) or wsd (warmup-stable-decay).

Trainer

trainer:
default_local_dir: /tmp/sft_model
project_name: gsm8k-sft
experiment_name: test
total_epochs: 4
total_training_steps: null
logger: ['console']
seed: 1
ParameterTypeDefaultDescription
default_local_dirstrDirectory to save checkpoints.
project_namestrProject name for logging.
total_epochsint4Number of training epochs.
total_training_stepsintnullAlternative: stop after N steps (overrides total_epochs).
loggerlist['console']Logging backends.
seedint1Random seed for reproducibility.

Sequence Parallelism

ulysses_sequence_parallel_size: 1
use_remove_padding: False
  • ulysses_sequence_parallel_size: Degree of Ulysses sequence parallelism. Set to 2 or more to split long sequences across GPUs.
  • use_remove_padding: Remove padding tokens before computation for better efficiency.

Examples

Basic SFT on GSM8K

torchrun --standalone --nnodes=1 --nproc_per_node=2 \
-m verl.trainer.fsdp_sft_trainer \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.prompt_key=extra_info \
data.response_key=extra_info \
data.prompt_dict_keys=['question'] \
+data.response_dict_keys=['answer'] \
data.micro_batch_size=4 \
optim.lr=1e-4 \
model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \
trainer.default_local_dir=/tmp/sft_output \
trainer.project_name=gsm8k-sft \
trainer.experiment_name=qwen-0.5b-sft \
trainer.logger=['console'] \
trainer.total_epochs=4 \
ulysses_sequence_parallel_size=2 \
use_remove_padding=true

SFT with LoRA

For parameter-efficient fine-tuning:

torchrun --standalone --nnodes=1 --nproc_per_node=2 \
-m verl.trainer.fsdp_sft_trainer \
model.partial_pretrain=Qwen/Qwen2.5-7B-Instruct \
model.lora_rank=32 \
model.lora_alpha=16 \
model.target_modules=all-linear \
optim.lr=1e-4 \
data.train_files=$HOME/data/train.parquet \
data.val_files=$HOME/data/test.parquet \
trainer.total_epochs=3

Multi-Turn SFT

For training on conversation data:

torchrun --standalone --nnodes=1 --nproc_per_node=2 \
-m verl.trainer.fsdp_sft_trainer \
data.train_files=$HOME/data/multiturn/train.parquet \
data.val_files=$HOME/data/multiturn/test.parquet \
data.multiturn.enable=true \
data.multiturn.messages_key=messages \
data.max_length=2048 \
model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \
trainer.total_epochs=4

Data Format

Single-Turn Data

The training data should be in parquet format with prompt and response columns:

prompt (or question)response (or answer)
"What is 2 + 3?""2 + 3 = 5"
"Solve: 4x = 12""x = 3"

When using prompt_dict_keys and response_dict_keys, the column can contain JSON dicts:

{
"extra_info": {
"question": "What is 2 + 3?",
"answer": "2 + 3 = 5"
}
}

Multi-Turn Data

For multi-turn format, provide a messages column with a list of role-content pairs:

{
"messages": [
{"role": "user", "content": "What is 2+3?"},
{"role": "assistant", "content": "2+3 = 5"},
{"role": "user", "content": "And 5+7?"},
{"role": "assistant", "content": "5+7 = 12"}
]
}

Data Preparation

Use the provided preprocessing scripts to prepare datasets:

# GSM8K
python3 -m examples.data_preprocess.gsm8k

# Informal Math
python3 -m examples.data_preprocess.prepare_informal_math \
--data_source DigitalLearningGmbH/MATH-lighteval

# Multi-turn data
python3 -m examples.data_preprocess.multiturn

SFT → RL Pipeline

The SFT checkpoint can be directly used as the starting point for RL training:

# Step 1: SFT
torchrun ... -m verl.trainer.fsdp_sft_trainer \
model.partial_pretrain=Qwen/Qwen2.5-1.5B-Instruct \
trainer.default_local_dir=/tmp/sft_ckpt \
...

# Step 2: RL Training (using SFT checkpoint)
python3 -m verl.trainer.main_ppo \
actor_rollout_ref.model.path=/tmp/sft_ckpt \
algorithm.adv_estimator=grpo \
...
info

The SFT checkpoint directory can be passed directly to actor_rollout_ref.model.path — no format conversion needed.