Prompt tuning T5 모델에서 텍스트 분류 작업을 위해 개발되었으며, 모든 다운스트림 작업은 텍스트 생성 작업으로 변환됨.
프롬프트는 토큰의 시리즈로 입력에 추가됨.
Prompt tuning의 ****핵심 아이디어는 프롬프트 토큰이 독립적으로 업데이트되는 자체 매개변수를 가짐으로써 사전 훈련된 모델의 매개변수를 고정된 상태로 유지하고, 프롬프트 토큰 임베딩의 그래디언트(gradient)만 업데이트할 수 있음.
구현 예시💻
from peft import PromptTuningConfig, PromptTuningInit, get_peft_model
prompt_tuning_init_text = "Classify if the tweet is a complaint or no complaint.\\n"
peft_config = PromptTuningConfig(
task_type="CAUSAL_LM",
prompt_tuning_init=PromptTuningInit.TEXT,
num_virtual_tokens=len(tokenizer(prompt_tuning_init_text)["input_ids"]),,
prompt_tuning_init_text=prompt_tuning_init_text,
tokenizer_name_or_path="bigscience/bloomz-560m",
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
"trainable params: 8,192 || all params: 559,222,784 || trainable%: 0.0014648902430985358"