from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer, SFTConfig
model_id = "meta-llama/Llama-3.1-8B-Instruct"
# QLoRA: load model in 4-bit
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype="bfloat16",
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=bnb_config,
device_map="auto",
)
# LoRA config — rank 16 for domain adaptation
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
lora_dropout=0.05,
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# Load dataset
dataset = load_dataset("json", data_files="/root/train_data.jsonl", split="train")
# Train
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
args=SFTConfig(
output_dir="/root/output",
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-4,
logging_steps=10,
save_strategy="epoch",
bf16=True,
),
processing_class=tokenizer,
)
trainer.train()
trainer.save_model("/root/output/final")
tokenizer.save_pretrained("/root/output/final")
print("Training complete.")