!pip install accelerate==1.0.1 rouge-score==0.1.2 nltk==3.9.1 ms-swift[llm]==2.4.2.post2 evalscope==0.5.5rc1
下载模型
!mkdir ./model
!modelscope download --model qwen/Qwen2.5-1.5B-Instruct --local_dir './model'
这里已经提前准备好了一份线上教育公司相关数据库问题查询的数据集,其中大概500+作为训练集,100+作为测试集。
训练集:用于在训练阶段进行模型参数的训练。
{
"messages": [
{
"role": "system",
"content": "#背景#
数据库信息:{'column_names': [[-1, '*', 'text'], [0, 'address id', 'number'], [0, 'address details', 'text'], [1, 'staff id', 'number'], [1, 'staff gender', 'text'], [1, 'staff name', 'text'], [2, 'supplier id', 'number'], [2, 'supplier name', 'text'], [2, 'supplier phone', 'text'], [3, 'department store chain id', 'number'], [3, 'department store chain name', 'text'], [4, 'customer id', 'number'], [4, 'payment method code', 'text'], [4, 'customer code', 'text'], [4, 'customer name', 'text'], [4, 'customer address', 'text'], [4, 'customer phone', 'text'], [4, 'customer email', 'text'], [5, 'product id', 'number'], [5, 'product type code', 'text'], [5, 'product name', 'text'], [5, 'product price', 'number'], [6, 'supplier id', 'number'], [6, 'address id', 'number'], [6, 'date from', 'time'], [6, 'date to', 'time'], [7, 'customer id', 'number'], [7, 'address id', 'number'], [7, 'date from', 'time'], [7, 'date to', 'time'], [8, 'order id', 'number'], [8, 'customer id', 'number'], [8, 'order status code', 'text'], [8, 'order date', 'time'], [9, 'department store id', 'number'], [9, 'department store chain id', 'number'], [9, 'store name', 'text'], [9, 'store address', 'text'], [9, 'store phone', 'text'], [9, 'store email', 'text'], [10, 'department id', 'number'], [10, 'department store id', 'number'], [10, 'department name', 'text'], [11, 'order item id', 'number'], [11, 'order id', 'number'], [11, 'product id', 'number'], [12, 'product id', 'number'], [12, 'supplier id', 'number'], [12, 'date supplied from', 'time'], [12, 'date supplied to', 'time'], [12, 'total amount purchased', 'text'], [12, 'total value purchased', 'number'], [13, 'staff id', 'number'], [13, 'department id', 'number'], [13, 'date assigned from', 'time'], [13, 'job title code', 'text'], [13, 'date assigned to', 'time']], 'foreign_keys': [[22, 6], [23, 1], [26, 11], [27, 1], [31, 11], [35, 9], [41, 34], [45, 18], [44, 30], [46, 18], [47, 6], [52, 3], [53, 40]], 'primary_keys': [1, 3, 6, 9, 11, 18, 22, 26, 30, 34, 40, 43, 46, 52], 'table_names': ['addresses', 'staff', 'suppliers', 'department store chain', 'customers', 'products', 'supplier addresses', 'customer addresses', 'customer orders', 'department stores', 'departments', 'order items', 'product suppliers', 'staff department assignments']}
#受众#
Mysql数据库
#输出#
只输出SQL查询语句"
},
{
"role": "user",
"content": "#目的#
将问题\"What are the names and ids of customers whose address contains TN?\"转换为转化为SQL查询语句"
},
{
"role": "assistant",
"content": "SELECT customer_name , customer_id FROM customers WHERE customer_address LIKE \"%TN%\""
}
]
}
{
"messages": [
{
"role": "system",
"content": "#背景#
数据库信息:{'column_names': [[-1, '*', 'text'], [0, 'address id', 'number'], [0, 'address details', 'text'], [1, 'staff id', 'number'], [1, 'staff gender', 'text'], [1, 'staff name', 'text'], [2, 'supplier id', 'number'], [2, 'supplier name', 'text'], [2, 'supplier phone', 'text'], [3, 'department store chain id', 'number'], [3, 'department store chain name', 'text'], [4, 'customer id', 'number'], [4, 'payment method code', 'text'], [4, 'customer code', 'text'], [4, 'customer name', 'text'], [4, 'customer address', 'text'], [4, 'customer phone', 'text'], [4, 'customer email', 'text'], [5, 'product id', 'number'], [5, 'product type code', 'text'], [5, 'product name', 'text'], [5, 'product price', 'number'], [6, 'supplier id', 'number'], [6, 'address id', 'number'], [6, 'date from', 'time'], [6, 'date to', 'time'], [7, 'customer id', 'number'], [7, 'address id', 'number'], [7, 'date from', 'time'], [7, 'date to', 'time'], [8, 'order id', 'number'], [8, 'customer id', 'number'], [8, 'order status code', 'text'], [8, 'order date', 'time'], [9, 'department store id', 'number'], [9, 'department store chain id', 'number'], [9, 'store name', 'text'], [9, 'store address', 'text'], [9, 'store phone', 'text'], [9, 'store email', 'text'], [10, 'department id', 'number'], [10, 'department store id', 'number'], [10, 'department name', 'text'], [11, 'order item id', 'number'], [11, 'order id', 'number'], [11, 'product id', 'number'], [12, 'product id', 'number'], [12, 'supplier id', 'number'], [12, 'date supplied from', 'time'], [12, 'date supplied to', 'time'], [12, 'total amount purchased', 'text'], [12, 'total value purchased', 'number'], [13, 'staff id', 'number'], [13, 'department id', 'number'], [13, 'date assigned from', 'time'], [13, 'job title code', 'text'], [13, 'date assigned to', 'time']], 'foreign_keys': [[22, 6], [23, 1], [26, 11], [27, 1], [31, 11], [35, 9], [41, 34], [45, 18], [44, 30], [46, 18], [47, 6], [52, 3], [53, 40]], 'primary_keys': [1, 3, 6, 9, 11, 18, 22, 26, 30, 34, 40, 43, 46, 52], 'table_names': ['addresses', 'staff', 'suppliers', 'department store chain', 'customers', 'products', 'supplier addresses', 'customer addresses', 'customer orders', 'department stores', 'departments', 'order items', 'product suppliers', 'staff department assignments']}
#受众#
Mysql数据库
#输出#
只输出SQL查询语句"
},
{
"role": "user",
"content": "#目的#
将问题\"What are the highest and lowest prices of products, grouped by and alphabetically ordered by product type?\"转换为转化为SQL查询语句"
},
{
"role": "assistant",
"content": "SELECT max(product_price) , min(product_price) , product_type_code FROM products GROUP BY product_type_code ORDER BY product_type_code"
}
]
}
验证集:用于在训练阶段评估模型训练的效果,从而调整模型的超参数(如学习率、退火策略等)和监控过拟合情况。
训练集文件会被按照4:1的比例拆分为训练集与验证集,测试集文件则对应测试集。
测试集:用于在训练阶段完成后评估模型训练的效果。本课程提供的实验数据中,测试集涉及的数据库均没有在训练、验证数据集中出现过。
{
"history": [],
"query": "#背景#
数据库信息:{'column_names': [[-1, '*', 'text'], [0, 'school id', 'text'], [0, 'school name', 'text'], [0, 'location', 'text'], [0, 'mascot', 'text'], [0, 'enrollment', 'number'], [0, 'ihsaa class', 'text'], [0, 'ihsaa football class', 'text'], [0, 'county', 'text'], [1, 'school id', 'number'], [1, 'year', 'number'], [1, 'budgeted', 'number'], [1, 'total budget percent budgeted', 'number'], [1, 'invested', 'number'], [1, 'total budget percent invested', 'number'], [1, 'budget invested percent', 'text'], [2, 'endowment id', 'number'], [2, 'school id', 'number'], [2, 'donator name', 'text'], [2, 'amount', 'number']], 'foreign_keys': [[9, 1], [17, 1]], 'primary_keys': [1, 9, 16], 'table_names': ['school', 'budget', 'endowment']}
#受众#
Mysql数据库
#输出#
只输出SQL查询语句
#目的#
将问题\"Count the number of schools.\"转换为转化为SQL查询语句\n",
"response": "SELECT count(*) FROM school"
}
{
"history": [],
"query": "#背景#
数据库信息:{'column_names': [[-1, '*', 'text'], [0, 'school id', 'text'], [0, 'school name', 'text'], [0, 'location', 'text'], [0, 'mascot', 'text'], [0, 'enrollment', 'number'], [0, 'ihsaa class', 'text'], [0, 'ihsaa football class', 'text'], [0, 'county', 'text'], [1, 'school id', 'number'], [1, 'year', 'number'], [1, 'budgeted', 'number'], [1, 'total budget percent budgeted', 'number'], [1, 'invested', 'number'], [1, 'total budget percent invested', 'number'], [1, 'budget invested percent', 'text'], [2, 'endowment id', 'number'], [2, 'school id', 'number'], [2, 'donator name', 'text'], [2, 'amount', 'number']], 'foreign_keys': [[9, 1], [17, 1]], 'primary_keys': [1, 9, 16], 'table_names': ['school', 'budget', 'endowment']}
#受众#
Mysql数据库
#输出#
只输出SQL查询语句
#目的#
将问题\"Show all donor names.\"转换为转化为SQL查询语句",
"response": "SELECT DISTINCT donator_name FROM endowment"
}
先评测
!mkdir -p eval_outputs
%env LOG_LEVEL=WARNING
!swift eval \
--model_id_or_path './model' \
--model_type 'qwen2_5-1_5b-instruct' \
--eval_dataset no \
--custom_eval_config 'resources/2_4/data/config_eval.json' \
--max_length -1 \
--system '' \
--infer_backend 'pt' \
--name 'pre_train_evaluation' \
--eval_output_dir './eval_outputs'
在完成了模型的评测后,就可以开始对模型进行微调了。swift提供了操作简单的微调工具,你只需要传入训练集、模型路径等参数。同时swift框架在微调时默认使用LoRA方法,在命令中不需要额外声明。
微调命令:
%env CUDA_VISIBLE_DEVICES=0
%env LOG_LEVEL=INFO
!mkdir -p logs
!swift sft \
--dataset 'resources/2_4/data/train.jsonl' \
--learning_rate '1e-4' \
--eval_steps '10' \
--batch_size '4' \
--model_type 'qwen2_5-1_5b-instruct' \
--max_length 2048 \
--model_id_or_path './model' \
--num_train_epochs 3
微调后一般会保存两个checkpoint文件,分别是best_model_checkpoint(在验证集表现最佳的微调模型)与last_model_checkpoint(最后一次保存的checkpoint)。
[INFO:swift] last_model_checkpoint: /mnt/workspace/aliyun_acp_learning/大模型ACP认证教程/p2_构造大模型问答系统/output/qwen2_5-1_5b-instruct/v1-20250106-161819/checkpoint-102
[INFO:swift] best_model_checkpoint: /mnt/workspace/aliyun_acp_learning/大模型ACP认证教程/p2_构造大模型问答系统/output/qwen2_5-1_5b-instruct/v1-20250106-161819/checkpoint-90
微调之后再测评
!mkdir -p eval_outputs
%env LOG_LEVEL=WARNING
!swift eval \
--ckpt_dir 'output/qwen2_5-1_5b-instruct/v1-20250106-161819/checkpoint-102' \
--eval_dataset no \
--custom_eval_config 'resources/2_4/data/config_eval.json' \
--max_length -1 \
--system '' \
--infer_backend 'pt' \
--name 'pre_train_evaluation' \
--eval_output_dir './eval_outputs'
可以看到,Rouge-L分数都有了非常大的提升,几乎翻倍,说明经过微调后的模型NL2SQL的能力得到了大幅度的加强。
模型微调训练完成后,将基础模型与微调得到的低秩参数融合,获得一个完整的、更新了参数的模型,再调用融合了的模型。
%env LOG_LEVEL=INFO
!swift export \
--ckpt_dir 'output/qwen2_5-1_5b-instruct/v1-20250106-161819/checkpoint-102' \
--merge_lora true
微调数据集构建策略
一般来说,在比较复杂的场景中,微调至少需要1000+条优质的训练集数据。构建数据集时,请确认以下几点:
数据质量:确保数据集的准确性和相关性,避免模糊和错误内容。
数据多样性:覆盖任务的所有关键方面和潜在变化,包括不同场景、语境和专业术语。
平衡性:如果任务涉及多种类别场景,确保各类别样本均衡,防止模型偏向于某一类。
持续迭代:微调是一个迭代过程,根据模型在验证集上的表现反馈,不断优化和扩大数据集。
而如果你在进行模型微调时缺乏数据,建议你使用知识库检索来增强模型能力。
在很多复杂的业务场景中,可以综合采用模型调优和知识库检索相结合的技术方案。
你也可以采用以下策略扩充数据集:
联系数据专家或数据团队基于已有数据,制作更多的典型场景数据。
让能力更强的大模型模拟生成特定业务/场景的相关内容,辅助你生成更多可用于微调的数据。
通过应用场景收集、网络爬虫、社交媒体和在线论坛、公开数据集、合作伙伴与行业资源、用户贡献等方式获取更多数据。