From 01fb2be908a2f05abe72666df770d3fc57e7ddb5 Mon Sep 17 00:00:00 2001 From: Tomasz Patejko Date: Mon, 21 May 2018 05:39:53 +0200 Subject: [PATCH] MKL elementwise add: default implementation used for integral types, float16 and/or GPU --- paddle/fluid/operators/elementwise_add_op.h | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/elementwise_add_op.h b/paddle/fluid/operators/elementwise_add_op.h index 1f8735b7b..d75d86c24 100644 --- a/paddle/fluid/operators/elementwise_add_op.h +++ b/paddle/fluid/operators/elementwise_add_op.h @@ -36,9 +36,12 @@ void default_elementwise_add(const framework::ExecutionContext& ctx, } template -typename std::enable_if::value>::type elementwise_add( - const framework::ExecutionContext& ctx, const framework::Tensor* x, - const framework::Tensor* y, framework::Tensor* z) { +typename std::enable_if< + std::is_floating_point::value && + std::is_same::value>::type +elementwise_add(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + framework::Tensor* z) { auto eigen_x = framework::EigenVector::Flatten(*x); auto eigen_y = framework::EigenVector::Flatten(*y); auto eigen_z = framework::EigenVector::Flatten(*z); @@ -48,9 +51,12 @@ typename std::enable_if::value>::type elementwise_add( } template -typename std::enable_if::value>::type elementwise_add( - const framework::ExecutionContext& ctx, const framework::Tensor* x, - const framework::Tensor* y, framework::Tensor* z) { +typename std::enable_if< + !std::is_floating_point::value || + !std::is_same::value>::type +elementwise_add(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + framework::Tensor* z) { default_elementwise_add(ctx, x, y, z); } @@ -66,7 +72,7 @@ class ElementwiseAddKernel : public framework::OpKernel { z->mutable_data(ctx.GetPlace()); auto dims_equal = x->dims() == y->dims(); - if (platform::is_cpu_place(ctx.GetPlace()) && dims_equal) { + if (dims_equal) { elementwise_add(ctx, x, y, z); } else { default_elementwise_add(ctx, x, y, z); -- GitLab