diff --git a/oneflow/core/kernel/model_diff_accumulate_kernel.cpp b/oneflow/core/kernel/model_diff_accumulate_kernel.cpp index 9c2d2a5835754d801ebaba142c90cc8cc1151d3b..80f913b4f2713ec069d16fbdfa5dd5a962d01b06 100644 --- a/oneflow/core/kernel/model_diff_accumulate_kernel.cpp +++ b/oneflow/core/kernel/model_diff_accumulate_kernel.cpp @@ -9,8 +9,7 @@ void MdDiffAccKernel::Forward( Blob* in_blob = BnInOp2BlobPtr("model_diff"); Blob* out_blob = BnInOp2BlobPtr("model_diff_acc"); KernelUtil::BlasAxpy( - ctx, in_blob->shape().elem_cnt() * sizeof(FloatingPointType), - static_cast(1.0), + ctx, in_blob->shape().elem_cnt(), static_cast(1.0), static_cast(in_blob->dptr()), 1, static_cast(out_blob->mut_dptr()), 1); }