【LLM】Prompt tuning大模型微调实战

2023-07-10 21:48:49

note

  • prompt tuning可看做是prefix tuning的简化版本,在输入层加入prompt tokens,并不需要加入MLP进行调整来解决难训练的问题,作者实验表明随着预训练模型参数量的增加,prompt tuning效果逼近fine tuning效果

一、Propmt tuning

1. peft库中的tuning

  • 之前提到过可以借助peft库(Parameter-Efficient Fine-Tuning)进行微调,支持如下tuning:
    • Adapter Tuning(固定原预训练模型的参数 只对新增的adapter进行微调)
    • Prefix Tuning(在输入token前构造一段任务相关的virtual tokens作为prefix,训练时只更新Prefix部分的参数,而Transformer的其他不分参数固定,和构造prompt类似,只是prompt是人为构造的即无法在模型训练时更新参数,而Prefix可以学习<隐式>的prompt)
    • Prompt Tuning(Prefix Tuning的简化版,只在输入层加入prompt tokens,并不需要加入MLP)
    • P-tuning(将prompt转为可学习的embedding层,v2则加入了prompts tokens作为输入)
    • LoRA(Low-Rank Adaption,为了解决adapter增加模型深度而增加模型推理时间、上面几种tuning中prompt较难训练,减少模型的可用序列长度)
      • 该方法可以在推理时直接用训练好的AB两个矩阵和原预训练模型的参数相加,相加结果替换原预训练模型参数。
      • 相当于用LoRA模拟full-tunetune过程

2. prompt tuning怎么搞

  • 给出好的prompt可以让LLM生成更好的答案,反过来想通过LLM帮我们找到好的prompt就是prompt tuning的思路,训练让模型看到新的例子生成prompt,并把该段prompt作为前缀拼接到我们自己的prompt上,送入LLM得到结果
    • prompt tuning训练的前缀是向量,所以解释性略差
  • 和few show比较:LLM的上下文context长度是有限的(prompt中给出有限的例子,业务复杂时难让模型学习足够多知识),prompt tuning就没有这个限制,只需在训练LLM时给他看足够多例子,之后提问带上一个短的prompt前缀(一般8~20个token)即可
  • 和fine tuning比较:prompt tuning是完全冻结LLM模型参数,只需训练一个几个token的prompt前缀;但是fine tuning精调一个模型很耗资源
  • 为每一个任务额外添加一个或多个embedding,之后拼接query正常输入LLM,并只训练这些embedding。如下图,左图为单任务全参数微调,右图为prompt tuning。
    • prompt tuning将fine tune任务转为mlm任务。自动学习模板:离散的主要包括 Prompt Mining, Prompt Paraphrasing, Gradient-based Search, Prompt Generation 和 Prompt Scoring;连续的则主要包括Prefix Tuning, Tuning Initialized with Discrete Prompts 和 Hard-Soft Prompt Hybrid Tuning。
    • 正常微调举例:[cls]今天天上都出太阳了,阳光明媚。[SEP]
      prompt输入举例:[cls]今天天气是[MASK]。[SEP] 今天天上都出太阳了,阳光明媚。[SEP]

在这里插入图片描述

3. 参数如何选择

prompt tuning论文:The Power of Scale for Parameter-Efficient Prompt Tuning

在这里插入图片描述

  • 作者的对比实验如下,随着预训练模型参数的增加,很简单的参数设置也能达到不错效果:
    • prompt长度,即下面代码中的num_virtual_tokens参数:模型参数达到一定量级时,Prompt 长度为 1 也能达到不错的效果,Prompt 长度为 20 就能达到极好效果。
    • prompt初始化方式,即下面代码中的prompt_tuning_init:初始化方式中random方式稍差于另外的
    • TaskType任务类型:和peft的其他tuning类似,也有这个参数
class TaskType(str, enum.Enum):
    SEQ_CLS = "SEQ_CLS"   常规分类任务
    SEQ_2_SEQ_LM = "SEQ_2_SEQ_LM" seq2seq任务
    CAUSAL_LM = "CAUSAL_LM"  LM任务
    TOKEN_CLS = "TOKEN_CLS"  token的分类任务:序列标注之类的

二、Prompt tuning代码实战

1. tuning训练

  • 数据:twitter_complaints
  • 模型:bigscience/bloomz-560m模型
  • PromptTuningConfig设置Prompt tuning配置,下面num_virtual_tokens设置prompt前缀的token数,因为token初始化用任务相关文字效果更好,所以下面用Classify if the tweet is a complaint or not:初始化,
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Author : andy
@Date   : 2023/7/10 20:37
@Contact: 864934027@qq.com 
@File   : prompt_tuning.py 
"""
from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator, get_linear_schedule_with_warmup
from peft import get_peft_config, get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType, PeftType
import torch
from datasets import load_dataset
import os
from torch.utils.data import DataLoader
from tqdm import tqdm

device = "mps"
# device = "cuda"
model_name_or_path = "bigscience/bloomz-560m"
tokenizer_name_or_path = "bigscience/bloomz-560m"
peft_config = PromptTuningConfig(
    task_type=TaskType.CAUSAL_LM,
    prompt_tuning_init=PromptTuningInit.TEXT,
    num_virtual_tokens=8,
    prompt_tuning_init_text="Classify if the tweet is a complaint or not:",
    tokenizer_name_or_path=tokenizer_name_or_path,
)

dataset_name = "twitter_complaints"
text_column = "Tweet text"
label_column = "text_label"
max_length = 64
learning_rate = 3e-2
num_epochs = 20
batch_size = 8
output_dir = './output'

# 1. load a subset of the RAFT dataset at https://huggingface.co/datasets/ought/raft
dataset = load_dataset("ought/raft", dataset_name)

# get lable's possible values
label_values = [name.replace("_", "") for name in dataset["train"].features["Label"].names]
# append label value to the dataset to make it more readable
dataset = dataset.map(
    lambda x: {label_column: [label_values[label] for label in x["Label"]]},
    batched=True,
    num_proc=1
)
# have a look at the data structure
dataset["train"][0]

在这里插入图片描述

# 2. dataset
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

def preprocess_fn(examples):
    tweets = examples[text_column]
    # pad labels with a pad token at the end
    labels = [str(x) + tokenizer.pad_token for x in examples[label_column]]
    # concatenate the tweet with it label
    inputs = [f"{text_column} : {tweet}\nLabel :{label}"
              for tweet, label in zip(tweets, labels)]
    # tokenize input
    model_inputs = tokenizer(inputs,
                           padding='max_length',
                           max_length=max_length,
                           truncation=True,)
    # tokenize label, as -100 not a valid token id, do the padding manually here
    labels_input_ids = []
    for i in range(len(labels)):
        ids = tokenizer(labels[i])["input_ids"]
        padding = [-100] * (max_length - len(ids))
        labels_input_ids.append(padding + ids)
        model_inputs["labels"] = labels_input_ids
        # make model inputs tensor
        model_inputs["input_ids"] = [torch.tensor(ids) for ids in model_inputs["input_ids"]]
        model_inputs["attention_mask"] = [torch.tensor(ids) for ids in model_inputs["attention_mask"]]
        model_inputs["labels"] = [torch.tensor(ids) for ids in model_inputs["labels"]]

    return model_inputs

# have a look at the preprocessing result
# print(preprocess_fn(dataset["train"][:2]))

processed_datasets = dataset.map(
    preprocess_fn,
    batched=True,
    num_proc=1,
    remove_columns=dataset["train"].column_names, #remove unprocessed column for training
    load_from_cache_file=False,
    desc="Running tokenizer on datasset"
)

test_size = round(len(processed_datasets["train"]) * 0.2)
train_val = processed_datasets["train"].train_test_split(
    test_size=test_size, shuffle=True, seed=42)
train_data = train_val["train"]
val_data = train_val["test"]


# 3. model
model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
model = get_peft_model(model, peft_config)
print(model.print_trainable_parameters())
trainable params: 8192 || all params: 559222784 || trainable%: 0.0014648902430985358

从上面打印结果看出,模型的参数有5.6亿左右,但是需要训练的参数只占0.001%,只有8192个。

# 4. trainer
from transformers import Trainer, TrainingArguments
trainer = Trainer(
    model=model,
    train_dataset=train_data,
    eval_dataset=val_data,
    data_collator=default_data_collator,
    args=TrainingArguments(
      output_dir='./output',
      per_device_train_batch_size=batch_size,
      num_train_epochs=num_epochs,
      learning_rate=learning_rate,
      load_best_model_at_end=True,
      logging_strategy='steps',
      logging_steps=10,
      evaluation_strategy='steps',
      eval_steps=10,
      save_strategy='steps',
      save_steps=10,
    )
  )
trainer.train()

在这里插入图片描述

2. 模型推理比较

# 5. inference
def  inference():
    def generate(inputs, infer_model):
        with torch.no_grad():
            inputs = {k: v.to(device) for k, v in inputs.items()}
            outputs = infer_model.generate(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                max_new_tokens=20,
                eos_token_id=3
            )
            print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0])

    # (1) base model_inference
    base_model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
    base_model.to(device)
    inputs = tokenizer(
        f'{text_column} : {"@denny the grocery price is soaring, even milk is becoming unaffordable, could you do something?"}\nLabel :',
        return_tensors="pt",  # Return PyTorch torch.Tensor objects.
    )
    generate(inputs, base_model)
    print("----------------------------------------")
    shot1 = f'{text_column} : {"@nationalgridus I have no water and the bill is current and paid. Can you do something about this?"}\nLabel :complaint\n'
    shot2 = f'{text_column} : {"@HMRCcustomers No this is my first job"}\nLabel :no complaint\n'
    input = f'{text_column} : {"@denny the grocery price is soaring, even milk is becoming unaffordable, could you do something?"}\nLabel :'
    inputs_few_shot = tokenizer(
        shot1 + shot2 + input,
        return_tensors="pt",
    )
    generate(inputs_few_shot, base_model)

    # (2) prompt-tuned model_inference
    from peft import PeftModel, PeftConfig
    path = "/content/drive/MyDrive/prompt_tuning"
    config = PeftConfig.from_pretrained(path)
    pretrained_model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
    prompt_tuned_model = PeftModel.from_pretrained(pretrained_model, path)
    prompt_tuned_model.to(device)
    inputs = tokenizer(
        f'{text_column} : {"@denny the grocery price is soaring, even milk is becoming unaffordable, could you do something?"}\nLabel :',
        return_tensors="pt",  # Return PyTorch torch.Tensor objects.
    )
    generate(inputs, prompt_tuned_model)

inference()
  • 上面base model推理结果:
Tweet text : @denny the grocery price is soaring, even milk is becoming unaffordable, could you do something?
Label : @denny the grocery<?php
/**
 * Copyright © 2016 Google Inc.

----------------------------------------
Tweet text : @nationalgridus I have no water and the bill is current and paid. Can you do something about this?
Label :complaint
Tweet text : @HMRCcustomers No this is my first job
Label :no complaint
Tweet text : @denny the grocery price is soaring, even milk is becoming unaffordable, could you do something?
Label :complaint<?php
/**
 * Copyright © Magento, Inc. All rights reserved.
  • prompt-tuned model推理结果:
Tweet text : @denny the grocery price is soaring, even milk is becoming unaffordable, could you do something?
Label :complaint

3. 其他tuning技术

在这里插入图片描述

  • prefix tuning和prompt tuning都不需要改LLM模型参数本身,但prefix tuning不进在用户输入该层找到一个前缀,还要在LLM的每层都找到一个前缀并加上,训练成本明显更高
  • p-tuning则不仅可在用户输入的开头加附加信息,也可以在中间或结尾附加信息
  • lora tuning如下图,上一篇博客也讲过

在这里插入图片描述

Reference

[1] https://github.com/jxhe/unify-parameter-efficient-tuning
[2] Continuous Optimization:从Prefix-tuning到更强大的P-Tuning V2
[3] 五万字综述!Prompt-Tuning:深度解读一种新的微调范式
[4] 还在Fine-tune大规模预训练模型?了解下Prompt-tuning
[5] 让天下没有难Tuning的大模型:PEFT技术简介.阿里-风飏
[6] prompt tuning论文:The Power of Scale for Parameter-Efficient Prompt Tuning
[6] 你还弄不清xxxForCausalLM和xxxForConditionalGeneration吗?

更多推荐

Java版分布式微服务云开发架构 Spring Cloud+Spring Boot+Mybatis 电子招标采购系统功能清单

项目说明随着公司的快速发展,企业人员和经营规模不断壮大,公司对内部招采管理的提升提出了更高的要求。在企业里建立一个公平、公开、公正的采购环境,最大限度控制采购成本至关重要。符合国家电子招投标法律法规及相关规范,以及审计监督要求;通过电子化平台提高招投标工作的公开性和透明性;通过电子化招投标,使得招标采购的质量更高、速度

日志输出-查看 SQL:深入分析 MyBatis 执行过程

😀前言在现代软件开发中,数据库操作是不可或缺的一部分,而持久层框架的应用能够极大地简化这一过程。然而,当我们在开发MyBatis程序时,有时候需要深入了解程序底层实际执行的SQL语句,以便更好地分析和优化数据库操作。本文将探讨如何通过配置日志输出,在MyBatis中查看SQL语句的执行情况,以便更深入地了解其执行过程

逻辑回归中对L1\L2正则化的理解

在逻辑回归中,L1和L2正则化是常用的正则化技术,用于控制模型的复杂度并防止过拟合。它们通过在损失函数中引入额外的正则化项来实现。L1正则化(Lasso正则化):L1正则化使用参数权重的绝对值之和作为正则化项。其目标是将一些权重压缩为零,从而实现特征选择的效果。L1正则化的数学形式如下:R(w)=λ∑i=1n∣wi∣\

英语CN专刊《英语教师》简介及投稿须知

英语CN专刊《英语教师》简介及投稿须知《英语教师》杂志是由中华人民共和国新闻出版总署、正式批准公开发行的优秀期刊,《英语教师》系一本面向基础教育和高等教育英语教师的、兼顾理论性与实践性的专业性期刊。《英语教师》的读者对象主要是广大英语教师、英语教研员以及高校外语院系学生。本刊主要刊载有关英语教学和英语教师教育的论文、实

Linux 系统移植(二)--系统调试

文章目录一、编译文件系统1.1下载资源安装包1.2配置模板ARM64目标平台1.3配置交叉编译器1.4配置登录用户名和密码1.5配置Linux控制台1.6配置文件系统格式1.7编译buildroot文件系统二、编译ARM64Linux三、启动QemuLinux系统参考链接:一、编译文件系统1.1下载资源安装包我们使用b

(日积月累版)大数据基础知识点1-关系型数据库

好久不见,甚是想念。笔者最近有时间整理关于大数据的一些基础知识点,整理的目不在于能提升多少技能,关键在于巩固一些很基础的知识点,毕竟互联网就是基础略稳固的人比较有优势,在遇到或发现一些技术问题时,从底层科学的去理解这些问题,说不定会有另一片天下。那么本期带来大数据面是:关系型数据库一、什么是关系型数据库?关系型数据库是

全国职业技能大赛云计算--高职组赛题卷④(私有云)

全国职业技能大赛云计算--高职组赛题卷④(私有云)第一场次题目:OpenStack平台部署与运维任务1基础运维任务(5分)任务3OpenStack云平台运维(15分)任务4OpenStack云平台运维开发(15分,本任务只公布考试范围,不公布赛题)需要环境私信博主!!!第一场次题目:OpenStack平台部署与运维某企

2023,DaaS驶入“AI大航海时代”

2023,“制胜”已然成为所有行业、企业的共同命题,随着数字化行至中程,数据壁垒逐渐被打破,DaaS作为企业增长问题的解法,再次被看到。作者|斗斗编辑|皮爷出品|产业家2002年,在竞争激烈的美国职业棒球联盟,奥克兰运动家队无论在人员和物质配备以及资金实力上都只是“下三流”之列。然而在数据分析高材生的帮助下,经过分析数

RK3568开发笔记(十一):开发版buildroot固件移植一个ffmpeg播放rtsp的播放器Demo

若该文为原创文章,转载请注明原文出处本文章博客地址:https://hpzwl.blog.csdn.net/article/details/133022813红胖子网络科技博文大全:开发技术集合(包含Qt实用技术、树莓派、三维、OpenCV、OpenGL、ffmpeg、OSG、单片机、软硬结合等等)持续更新中…瑞芯微开

HBase基本操作及命令示例

HBase是一种分布式、可扩展、面向列的数据库,它是由Google的Bigtable项目衍生而来,并由Apache软件基金会开发及维护。对于HBase的基本操作类型,主要包括以下几种:创建表:在HBase中,可以创建一个新的表来存储数据。创建表时,需要定义表的名称以及表的列族。命令示例:create'table_nam

ASP.NET Core 8 的 Web App

WebAppWebApp与WebAPI的不同之处在于包含UI部分,所谓的UI就是HTML页面。WebApp支持几种渲染HTML的方式:服务端渲染客户端渲染混合渲染服务端渲染服务端渲染UI是在浏览器请求的时候,服务端生成HTML,然后返回给浏览器。优点是:减轻客户端的压力服务端生成HTML,适配各种浏览器极少从Clien

热文推荐