-
由 Ghost Screaming 提交于
* Fix bug of reduce_sum op. When input.numel() > INT32_MAX, its result is wrong. * support pure bfloat16 * support bf16 linear * update PR to pass CI * tiny fix where_grad_kernel.cu * Support bfloat16 type for reducer and sharding. * Fix some bug. * Polish code. * Polise code. * Add bfloat16 datatype in fill_grad kernels. Co-authored-by: Nsneaxiy <sneaxiy@126.com>
0b39b244