Skip to content

Feat: add checkpoint loading mechanism#146

Open
JYMiracle305 wants to merge 9 commits into
masterfrom
feature/add_checkpoint
Open

Feat: add checkpoint loading mechanism#146
JYMiracle305 wants to merge 9 commits into
masterfrom
feature/add_checkpoint

Conversation

@JYMiracle305

@JYMiracle305 JYMiracle305 commented Apr 21, 2026

Copy link
Copy Markdown
Contributor

1. checkpoint机制

From ArcaLunar

Checkpoint 读取工具主要参数:

  • --save 训练过程中的保存目录
  • --save_interval 每 N 次保存一次,设置为 0 则不保存
  • --max_checkpoint_keep 最多保留 K 个 checkpoint
  • --no_save_optim 是否保存优化器的状态
  • --load 从指定 checkpoint 目录恢复训练

Checkpoint 文件可以通过从 /data/shared/....../llmc/gpt2 (or llama3) 的原始模型参数训练而来,例子可见仓库中的 REPORT.md(Experiment 实际上也测试了llama3,但是命令只记录了 GPT2 训练),model.bin, optimizer.bin, trainer_state.json 都可以从训练中获取.因此不在附件中提供

Experiment

CUDA_VISIBLE_DEVICES=5,6,7 ./gpt2 --input_bin ../../data/llmc/gpt2/tinyshakespeare/tiny_shakespeare_train.bin --llmc_filepath ../../data/llmc/gpt2/gpt2_124M.bin --save ../ckpt2/gpt2-noresume/ --num_iteration 100 --save_interval 20 --no_save_optim false --max_checkpoint_keep 10
CUDA_VISIBLE_DEVICES=5,6,7 ./gpt2 --input_bin ../../data/llmc/gpt2/tinyshakespeare/tiny_shakespeare_train.bin --llmc_filepath ../../data/llmc/gpt2/gpt2_124M.bin --save ../ckpt2/gpt2-resumefrom40/ --num_iteration 100 --save_interval 20 --no_save_optim false --max_checkpoint_keep 10 --load ../ckpt2/gpt2-noresume/checkpoint_step_000040/ > ../ckpt2/gpt2-resumefrom40/gpt2-resume.log 2>&1

(以上两条训练命令同样用 llama3 也运行了)

运行 compare_loss.py,对于 llama3 模型,由于从 step 40 恢复训练,所以 step 1~40 数据缺失,而其余 60 步的 loss 在 FP32, BF16 下均吻合

  Summary: 60/100 steps matched

==================================================
Overall Summary:
  fp32:    0/1 test cases passed (threshold: 1e-05)
  bfloat16: 0/0 test cases passed (threshold: 1e-02)
  Total:   0/1 test cases passed
==================================================

==================================================
Overall Summary:
  fp32:    0/0 test cases passed (threshold: 1e-05)
  bfloat16: 0/1 test cases passed (threshold: 1e-02)
  Total:   0/1 test cases passed
==================================================

对于 GPT2,模型保存的逻辑有误:训练中 lm_head 与 wte 并非真共享,而 LLMC 存取又按“共享”假设处理,resume 后 lm_head 很容易和 no resume 不一致。解决方法是把训练用 checkpoint 从 LLMC 回调路径切到原生 StateDict 二进制路径,并在加载后显式重建权重绑定语义 (example/gpt2/main.cc).经过修复后,也可以通过.

2. 训练对比

精度对比:
image

性能对比:
image

@JYMiracle305 JYMiracle305 force-pushed the feature/add_checkpoint branch from e8c5dd5 to 0a3deb2 Compare April 24, 2026 09:22
@JYMiracle305 JYMiracle305 changed the title [WIP] Feat: add checkpoint loading mechanism Feat: add checkpoint loading mechanism Apr 29, 2026
Comment thread example/gpt2/main.cc Outdated
Comment thread infini_train/src/optimizer.cc
Comment thread infini_train/src/nn/modules/module.cc
Comment thread example/common/checkpoint_loader.cc
Comment thread example/common/checkpoint_loader.cc Outdated
Comment thread example/common/utils.cc Outdated
Comment thread infini_train/src/dataloader.cc Outdated
Comment thread infini_train/include/checkpoint.h Outdated
@JYMiracle305 JYMiracle305 requested a review from chen2021673 May 13, 2026 12:21
Comment thread example/llama3/main.cc Outdated
Comment thread example/common/utils.h Outdated
Comment thread example/common/checkpoint_loader.h
Comment thread example/common/checkpoint_loader.cc Outdated
Comment thread example/common/checkpoint_loader.cc Outdated

named_shard_params_.clear();
for (size_t i = 0; i < shard_params_.size(); ++i) {
named_shard_params_.emplace_back(shard_param_names_[i], shard_params_[i]);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里named_shard_params_有没有可能出现多个相同 name

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个name不会重复,name对应各层的参数名(模块名加index

@JYMiracle305 JYMiracle305 force-pushed the feature/add_checkpoint branch 7 times, most recently from 0a089ae to 4981cd4 Compare May 20, 2026 09:57
Comment thread example/common/checkpoint_loader.cc
Comment thread example/common/checkpoint_loader.h Outdated
Comment thread example/gpt2/config.h Outdated
Comment thread example/gpt2/main.cc Outdated

// precision
DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)");
DEFINE_uint32(save_steps, 0, "save checkpoint every N steps; 0 disables saving");

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/NVIDIA/Megatron-LM/blob/main/examples/llama/train_llama3_8b_h100_fp8.sh

这个参数在 megatron 里应该是叫 --save-interval,其余参数也都确认下,
命名和 megatron 对齐吧。

以及现在 main 里的参数有点多了,感觉之后可以考虑类似 megatron 那样把参数按不同类型分个组,先加个 TODO 记一下。

@JYMiracle305 JYMiracle305 May 27, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

做了如下替换
save_interval ----- save_steps
load ---- resume_from
save ---- checkpoint_dir
no_save_optim ---- save_optimizer_state // 语义是反的,默认值为false
checkpoint_format 已删除

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确实。。no_save_optim 这个语义有点奇怪了,还是改回正向配置吧

Comment thread example/llama3/checkpoint_loader.cc
Comment thread infini_train/include/checkpoint.h Outdated
Comment thread infini_train/include/checkpoint.h Outdated
Comment thread infini_train/src/checkpoint.cc Outdated
Comment thread example/gpt2/main.cc Outdated
Comment thread example/gpt2/main.cc Outdated
@kilinchange

Copy link
Copy Markdown
Collaborator

建议补一个文档(飞书文档就行,不用放仓库里),介绍一下目前 pth 的格式与 torch/megatron ckpt 格式,解释一下现有格式与主流框架的区别、未来可能的兼容方式。

Comment thread infini_train/src/checkpoint.cc
const std::unordered_map<std::string, std::shared_ptr<Tensor>> &state_dict);

static std::unordered_map<std::string, std::shared_ptr<Tensor>>
LoadStateDictBinary(const std::filesystem::path &path);

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

binary 相关的 save/load 保留原有的功能函数形式就行,没必要做到 ckpt 里,ckpt 就只做 pth 的结构化存取就行。

@JYMiracle305 JYMiracle305 May 29, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checkpoint 里不再支持binary格式,相关为了做format适配的flag 以及 options结构体都删除

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在新增的格式后缀记成了 .ckpt, SaveStateDictBinary 和 LoadStateDictBinary 针对的就是 .ckpt。上次修改把LoadFromLLMC (对旧格式的读取)和 SaveLLMC 移出了 checkpoint机制,并且不再有 SaveLLMC 的调用,也在修改中删了。

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉这俩函数命名叫 SaveStateDict/LoadStateDict 更合适,跟下面的 SaveTrainerState/LoadTrainerState 对齐(都不带 binary 后缀)

Comment thread infini_train/include/checkpoint.h Outdated
Comment thread infini_train/include/checkpoint.h Outdated
@JYMiracle305 JYMiracle305 force-pushed the feature/add_checkpoint branch 4 times, most recently from 0b4857d to 8ebe12f Compare May 29, 2026 02:01
@JYMiracle305

Copy link
Copy Markdown
Contributor Author

建议补一个文档(飞书文档就行,不用放仓库里),介绍一下目前 pth 的格式与 torch/megatron ckpt 格式,解释一下现有格式与主流框架的区别、未来可能的兼容方式。

之前的文档有介绍现在的格式排布,待补充torch/megatron 用到的格式介绍。

@JYMiracle305 JYMiracle305 force-pushed the feature/add_checkpoint branch 6 times, most recently from c499288 to ebaeadf Compare June 8, 2026 02:43
ArcaLunar and others added 8 commits June 10, 2026 14:32
format: use clang-format-16 instead
remove redundent arguments
   - Use name-based optimizer state keys instead of index-based to
     prevent state corruption from unordered_map traversal order
   - Warn on unexpected keys when loading model state dict
   - Validate parallel topology (TP/PP/SP) consistency on resume
   - Add batch_idx alignment check for distributed data loader
   - Default best_loss to infinity instead of zero
Comment thread example/gpt2/checkpoint_loader.cc
Comment thread example/gpt2/config.h
Comment thread infini_train/src/nn/modules/module.cc Outdated
Comment thread infini_train/src/nn/modules/module.cc
Comment thread infini_train/include/utils/string_utils.h Outdated
Comment thread tests/optimizer/test_optimizer_step.cc Outdated
Comment thread tests/optimizer/test_optimizer_creation.cc Outdated
Comment thread scripts/__pycache__/compare_utils.cpython-312.pyc Outdated
Comment thread example/gpt2/main.cc Outdated
Comment thread example/gpt2/main.cc
@JYMiracle305 JYMiracle305 force-pushed the feature/add_checkpoint branch 4 times, most recently from 8dc64b3 to 23a5e32 Compare June 11, 2026 09:52
@JYMiracle305 JYMiracle305 requested a review from kilinchange June 11, 2026 10:16
@JYMiracle305 JYMiracle305 force-pushed the feature/add_checkpoint branch 5 times, most recently from 68b7c77 to 702a9bd Compare June 12, 2026 03:32
@JYMiracle305 JYMiracle305 force-pushed the feature/add_checkpoint branch from 702a9bd to e8b1e37 Compare June 12, 2026 05:00
const std::unordered_map<std::string, std::shared_ptr<Tensor>> &state_dict);

static std::unordered_map<std::string, std::shared_ptr<Tensor>>
LoadStateDictBinary(const std::filesystem::path &path);

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉这俩函数命名叫 SaveStateDict/LoadStateDict 更合适,跟下面的 SaveTrainerState/LoadTrainerState 对齐(都不带 binary 后缀)

constexpr uint32_t kCkptMagic = 0x54504B43; // CKPT
constexpr uint32_t kCkptVersion = 1;

uint32_t PeekMagic(const std::filesystem::path &path) {

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PeekMagic 这个命名有点容易让人联想到 istream::peek() 的语义(查看但不消费数据),而这里实际上是打开文件并读取 magic number(实际会消费数据)。建议改成 ReadMagic() 或 ReadMagicNumber(),语义会更直接一些。

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没用到这个函数,先删掉

return s;
}

std::string ExtractStringField(const std::string &content, const std::string &key, const std::string &fallback) {

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数没被用到。

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删掉

return content.substr(first_quote + 1, second_quote - first_quote - 1);
}

template <typename T> T ExtractNumberField(const std::string &content, const std::string &key, T fallback) {

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里实际上是在手写一个简易 JSON 解析器。建议加个 TODO 标注一下,后续引入 JSON 库后统一替换,避免长期维护这套字符串解析逻辑。

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}

SaveTrainerState(checkpoint_dir / "trainer_state.json", state);
LOG(ERROR) << "[CKPT] Save done: dir=" << checkpoint_dir;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

应该是 LOG(INFO) 吧?下面 Checkpoint::Load 的几个 LOG(ERROR) 也应该是 INFO。

@JYMiracle305 JYMiracle305 Jun 15, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里save/load 写成LOG(ERROR) 好一些,不然不开维测感知不到默认保存了checkpoint,其他分支改成LOG(INFO)

class Optimizer;

using OptimizerCreator = std::function<std::shared_ptr<Optimizer>(const std::vector<std::shared_ptr<Tensor>> &params)>;
using OptimizerCreatorNamed = std::function<std::shared_ptr<Optimizer>(

@kilinchange kilinchange Jun 14, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前框架里似乎没有实际使用 named_parameters 构造 Optimizer 的场景,而当前实现中参数名在构造后也没有被利用,StateDict 仍然是基于下标生成 adam.m.{i}、adam.v.{i} 等 key,与普通 Parameters() 构造方式本质上没有区别。

建议暂时不引入 OptimizerCreatorNamed 及相关 named parameters 构造接口,后续如果需要支持 param group、按参数名配置超参(如 weight decay 分组)、或者对齐 Torch 的 optimizer state_dict 设计,再统一引入会更自然一些。

从 Torch 的实现来看,optimizer state_dict 也是围绕 state + param_groups 组织的,而不是直接基于参数名保存状态,后续如果要支持 param group,接口设计可能还需要进一步调整:
https://docs.pytorch.org/docs/2.12/generated/torch.optim.Optimizer.state_dict.html?utm_source=chatgpt.com#torch-optim-optimizer-state-dict

Comment thread scripts/test_config.json
"checkpoint_dir": "/data1/ckpt/bf16_resume"
}
},
{

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

理论上 ckpt 需要特殊处理 lora 的 A/B 权重读取逻辑,lora 用例也应当设置 lora_save_path/lora_load_path,在 #150 修复了,暂时还没合入。

建议先将 lora 相关用例从本 pr 中移除,lora 与 checkpoint 的集成逻辑可以后续单独提 pr 支持,可以和 @chen2021673 讨论确认下集成逻辑。

COMPARE_LOG_DIR="$(read_var COMPARE_LOG_DIR)"; : "${COMPARE_LOG_DIR:=}"
RUN_CTEST="$(read_var RUN_CTEST)"; : "${RUN_CTEST:=true}"
CTEST_CMD="$(read_var CTEST_CMD)"; : "${CTEST_CMD:=ctest --output-on-failure -LE cuda -j$(nproc) && ctest --output-on-failure -L cuda -j1}"
CKPT_CLEAN_DIRS=(

@kilinchange kilinchange Jun 14, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么会有两个 checkpoint 路径?另外,建议在 test_config.json 中统一增加一个 CKPT_DIR 配置,后续无论是 test_config.json 里的传参、checkpoint 清理,还是其他依赖 checkpoint 路径的逻辑,都基于该配置拼接生成路径。

这样可以避免路径配置分散在多个地方,后续修改目录结构时出现配置不一致的问题。

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件里的逻辑已经不只是 example 层的辅助函数了,包含了 checkpoint 保存/加载、分布式拓扑校验、模型配置校验以及 checkpoint pruning 等通用能力,建议抽到框架侧,例如infini_train/include/checkpoint/checkpoint_manager.h。(Checkpoint 类的声明/实现也可以一并迁移到 checkpoint 目录下)


LOG(INFO) << std::format("Checkpoint saved at: {} ({:.2f} ms)", args.save_dir.string(), ckpt_ms);

if (!args.prune_step_checkpoints) {

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prune_step_checkpoints 这个开关有必要吗?感觉保留 max_checkpoint_keep 就足够了,例如约定:

max_checkpoint_keep = 0  -> 不清理
max_checkpoint_keep > 0  -> 保留最近 N 个

另外下面的清理逻辑是直接对 ckpt 字符串路径做排序,如果 ckpt 目录名不做固定宽度补零(例如 checkpoint_step_10、checkpoint_step_100),可能导致清理顺序错误。
建议先留个 FIXME,后续明确 checkpoint 命名规范(统一补零),或者显式解析 step 后按数值排序。这样会更稳妥一些。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants