From a25dc434d49267268ae5d28d2c4e0cfd9089f05a Mon Sep 17 00:00:00 2001 From: wangqihui01 Date: Mon, 9 Jun 2025 16:08:37 +0800 Subject: [PATCH] [misc] add support for qwen3 model (dense/moe) --- .../tests/verl/test_flops_counter.py | 126 ++++++++++++++++++ .../verl/utils/flops_counter.py | 42 +++++- 2 files changed, 166 insertions(+), 2 deletions(-) create mode 100644 PyTorch/built-in/rl/VeRL_for_PyTorch/tests/verl/test_flops_counter.py diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/tests/verl/test_flops_counter.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/tests/verl/test_flops_counter.py new file mode 100644 index 0000000000..437593e831 --- /dev/null +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/tests/verl/test_flops_counter.py @@ -0,0 +1,126 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import json +import pytest +from verl.utils.flops_counter import FlopsCounter + +VALID_CONFIG_TYPE = {"llama", "qwen2", "qwen3", "qwen3_moe", "deepseek_v3"} + + +class Config: + def __init__(self, config_dict): + for key, value in config_dict.items(): + setattr(self, key, value) + + +CONFIG = { + "llama": { + "config": { # llama2-7B + "model_type": "llama", + "vocab_size": 32000, + "hidden_size": 4096, + "intermediate_size": 11008, + "num_hidden_layers": 32, + "num_attention_heads": 32, + "num_key_value_heads": 32, + }, + "batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]), + "expected_flops_tuple": (153555818250240 / 1e12, 575955114393600 / 1e12), + }, + "qwen2": { + "config": { # Qwen/Qwen2.5-7B-Instruct + "model_type": "qwen2", + "vocab_size": 152064, + "hidden_size": 3584, + "intermediate_size": 18944, + "num_hidden_layers": 28, + "num_attention_heads": 28, + "num_key_value_heads": 4, + }, + "batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]), + "expected_flops_tuple": (170388331954176 / 1e12, 622070178250752 / 1e12), + }, + "qwen3": { + "config": { # Qwen/Qwen3-8B + "model_type": "qwen3", + "vocab_size": 151936, + "hidden_size": 4096, + "intermediate_size": 12288, + "num_hidden_layers": 36, + "num_attention_heads": 32, + "num_key_value_heads": 8, + "head_dim": 128, + }, + "batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]), + "expected_flops_tuple": (185867930959872 / 1e12, 692924253732864 / 1e12), + }, + "qwen3_moe": { + "config": { # Qwen/Qwen3-30B-A3B-Base + "model_type": "qwen3_moe", + "hidden_size": 2048, + "vocab_size": 151936, + "num_hidden_layers": 48, + "num_key_value_heads": 4, + "num_attention_heads": 32, + "head_dim": 128, + "moe_intermediate_size": 768, + "num_experts_per_tok": 8, + "num_experts": 128, + }, + "batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]), + "expected_flops_tuple": (85087060230144 / 1e12, 365944098521088 / 1e12), + }, + "deepseek_v3": { + "config": { # deepseek-ai/DeepSeek-Prover-V2-671B + "model_type": "deepseek_v3", + "hidden_size": 7168, + "vocab_size": 129280, + "moe_intermediate_size": 2048, + "num_hidden_layers": 61, + "first_k_dense_replace": 3, + "num_attention_heads": 128, + "n_routed_experts": 256, + "num_experts_per_tok": 8, + "n_shared_experts": 1, + "kv_lora_rank": 512, + "qk_rope_head_dim": 64, + "v_head_dim": 128, + "intermediate_size": 18432, + "qk_nope_head_dim": 128, + "q_lora_rank": 1536 + + }, + "batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]), + "expected_flops_tuple": (906535995703296 / 1e12, 3674028304760832 / 1e12), + }, +} + + +@pytest.mark.parametrize( + "config_type", + ["llama", "qwen2", "qwen3", "qwen3_moe", "deepseek_v3"], +) +def test_flops_counter(config_type: str): + test_config = CONFIG[config_type] + config = Config(test_config["config"]) + flops_counter = FlopsCounter(config) + for batch_seqlens, expected_flops in zip(test_config["batch_seqlens_tuple"], test_config["expected_flops_tuple"]): + # set delta time to 1 to get the flops + counted_flops, _ = flops_counter.estimate_flops(batch_seqlens, 1) + print(f"Expect flops for {test_config['config']} is {expected_flops}, but get {counted_flops}") + assert math.isclose( + counted_flops, expected_flops + ), f"Expect flops for {test_config['config']} is {expected_flops}, but get {counted_flops}" diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/flops_counter.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/flops_counter.py index 795fb18366..a1693c35b3 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/flops_counter.py +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/flops_counter.py @@ -16,7 +16,7 @@ import torch from transformers import PretrainedConfig from verl.utils.device import get_torch_device -VALID_CONFIG_TYPE = {"llama", "qwen2", "qwen2_vl", "qwen2_5_vl", "deepseek_v3"} +VALID_CONFIG_TYPE = {"llama", "qwen2", "qwen2_vl", "qwen2_5_vl", "qwen3", "qwen3_moe", "deepseek_v3"} def get_device_flops(unit="T"): @@ -72,6 +72,8 @@ class FlopsCounter: 'llama': self._estimate_qwen2_flops, 'qwen2_vl': self._estimate_qwen2_flops, 'qwen2_5_vl': self._estimate_qwen2_flops, + "qwen3": self._estimate_qwen2_flops, + "qwen3_moe": self._estimate_qwen3_moe_flops, 'deepseek_v3': self._estimate_deepseek_v3_flops, } self.config = config @@ -87,7 +89,7 @@ class FlopsCounter: num_attention_heads = self.config.num_attention_heads intermediate_size = self.config.intermediate_size - head_dim = hidden_size // num_attention_heads + head_dim = getattr(self.config, "head_dim", self.config.hidden_size // self.config.num_attention_heads) q_size = num_attention_heads * head_dim k_size = num_key_value_heads * head_dim v_size = num_key_value_heads * head_dim @@ -162,6 +164,42 @@ class FlopsCounter: return flops_achieved + def _estimate_qwen3_moe_flops(self, tokens_sum, batch_seqlens, delta_time): + hidden_size = self.config.hidden_size + vocab_size = self.config.vocab_size + num_hidden_layers = self.config.num_hidden_layers + num_key_value_heads = self.config.num_key_value_heads + num_attention_heads = self.config.num_attention_heads + moe_intermediate_size = self.config.moe_intermediate_size + moe_topk = self.config.num_experts_per_tok + num_experts = self.config.num_experts + + head_dim = getattr(self.config, "head_dim", self.config.hidden_size // self.config.num_attention_heads) + q_size = num_attention_heads * head_dim + k_size = num_key_value_heads * head_dim + v_size = num_key_value_heads * head_dim + + # non-attn per layer parm + # gate + moe export + moe_mlp_N = hidden_size * moe_topk * moe_intermediate_size * 3 + hidden_size * num_experts + attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim) + emd_and_lm_head_N = vocab_size * hidden_size * 2 + # non-attn all_layer parm + dense_N = (moe_mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N + # non-attn all_layer & all_token fwd & bwd flops + dense_N_flops = 6 * dense_N * tokens_sum + + # attn all_layer & all_token fwd & bwd flops + seqlen_square_sum = 0 + for seqlen in batch_seqlens: + seqlen_square_sum += seqlen * seqlen + attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers + + # all_layer & all_token fwd & bwd flops + flops_all_token = dense_N_flops + attn_qkv_flops + flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12 + return flops_achieved + def estimate_flops(self, batch_seqlens, delta_time): """ Estimate the FLOPS based on the number of valid tokens in the current batch and the time taken. -- Gitee