From 20911289bdedc8b9d0b11a80fa06a3eb8637bebf Mon Sep 17 00:00:00 2001 From: sunboquan Date: Tue, 17 Jun 2025 20:40:13 +0800 Subject: [PATCH] ascend adapter --- .../.github/workflows/e2e_ascend.yml | 51 +++++++++++++++++-- .../recipe/dapo/src/main_dapo.py | 1 + .../rl/VeRL_for_PyTorch/requirements-npu.txt | 10 ++-- .../rl/VeRL_for_PyTorch/verl/__init__.py | 6 ++- .../verl/models/transformers/qwen2_vl.py | 1 + .../rl/VeRL_for_PyTorch/verl/protocol.py | 1 + .../verl/single_controller/base/worker.py | 1 + .../verl/trainer/main_generation.py | 4 +- .../utils/checkpoint/checkpoint_manager.py | 12 ++++- .../verl/workers/rollout/hf_rollout.py | 4 +- 10 files changed, 76 insertions(+), 15 deletions(-) diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/.github/workflows/e2e_ascend.yml b/PyTorch/built-in/rl/VeRL_for_PyTorch/.github/workflows/e2e_ascend.yml index 24b7f06c18..fc10ff7c04 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/.github/workflows/e2e_ascend.yml +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/.github/workflows/e2e_ascend.yml @@ -17,11 +17,11 @@ jobs: test: name: verl Ascend test (self-host) runs-on: [self-hosted, npu-0] - timeout-minutes: 5 # Increase this timeout value as needed + timeout-minutes: 30 # Increase this timeout value as needed env: HF_HUB_ENABLE_HF_TRANSFER: 1 container: - image: quay.io/ascend/cann:8.0.0-910b-ubuntu22.04-py3.10 + image: quay.io/ascend/cann:8.1.rc1-910b-ubuntu22.04-py3.10 volumes: - /usr/local/dcmi:/usr/local/dcmi - /usr/local/bin/npu-smi:/usr/local/bin/npu-smi @@ -35,6 +35,13 @@ jobs: --device /dev/hisi_hdc --privileged --network "host" + --shm-size 2g + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP}} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS}} + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" + HF_HUB_ENABLE_HF_TRANSFER: "0" #this is more stable steps: - name: Check npu and CANN info run: | @@ -42,6 +49,42 @@ jobs: npu-smi info - name: Checkout volcengine/verl repo uses: actions/checkout@v4 - - name: Run test + - name: Install torch + run: | + pip install torch==2.5.1+cpu --index-url https://download.pytorch.org/whl/cpu + pip install torch-npu==2.5.1 + pip install /usr/local/Ascend/ascend-toolkit/latest/lib64/te-0.4.0-py3-none-any.whl + - name: Install vllm + run: | + apt-get update && apt-get install -y git + git clone -b v0.7.3 --depth 1 https://github.com/vllm-project/vllm.git vllm-npu + cd vllm-npu + pip install -r requirements-build.txt + VLLM_TARGET_DEVICE=empty pip install -e . --extra-index https://download.pytorch.org/whl/cpu/ + - name: Install vllm-ascend + run: | + pip list + pip show torch + git clone -b v0.7.3 --depth 1 https://github.com/vllm-project/vllm-ascend.git + cd vllm-ascend + export COMPILE_CUSTOM_KERNELS=1 + python setup.py install + - name: Install the current repository + run: | + pip3 install hf_transfer peft + pip3 install -r requirements-npu.txt + pip install -e . + - name: Prepare gsm8k dataset + run: | + ray stop --force + python3 examples/data_preprocess/gsm8k.py + - name: Running gsm8k e2e training tests with LoRA on ASCEND NPU + run: | + ray stop --force + bash tests/e2e/sft/run_sft.sh + rm -rf $HOME/ckpts + - name: Running gsm8k e2e training tests with GRPO on ASCEND NPU run: | - lscpu + ray stop --force + bash tests/npu/run_qwen2_5_05b_grpo.sh + rm -rf $HOME/ckpts diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/recipe/dapo/src/main_dapo.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/recipe/dapo/src/main_dapo.py index 2a98dc81b1..33bf980544 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/recipe/dapo/src/main_dapo.py +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/recipe/dapo/src/main_dapo.py @@ -15,6 +15,7 @@ Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. """ from .dapo_ray_trainer import RayDAPOTrainer +from verl.utils.device import is_cuda_available import os import ray diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/requirements-npu.txt b/PyTorch/built-in/rl/VeRL_for_PyTorch/requirements-npu.txt index a01d1514a4..003243e03d 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/requirements-npu.txt +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/requirements-npu.txt @@ -4,15 +4,17 @@ codetiming datasets dill hydra-core -numpy==1.26.4 +numpy pandas peft pyarrow>=15.0.0 pybind11 pylatexenc ray -tensordict<0.6 -transformers>=4.51.0 +tensordict<=0.6.2 +transformers>=4.52.0 +wandb mathruler torchdata -wandb \ No newline at end of file +einops +qwen_vl_utils \ No newline at end of file diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/__init__.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/__init__.py index a867d6a22d..0ec1972f84 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/__init__.py +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/__init__.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os -import math import pkg_resources from pkg_resources import DistributionNotFound from packaging.version import parse as parse_version +from .protocol import DataProto +from .utils.logging_utils import set_basic_config from .utils.device import is_npu_available + version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__))) with open(os.path.join(version_folder, 'version/version')) as f: @@ -56,3 +59,4 @@ if os.getenv('VERL_USE_MODELSCOPE', 'False').lower() == 'true': # Patch hub to download models from modelscope to speed up. from modelscope.utils.hf_util import patch_hub patch_hub() + diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/models/transformers/qwen2_vl.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/models/transformers/qwen2_vl.py index 97ee7f3466..9ebb21ef0c 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/models/transformers/qwen2_vl.py +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/models/transformers/qwen2_vl.py @@ -24,6 +24,7 @@ from verl.utils.ulysses import gather_heads_scatter_seq, gather_seq_scatter_head try: from transformers.modeling_flash_attention_utils import flash_attn_func, flash_attn_varlen_func + _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) except ImportError: flash_attn_varlen_func = None diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/protocol.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/protocol.py index d68895cdd3..b2cea32a42 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/protocol.py +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/protocol.py @@ -29,6 +29,7 @@ from tensordict import TensorDict from torch.utils.data import DataLoader, Dataset from verl.utils.py_functional import union_two_dict +from verl.utils.torch_functional import allgather_dict_tensors from verl.utils.device import get_torch_device __all__ = ['DataProto', 'union_tensor_dict'] diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/single_controller/base/worker.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/single_controller/base/worker.py index fc42078e25..a03bfd7e1a 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/single_controller/base/worker.py +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/single_controller/base/worker.py @@ -19,6 +19,7 @@ import socket from dataclasses import dataclass from ...utils.device import get_device_name from .decorator import register, Dispatch, Execute +from verl.utils.device import get_torch_device @dataclass diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/trainer/main_generation.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/trainer/main_generation.py index 75d986c83c..9395ceb576 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/trainer/main_generation.py +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/trainer/main_generation.py @@ -18,7 +18,7 @@ import ray import numpy as np import hydra import os - +from verl.utils.device import is_cuda_available os.environ['NCCL_DEBUG'] = 'WARN' os.environ['TOKENIZERS_PARALLELISM'] = 'true' # os.environ['TORCH_COMPILE_DISABLE'] = '1' @@ -76,7 +76,7 @@ def main_task(config): ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role='rollout') resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes) - wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init) + wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, device_name="cuda" if is_cuda_available else "npu") wg.init_model() total_samples = len(dataset) diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/checkpoint/checkpoint_manager.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/checkpoint/checkpoint_manager.py index af32e85d97..3d8befb4e8 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/checkpoint/checkpoint_manager.py +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/checkpoint/checkpoint_manager.py @@ -22,7 +22,7 @@ from transformers import PreTrainedTokenizer, ProcessorMixin import numpy as np import random import re - +from verl.utils.device import is_cuda_available, is_npu_available class BaseCheckpointManager: """ @@ -111,14 +111,22 @@ class BaseCheckpointManager: 'numpy': np.random.get_state(), 'random': random.getstate(), } + if is_cuda_available: + rng_state["cuda"] = torch.cuda.get_rng_state() + elif is_npu_available: + rng_state["npu"] = torch.npu.get_rng_state() + return rng_state @staticmethod def load_rng_state(rng_state): torch.set_rng_state(rng_state['cpu']) - torch.cuda.set_rng_state(rng_state['cuda']) np.random.set_state(rng_state['numpy']) random.setstate(rng_state['random']) + if is_cuda_available: + torch.cuda.set_rng_state(rng_state["cuda"]) + elif is_npu_available: + torch.npu.set_rng_state(rng_state["npu"]) def find_latest_ckpt_path(path, directory_format="global_step_{}"): diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/workers/rollout/hf_rollout.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/workers/rollout/hf_rollout.py index 061a76180c..c009872284 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/workers/rollout/hf_rollout.py +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/workers/rollout/hf_rollout.py @@ -26,7 +26,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from verl import DataProto from verl.utils.torch_functional import get_response_mask from .base import BaseRollout - +from verl.utils.device import get_torch_device from transformers import GenerationConfig __all__ = ['HFRollout'] @@ -136,7 +136,7 @@ class HFRollout(BaseRollout): batch_size=batch_size) # empty cache before compute old_log_prob - torch.cuda.empty_cache() + get_torch_device().empty_cache() self.module.train() return DataProto(batch=batch) -- Gitee