Support BF16 training for sharding (#46846)
* Fix bug of reduce_sum op. When input.numel() > INT32_MAX, its result
is wrong.
* support pure bfloat16
* support bf16 linear
* update PR to pass CI
* tiny fix where_grad_kernel.cu
* Support bfloat16 type for reducer and sharding.
* Fix some bug.
* Polish code.
* Polise code.
* Add bfloat16 datatype in fill_grad kernels.
Co-authored-by: Nsneaxiy <sneaxiy@126.com>
Showing
想要评论请 注册 或 登录