mindspeed-llm源码解析(一)preprocess_data

本文涉及的产品
全局流量管理 GTM,标准版 1个月
云解析 DNS,旗舰版 1个月
公共DNS(含HTTPDNS解析),每月1000万次HTTP解析
简介: mindspeed-llm是昇腾模型套件代码仓,原来叫"modelLink"。这篇文章带大家阅读一下数据处理脚本preprocess_data.py(基于1.0.0分支),数据处理是模型训练的第一步,经常会用到。

mindspeed-llm是昇腾模型套件代码仓,原来叫"modelLink"。这篇文章带大家阅读一下数据处理脚本preprocess_data.py(基于1.0.0分支),数据处理是模型训练的第一步,经常会用到。

文章中贴的源码加了相关注释,同学们可以把源码和注释结合起来看。

首先来看一下main函数

def main():
    # 获取入参,通过后面的代码可以知道有哪些关键参数
    args = get_args()
    # 参数校验
    validate_args(args)
    # 合并已经处理好的数据集
    if args.merge_group_keys is not None:
        merge_datasets(args)
        return
    # 创建splitter,用来把文章段落分割成句子
    splitter = build_splitter(args)
    # 创建tokenizer,用来把句子切分成单个的词
    tokenizer = build_tokenizer(args)
    logger.info("building dataset: %s", args.input)
    # 加载数据,把CSV、JSON、TXT等格式的数据加载到内存
    raw_data = build_dataset(args)
    # 保存到一个文件
    if args.n_subs == 1:
        # 获取处理后的数据句柄
        handler = get_dataset_handler(args, raw_data, tokenizer, splitter)
        # 数据落盘
        handler.serialize_to_disk()
    # 保存到多个文件,使用多进程处理,单文件的处理方式和if条件中是一致的
    else:
        target_prefix = args.output_prefix
        target_prefixname = os.path.basename(target_prefix)
        
        num_samples = len(raw_data)
        start_ends = cut_range_to_subs(num_samples, num_samples // args.n_subs)
        subsets = [raw_data.select(range(x[0], x[1])) for x in start_ends]
        
        # multiprocessing
        params_list = []
        for k, subset in enumerate(subsets):
            args_ = copy.deepcopy(args)
            args_.output_prefix = target_prefix.replace(target_prefixname, f'{str(k).zfill(3)}_of_{str(len(subsets)-1).zfill(3)}_{target_prefixname}')
            params = [args_, subset, tokenizer, splitter]
            params_list.append(params)
        pool = multiprocessing.Pool()
        sub_idx_files = pool.map(handle_subset, params_list)
        pool.close()
        pool.join()
        
        for key in sub_idx_files[0].keys():
            idx_files = [x[key] for x in sub_idx_files]
            idx_files.sort()
            target_idx = idx_files[0].replace(f'000_of_{str(len(subsets)-1).zfill(3)}_{target_prefixname}', target_prefixname)
            target_bin = target_idx.replace('.idx', '.bin')
            idx = IndexedDatasetBuilder(target_bin)
            for idx_file in idx_files:
                idx.add_index(idx_file.replace('.idx', ''))
            idx.finalize(target_idx)
            
            for idx_file in idx_files:
                os.remove(idx_file)
                os.remove(idx_file.replace('.idx', '.bin'))

可以看到,main函数处理逻辑主要由这几个函数组成:build_splitter、build_tokenizer、build_dataset、get_dataset_handler、serialize_to_disk。

build_splitter

这个函数的功能是把文字段落分割成单个句子,查看源码,主要使用的是三方库nltk的函数:

def build_splitter(args):
    if nltk and args.split_sentences:
        nltk.download("punkt", quiet=True)
    if args.split_sentences:
        if not nltk:
            logger.error("NLTK is not available to split sentences.")
            raise Exception("nltk is not available")
        splitter = nltk.load("tokenizers/punkt/english.pickle")
        if args.keep_newlines:
            # this prevents punkt from eating newlines after sentences
            final_splitter = nltk.tokenize.punkt.PunktSentenceTokenizer(
                train_text=splitter._params,
                lang_vars=CustomLanguageVars())
        else:
            final_splitter = splitter
    else:
        # 自定义splitter
        final_splitter = IdentitySplitter()
    return final_splitter

build_tokenizer

这个函数的主要功能是把句子切分成单个的词,比如说把 "今天是星期几" 切分成 "今天"、"是"、"星期几",然后转成对应的整数。

def build_tokenizer(args):
    """Initialize tokenizer."""
    # 获取huggingface的tokenizer 
    if args.tokenizer_type == "PretrainedFromHF":
        if args.rank == 0:
            print(' > building PretrainFromHF tokenizer. Vocab file is un-used, '
                  'loading tokenizer from pre-trained model', flush=True)
        if args.tokenizer_name_or_path is None:
            raise ValueError("Missing tokenizer_name_or_path while building PretrainFromHF tokenizer.")
        hf_tokenizer_kwargs = dict()
        if hasattr(args, "tokenizer_kwargs") and args.tokenizer_kwargs:
            if len(args.tokenizer_kwargs) % 2 != 0:
                raise ValueError("The token name and token value must be entered in pairs.")
            for i in range(0, len(args.tokenizer_kwargs), 2):
                hf_tokenizer_kwargs[args.tokenizer_kwargs[i]] = \
                    args.tokenizer_kwargs[i + 1]
        # 基于MegatronTokenizer构建的类
        tokenizer = _AutoTokenizer(
            args.tokenizer_name_or_path,
            vocab_extra_ids=args.vocab_extra_ids,
            model_max_length=args.seq_length,
            use_fast=args.tokenizer_not_use_fast,
            **hf_tokenizer_kwargs
        )
        # Add vocab size (if not already set from a checkpoint).
        if getattr(args, "padded_vocab_size", None) is None:
            args.padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size,
                                                              args)
    else:
        # 
        tokenizer = TokenizerAdaptor(megatron_build_tokenizer(args))
    # 根据prompt_type完善tokenizer
    if hasattr(args, "prompt_type") and args.prompt_type is not None:
        if ("PreTrainedTokenizerBase" not in str(tokenizer.tokenizer._pad.__func__)):
            tokenizer.tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer.tokenizer)
            tokenizer.tokenizer.padding_side = "right"
        fix_model_tokenizer(tokenizer.tokenizer, args.prompt_type.strip())
    return tokenizer

分成了2类tokenizer,一类是PretrainedFromHF,也就是使用预训练的 HuggingFace 分词器;如果不适用hf的,则则使用 TokenizerAdaptor 类和 megatron_build_tokenizer 函数创建分词器实例 tokenizer。

build_dataset

这个函数的功能是把数据文件加载到内存,返回DatasetDict 或Dataset,也就是一个Python容器。这个函数中调用的load_dataset是huggingface的datasets库的函数。

def build_dataset(args):
    """loading dataset by huggingface"""
    raw_datasets = None
    if args.handler_name == "LlamaFactoryInstructionHandler":
        all_datasets = []
        for dataset_attr in get_dataset_list(args):
            # 加载单个数据集
            all_datasets.append(load_single_dataset(dataset_attr, args))
        # 合并数据集
        raw_datasets = merge_dataset(all_datasets, args)
    else:
        if args.handler_name == "MOSSInstructionHandler" or args.handler_name == "MOSSMultiTurnHandler":
            # for MOSS, streaming is needed.流式加载数据
            args.streaming = True
        if args.hf_datasets_params:
            with open(args.hf_datasets_params, 'r') as fin:
                param_dict = json.load(fin)
            return load_dataset(**param_dict)
        cache_dir = args.cache_dir
        split_flag = "train"
        load_from_local = os.path.exists(args.input)
        # 从本地加载
        if load_from_local:
            # args.input 是一个有效的 Python 脚本路径
            if _has_py_script(args.input):
                logger.info("loading data from a local python script")
                raw_datasets = load_dataset(
                    args.input,
                    data_dir='./' if not args.script_data_dir else args.script_data_dir,
                    split=split_flag,
                    num_proc=None if args.streaming else args.workers,
                    cache_dir=cache_dir,
                    streaming=args.streaming,
                    trust_remote_code=False
                )
            else:
                # args.input 是一个文件或目录路径
                data_files = [args.input] if os.path.isfile(args.input) else \
                    glob.glob(os.path.join(args.input, '*'))
                # 获取文件格式
                ext, data_format = _get_data_format(data_files)
                # 筛选合法的文件格式
                filtered_data_files = list(filter(lambda x: x.split('.')[-1] == ext, data_files))
                if filtered_data_files:
                    logger.info("loading data from local file, format: %s,"
                                " file num: %s", data_format, len(data_files))
                    raw_datasets = load_dataset(
                        data_format,
                        split=split_flag,
                        data_files=filtered_data_files,
                        num_proc=None if args.streaming else args.workers,
                        cache_dir=cache_dir,
                        streaming=args.streaming,
                        trust_remote_code=False
                    )
                else:
                    raise Exception("unknown local data!")
        else:
            logger.info("loading data from remote huggingface")  # 从远程 Hugging Face 数据集加载
            raw_datasets = load_dataset(
                args.input,
                split=split_flag,
                num_proc=None if args.streaming else args.workers,
                cache_dir=cache_dir,
                streaming=args.streaming,
                trust_remote_code=False
            )
        if raw_datasets is None:
            raise Exception("unknown data!")
        if args.handler_name in [
            "AlpacaStyleInstructionHandler",
            "SharegptStyleInstructionHandler",
            "AlpacaStylePairwiseHandler",
            "SharegptStylePairwiseHandler"
        ]:
            handler_dataset_attr = get_handler_dataset_attr(args, raw_datasets)
            return align_dataset(raw_datasets, handler_dataset_attr, args)
    return raw_datasets

get_dataset_handler

这个函数的功能是创建数据集处理实例,_get_handler_cls会根据args.handler_name选择对应的handler。handler的基类和子类都在mindspeed_llm/tasks/preprocess/data_handler.py里面定义了,查看BaseDatasetHandler可以知道,这个类的对外函数有这几个:get_tokenized_data、serialize_to_disk,功能分别是对数据进行令牌化、数据序列化。

serialize_to_disk

接着上面讲,这个函数是handler的类函数,用于将分词后的数据集保存到磁盘。具体来说,它将数据集的每个样本(或句子)序列化为二进制文件,并生成相应的索引文件。代码如下:

def _serialize_to_disk(self, iteration_batch_size=50):
        startup_start = time.time()
        if not self.tokenized_dataset:
            self.tokenized_dataset = self.get_tokenized_data()
        output_bin_files = {}  # 保存数据的文件路径
        output_idx_files = {}  # 保存数据的文件路径
        builders = {}  # 用于构建索引数据集的字典
        level = "document"
        if self.args.split_sentences:
            level = "sentence"
        logger.info("Vocab size: %s", self.tokenizer.vocab_size)
        logger.info("Output prefix: %s", self.args.output_prefix)
        # 字典的key就是文件名,json_keys就是类似"input_ids", "attention_mask", "labels"的string
        for key in self.args.json_keys:
            output_bin_files[key] = f"{self.args.output_prefix}_{key}_{level}.bin"
            output_idx_files[key] = f"{self.args.output_prefix}_{key}_{level}.idx"
            # vocab_size=None : use int32 dtype for -100 will be used in labels
            # 为每个文件创建一个数据字典
            builders[key] = indexed_dataset.IndexedDatasetBuilder(output_bin_files[key])
        self.output_idx_files = output_idx_files
        startup_end = time.time()
        proc_start = time.time()
        total_bytes_processed = 0
        logger.info("Time to startup:%s", startup_end - startup_start)
        skip_num = 0
        # 遍历每个文件的内容
        for i, doc in enumerate(self.tokenized_dataset.iter(batch_size=iteration_batch_size), start=1):
            # In post-training stage, we need to drop the data exceeded set sequence-length
            skip_indices = set()
            # 进行一次筛选
            for key in self.args.json_keys:
                batch = [sentences for sentences in doc[key] if len(sentences) > 0]
                if len(batch) == 0:
                    continue
                for j, sentences in enumerate(batch):
                    for k, sentence in enumerate(sentences):
                        if self.args.seq_length is not None and len(sentence) >= self.args.seq_length:
                            skip_indices.add((j, k))
            # 正式开始处理每个句子
            for key in self.args.json_keys:
                batch = [sentences for sentences in doc[key] if len(sentences) > 0]
                if len(batch) == 0:
                    continue
                for j, sentences in enumerate(batch):
                    for k, sentence in enumerate(sentences):
                        if (j, k) in skip_indices:
                            skip_num = skip_num + 1
                            continue
                        # 记录处理的字节数
                        total_bytes_processed += len(sentence) * np.int32().itemsize
                        # 把合法的句子加到builders里面
                        builders[key].add_item(sentence)
                    builders[key].end_document()
            batch_id = i * iteration_batch_size
            if batch_id % self.args.log_interval == 0:
                current = time.time()
                elapsed = current - proc_start
                mbs = total_bytes_processed / elapsed / 1024 / 1024
                logger.info("Processed %s documents (%s docs/s, %s MB/s).", batch_id, batch_id / elapsed, mbs)
        logger.info("Skip %s sample exceeded seq-length(%s)", skip_num / len(self.args.json_keys), self.args.seq_length)
        for key in self.args.json_keys:
            builders[key].finalize(output_idx_files[key])

以上就是mindspeed-llm处理数据的主要函数了,大家还有什么想了解的呢?欢迎评论区提问!

目录
相关文章
|
2月前
|
监控 Java 应用服务中间件
高级java面试---spring.factories文件的解析源码API机制
【11月更文挑战第20天】Spring Boot是一个用于快速构建基于Spring框架的应用程序的开源框架。它通过自动配置、起步依赖和内嵌服务器等特性,极大地简化了Spring应用的开发和部署过程。本文将深入探讨Spring Boot的背景历史、业务场景、功能点以及底层原理,并通过Java代码手写模拟Spring Boot的启动过程,特别是spring.factories文件的解析源码API机制。
107 2
|
3月前
|
缓存 Java 程序员
Map - LinkedHashSet&Map源码解析
Map - LinkedHashSet&Map源码解析
92 0
|
3月前
|
算法 Java 容器
Map - HashSet & HashMap 源码解析
Map - HashSet & HashMap 源码解析
77 0
|
3月前
|
存储 Java C++
Collection-PriorityQueue源码解析
Collection-PriorityQueue源码解析
79 0
|
26天前
|
存储 设计模式 算法
【23种设计模式·全精解析 | 行为型模式篇】11种行为型模式的结构概述、案例实现、优缺点、扩展对比、使用场景、源码解析
行为型模式用于描述程序在运行时复杂的流程控制,即描述多个类或对象之间怎样相互协作共同完成单个对象都无法单独完成的任务,它涉及算法与对象间职责的分配。行为型模式分为类行为模式和对象行为模式,前者采用继承机制来在类间分派行为,后者采用组合或聚合在对象间分配行为。由于组合关系或聚合关系比继承关系耦合度低,满足“合成复用原则”,所以对象行为模式比类行为模式具有更大的灵活性。 行为型模式分为: • 模板方法模式 • 策略模式 • 命令模式 • 职责链模式 • 状态模式 • 观察者模式 • 中介者模式 • 迭代器模式 • 访问者模式 • 备忘录模式 • 解释器模式
【23种设计模式·全精解析 | 行为型模式篇】11种行为型模式的结构概述、案例实现、优缺点、扩展对比、使用场景、源码解析
|
26天前
|
设计模式 存储 安全
【23种设计模式·全精解析 | 创建型模式篇】5种创建型模式的结构概述、实现、优缺点、扩展、使用场景、源码解析
结构型模式描述如何将类或对象按某种布局组成更大的结构。它分为类结构型模式和对象结构型模式,前者采用继承机制来组织接口和类,后者釆用组合或聚合来组合对象。由于组合关系或聚合关系比继承关系耦合度低,满足“合成复用原则”,所以对象结构型模式比类结构型模式具有更大的灵活性。 结构型模式分为以下 7 种: • 代理模式 • 适配器模式 • 装饰者模式 • 桥接模式 • 外观模式 • 组合模式 • 享元模式
【23种设计模式·全精解析 | 创建型模式篇】5种创建型模式的结构概述、实现、优缺点、扩展、使用场景、源码解析
|
26天前
|
设计模式 存储 安全
【23种设计模式·全精解析 | 创建型模式篇】5种创建型模式的结构概述、实现、优缺点、扩展、使用场景、源码解析
创建型模式的主要关注点是“怎样创建对象?”,它的主要特点是"将对象的创建与使用分离”。这样可以降低系统的耦合度,使用者不需要关注对象的创建细节。创建型模式分为5种:单例模式、工厂方法模式抽象工厂式、原型模式、建造者模式。
【23种设计模式·全精解析 | 创建型模式篇】5种创建型模式的结构概述、实现、优缺点、扩展、使用场景、源码解析
|
17天前
|
存储 缓存 人工智能
深度解析CPFS 在 LLM 场景下的高性能存储技术
本文深入探讨了CPFS在大语言模型(LLM)训练中的端到端性能优化策略,涵盖计算端缓存加速、智能网卡加速、数据并行访问及数据流优化等方面。重点分析了大模型对存储系统的挑战,包括计算规模扩大、算力多样性及数据集增长带来的压力。通过分布式P2P读缓存、IO加速、高性能存算通路技术以及智能数据管理等手段,显著提升了存储系统的吞吐量和响应速度,有效提高了GPU利用率,降低了延迟,从而加速了大模型的训练进程。总结了CPFS在AI训练场景中的创新与优化实践,为未来大模型发展提供了有力支持。
|
2月前
|
缓存 监控 Java
Java线程池提交任务流程底层源码与源码解析
【11月更文挑战第30天】嘿,各位技术爱好者们,今天咱们来聊聊Java线程池提交任务的底层源码与源码解析。作为一个资深的Java开发者,我相信你一定对线程池并不陌生。线程池作为并发编程中的一大利器,其重要性不言而喻。今天,我将以对话的方式,带你一步步深入线程池的奥秘,从概述到功能点,再到背景和业务点,最后到底层原理和示例,让你对线程池有一个全新的认识。
65 12
|
1月前
|
PyTorch Shell API
Ascend Extension for PyTorch的源码解析
本文介绍了Ascend对PyTorch代码的适配过程,包括源码下载、编译步骤及常见问题,详细解析了torch-npu编译后的文件结构和三种实现昇腾NPU算子调用的方式:通过torch的register方式、定义算子方式和API重定向映射方式。这对于开发者理解和使用Ascend平台上的PyTorch具有重要指导意义。

热门文章

最新文章

推荐镜像

更多