提交 97134844 编写于 作者: L Liu Guo 提交者: chengtbf

model diff acc kernel (#197)

* refine copy actors and model diff acc actor

* use set_num_of_not_eord

* fix element cnt
上级 dccd1cee
......@@ -9,8 +9,7 @@ void MdDiffAccKernel<device_type, FloatingPointType>::Forward(
Blob* in_blob = BnInOp2BlobPtr("model_diff");
Blob* out_blob = BnInOp2BlobPtr("model_diff_acc");
KernelUtil<device_type, FloatingPointType>::BlasAxpy(
ctx, in_blob->shape().elem_cnt() * sizeof(FloatingPointType),
static_cast<FloatingPointType>(1.0),
ctx, in_blob->shape().elem_cnt(), static_cast<FloatingPointType>(1.0),
static_cast<const FloatingPointType*>(in_blob->dptr()), 1,
static_cast<FloatingPointType*>(out_blob->mut_dptr()), 1);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册