diff --git a/PyTorch/built-in/foundation/GPT-NeoX/README.md b/PyTorch/built-in/foundation/GPT-NeoX/README.md index a073aaee61b372659cb1699631f6db7812e49bbd..608b1e91605802d3267109ed5b862e1e44137c15 100644 --- a/PyTorch/built-in/foundation/GPT-NeoX/README.md +++ b/PyTorch/built-in/foundation/GPT-NeoX/README.md @@ -1,405 +1,254 @@ -[![GitHub issues](https://img.shields.io/github/issues/EleutherAI/gpt-neox)](https://github.com/EleutherAI/gpt-neox/issues) -[Weights & Biases monitoring](https://wandb.ai/eleutherai/neox) - # GPT-NeoX -This repository records [EleutherAI](https://www.eleuther.ai)'s library for training large-scale language models on GPUs. Our current framework is based on NVIDIA's [Megatron Language Model](https://github.com/NVIDIA/Megatron-LM) and has been augmented with techniques from [DeepSpeed](https://www.deepspeed.ai) as well as some novel optimizations. We aim to make this repo a centralized and accessible place to gather techniques for training large-scale autoregressive language models, and accelerate research into large-scale training. - -For those looking for a TPU-centric codebase, we recommend [Mesh Transformer JAX](https://github.com/kingoflolz/mesh-transformer-jax). - -**If you are not looking to train models with billions of parameters from scratch, this is likely the wrong library to use. For generic inference needs, we recommend you use the Hugging Face `transformers` library instead which supports GPT-NeoX models.** - -# Contents - -* [Quick Start](#quick-start) - * [Environment and Dependencies](#environment-and-dependencies) - * [Usage](#usage) -* [Configuration](#configuration) -* [Datasets](#datasets) - * [Preconfigured Datasets](#preconfigured-datasets) - * [Using Custom Data](#using-custom-data) -* [Training and Finetuning](#training-and-finetuning) - * [Select Pretrained Models](#pretrained-models) - * [GPT-NeoX-20B](#gpt-neox-20b) - * [Pythia](#pythia) - * [Polyglot](#polyglot) - * [Fill-in-the-Middle](#fill-in-the-middle) -* [Inference](#inference) -* [Evaluation](#evaluation) -* [Exporting to Hugging Face](#exporting-to-hugging-face) -* [Monitoring](#monitoring) - * [Weights & Biases](#wandb) - * [TensorBoard](#tensorboard) -* [Administrative Notes](#administrative-notes) - * [Citing GPT-NeoX](#citing-gpt-neox) - * [Licensing](#licensing) - * [Publications](#publications) - * [Acknowledgements](#acknowledgements) - -# Quick Start - -## Environment and Dependencies - -### Host Setup - -First make sure you are in an environment with Python 3.8 with an appropriate version of PyTorch 1.8 or later installed. **Note:** Some of the libraries that GPT-NeoX depends on have not been updated to be compatible with Python 3.10+. Python 3.9 appears to work, but this codebase has been developed and tested for Python 3.8. - -To install the remaining basic dependencies, run: - -```bash -pip install -r requirements/requirements.txt -python ./megatron/fused_kernels/setup.py install # optional if not using fused kernels -``` - -from the repository root. - - - -### TensorBoard -======= -### Flash Attention - -To use [Flash-Attention](https://github.com/HazyResearch/flash-attention), install the additional dependencies in `./requirements/requirements-flashattention.txt` and set the attention type in your configuration accordingly (see [configs](./configs/)). This can provide significant speed-ups over regular attention on certain GPU architectures, including Ampere GPUs (such as A100s); see the repository for more details. - - -### Containerized Setup - -We also provide a Dockerfile if you prefer to run NeoX in a container. To use this option, first build an image named `gpt-neox` from the repository root directory with `docker build -t gpt-neox -f Dockerfile .`. We also host pre-built images on [Docker Hub at `leogao2/gpt-neox`](https://hub.docker.com/r/leogao2/gpt-neox/tags). - -You can then run a container based on this image. For instance, the below snippet mounts the cloned repository (`gpt-neox`) directory to `/gpt-neox` in the container and uses [nvidia-docker](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) to make four GPUs (numbers 0-3) accessible to the container. [As noted by the NCCL documentation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/troubleshooting.html#sharing-data), both `--shm-size=1g` and `--ulimit memlock=-1` are important to prevent Docker from allocating too little shared memory. -``` -nvidia-docker run --rm -it -e NVIDIA_VISIBLE_DEVICES=0,1,2,3 --shm-size=1g --ulimit memlock=-1 --mount type=bind,src=$PWD,dst=/gpt-neox gpt-neox -``` - -## Usage - -All functionality (inference included), should be launched using `deepy.py`, a wrapper around the `deepspeed` launcher. - -We currently offer three main functions: -1. `train.py` is used for training and finetuning models. -2. `evaluate.py` is used to evaluate a trained model using the [language model evaluation harness](https://github.com/EleutherAI/lm-evaluation-harness). -3. `generate.py` is used to sample text from a trained model. - -which can be launched with: - -```bash -./deepy.py [script.py] [./path/to/config_1.yml] [./path/to/config_2.yml] ... [./path/to/config_n.yml] -``` - -E.G To generate text unconditionally with the GPT-NeoX-20B model, you can use the following: -```bash -./deepy.py generate.py ./configs/20B.yml -``` - -Or optionally pass in a text file (e.g `prompt.txt`) to use as the prompt, which should be a plain `.txt` file with each prompt separated by newline characters, also passing in the path to an output file. - -```bash -./deepy.py generate.py ./configs/20B.yml -i prompt.txt -o sample_outputs.txt -``` - -To reproduce our evaluation numbers on, for example, TriviaQA and PIQA use: - -```bash -./deepy.py evaluate.py ./configs/20B.yml --eval_tasks triviaqa piqa -``` - -You can add an arbitrary list of evaluation tasks here, for details of all tasks available, see [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness). - -For more details on each entry point, see the [Training and Finetuning](#training-and-finetuning), [Inference](#inference) and [Evaluation](#evaluation) -# Configuration - -GPT-NeoX parameters are defined in a YAML configuration file which is passed to the deepy.py launcher. We have provided some example .yaml files in [configs](./configs/), including one for GPT-NeoX-20B, and example configuration files for other model sizes. - -These files are generally complete, but non-optimal. For example, depending on your specific GPU configuration, you may need to change some settings such as `pipe-parallel-size`, `model-parallel-size` to increase or decrease the degree of parallelisation, `train_micro_batch_size_per_gpu` or `gradient-accumulation-steps` to modify batch size related settings, or the `zero_optimization` dict to modify how optimizer states are parallelised across workers. - -For a more detailed guide to all the features available and how to configure them, see [the configuration README](configs/README.md), and for documentation of every possible argument, see [configs/neox_arguments.md](configs/neox_arguments.md). - -# Datasets - -## Preconfigured Datasets - -Several preconfigured datasets are available, including most components from [the Pile](https://arxiv.org/abs/2101.00027), as well as the Pile train set itself, for straightforward tokenization using the `prepare_data.py` entry point. - -E.G, to download and tokenize the Enron emails corpus with the GPT2 Tokenizer, saving them to `./data` you can run: - -``` -python prepare_data.py -d ./data -``` - -or with the GPT-NeoX-20B tokenizer (assuming you have it saved at `./20B_checkpoints/20B_tokenizer.json`): - -``` -python prepare_data.py -d ./data -t HFTokenizer --vocab-file ./20B_checkpoints/20B_tokenizer.json -``` - -The tokenized data will be saved out to two files: `[data-dir]/[dataset-name]/[dataset-name]_text_document.bin`and `[data-dir]/[dataset-name]/[dataset-name]_text_document.idx`. You will need to add the prefix that both these files share to your training configuration file under the `data-path` field. E.G: - -```yaml - "data-path": "./data/enron/enron_text_document", -``` - -## Using Custom Data - -To prepare your own dataset for training with custom data, format it as one large [jsonl](https://jsonlines.org/)-formatted file with each item in the list of dictionaries being a separate document. The document text should be grouped under one JSON key, i.e `"text"`. Any auxiliary data stored in other fields will not be used. - -Next make sure to download the GPT2 tokenizer vocab, and merge files from the following links: - -- Vocab: https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json -- Merge: https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt - -Or use the 20B tokenizer (for which only a single Vocab file is needed): - -- Vocab: https://the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/20B_tokenizer.json - -(alternatively, you can provide any tokenizer file that can be loaded by Hugging Face's tokenizers library with the `Tokenizer.from_pretrained()` command) - -You can now pretokenize your data using `tools/preprocess_data.py`, the arguments for which are detailed below: - -``` -usage: preprocess_data.py [-h] --input INPUT [--jsonl-keys JSONL_KEYS [JSONL_KEYS ...]] [--num-docs NUM_DOCS] --tokenizer-type {HFGPT2Tokenizer,HFTokenizer,GPT2BPETokenizer,CharLevelTokenizer} [--vocab-file VOCAB_FILE] [--merge-file MERGE_FILE] [--append-eod] [--ftfy] --output-prefix OUTPUT_PREFIX - [--dataset-impl {lazy,cached,mmap}] [--workers WORKERS] [--log-interval LOG_INTERVAL] - -optional arguments: - -h, --help show this help message and exit - -input data: - --input INPUT Path to input jsonl files or lmd archive(s) - if using multiple archives, put them in a comma separated list - --jsonl-keys JSONL_KEYS [JSONL_KEYS ...] - space separate listed of keys to extract from jsonl. Defa - --num-docs NUM_DOCS Optional: Number of documents in the input data (if known) for an accurate progress bar. - -tokenizer: - --tokenizer-type {HFGPT2Tokenizer,HFTokenizer,GPT2BPETokenizer,CharLevelTokenizer} - What type of tokenizer to use. - --vocab-file VOCAB_FILE - Path to the vocab file - --merge-file MERGE_FILE - Path to the BPE merge file (if necessary). - --append-eod Append an token to the end of a document. - --ftfy Use ftfy to clean text - -output data: - --output-prefix OUTPUT_PREFIX - Path to binary output file without suffix - --dataset-impl {lazy,cached,mmap} - Dataset implementation to use. Default: mmap - -runtime: - --workers WORKERS Number of worker processes to launch - --log-interval LOG_INTERVAL - Interval between progress updates - -``` - -For example: - -```bash -python tools/preprocess_data.py \ - --input ./data/mydataset.jsonl.zst \ - --output-prefix ./data/mydataset \ - --vocab ./data/gpt2-vocab.json \ - --merge-file gpt2-merges.txt \ - --dataset-impl mmap \ - --tokenizer-type GPT2BPETokenizer \ - --append-eod -``` - -You would then run training with the following settings added to your configuration file: - -```yaml - "data-path": "data/mydataset/mydataset", -``` - -# Training and Finetuning - -Training is launched using `deepy.py`, a wrapper around DeepSpeed's launcher, which launches the same script in parallel across many GPUs / nodes. - -The general usage pattern is: - -```bash -python ./deepy.py train.py [path/to/config1.yml] [path/to/config2.yml] ... -``` - -You can pass in an arbitrary number of configs which will all be merged at runtime. - -You can also optionally pass in a config prefix, which will assume all your configs are in the same folder and append that prefix to their path. - -E.G: - -```bash -python ./deepy.py train.py -d configs small.yml local_setup.yml -``` - -This will deploy the `train.py` script on all nodes with one process per GPU. The worker nodes and number of GPUs are specified in the `/job/hostfile` file (see [parameter documentation](configs/README.md)), or can simply be passed in as the `num_gpus` arg if running on a single node setup. - -Although this is not strictly necessary, we find it useful to define the model parameters in one config file (e.g `configs/small.yml`) and the data path parameters in another (e.g `configs/local_setup.yml`). - - -## Pretrained Models - -### GPT-NeoX-20B - -GPT-NeoX-20B is a 20 billion parameter autoregressive language model trained on [the Pile](https://arxiv.org/abs/2101.00027). Technical details about GPT-NeoX-20B can be found in [the associated paper](https://arxiv.org/abs/2204.06745). The configuration file for this model is both available at [`./configs/20B.yml`](./configs/20B.yml) and included in the download links below. - -[Slim weights](https://the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/) - (No optimizer states, for inference or finetuning, 39GB) - -To download from the command line to a folder named `20B_checkpoints`, use the following command: - -```bash -wget --cut-dirs=5 -nH -r --no-parent --reject "index.html*" https://the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/ -P 20B_checkpoints -``` - -[Full weights](https://the-eye.eu/public/AI/models/GPT-NeoX-20B/full_weights/) - (Including optimizer states, 268GB) - -To download from the command line to a folder named `20B_checkpoints`, use the following command: - -```bash -wget --cut-dirs=5 -nH -r --no-parent --reject "index.html*" https://the-eye.eu/public/AI/models/GPT-NeoX-20B/full_weights/ -P 20B_checkpoints -``` - -Weights can be alternatively be downloaded using a BitTorrent client. Torrent files can be downloaded here: [slim weights](https://the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights.torrent), [full weights](https://the-eye.eu/public/AI/models/GPT-NeoX-20B/full_weights.torrent). - -We additionally have 150 checkpoints saved throughout training, one every 1,000 steps. We are working on figuring out how to best serve these at scale, but in the meanwhile people interested in working with the partially trained checkpoints can email us at contact@eleuther.ai to arrange access. - -### Pythia - -The Pythia Scaling Suite is a suite of models ranging from 19M parameters to 13B parameters trained on [the Pile](pile.eleuther.ai) intended to promote research on interpretability and training dynamics of large language models. Further details about the project and links to the models can be found [here](https://github.com/EleutherAI/pythia). - -### Polyglot - -The Polyglot Project is an effort to train powerful non-English pretrained language models to promote the accessibility of this technology to researchers outside the dominant powerhouses of machine learning. EleutherAI has trained and released 1.3B, 3.8B, and 5.8B parameter Korean language models, the largest of which outpreforms all other publicly available language models on Korean language tasks. Further details about the project and links to the models can be found [here](https://github.com/EleutherAI/polyglot). - -### Fill-in-the-Middle - -EleutherAI's [Carper lab](https://www.carper.ai) has also used this codebase to train models using FIM (fill-in-the-middle), a data transformation proposed in [Bavarian et al. 2022](https://arxiv.org/abs/2207.14255) with a similar technique also used by [Fried et al.](https://arxiv.org/abs/2204.05999) and [Aghajanyan et al. 2022](https://arxiv.org/abs/2201.07520), to enable typically autoregressive left-to-right language models to perform text infilling conditioned on both "left" and "right" context. A 1.3B parameter model trained on [the Pile](pile.eleuther.ai) is available [here](https://huggingface.co/CarperAI/FIM-NeoX-1.3B), with further experiments and and models forthcoming. - -# Inference - -**For most uses we recommend deploying models trained using the GPT-NeoX library via the Hugging Face Transformers library which is better optimized for inference.** - -We support three types of generation from a pretrained model: -1. Unconditional generation -2. Conditional generation based on an input read from a file -3. Interactive generation, which allows for multiple rounds of back-and-forth between a user and the language model via a command line interface - -All three types of text generation can be launched via `python ./deepy.py generate.py -d configs small.yml local_setup.yml text_generation.yml` with the appropriate values set in `configs/text_generation.yml`. - -# Evaluation - -GPT-NeoX supports evaluation on downstream tasks through the [language model evaluation harness](https://github.com/EleutherAI/lm-evaluation-harness). - -To evaluate a trained model on the evaluation harness, simply run: - -```bash -python ./deepy.py evaluate.py -d configs your_configs.yml --eval_tasks task1 task2 ... taskn -``` - -where `--eval_tasks` is a list of evaluation tasks followed by spaces, e.g `--eval_tasks lambada hellaswag piqa sciq`. For details of all tasks available, refer to the [lm-evaluation-harness repo](https://github.com/EleutherAI/lm-evaluation-harness). - -# Exporting to Hugging Face - -GPT-NeoX is optimized heavily for training only, and GPT-NeoX model checkpoints are not compatible out of the box with other deep learning libraries. To make models easily loadable and shareable with end users, and for further exporting to various other frameworks, GPT-NeoX supports checkpoint conversion to the [Hugging Face Transformers](https://arxiv.org/abs/1910.03771) GPTNeoXModel format. - -To convert a NeoX checkpoint to Hugging Face-loadable format, run: -```bash -python ./tools/convert_to_hf.py --input_dir /path/to/model/global_stepXXX --config_file your_config.yml --output_dir hf_model/save/location -``` -Then to upload a model to [the Hugging Face Hub](https://huggingface.co/), run: -``` -huggingface-cli login -python ./tools/upload.py -``` -and input the requested information, including HF hub user token. - -Note, however, that this compatibility is not one-to-one, and only certain configurations from GPT-NeoX are supported in the Hugging Face GPTNeoXModel class. Advanced features such as alternative positional embeddings may require new Transformers modeling code and new conversion script tweaks. - -# Monitoring +- [概述](概述.md) +- [准备训练环境](准备训练环境.md) +- [开始训练](开始训练.md) +- [训练结果展示](训练结果展示.md) +- [版本说明](版本说明.md) -In addition to storing logs locally, we provide built-in support for two popular experiment monitoring frameworks: [Weights & Biases](https://wandb.ai/site) and [TensorBoard](https://www.tensorflow.org/tensorboard/) -

Weights & Biases

-EleutherAI is currently using [Weights & Biases to record our experiments](https://wandb.ai/eleutherai/neox). If you are logged into Weights & Biases on your machine—you can do this by executing `wandb login`—your runs will automatically be recorded. There are two optional fields associated with Weights & Biases: wandb_group allows you to name the run group and wandb_team allows you to assign your runs to an organization or team account. +# 概述 -## TensorBoard +## 简述 -We also support using TensorBoard via the tensorboard-dir field. Dependencies required for TensorBoard monitoring can be found in and installed from `./requirements/requirements-tensorboard.txt`. +GPT-NeoX-20B 是由EleutherAI和Hugging face合作开发的一个超大规模的语言模型,它采用了分布式训练和轻量级架构等技术,同时也有很高的精度和效率。 -# Running on multi-node +- 参考实现: -If you need to supply a hostfile for use with the MPI-based DeepSpeed launcher, you can set the environment variable `DLTS_HOSTFILE` to point to the hostfile. + ``` + url=https://github.com/EleutherAI/gpt-neox/tree/v2.0 + commit_id=9610391ab319403cef079b438edd016a2443af54 + ``` -# Administrative Notes +- 适配昇腾 AI 处理器的实现: + + ``` + url=https://gitee.com/ascend/ModelZoo-PyTorch/tree/master/PyTorch/built-in/foundation/GPT-NeoX + code_path=PyTorch/built-in/foundation/GPT-NeoX + ``` + + +# 准备训练环境 + +## 准备环境 + +- 当前模型支持的固件与驱动、 CANN 以及 PyTorch 如下表所示。 + + **表 1** 版本配套表 + + | 配套 | 版本 | + | --------- | ------------------------------------------------------------ | + | 固件与驱动 | 23.0.T13 | + | CANN | 6.1.RC2 | + | PyTorch |PyTorch 1.11| + | Python |Python 3.7.5| + + + +- 前模型支持的 PyTorch 版本和已知三方库依赖如下表所示。 + + **表 2** 版本支持表 + + | Torch_Version | 三方库依赖版本 | + |:---------------:| :----------------------------------------------------------: | + | PyTorch 1.11 | deepspeed 0.9.2 | + +- 环境准备指导。 + + 请参考《[Pytorch框架训练环境准备](https://www.hiascend.com/document/detail/zh/ModelZoo/pytorchframework/ptes)》。 + +- 安装依赖。 + +1. 安装基础依赖 + + 在模型源码包根目录下执行命令,安装模型对应PyTorch版本需要的依赖。 + ``` + pip install -r requirements.txt # PyTorch1.11版本 + ``` +2. 安装deepspeed_npu插件 + + ``` + # adaptor分支 + git clone https://gitee.com/ascend/DeepSpeed.git + cd Deepspeed + pip3 install ./ + ``` +3. 安装pdsh插件 -## Citing GPT-NeoX + ``` + 获取pdsh-2.34源码并解压 + cd pdsh-2.34 + ./configure --with-ssh + make && make install + ``` + -If you have found the GPT-NeoX library helpful in your work, you can cite this repository as -```bibtex -@software{gpt-neox-library, - title = {{GPT-NeoX: Large Scale Autoregressive Language Modeling in PyTorch}}, - author = {Andonian, Alex and Anthony, Quentin and Biderman, Stella and Black, Sid and Gali, Preetham and Gao, Leo and Hallahan, Eric and Levy-Kramer, Josh and Leahy, Connor and Nestler, Lucas and Parker, Kip and Pieler, Michael and Purohit, Shivanshu and Songz, Tri and Phil, Wang and Weinbach, Samuel}, - url = {https://www.github.com/eleutherai/gpt-neox}, - doi = {10.5281/zenodo.5879544}, - month = {8}, - year = {2021}, - version = {0.0.1}, -} -``` -To cite our 20 billion parameter model, please use +## 准备数据集 -```bibtex -@inproceedings{gpt-neox-20b, - title={{GPT-NeoX-20B}: An Open-Source Autoregressive Language Model}, - author={Black, Sid and Biderman, Stella and Hallahan, Eric and Anthony, Quentin and Gao, Leo and Golding, Laurence and He, Horace and Leahy, Connor and McDonell, Kyle and Phang, Jason and Pieler, Michael and Prashanth, USVSN Sai and Purohit, Shivanshu and Reynolds, Laria and Tow, Jonathan and Wang, Ben and Weinbach, Samuel}, - booktitle={Proceedings of the ACL Workshop on Challenges \& Perspectives in Creating Large Language Models}, - url={https://arxiv.org/abs/2204.06745}, - year={2022} -} -``` +1. 获取数据集。 + ``` + 方法1: + 原始数据(480G):https://github.com/EleutherAI/the-pile + 下载源代码:git clone https://github.com/EleutherAI/the-pile.git + 下载数据集:1、进入the-pile目录;2、pip install -e;3、python the_pile/pile.py --interleave_output 30 --using pile_reprod + 下载完共30个文件、480G;单个文件15G、解压后43G;文件命名分别为:00.jsonl.zst~29.jsonl.zst + 方法2: + 数据集:https://opendatalab.com/ 可从此链接直接下载压缩后的数据集,需要下载的数据集参考方法1中的原始数据集 + 解压工具:https://github.com/facebook/zstd + 解析命令:1、tar -zxvf zstd-1.5.5.tar.gz;2、进入zstd-1.5.5目录执行:make && make install;3、解压缩.zst文件:zstd -d *.zst + ``` + +2. 词表 + ``` + 词表: + GPT2BPETokenizer: GPT2 tokenizer vocab, and merge files from the following links: + Vocab: https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json + Merge: https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt + + HFTokenizer: use the 20B tokenizer (for which only a single Vocab file is needed): + Vocab: https://the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/20B_tokenizer.json + + ``` +3. 数据预处理(按需处理所需要的数据集)。 + ``` + 依赖包:ujson、lm-dataformat、ftfy + source_path="定义数据路径" + out_path="输出路径" + word_path="词表路径" + source_data=" $source_path/00.jsonl,$source_path/01.jsonl,$source_path/02.jsonl,$source_path/03.jsonl,$source_path/04.jsonl,$source_path/05.jsonl,$source_path/06.jsonl,$source_path/07.jsonl,$source_path/08.jsonl, \ + $source_path/09.jsonl,$source_path/10.jsonl,$source_path/11.jsonl,$source_path/12.jsonl,$source_path/13.jsonl,$source_path/14.jsonl,$source_path/15.jsonl,$source_path/16.jsonl,$source_path/17.jsonl, \ + $source_path/18.jsonl,$source_path/19.jsonl,$source_path/20.jsonl,$source_path/21.jsonl,$source_path/22.jsonl,$source_path/23.jsonl,$source_path/24.jsonl,$source_path/25.jsonl,$source_path/26.jsonl, \ + $source_path/27.jsonl,$source_path/28.jsonl,$source_path/29.jsonl" + 预处理脚本 + python ./tools/preprocess_data.py \ + --input ${source_data} \ + --output-prefix ${out_path} \ + --vocab ${word_path}/gpt2-vocab.json \ + --merge-file ${word_path}/gpt2-merges.txt \ + --dataset-impl mmap \ + --tokenizer-type GPT2BPETokenizer \ + --append-eod \ + --workers 150 \ + > ./../logs/preprocess_pile_data.log 2>&1 & + OR + python ./tools/preprocess_data.py \ + --input ${source_data} \ + --output-prefix ${out_path} \ + --vocab ${word_path}/20B_tokenizer.json \ + --dataset-impl mmap \ + --tokenizer-type HFTokenizer \ + --append-eod \ + --workers 150 \ + > ./../logs/preprocess_pile_data.log 2>&1 & + 备注: + 1、预计处理时间:36h + 2、官方训练使用:HFTokenizer + ``` + + + +# 开始训练 + +## 训练模型 + +1. 进入解压后的源码包根目录。 + + ``` + cd /${模型文件夹名称} + ``` + +2. 运行训练脚本。 + + 该模型支持单机单卡训练和单机8卡训练。 + + - 单机单卡训练 + + 启动单卡训练 + + ``` + python ./deepy.py train.py -d configs 20B.yml #修改20B.yml文件,默认0卡: "num_gpus": 1, "global_num_gpus": 1, + ``` + + - 单机8卡训练 + + 启动8卡训练 + + ``` + python ./deepy.py train.py -d configs 20B.yml #修改20B.yml文件 + ``` + + + +3. 模型训练脚本参数说明如下。 + + ``` + # Tokenizer / checkpoint settings - you will need to change these to the location you have them saved in + "vocab-file": "./20B_checkpoints/20B_tokenizer.json", # 根据tokenizer_type 配置相应所需词表 + "save": "./20B_checkpoints", # ckpt保存路径 + "load": "./20B_checkpoints", # ckpt加载路径 + "data-path": "./data/pile_20B_tokenizer/pile_20B_tokenizer_text_document", # 数据集路径 + "pipe-parallel-size": 4, # 流水线并行 + "model-parallel-size": 2, # 模型并行,数据并行,自动计算 + # model settings + "num-layers": 44, + "hidden-size": 6144, + "num-attention-heads": 64, + "seq-length": 2048, + "max-position-embeddings": 2048, + "norm": "layernorm", + "pos-emb": "rotary", + "rotary_pct": 0.25, + # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training + "zero_optimization": { + "stage": 1, + "allgather_partitions": True, + "allgather_bucket_size": 1260000000, + "overlap_comm": True, + "reduce_scatter": True, + "reduce_bucket_size": 1260000000, + "contiguous_gradients": True, + }, + # batch / data settings (assuming 96 GPUs) + "train_micro_batch_size_per_gpu": 4, # per_batch_siize + "gradient_accumulation_steps": 32, # 梯度累积 + # activation checkpointing + "checkpoint-activations": true, # 重计算开关 + # misc. training settings + "train-iters": 150000, # 训练step数据 + "checkpoint-factor": 500, # this variable previously called `save-interval` # ckpt保存间隔 + "eval-interval": 1000, # 1000步一预估 + "eval-iters": 10, #训练结束 预估 + ``` + + 训练完成后,权重文件保存在当前路径下,并输出模型训练精度和性能信息。 -Citation instructions for other pretrained models can be found [in the appropriate repository](#pretrained-models). +# 训练结果展示 -## Licensing +**表 3** 训练结果展示表 -This repository hosts code that is part of EleutherAI's GPT-NeoX project. Copyright (c) 2021, EleutherAI. Licensed under the Apache License: - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 +| NAME | tflops | Iterations | DataType | Torch_Version | Card | +|:-------------:|:-------------:|:-:|:-:|:-:|:----:| +| GPU-2pp4mp2dp | 100 | 5000 | fp16 | 1.5 | A100 | +| NPU-2pp4mp2dp | 150 | 5000 | fp16 | 1.5 | 910B | +备注:一定要有竞品和NPU。 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +# 版本说明 -This repository is based off code written by NVIDIA that is licensed under the Apache License, Version 2.0. In accordance with the Apache License, all files that are modifications of code originally written by NVIDIA maintain a NVIDIA copyright header. All files that do not contain such a header are the exclusive copyright of EleutherAI. When the NVIDIA code has been modified from its original version, that fact is noted in the copyright header. All derivative works of this repository must preserve these headers under the terms of the Apache License. +## 变更 -This repository also contains code written by a number of other authors. Such contributions are marked and the relevant licensing is included where appropriate. +2023.07.07:首次发布。 -For full terms, see the `LICENSE` file. If you have any questions, comments, or concerns about licensing please email us at contact@eleuther.ai. +## 已知问题 -## Publications +**_当前发行版本中存在的问题描述。_** -The following publications have come out of this project: +无。 - - Black, Biderman, Hallahan, Anthony, Gao, Golding, He, Leahy, McDonell, Phang, Pieler, Prashanth, Purohit, Reynolds, Tow, Wang, and Weinbach. "[GPT-NeoX-20B: An Open-Source Autoregressive Language Model](https://arxiv.org/abs/2204.06745)." In *Proceedings of the ACL Workshop on Challenges \& Perspectives in Creating Large Language Models*. 2022. -The following publications by other research groups use this library: -- Chi, Fan, Ramadge, and Rudnicky. "[KERPLE: Kernelized Relative Positional Embedding for Length Extrapolation](https://arxiv.org/abs/2205.09921)". _arXiv preprint arXiv:2205.09921_. 2022. -- Horawalavithana, Ayton, Sharma, Howland, Subramanian, Vasquez, Cosbey, Glenski, and Volkova. "[Foundation Models of Scientific Knowledge for Chemistry: Opportunities, Challenges and Lessons Learned](https://openreview.net/pdf?id=SLX-I2MHUZ9)." In *Proceedings of the ACL Workshop on Challenges \& Perspectives in Creating Large Language Models*. 2022. -- Kolak, Martins, Le Goues, and Hellendoorn. "[Patch Generation with Language Models: Feasibility and Scaling Behavior](https://openreview.net/forum?id=rHlzJh_b1-5)"." In *Proceedings of the Deep Learning for Code Workshop at ICLR*. 2022. -- Muennighoff, Niklas. "[SGPT: GPT Sentence Embeddings for Semantic Search](https://arxiv.org/abs/2202.08904)." *arXiv preprint arXiv:2202.08904*. 2022. -- Xu, Alon, Neubig, and Hellendoorn. "[A Systematic Evaluation of Large Language Models of Code](https://arxiv.org/abs/2202.13169)." In *Proceedings of the ICLR Workshop on Deep Learning For Code*. 2022. -## Acknowledgements -We run our experiments on a Kubernetes cluster generously provided by [CoreWeave](https://coreweave.com/) and a SLURM cluster provided by [Stability AI](https://stability.ai). diff --git a/PyTorch/built-in/foundation/GPT-NeoX/configs/20B.yml b/PyTorch/built-in/foundation/GPT-NeoX/configs/20B.yml index 7fcbdab89a64abed1b3aa0b80a36360164460eac..af16fbcdfb211787b010bc9e6f43bdcc430162ff 100644 --- a/PyTorch/built-in/foundation/GPT-NeoX/configs/20B.yml +++ b/PyTorch/built-in/foundation/GPT-NeoX/configs/20B.yml @@ -36,6 +36,8 @@ "output_layer_init_method": "wang_init", "scaled_masked_softmax_fusion":true, + "async_tensor_model_parallel_allreduce":true, + "use_triangle_attn":true, # optimizer settings "optimizer": { diff --git a/PyTorch/built-in/foundation/GPT-NeoX/megatron/model/transformer.py b/PyTorch/built-in/foundation/GPT-NeoX/megatron/model/transformer.py index d5d428faa1459378878f7fe08629f815f4b66594..78a2e40c36cfdd17245905e4775e85f793a102d7 100644 --- a/PyTorch/built-in/foundation/GPT-NeoX/megatron/model/transformer.py +++ b/PyTorch/built-in/foundation/GPT-NeoX/megatron/model/transformer.py @@ -198,6 +198,12 @@ class ParallelSelfAttention(nn.Module): self.fp16 = neox_args.precision == "fp16" self.bf16 = neox_args.precision == "bfloat16" + + self.use_triangle_attn = neox_args.use_triangle_attn + self.block_size = 512 + self.mask_tmp_initialed = False + self.mask_tmp_groups = [] + self.attention_mask_func = attention_mask_func self.apply_query_key_layer_scaling = neox_args.apply_query_key_layer_scaling self.use_cache = use_cache @@ -313,6 +319,62 @@ class ParallelSelfAttention(nn.Module): def attention( self, query_layer, key_layer, value_layer, layer_past, attention_mask ): + if self.use_triangle_attn and layer_past is None and query_layer.size( + 0) >= self.block_size * 2 and query_layer.size(0) % self.block_size == 0: + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(1), + value_layer.size(2), + query_layer.size(0), + value_layer.size(3)) + np = value_layer.size(2) + sequence_len, bsz, head_num, head_dim = query_layer.shape + + # todo the scenarios that cannot be exactly divided + sparse_groups = sequence_len // self.block_size + context_layer = None + alpha = 1.0 / self.norm_factor + query_layer *= alpha + + q_tmp_layers_tuple = torch.chunk(query_layer, sparse_groups, 0) + k_tmp_layers_tuple = torch.chunk(key_layer, sparse_groups, 0) + v_tmp_layers_tuple = torch.chunk(value_layer, sparse_groups, 0) + + context_list_tmp, k_tmp, v_tmp = [], (), () + for i in range(sparse_groups): + # compute slice shape of q k v for each loop + q_begin = i * self.block_size + q_end = (i + 1) * self.block_size + kv_begin = 0 + kv_end = (i + 1) * self.block_size + + # q_tmp: [q_size, b * np, hn] + q_tmp = q_tmp_layers_tuple[i].permute(1, 2, 0, 3).contiguous() + # slice k and v + if i == 0: + k_tmp = k_tmp_layers_tuple[i].permute(1, 2, 3, 0).contiguous() + v_tmp = v_tmp_layers_tuple[i].permute(1, 2, 0, 3).contiguous() + else: + k_tmp = torch.cat((k_tmp, k_tmp_layers_tuple[i].permute(1, 2, 3, 0).contiguous()), -1).contiguous() + v_tmp = torch.cat((v_tmp, v_tmp_layers_tuple[i].permute(1, 2, 0, 3).contiguous()), -2).contiguous() + cur_sim = torch.matmul(q_tmp, k_tmp) + # [b, np, sq, sk] -> [b, np, q_size, kv_size] + if not self.mask_tmp_initialed: + mask_tmp = attention_mask[:, :, q_begin:q_end, kv_begin:kv_end] + self.mask_tmp_groups.append(mask_tmp.contiguous()) + else: + mask_tmp = self.mask_tmp_groups[i] + probs = self.scale_mask_softmax(cur_sim, mask_tmp) + with mpu.get_cuda_rng_tracker().fork(): + probs = self.attention_dropout(probs) + # [b * np, q_size, kv_size] * [b * np, kv_size, hn] -> [b * np, q_size, hn] + context_layer_tmp = torch.matmul(probs, v_tmp) + context_list_tmp.append(context_layer_tmp) + self.mask_tmp_initialed = True + context_layer = torch.cat(context_list_tmp, 2) + # ================= + # Output. [sq, b, h] + # ================= + return context_layer # =================================== # Raw attention scores. [b, np, s, s] # =================================== @@ -331,28 +393,13 @@ class ParallelSelfAttention(nn.Module): ) key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) - - ''' - ### Raw attention scores. [b * np, sq, sk] # orig - matmul_result = torch.baddbmm( - matmul_result, - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] - beta=0.0, - alpha=(1.0 / self.norm_factor), - ) - ''' ###bmmm alpha = (1.0 / self.norm_factor) query_layer *= alpha matmul_result = torch.bmm( - # matmul_result, query_layer.transpose(0, 1), # [b * np, sq, hn] key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] - # beta=0.0, - # alpha=(1.0 / self.norm_factor), ) - # change view to [b, np, sq, sk] attention_scores = matmul_result.view(*output_size) diff --git a/PyTorch/built-in/foundation/GPT-NeoX/megatron/mpu/layers.py b/PyTorch/built-in/foundation/GPT-NeoX/megatron/mpu/layers.py index 2c48ff6918ce54dd71be45d4a9850ccc1e78e5ef..7ccea7ddbe2f0006f5aee967e6aa581a8cff884c 100644 --- a/PyTorch/built-in/foundation/GPT-NeoX/megatron/mpu/layers.py +++ b/PyTorch/built-in/foundation/GPT-NeoX/megatron/mpu/layers.py @@ -21,6 +21,7 @@ import math +from typing import Callable, Optional import torch import torch.nn.functional as F @@ -29,6 +30,8 @@ from torch.nn.parameter import Parameter from .initialize import get_model_parallel_rank from .initialize import get_model_parallel_world_size +from .initialize import get_model_parallel_group +from .initialize import get_fp32_allreduce from .mappings import copy_to_model_parallel_region from .mappings import gather_from_model_parallel_region from .mappings import reduce_from_model_parallel_region @@ -92,6 +95,111 @@ def _initialize_affine_weight_cpu( return None +# allreduce_hidden +class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): + """See linear_with_grad_accumulation_and_async_allreduce""" + @staticmethod + def forward( + ctx, + input, + weight, + bias, + gradient_accumulation_fusion, + async_grad_allreduce, + sequence_parallel, + ): + ctx.save_for_backward(input, weight) + ctx.use_bias = bias is not None + ctx.gradient_accumulation_fusion = gradient_accumulation_fusion + ctx.async_grad_allreduce = async_grad_allreduce + # ctx.sequence_parallel = sequence_parallel + + total_input = input + + output = torch.matmul(total_input, weight.t()) + if bias is not None: + output = output + bias + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + use_bias = ctx.use_bias + total_input = input + grad_input = grad_output.matmul(weight) + + # Doing gather + slicing during the NeMo forward pass can make this tensor + # not be contiguous. PyTorch only checks if the tensor is contiguous, and only + # clones it if it's not contiguous: + # https://github.com/pytorch/pytorch/blob/c47cf9bc7f9e02f649ab4ed53fe4d35732c92ab6/torch/_refs/__init__.py#L2761 + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + grad_output = grad_output.view( + grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2] + ) + total_input = total_input.view( + total_input.shape[0] * total_input.shape[1], total_input.shape[2] + ) + + if ctx.async_grad_allreduce: + # Bf16 convert + dt = grad_input.dtype + if dt == torch.bfloat16 and get_fp32_allreduce(): + grad_input = grad_input.float() + + # Asynchronous all-reduce + handle = torch.distributed.all_reduce( + grad_input, group=get_model_parallel_group(), async_op=True + ) + + # Bf16 convert + if dt == torch.bfloat16 and get_fp32_allreduce(): + grad_input = grad_input.bfloat16() + # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the + # all-reduce is scheduled before the weight gradient computation + + if ctx.gradient_accumulation_fusion: + if weight.main_grad.dtype == torch.float32: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32( + total_input, grad_output, weight.main_grad + ) + elif weight.main_grad.dtype in (torch.float16, torch.bfloat16): + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16( + total_input, grad_output, weight.main_grad + ) + else: + raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") + grad_weight = None + else: + grad_weight = grad_output.t().matmul(total_input) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.async_grad_allreduce: + handle.wait() + + return grad_input, grad_weight, grad_bias, None, None, None + +# allreduce_hidden +def linear_with_grad_accumulation_and_async_allreduce( + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + gradient_accumulation_fusion: bool, + async_grad_allreduce: bool, + sequence_parallel: bool, +) -> torch.Tensor: + args = [ + input, + weight, + bias, + gradient_accumulation_fusion, + async_grad_allreduce, + sequence_parallel, + ] + return LinearWithGradAccumulationAndAsyncCommunication.apply(*args) + + + class VocabParallelEmbedding(torch.nn.Module): """Embedding parallelized in the vocabulary dimension. @@ -483,6 +591,11 @@ class ColumnParallelLinear(torch.nn.Module): self.bias.zero_() else: self.register_parameter("bias", None) + # allreduce_hidden + self._forward_impl = linear_with_grad_accumulation_and_async_allreduce + self.async_tensor_model_parallel_allreduce = ( + neox_args.async_tensor_model_parallel_allreduce and world_size > 1 + ) # Copied from Mup def width_mult(self): @@ -539,12 +652,26 @@ class ColumnParallelLinear(torch.nn.Module): def forward(self, input_): if self.use_mup and self.mup_rescale_parameters: input_ /= self.width_mult() - # Set up backprop all-reduce. - input_parallel = copy_to_model_parallel_region(input_) # Matrix multiply. - bias = self.bias if not self.skip_bias_add else None - output_parallel = F.linear(input_parallel, self.weight, bias) + # allreduce_hidden + if self.async_tensor_model_parallel_allreduce: + # goes into async tensor model parallel allreduce + input_parallel = input_ + output_parallel = self._forward_impl( + input=input_parallel, + weight=self.weight, + bias=bias, + gradient_accumulation_fusion=False, #self.gradient_accumulation_fusion, + async_grad_allreduce=self.async_tensor_model_parallel_allreduce, + sequence_parallel=False, + ) + else: + # orig + # Set up backprop all-reduce. + input_parallel = copy_to_model_parallel_region(input_) + output_parallel = F.linear(input_parallel, self.weight, bias) + if self.gather_output: # All-gather across the partitions. output = gather_from_model_parallel_region(output_parallel) diff --git a/PyTorch/built-in/foundation/GPT-NeoX/megatron/neox_arguments/neox_args.py b/PyTorch/built-in/foundation/GPT-NeoX/megatron/neox_arguments/neox_args.py index f9d632cd996e4b9c4ccb8a4a3628a8da925671b7..fac9156cb1fde05424c7d96c6e69df4a32c799de 100644 --- a/PyTorch/built-in/foundation/GPT-NeoX/megatron/neox_arguments/neox_args.py +++ b/PyTorch/built-in/foundation/GPT-NeoX/megatron/neox_arguments/neox_args.py @@ -245,6 +245,16 @@ class NeoXArgsModel(NeoXArgsTemplate): Enable fusion of query_key_value_scaling general masking and softmax. """ + async_tensor_model_parallel_allreduce: bool = False + """ + Enable hidden allreduce + """ + + use_triangle_attn: bool = False + """ + enable use triangle attention + """ + bias_gelu_fusion: bool = False """ Enable bias and gelu fusion.