diff --git a/paddle/fluid/operators/elementwise_add_op.h b/paddle/fluid/operators/elementwise_add_op.h index 1f8735b7b17e9b212889532ca1677d678bfcff2f..d75d86c242d3f06db51fd3623a9d42e944791a96 100644 --- a/paddle/fluid/operators/elementwise_add_op.h +++ b/paddle/fluid/operators/elementwise_add_op.h @@ -36,9 +36,12 @@ void default_elementwise_add(const framework::ExecutionContext& ctx, } template -typename std::enable_if::value>::type elementwise_add( - const framework::ExecutionContext& ctx, const framework::Tensor* x, - const framework::Tensor* y, framework::Tensor* z) { +typename std::enable_if< + std::is_floating_point::value && + std::is_same::value>::type +elementwise_add(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + framework::Tensor* z) { auto eigen_x = framework::EigenVector::Flatten(*x); auto eigen_y = framework::EigenVector::Flatten(*y); auto eigen_z = framework::EigenVector::Flatten(*z); @@ -48,9 +51,12 @@ typename std::enable_if::value>::type elementwise_add( } template -typename std::enable_if::value>::type elementwise_add( - const framework::ExecutionContext& ctx, const framework::Tensor* x, - const framework::Tensor* y, framework::Tensor* z) { +typename std::enable_if< + !std::is_floating_point::value || + !std::is_same::value>::type +elementwise_add(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + framework::Tensor* z) { default_elementwise_add(ctx, x, y, z); } @@ -66,7 +72,7 @@ class ElementwiseAddKernel : public framework::OpKernel { z->mutable_data(ctx.GetPlace()); auto dims_equal = x->dims() == y->dims(); - if (platform::is_cpu_place(ctx.GetPlace()) && dims_equal) { + if (dims_equal) { elementwise_add(ctx, x, y, z); } else { default_elementwise_add(ctx, x, y, z);