From 5f16caa3e75a56db49d3f843e36180825e42a875 Mon Sep 17 00:00:00 2001 From: liutongtong27 Date: Tue, 27 May 2025 20:12:28 +0800 Subject: [PATCH 1/2] add profiler for verl --- .../rl/VeRL_for_PyTorch/docs/profiler.md | 123 ++++++++++++++++++ .../verl/trainer/config/ppo_trainer.yaml | 4 + .../config/profiler_config/profiler.yaml | 17 +++ .../verl/trainer/ppo/core_algos.py | 4 + .../verl/trainer/ppo/ray_trainer.py | 2 + .../VeRL_for_PyTorch/verl/utils/profiler.py | 113 ++++++++++++++++ .../verl/workers/actor/dp_actor.py | 5 + .../verl/workers/fsdp_workers.py | 29 ++++- .../rollout/vllm_rollout/vllm_rollout_spmd.py | 3 + 9 files changed, 299 insertions(+), 1 deletion(-) create mode 100644 PyTorch/built-in/rl/VeRL_for_PyTorch/docs/profiler.md create mode 100644 PyTorch/built-in/rl/VeRL_for_PyTorch/verl/trainer/config/profiler_config/profiler.yaml create mode 100644 PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/profiler.py diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/docs/profiler.md b/PyTorch/built-in/rl/VeRL_for_PyTorch/docs/profiler.md new file mode 100644 index 0000000000..e61cb2af49 --- /dev/null +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/docs/profiler.md @@ -0,0 +1,123 @@ +## 概述 + +性能调优模块为强化学习训练流程提供了性能数据采集、分析能力,可帮助用户识别训练过程中的性能瓶颈并进行优化。 + +## 配置选项 + +性能调优工具通过`verl/trainer/config/profiler_config/profiler.yaml`配置文件中的 `profiler_config` 部分进行配置: + +```yaml +actor_rollout_ref: + profile: false + mstx: false + role: actor_rollout_ref + stage: all + profile_save_path: ./profiler_data + profile_export_type: text + profile_step_start: 1 + profile_step_end: 2 + profile_level: level1 + profile_with_memory: false + profile_record_shapes: false + profile_with_npu: true + profile_with_cpu: true + profile_with_module: false + profile_analysis: false + profile_ranks: all +``` + +### 主要配置参数说明 + +| 参数 | 说明 | 可选值 | +|-----------------------|---------------|-------------------------------------------------------------------------------------------| +| profile | 性能分析开关 | true/false,所有性能数据采集均依赖该开关开启 | +| mstx | 轻量级打点采集模式 | true/false,启用/关闭轻量级打点采集 | +| role | 角色 | 不需要修改 | +| stage | 性能数据采集阶段 | all(表示采集所有阶段)、 actor_generate, actor_compute_log_prob, ref_compute_log_prob, actor_update | +| profile_save_path | 性能数据输出目录 | 任意有效路径,默认为"./profiler_data" | +| profile_export_type | 导出格式 | text、db (可减少约70%磁盘空间),默认值text | +| profile_step_start | 开启采集数据的步骤 | 任意正整数,默认为1,profile_step_start从1开始 | +| profile_step_end | 结束采集数据的步骤 | 任意正整数,默认为2,实际采集步数为 profile_step_end-profile_step_start,不包含profile_step_end | +| profile_level | 采集级别 | level_none、level0、level1、level2,默认值level0 | +| profile_with_memory | 内存分析开关 | true/false,默认值false,启用/关闭内存分析 | +| profile_record_shapes | 张量形状记录开关 | true/false,默认值false,是否记录张量形状 | +| profile_with_cpu | Host侧性能数据开关 | true/false,默认值false,是否包含Host侧性能数据 | +| profile_with_npu | Device侧性能数据开关 | true/false,默认值false,是否包含NPU侧性能数据 | +| profile_with_module | Python调用栈信息开关 | true/false,默认值false,是否包含Python侧调用栈信息 | +| profile_analysis | 自动解析开关 | true/false,默认值false,是否在采集后自动解析数据 | +| profile_ranks | 采集数据的卡号 | all表示所有rank, 默认值all,可以通过列表指定,如[0, 1] | + +## 性能数据采集 + +### 1. 按训练阶段分段采集 + +- **关键配置**: + ```yaml + profile: true + mstx: false + profile_level: level1 + stage: actor_generate + ``` + + - `stage`参数可选值: + - all + - actor_generate + - actor_compute_log_prob + - ref_compute_log_prob + - actor_update + +- **适用场景**: 需要查看训练某一特定阶段的详细计算、通信profiling数据 + +### 2. 使用轻量化采集模式 + +- **关键配置**: + ```yaml + profile: true + mstx: true + profile_level: level_none + profile_with_cpu: false + profile_with_npu: true + ``` + +- **适用场景**: 目前已集成ActorRolloutRefWorker的update_actor、generate_sequences等关键函数打点。如需查看某代码片段在timeline中的执行耗时,可通过以下两种方式在代码中添加自定义打点: + + ```python + # 方式一:使用装饰器装饰函数 + from verl.utils.profiler import mstx_timer_decorator + + @mstx_timer_decorator + def your_function(): + # 函数代码 + pass + + # 方式二:框住代码片段 + import torch_npu + + id = torch_npu.npu.mstx.range_start("your_tag_name") + result = complex_operation() # 需要记录打点时间片的代码 + torch_npu.npu.mstx.range_end(id) + ``` + +## 性能数据解析 + +性能数据采集后需要进行解析才能查看,可通过以下两种方式: + +### 1. 离线解析 + +适用于大规模集群,使用如下脚本对性能数据进行解析。 + +```python +import torch_npu +# 在性能数据采集完成后,可对所有性能数据执行离线解析("./profiler_data"可包含多份性能数据,解析可并行进行) +torch_npu.profiler.profiler.analyse(profiler_path="./profiler_data") +``` + +### 2. 在线解析 + +设置`profile_analysis=true`在采集后自动解析。注意:当性能数据量较大时,解析时间可能较长。 + +## 结果可视化 + +解析后的性能数据保存在`profile_save_path`指定目录中,可通过以下工具进行可视化: + +- **MindStudio Insight**:提供丰富的性能分析视图,包括时间线视图、算子分析、通信分析。详细使用指南可参考[MindStudio文档](https://www.hiascend.com/document/detail/zh/mindstudio/80RC1/index/index.html) diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/trainer/config/ppo_trainer.yaml b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/trainer/config/ppo_trainer.yaml index 59e9b9db57..10000a58b2 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/trainer/config/ppo_trainer.yaml +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/trainer/config/ppo_trainer.yaml @@ -1,3 +1,7 @@ +defaults: + - _self_ + - profiler_config/profiler + data: tokenizer: null train_files: ~/data/rlhf/gsm8k/train.parquet diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/trainer/config/profiler_config/profiler.yaml b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/trainer/config/profiler_config/profiler.yaml new file mode 100644 index 0000000000..1d3571c8bc --- /dev/null +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/trainer/config/profiler_config/profiler.yaml @@ -0,0 +1,17 @@ +actor_rollout_ref: + profile: false + mstx: false + role: actor_rollout_ref + stage: actor_update + profile_save_path: ./profiler_data + profile_export_type: text + profile_step_start: 1 + profile_step_end: 2 + profile_level: level1 + profile_with_memory: false + profile_record_shapes: false + profile_with_npu: true + profile_with_cpu: true + profile_with_module: false + profile_analysis: false + profile_ranks: all \ No newline at end of file diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/trainer/ppo/core_algos.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/trainer/ppo/core_algos.py index 1a31621077..5ec8844df3 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/trainer/ppo/core_algos.py +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/trainer/ppo/core_algos.py @@ -23,6 +23,7 @@ import torch from collections import defaultdict import verl.utils.torch_functional as verl_F +from verl.utils.profiler import mstx_timer_decorator class AdaptiveKLController: @@ -306,6 +307,7 @@ def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio): return token_level_scores - kl * kl_ratio +@mstx_timer_decorator def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str): """ Aggregate the loss matrix into a scalar. @@ -334,6 +336,7 @@ def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str return loss +@mstx_timer_decorator def compute_policy_loss(old_log_prob, log_prob, advantages, @@ -447,6 +450,7 @@ def compute_value_loss(vpreds, returns, values, response_mask, cliprange_value): return vf_loss, vf_clipfrac +@mstx_timer_decorator def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor: """Compute KL divergence given logprob and ref_logprob. Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1104 diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/trainer/ppo/ray_trainer.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/trainer/ppo/ray_trainer.py index 0f28f31315..4462bd0b76 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/trainer/ppo/ray_trainer.py +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/trainer/ppo/ray_trainer.py @@ -620,6 +620,7 @@ class RayPPOTrainer(object): resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout) actor_rollout_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.ActorRollout], config=self.config.actor_rollout_ref, + profiler_config=self.config.profiler_config, role='actor_rollout') self.resource_pool_to_cls[resource_pool]['actor_rollout'] = actor_rollout_cls else: @@ -636,6 +637,7 @@ class RayPPOTrainer(object): resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) ref_policy_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RefPolicy], config=self.config.actor_rollout_ref, + profiler_config=self.config.profiler_config, role='ref') self.resource_pool_to_cls[resource_pool]['ref'] = ref_policy_cls diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/profiler.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/profiler.py new file mode 100644 index 0000000000..b12cf761e4 --- /dev/null +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/profiler.py @@ -0,0 +1,113 @@ +import os +from functools import wraps + +import torch +import torch_npu + +def get_grpo_profiler(profiler_config, role: str = None): + if not profiler_config or not profiler_config.profile: + return None + + profiler_this_rank = False + if profiler_config.profile_ranks == "all": + profiler_this_rank = True + else: + try: + ranks = list(profiler_config.profile_ranks) + except (TypeError, AttributeError): + ranks = [0] + if (torch.distributed.get_rank() in ranks): + profiler_this_rank = True + if not profiler_this_rank: + return None + + if profiler_config.profile_level == 'level_none': + profiler_level = torch_npu.profiler.ProfilerLevel.Level_none + elif profiler_config.profile_level == 'level0': + profiler_level = torch_npu.profiler.ProfilerLevel.Level0 + elif profiler_config.profile_level == 'level1': + profiler_level = torch_npu.profiler.ProfilerLevel.Level1 + elif profiler_config.profile_level == 'level2': + profiler_level = torch_npu.profiler.ProfilerLevel.Level2 + else: + raise ValueError(f"profiler_level only supports level0," + f" 1, 2, and level_none, but gets {profiler_config.profile_level}") + + if profiler_config.profile_export_type == 'text': + profile_export_type = torch_npu.profiler.ExportType.Text + elif profiler_config.profile_export_type == 'db': + profile_export_type = torch_npu.profiler.ExportType.Db + else: + raise ValueError(f"profile_export_type only supports text or db," + f"but gets {profiler_config.export_type}") + + base_path = profiler_config.profile_save_path + if role: + profile_save_path = os.path.join(base_path, role) + else: + profile_save_path = base_path + + experimental_config = torch_npu.profiler._ExperimentalConfig( + aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, + profiler_level=profiler_level, + export_type=profile_export_type, + data_simplification=True, + msprof_tx=profiler_config.mstx + ) + if profiler_config.stage == "all": + skip_first = profiler_config.profile_step_start + active = profiler_config.profile_step_end - profiler_config.profile_step_start + else: + skip_first = 0 + active = 1 + + activites = [] + if profiler_config.profile_with_npu: + activites.append(torch_npu.profiler.ProfilerActivity.NPU) + if profiler_config.profile_with_cpu: + activites.append(torch_npu.profiler.ProfilerActivity.CPU) + + prof = torch_npu.profiler.profile( + with_modules=profiler_config.profile_with_module, + record_shapes=profiler_config.profile_record_shapes, + profile_memory=profiler_config.profile_with_memory, + activities=activites, + schedule=torch_npu.profiler.schedule(wait=0, warmup=0, active=active, repeat=1, skip_first=skip_first), + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(profile_save_path, + analyse_flag=profiler_config.profile_analysis), + experimental_config=experimental_config) + + return prof + + +def mstx_timer_decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + range_id = torch_npu.npu.mstx.range_start(func.__qualname__) + result = func(*args, **kwargs) + torch_npu.npu.mstx.range_end(range_id) + return result + return wrapper + + +def profiler_start(profiler_config, role="profiler_data", profiler_iteration=None): + if not profiler_config: + return None + if profiler_iteration is not None and ( + profiler_iteration < profiler_config.profile_step_start or + profiler_iteration >= profiler_config.profile_step_end): + return None + if profiler_config.stage == "all" and role != profiler_config.role: + return None + if profiler_config.stage != "all" and role != profiler_config.stage: + return None + profiler = get_grpo_profiler(profiler_config, role) + if not profiler: + return None + profiler.start() + return profiler + + +def profiler_step(profiler): + if profiler: + profiler.step() diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/workers/actor/dp_actor.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/workers/actor/dp_actor.py index 798855ac3f..4decfd0ebc 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/workers/actor/dp_actor.py +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/workers/actor/dp_actor.py @@ -29,6 +29,7 @@ from verl.utils.py_functional import append_to_dict from verl.utils.torch_functional import logprobs_from_logits, masked_mean from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx +from verl.utils.profiler import mstx_timer_decorator import verl.utils.torch_functional as verl_F from verl.utils.device import get_device_name, get_torch_device, is_cuda_available, is_npu_available @@ -69,6 +70,7 @@ class DataParallelPPOActor(BasePPOActor): else entropy_from_logits) self.device_name = get_device_name() + @mstx_timer_decorator def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, torch.Tensor]: """ Returns: @@ -172,6 +174,7 @@ class DataParallelPPOActor(BasePPOActor): return entropy, log_probs + @mstx_timer_decorator def _optimizer_step(self): assert self.config.grad_clip is not None @@ -188,6 +191,7 @@ class DataParallelPPOActor(BasePPOActor): self.actor_optimizer.step() return grad_norm + @mstx_timer_decorator def compute_log_prob(self, data: DataProto) -> torch.Tensor: """Compute the log probability of the responses given input_ids, attention_mask and position_ids @@ -246,6 +250,7 @@ class DataParallelPPOActor(BasePPOActor): return log_probs + @mstx_timer_decorator def update_policy(self, data: DataProto): # make sure we are in training mode self.actor_module.train() diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/workers/fsdp_workers.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/workers/fsdp_workers.py index 2ca6556abe..0babbc6bea 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/workers/fsdp_workers.py +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/workers/fsdp_workers.py @@ -40,6 +40,7 @@ from verl.utils.flops_counter import FlopsCounter from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager from verl.utils.device import get_device_name, is_cuda_available, get_torch_device, is_npu_available +from verl.utils.profiler import mstx_timer_decorator, profiler_start, profiler_step from codetiming import Timer @@ -76,9 +77,10 @@ class ActorRolloutRefWorker(Worker): or a hybrid engine based on the config.rollout """ - def __init__(self, config: DictConfig, role: str): + def __init__(self, config: DictConfig, profiler_config: DictConfig, role: str): super().__init__() self.config = config + self.profiler_config = profiler_config import torch.distributed if not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend="nccl" if is_cuda_available else "hccl") @@ -141,6 +143,9 @@ class ActorRolloutRefWorker(Worker): self.ulysses_sequence_parallel_size) self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size + self.actor_rollout_ref_profiler = profiler_start(self.profiler_config.actor_rollout_ref, "actor_rollout_ref") + self.prof_iteration = 1 + def _build_model_optimizer(self, model_path, fsdp_config, @@ -447,9 +452,13 @@ class ActorRolloutRefWorker(Worker): processing_class=self.processor if self.processor is not None else self.tokenizer, checkpoint_contents=self.config.actor.checkpoint.contents) + @mstx_timer_decorator @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def update_actor(self, data: DataProto): # Support all hardwares + update_actor_profiler = profiler_start(self.profiler_config.actor_rollout_ref, role="actor_update", + profiler_iteration=self.prof_iteration) + data = data.to(get_torch_device().current_device()) assert self._is_actor @@ -491,11 +500,18 @@ class ActorRolloutRefWorker(Worker): if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.actor_optimizer) + profiler_step(update_actor_profiler) + profiler_step(self.actor_rollout_ref_profiler) + self.prof_iteration += 1 + return output + @mstx_timer_decorator @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def generate_sequences(self, prompts: DataProto): # Support all hardwares + generate_sequences_profiler = profiler_start(self.profiler_config.actor_rollout_ref, role="actor_generate", + profiler_iteration=self.prof_iteration) prompts = prompts.to(get_torch_device().current_device()) assert self._is_rollout @@ -531,11 +547,15 @@ class ActorRolloutRefWorker(Worker): # clear kv cache log_gpu_memory_usage('After generate_sequences', logger=logger) + profiler_step(generate_sequences_profiler) return output + @mstx_timer_decorator @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_log_prob(self, data: DataProto): assert self._is_actor + compute_log_prob_profiler = profiler_start(self.profiler_config.actor_rollout_ref, role="actor_compute_log_prob", + profiler_iteration=self.prof_iteration) if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) @@ -565,12 +585,17 @@ class ActorRolloutRefWorker(Worker): offload_fsdp_model_to_cpu(self.actor_module_fsdp) log_gpu_memory_usage('After compute_log_prob', logger=logger) + profiler_step(compute_log_prob_profiler) return output + @mstx_timer_decorator @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_ref_log_prob(self, data: DataProto): assert self._is_ref + compute_ref_log_prob_profiler = profiler_start(self.profiler_config.actor_rollout_ref, + role="ref_compute_log_prob", + profiler_iteration=self.prof_iteration) # Support all hardwares data = data.to(get_torch_device().current_device()) @@ -592,6 +617,8 @@ class ActorRolloutRefWorker(Worker): if self.world_size > 1: self.ref_policy.actor_module._handle.reshard(True) + profiler_step(compute_ref_log_prob_profiler) + self.prof_iteration += 1 return output @register(dispatch_mode=Dispatch.ONE_TO_ALL) diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py index 0b5fbe3d37..fb18edbe24 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py @@ -39,6 +39,7 @@ from verl.workers.rollout.base import BaseRollout from vllm.distributed import parallel_state as vllm_ps from vllm import LLM, SamplingParams from verl.third_party.vllm import vllm_version +from verl.utils.profiler import mstx_timer_decorator # TODO # 1. support pp in vllm @@ -151,6 +152,7 @@ class vLLMRollout(BaseRollout): self.pad_token_id = tokenizer.pad_token_id + @mstx_timer_decorator @contextmanager def update_sampling_params(self, **kwargs): # update sampling params @@ -167,6 +169,7 @@ class vLLMRollout(BaseRollout): for key, value in old_sampling_params_args.items(): setattr(self.sampling_params, key, value) + @mstx_timer_decorator @torch.no_grad() def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: # rebuild vllm cache engine -- Gitee From 32de33afa67d1d9f77ca65ad654e5bcf046e4361 Mon Sep 17 00:00:00 2001 From: liutongtong27 Date: Tue, 10 Jun 2025 22:02:52 +0800 Subject: [PATCH 2/2] add profiler for verl --- .../rl/VeRL_for_PyTorch/docs/profiler.md | 2 +- .../rl/VeRL_for_PyTorch/verl/utils/profiler.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/docs/profiler.md b/PyTorch/built-in/rl/VeRL_for_PyTorch/docs/profiler.md index e61cb2af49..8100df9e39 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/docs/profiler.md +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/docs/profiler.md @@ -37,7 +37,7 @@ actor_rollout_ref: | profile_save_path | 性能数据输出目录 | 任意有效路径,默认为"./profiler_data" | | profile_export_type | 导出格式 | text、db (可减少约70%磁盘空间),默认值text | | profile_step_start | 开启采集数据的步骤 | 任意正整数,默认为1,profile_step_start从1开始 | -| profile_step_end | 结束采集数据的步骤 | 任意正整数,默认为2,实际采集步数为 profile_step_end-profile_step_start,不包含profile_step_end | +| profile_step_end | 结束采集数据的步骤 | 任意正整数,默认为2,实际采集步数不包含profile_step_end,采集总步数为profile_step_end-profile_step_start | | profile_level | 采集级别 | level_none、level0、level1、level2,默认值level0 | | profile_with_memory | 内存分析开关 | true/false,默认值false,启用/关闭内存分析 | | profile_record_shapes | 张量形状记录开关 | true/false,默认值false,是否记录张量形状 | diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/profiler.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/profiler.py index b12cf761e4..763df78da8 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/profiler.py +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/profiler.py @@ -1,9 +1,25 @@ +# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copied from https://gitee.com/ascend/MindSpeed-RL/blob/master/mindspeed_rl/utils/utils.py import os from functools import wraps import torch import torch_npu + def get_grpo_profiler(profiler_config, role: str = None): if not profiler_config or not profiler_config.profile: return None -- Gitee