提交 fde47aae 编写于 作者: T Tomasz Patejko

MKL elementwise add backward: grad inputs copied when they are not null

上级 996d12f1
...@@ -102,13 +102,15 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> { ...@@ -102,13 +102,15 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> {
if (platform::is_cpu_place(ctx.GetPlace()) && (x->dims() == y->dims())) { if (platform::is_cpu_place(ctx.GetPlace()) && (x->dims() == y->dims())) {
auto blas = math::GetBlas<DeviceContext, T>(ctx); auto blas = math::GetBlas<DeviceContext, T>(ctx);
if (dx) if (dx) {
dx->mutable_data<T>(ctx.GetPlace()); blas.VCOPY(dout->numel(), dout->data<T>(),
if (dy) dx->mutable_data<T>(ctx.GetPlace()));
dy->mutable_data<T>(ctx.GetPlace()); }
blas.VCOPY(dout->numel(), dout->data<T>(), dx->data<T>()); if (dy) {
blas.VCOPY(dout->numel(), dout->data<T>(), dy->data<T>()); blas.VCOPY(dout->numel(), dout->data<T>(),
dy->mutable_data<T>(ctx.GetPlace()));
}
} else { } else {
ElemwiseGradCompute<DeviceContext, T, IdentityGrad<T>, IdentityGrad<T>>( ElemwiseGradCompute<DeviceContext, T, IdentityGrad<T>, IdentityGrad<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad<T>(), ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad<T>(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册