提交 5a622c29 编写于 作者: T Tomasz Patejko

MKL elementwise add backward: Initial implementation with vector copy

上级 01fb2be9
...@@ -25,6 +25,6 @@ REGISTER_OP_CPU_KERNEL( ...@@ -25,6 +25,6 @@ REGISTER_OP_CPU_KERNEL(
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_add_grad, elementwise_add_grad,
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, float>, ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, double>, ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, double>);
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int>, // ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int64_t>); // ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
...@@ -98,10 +98,23 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> { ...@@ -98,10 +98,23 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> {
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
if (platform::is_cpu_place(ctx.GetPlace()) && (x->dims() == y->dims())) {
auto blas = math::GetBlas<DeviceContext, T>(ctx);
if (dx)
dx->mutable_data<T>(ctx.GetPlace());
if (dy)
dy->mutable_data<T>(ctx.GetPlace());
blas.VCOPY(dout->numel(), dout->data<T>(), dx->data<T>());
blas.VCOPY(dout->numel(), dout->data<T>(), dy->data<T>());
} else {
ElemwiseGradCompute<DeviceContext, T, IdentityGrad<T>, IdentityGrad<T>>( ElemwiseGradCompute<DeviceContext, T, IdentityGrad<T>, IdentityGrad<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad<T>(), ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad<T>(),
IdentityGrad<T>()); IdentityGrad<T>());
} }
}
}; };
} // namespace operators } // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册