Feat: add checkpoint loading mechanism#146
Conversation
e8c5dd5 to
0a3deb2
Compare
|
|
||
| named_shard_params_.clear(); | ||
| for (size_t i = 0; i < shard_params_.size(); ++i) { | ||
| named_shard_params_.emplace_back(shard_param_names_[i], shard_params_[i]); |
There was a problem hiding this comment.
这里named_shard_params_有没有可能出现多个相同 name
There was a problem hiding this comment.
这个name不会重复,name对应各层的参数名(模块名加index
0a089ae to
4981cd4
Compare
|
|
||
| // precision | ||
| DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)"); | ||
| DEFINE_uint32(save_steps, 0, "save checkpoint every N steps; 0 disables saving"); |
There was a problem hiding this comment.
https://github.com/NVIDIA/Megatron-LM/blob/main/examples/llama/train_llama3_8b_h100_fp8.sh
这个参数在 megatron 里应该是叫 --save-interval,其余参数也都确认下,
命名和 megatron 对齐吧。
以及现在 main 里的参数有点多了,感觉之后可以考虑类似 megatron 那样把参数按不同类型分个组,先加个 TODO 记一下。
There was a problem hiding this comment.
做了如下替换
save_interval ----- save_steps
load ---- resume_from
save ---- checkpoint_dir
no_save_optim ---- save_optimizer_state // 语义是反的,默认值为false
checkpoint_format 已删除
There was a problem hiding this comment.
确实。。no_save_optim 这个语义有点奇怪了,还是改回正向配置吧
|
建议补一个文档(飞书文档就行,不用放仓库里),介绍一下目前 pth 的格式与 torch/megatron ckpt 格式,解释一下现有格式与主流框架的区别、未来可能的兼容方式。 |
| const std::unordered_map<std::string, std::shared_ptr<Tensor>> &state_dict); | ||
|
|
||
| static std::unordered_map<std::string, std::shared_ptr<Tensor>> | ||
| LoadStateDictBinary(const std::filesystem::path &path); |
There was a problem hiding this comment.
binary 相关的 save/load 保留原有的功能函数形式就行,没必要做到 ckpt 里,ckpt 就只做 pth 的结构化存取就行。
There was a problem hiding this comment.
checkpoint 里不再支持binary格式,相关为了做format适配的flag 以及 options结构体都删除
There was a problem hiding this comment.
现在新增的格式后缀记成了 .ckpt, SaveStateDictBinary 和 LoadStateDictBinary 针对的就是 .ckpt。上次修改把LoadFromLLMC (对旧格式的读取)和 SaveLLMC 移出了 checkpoint机制,并且不再有 SaveLLMC 的调用,也在修改中删了。
There was a problem hiding this comment.
感觉这俩函数命名叫 SaveStateDict/LoadStateDict 更合适,跟下面的 SaveTrainerState/LoadTrainerState 对齐(都不带 binary 后缀)
0b4857d to
8ebe12f
Compare
之前的文档有介绍现在的格式排布,待补充torch/megatron 用到的格式介绍。 |
c499288 to
ebaeadf
Compare
format: use clang-format-16 instead
remove redundent arguments
format files
- Use name-based optimizer state keys instead of index-based to
prevent state corruption from unordered_map traversal order
- Warn on unexpected keys when loading model state dict
- Validate parallel topology (TP/PP/SP) consistency on resume
- Add batch_idx alignment check for distributed data loader
- Default best_loss to infinity instead of zero
8dc64b3 to
23a5e32
Compare
68b7c77 to
702a9bd
Compare
…, with plans to unify into one later.
702a9bd to
e8b1e37
Compare
| const std::unordered_map<std::string, std::shared_ptr<Tensor>> &state_dict); | ||
|
|
||
| static std::unordered_map<std::string, std::shared_ptr<Tensor>> | ||
| LoadStateDictBinary(const std::filesystem::path &path); |
There was a problem hiding this comment.
感觉这俩函数命名叫 SaveStateDict/LoadStateDict 更合适,跟下面的 SaveTrainerState/LoadTrainerState 对齐(都不带 binary 后缀)
| constexpr uint32_t kCkptMagic = 0x54504B43; // CKPT | ||
| constexpr uint32_t kCkptVersion = 1; | ||
|
|
||
| uint32_t PeekMagic(const std::filesystem::path &path) { |
There was a problem hiding this comment.
PeekMagic 这个命名有点容易让人联想到 istream::peek() 的语义(查看但不消费数据),而这里实际上是打开文件并读取 magic number(实际会消费数据)。建议改成 ReadMagic() 或 ReadMagicNumber(),语义会更直接一些。
| return s; | ||
| } | ||
|
|
||
| std::string ExtractStringField(const std::string &content, const std::string &key, const std::string &fallback) { |
| return content.substr(first_quote + 1, second_quote - first_quote - 1); | ||
| } | ||
|
|
||
| template <typename T> T ExtractNumberField(const std::string &content, const std::string &key, T fallback) { |
There was a problem hiding this comment.
这里实际上是在手写一个简易 JSON 解析器。建议加个 TODO 标注一下,后续引入 JSON 库后统一替换,避免长期维护这套字符串解析逻辑。
| } | ||
|
|
||
| SaveTrainerState(checkpoint_dir / "trainer_state.json", state); | ||
| LOG(ERROR) << "[CKPT] Save done: dir=" << checkpoint_dir; |
There was a problem hiding this comment.
应该是 LOG(INFO) 吧?下面 Checkpoint::Load 的几个 LOG(ERROR) 也应该是 INFO。
There was a problem hiding this comment.
这里save/load 写成LOG(ERROR) 好一些,不然不开维测感知不到默认保存了checkpoint,其他分支改成LOG(INFO)
| class Optimizer; | ||
|
|
||
| using OptimizerCreator = std::function<std::shared_ptr<Optimizer>(const std::vector<std::shared_ptr<Tensor>> ¶ms)>; | ||
| using OptimizerCreatorNamed = std::function<std::shared_ptr<Optimizer>( |
There was a problem hiding this comment.
目前框架里似乎没有实际使用 named_parameters 构造 Optimizer 的场景,而当前实现中参数名在构造后也没有被利用,StateDict 仍然是基于下标生成 adam.m.{i}、adam.v.{i} 等 key,与普通 Parameters() 构造方式本质上没有区别。
建议暂时不引入 OptimizerCreatorNamed 及相关 named parameters 构造接口,后续如果需要支持 param group、按参数名配置超参(如 weight decay 分组)、或者对齐 Torch 的 optimizer state_dict 设计,再统一引入会更自然一些。
从 Torch 的实现来看,optimizer state_dict 也是围绕 state + param_groups 组织的,而不是直接基于参数名保存状态,后续如果要支持 param group,接口设计可能还需要进一步调整:
https://docs.pytorch.org/docs/2.12/generated/torch.optim.Optimizer.state_dict.html?utm_source=chatgpt.com#torch-optim-optimizer-state-dict
| "checkpoint_dir": "/data1/ckpt/bf16_resume" | ||
| } | ||
| }, | ||
| { |
There was a problem hiding this comment.
理论上 ckpt 需要特殊处理 lora 的 A/B 权重读取逻辑,lora 用例也应当设置 lora_save_path/lora_load_path,在 #150 修复了,暂时还没合入。
建议先将 lora 相关用例从本 pr 中移除,lora 与 checkpoint 的集成逻辑可以后续单独提 pr 支持,可以和 @chen2021673 讨论确认下集成逻辑。
| COMPARE_LOG_DIR="$(read_var COMPARE_LOG_DIR)"; : "${COMPARE_LOG_DIR:=}" | ||
| RUN_CTEST="$(read_var RUN_CTEST)"; : "${RUN_CTEST:=true}" | ||
| CTEST_CMD="$(read_var CTEST_CMD)"; : "${CTEST_CMD:=ctest --output-on-failure -LE cuda -j$(nproc) && ctest --output-on-failure -L cuda -j1}" | ||
| CKPT_CLEAN_DIRS=( |
There was a problem hiding this comment.
这里为什么会有两个 checkpoint 路径?另外,建议在 test_config.json 中统一增加一个 CKPT_DIR 配置,后续无论是 test_config.json 里的传参、checkpoint 清理,还是其他依赖 checkpoint 路径的逻辑,都基于该配置拼接生成路径。
这样可以避免路径配置分散在多个地方,后续修改目录结构时出现配置不一致的问题。
There was a problem hiding this comment.
这个文件里的逻辑已经不只是 example 层的辅助函数了,包含了 checkpoint 保存/加载、分布式拓扑校验、模型配置校验以及 checkpoint pruning 等通用能力,建议抽到框架侧,例如infini_train/include/checkpoint/checkpoint_manager.h。(Checkpoint 类的声明/实现也可以一并迁移到 checkpoint 目录下)
|
|
||
| LOG(INFO) << std::format("Checkpoint saved at: {} ({:.2f} ms)", args.save_dir.string(), ckpt_ms); | ||
|
|
||
| if (!args.prune_step_checkpoints) { |
There was a problem hiding this comment.
prune_step_checkpoints 这个开关有必要吗?感觉保留 max_checkpoint_keep 就足够了,例如约定:
max_checkpoint_keep = 0 -> 不清理
max_checkpoint_keep > 0 -> 保留最近 N 个
另外下面的清理逻辑是直接对 ckpt 字符串路径做排序,如果 ckpt 目录名不做固定宽度补零(例如 checkpoint_step_10、checkpoint_step_100),可能导致清理顺序错误。
建议先留个 FIXME,后续明确 checkpoint 命名规范(统一补零),或者显式解析 step 后按数值排序。这样会更稳妥一些。
1. checkpoint机制
From ArcaLunar:
Checkpoint 读取工具主要参数:
--save训练过程中的保存目录--save_interval每 N 次保存一次,设置为 0 则不保存--max_checkpoint_keep最多保留 K 个 checkpoint--no_save_optim是否保存优化器的状态--load从指定 checkpoint 目录恢复训练Checkpoint 文件可以通过从
/data/shared/....../llmc/gpt2(orllama3) 的原始模型参数训练而来,例子可见仓库中的 REPORT.md(Experiment 实际上也测试了llama3,但是命令只记录了 GPT2 训练),model.bin, optimizer.bin, trainer_state.json都可以从训练中获取.因此不在附件中提供Experiment
CUDA_VISIBLE_DEVICES=5,6,7 ./gpt2 --input_bin ../../data/llmc/gpt2/tinyshakespeare/tiny_shakespeare_train.bin --llmc_filepath ../../data/llmc/gpt2/gpt2_124M.bin --save ../ckpt2/gpt2-noresume/ --num_iteration 100 --save_interval 20 --no_save_optim false --max_checkpoint_keep 10(以上两条训练命令同样用 llama3 也运行了)
运行 compare_loss.py,对于 llama3 模型,由于从 step 40 恢复训练,所以 step 1~40 数据缺失,而其余 60 步的 loss 在 FP32, BF16 下均吻合
对于 GPT2,模型保存的逻辑有误:训练中 lm_head 与 wte 并非真共享,而 LLMC 存取又按“共享”假设处理,resume 后 lm_head 很容易和 no resume 不一致。解决方法是把训练用 checkpoint 从 LLMC 回调路径切到原生 StateDict 二进制路径,并在加载后显式重建权重绑定语义 (
example/gpt2/main.cc).经过修复后,也可以通过.2. 训练对比
精度对比:

性能对比:
