diff --git a/paddle/fluid/operators/elementwise_add_op.h b/paddle/fluid/operators/elementwise_add_op.h index d85f785283cc0e98a14f9ab4292fff5f8130c03f..baf04c30b17cb333fc8a6544afd6c479442f835b 100644 --- a/paddle/fluid/operators/elementwise_add_op.h +++ b/paddle/fluid/operators/elementwise_add_op.h @@ -85,7 +85,7 @@ struct IdentityGrad { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; } }; -template +template void default_elementwise_add_grad(const framework::ExecutionContext& ctx, const framework::Tensor* x, const framework::Tensor* y, @@ -100,16 +100,15 @@ void default_elementwise_add_grad(const framework::ExecutionContext& ctx, IdentityGrad()); } -template +template typename std::enable_if< std::is_floating_point::value && std::is_same::value>::type elementwise_add_grad(const framework::ExecutionContext& ctx, - const framework::Tensor* x, - const framework::Tensor* y, + const framework::Tensor* x, const framework::Tensor* y, const framework::Tensor* out, - const framework::Tensor* dout, - framework::Tensor* dx, framework::Tensor* dy) { + const framework::Tensor* dout, framework::Tensor* dx, + framework::Tensor* dy) { auto blas = math::GetBlas(ctx); if (dx) { @@ -123,16 +122,15 @@ elementwise_add_grad(const framework::ExecutionContext& ctx, } } -template +template typename std::enable_if< !std::is_floating_point::value || !std::is_same::value>::type elementwise_add_grad(const framework::ExecutionContext& ctx, - const framework::Tensor* x, - const framework::Tensor* y, + const framework::Tensor* x, const framework::Tensor* y, const framework::Tensor* out, - const framework::Tensor* dout, - framework::Tensor* dx, framework::Tensor* dy) { + const framework::Tensor* dout, framework::Tensor* dx, + framework::Tensor* dy) { default_elementwise_add_grad(ctx, x, y, out, dout, dx, dy); } @@ -152,8 +150,8 @@ class ElementwiseAddGradKernel : public framework::OpKernel { if (platform::is_cpu_place(ctx.GetPlace()) && (x->dims() == y->dims())) { elementwise_add_grad(ctx, x, y, out, dout, dx, dy); } else { - default_elementwise_add_grad( - ctx, x, y, out, dout, dx, dy); + default_elementwise_add_grad(ctx, x, y, out, dout, dx, + dy); } } };