• G
    Support BF16 training for sharding (#46846) · 0b39b244
    Ghost Screaming 提交于
    * Fix bug of reduce_sum op. When input.numel() > INT32_MAX, its result
    is wrong.
    
    * support pure bfloat16
    
    * support bf16 linear
    
    * update PR to pass CI
    
    * tiny fix where_grad_kernel.cu
    
    * Support bfloat16 type for reducer and sharding.
    
    * Fix some bug.
    
    * Polish code.
    
    * Polise code.
    
    * Add bfloat16 datatype in fill_grad kernels.
    Co-authored-by: Nsneaxiy <sneaxiy@126.com>
    0b39b244
group_sharded_optimizer_stage2.py 22.3 KB