提交 3e876b3e 编写于 作者: T Tomasz Patejko

MKL optimized elementwise add: fix style check

上级 9241011b
...@@ -85,7 +85,7 @@ struct IdentityGrad { ...@@ -85,7 +85,7 @@ struct IdentityGrad {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; } HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; }
}; };
template<typename DeviceContext, typename T> template <typename DeviceContext, typename T>
void default_elementwise_add_grad(const framework::ExecutionContext& ctx, void default_elementwise_add_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* x,
const framework::Tensor* y, const framework::Tensor* y,
...@@ -100,16 +100,15 @@ void default_elementwise_add_grad(const framework::ExecutionContext& ctx, ...@@ -100,16 +100,15 @@ void default_elementwise_add_grad(const framework::ExecutionContext& ctx,
IdentityGrad<T>()); IdentityGrad<T>());
} }
template<typename DeviceContext, typename T> template <typename DeviceContext, typename T>
typename std::enable_if< typename std::enable_if<
std::is_floating_point<T>::value && std::is_floating_point<T>::value &&
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_add_grad(const framework::ExecutionContext& ctx, elementwise_add_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* y,
const framework::Tensor* out, const framework::Tensor* out,
const framework::Tensor* dout, const framework::Tensor* dout, framework::Tensor* dx,
framework::Tensor* dx, framework::Tensor* dy) { framework::Tensor* dy) {
auto blas = math::GetBlas<DeviceContext, T>(ctx); auto blas = math::GetBlas<DeviceContext, T>(ctx);
if (dx) { if (dx) {
...@@ -123,16 +122,15 @@ elementwise_add_grad(const framework::ExecutionContext& ctx, ...@@ -123,16 +122,15 @@ elementwise_add_grad(const framework::ExecutionContext& ctx,
} }
} }
template<typename DeviceContext, typename T> template <typename DeviceContext, typename T>
typename std::enable_if< typename std::enable_if<
!std::is_floating_point<T>::value || !std::is_floating_point<T>::value ||
!std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type !std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_add_grad(const framework::ExecutionContext& ctx, elementwise_add_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* y,
const framework::Tensor* out, const framework::Tensor* out,
const framework::Tensor* dout, const framework::Tensor* dout, framework::Tensor* dx,
framework::Tensor* dx, framework::Tensor* dy) { framework::Tensor* dy) {
default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy); default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
} }
...@@ -152,8 +150,8 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> { ...@@ -152,8 +150,8 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> {
if (platform::is_cpu_place(ctx.GetPlace()) && (x->dims() == y->dims())) { if (platform::is_cpu_place(ctx.GetPlace()) && (x->dims() == y->dims())) {
elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy); elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
} else { } else {
default_elementwise_add_grad<DeviceContext, T>( default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx,
ctx, x, y, out, dout, dx, dy); dy);
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册