From b432d0249557e50ac3ccaa7c0986bfaf4aa1f3fe Mon Sep 17 00:00:00 2001 From: limingshu <61349199+JamesLim-sy@users.noreply.github.com> Date: Wed, 2 Jun 2021 22:23:19 +0800 Subject: [PATCH] Support Add Sub Mul Max Min Pow binary functors in elementwise system (#33050) --- .../fluid/operators/controlflow/compare_op.cu | 47 +++++----- .../elementwise/elementwise_add_op.cu | 11 ++- .../elementwise/elementwise_add_op.h | 6 +- .../elementwise/elementwise_max_op.cu | 31 +++++++ .../elementwise/elementwise_min_op.cu | 31 +++++++ .../elementwise/elementwise_mul_op.cu | 85 +++++++++++++------ .../elementwise/elementwise_mul_op.h | 1 - .../elementwise/elementwise_op_broadcast.cu.h | 24 +++--- .../elementwise/elementwise_op_function.h | 16 ++-- .../elementwise/elementwise_pow_op.cu | 42 +++++++++ .../elementwise/elementwise_sub_op.cu | 47 ++++------ .../elementwise/elementwise_sub_op.h | 2 +- 12 files changed, 231 insertions(+), 112 deletions(-) diff --git a/paddle/fluid/operators/controlflow/compare_op.cu b/paddle/fluid/operators/controlflow/compare_op.cu index a52920d9e87..cc0c46adb11 100644 --- a/paddle/fluid/operators/controlflow/compare_op.cu +++ b/paddle/fluid/operators/controlflow/compare_op.cu @@ -21,21 +21,21 @@ namespace plat = paddle::platform; namespace paddle { namespace operators { -#define DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(Func, op) \ +#define DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(func, op) \ template \ - struct Func##Functor { \ + struct func { \ using ELEMENT_TYPE = T; \ inline HOSTDEVICE bool operator()(const T* args) const { \ return args[0] op args[1]; \ } \ }; -DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaLessThan, <) -DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaLessEqual, <=) -DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaGreaterThan, >) -DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaGreaterEqual, >=) -DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaEqual, ==) -DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaNotEqual, !=) +DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaLessThanFunctor, <) +DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaLessEqualFunctor, <=) +DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaGreaterThanFunctor, >) +DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaGreaterEqualFunctor, >=) +DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaEqualFunctor, ==) +DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaNotEqualFunctor, !=) #undef DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT template @@ -67,10 +67,12 @@ class CompareOpKernel auto functor = Functor(); std::vector ins; std::vector outs; + const auto& cuda_ctx = + ctx.template device_context(); - PackTensorsIntoVector(ctx, &ins, &outs); + int axis = PackTensorsIntoVector(ctx, &ins, &outs); LaunchElementwiseCudaKernel( - ctx, ins, &outs, functor); + cuda_ctx, ins, &outs, axis, functor); } }; @@ -79,19 +81,16 @@ class CompareOpKernel #define REGISTER_CUDA_COMPARE_KERNEL(op_type, func) \ REGISTER_OP_CUDA_KERNEL( \ - op_type, ops::CompareOpKernel, void>, \ - ops::CompareOpKernel, void>, \ - ops::CompareOpKernel, \ - void>, \ - ops::CompareOpKernel, void>); + op_type, \ + ops::CompareOpKernel, void>, \ + ops::CompareOpKernel, void>, \ + ops::CompareOpKernel, void>, \ + ops::CompareOpKernel, void>); -REGISTER_CUDA_COMPARE_KERNEL(equal, CudaEqual) -REGISTER_CUDA_COMPARE_KERNEL(not_equal, CudaNotEqual) -REGISTER_CUDA_COMPARE_KERNEL(less_than, CudaLessThan) -REGISTER_CUDA_COMPARE_KERNEL(less_equal, CudaLessEqual) -REGISTER_CUDA_COMPARE_KERNEL(greater_than, CudaGreaterThan) -REGISTER_CUDA_COMPARE_KERNEL(greater_equal, CudaGreaterEqual) +REGISTER_CUDA_COMPARE_KERNEL(equal, CudaEqualFunctor) +REGISTER_CUDA_COMPARE_KERNEL(not_equal, CudaNotEqualFunctor) +REGISTER_CUDA_COMPARE_KERNEL(less_than, CudaLessThanFunctor) +REGISTER_CUDA_COMPARE_KERNEL(less_equal, CudaLessEqualFunctor) +REGISTER_CUDA_COMPARE_KERNEL(greater_than, CudaGreaterThanFunctor) +REGISTER_CUDA_COMPARE_KERNEL(greater_equal, CudaGreaterEqualFunctor) #undef REGISTER_CUDA_COMPARE_KERNEL diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.cu b/paddle/fluid/operators/elementwise/elementwise_add_op.cu index aad5303d2e6..aff0cb28164 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.cu @@ -28,11 +28,11 @@ namespace operators { 1. For Unary Op, the length of input array is 1, e.g. Relu: return args[0] > 0 ? args[0] : 0; 2. For Binary Op, the length of input array is 2, - e.g. Add: return args[0] + args[1]; + e.g. Add: return args[0] expr args[1]; */ template struct CudaAddFunctor { - __device__ __forceinline__ T operator()(const T* args) const { + inline HOSTDEVICE T operator()(const T* args) const { return args[0] + args[1]; } }; @@ -44,9 +44,12 @@ class ElementwiseAddKernel void Compute(const framework::ExecutionContext& ctx) const override { std::vector ins; std::vector outs; - PackTensorsIntoVector(ctx, &ins, &outs); + const auto& cuda_ctx = + ctx.template device_context(); + + int axis = PackTensorsIntoVector(ctx, &ins, &outs); LaunchElementwiseCudaKernel( - ctx, ins, &outs, CudaAddFunctor()); + cuda_ctx, ins, &outs, axis, CudaAddFunctor()); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.h b/paddle/fluid/operators/elementwise/elementwise_add_op.h index ec7d036a1a1..a469ebbaec2 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.h @@ -72,12 +72,10 @@ class ElementwiseAddKernel : public framework::OpKernel { auto *z = ctx.Output("Out"); z->mutable_data(ctx.GetPlace()); if (x->dims() == y->dims()) { - SameDimsElemwiseAdd - LaunchElementwiseCpuKernel; + SameDimsElemwiseAdd LaunchElementwiseCpuKernel; LaunchElementwiseCpuKernel(ctx, x, y, z); } else { - LaunchBroadcastElementwiseCpuKernel(ctx, x, - y, z); + LaunchBroadcastElementwiseCpuKernel(ctx, x, y, z); } } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op.cu b/paddle/fluid/operators/elementwise/elementwise_max_op.cu index 5d086a1b29f..483b21d07fa 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.cu @@ -12,9 +12,40 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_max_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" namespace ops = paddle::operators; +namespace paddle { +namespace operators { + +template +struct CudaMaxFunctor { + inline HOSTDEVICE T operator()(const T* args) const { + return (args[0] > args[1] ? args[0] : args[1]); + } +}; + +template +class ElementwiseMaxKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + std::vector ins; + std::vector outs; + const auto& cuda_ctx = + ctx.template device_context(); + + int axis = PackTensorsIntoVector(ctx, &ins, &outs); + LaunchElementwiseCudaKernel( + cuda_ctx, ins, &outs, axis, CudaMaxFunctor()); + } +}; + +} // namespace operators +} // namespace paddle + REGISTER_OP_CUDA_KERNEL( elementwise_max, ops::ElementwiseMaxKernel, diff --git a/paddle/fluid/operators/elementwise/elementwise_min_op.cu b/paddle/fluid/operators/elementwise/elementwise_min_op.cu index cf93e5a97a3..88faaf257af 100644 --- a/paddle/fluid/operators/elementwise/elementwise_min_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_min_op.cu @@ -12,9 +12,40 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_min_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" namespace ops = paddle::operators; +namespace paddle { +namespace operators { + +template +struct CudaMinFunctor { + inline HOSTDEVICE T operator()(const T* args) const { + return (args[0] > args[1] ? args[1] : args[0]); + } +}; + +template +class ElementwiseMinKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + std::vector ins; + std::vector outs; + const auto& cuda_ctx = + ctx.template device_context(); + + int axis = PackTensorsIntoVector(ctx, &ins, &outs); + LaunchElementwiseCudaKernel( + cuda_ctx, ins, &outs, axis, CudaMinFunctor()); + } +}; + +} // namespace operators +} // namespace paddle + REGISTER_OP_CUDA_KERNEL( elementwise_min, ops::ElementwiseMinKernel, diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu index 8fd4609c3aa..973f2305cc7 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_mul_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" #include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/float16.h" @@ -24,37 +25,65 @@ namespace paddle { namespace operators { template -struct SameDimsElemwiseMul { - void operator()(const framework::ExecutionContext& ctx, - const framework::Tensor* x, const framework::Tensor* y, - framework::Tensor* z) { - MulRangeFunctor functor(x->data(), y->data(), z->data()); - auto& dev_ctx = ctx.template device_context(); - platform::ForRange for_range(dev_ctx, - x->numel()); - for_range(functor); +struct CudaMulFunctor { + inline HOSTDEVICE T operator()(const T* args) const { + return args[0] * args[1]; } }; -template <> -struct SameDimsElemwiseMul { - void operator()(const framework::ExecutionContext& ctx, - const framework::Tensor* x, const framework::Tensor* y, - framework::Tensor* z) { - auto size = x->numel(); - dim3 grid_size = dim3(((size + 7) / 8 + PADDLE_CUDA_THREAD_SIZE - 1) / - PADDLE_CUDA_THREAD_SIZE, - 1); - dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); - const half* x2 = - reinterpret_cast(x->data()); - const half* y2 = - reinterpret_cast(y->data()); - half* z2 = reinterpret_cast(z->data()); - SameDimsElemwiseMulCUDAKernel<<< - grid_size, block_size, 0, - ctx.template device_context().stream()>>>( - x2, y2, z2, size); +template +class ElementwiseMulKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + int axis = -1; + auto x_var = ctx.InputVar("X"); + PADDLE_ENFORCE_NOT_NULL( + x_var, platform::errors::InvalidArgument( + "Cannot get input Variable X, Variable name = %s.", + ctx.InputName("X"))); + auto* y = ctx.Input("Y"); + + framework::Tensor x, *z; + std::vector ins; + std::vector outs; + const auto& cuda_ctx = + ctx.template device_context(); + + if (x_var->IsType()) { + x = x_var->Get(); + z = ctx.Output("Out"); + axis = PackTensorsIntoVector(ctx, &ins, &outs); + } else if (x_var->IsType()) { + PADDLE_ENFORCE_EQ(y->dims().size() == 1 && y->dims()[0] == 1, true, + platform::errors::InvalidArgument( + "For elementwise_op, if X is Sparse, Y must be " + "scalar. But reveived the size of Y = %s.", + y->dims().size())); + auto& x_sele = x_var->Get(); + auto out_sele = ctx.Output("Out"); + x = x_sele.value(); + out_sele->set_rows(x_sele.rows()); + out_sele->set_height(x_sele.height()); + out_sele->mutable_value()->Resize(x_sele.value().dims()); + out_sele->mutable_value()->mutable_data(ctx.GetPlace(), x.type()); + z = ctx.Output("Out")->mutable_value(); + z->mutable_data(ctx.GetPlace()); + outs.emplace_back(z); + ins.emplace_back(&x); + ins.emplace_back(y); + + axis = ctx.HasAttr("axis") ? ctx.Attr("axis") : -1; + axis = axis == -1 ? std::abs(y->dims().size() - x.dims().size()) : axis; + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "X's type[%s] is not supported by elementwise_op. X's type should be " + "LoDTensor or SelectedRows.", + framework::ToTypeName(x_var->Type()))); + } + + LaunchElementwiseCudaKernel( + cuda_ctx, ins, &outs, axis, CudaMulFunctor()); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.h b/paddle/fluid/operators/elementwise/elementwise_mul_op.h index 10e69491643..a734f891a9d 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.h @@ -126,7 +126,6 @@ class ElementwiseMulKernel : public framework::OpKernel { } } }; - template struct MulGradDX { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * y; } diff --git a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h index 0612d01b6bf..74216d6a9d4 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h @@ -465,7 +465,11 @@ void LaunchBroadcastElementwiseCudaKernel( const platform::CUDADeviceContext &ctx, const std::vector &ins, std::vector *outs, int axis, Functor func) { - static_assert(ET == (ElementwiseType)2, "Only Support binary calculation."); + PADDLE_ENFORCE_EQ(ET, ElementwiseType::kBinary, + platform::errors::InvalidArgument( + "Currently, only Support binary calculation, " + "but received %d input tensors.\n", + static_cast(ET))); int in_vec_size = 4; framework::Tensor *out = (*outs)[0]; for (auto *in : ins) { @@ -502,26 +506,18 @@ void LaunchBroadcastElementwiseCudaKernel( template void LaunchElementwiseCudaKernel( - const framework::ExecutionContext &ctx, + const platform::CUDADeviceContext &cuda_ctx, const std::vector &ins, - std::vector *outs, Functor func) { - std::vector dims_size; + std::vector *outs, int axis, Functor func) { bool no_broadcast_flag = true; for (auto *in : ins) { no_broadcast_flag = ins[0]->dims() == in->dims(); - dims_size.emplace_back(in->dims().size()); } - const auto &cuda_ctx = - ctx.template device_context(); + if (no_broadcast_flag) { - LaunchSameDimsElementwiseCudaKernel( - cuda_ctx, ins, outs, func); + LaunchSameDimsElementwiseCudaKernel(cuda_ctx, ins, outs, + func); } else { - int axis = ctx.HasAttr("axis") ? ctx.Attr("axis") : -1; - axis = axis == -1 - ? *std::max_element(dims_size.begin(), dims_size.end()) - - *std::min_element(dims_size.begin(), dims_size.end()) - : axis; LaunchBroadcastElementwiseCudaKernel(cuda_ctx, ins, outs, axis, func); } diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h index 05b78bcf6ad..d19c75eaf3d 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.h @@ -64,20 +64,24 @@ namespace operators { * To pack the input and output tnesors into vector for * LaunchElementwiseCudaKernel */ -template -void PackTensorsIntoVector(const framework::ExecutionContext &ctx, - std::vector *ins, - std::vector *outs) { +template +int PackTensorsIntoVector(const framework::ExecutionContext &ctx, + std::vector *ins, + std::vector *outs) { + int axis = -1; auto *x = ctx.Input("X"); auto *y = ctx.Input("Y"); auto *z = ctx.Output("Out"); - z->mutable_data(ctx.GetPlace()); - ins->emplace_back(x); + z->mutable_data(ctx.GetPlace()); outs->emplace_back(z); + ins->emplace_back(x); if (y != nullptr) { ins->emplace_back(y); + axis = ctx.HasAttr("axis") ? ctx.Attr("axis") : -1; + axis = axis == -1 ? std::abs(y->dims().size() - x->dims().size()) : axis; } + return axis; } /* diff --git a/paddle/fluid/operators/elementwise/elementwise_pow_op.cu b/paddle/fluid/operators/elementwise/elementwise_pow_op.cu index 320d1e7b38d..5335f274ef1 100644 --- a/paddle/fluid/operators/elementwise/elementwise_pow_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_pow_op.cu @@ -8,10 +8,52 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_pow_op.h" namespace ops = paddle::operators; +namespace paddle { +namespace operators { + +template +struct CudaPowFunctor { + inline HOSTDEVICE T operator()(const T args[]) const { + return std::pow(args[0], args[1]); + } +}; + +template +struct CudaPowFunctor< + T, typename std::enable_if::value>::type> { + // On CUDAPlace, std::pow(3, 1) calls pow(float, float), and + // it will return a float number like 2.99... , which floor to 2 + // when cast to int by default and it is wrong. + // Use llrint to cast it to the nearest integer, which is 3. + inline HOSTDEVICE T operator()(const T args[]) const { + return std::llrint(std::pow(args[0], args[1])); + } +}; + +template +class ElementwisePowKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + std::vector ins; + std::vector outs; + const auto& cuda_ctx = + ctx.template device_context(); + + int axis = PackTensorsIntoVector(ctx, &ins, &outs); + LaunchElementwiseCudaKernel( + cuda_ctx, ins, &outs, axis, CudaPowFunctor()); + } +}; + +} // namespace operators +} // namespace paddle + REGISTER_OP_CUDA_KERNEL( elementwise_pow, ops::ElementwisePowKernel, diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op.cu b/paddle/fluid/operators/elementwise/elementwise_sub_op.cu index 19cbbb7bf04..da9610243f7 100644 --- a/paddle/fluid/operators/elementwise/elementwise_sub_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op.cu @@ -11,8 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_sub_op.h" #include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/float16.h" @@ -24,37 +23,25 @@ namespace paddle { namespace operators { template -struct SameDimsElemwiseSub { - void operator()(const framework::ExecutionContext& ctx, - const framework::Tensor* x, const framework::Tensor* y, - framework::Tensor* z) { - SubRangeFunctor functor(x->data(), y->data(), z->data()); - auto& dev_ctx = ctx.template device_context(); - platform::ForRange for_range(dev_ctx, - x->numel()); - for_range(functor); +struct CudaSubFunctor { + inline HOSTDEVICE T operator()(const T* args) const { + return args[0] - args[1]; } }; -template <> -struct SameDimsElemwiseSub { - void operator()(const framework::ExecutionContext& ctx, - const framework::Tensor* x, const framework::Tensor* y, - framework::Tensor* z) { - auto size = x->numel(); - dim3 grid_size = dim3(((size + 7) / 8 + PADDLE_CUDA_THREAD_SIZE - 1) / - PADDLE_CUDA_THREAD_SIZE, - 1); - dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); - const half* x2 = - reinterpret_cast(x->data()); - const half* y2 = - reinterpret_cast(y->data()); - half* z2 = reinterpret_cast(z->data()); - SameDimsElemwiseSubCUDAKernel<<< - grid_size, block_size, 0, - ctx.template device_context().stream()>>>( - x2, y2, z2, size); +template +class ElementwiseSubKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + std::vector ins; + std::vector outs; + const auto& cuda_ctx = + ctx.template device_context(); + + int axis = PackTensorsIntoVector(ctx, &ins, &outs); + LaunchElementwiseCudaKernel( + cuda_ctx, ins, &outs, axis, CudaSubFunctor()); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op.h b/paddle/fluid/operators/elementwise/elementwise_sub_op.h index 4171d2eb9e5..42609341327 100644 --- a/paddle/fluid/operators/elementwise/elementwise_sub_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op.h @@ -11,8 +11,8 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - #pragma once + #include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" -- GitLab