diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/recipe/dapo/src/config/dapo_trainer.yaml b/PyTorch/built-in/rl/VeRL_for_PyTorch/recipe/dapo/src/config/dapo_trainer.yaml index bcafe90f69122a9ef64cf37ebf02b9929d64d587..d06566950b257ef8f4202eff2b0ecde1c7360101 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/recipe/dapo/src/config/dapo_trainer.yaml +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/recipe/dapo/src/config/dapo_trainer.yaml @@ -1,3 +1,11 @@ +hydra: + searchpath: + - file://verl/trainer/config + +defaults: + - ppo_trainer + - _self_ + data: tokenizer: null train_files: ~/data/rlhf/gsm8k/train.parquet diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/test/runtime_env_32b.yaml b/PyTorch/built-in/rl/VeRL_for_PyTorch/test/runtime_env_32b.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c3d76cfb0a5d1212b4b68e9dbae03ff7d6012063 --- /dev/null +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/test/runtime_env_32b.yaml @@ -0,0 +1,5 @@ +working_dir: ./ +excludes: ["/.git/"] +env_vars: + HCCL_CONNECT_TIMEOUT: "1500" + HCCL_EXEC_TIMEOUT: "1500" \ No newline at end of file diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/test/train_qwen2_5_32b_instruct_DAPO_performance_32p.sh b/PyTorch/built-in/rl/VeRL_for_PyTorch/test/train_qwen2_5_32b_instruct_DAPO_performance_32p.sh new file mode 100644 index 0000000000000000000000000000000000000000..8df492ae9708b0192206f40da77e0b49e9048b0e --- /dev/null +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/test/train_qwen2_5_32b_instruct_DAPO_performance_32p.sh @@ -0,0 +1,143 @@ +#!/usr/bin/env bash +set -euxo pipefail + +project_name='DAPO' +exp_name='DAPO-Qwen2.5-32B-Instruct' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 20)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=16 +gen_prompt_bsz=$((train_prompt_bsz * 2)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=1 + +# Ray +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/test/runtime_env_32b.yaml"} +NNODES=${NNODES:-2} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/Qwen2.5-32B-Instruct"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/DAPO-Math-17k/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/AIME-2024/data/aime-2024.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Performance Related Parameter +sp_size=8 +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size)) +offload=True +gen_tp=4 + +ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ + --working-dir "${WORKING_DIR}" \ + -- python3 -m recipe.dapo.src.main_dapo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.50 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k="${top_k}" \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + reward_model.reward_manager=dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=16 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=False \ + trainer.test_freq=5 \ + trainer.save_freq=-1 \ + trainer.total_epochs=1 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + data.shuffle=False \ + actor_rollout_ref.actor.use_torch_compile=False \ + actor_rollout_ref.ref.use_torch_compile=False \ + actor_rollout_ref.actor.entropy_checkpointing=True \ + actor_rollout_ref.ref.entropy_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.ref.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.actor.fsdp_config.backward_prefetch=BACKWARD_PRE \ + actor_rollout_ref.ref.fsdp_config.backward_prefetch=BACKWARD_PRE \ + actor_rollout_ref.actor.use_entropy_from_logits_with_chunking=True \ + actor_rollout_ref.ref.use_entropy_from_logits_with_chunking=True \ No newline at end of file diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/test/train_qwen2_5_7b_instruct_DAPO_performance_16p.sh b/PyTorch/built-in/rl/VeRL_for_PyTorch/test/train_qwen2_5_7b_instruct_DAPO_performance_16p.sh index 62e43355f465824cb4f010724f65cbaa80b5a617..6376ac00931e707d2988a95e2400034141eea59f 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/test/train_qwen2_5_7b_instruct_DAPO_performance_16p.sh +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/test/train_qwen2_5_7b_instruct_DAPO_performance_16p.sh @@ -132,12 +132,13 @@ ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ trainer.resume_mode=auto \ data.shuffle=False \ actor_rollout_ref.actor.use_torch_compile=False \ - +actor_rollout_ref.ref.use_torch_compile=False \ - +actor_rollout_ref.actor.entropy_checkpointing=True \ - +actor_rollout_ref.ref.entropy_checkpointing=True \ - +actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \ - +actor_rollout_ref.ref.fsdp_config.forward_prefetch=True \ - +actor_rollout_ref.actor.fsdp_config.backward_prefetch=BACKWARD_PRE \ - +actor_rollout_ref.ref.fsdp_config.backward_prefetch=BACKWARD_PRE \ - +actor_rollout_ref.actor.use_entropy_from_logits_with_chunking=True \ - +actor_rollout_ref.ref.use_entropy_from_logits_with_chunking=True \ No newline at end of file + actor_rollout_ref.ref.use_torch_compile=False \ + actor_rollout_ref.actor.entropy_checkpointing=True \ + actor_rollout_ref.ref.entropy_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.ref.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.actor.fsdp_config.backward_prefetch=BACKWARD_PRE \ + actor_rollout_ref.ref.fsdp_config.backward_prefetch=BACKWARD_PRE \ + actor_rollout_ref.actor.use_entropy_from_logits_with_chunking=True \ + actor_rollout_ref.ref.use_entropy_from_logits_with_chunking=True \ + actor_rollout_ref.rollout.seed=1234 diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/trainer/config/ppo_trainer.yaml b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/trainer/config/ppo_trainer.yaml index 59e9b9db570856ec34a2692c913ee845a14c20d3..85d5978759f7a8f853dc91721d627cf71f897882 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/trainer/config/ppo_trainer.yaml +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/trainer/config/ppo_trainer.yaml @@ -63,9 +63,15 @@ actor_rollout_ref: param_offload: False optimizer_offload: False fsdp_size: -1 + forward_prefetch: False + backward_prefetch: None + entropy_checkpointing: False + use_entropy_from_logits_with_chunking: False ref: fsdp_config: param_offload: False + forward_prefetch: False + backward_prefetch: None wrap_policy: # transformer_layer_cls_to_wrap: None min_num_params: 0 @@ -75,6 +81,8 @@ actor_rollout_ref: log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size + entropy_checkpointing: False + use_entropy_from_logits_with_chunking: False rollout: name: vllm temperature: 1.0 @@ -104,6 +112,7 @@ actor_rollout_ref: do_sample: True # number of responses (i.e. num sample times) n: 1 # > 1 for grpo + seed: 0 val_kwargs: # sampling parameters for validation top_k: -1 # 0 for hf rollout, -1 for vllm rollout diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py index 0b5fbe3d372d6c05a1aae69b1d609bab5f79e3fa..f3584de66a737708cc5b42b471aa495e5a744643 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py @@ -126,6 +126,7 @@ class vLLMRollout(BaseRollout): enable_chunked_prefill=config.enable_chunked_prefill, enable_prefix_caching=True, trust_remote_code=trust_remote_code, + seed=config.seed ) # Offload vllm model to reduce peak memory usage