diff --git a/paddle/fluid/operators/controlflow/compare_all_op.cc b/paddle/fluid/operators/controlflow/compare_all_op.cc index 9442c7583d98fec2d46a96be136eae9bf4d20634..ede349f737d899e5f04cb5e35d1dbc0c0abc2403 100644 --- a/paddle/fluid/operators/controlflow/compare_all_op.cc +++ b/paddle/fluid/operators/controlflow/compare_all_op.cc @@ -30,29 +30,13 @@ class CompareReduceOpKernel auto* x = context.Input("X"); auto* y = context.Input("Y"); auto* z = context.Output("Out"); - bool shape_same = true; - Tensor tmp; - framework::DDim x_dims = x->dims(); - framework::DDim y_dims = y->dims(); - - // judge the two inputs shape is same, if not same, just return false - if (x_dims.size() != y_dims.size()) { - shape_same = false; - } else { - for (auto i = 0; i < x_dims.size(); i++) { - if (x_dims[i] != y_dims[i]) { - shape_same = false; - break; - } - } - } - bool* z_data = z->mutable_data(context.GetPlace()); - if (!shape_same) { + + if (x->dims() != y->dims()) { z_data[0] = false; } else { - tmp.mutable_data(x_dims, context.GetPlace()); + tmp.mutable_data(x->dims(), context.GetPlace()); if (x->numel() == 1 && y->numel() == 1) { bool* z_data = tmp.mutable_data(context.GetPlace()); z_data[0] = Functor()(x->data()[0], y->data()[0]); diff --git a/paddle/fluid/operators/controlflow/compare_all_op.cu b/paddle/fluid/operators/controlflow/compare_all_op.cu index 3753ed6b15f1e369c6f8777f939ffd3d8317fba0..9e22d74d6e2aac97ad23f99ad9d5b6a7f9924bbe 100644 --- a/paddle/fluid/operators/controlflow/compare_all_op.cu +++ b/paddle/fluid/operators/controlflow/compare_all_op.cu @@ -14,14 +14,18 @@ limitations under the License. */ #include #include "paddle/fluid/operators/controlflow/compare_all_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" #include "paddle/fluid/operators/reduce_ops/cub_reduce.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + namespace paddle { namespace operators { template struct IdentityFunctor { HOSTDEVICE explicit inline IdentityFunctor() {} - HOSTDEVICE inline T operator()(const T& x) const { return x; } }; @@ -33,6 +37,24 @@ struct BitwiseAdd { return a & b; } }; + +template +struct CudaEqualReduceFunctor { + using ELEM_TYPE = T; + HOSTDEVICE bool operator()(const T args[]) const { + return (args[0] == args[1]); + } +}; + +template +struct CudaEqualReduceFunctor< + T, typename std::enable_if::value>::type> { + using ELEM_TYPE = T; + HOSTDEVICE bool operator()(const T args[]) const { + return fabs(static_cast(args[0] - args[1])) < 1e-8; + } +}; + template class CompareReduceOpKernel : public framework::OpKernel { @@ -44,32 +66,22 @@ class CompareReduceOpKernel auto* x = context.Input("X"); auto* y = context.Input("Y"); auto* z = context.Output("Out"); - bool shape_same = true; - + bool* z_data = z->mutable_data(context.GetPlace()); Tensor tmp; - framework::DDim x_dims = x->dims(); - framework::DDim y_dims = y->dims(); - if (x_dims.size() != y_dims.size()) { - shape_same = false; - } else { - for (auto i = 0; i < x_dims.size(); i++) { - if (x_dims[i] != y_dims[i]) { - shape_same = false; - break; - } - } - } - - bool* z_data = z->mutable_data(context.GetPlace()); - if (!shape_same) { + if (x->dims() != y->dims()) { thrust::device_ptr z_dev_ptr(z_data); thrust::fill(z_dev_ptr, z_dev_ptr + 1, false); return; } else { - tmp.mutable_data(x_dims, context.GetPlace()); - ElementwiseComputeEx(context, x, y, 0, - Functor(), &tmp); + tmp.mutable_data(x->dims(), context.GetPlace()); + const auto& cuda_ctx = + context.template device_context(); + std::vector ins = {x, y}; + std::vector outs = {&tmp}; + LaunchSameDimsElementwiseCudaKernel( + cuda_ctx, ins, &outs, Functor()); + // Reduce by 'bitwise and' operator std::vector reduce_dims; reduce_dims.resize(tmp.dims().size()); @@ -85,18 +97,17 @@ class CompareReduceOpKernel } // namespace operators } // namespace paddle -#define REGISTER_COMPARE_REDUCE_CUDA_KERNEL(op_type, functor) \ - REGISTER_OP_CUDA_KERNEL( \ - op_type, paddle::operators::CompareReduceOpKernel< \ - paddle::platform::CUDADeviceContext, functor>, \ - paddle::operators::CompareReduceOpKernel< \ - paddle::platform::CUDADeviceContext, functor>, \ - paddle::operators::CompareReduceOpKernel< \ - paddle::platform::CUDADeviceContext, functor>, \ - paddle::operators::CompareReduceOpKernel< \ - paddle::platform::CUDADeviceContext, functor>, \ - paddle::operators::CompareReduceOpKernel< \ - paddle::platform::CUDADeviceContext, functor>); - -REGISTER_COMPARE_REDUCE_CUDA_KERNEL(equal_all, - paddle::operators::EqualReduceFunctor); +#define REGISTER_COMPARE_REDUCE_CUDA_KERNEL(op_type, functor) \ + REGISTER_OP_CUDA_KERNEL( \ + op_type, \ + ops::CompareReduceOpKernel>, \ + ops::CompareReduceOpKernel>, \ + ops::CompareReduceOpKernel>, \ + ops::CompareReduceOpKernel>, \ + ops::CompareReduceOpKernel>); + +REGISTER_COMPARE_REDUCE_CUDA_KERNEL(equal_all, CudaEqualReduceFunctor) +#undef REGISTER_COMPARE_REDUCE_CUDA_KERNEL diff --git a/paddle/fluid/operators/controlflow/compare_op.cu b/paddle/fluid/operators/controlflow/compare_op.cu index 6f3a615edb44bebd4c8427303ea8e3331cff02b7..bf7861a03d8d4da4ff1ae65ff62c761ffab914bd 100644 --- a/paddle/fluid/operators/controlflow/compare_op.cu +++ b/paddle/fluid/operators/controlflow/compare_op.cu @@ -59,7 +59,6 @@ struct CudaNotEqualFunctor< template class CompareOpKernel : public framework::OpKernel { - public: public: using InT = typename Functor::ELEMENT_TYPE; using OutT = bool;