[NLP] LLM---<训练中文LLama2(五)>对SFT后的LLama2进行DPO训练

2023-09-17 08:42:14

当前关于LLM的共识

大型语言模型(LLM)使 NLP 中微调模型的过程变得更加复杂。最初,当 ChatGPT 等模型首次出现时,最主要的方法是先训练奖励模型,然后优化 LLM 策略。从人类反馈中强化学习(RLHF)极大地推动了NLP的发展,并将NLP中许多长期面临的挑战抛在了一边。基于人类反馈的强化学习 (Reinforcement Learning from Human Feedback,RLHF) 事实上已成为 GPT-4 或 Claude 等 LLM 训练的最后一步,它可以确保语言模型的输出符合人类在闲聊或安全性等方面的期望。

然而,它也给 NLP 引入了一些 RL 相关的复杂性: 既要构建一个好的奖励函数,并训练一个模型用以估计每个状态的价值 (value); 又要注意最终生成的 LLM 不能与原始模型相差太远,如果太远的话会使得模型容易产生乱码而非有意义的文本。该过程非常复杂,涉及到许多复杂的组件,而这些组件本身在训练过程中又是动态变化的,因此把它们料理好并不容易。

现在主流的LLM,比如chatglm、chinese-alpaca,主要进行了三步操作:

Step1:知识学习,CLM,大规模语料库上的预训练,本步的模型拥有续写的功能

Step2:知识表达,指令微调,在指令数据上进行微调,本步骤可以使用Lora等节省显存的方式,本模型可以听懂人类指令并进行回答的功能

Step3:偏好学习,RLHF或本文所提的DPO,可以让模型的输出更符合人类偏好,通俗说就是同样一句话,得调教的让模型输出人类喜欢的表达方式,好比高情商的人说话让人舒服

第二步,还是多多少少学习了一点知识,第三步则几乎不学知识,只学表达方式了。

RLHF太耗时耗力了,得提前训练好RewardModel,然后PPO阶段,得加载4个模型,2个推理,2个训练,实在是太不友好了。

下图是SFT+RLHF的过程,对应上文的Step2和Step3,主要包括指令微调模型、训练奖励模型和PPO优化。

现在大多数目前开源的LLM模型都只做了前2步:预训练和指令微调。

而其中原因就是第3步人类反馈强化学习(RLHF)实现起来很困难:

1.需要人类反馈数据(很难收集)
2.奖励模型训练(很难训练)
3. PPO强化学习微调(不仅很耗资源,而且也很难训练)

但是能不能不要最后一步呢,一般来说还是有RLHF比较好,有主要有以下几个原因:

  1. 提高安全性和可控性;
  2. 改进交互性;
  3. 克服数据集偏差;
  4. 提供个性化体验;
  5. 符合道德规范;
  6. 持续优化和改进。

RLHF使得ChatGPT这样的大型对话模型既具备强大能力,又能够接受人类价值观的指导,生成更智能、安全、有益的对话回复。这是未来可信赖和可解释AI的重要发展方向。

所以这一步还是非常重要。那如何解决人类反馈强化学习(RLHF)训练这个难题呢?

DPO (Differentiable Policy Optimization) 算法

Rafailov、Sharma、Mitchell 等人最近发表了一篇论文 Direct Preference Optimization,论文提出将现有方法使用的基于强化学习的目标转换为可以通过简单的二元交叉熵损失直接优化的目标,这一做法大大简化了 LLM 的提纯过程。

DPO 是为实现对 LLM 的精确控制而引入的一种方法。从人类反馈强化学习(RLHF)的基础是训练奖励模型,然后使用近端策略优化(PPO)使语言模型的输出与人类的偏好相一致。这种方法虽然有效,但既复杂又不稳定。DPO 将受限奖励最大化问题视为人类偏好数据的分类问题。这种方法稳定、高效、计算量小。它无需进行奖励模型拟合、大量采样和超参数调整。

DPO(Direct Preference Optimization)是一种直接偏好优化算法,它与PPO(Proximal Policy Optimization)优化的目标相同。主要思路是:

1.定义policy模型(策略模型)和reference模型(参考模型),Policy模型是需要训练的对话生成模型,reference模型是给定的预训练模型或人工构建的模型。

2.对于给定prompt,计算两模型对正样本和负样本的概率,正样本是人类选择的回复,负样本是被拒绝的回复。

3.通过两个模型概率的差值构建DPO损失函数,惩罚policy模型对正样本概率的下降和负样本概率的上升。通过最小化DPO损失进行模型训练。

相比之下DPO就很友好,只需要加载2个模型,其中一个推理,另外一个训练,直接在偏好数据上进行训练即可:

DPO 拒绝有害问题 实战部分

数据集

数据集其实就是标准的RLHF奖励模型的训练集,下载地址在这

Anthropic/hh-rlhf · Datasets at Hugging Face

dikw/hh_rlhf_cn · Datasets at Hugging Face

其样式就是:一个context,一个选择的正样本,一个拒绝的负样本。希望这些样本能够让LLM 尽可能生成用户选择的无害回复,而不要生成有害的回复。

微调代码
下方这段代码实现了基于DPO (Differentiable Policy Optimization) 的对话模型微调。主要步骤包括:

  1. 加载预训练语言模型(这里使用llama-2-7b)并准备量化训练,采用int4量化的+少量lora 参数。
  2. 定义参考模型(int4量化的模型),也使用同样的预训练模型。
  3. 加载Helpful/Harmless数据集,并转换成所需格式。
  4. 定义DPO训练参数,包括batch size,学习率等。
  5. 定义DPO训练器,传入policy模型,参考模型,训练参数等。
  6. 进行DPO微调训练。
  7. 保存微调后的模型,只保存量lora 参数。

关键点:

1. 使用DPO损失函数实现安全性约束的模型训练。不需要额外在训练一个奖励模型。
2. 这也导致整个训练过程只需要策略模型和参考模型 2个LLM模型,不需要额外的显存去加载奖励模型。
3. 整个训练过程策略模型和参考模型可以进行4int的模型量化 + 少量的lora 参数

综上,这段代码对预训练语言模型进行DPO微调,以实现安全可控的对话生成

#!/usr/bin/env python
# coding: utf-8

from typing import Dict

import torch
from datasets import Dataset, load_dataset
from trl import DPOTrainer
import bitsandbytes as bnb

from transformers import TrainingArguments
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import BitsAndBytesConfig
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training
)

output_dir1 = "./dpo_output_dir1"
output_dir2 = "./dpo_output_dir2"

base_model = "/home/work/llama-2-7b"

###准备训练数据
dataset = load_dataset("json", data_files="./dpo_dataset/harmless_base_cn_train.jsonl")
train_val = dataset["train"].train_test_split(
    test_size=2000, shuffle=True, seed=42
)
train_data = train_val["train"]
val_data = train_val["test"]


def extract_anthropic_prompt(prompt_and_response):
    final = ""
    for sample in prompt_and_response:
        final += sample["role"] + "\n" + sample["text"]
    final += "\n"
    return final


def get_hh(dataset, split: str, sanity_check: bool = False, silent: bool = False, cache_dir: str = None) -> Dataset:
    """Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format.

    The dataset is converted to a dictionary with the following structure:
    {
        'prompt': List[str],
        'chosen': List[str],
        'rejected': List[str],
    }

    Prompts should be structured as follows:
      \n\nHuman: <prompt>\n\nAssistant:
    Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:.
    """
    dataset = dataset
    if sanity_check:
        dataset = dataset.select(range(min(len(dataset), 1000)))

    def split_prompt_and_responses(sample) -> Dict[str, str]:
        prompt = extract_anthropic_prompt(sample["context"])
        return {
            "prompt": prompt,
            "chosen": sample["chosen"]["role"] + "\n" + sample["chosen"]["text"],
            "rejected": sample["rejected"]["role"] + "\n" + sample["rejected"]["text"],
        }

    return dataset.map(split_prompt_and_responses)


train_dataset = get_hh(train_data, "train", sanity_check=True)
eval_dataset = get_hh(val_data, "test", sanity_check=True)


def find_all_linear_names(model):
    # cls = bnb.nn.Linear8bitLt
    cls = bnb.nn.Linear4bit
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])

    if 'lm_head' in lora_module_names:  # needed for 16-bit
        lora_module_names.remove('lm_head')
    return list(lora_module_names)


def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainables%: {100 * trainable_params / all_param}"
    )


tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"  # Fix weird overflow issue with fp16 training

bnb_4bit_compute_dtype = "float16"
compute_dtype = getattr(torch, bnb_4bit_compute_dtype)
bnb_4bit_quant_type = "nf4"
use_nested_quant = False

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type=bnb_4bit_quant_type,
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=use_nested_quant,
)

model = AutoModelForCausalLM.from_pretrained(base_model,
                                             trust_remote_code=True,
                                             quantization_config=bnb_config,
                                             device_map="auto")
model.config.use_cache = False
model = prepare_model_for_kbit_training(model)

modules = find_all_linear_names(model)
config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    target_modules=modules,
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, config)
print_trainable_parameters(model)

###定义参考模型
model_ref = AutoModelForCausalLM.from_pretrained(base_model,
                                                 trust_remote_code=True,
                                                 quantization_config=bnb_config,
                                                 device_map="auto")
###定义dpo训练参数
training_args = TrainingArguments(
    per_device_train_batch_size=1,
    max_steps=100,
    remove_unused_columns=False,
    gradient_accumulation_steps=2,
    learning_rate=3e-4,
    evaluation_strategy="steps",
    output_dir="./test",
)

###定义dpo训练器
dpo_trainer = DPOTrainer(
    model,
    model_ref,
    args=training_args,
    beta=0.1,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
)
###训练
dpo_trainer.train()
###模型保存
dpo_trainer.save_model(output_dir1)

dpo_trainer.model.save_pretrained(output_dir2)
tokenizer.save_pretrained(output_dir2)

训练过程

其中看出加载了2遍int4量化的模型到显存中,需要训练的策略模型只有一部分lora参数,而参考模型就是原始模型本身.

模型保存

保存下来的参数也就是lora参数,这部分lora 参数就学会了如何拒绝回答有害问题。

至此,我们就学会了如何利用使用DPO +Qlora 实现在完成RLHF的实战。

使用场景

核心原则:偏好数据集中的good/bad response都是和SFT model的训练数据同分布的,也可以说模型是可以生成good/bad response的。

场景1

已有一个SFT model,为了让它更好,对它的output进行偏好标注,然后使用DPO进行训练,这是最正常的使用场景,但是偏好数据集确实避免不了的

场景2

场景1的改进版本,偏好标注不由人来做,而是让gpt4或者一个reward model来标注好坏,至于reward model怎么来,就各凭本事吧

场景3

没有SFT model只有偏好数据集,那就先在偏好数据即中的进行训练,然后在进行DPO的训练。先SFT就是为了符合上文的核心原则

OpenAI独家绝技RLHF也被开源超越啦?!DPO让小白轻松玩转RLHF![已开源] - 知乎 (zhihu.com)

RLHF中的「RL」是必需的吗?有人用二进制交叉熵直接微调LLM,效果更好 - 知乎 (zhihu.com)

直接偏好优化:你的语言模型其实是一个奖励模型 - 知乎 (zhihu.com)

消费级显卡搞定RLHF——DPO算法+QLora微调LLM拒绝有害问题回答实战 - 知乎 (zhihu.com)

使用 DPO 微调 Llama 2 - 知乎 (zhihu.com)

DPO(Direct Preference Optimization):LLM的直接偏好优化 - 知乎 (zhihu.com)

DPO: Direct Preference Optimization 论文解读及代码实践 - 知乎 (zhihu.com)GitHub - mzbac/llama2-fine-tune: Scripts for fine-tuning Llama2 via SFT and DPO.

DPO——RLHF 的替代之《Direct Preference Optimization: Your Language Model is Secretly a Reward Model》论文阅读 - 知乎 (zhihu.com)

更多推荐

【2023 睿思芯科 笔试题】~ 题目及参考答案

文章目录1.题目&答案单选题编程题问题1:解析1:问题2:解析2:声明名称如标题所示,希望大家正确食用(点赞+转发+评论)本次笔试题以两种形式考察的,分别是:选择题(包括单选和多选)和编程题。这里强调的是笔试全英文!!!其实不光这一家公司,很多都是这样的,所以,英语还是那么的重要!一共5个选择+2个编程题。PS:大家在

跨境电商如何更好地备战销售旺季?

跨境电商秋促来临!不知道各位卖家是否做好准备了呢?据外媒报道,TikTokShop于近日开启了年度最大规模的黑五大促,而为了抢占旺季流量,继周二亚马逊宣布将于10月10日至11日举办第二届“PrimeBigDealDays”大促后,周三沃尔玛、Target和BestBuy也都宣布将推出自己的大促,提前开启假日促销。这几

Docker进阶:mysql 主从复制、redis集群3主3从【扩缩容案例】

Docker进阶:mysql主从复制、redis集群3主3从【扩缩容案例】一、Docker常规软件安装1.1docker安装tomcat(默认最新版)1.2docker指定安装tomcat8.01.3docker安装mysql5.7(数据卷配置)1.4演示--删除mysql容器,里面的数据是否能正常恢复1.5docke

HTML的学习 Day01

《网页设计与制作》是web前端开发技术中静态网页中的内容,主要包括html、css、js中的静态内容部分,是专业基础课程。随着5G时代的到来,人工智能与物联网结合行业的飞速发展,更多的互联网的崛起。这肯定就比如伴随着对移动互联网领域新的开发场景的需求,需要大量的前端和移动前端开发来呈现。【技术越来越成熟就越要想每一个给

认识面向对象-PHP8知识详解

面向对象编程,也叫面向对象程序设计,是在面向过程程序设计的基础上发展而来的,它比面向过程编程具有更强的灵活性和扩展性。它用类、对象、关系、属性等一系列东西来提高编程的效率,其主要的特性是可封装性、可继承性和多态性。面向对象编程的主要好处就是吧编程的重心从处理过程转移到对现实世界实体的表达。这十分符合人们的正常思维方法。

鞋服零售企业如何进行数字化运营

企业进行信息化和数字化都是在解决同一个问题,降本增效,用信息产生数据,用数据解读业务。在信息化行业操盘十多年,在我的大脑中始终有一个伟大的蓝图,那就是“作战室”,特别是在鞋服零售各种解决方案非常成熟,对作战室的建设几率又大大增加。作战室不仅是数据的可视,更重要的是能指挥作战,如线上线下同步,而线下店铺可视,在通过一个销

百货店失去核心竞争力了吗?全靠超市即时零售撑起

近年来,传统百货不景气,不过,从今年上半年的情况看,似乎有些好转。然而,实际情况并不如预期那么好,有专家认为:“百货上市公司业绩和去年相比增长是正常的。实际上,百货业态增长的可能性几乎为零,如果没有新店开张,没有面积的扩大,单是百货这一块,有增长的可能性不是很大,因为很多上市公司百货店,其实很大一块是有超市带动的,超市

新零售商城模式与传统电商和零售的痛点的对比

新零售是一种以消费者体验为中心的数据驱动的泛零售形态,它通过运用大数据、人工智能等先进技术手段,对商品的生产、流通与销售过程进行升级改造,进而重塑业态结构与生态圈,并对线上服务、线下体验以及现代物流进行深度融合的零售新模式。新零售商城模式是新零售的一种具体实现方式,它通过打造线上线下一体化的商城平台,实现商品、用户和员

Android Fragment

基本概念Fragment是Android3.0后引入的一个新的API,他出现的初衷是为了适应大屏幕的平板电脑,普通手机开发也会加入这个Fragment,可以把他看成一个小型的Activity,又称Activity片段!如果一个很大的界面,就一个布局,写起界面会很麻烦,而且如果组件多的话是管理起来也很麻烦!而使用Frag

深度学习如何入门?

深度学习是一种强大的机器学习方法,它在各个领域都有广泛应用。如果你是一个新手,想要入门深度学习,下面是一些步骤和资源,可以帮助你开始学习和实践深度学习。1.学习基本概念在开始深度学习之前,你需要对一些基本概念有所了解。以下是一些你需要学习的重要概念:神经网络:它从信息处理角度对人脑神经元网络进行抽象,建立某种简单模型,

企业电子招标采购系统源码之从供应商管理到采购招投标、采购合同、采购执行的全过程数字化管理

功能描述1、门户管理:所有用户可在门户页面查看所有的公告信息及相关的通知信息。主要板块包含:招标公告、非招标公告、系统通知、政策法规。2、立项管理:企业用户可对需要采购的项目进行立项申请,并提交审批,查看所有的立项信息。主要功能包含:招标立项申请、非招标立项申请、采购立项管理。3、采购项目管理:可对项目采购过程全流程管

热文推荐