提交 9241011b 编写于 作者: T Tomasz Patejko

MKL elementwise add backward: backward works for integral types with fall back to default impl

上级 fde47aae
...@@ -25,6 +25,6 @@ REGISTER_OP_CPU_KERNEL( ...@@ -25,6 +25,6 @@ REGISTER_OP_CPU_KERNEL(
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_add_grad, elementwise_add_grad,
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, float>, ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, double>); ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, double>,
// ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int>, ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int>,
// ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
...@@ -85,6 +85,57 @@ struct IdentityGrad { ...@@ -85,6 +85,57 @@ struct IdentityGrad {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; } HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; }
}; };
template<typename DeviceContext, typename T>
void default_elementwise_add_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout,
framework::Tensor* dx,
framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");
ElemwiseGradCompute<DeviceContext, T, IdentityGrad<T>, IdentityGrad<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad<T>(),
IdentityGrad<T>());
}
template<typename DeviceContext, typename T>
typename std::enable_if<
std::is_floating_point<T>::value &&
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_add_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy) {
auto blas = math::GetBlas<DeviceContext, T>(ctx);
if (dx) {
blas.VCOPY(dout->numel(), dout->data<T>(),
dx->mutable_data<T>(ctx.GetPlace()));
}
if (dy) {
blas.VCOPY(dout->numel(), dout->data<T>(),
dy->mutable_data<T>(ctx.GetPlace()));
}
}
template<typename DeviceContext, typename T>
typename std::enable_if<
!std::is_floating_point<T>::value ||
!std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_add_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy) {
default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ElementwiseAddGradKernel : public framework::OpKernel<T> { class ElementwiseAddGradKernel : public framework::OpKernel<T> {
public: public:
...@@ -97,24 +148,12 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> { ...@@ -97,24 +148,12 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> {
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
if (platform::is_cpu_place(ctx.GetPlace()) && (x->dims() == y->dims())) { if (platform::is_cpu_place(ctx.GetPlace()) && (x->dims() == y->dims())) {
auto blas = math::GetBlas<DeviceContext, T>(ctx); elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
if (dx) {
blas.VCOPY(dout->numel(), dout->data<T>(),
dx->mutable_data<T>(ctx.GetPlace()));
}
if (dy) {
blas.VCOPY(dout->numel(), dout->data<T>(),
dy->mutable_data<T>(ctx.GetPlace()));
}
} else { } else {
ElemwiseGradCompute<DeviceContext, T, IdentityGrad<T>, IdentityGrad<T>>( default_elementwise_add_grad<DeviceContext, T>(
ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad<T>(), ctx, x, y, out, dout, dx, dy);
IdentityGrad<T>());
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册