diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/README.md b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/README.md index 2cb8191853368778180eb11533fdaffb5729e8b6..f05e808722b34f26ab0f96dcabd1febbe1d2f61c 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/README.md +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/README.md @@ -37,7 +37,6 @@ OpenRLHF is a high-performance RLHF framework built on Ray, DeepSpeed and HF Tra More details are in [Slides](https://docs.google.com/presentation/d/1JRhB1d7csofx0PIZBmfyBdMluxNd5JLPpUHrrvVhGnk/edit?usp=sharing) | [Technical Report](https://arxiv.org/abs/2405.11143) | [Documents](https://openrlhf.readthedocs.io/) ## News -- [2025/1] HKUST reproduced the [DeepSeek-R1-Zero and DeepSeek-R1 training on small models using OpenRLHF](https://github.com/hkust-nlp/simpleRL-reason) - [2024/12] We "proposed" 😊 the [REINFORCE++: A Simple and Efficient Approach for Aligning Large Language Models](https://www.researchgate.net/publication/387487679_REINFORCE_A_SIMPLE_AND_EFFICIENT_APPROACH_FOR_ALIGNING_LARGE_LANGUAGE_MODELS). - [2024/12] We analyzed the PPO, REINFORCE++, GRPO and RLOO in the [Notion Blogpost](https://hijkzzz.notion.site/unraveling-rlhf-and-its-variants-engineering-insights#147d9a33ecc9806090f3d5c749d31f05). @@ -203,9 +202,6 @@ deepspeed --module openrlhf.cli.train_sft \ # --ring_attn_size 2 \ # --ring_head_stride 2 \ -# Multi-turn fine-tuning loss -# --multiturn - # Can also be used for continued pre-training # --pretrain_mode ``` diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/vision_scripts/llava_zh_300k.json b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/vision_scripts/llava_zh_300k.json new file mode 100644 index 0000000000000000000000000000000000000000..16886c6651d542cc62c210459d87e8f6c8ab34f6 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/vision_scripts/llava_zh_300k.json @@ -0,0 +1,12 @@ +{ + "columns": { + "messages": "messages", + "images": "images" + }, + "tags": { + "role_tag": "role", + "content_tag": "content", + "user_tag": "user", + "assistant_tag": "assistant" + } +} \ No newline at end of file diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/vision_scripts/rlhf_v.json b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/vision_scripts/rlhf_v.json new file mode 100644 index 0000000000000000000000000000000000000000..6db7beec37274d6ff3bc313ceac75a31f7ab75e1 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/vision_scripts/rlhf_v.json @@ -0,0 +1,9 @@ +{ + "ranking": true, + "columns": { + "messages": "conversations", + "chosen": "chosen", + "rejected": "rejected", + "images": "images" + } +} \ No newline at end of file diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/vision_scripts/train_dpo_qwen2vl.sh b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/vision_scripts/train_dpo_qwen2vl.sh new file mode 100644 index 0000000000000000000000000000000000000000..6b906bc489e222a84dc66f7bc405cdb4da5bfde9 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/examples/vision_scripts/train_dpo_qwen2vl.sh @@ -0,0 +1,48 @@ +set -x + +export ACLNN_CACHE_LIMIT=100000 +export COMBINED_ENABLE=1 +export TASK_QUEUE_ENABLE=2 +export HF_DATASETS_OFFLINE=1 + +read -r -d '' training_commands <= 1 and args.enable_prefix_caching: - import vllm - if vllm.__version__ < "0.7.0": - args.enable_prefix_caching = False - print("[Warning] Disable prefix cache because vLLM updates weights without updating the old KV Cache for vLLM version below 0.7.0.") + args.enable_prefix_caching = False + print("[Warning] Disable prefix cache because vLLM updates weights without updating the old KV Cache.") if args.input_template and "{}" not in args.input_template: print("[Warning] {} not in args.input_template, set to None") diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_sft.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_sft.py index 843e37adadd4225fe173aeb0ac5c0b4347dc9969..ae6e301821dea08b99722a1b8fe255355f4b59e9 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_sft.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_sft.py @@ -63,7 +63,6 @@ def train(args): pretrain_mode=args.pretrain_mode, input_template=args.input_template, multiple_of=args.ring_attn_size, - multiturn=args.multiturn, ) eval_dataset = SFTDataset( eval_data, @@ -73,7 +72,6 @@ def train(args): pretrain_mode=args.pretrain_mode, input_template=args.input_template, multiple_of=args.ring_attn_size, - multiturn=args.multiturn, ) # prepare dataloader @@ -207,7 +205,6 @@ if __name__ == "__main__": parser.add_argument("--dataset_probs", type=str, default="1.0", help="sampling probs for datasets") parser.add_argument("--train_split", type=str, default="train", help="train split of the HF dataset") parser.add_argument("--eval_split", type=str, default="test", help="test split of the dataset") - parser.add_argument("--multiturn", action="store_true", default=False, help="Use compacted multiturn dataset") parser.add_argument("--input_key", type=str, default="input", help="JSON dataset key") parser.add_argument("--output_key", type=str, default=None, help="JSON dataset key") @@ -235,9 +232,6 @@ if __name__ == "__main__": args = parser.parse_args() - if args.multiturn: - assert args.apply_chat_template, "apply_chat_template must be enabled when using multiturn format" - if args.input_template and "{}" not in args.input_template: print("[Warning] {} not in args.input_template, set to None") args.input_template = None @@ -252,6 +246,7 @@ if __name__ == "__main__": print("[Warning] Please --flash_attn to accelerate when --packing_samples is enabled.") args.flash_attn = True + # TODO: [packing samples] if args.ring_attn_size > 1: assert args.packing_samples, "packing_samples must be enabled when using ring attention" diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_vl_dpo.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_vl_dpo.py new file mode 100644 index 0000000000000000000000000000000000000000..6c3efc7efc7db66fc7e1bb1082392aaaf8f7aaa9 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_vl_dpo.py @@ -0,0 +1,270 @@ +import argparse +import math +import os +from datetime import datetime + +import torch +from transformers.trainer import get_scheduler + +from openrlhf.datasets import build_train_and_valid_datasets, build_data_collator +from openrlhf.models import Actor +from openrlhf.trainer import VLDPOTrainer +from openrlhf.utils import ( + get_strategy, + get_tokenizer, + get_vision_processor, + get_qwen2_vl_utils, + add_vision_args, +) + + +def train(args): + # configure strategy + strategy = get_strategy(args) + strategy.setup_distributed() + if torch.distributed.get_rank() == 0: + print(f"Running args {args}") + + # configure model + # load huggingface model + model = Actor( + args.pretrain, + use_flash_attention_2=args.flash_attn, + bf16=args.bf16, + load_in_4bit=args.load_in_4bit, + lora_rank=args.lora_rank, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + target_modules=args.target_modules, + ds_config=strategy.get_ds_train_config(is_actor=True), + packing_samples=args.packing_samples, + create_vison_model=args.model_arch in ['qwen2_vl'], + ) + + # configure tokenizer + tokenizer = get_tokenizer(args.pretrain, model.model, "right", strategy, use_fast=not args.disable_fast_tokenizer) + strategy.print(model) + + # configure processor + vision_processor = get_vision_processor(args, args.pretrain, tokenizer) + if args.model_arch == "qwen2_vl": + encoder_utils = get_qwen2_vl_utils(args) + else: + raise NotImplementedError(f"no support model arch {args.model_arch=}") + + # load weights for ref model + ref_model = Actor( + args.ref_pretrain, + use_flash_attention_2=args.flash_attn, + bf16=args.bf16, + load_in_4bit=args.load_in_4bit, + ds_config=strategy.get_ds_eval_config(offload=args.ref_offload), + packing_samples=args.packing_samples, + create_vison_model=args.model_arch in ['qwen2_vl'], + ) + if args.ref_offload: + ref_model._offload = True + get_tokenizer(args.pretrain, ref_model.model, "right", strategy, use_fast=not args.disable_fast_tokenizer) + + # gradient_checkpointing + if args.gradient_checkpointing: + model.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant} + ) + + # configure optimizer + optim = strategy.create_optimizer(model, lr=args.learning_rate, betas=args.adam_betas, weight_decay=args.l2) + + # repace datasets + assert args.task_type == "dpo", f"the script is used for DPO training" + train_dataset, eval_dataset = build_train_and_valid_datasets( + args, tokenizer, processor=vision_processor, encoder_utils=encoder_utils, strategy=strategy) + + data_collator = build_data_collator(args, tokenizer, encoder_utils, vision_processor) + + # prepare dataloader + train_dataloader = strategy.setup_dataloader( + train_dataset, + args.micro_train_batch_size, + True, + True, + collate_fn=data_collator, + ) + + eval_dataloader = strategy.setup_dataloader( + eval_dataset, + args.micro_train_batch_size, + True, + False, + collate_fn=data_collator, + ) + + # scheduler + num_update_steps_per_epoch = len(train_dataset) // args.train_batch_size + max_steps = math.ceil(args.max_epochs * num_update_steps_per_epoch) + + scheduler = get_scheduler( + args.lr_scheduler, + optim, + num_warmup_steps=math.ceil(max_steps * args.lr_warmup_ratio), + num_training_steps=max_steps, + scheduler_specific_kwargs={"min_lr": args.learning_rate * 0.1}, + ) + + # strategy prepare + ((model, optim, scheduler), ref_model) = strategy.prepare((model, optim, scheduler), ref_model) + + # load checkpoint + consumed_samples = 0 + if args.load_checkpoint and os.path.exists(args.ckpt_path): + _, states = strategy.load_ckpt(model.model, args.ckpt_path) + consumed_samples = states["consumed_samples"] + strategy.print(f"Loaded the checkpoint: {args.ckpt_path}, consumed_samples: {consumed_samples}") + + os.makedirs(args.save_path, exist_ok=True) + + # batch_size here is expected to be C(k,2), k means # response of each prompt + # be limited with the format of dataset 'Dahoas/rm-static', we'd better use batch_size as 1 + trainer = VLDPOTrainer( + model=model, + ref_model=ref_model, + tokenizer=tokenizer, + strategy=strategy, + optim=optim, + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + scheduler=scheduler, + max_norm=args.max_norm, + beta=args.beta, + max_epochs=args.max_epochs, + ) + + trainer.fit(args, consumed_samples, num_update_steps_per_epoch) + + # save model checkpoint after fitting on only rank0 + strategy.save_model(model, tokenizer, args.save_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Checkpoints + parser.add_argument("--save_path", type=str, default="./ckpt") + parser.add_argument("--save_steps", type=int, default=-1) + parser.add_argument("--logging_steps", type=int, default=1) + parser.add_argument("--eval_steps", type=int, default=-1) + parser.add_argument("--ckpt_path", type=str, default="./ckpt/checkpoints_dpo") + parser.add_argument("--max_ckpt_num", type=int, default=3) + parser.add_argument("--max_ckpt_mem", type=int, default=1e8) + + # DeepSpeed + parser.add_argument("--micro_train_batch_size", type=int, default=8, help="batch size per GPU") + parser.add_argument("--train_batch_size", type=int, default=128, help="Global training batch size") + parser.add_argument("--load_checkpoint", action="store_true", default=False) + parser.add_argument("--max_norm", type=float, default=1.0, help="Gradient clipping") + parser.add_argument("--gradient_checkpointing", action="store_true", default=False) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--disable_fast_tokenizer", action="store_true", default=False) + parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for deepspeed") + parser.add_argument("--zero_stage", type=int, default=2, help="DeepSpeed ZeRO stage") + parser.add_argument("--bf16", action="store_true", default=False, help="Enable bfloat16") + parser.add_argument("--ref_offload", action="store_true", default=False) + parser.add_argument("--learning_rate", type=float, default=1e-5) + parser.add_argument("--lr_warmup_ratio", type=float, default=0.03) + parser.add_argument("--zpg", type=int, default=1, help="ZeRO++ max partition size") + parser.add_argument("--adam_offload", action="store_true", default=False, help="Offload Adam Optimizer") + parser.add_argument("--flash_attn", type=str, default="eager", help="Enable FlashAttention2") + parser.add_argument("--grad_accum_dtype", type=str, default=None, help="Adam grad accum data type") + parser.add_argument("--overlap_comm", action="store_true", default=False) + parser.add_argument("--gradient_checkpointing_use_reentrant", action="store_true", default=False) + + # DPO + parser.add_argument("--lr_scheduler", type=str, default="cosine_with_min_lr") + parser.add_argument("--max_epochs", type=int, default=1) + parser.add_argument("--l2", type=float, default=0.0, help="weight decay loss") + parser.add_argument("--beta", type=float, default=0.1) + parser.add_argument("--ipo", action="store_true", default=False) # IPO https://arxiv.org/pdf/2310.12036v2.pdf + parser.add_argument("--label_smoothing", type=float, default=0.0) # cDPO https://arxiv.org/pdf/2305.18290.pdf + parser.add_argument("--aux_loss_coef", type=float, default=0, help="MoE balancing loss") + parser.add_argument( + "--nll_loss_coef", type=float, default=0, help="Regularization with NLL loss, see LLama 3.1 tech report." + ) + parser.add_argument("--adam_betas", type=float, nargs=2, default=(0.9, 0.95), help="Betas for Adam optimizer") + + # Context Parallel + parser.add_argument("--ring_attn_size", type=int, default=1, help="Ring attention group size") + parser.add_argument( + "--ring_head_stride", + type=int, + default=1, + help="the number of heads to do ring attention each time. " + "It should be a divisor of the number of heads. " + "A larger value may results in faster training but will consume more memory.", + ) + + # LoRA + parser.add_argument("--load_in_4bit", action="store_true", default=False) + parser.add_argument("--lora_rank", type=int, default=0) + parser.add_argument("--lora_alpha", type=int, default=16) + parser.add_argument("--target_modules", type=str, nargs="*", default="all-linear") + parser.add_argument("--lora_dropout", type=float, default=0) + + # packing samples using Flash Attention2 + parser.add_argument("--packing_samples", action="store_true", default=False) + + # Custom dataset + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--ref_pretrain", type=str, default=None) + parser.add_argument("--dataset", type=str, default=None) + parser.add_argument("--dataset_probs", type=str, default="1.0", help="sampling probs for datasets") + parser.add_argument("--train_split", type=str, default="train", help="train split of the HF dataset") + parser.add_argument("--eval_split", type=str, default="test", help="test split of the dataset") + + parser.add_argument("--prompt_key", type=str, default=None) + parser.add_argument("--chosen_key", type=str, default="chosen") + parser.add_argument("--rejected_key", type=str, default="rejected") + parser.add_argument("--input_template", type=str, default=None) + parser.add_argument( + "--apply_chat_template", action="store_true", default=False, help="Use HF tokenizer chat template" + ) + parser.add_argument("--max_samples", type=int, default=1e8, help="Max number of samples") + parser.add_argument("--max_len", type=int, default=512) + + # wandb parameters + parser.add_argument("--use_wandb", type=str, default=None) + parser.add_argument("--wandb_org", type=str, default=None) + parser.add_argument("--wandb_group", type=str, default=None) + parser.add_argument("--wandb_project", type=str, default="openrlhf_train_dpo") + parser.add_argument( + "--wandb_run_name", + type=str, + default="exp_%s" % datetime.now().strftime("%m%dT%H:%M"), + ) + + # TensorBoard parameters + parser.add_argument("--use_tensorboard", type=str, default=None, help="TensorBoard logging path") + + parser = add_vision_args(parser) + args = parser.parse_args() + print(args) + + if args.ref_pretrain is None or args.ref_pretrain == "": + args.ref_pretrain = args.pretrain + + if args.input_template and "{}" not in args.input_template: + print("[Warning] {} not in args.input_template, set to None") + args.input_template = None + + if args.input_template and "\\n" in args.input_template: + print( + "[Warning] input_template contains \\n chracters instead of newline. " + "You likely want to pass $'\\n' in Bash or \"`n\" in PowerShell." + ) + + if args.packing_samples and not args.flash_attn: + print("[Warning] Please --flash_attn to accelerate when --packing_samples is enabled.") + args.flash_attn = True + + if args.ring_attn_size > 1: + assert args.packing_samples, "packing_samples must be enabled when using ring attention" + + train(args) diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_vl_sft.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_vl_sft.py new file mode 100644 index 0000000000000000000000000000000000000000..f98a7f8f7065e9174ce97179450c482e155c2e90 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/cli/train_vl_sft.py @@ -0,0 +1,242 @@ +import argparse +import math +import os +from datetime import datetime + +import torch +from transformers.trainer import get_scheduler + +from openrlhf.datasets import build_train_and_valid_datasets, build_data_collator +from openrlhf.models import Actor +from openrlhf.trainer import VLSFTTrainer +from openrlhf.utils import ( + get_strategy, + get_tokenizer, + get_vision_processor, + get_qwen2_vl_utils, + add_vision_args, +) + + +def train(args): + # configure strategy + strategy = get_strategy(args) + strategy.setup_distributed() + if torch.distributed.get_rank() == 0: + print(f"Running args {args}") + + # configure model + # load huggingface model + model = Actor( + args.pretrain, + use_flash_attention_2=args.flash_attn, + bf16=args.bf16, + load_in_4bit=args.load_in_4bit, + lora_rank=args.lora_rank, + lora_alpha=args.lora_alpha, + target_modules=args.target_modules, + lora_dropout=args.lora_dropout, + ds_config=strategy.get_ds_train_config(is_actor=True), + packing_samples=args.packing_samples, + create_vison_model=args.model_arch in ['qwen2_vl'], + ) + # configure tokenizer + tokenizer = get_tokenizer(args.pretrain, model.model, "right", strategy, use_fast=not args.disable_fast_tokenizer) + strategy.print(model) + + # configure processor + vision_processor = get_vision_processor(args, args.pretrain, tokenizer) + if args.model_arch == "qwen2_vl": + encoder_utils = get_qwen2_vl_utils(args) + else: + raise NotImplementedError(f"no support model arch {args.model_arch=}") + + # gradient_checkpointing + if args.gradient_checkpointing: + model.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant} + ) + + # configure optimizer + optim = strategy.create_optimizer(model, lr=args.learning_rate, betas=args.adam_betas, weight_decay=args.l2) + + # repace datasets + assert args.task_type == "sft", f"the script is used for SFT training" + train_dataset, eval_dataset = build_train_and_valid_datasets( + args, tokenizer, processor=vision_processor, encoder_utils=encoder_utils, strategy=strategy) + + data_collator = build_data_collator(args, tokenizer, encoder_utils, vision_processor) + + # prepare dataloader + train_dataloader = strategy.setup_dataloader( + train_dataset, + args.micro_train_batch_size, + True, + shuffle=True, + collate_fn=data_collator, + ) + eval_dataloader = strategy.setup_dataloader( + eval_dataset, + args.micro_train_batch_size, + True, + False, + collate_fn=data_collator, + ) + + # scheduler + num_update_steps_per_epoch = len(train_dataset) // args.train_batch_size + max_steps = math.ceil(args.max_epochs * num_update_steps_per_epoch) + + scheduler = get_scheduler( + args.lr_scheduler, + optim, + num_warmup_steps=math.ceil(max_steps * args.lr_warmup_ratio), + num_training_steps=max_steps, + scheduler_specific_kwargs={"min_lr": args.learning_rate * 0.1}, + ) + + # prepare models + (model, optim, scheduler) = strategy.prepare((model, optim, scheduler)) + + # load checkpoint + consumed_samples = 0 + if args.load_checkpoint and os.path.exists(args.ckpt_path): + _, states = strategy.load_ckpt(model.model, args.ckpt_path) + consumed_samples = states["consumed_samples"] + strategy.print(f"Loaded the checkpoint: {args.ckpt_path}, consumed_samples: {consumed_samples}") + + os.makedirs(args.save_path, exist_ok=True) + + # configure Trainer + trainer = VLSFTTrainer( + model=model, + strategy=strategy, + optim=optim, + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + scheduler=scheduler, + max_norm=args.max_norm, + pretrain_mode=args.pretrain_mode, + batch_size=args.train_batch_size, + max_epochs=args.max_epochs, + tokenizer=tokenizer, + ) + + trainer.fit(args, consumed_samples, num_update_steps_per_epoch) + + # save model checkpoint after fitting on only rank0 + strategy.save_model(model, tokenizer, args.save_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Checkpoint + parser.add_argument("--save_path", type=str, default="./ckpt") + parser.add_argument("--save_steps", type=int, default=-1) + parser.add_argument("--logging_steps", type=int, default=1) + parser.add_argument("--eval_steps", type=int, default=-1) + parser.add_argument("--ckpt_path", type=str, default="./ckpt/checkpoints_sft") + parser.add_argument("--max_ckpt_num", type=int, default=3) + parser.add_argument("--max_ckpt_mem", type=int, default=1e8) + parser.add_argument("--load_checkpoint", action="store_true", default=False) + + # DeepSpeed + parser.add_argument("--micro_train_batch_size", type=int, default=8, help="batch size per GPU") + parser.add_argument("--train_batch_size", type=int, default=128, help="Global training batch size") + parser.add_argument("--max_norm", type=float, default=1.0, help="Gradient clipping") + parser.add_argument("--gradient_checkpointing", action="store_true", default=False) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for deepspeed") + parser.add_argument("--zero_stage", type=int, default=2, help="DeepSpeed ZeRO stage") + parser.add_argument("--bf16", action="store_true", default=False, help="Enable bfloat16") + parser.add_argument("--zpg", type=int, default=1, help="ZeRO++ max partition size") + parser.add_argument("--adam_offload", action="store_true", default=False, help="Offload Adam Optimizer") + parser.add_argument("--flash_attn", type=str, default="eager", help="Enable FlashAttention2") + parser.add_argument("--grad_accum_dtype", type=str, default=None, help="Adam grad accum data type") + parser.add_argument("--overlap_comm", action="store_true", default=False) + parser.add_argument("--gradient_checkpointing_use_reentrant", action="store_true", default=False) + parser.add_argument("--disable_fast_tokenizer", action="store_true", default=False) + + # SFT + parser.add_argument("--max_epochs", type=int, default=2) + parser.add_argument("--aux_loss_coef", type=float, default=0, help="MoE balancing loss") + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--learning_rate", type=float, default=5e-6) + parser.add_argument("--lr_warmup_ratio", type=float, default=0.03) + parser.add_argument("--pretrain_mode", action="store_true", default=False, help="Use pretrain loss") + parser.add_argument("--lr_scheduler", type=str, default="cosine_with_min_lr") + parser.add_argument("--l2", type=float, default=0, help="weight decay loss") + parser.add_argument("--adam_betas", type=float, nargs=2, default=(0.9, 0.95), help="Betas for Adam optimizer") + + # ring-attention + parser.add_argument("--ring_attn_size", type=int, default=1, help="Ring attention group size") + parser.add_argument( + "--ring_head_stride", + type=int, + default=1, + help="the number of heads to do ring attention each time. " + "It should be a divisor of the number of heads. " + "A larger value may results in faster training but will consume more memory.", + ) + + # LoRA + parser.add_argument("--load_in_4bit", action="store_true", default=False) + parser.add_argument("--lora_rank", type=int, default=0) + parser.add_argument("--lora_alpha", type=int, default=16) + parser.add_argument("--target_modules", type=str, nargs="*", default="all-linear") + parser.add_argument("--lora_dropout", type=float, default=0) + + # packing SFT samples without CrossAttention + parser.add_argument("--packing_samples", action="store_true", default=False) + + # custom dataset + parser.add_argument("--dataset", type=str, default=None) + parser.add_argument("--dataset_probs", type=str, default="1.0", help="sampling probs for datasets") + parser.add_argument("--train_split", type=str, default="train", help="train split of the HF dataset") + parser.add_argument("--eval_split", type=str, default="test", help="test split of the dataset") + + parser.add_argument("--input_key", type=str, default="input", help="JSON dataset key") + parser.add_argument("--output_key", type=str, default=None, help="JSON dataset key") + parser.add_argument("--input_template", type=str, default="User: {}\nAssistant: ") + parser.add_argument( + "--apply_chat_template", action="store_true", default=False, help="Use HF tokenizer chat template" + ) + parser.add_argument("--tokenizer_chat_template", type=str, default=None) + parser.add_argument("--max_samples", type=int, default=1e8, help="Max number of samples") + parser.add_argument("--max_len", type=int, default=2048, help="Max tokens for the samples") + + # wandb parameters + parser.add_argument("--use_wandb", type=str, default=None) + parser.add_argument("--wandb_org", type=str, default=None) + parser.add_argument("--wandb_group", type=str, default=None) + parser.add_argument("--wandb_project", type=str, default="openrlhf_train_sft") + parser.add_argument( + "--wandb_run_name", + type=str, + default="sft_%s" % datetime.now().strftime("%m%dT%H:%M"), + ) + + # TensorBoard parameters + parser.add_argument("--use_tensorboard", type=str, default=None, help="TensorBoard logging path") + + parser = add_vision_args(parser) + args = parser.parse_args() + + if args.input_template and "{}" not in args.input_template: + print("[Warning] {} not in args.input_template, set to None") + args.input_template = None + + if args.input_template and "\\n" in args.input_template: + print( + "[Warning] input_template contains \\n chracters instead of newline. " + "You likely want to pass $'\\n' in Bash or \"`n\" in PowerShell." + ) + + if args.packing_samples and not args.flash_attn: + print("[Warning] Please --flash_attn to accelerate when --packing_samples is enabled.") + args.flash_attn = True + + if args.ring_attn_size > 1: + assert args.packing_samples, "packing_samples must be enabled when using ring attention" + + train(args) diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/__init__.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/__init__.py index bbb762f1eafe99ddf26ae7cf493822885f023a90..5877886453aff7574e60a39b2bc711363819086a 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/__init__.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/__init__.py @@ -3,5 +3,6 @@ from .prompts_dataset import PromptDataset from .reward_dataset import RewardDataset from .sft_dataset import SFTDataset from .unpaired_preference_dataset import UnpairedPreferenceDataset +from .vl_dataset import build_train_and_valid_datasets, build_data_collator -__all__ = ["ProcessRewardDataset", "PromptDataset", "RewardDataset", "SFTDataset", "UnpairedPreferenceDataset"] +__all__ = ["ProcessRewardDataset", "PromptDataset", "RewardDataset", "SFTDataset", "UnpairedPreferenceDataset", "build_train_and_valid_datasets", "build_data_collator"] diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/sft_dataset.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/sft_dataset.py index 6e031f70abbf4dc8acca656cf1ee07c034c3fca9..e5e0c004e97b8ce297f67d4e555c10e821184395 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/sft_dataset.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/sft_dataset.py @@ -7,7 +7,7 @@ from torch.utils.data import Dataset from .utils import zero_pad_sequences -def preprocess_data(data, input_template=None, input_key="input", output_key=None, apply_chat_template=None, multiturn=False): +def preprocess_data(data, input_template=None, input_key="input", output_key=None, apply_chat_template=None): if apply_chat_template: if output_key: prompt_message = data[input_key] @@ -51,7 +51,6 @@ class SFTDataset(Dataset): pretrain_mode=False, num_processors=8, # Specify the number of processors you want to use multiple_of=1, - multiturn=False, ) -> None: super().__init__() self.tokenizer = tokenizer @@ -59,7 +58,6 @@ class SFTDataset(Dataset): self.pretrain_mode = pretrain_mode self.max_length = max_length self.multiple_of = multiple_of - self.multiturn = multiturn # chat template self.input_template = input_template @@ -75,9 +73,7 @@ class SFTDataset(Dataset): # Parallel loading datasets processed_dataset = dataset.map( - self.process_data, - remove_columns=dataset.column_names, - num_proc=num_processors, + self.process_data, remove_columns=dataset.column_names, num_proc=num_processors ) processed_dataset = processed_dataset.filter(lambda x: x["prompt"] is not None) @@ -85,51 +81,15 @@ class SFTDataset(Dataset): self.prompts = processed_dataset["prompt"] self.responses = processed_dataset["response"] self.prompt_ids_lens = processed_dataset["prompt_ids_len"] - self.response_ranges = processed_dataset["response_ranges"] if self.multiturn else None def process_data(self, data): - if self.multiturn and self.output_key: - data[self.input_key].append(data[self.output_key]) - data[self.output_key] = None - - if self.multiturn: - assert not self.output_key or not data[self.output_key], "You should put the whole trajactory into data[input_key] and do not set output_key" - input_key = self.input_key - apply_chat_template = self.apply_chat_template - response_ranges = [] - for idx, message in enumerate(data[input_key]): - if message['role'] == 'assistant': - prompt = apply_chat_template(data[input_key][: idx], tokenize=False, add_generation_prompt=True) - response = apply_chat_template(data[input_key][: idx + 1], tokenize=False)[len(prompt):] - - start_idx = self.tokenizer( - prompt, - max_length=self.max_length, - padding=False, - truncation=True, - return_tensors="pt", - add_special_tokens=False, - )["attention_mask"].int().sum().item() - - end_idx = start_idx + self.tokenizer( - response, - max_length=self.max_length, - padding=False, - truncation=True, - return_tensors="pt", - add_special_tokens=False, - )["attention_mask"].int().sum().item() - 1 - response_ranges.append((start_idx, end_idx)) # left close right open - prompt, response = preprocess_data( data, None if self.pretrain_mode else self.input_template, self.input_key, self.output_key, apply_chat_template=None if self.pretrain_mode else self.apply_chat_template, - multiturn=self.multiturn, ) - if not self.pretrain_mode: prompt_token = self.tokenizer( prompt, @@ -147,7 +107,7 @@ class SFTDataset(Dataset): else: prompt_ids_len = 0 - return {"prompt": prompt, "response": response, "prompt_ids_len": prompt_ids_len, "response_ranges": response_ranges if self.multiturn else None} + return {"prompt": prompt, "response": response, "prompt_ids_len": prompt_ids_len} def __len__(self): length = len(self.prompts) @@ -178,7 +138,7 @@ class SFTDataset(Dataset): # to avoid EOS_token truncation input_token["input_ids"][0][-1] = self.tokenizer.eos_token_id input_token["attention_mask"][0][-1] = True - info = {"input": prompt, "output": response, "input_length": input_token["attention_mask"].int().sum().item(), "response_ranges": self.response_ranges[idx] if self.multiturn else None} + info = {"input": prompt, "output": response, "input_length": input_token["attention_mask"].int().sum().item()} return prompt_ids_len, input_token["input_ids"], input_token["attention_mask"], info @@ -203,19 +163,14 @@ class SFTDataset(Dataset): packed_input_ids = [] packed_attention_masks = [] prompt_ids_lens = [] - infos = {"input_length": [], "response_ranges": [] if self.multiturn else None} + infos = {"input_length": []} + index = 1 for prompt_ids_len, input_id, attention_mask, info in item_list: packed_input_ids.append(input_id.flatten()) packed_attention_masks.append(torch.full_like(input_id.flatten(), index)) prompt_ids_lens.append(prompt_ids_len) infos["input_length"].append(info["input_length"]) - if self.multiturn: - if len(infos["response_ranges"]) >= 1: - for i in range(len(info["response_ranges"])): - info["response_ranges"][i][0] += infos["response_ranges"][-1][-1][1] # end_index of the last response of the last item - info["response_ranges"][i][1] += infos["response_ranges"][-1][-1][1] - infos["response_ranges"].append(info["response_ranges"]) index += 1 packed_input_ids = torch.cat(packed_input_ids, dim=0).unsqueeze(0) diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/vl_dataset.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/vl_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..3a41960c951f48d9bb7ef66a1c6f543a306065d4 --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/datasets/vl_dataset.py @@ -0,0 +1,513 @@ +import os +from typing import Any, Dict, Literal, Optional, Sequence, Union, List, Tuple +from functools import partial +from dataclasses import dataclass +from collections import defaultdict + +import torch +from transformers import DataCollatorForSeq2Seq, ProcessorMixin + +from openrlhf.utils.utils import blending_datasets +from openrlhf.utils.vision_utils import ( + IGNORE_INDEX, ImageInput, + VisionEncoderUtils, DatasetAttr, + get_dataset_attr, +) + + +def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]: + r""" + Computes the real sequence length after truncation by the cutoff_len. + """ + if target_len * 2 < cutoff_len: # truncate source + max_target_len = cutoff_len + elif source_len * 2 < cutoff_len: # truncate target + max_target_len = cutoff_len - source_len + else: # truncate both + max_target_len = int(cutoff_len * (target_len / (source_len + target_len))) + + new_target_len = min(max_target_len, target_len) + max_source_len = max(cutoff_len - new_target_len, 0) + new_source_len = min(max_source_len, source_len) + return new_source_len, new_target_len + + +def _convert_images( + images: Union[ImageInput, Sequence[ImageInput]], + dataset_attr: DatasetAttr, +) -> Optional[List[ImageInput]]: + r""" + Optionally concatenates image path to dataset dir when loading from local disk. + """ + if not isinstance(images, list): + images = [images] + elif len(images) == 0: + return None + else: + images = images[:] + + return images + + +def convert_sharegpt( + example: Dict[str, Any], + dataset_attr: DatasetAttr +) -> Dict[str, Any]: + r""" + Converts sharegpt format dataset to the standard format. + """ + tag_mapping = { + dataset_attr.user_tag: "user", + dataset_attr.assistant_tag: "assistant", + dataset_attr.observation_tag: "observation", + dataset_attr.function_tag: "function", + dataset_attr.system_tag: "system", + } + odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag) + even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag) + accept_tags = (odd_tags, even_tags) + messages = example[dataset_attr.messages] + if ( + dataset_attr.system_tag + and len(messages) != 0 + and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag + ): + system = messages[0][dataset_attr.content_tag] + messages = messages[1:] + else: + system = example[dataset_attr.system] if dataset_attr.system else "" + + aligned_messages = [] + broken_data = False + for turn_idx, message in enumerate(messages): + if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]: + print(f"Invalid role tag in {message}.") + broken_data = True + + aligned_messages.append( + {"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]} + ) + + if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or ( + dataset_attr.ranking and len(aligned_messages) % 2 == 0 + ): + print(f"Invalid message count in {messages}.") + broken_data = True + + if ( + dataset_attr.ranking + and isinstance(example[dataset_attr.chosen], dict) + and isinstance(example[dataset_attr.rejected], dict) + ): # pairwise example + chosen = example[dataset_attr.chosen] + rejected = example[dataset_attr.rejected] + if ( + chosen[dataset_attr.role_tag] not in accept_tags[-1] + or rejected[dataset_attr.role_tag] not in accept_tags[-1] + ): + print(f"Invalid role tag in {[chosen, rejected]}.") + broken_data = True + + prompt = aligned_messages + response = [ + {"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]}, + {"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]}, + ] + else: # normal example + prompt = aligned_messages[:-1] + response = aligned_messages[-1:] + + if broken_data: + print("Skipping this abnormal example.") + prompt, response = [], [] + + convert_images = partial(_convert_images, dataset_attr=dataset_attr) + output = { + "_prompt": prompt, + "_response": response, + "_system": system, + "_tools": example[dataset_attr.tools] if dataset_attr.tools else "", + "_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None, + "_videos": None, + } + return output + + +def _encode_supervised_example( + prompt: Sequence[Dict[str, str]], + response: Sequence[Dict[str, str]], + system: Optional[str], + tools: Optional[str], + images: Sequence[ImageInput], + videos: Sequence, + encoder: VisionEncoderUtils, + tokenizer, + processor, + cutoff_len: int, + train_on_prompt: bool, + mask_history: bool, +) -> Tuple[List[int], List[int]]: + messages = encoder.mm_plugin.process_messages(prompt + response, images, videos, processor) + input_ids, labels = encoder.mm_plugin.process_token_ids([], [], images, videos, tokenizer, processor) + encoded_pairs = encoder.encode_multiturn(tokenizer, messages, system, tools) + total_length = len(input_ids) + (1 if encoder.efficient_eos else 0) + if mask_history: + encoded_pairs = encoded_pairs[::-1] # high priority for last turns + + for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs): + if total_length >= cutoff_len: + break + + source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), cutoff_len - total_length) + source_ids = source_ids[:source_len] + target_ids = target_ids[:target_len] + total_length += source_len + target_len + + if train_on_prompt: + source_label = source_ids + elif encoder.efficient_eos: + source_label = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1) + else: + source_label = [IGNORE_INDEX] * source_len + + if mask_history and turn_idx != 0: # train on the last turn only + target_label = [IGNORE_INDEX] * target_len + else: + target_label = target_ids + + if mask_history: # reversed sequences + input_ids = source_ids + target_ids + input_ids + labels = source_label + target_label + labels + else: + input_ids += source_ids + target_ids + labels += source_label + target_label + + if encoder.efficient_eos: + input_ids += [tokenizer.eos_token_id] + labels += [tokenizer.eos_token_id] + + return input_ids, labels + + +def _encode_pairwise_example( + prompt: Sequence[Dict[str, str]], + response: Sequence[Dict[str, str]], + system: Optional[str], + tools: Optional[str], + images: Sequence[ImageInput], + videos: Sequence, + encoder: VisionEncoderUtils, + tokenizer, + processor, + cutoff_len: int, +) -> Tuple[List[int], List[int], List[int], List[int]]: + chosen_messages = encoder.mm_plugin.process_messages(prompt + [response[0]], images, videos, processor) + rejected_messages = encoder.mm_plugin.process_messages(prompt + [response[1]], images, videos, processor) + prompt_ids, chosen_ids = encoder.encode_oneturn(tokenizer, chosen_messages, system, tools) + _, rejected_ids = encoder.encode_oneturn(tokenizer, rejected_messages, system, tools) + + if encoder.efficient_eos: + chosen_ids += [tokenizer.eos_token_id] + rejected_ids += [tokenizer.eos_token_id] + + prompt_ids, _ = encoder.mm_plugin.process_token_ids(prompt_ids, None, images, videos, tokenizer, processor) + # consider the response is more important + source_len, target_len = infer_seqlen(len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), cutoff_len) + prompt_ids = prompt_ids[:source_len] + chosen_ids = chosen_ids[:target_len] + rejected_ids = rejected_ids[:target_len] + + chosen_input_ids = prompt_ids + chosen_ids + chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids + rejected_input_ids = prompt_ids + rejected_ids + rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids + return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels + + +def preprocess_supervised_dataset( + examples: Dict[str, List[Any]], + encoder: VisionEncoderUtils, + tokenizer, + processor, + data_args, +) -> Dict[str, List[Any]]: + # build inputs with format ` X Y ` and labels with format ` ... Y ` + # for multiturn examples, we only mask the prompt part in each prompt-response pair. + model_inputs = defaultdict(list) + for i in range(len(examples["_prompt"])): + if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1: + print( + "Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]) + ) + continue + + input_ids, labels = _encode_supervised_example( + prompt=examples["_prompt"][i], + response=examples["_response"][i], + system=examples["_system"][i], + tools=examples["_tools"][i], + images=examples["_images"][i] or [], + videos=examples["_videos"][i] or [], + encoder=encoder, + tokenizer=tokenizer, + processor=processor, + cutoff_len=data_args.max_len, + train_on_prompt=data_args.train_on_prompt, + mask_history=data_args.mask_history, + ) + model_inputs["input_ids"].append(input_ids) + model_inputs["attention_mask"].append([1] * len(input_ids)) + model_inputs["labels"].append(labels) + model_inputs["images"].append(examples["_images"][i]) + model_inputs["videos"].append(examples["_videos"][i]) + + return model_inputs + + +def preprocess_pairwise_dataset( + examples: Dict[str, List[Any]], + encoder: VisionEncoderUtils, + tokenizer, + processor, + data_args, +) -> Dict[str, List[Any]]: + # build input pairs with format ` X`, `Y1 ` and `Y2 ` + model_inputs = defaultdict(list) + for i in range(len(examples["_prompt"])): + if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2: + print( + "Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]) + ) + continue + + chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example( + prompt=examples["_prompt"][i], + response=examples["_response"][i], + system=examples["_system"][i], + tools=examples["_tools"][i], + images=examples["_images"][i] or [], + videos=examples["_videos"][i] or [], + encoder=encoder, + tokenizer=tokenizer, + processor=processor, + cutoff_len=data_args.max_len, + ) + model_inputs["chosen_input_ids"].append(chosen_input_ids) + model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids)) + model_inputs["chosen_labels"].append(chosen_labels) + model_inputs["rejected_input_ids"].append(rejected_input_ids) + model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids)) + model_inputs["rejected_labels"].append(rejected_labels) + model_inputs["images"].append(examples["_images"][i]) + model_inputs["videos"].append(examples["_videos"][i]) + + return model_inputs + + +def get_preprocessed_dataset(args, data_list, encoder, tokenizer, processor): + train_data, eval_data = data_list + + dataset_attr = get_dataset_attr(args.dataset_config_path) + + kwargs = dict( + num_proc=args.processing_num_workers, + load_from_cache_file=(not args.overwrite_cache) or (args.local_process_index != 0), + desc="Converting format of dataset", + ) + convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr) + column_names = list(next(iter(train_data)).keys()) + train_data = train_data.map(convert_func, batched=False, remove_columns=column_names, + **kwargs) + if eval_data is not None: + eval_data = eval_data.map(convert_func, batched=False, remove_columns=column_names, + **kwargs) + + if args.task_type == "sft": + process_dataset_class = preprocess_supervised_dataset + elif args.task_type == "dpo": + process_dataset_class = preprocess_pairwise_dataset + else: + raise NotImplementedError(f"Unknown task_type: {args.task_type}") + + preprocess_func = partial( + process_dataset_class, + encoder=encoder, + tokenizer=tokenizer, + processor=processor, + data_args=args, + ) + kwargs.update({"desc": "Running tokenizer on dataset"}) + column_names = list(next(iter(train_data)).keys()) + train_data = train_data.map(preprocess_func, batched=True, + batch_size=args.preprocessing_batch_size, + remove_columns=column_names, **kwargs) + if eval_data is not None: + eval_data = eval_data.map(preprocess_func, batched=True, + batch_size=args.preprocessing_batch_size, + remove_columns=column_names, **kwargs) + return train_data, eval_data + + +# Copied from https://github.com/hiyouga/LLaMA-Factory/blob/main/src/llamafactory/data/collator.py +@dataclass +class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): + r""" + Data collator that supports VLMs. + + Features should contain input_ids, attention_mask, labels and images. + """ + + encoder_utils: Optional[VisionEncoderUtils] = None + vision_processor: Optional[ProcessorMixin] = None + + def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: + batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids = [], [], [], [], [] + for feature in features: + images = feature.pop("images", None) or [] + videos = feature.pop("videos", None) or [] + batch_images.extend(images) + batch_videos.extend(videos) + batch_imglens.append(len(images)) + batch_vidlens.append(len(videos)) + batch_input_ids.append(feature["input_ids"]) + + mm_inputs = self.encoder_utils.mm_plugin.get_mm_inputs( + batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids, self.vision_processor + ) + if "token_type_ids" in mm_inputs: + token_type_ids = mm_inputs.pop("token_type_ids") + for i, feature in enumerate(features): + feature["token_type_ids"] = token_type_ids[i] + + features: Dict[str, torch.Tensor] = super().__call__(features) + features.update(mm_inputs) + if isinstance(features.get("pixel_values"), list): # for pixtral inputs + features = features.data # use default_collate() instead of BatchEncoding.to() + return features + +# Copied from https://github.com/hiyouga/LLaMA-Factory/blob/main/src/llamafactory/data/collator.py + + +@dataclass +class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq): + r""" + Data collator for 4d attention mask. + """ + + block_diag_attn: bool = False + attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager" + compute_dtype: "torch.dtype" = torch.float32 + + def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: + features = super().__call__(features) + if self.block_diag_attn and self.attn_implementation != "flash_attention_2": + features["attention_mask"] = self.prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype) + return features + + def prepare_4d_attention_mask(self, attention_mask_with_indices, dtype) -> torch.Tensor: + r""" + Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len), + while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking. + + e.g. + ```python + # input + [[1, 1, 2, 2, 2, 0]] + # output + [ + [ + [ + [o, x, x, x, x, x], + [o, o, x, x, x, x], + [x, x, o, x, x, x], + [x, x, o, o, x, x], + [x, x, o, o, o, x], + [x, x, x, x, x, x], + ] + ] + ] + ``` + where `o` equals to `0.0`, `x` equals to `min_dtype`. + """ + bsz, seq_len = attention_mask_with_indices.size() + min_dtype = torch.finfo(dtype).min + expanded_mask = attention_mask_with_indices[:, None, None, :].expand(bsz, 1, seq_len, seq_len) + # Create a binary mask from the original mask where zeros remain zeros and all other values are set to one + padding_mask = torch.where(expanded_mask != 0, 1, 0) + # Create a block-diagonal mask. + attention_mask_4d = torch.eq(expanded_mask, expanded_mask.transpose(-1, -2)).int() * padding_mask + # Use the lower triangular mask to zero out the upper triangular part + attention_mask_4d *= torch.tril(torch.ones((seq_len, seq_len), dtype=torch.long)) + # Invert the attention mask. + attention_mask_4d = torch.where(attention_mask_4d != 0, torch.tensor(0, dtype=dtype), min_dtype) + return attention_mask_4d + +# Copied from https://github.com/hiyouga/LLaMA-Factory/blob/main/src/llamafactory/data/collator.py + + +@dataclass +class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): + r""" + Data collator for pairwise data. + """ + + def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: + r""" + Pads batched data to the longest sequence in the batch. + + We generate 2 * n examples where the first n examples represent chosen examples and + the last n examples represent rejected examples. + """ + concatenated_features = [] + for key in ("chosen", "rejected"): + for feature in features: + target_feature = { + "input_ids": feature[f"{key}_input_ids"], + "attention_mask": feature[f"{key}_attention_mask"], + "labels": feature[f"{key}_labels"], + "images": feature["images"], + "videos": feature["videos"], + } + concatenated_features.append(target_feature) + + return super().__call__(concatenated_features) + + +def build_train_and_valid_datasets(args, tokenizer, processor, encoder_utils, strategy): + train_ds, eval_ds = blending_datasets( + args.dataset, + args.dataset_probs, + strategy, + args.seed, + max_count=args.max_samples, + train_split=args.train_split, + eval_split=args.eval_split, + ) + + return get_preprocessed_dataset(args, [train_ds, eval_ds], encoder_utils, tokenizer, processor) + + +def build_data_collator(args, tokenizer, encoder_utils, vision_processor): + collator_class = None + kwargs = {} + if args.task_type == "dpo": + collator_class = PairwiseDataCollatorWithPadding + elif args.task_type == "sft": + collator_class = SFTDataCollatorWith4DAttentionMask + kwargs = { + "block_diag_attn": args.neat_packing, + "attn_implementation": "flash_attention_2" if args.flash_attn else None, + "compute_dtype": torch.bfloat16 if args.bf16 else torch.float16 + } + else: + raise NotImplementedError(f"Task type {args.task_type} not supported.") + + data_collator = collator_class( + encoder_utils=encoder_utils, + vision_processor=vision_processor, + pad_to_multiple_of=8, + label_pad_token_id=IGNORE_INDEX, + tokenizer=tokenizer, + **kwargs + ) + return data_collator diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/models/actor.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/models/actor.py index 68009603c64906f08fddd931ad9b318cf9be6d92..13f265f157f9658712672fb4b7e7a8faf35bc1bf 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/models/actor.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/models/actor.py @@ -5,7 +5,7 @@ import torch.distributed as dist import torch.nn as nn from peft import LoraConfig, TaskType, get_peft_model from peft.tuners.lora import LoraLayer -from transformers import AutoModelForCausalLM, BitsAndBytesConfig +from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoModelForVision2Seq from transformers.integrations.deepspeed import HfDeepSpeedConfig from .ring_attn_utils import convert_ring_attn_params @@ -45,12 +45,19 @@ class Actor(nn.Module): ds_config=None, device_map=None, packing_samples=False, + create_vison_model=False, + freeze_vision_tower=True, **kwargs, ) -> None: super().__init__() if isinstance(pretrain_or_model, str): - attn_implementation = "flash_attention_2" if use_flash_attention_2 else "eager" + if use_flash_attention_2 == "fa2": + attn_implementation = "flash_attention_2" + elif use_flash_attention_2 == "sdpa": + attn_implementation = "sdpa" + else: + attn_implementation = "eager" # Note: dschf is defined in function scope to avoid global effects # https://huggingface.co/docs/transformers/deepspeed#non-trainer-deepspeed-integration @@ -69,8 +76,12 @@ class Actor(nn.Module): ) else: nf4_config = None - - self.model = AutoModelForCausalLM.from_pretrained( + model_class = None + if create_vison_model: + model_class = AutoModelForVision2Seq + else: + model_class = AutoModelForCausalLM + self.model = model_class.from_pretrained( pretrain_or_model, trust_remote_code=True, attn_implementation=attn_implementation, @@ -78,6 +89,9 @@ class Actor(nn.Module): torch_dtype=torch.bfloat16 if bf16 else "auto", device_map=device_map, ) + self.model_type = getattr(self.model.config, "model_type", None) + if create_vison_model: + self.prepare_model(self.model, freeze_vision_tower) # LoRA if lora_rank > 0: @@ -188,9 +202,22 @@ class Actor(nn.Module): return_output=False, ring_attn_group: Optional[dist.ProcessGroup] = None, packed_seq_lens: Optional[list[int]] = None, + **kwargs ) -> torch.Tensor: """Returns action log probs""" - if not self.packing_samples: + if self.model_type == "qwen2_vl": + # Before Transformers version 4.47, when using the Qwen2VL model, + # the position IDs needed to be externally provided in a specific mrope format + # during the forward pass. Therefore, it was decided to consistently pass them + # externally through the model. + position_ids, rope_deltas = self.model.get_rope_index( + input_ids=sequences, + image_grid_thw=kwargs.get("image_grid_thw", None), + video_grid_thw=kwargs.get("video_grid_thw", None), + attention_mask=attention_mask, + ) + kwargs["rope_deltas"] = rope_deltas + elif not self.packing_samples: # https://github.com/OpenRLHF/OpenRLHF/issues/217 position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) @@ -205,7 +232,8 @@ class Actor(nn.Module): # explicitly ignore attention_mask for packing_samples attention_mask = None - output = self.model(sequences, attention_mask=attention_mask, position_ids=position_ids) + output = self.model(sequences, attention_mask=attention_mask, position_ids=position_ids, + **kwargs) # https://github.com/OpenRLHF/OpenRLHF/pull/634 output["logits"] = output["logits"].to(torch.float32) @@ -240,3 +268,19 @@ class Actor(nn.Module): def print_trainable_parameters(self): self.model.print_trainable_parameters() + + def prepare_model(self, model, freeze_vision_tower): + freeze_modules = set() + if self.model_type == "qwen2_vl": + if freeze_vision_tower: + freeze_modules.add("visual") + elif self.model_type == None: + pass + else: + raise NotImplementedError("TODO: Implement prepare_model for model_type: {}".format(self.model_type)) + + for name, param in model.named_parameters(): + if not any(freeze_mod in name for freeze_mod in freeze_modules): + pass + else: + param.requires_grad_(False) diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/models/model.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/models/model.py index 3d2102dc9403c435c01775b8ca4ab5f74bb39009..a7a71ae5b79ccd230facfcf1bf15983c80216662 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/models/model.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/models/model.py @@ -3,17 +3,20 @@ from typing import Optional, Union import deepspeed import torch import torch.nn as nn -from flash_attn.utils.distributed import all_gather from peft import LoraConfig, get_peft_model from peft.tuners.lora import LoraLayer from transformers import AutoConfig, AutoModel, BitsAndBytesConfig from transformers.integrations.deepspeed import HfDeepSpeedConfig +from transformers.utils import is_flash_attn_2_available from openrlhf.utils.logging_utils import init_logger from .ring_attn_utils import convert_ring_attn_params from .utils import reset_position_ids +if is_flash_attn_2_available(): + from flash_attn.utils.distributed import all_gather + logger = init_logger(__name__) @@ -68,7 +71,12 @@ def get_llm_for_sequence_regression( config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) config.normalize_reward = normalize_reward - config._attn_implementation = "flash_attention_2" if use_flash_attention_2 else "eager" + if use_flash_attention_2 == "fa2": + config._attn_implementation = "flash_attention_2" + elif use_flash_attention_2 == "sdpa": + config._attn_implementation = "sdpa" + else: + config._attn_implementation = "eager" # Prioritize using the value_head_prefix in the model configuration. value_head_prefix = getattr(config, "value_head_prefix", value_head_prefix) diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/__init__.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/__init__.py index a26d247b6357ce400e56103bc398550717246c02..4eaeca47a900f890033d3a432e99f81997502c84 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/__init__.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/__init__.py @@ -1,10 +1,10 @@ -from .dpo_trainer import DPOTrainer +from .dpo_trainer import DPOTrainer, VLDPOTrainer from .kd_trainer import KDTrainer from .kto_trainer import KTOTrainer from .ppo_trainer import PPOTrainer from .prm_trainer import ProcessRewardModelTrainer from .rm_trainer import RewardModelTrainer -from .sft_trainer import SFTTrainer +from .sft_trainer import SFTTrainer, VLSFTTrainer __all__ = [ "DPOTrainer", @@ -14,4 +14,6 @@ __all__ = [ "ProcessRewardModelTrainer", "RewardModelTrainer", "SFTTrainer", + "VLSFTTrainer", + "VLDPOTrainer", ] diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/dpo_trainer.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/dpo_trainer.py index 55088cd559de16ac675ef9d9236f882147f4268f..3b8c45bac42fd9f0394bc7d48e49de71148b997c 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/dpo_trainer.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/dpo_trainer.py @@ -1,14 +1,20 @@ import os +import time + from abc import ABC import torch -from flash_attn.utils.distributed import all_gather from torch.nn import functional as F from torch.optim import Optimizer from tqdm import tqdm +from transformers.utils import is_flash_attn_2_available from openrlhf.models import DPOLoss from openrlhf.utils.distributed_sampler import DistributedSampler +from openrlhf.utils.vision_utils import IGNORE_INDEX + +if is_flash_attn_2_available(): + from flash_attn.utils.distributed import all_gather class DPOTrainer(ABC): @@ -197,7 +203,6 @@ class DPOTrainer(ABC): logs_dict["nll_loss"] = nll_loss.item() # step bar logs_dict = self.strategy.all_reduce(logs_dict) - step_bar.set_postfix(logs_dict) step_bar.update() # logs/checkpoints/evaluation @@ -223,9 +228,11 @@ class DPOTrainer(ABC): def save_logs_and_checkpoints(self, args, global_step, step_bar, logs_dict={}, client_states={}): # logs if global_step % args.logging_steps == 0: + logs = {"train/%s" % k: v for k, v in {**logs_dict, "global_step": global_step}.items()} + if self.strategy.is_rank_0(): + step_bar.write(str(logs)) # wandb if self._wandb is not None and self.strategy.is_rank_0(): - logs = {"train/%s" % k: v for k, v in {**logs_dict, "global_step": global_step}.items()} self._wandb.log(logs) # TensorBoard elif self._tensorboard is not None and self.strategy.is_rank_0(): @@ -476,3 +483,238 @@ class DPOTrainer(ABC): index = index + seq_len return torch.stack(logprobs_sums), torch.stack(logprobs_means) + + +class VLDPOTrainer(DPOTrainer): + """ + Trainer for Direct Preference Optimization (DPO) training. + + Args: + model (torch.nn.Module): The primary model to be trained. + ref_model (torch.nn.Module): The reference model for comparing and guiding preference. + strategy (Strategy): The strategy to use for training. + tokenizer (Tokenizer): The tokenizer for processing input data. + optim (Optimizer): The optimizer for training the model. + train_dataloader (DataLoader): The dataloader for the training dataset. + eval_dataloader (DataLoader): The dataloader for the evaluation dataset. + scheduler (Scheduler): The learning rate scheduler to control learning rate during training. + max_norm (float, defaults to 0.5): Maximum gradient norm for gradient clipping. + beta (float, defaults to 0.01): Coefficient for regularizing the preference loss. + max_epochs (int, defaults to 2): Maximum number of training epochs. + """ + def data_to_device(self, input_data): + for key, value in input_data.items(): + input_data[key] = value.to(torch.cuda.current_device()) + return input_data + + def concatenated_forward(self, model, input_ids, attn_masks, labels, pixel_values, image_grid_thw): + """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. + + We do this to avoid doing two forward passes, because it's faster for FSDP. + """ + + output = model(input_ids, + attention_mask=attn_masks, + return_output=True, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw) + + all_logits = output["logits"] + all_logps_sum, all_logps_mean = self._get_batch_logps( + all_logits, labels, attn_masks, None, average_log_prob=False + ) + assert input_ids.shape[0] % 2 == 0 + batch_size = input_ids.shape[0] // 2 + chosen_logps = all_logps_sum[:batch_size] + rejected_logps = all_logps_sum[batch_size:] + aux_loss = output.aux_loss if "aux_loss" in output else [] + return chosen_logps, rejected_logps, aux_loss, -all_logps_mean[: batch_size].mean() + + def _get_batch_logps( + self, + logits: torch.FloatTensor, + labels: torch.LongTensor, + attention_mask, + prompt_id_lens, + average_log_prob: bool = False, + ) -> torch.FloatTensor: + """Compute the log probabilities of the given labels under the given logits. + + Args: + logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) + labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length) + average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. + + Returns: + A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. + """ + assert average_log_prob == False + assert logits.shape[:-1] == labels.shape + + labels = labels[:, 1:].clone() + logits = logits[:, :-1, :] + loss_masks = (labels != IGNORE_INDEX) + + # dummy token; we'll ignore the losses on these tokens later + labels[labels == IGNORE_INDEX] = 0 + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + + logprobs_sums = (per_token_logps * loss_masks).sum(-1) + logprobs_means = (per_token_logps * loss_masks).sum(-1) / loss_masks.sum(-1) + return logprobs_sums, logprobs_means + + def fit(self, args, consumed_samples=0, num_update_steps_per_epoch=None): + # get eval and save steps + if args.eval_steps == -1: + args.eval_steps = num_update_steps_per_epoch # Evaluate once per epoch + if args.save_steps == -1: + args.save_steps = float("inf") # do not save ckpt + + # Restore step and start_epoch + step = consumed_samples // args.train_batch_size * self.strategy.accumulated_gradient + 1 + start_epoch = consumed_samples // args.train_batch_size // num_update_steps_per_epoch + consumed_samples = consumed_samples % (num_update_steps_per_epoch * args.train_batch_size) + + epoch_bar = tqdm( + range(start_epoch, self.epochs), + desc="Train epoch", + disable=not self.strategy.is_rank_0(), + ) + for epoch in range(start_epoch, self.epochs): + if isinstance(self.train_dataloader.sampler, DistributedSampler): + self.train_dataloader.sampler.set_epoch( + epoch, consumed_samples=0 if epoch > start_epoch else consumed_samples + ) + + step_bar = tqdm( + range(self.train_dataloader.__len__()), + desc="Train step of epoch %d" % epoch, + disable=not self.strategy.is_rank_0(), + ) + + self.model.train() + self.ref_model.eval() + acc_mean = 0 + loss_mean = 0 + + assert self.strategy.ring_attn_group is None, f"Ring attention is not supported on vision models currently" + + # train + for input_data in self.train_dataloader: + start_time = time.time() + data = self.data_to_device(input_data) + + chosen_logps, rejected_logps, aux_loss, nll_loss = self.concatenated_forward( + self.model, data["input_ids"], data["attention_mask"], + data["labels"], data["pixel_values"], data["image_grid_thw"], + ) + with torch.no_grad(): + reference_chosen_logps, reference_rejected_logps, _, _ = self.concatenated_forward( + self.ref_model, data["input_ids"], data["attention_mask"], + data["labels"], data["pixel_values"], data["image_grid_thw"], + ) + + # loss function + preference_loss, chosen_reward, reject_reward = self.loss_fn( + chosen_logps, rejected_logps, reference_chosen_logps, reference_rejected_logps + ) + # mixtral + if not self.aux_loss: + aux_loss = 0 + # nll loss + if not self.nll_loss: + nll_loss = 0 + + loss = preference_loss + aux_loss * self.args.aux_loss_coef + nll_loss * self.args.nll_loss_coef + self.strategy.backward(loss, self.model, self.optimizer) + self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler) + + acc = (chosen_reward > reject_reward).float().mean().item() + acc_mean = acc_mean * 0.9 + 0.1 * acc + loss_mean = loss_mean * 0.9 + 0.1 * preference_loss.item() + # dpo logs + logs_dict = { + "loss": preference_loss.item(), + "acc": acc, + "chosen_reward": chosen_reward.mean().item(), + "reject_reward": reject_reward.mean().item(), + "loss_mean": loss_mean, + "acc_mean": acc_mean, + "lr": self.scheduler.get_last_lr()[0], + } + grad_norm = self.model.model.get_global_grad_norm() + if grad_norm is not None: + logs_dict.update({ + "grad_norm": grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm}) + + if self.nll_loss: + logs_dict["nll_loss"] = nll_loss.item() + # step bar + logs_dict = self.strategy.all_reduce(logs_dict) + step_bar.update() + end_time = time.time() + step_time = end_time - start_time + + # logs/checkpoints/evaluation + if step % self.strategy.accumulated_gradient == 0: + global_step = step // self.strategy.accumulated_gradient + client_states = {"consumed_samples": global_step * args.train_batch_size} + logs_dict["step_time"] = f"{step_time:.3f}s" + self.save_logs_and_checkpoints(args, global_step, step_bar, logs_dict, client_states) + + step += 1 + epoch_bar.update() + + if self._wandb is not None and self.strategy.is_rank_0(): + self._wandb.finish() + if self._tensorboard is not None and self.strategy.is_rank_0(): + self._tensorboard.close() + + def evaluate(self, eval_dataloader, steps=0): + self.model.eval() + with torch.no_grad(): + step_bar = tqdm( + range(eval_dataloader.__len__()), + desc="Eval stage of global_step %d" % steps, + disable=not self.strategy.is_rank_0(), + ) + acc_sum = 0 + loss_sum = 0 + times = 0 + for input_data in eval_dataloader: + data = self.data_to_device(input_data) + + chosen_logps, rejected_logps, aux_loss, _ = self.concatenated_forward( + self.model, data["input_ids"], data["attention_mask"], + data["labels"], data["pixel_values"], data["image_grid_thw"], + ) + with torch.no_grad(): + reference_chosen_logps, reference_rejected_logps, _, _ = self.concatenated_forward( + self.ref_model, data["input_ids"], data["attention_mask"], + data["labels"], data["pixel_values"], data["image_grid_thw"], + ) + + loss, chosen_reward, reject_reward = self.loss_fn( + chosen_logps, rejected_logps, reference_chosen_logps, reference_rejected_logps + ) + acc_sum += (chosen_reward > reject_reward).float().mean().item() + loss_sum += loss.item() + times += 1 + step_bar.update() + + logs = { + "eval_loss": loss_sum / times, + "acc_mean": acc_sum / times, + } + logs = self.strategy.all_reduce(logs) + step_bar.set_postfix(logs) + + if self.strategy.is_rank_0(): + if self._wandb is not None: + logs = {"eval/%s" % k: v for k, v in {**logs, "global_step": steps}.items()} + self._wandb.log(logs) + elif self._tensorboard is not None: + for k, v in logs.items(): + self._tensorboard.add_scalar(f"eval/{k}", v, steps) + self.model.train() # reset model state + diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/ppo_actor.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/ppo_actor.py index 9661b0edb0d60591bc780405ff3536805d2ce1d9..1d8bb59c459b112cd23dc87b67ddf16cb1bbeb18 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/ppo_actor.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/ppo_actor.py @@ -82,37 +82,24 @@ class ActorPPOTrainer(PPOTrainer): world_size = vllm_num_engines * vllm_tensor_parallel_size + 1 backend = getattr(self.strategy.args, "vllm_sync_backend", "nccl") - use_ray = getattr(self.strategy.args, "vllm_sync_with_ray", False) - group_name = "openrlhf" refs = [ engine.init_process_group.remote( master_address, master_port, i * vllm_tensor_parallel_size + 1, world_size, - group_name, + "openrlhf", backend=backend, - use_ray=use_ray, ) for i, engine in enumerate(self.vllm_engines) ] - if use_ray: - import ray.util.collective as collective - collective.init_collective_group( - world_size=world_size, - rank=0, - backend=backend, - group_name=group_name - ) - self._model_update_group = group_name - else: - self._model_update_group = init_process_group( - backend=backend, - init_method=f"tcp://{master_address}:{master_port}", - world_size=world_size, - rank=0, - group_name=group_name, - ) + self._model_update_group = init_process_group( + backend=backend, + init_method=f"tcp://{master_address}:{master_port}", + world_size=world_size, + rank=0, + group_name="openrlhf", + ) ray.get(refs) @@ -149,15 +136,8 @@ class ActorPPOTrainer(PPOTrainer): return self.training_step_actor(experience) def _broadcast_to_vllm(self): - use_prefix_cache = getattr(self.strategy.args, "enable_prefix_caching", False) - cache_reset_refs = [] - if use_prefix_cache and torch.distributed.get_rank() == 0: - # clear prefix cache - for engine in self.vllm_engines: - cache_reset_refs.append(engine.reset_prefix_cache.remote()) # avoid OOM torch.cuda.empty_cache() - use_ray = getattr(self.strategy.args, "vllm_sync_with_ray", False) model = self.actor.model.module count, num_params = 0, len(list(model.named_parameters())) for name, param in model.named_parameters(): @@ -174,14 +154,8 @@ class ActorPPOTrainer(PPOTrainer): # For ZeRO-3, allgather sharded parameter and broadcast to all vllm engines by rank 0 with deepspeed.zero.GatheredParameters([param], enabled=self.strategy.args.zero_stage == 3): if torch.distributed.get_rank() == 0: - if use_ray: - import ray.util.collective as collective - collective.broadcast(param.data, 0, group_name=self._model_update_group) - else: - torch.distributed.broadcast(param.data, 0, group=self._model_update_group) + torch.distributed.broadcast(param.data, 0, group=self._model_update_group) ray.get(refs) - if cache_reset_refs: - ray.get(cache_reset_refs) torch.distributed.barrier() def _save_checkpoint(self, args, tag, client_states): diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/vllm_engine.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/vllm_engine.py index 889b03424265ce4ded3ae59dcda7b52b2e310147..733c57effb320ef0e3be24f21741aff5fda0eb29 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/vllm_engine.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/vllm_engine.py @@ -35,11 +35,7 @@ class LLMRayActor: else: # RayGPUExecutor # See the patch https://github.com/vllm-project/vllm/commit/479d69fad0538f04cb22bf13e76ff91cfeb8a4e5 - if vllm.__version__ >= "0.4.3": - # https://github.com/vllm-project/vllm/commit/676a99982fe9aabe72fd52a91e08988a653a7359 - kwargs["distributed_executor_backend"] = "ray" - else: - kwargs["worker_use_ray"] = True + kwargs["worker_use_ray"] = True if vllm.__version__ > "0.6.4.post1": # https://github.com/vllm-project/vllm/pull/10555 @@ -60,14 +56,14 @@ class LLMRayActor: def generate(self, *args, **kwargs): return self.llm.generate(*args, **kwargs) - def init_process_group(self, master_address, master_port, rank_offset, world_size, group_name, backend, use_ray): + def init_process_group(self, master_address, master_port, rank_offset, world_size, group_name, backend): if self.use_gpu_executor: return self.llm.llm_engine.model_executor.driver_worker.init_process_group( - master_address, master_port, rank_offset, world_size, group_name, backend, use_ray + master_address, master_port, rank_offset, world_size, group_name, backend ) else: return self.llm.llm_engine.model_executor._run_workers( - "init_process_group", master_address, master_port, rank_offset, world_size, group_name, backend, use_ray + "init_process_group", master_address, master_port, rank_offset, world_size, group_name, backend ) def update_weight(self, name, dtype, shape, empty_cache=False): @@ -78,14 +74,6 @@ class LLMRayActor: else: return self.llm.llm_engine.model_executor._run_workers("update_weight", name, dtype, shape, empty_cache) - def reset_prefix_cache(self): - import vllm - if vllm.__version__ < "0.7.0": - # https://github.com/vllm-project/vllm/commit/7206ce4ce112ed117796a59045c968a6d353f691 - logger.warning("Reset prefix cache API is available only from vLLM 0.7.0!") - return - self.llm.llm_engine.reset_prefix_cache() - def stop_remote_worker_execution_loop(self): # Fix error for using 2 communication group # https://github.com/vllm-project/vllm/commit/eb6d3c264d0cd8e44dec16bca7947fbe96415ce9#diff-e1ad69e38e033accddfa5480ec808c4740eb39244d1ef51cc3407e20dde8cfd4 diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/vllm_worker_wrap.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/vllm_worker_wrap.py index 2f324793d08ccecad759d094009d34872ef8812f..730dd12b85cec344e2d4accfe02f7907a7d82c5e 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/vllm_worker_wrap.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/ray/vllm_worker_wrap.py @@ -8,30 +8,19 @@ logger = init_logger(__name__) class WorkerWrap(Worker): - def init_process_group(self, master_address, master_port, rank_offset, world_size, group_name, backend="nccl", use_ray=False): + def init_process_group(self, master_address, master_port, rank_offset, world_size, group_name, backend="nccl"): """Init torch process group for model weights update""" assert torch.distributed.is_initialized(), f"default torch process group must be initialized" assert group_name != "", f"group name must not be empty" rank = torch.distributed.get_rank() + rank_offset - if use_ray: - import ray.util.collective as collective - collective.init_collective_group( - world_size=world_size, - rank=rank, - backend=backend, - group_name=group_name - ) - self._model_update_group = group_name - else: - self._model_update_group = init_process_group( - backend=backend, - init_method=f"tcp://{master_address}:{master_port}", - world_size=world_size, - rank=rank, - group_name=group_name, - ) - self._model_update_with_ray = use_ray + self._model_update_group = init_process_group( + backend=backend, + init_method=f"tcp://{master_address}:{master_port}", + world_size=world_size, + rank=rank, + group_name=group_name, + ) print( f"init_process_group: master_address={master_address}, master_port={master_port}, ", f"rank={rank}, world_size={world_size}, group_name={group_name}", @@ -44,11 +33,7 @@ class WorkerWrap(Worker): assert dtype == self.model_config.dtype, f"mismatch dtype: src {dtype}, dst {self.model_config.dtype}" weight = torch.empty(shape, dtype=dtype, device="cuda") - if self._model_update_with_ray: - import ray.util.collective as collective - collective.broadcast(weight, 0, group_name=self._model_update_group) - else: - torch.distributed.broadcast(weight, 0, group=self._model_update_group) + torch.distributed.broadcast(weight, 0, group=self._model_update_group) self.model_runner.model.load_weights(weights=[(name, weight)]) diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/sft_trainer.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/sft_trainer.py index fa92452d5054b3b71c5cc860a1b0ea24ac24ca4c..85231970356f082bd8d9f4e92a8817c955f5e76d 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/sft_trainer.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/trainer/sft_trainer.py @@ -1,5 +1,8 @@ +import time import os + from abc import ABC +from typing_extensions import override import torch from torch.optim import Optimizer @@ -165,18 +168,10 @@ class SFTTrainer(ABC): if not self.pretrain_mode: if self.packing_samples: - # As response_ranges need to constrain the dataset organization strictly, we handle multiturn feature separately. - if infos["response_ranges"]: - dump_labels = torch.full(labels.size(), self.loss_fn.IGNORE_INDEX).to(labels.device) - for response_ranges in infos["response_ranges"]: - for response_range in response_ranges: - dump_labels[0][response_range[0]: response_range[1]] = labels[0][response_range[0]: response_range[1]] - labels = dump_labels - else: - index = 0 - for input_length, source_len in zip(infos["input_length"], prompt_id_lens): - labels[0][index : index + source_len] = self.loss_fn.IGNORE_INDEX - index += input_length + index = 0 + for input_length, source_len in zip(infos["input_length"], prompt_id_lens): + labels[0][index : index + source_len] = self.loss_fn.IGNORE_INDEX + index += input_length else: for label, source_len in zip(labels, prompt_id_lens): label[:source_len] = self.loss_fn.IGNORE_INDEX @@ -195,7 +190,6 @@ class SFTTrainer(ABC): logs_dict["aux_loss"] = aux_loss.item() # step bar logs_dict = self.strategy.all_reduce(logs_dict) - step_bar.set_postfix(logs_dict) step_bar.update() # logs/checkpoints/evaluation @@ -218,9 +212,11 @@ class SFTTrainer(ABC): # logs/checkpoints/evaluation def save_logs_and_checkpoints(self, args, global_step, step_bar, logs_dict={}, client_states={}): if global_step % args.logging_steps == 0: + logs = {"train/%s" % k: v for k, v in {**logs_dict, "global_step": global_step}.items()} + if self.strategy.is_rank_0(): + step_bar.write(str(logs)) # wandb if self._wandb is not None and self.strategy.is_rank_0(): - logs = {"train/%s" % k: v for k, v in {**logs_dict, "global_step": global_step}.items()} self._wandb.log(logs) # TensorBoard elif self._tensorboard is not None and self.strategy.is_rank_0(): @@ -284,17 +280,10 @@ class SFTTrainer(ABC): if not self.pretrain_mode: if self.packing_samples: - if infos["response_ranges"]: - dump_labels = torch.full(labels.size(), self.loss_fn.IGNORE_INDEX).to(labels.device) - for response_ranges in infos["response_ranges"]: - for response_range in response_ranges: - dump_labels[0][response_range[0]: response_range[1]] = labels[0][response_range[0]: response_range[1]] - labels = dump_labels - else: - index = 0 - for input_length, source_len in zip(infos["input_length"], prompt_id_lens): - labels[0][index : index + source_len] = self.loss_fn.IGNORE_INDEX - index += input_length + index = 0 + for input_length, source_len in zip(infos["input_length"], prompt_id_lens): + labels[0][index : index + source_len] = self.loss_fn.IGNORE_INDEX + index += input_length else: for label, source_len in zip(labels, prompt_id_lens): label[:source_len] = self.loss_fn.IGNORE_INDEX @@ -316,3 +305,147 @@ class SFTTrainer(ABC): for k, v in logs.items(): self._tensorboard.add_scalar(f"eval/{k}", v, steps) self.model.train() # reset model state + + +class VLSFTTrainer(SFTTrainer): + def data_to_device(self, input_data): + for key, value in input_data.items(): + input_data[key] = value.to(torch.cuda.current_device()) + return input_data + + @override + def fit(self, args, consumed_samples=0, num_update_steps_per_epoch=None): + # get eval and save steps + if args.eval_steps == -1: + args.eval_steps = num_update_steps_per_epoch # Evaluate once per epoch + if args.save_steps == -1: + args.save_steps = float("inf") # do not save ckpt + + # Restore step and start_epoch + step = consumed_samples // args.train_batch_size * self.strategy.accumulated_gradient + 1 + start_epoch = consumed_samples // args.train_batch_size // num_update_steps_per_epoch + consumed_samples = consumed_samples % (num_update_steps_per_epoch * args.train_batch_size) + + epoch_bar = tqdm( + range(start_epoch, self.epochs), + desc="Train epoch", + disable=not self.strategy.is_rank_0(), + ) + + for epoch in range(start_epoch, self.epochs): + if isinstance(self.train_dataloader.sampler, DistributedSampler): + self.train_dataloader.sampler.set_epoch( + epoch, consumed_samples=0 if epoch > start_epoch else consumed_samples + ) + + step_bar = tqdm( + range(self.train_dataloader.__len__()), + desc="Train step of epoch %d" % epoch, + disable=not self.strategy.is_rank_0(), + ) + + # train + self.model.train() + loss_mean = 0 + + assert self.strategy.ring_attn_group is None, f"Ring attention is not supported on vision models currently" + + for input_data in self.train_dataloader: + start_time = time.time() + data = self.data_to_device(input_data) + labels = data["labels"] + + output = self.model( + data["input_ids"], + attention_mask=data["attention_mask"], + return_output=True, + pixel_values=data["pixel_values"], + image_grid_thw=data["image_grid_thw"], + ) + + # mixtral + if self.aux_loss: + aux_loss = output.aux_loss + else: + aux_loss = 0 + + gpt_loss = self.loss_fn(output.logits, labels) + loss = gpt_loss + aux_loss * self.args.aux_loss_coef + self.strategy.backward(loss, self.model, self.optimizer) + self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler) + + loss_mean = loss_mean * 0.9 + 0.1 * gpt_loss.item() + logs_dict = { + "gpt_loss": gpt_loss.item(), + "loss_mean": loss_mean, + "lr": self.scheduler.get_last_lr()[0], + } + grad_norm = self.model.model.get_global_grad_norm() + if grad_norm is not None: + logs_dict.update({ + "grad_norm": grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm}) + + if self.aux_loss: + logs_dict["aux_loss"] = aux_loss.item() + # step bar + logs_dict = self.strategy.all_reduce(logs_dict) + step_bar.update() + end_time = time.time() + step_time = end_time - start_time + + # logs/checkpoints/evaluation + if step % self.strategy.accumulated_gradient == 0: + global_step = step // self.strategy.accumulated_gradient + client_states = {"consumed_samples": global_step * args.train_batch_size} + logs_dict["step_time"] = f"{step_time:.3f}s" + self.save_logs_and_checkpoints(args, global_step, step_bar, logs_dict, client_states) + + step += 1 + + epoch_bar.update() + + if self._wandb is not None and self.strategy.is_rank_0(): + self._wandb.finish() + if self._tensorboard is not None and self.strategy.is_rank_0(): + self._tensorboard.close() + + def evaluate(self, eval_dataloader, steps=0): + times = 0 + self.model.eval() + with torch.no_grad(): + loss_sum = 0 + step_bar = tqdm( + range(eval_dataloader.__len__()), + desc="Eval stage of steps %d" % steps, + disable=not self.strategy.is_rank_0(), + ) + + for input_data in eval_dataloader: + data = self.data_to_device(input_data) + labels = data["labels"] + + output = self.model( + data["input_ids"], + attention_mask=data["attention_mask"], + return_output=True, + pixel_values=data["pixel_values"], + image_grid_thw=data["image_grid_thw"], + ) + + loss = self.loss_fn(output.logits, labels) + + times += 1 + loss_sum += loss.item() + bar_dict = {"eval gpt_loss": loss_sum / times} + step_bar.update() + logs = self.strategy.all_reduce(bar_dict) + step_bar.set_postfix(logs) + + if self.strategy.is_rank_0(): + if self._wandb is not None: + logs = {"eval/%s" % k: v for k, v in {**logs, "global_step": steps}.items()} + self._wandb.log(logs) + elif self._tensorboard is not None: + for k, v in logs.items(): + self._tensorboard.add_scalar(f"eval/{k}", v, steps) + self.model.train() # reset model state diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/__init__.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/__init__.py index 08ab0a9ba91d8bd770b3694d3cdce94ae0836718..e69a0696ace71b679ad7e6e877253621514e6b12 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/__init__.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/__init__.py @@ -1,5 +1,7 @@ from .processor import get_processor, reward_normalization from .utils import blending_datasets, get_strategy, get_tokenizer +from .vision_args import add_vision_args +from .vision_utils import get_qwen2_vl_utils, get_vision_processor __all__ = [ "get_processor", @@ -7,4 +9,7 @@ __all__ = [ "blending_datasets", "get_strategy", "get_tokenizer", + "get_vision_processor", + "get_qwen2_vl_utils", + "add_vision_args", ] diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/utils.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/utils.py index a69b13ece3d34a974fef364d30af700ad895668b..71c25d77c34c47f05adb53956f3fad460ba134e9 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/utils.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/utils.py @@ -71,10 +71,6 @@ def blending_datasets( ext = "json" data = load_dataset(ext, data_files=dataset) strategy.print(f"loaded {dataset} with data_files={dataset}") - # local dataset saved with `datasets.Dataset.save_to_disk` - elif os.path.isdir(dataset): - data = load_from_disk(dataset) - strategy.print(f"loaded {dataset} from disk") # remote/local folder or common file else: data = load_dataset(dataset, data_dir=data_dir) diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/vision_args.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/vision_args.py new file mode 100644 index 0000000000000000000000000000000000000000..be1acde5810653270d0b9bdc01c8ab854b192eae --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/vision_args.py @@ -0,0 +1,41 @@ +def add_vision_args(parser): + group = parser.add_argument_group(title='vision args') + group.add_argument("--task_type", type=str, + default="sft", + choices=["sft", "dpo"], + help="task type") + group.add_argument("--model_arch", type=str, choices=["qwen2_vl"], + help="model arch",) + group.add_argument("--dataset_config_path", type=str, default=None, + help="the dataset config") + + group.add_argument("--image_resolution", type=int, default=512 * 512, + help="The number of pixels of image below this resolution.") + group.add_argument("--video_resolution", type=int, default=128 * 128, + help="The number of pixels of video below this resolution.") + group.add_argument("--video_fps", type=float, default=2.0, + help="The frames to sample per second for video inputs.") + group.add_argument("--video_maxlen", type=int, default=64, + help="The maximum number of sampled frames for video inputs.") + + group.add_argument("--efficient_eos", type=bool, default=False, + help="the efficient_eos of VisionEncoderUtils") + group.add_argument("--processing_num_workers", type=int, default=18, + help="num workers processing process") + group.add_argument("--train_on_prompt", type=bool, default=False, + help="train_on_prompt") + group.add_argument("--mask_history", type=bool, default=False, + help="mask_history") + group.add_argument("--overwrite_cache", type=bool, default=True, + help="overwrite_cache") + group.add_argument("--local_process_index", type=int, default=0, + help="local_process_index") + group.add_argument("--preprocessing_batch_size", type=int, default=1000, + help="preprocessing_batch_size") + group.add_argument("--neat_packing", action="store_true", + help="enable sequence packing without cross-attention.") + + group.add_argument("--freeze_vision_tower", type=bool, default=True, + help="Whether ot not to freeze vision tower in training. default: True") + + return parser diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/vision_utils.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/vision_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8c4e18fc907fc6c8a4a003e37f77ccd09d045eab --- /dev/null +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/openrlhf/utils/vision_utils.py @@ -0,0 +1,703 @@ +import json +import math +import os +import re +from abc import ABC, abstractmethod +from copy import deepcopy +from dataclasses import dataclass, field +from io import BytesIO +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union + +import numpy as np +import torch +from PIL import Image +from PIL.Image import Image as ImageObject +from transformers import AutoConfig, AutoProcessor +from typing_extensions import override + +IGNORE_INDEX = -100 +ImageInput = Union[str, bytes, ImageObject] +SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]] +IMAGE_PLACEHOLDER = os.environ.get("IMAGE_PLACEHOLDER", "") +VIDEO_PLACEHOLDER = os.environ.get("VIDEO_PLACEHOLDER", "