微信扫码
添加专属顾问
我要投稿
阿里妹导读
本文介绍了微调的基本概念,以及如何对语言模型进行微调。
一、什么是 fine-tuning
1.1. 为什么要 fine-tuning
1.2. 一些相关概念区分
1.3. 小结
二、如何 Fine-tuning
2.1. 微调的基本原理
2.2. 什么是 LoRA
2.3. 微调过程
1.准备数据:收集与目标任务相关的标注数据,将数据分为训练集、验证集,进行Tokenization处理。
2.微调参数设:配置LoRA参数、微调参数如学习率,确保模型收敛。
3.微调模型:在训练集上训练模型,并调整超参数以防止过拟合。
1.高质量:非常重要,再强调也不过分:Garbage in garbage out、Textbooks Are All You Need,都在强调数据质量重要性。
2.多样性:就像写代码的测试用例一样,尽量使用差异较大数据,能覆盖的场景更多的数据。
3.尽量人工生成:语言模型生成的文本,有一种隐含的“模式”。在看一些文字的时候,经常能识别出来“这是语言模型生成的”。
To fine-tune a model, you are required to provide at least 10 examples. We typically see clear improvements from fine-tuning on 50 to 100 training examples with gpt-3.5-turbo but the right number varies greatly based on the exact use case. We recommend starting with 50 well-crafted demonstrations and seeing if the model shows signs of improvement after fine-tuning. In some cases that may be sufficient, but even if the model is not yet production quality, clear improvements are a good sign that providing more data will continue to improve the model. No improvement suggests that you may need to rethink how to set up the task for the model or restructure the data before scaling beyond a limited example set.
2.4. 使用 LoRA 微调代码分析
!pip install datasets!pip install transformers!pip install evaluate!pip install torch!pip install peftfrom datasets import load_dataset, DatasetDict, Datasetfrom transformers import (AutoTokenizer,AutoConfig,AutoModelForSequenceClassification,DataCollatorWithPadding,TrainingArguments,Trainer)from peft import PeftModel, PeftConfig, get_peft_model, LoraConfigimport evaluateimport torchimport numpy as np
#imdb_dataset = load_dataset("stanfordnlp/imdb")#N = 1000#rand_idx = np.random.randint(24999, size=N)#x_train = imdb_dataset['train'][rand_idx]['text']y_train = imdb_dataset['train'][rand_idx]['label']x_test = imdb_dataset['test'][rand_idx]['text']y_test = imdb_dataset['test'][rand_idx]['label']#dataset = DatasetDict({'train':Dataset.from_dict({'label':y_train,'text':x_train}),'validation':Dataset.from_dict({'label':y_test,'text':x_test})})import numpy as np # Import the NumPy librarynp.array(dataset['train']['label']).sum()/len(dataset['train']['label']) # 0.508
{ "label": 0, "text": "Not a fan, don't recommed."}from transformers import AutoModelForSequenceClassificationmodel_checkpoint = 'distilbert-base-uncased'# model_checkpoint = 'roberta-base'id2label = {0: "Negative", 1: "Positive"}label2id = {"Negative":0, "Positive":1}model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=2, id2label=id2label, label2id=label2id)model
DistilBertForSequenceClassification( (distilbert): DistilBertModel( (embeddings): Embeddings( (word_embeddings): Embedding(30522, 768, padding_idx=0) (position_embeddings): Embedding(512, 768) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) (transformer): Transformer( (layer): ModuleList( (0-5): 6 x TransformerBlock( (attention): MultiHeadSelfAttention( (dropout): Dropout(p=0.1, inplace=False) (q_lin): Linear(in_features=768, out_features=768, bias=True) (k_lin): Linear(in_features=768, out_features=768, bias=True) (v_lin): Linear(in_features=768, out_features=768, bias=True) (out_lin): Linear(in_features=768, out_features=768, bias=True) ) (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (ffn): FFN( (dropout): Dropout(p=0.1, inplace=False) (lin1): Linear(in_features=768, out_features=3072, bias=True) (lin2): Linear(in_features=3072, out_features=768, bias=True) (activation): GELUActivation() ) (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) ) ) ) ) (pre_classifier): Linear(in_features=768, out_features=768, bias=True) (classifier): Linear(in_features=768, out_features=2, bias=True) (dropout): Dropout(p=0.2, inplace=False))
一个6层的 Transformer 模型,LoRA 影响的是:
from transformers import AutoTokenizer # Import AutoTokenizertokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)if tokenizer.pad_token is None:tokenizer.add_special_tokens({'pad_token': '[PAD]'})model.resize_token_embeddings(len(tokenizer))def tokenize_function(examples):text = examples["text"]tokenizer.truncation_side = "left"tokenized_inputs = tokenizer(text,return_tensors="np",truncation=True,max_length=512, # Change max_length to 512 to match model's expected input lengthpadding='max_length' # Pad shorter sequences to the maximum length)return tokenized_inputstokenized_dataset = dataset.map(tokenize_function, batched=True)from transformers import DataCollatorWithPadding # Import DataCollatorWithPaddingdata_collator = DataCollatorWithPadding(tokenizer=tokenizer)tokenized_dataset
import torch # Import PyTorchmodel_untrained = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=2, id2label=id2label, label2id=label2id)# define list of examplestext_list = ["It was good.", "Not a fan, don't recommed.", "Better than the first one.", "This is not worth watching even once.", "This one is a pass."]print("Untrained model predictions:")print("----------------------------")for text in text_list:# tokenize textinputs = tokenizer.encode(text, return_tensors="pt")# compute logitslogits = model_untrained(inputs).logits# convert logits to labelpredictions = torch.argmax(logits)print(text + " - " + id2label[predictions.tolist()])
Untrained model predictions: ---------------------------- It was good. - Positive Not a fan, don't recommed. - Positive Better than the first one. - Positive This is not worth watching even once. - Positive This one is a pass. - Positive
import evaluate # Import the evaluate moduleaccuracy = evaluate.load("accuracy")def compute_metrics(p):predictions, labels = ppredictions = np.argmax(predictions, axis=1)return {"accuracy": accuracy.compute(predictions=predictions, references=labels)}from peft import LoraConfig, get_peft_model # Import the missing functionpeft_config = LoraConfig(task_type="SEQ_CLS",r=1,lora_alpha=32,lora_dropout=0.01,target_modules = ['q_lin'])
LoraConfig(peft_type=<PeftType.LORA: 'LORA'>, auto_mapping=None, base_model_name_or_path=None, revision=None, task_type='SEQ_CLS', inference_mode=False, r=1, target_modules={'q_lin'}, lora_alpha=32, lora_dropout=0.01, fan_in_fan_out=False, bias='none', use_rslora=False, modules_to_save=None, init_lora_weights=True, layers_to_transform=None, layers_pattern=None, rank_pattern={}, alpha_pattern={}, megatron_config=None, megatron_core='megatron.core', loftq_config={}, use_dora=False, layer_replication=None, runtime_config=LoraRuntimeConfig(ephemeral_gpu_offload=False))
model = get_peft_model(model, peft_config)model.print_trainable_parameters()
trainable params: 601,346 || all params: 67,556,356 || trainable%: 0.8901
lr = 1e-3batch_size = 4num_epochs = 10from transformers import TrainingArguments # Import Trainertraining_args = TrainingArguments(output_dir= model_checkpoint + "-lora-text-classification",learning_rate=lr,per_device_train_batch_size=batch_size,per_device_eval_batch_size=batch_size,num_train_epochs=num_epochs,weight_decay=0.01,evaluation_strategy="epoch",save_strategy="epoch",load_best_model_at_end=True,)from transformers import Trainertrainer = Trainer(model=model,args=training_args,train_dataset=tokenized_dataset["train"],eval_dataset=tokenized_dataset["validation"],tokenizer=tokenizer,data_collator=data_collator,compute_metrics=compute_metrics,)trainer.train()
Trained model predictions: -------------------------- It was good. - Positive Not a fan, don't recommed. - Negative Better than the first one. - Positive This is not worth watching even once. - Negative This one is a pass. - Positive
三、结语
53AI,企业落地大模型首选服务商
产品:场景落地咨询+大模型应用平台+行业解决方案
承诺:免费POC验证,效果达标后再合作。零风险落地应用大模型,已交付160+中大型企业
2026-02-04
Agent 越用越聪明?AgentScope Java 在线训练插件来了!
2026-02-03
OpenClaw之后,我们离能规模化落地的Agent还差什么?
2026-01-30
Oxygen 9N-LLM生成式推荐训练框架
2026-01-29
自然·通讯:如何挖掘复杂系统中的三元交互
2026-01-29
微调已死?LoRA革新
2026-01-19
1GB 显存即可部署:腾讯 HY-MT1.5 的模型蒸馏与量化策略解析
2026-01-18
【GitHub高星】AI Research Skills:一键赋予AI“博士级”科研能力,74项硬核技能库开源!
2026-01-10
前Mata GenAI研究员田渊栋的年终总结:关于未来AI的思考
2025-11-21
2025-12-04
2026-01-04
2026-01-02
2025-11-22
2025-11-20
2026-01-01
2025-11-19
2025-12-21
2025-11-23
2026-02-03
2026-01-02
2025-11-19
2025-09-25
2025-06-20
2025-06-17
2025-05-21
2025-05-17