From 3e876b3e497c0aeef13a103d317fdb47eb6c3fc7 Mon Sep 17 00:00:00 2001 From: Tomasz Patejko Date: Thu, 24 May 2018 16:35:00 +0200 Subject: [PATCH] MKL optimized elementwise add: fix style check --- paddle/fluid/operators/elementwise_add_op.h | 24 ++++++++++----------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/operators/elementwise_add_op.h b/paddle/fluid/operators/elementwise_add_op.h index d85f785283c..baf04c30b17 100644 --- a/paddle/fluid/operators/elementwise_add_op.h +++ b/paddle/fluid/operators/elementwise_add_op.h @@ -85,7 +85,7 @@ struct IdentityGrad { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; } }; -template +template void default_elementwise_add_grad(const framework::ExecutionContext& ctx, const framework::Tensor* x, const framework::Tensor* y, @@ -100,16 +100,15 @@ void default_elementwise_add_grad(const framework::ExecutionContext& ctx, IdentityGrad()); } -template +template typename std::enable_if< std::is_floating_point::value && std::is_same::value>::type elementwise_add_grad(const framework::ExecutionContext& ctx, - const framework::Tensor* x, - const framework::Tensor* y, + const framework::Tensor* x, const framework::Tensor* y, const framework::Tensor* out, - const framework::Tensor* dout, - framework::Tensor* dx, framework::Tensor* dy) { + const framework::Tensor* dout, framework::Tensor* dx, + framework::Tensor* dy) { auto blas = math::GetBlas(ctx); if (dx) { @@ -123,16 +122,15 @@ elementwise_add_grad(const framework::ExecutionContext& ctx, } } -template +template typename std::enable_if< !std::is_floating_point::value || !std::is_same::value>::type elementwise_add_grad(const framework::ExecutionContext& ctx, - const framework::Tensor* x, - const framework::Tensor* y, + const framework::Tensor* x, const framework::Tensor* y, const framework::Tensor* out, - const framework::Tensor* dout, - framework::Tensor* dx, framework::Tensor* dy) { + const framework::Tensor* dout, framework::Tensor* dx, + framework::Tensor* dy) { default_elementwise_add_grad(ctx, x, y, out, dout, dx, dy); } @@ -152,8 +150,8 @@ class ElementwiseAddGradKernel : public framework::OpKernel { if (platform::is_cpu_place(ctx.GetPlace()) && (x->dims() == y->dims())) { elementwise_add_grad(ctx, x, y, out, dout, dx, dy); } else { - default_elementwise_add_grad( - ctx, x, y, out, dout, dx, dy); + default_elementwise_add_grad(ctx, x, y, out, dout, dx, + dy); } } }; -- GitLab