From 0f1549619accaad66ae35e5876e50fff102c6a7f Mon Sep 17 00:00:00 2001 From: limingshu <61349199+JamesLim-sy@users.noreply.github.com> Date: Wed, 2 Jun 2021 09:22:19 +0800 Subject: [PATCH] Reimplement the comparision binary ops using the new optimized CUDA function (#33064) --- .../fluid/operators/controlflow/compare_op.cu | 95 ++++++++++++++++--- .../elementwise/elementwise_add_op.cu | 17 +--- .../elementwise/elementwise_op_broadcast.cu.h | 24 +++-- .../elementwise/elementwise_op_function.h | 20 ++++ 4 files changed, 120 insertions(+), 36 deletions(-) diff --git a/paddle/fluid/operators/controlflow/compare_op.cu b/paddle/fluid/operators/controlflow/compare_op.cu index a60201f9d0..a52920d9e8 100644 --- a/paddle/fluid/operators/controlflow/compare_op.cu +++ b/paddle/fluid/operators/controlflow/compare_op.cu @@ -13,18 +13,85 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/controlflow/compare_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" -REGISTER_COMPARE_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor, - paddle::operators::GreaterThanFunctor); -REGISTER_COMPARE_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor, - paddle::operators::GreaterEqualFunctor); -REGISTER_COMPARE_KERNEL(greater_than, CUDA, - paddle::operators::GreaterThanFunctor, - paddle::operators::LessThanFunctor); -REGISTER_COMPARE_KERNEL(greater_equal, CUDA, - paddle::operators::GreaterEqualFunctor, - paddle::operators::LessEqualFunctor); -REGISTER_COMPARE_KERNEL(equal, CUDA, paddle::operators::EqualFunctor, - paddle::operators::EqualFunctor); -REGISTER_COMPARE_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor, - paddle::operators::NotEqualFunctor); +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +namespace paddle { +namespace operators { + +#define DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(Func, op) \ + template \ + struct Func##Functor { \ + using ELEMENT_TYPE = T; \ + inline HOSTDEVICE bool operator()(const T* args) const { \ + return args[0] op args[1]; \ + } \ + }; + +DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaLessThan, <) +DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaLessEqual, <=) +DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaGreaterThan, >) +DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaGreaterEqual, >=) +DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaEqual, ==) +DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaNotEqual, !=) +#undef DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT + +template +struct CudaEqualFunctor< + T, typename std::enable_if::value>::type> { + using ELEMENT_TYPE = T; + HOSTDEVICE bool operator()(const T* args) const { + return fabs(static_cast(args[0] - args[1])) < 1e-8; + } +}; + +template +struct CudaNotEqualFunctor< + T, typename std::enable_if::value>::type> { + using ELEMENT_TYPE = T; + HOSTDEVICE bool operator()(const T* args) const { + return fabs(static_cast(args[0] - args[1])) > 1e-8; + } +}; + +template +class CompareOpKernel + : public framework::OpKernel { + public: + public: + using InT = typename Functor::ELEMENT_TYPE; + using OutT = bool; + void Compute(const framework::ExecutionContext& ctx) const override { + auto functor = Functor(); + std::vector ins; + std::vector outs; + + PackTensorsIntoVector(ctx, &ins, &outs); + LaunchElementwiseCudaKernel( + ctx, ins, &outs, functor); + } +}; + +} // namespace operators +} // namespace paddle + +#define REGISTER_CUDA_COMPARE_KERNEL(op_type, func) \ + REGISTER_OP_CUDA_KERNEL( \ + op_type, ops::CompareOpKernel, void>, \ + ops::CompareOpKernel, void>, \ + ops::CompareOpKernel, \ + void>, \ + ops::CompareOpKernel, void>); + +REGISTER_CUDA_COMPARE_KERNEL(equal, CudaEqual) +REGISTER_CUDA_COMPARE_KERNEL(not_equal, CudaNotEqual) +REGISTER_CUDA_COMPARE_KERNEL(less_than, CudaLessThan) +REGISTER_CUDA_COMPARE_KERNEL(less_equal, CudaLessEqual) +REGISTER_CUDA_COMPARE_KERNEL(greater_than, CudaGreaterThan) +REGISTER_CUDA_COMPARE_KERNEL(greater_equal, CudaGreaterEqual) +#undef REGISTER_CUDA_COMPARE_KERNEL diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.cu b/paddle/fluid/operators/elementwise/elementwise_add_op.cu index 37e5fa5a20..aad5303d2e 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.cu @@ -42,20 +42,11 @@ class ElementwiseAddKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* z = ctx.Output("Out"); - z->mutable_data(ctx.GetPlace()); - int axis = ctx.Attr("axis"); - axis = axis == -1 ? std::abs(x->dims().size() - y->dims().size()) : axis; - - std::vector ins = {x, y}; - std::vector outs = {z}; - const auto& cuda_ctx = - ctx.template device_context(); - + std::vector ins; + std::vector outs; + PackTensorsIntoVector(ctx, &ins, &outs); LaunchElementwiseCudaKernel( - cuda_ctx, ins, &outs, axis, CudaAddFunctor()); + ctx, ins, &outs, CudaAddFunctor()); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h index 1492fc6294..0612d01b6b 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h @@ -343,7 +343,6 @@ template +template void LaunchElementwiseCudaKernel( - const platform::CUDADeviceContext &cuda_ctx, + const framework::ExecutionContext &ctx, const std::vector &ins, - std::vector *outs, int axis, Functor func) { + std::vector *outs, Functor func) { + std::vector dims_size; bool no_broadcast_flag = true; for (auto *in : ins) { no_broadcast_flag = ins[0]->dims() == in->dims(); + dims_size.emplace_back(in->dims().size()); } - + const auto &cuda_ctx = + ctx.template device_context(); if (no_broadcast_flag) { - LaunchSameDimsElementwiseCudaKernel( + LaunchSameDimsElementwiseCudaKernel( cuda_ctx, ins, outs, func); } else { - LaunchBroadcastElementwiseCudaKernel(cuda_ctx, ins, outs, axis, - func); + int axis = ctx.HasAttr("axis") ? ctx.Attr("axis") : -1; + axis = axis == -1 + ? *std::max_element(dims_size.begin(), dims_size.end()) - + *std::min_element(dims_size.begin(), dims_size.end()) + : axis; + LaunchBroadcastElementwiseCudaKernel(cuda_ctx, ins, outs, + axis, func); } } diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h index 32e49cf399..05b78bcf6a 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.h @@ -60,6 +60,26 @@ constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024; namespace paddle { namespace operators { +/* +* To pack the input and output tnesors into vector for +* LaunchElementwiseCudaKernel +*/ +template +void PackTensorsIntoVector(const framework::ExecutionContext &ctx, + std::vector *ins, + std::vector *outs) { + auto *x = ctx.Input("X"); + auto *y = ctx.Input("Y"); + auto *z = ctx.Output("Out"); + z->mutable_data(ctx.GetPlace()); + ins->emplace_back(x); + outs->emplace_back(z); + + if (y != nullptr) { + ins->emplace_back(y); + } +} + /* * Out = X ⊙ Y * If Y's shape does not match X' shape, they will be reshaped. -- GitLab