diff --git a/paddle/fluid/operators/controlflow/compare_op.cu b/paddle/fluid/operators/controlflow/compare_op.cu index a52920d9e870103d621a33863c7b1fa163c87ca4..cc0c46adb119a160d166e9093cc4ff677d8bd4e0 100644 --- a/paddle/fluid/operators/controlflow/compare_op.cu +++ b/paddle/fluid/operators/controlflow/compare_op.cu @@ -21,21 +21,21 @@ namespace plat = paddle::platform; namespace paddle { namespace operators { -#define DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(Func, op) \ +#define DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(func, op) \ template \ - struct Func##Functor { \ + struct func { \ using ELEMENT_TYPE = T; \ inline HOSTDEVICE bool operator()(const T* args) const { \ return args[0] op args[1]; \ } \ }; -DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaLessThan, <) -DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaLessEqual, <=) -DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaGreaterThan, >) -DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaGreaterEqual, >=) -DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaEqual, ==) -DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaNotEqual, !=) +DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaLessThanFunctor, <) +DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaLessEqualFunctor, <=) +DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaGreaterThanFunctor, >) +DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaGreaterEqualFunctor, >=) +DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaEqualFunctor, ==) +DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaNotEqualFunctor, !=) #undef DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT template @@ -67,10 +67,12 @@ class CompareOpKernel auto functor = Functor(); std::vector ins; std::vector outs; + const auto& cuda_ctx = + ctx.template device_context(); - PackTensorsIntoVector(ctx, &ins, &outs); + int axis = PackTensorsIntoVector(ctx, &ins, &outs); LaunchElementwiseCudaKernel( - ctx, ins, &outs, functor); + cuda_ctx, ins, &outs, axis, functor); } }; @@ -79,19 +81,16 @@ class CompareOpKernel #define REGISTER_CUDA_COMPARE_KERNEL(op_type, func) \ REGISTER_OP_CUDA_KERNEL( \ - op_type, ops::CompareOpKernel, void>, \ - ops::CompareOpKernel, void>, \ - ops::CompareOpKernel, \ - void>, \ - ops::CompareOpKernel, void>); + op_type, \ + ops::CompareOpKernel, void>, \ + ops::CompareOpKernel, void>, \ + ops::CompareOpKernel, void>, \ + ops::CompareOpKernel, void>); -REGISTER_CUDA_COMPARE_KERNEL(equal, CudaEqual) -REGISTER_CUDA_COMPARE_KERNEL(not_equal, CudaNotEqual) -REGISTER_CUDA_COMPARE_KERNEL(less_than, CudaLessThan) -REGISTER_CUDA_COMPARE_KERNEL(less_equal, CudaLessEqual) -REGISTER_CUDA_COMPARE_KERNEL(greater_than, CudaGreaterThan) -REGISTER_CUDA_COMPARE_KERNEL(greater_equal, CudaGreaterEqual) +REGISTER_CUDA_COMPARE_KERNEL(equal, CudaEqualFunctor) +REGISTER_CUDA_COMPARE_KERNEL(not_equal, CudaNotEqualFunctor) +REGISTER_CUDA_COMPARE_KERNEL(less_than, CudaLessThanFunctor) +REGISTER_CUDA_COMPARE_KERNEL(less_equal, CudaLessEqualFunctor) +REGISTER_CUDA_COMPARE_KERNEL(greater_than, CudaGreaterThanFunctor) +REGISTER_CUDA_COMPARE_KERNEL(greater_equal, CudaGreaterEqualFunctor) #undef REGISTER_CUDA_COMPARE_KERNEL diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.cu b/paddle/fluid/operators/elementwise/elementwise_add_op.cu index aad5303d2e6545939c4ab85bc9a746b06e940f18..aff0cb281642ecf9d9ee62890474ac87841c5e9a 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.cu @@ -28,11 +28,11 @@ namespace operators { 1. For Unary Op, the length of input array is 1, e.g. Relu: return args[0] > 0 ? args[0] : 0; 2. For Binary Op, the length of input array is 2, - e.g. Add: return args[0] + args[1]; + e.g. Add: return args[0] expr args[1]; */ template struct CudaAddFunctor { - __device__ __forceinline__ T operator()(const T* args) const { + inline HOSTDEVICE T operator()(const T* args) const { return args[0] + args[1]; } }; @@ -44,9 +44,12 @@ class ElementwiseAddKernel void Compute(const framework::ExecutionContext& ctx) const override { std::vector ins; std::vector outs; - PackTensorsIntoVector(ctx, &ins, &outs); + const auto& cuda_ctx = + ctx.template device_context(); + + int axis = PackTensorsIntoVector(ctx, &ins, &outs); LaunchElementwiseCudaKernel( - ctx, ins, &outs, CudaAddFunctor()); + cuda_ctx, ins, &outs, axis, CudaAddFunctor()); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.h b/paddle/fluid/operators/elementwise/elementwise_add_op.h index ec7d036a1a1e0295ec496960069335fb33d3d003..a469ebbaec2edc9fadf0992412ef7d3b23d483e6 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.h @@ -72,12 +72,10 @@ class ElementwiseAddKernel : public framework::OpKernel { auto *z = ctx.Output("Out"); z->mutable_data(ctx.GetPlace()); if (x->dims() == y->dims()) { - SameDimsElemwiseAdd - LaunchElementwiseCpuKernel; + SameDimsElemwiseAdd LaunchElementwiseCpuKernel; LaunchElementwiseCpuKernel(ctx, x, y, z); } else { - LaunchBroadcastElementwiseCpuKernel(ctx, x, - y, z); + LaunchBroadcastElementwiseCpuKernel(ctx, x, y, z); } } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op.cu b/paddle/fluid/operators/elementwise/elementwise_max_op.cu index 5d086a1b29febd8e57507eced7683f414ca34e07..483b21d07fab1180ef18eb3a4bfc39591b98d376 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.cu @@ -12,9 +12,40 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_max_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" namespace ops = paddle::operators; +namespace paddle { +namespace operators { + +template +struct CudaMaxFunctor { + inline HOSTDEVICE T operator()(const T* args) const { + return (args[0] > args[1] ? args[0] : args[1]); + } +}; + +template +class ElementwiseMaxKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + std::vector ins; + std::vector outs; + const auto& cuda_ctx = + ctx.template device_context(); + + int axis = PackTensorsIntoVector(ctx, &ins, &outs); + LaunchElementwiseCudaKernel( + cuda_ctx, ins, &outs, axis, CudaMaxFunctor()); + } +}; + +} // namespace operators +} // namespace paddle + REGISTER_OP_CUDA_KERNEL( elementwise_max, ops::ElementwiseMaxKernel, diff --git a/paddle/fluid/operators/elementwise/elementwise_min_op.cu b/paddle/fluid/operators/elementwise/elementwise_min_op.cu index cf93e5a97a3f3110aae907c593f58dbab0f9d090..88faaf257af45b3ae24bd08b562dbaa1ec5d634d 100644 --- a/paddle/fluid/operators/elementwise/elementwise_min_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_min_op.cu @@ -12,9 +12,40 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_min_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" namespace ops = paddle::operators; +namespace paddle { +namespace operators { + +template +struct CudaMinFunctor { + inline HOSTDEVICE T operator()(const T* args) const { + return (args[0] > args[1] ? args[1] : args[0]); + } +}; + +template +class ElementwiseMinKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + std::vector ins; + std::vector outs; + const auto& cuda_ctx = + ctx.template device_context(); + + int axis = PackTensorsIntoVector(ctx, &ins, &outs); + LaunchElementwiseCudaKernel( + cuda_ctx, ins, &outs, axis, CudaMinFunctor()); + } +}; + +} // namespace operators +} // namespace paddle + REGISTER_OP_CUDA_KERNEL( elementwise_min, ops::ElementwiseMinKernel, diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu index 8fd4609c3aa8508687540d5424a9e91511a1a3b5..973f2305cc778dd77051309b94575e8dc687f2ae 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_mul_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" #include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/float16.h" @@ -24,37 +25,65 @@ namespace paddle { namespace operators { template -struct SameDimsElemwiseMul { - void operator()(const framework::ExecutionContext& ctx, - const framework::Tensor* x, const framework::Tensor* y, - framework::Tensor* z) { - MulRangeFunctor functor(x->data(), y->data(), z->data()); - auto& dev_ctx = ctx.template device_context(); - platform::ForRange for_range(dev_ctx, - x->numel()); - for_range(functor); +struct CudaMulFunctor { + inline HOSTDEVICE T operator()(const T* args) const { + return args[0] * args[1]; } }; -template <> -struct SameDimsElemwiseMul { - void operator()(const framework::ExecutionContext& ctx, - const framework::Tensor* x, const framework::Tensor* y, - framework::Tensor* z) { - auto size = x->numel(); - dim3 grid_size = dim3(((size + 7) / 8 + PADDLE_CUDA_THREAD_SIZE - 1) / - PADDLE_CUDA_THREAD_SIZE, - 1); - dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); - const half* x2 = - reinterpret_cast(x->data()); - const half* y2 = - reinterpret_cast(y->data()); - half* z2 = reinterpret_cast(z->data()); - SameDimsElemwiseMulCUDAKernel<<< - grid_size, block_size, 0, - ctx.template device_context().stream()>>>( - x2, y2, z2, size); +template +class ElementwiseMulKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + int axis = -1; + auto x_var = ctx.InputVar("X"); + PADDLE_ENFORCE_NOT_NULL( + x_var, platform::errors::InvalidArgument( + "Cannot get input Variable X, Variable name = %s.", + ctx.InputName("X"))); + auto* y = ctx.Input("Y"); + + framework::Tensor x, *z; + std::vector ins; + std::vector outs; + const auto& cuda_ctx = + ctx.template device_context(); + + if (x_var->IsType()) { + x = x_var->Get(); + z = ctx.Output("Out"); + axis = PackTensorsIntoVector(ctx, &ins, &outs); + } else if (x_var->IsType()) { + PADDLE_ENFORCE_EQ(y->dims().size() == 1 && y->dims()[0] == 1, true, + platform::errors::InvalidArgument( + "For elementwise_op, if X is Sparse, Y must be " + "scalar. But reveived the size of Y = %s.", + y->dims().size())); + auto& x_sele = x_var->Get(); + auto out_sele = ctx.Output("Out"); + x = x_sele.value(); + out_sele->set_rows(x_sele.rows()); + out_sele->set_height(x_sele.height()); + out_sele->mutable_value()->Resize(x_sele.value().dims()); + out_sele->mutable_value()->mutable_data(ctx.GetPlace(), x.type()); + z = ctx.Output("Out")->mutable_value(); + z->mutable_data(ctx.GetPlace()); + outs.emplace_back(z); + ins.emplace_back(&x); + ins.emplace_back(y); + + axis = ctx.HasAttr("axis") ? ctx.Attr("axis") : -1; + axis = axis == -1 ? std::abs(y->dims().size() - x.dims().size()) : axis; + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "X's type[%s] is not supported by elementwise_op. X's type should be " + "LoDTensor or SelectedRows.", + framework::ToTypeName(x_var->Type()))); + } + + LaunchElementwiseCudaKernel( + cuda_ctx, ins, &outs, axis, CudaMulFunctor()); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.h b/paddle/fluid/operators/elementwise/elementwise_mul_op.h index 10e69491643c92d77f58c487abd122d51def82e5..a734f891a9d9e83592156442e48215a93af3a920 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.h @@ -126,7 +126,6 @@ class ElementwiseMulKernel : public framework::OpKernel { } } }; - template struct MulGradDX { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * y; } diff --git a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h index 0612d01b6bf7e6a12b4c7ad9568698a85b8a46df..74216d6a9d4d53b6ba164814abc623b2dc821308 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h @@ -465,7 +465,11 @@ void LaunchBroadcastElementwiseCudaKernel( const platform::CUDADeviceContext &ctx, const std::vector &ins, std::vector *outs, int axis, Functor func) { - static_assert(ET == (ElementwiseType)2, "Only Support binary calculation."); + PADDLE_ENFORCE_EQ(ET, ElementwiseType::kBinary, + platform::errors::InvalidArgument( + "Currently, only Support binary calculation, " + "but received %d input tensors.\n", + static_cast(ET))); int in_vec_size = 4; framework::Tensor *out = (*outs)[0]; for (auto *in : ins) { @@ -502,26 +506,18 @@ void LaunchBroadcastElementwiseCudaKernel( template void LaunchElementwiseCudaKernel( - const framework::ExecutionContext &ctx, + const platform::CUDADeviceContext &cuda_ctx, const std::vector &ins, - std::vector *outs, Functor func) { - std::vector dims_size; + std::vector *outs, int axis, Functor func) { bool no_broadcast_flag = true; for (auto *in : ins) { no_broadcast_flag = ins[0]->dims() == in->dims(); - dims_size.emplace_back(in->dims().size()); } - const auto &cuda_ctx = - ctx.template device_context(); + if (no_broadcast_flag) { - LaunchSameDimsElementwiseCudaKernel( - cuda_ctx, ins, outs, func); + LaunchSameDimsElementwiseCudaKernel(cuda_ctx, ins, outs, + func); } else { - int axis = ctx.HasAttr("axis") ? ctx.Attr("axis") : -1; - axis = axis == -1 - ? *std::max_element(dims_size.begin(), dims_size.end()) - - *std::min_element(dims_size.begin(), dims_size.end()) - : axis; LaunchBroadcastElementwiseCudaKernel(cuda_ctx, ins, outs, axis, func); } diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h index 05b78bcf6ad66c283cd3b02b22f5d72b281b083a..d19c75eaf3de08301cad2e435e8c5a030f8ce253 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.h @@ -64,20 +64,24 @@ namespace operators { * To pack the input and output tnesors into vector for * LaunchElementwiseCudaKernel */ -template -void PackTensorsIntoVector(const framework::ExecutionContext &ctx, - std::vector *ins, - std::vector *outs) { +template +int PackTensorsIntoVector(const framework::ExecutionContext &ctx, + std::vector *ins, + std::vector *outs) { + int axis = -1; auto *x = ctx.Input("X"); auto *y = ctx.Input("Y"); auto *z = ctx.Output("Out"); - z->mutable_data(ctx.GetPlace()); - ins->emplace_back(x); + z->mutable_data(ctx.GetPlace()); outs->emplace_back(z); + ins->emplace_back(x); if (y != nullptr) { ins->emplace_back(y); + axis = ctx.HasAttr("axis") ? ctx.Attr("axis") : -1; + axis = axis == -1 ? std::abs(y->dims().size() - x->dims().size()) : axis; } + return axis; } /* diff --git a/paddle/fluid/operators/elementwise/elementwise_pow_op.cu b/paddle/fluid/operators/elementwise/elementwise_pow_op.cu index 320d1e7b38da8e4f77015ef2b7bcc73e5db7675f..5335f274ef126f228694d1bfb23cb15f6da158ee 100644 --- a/paddle/fluid/operators/elementwise/elementwise_pow_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_pow_op.cu @@ -8,10 +8,52 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_pow_op.h" namespace ops = paddle::operators; +namespace paddle { +namespace operators { + +template +struct CudaPowFunctor { + inline HOSTDEVICE T operator()(const T args[]) const { + return std::pow(args[0], args[1]); + } +}; + +template +struct CudaPowFunctor< + T, typename std::enable_if::value>::type> { + // On CUDAPlace, std::pow(3, 1) calls pow(float, float), and + // it will return a float number like 2.99... , which floor to 2 + // when cast to int by default and it is wrong. + // Use llrint to cast it to the nearest integer, which is 3. + inline HOSTDEVICE T operator()(const T args[]) const { + return std::llrint(std::pow(args[0], args[1])); + } +}; + +template +class ElementwisePowKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + std::vector ins; + std::vector outs; + const auto& cuda_ctx = + ctx.template device_context(); + + int axis = PackTensorsIntoVector(ctx, &ins, &outs); + LaunchElementwiseCudaKernel( + cuda_ctx, ins, &outs, axis, CudaPowFunctor()); + } +}; + +} // namespace operators +} // namespace paddle + REGISTER_OP_CUDA_KERNEL( elementwise_pow, ops::ElementwisePowKernel, diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op.cu b/paddle/fluid/operators/elementwise/elementwise_sub_op.cu index 19cbbb7bf04287b49e023aaa10c9635b6c4fbda7..da9610243f7c4df3300b3ea8b9137cea84e5c72b 100644 --- a/paddle/fluid/operators/elementwise/elementwise_sub_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op.cu @@ -11,8 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_sub_op.h" #include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/float16.h" @@ -24,37 +23,25 @@ namespace paddle { namespace operators { template -struct SameDimsElemwiseSub { - void operator()(const framework::ExecutionContext& ctx, - const framework::Tensor* x, const framework::Tensor* y, - framework::Tensor* z) { - SubRangeFunctor functor(x->data(), y->data(), z->data()); - auto& dev_ctx = ctx.template device_context(); - platform::ForRange for_range(dev_ctx, - x->numel()); - for_range(functor); +struct CudaSubFunctor { + inline HOSTDEVICE T operator()(const T* args) const { + return args[0] - args[1]; } }; -template <> -struct SameDimsElemwiseSub { - void operator()(const framework::ExecutionContext& ctx, - const framework::Tensor* x, const framework::Tensor* y, - framework::Tensor* z) { - auto size = x->numel(); - dim3 grid_size = dim3(((size + 7) / 8 + PADDLE_CUDA_THREAD_SIZE - 1) / - PADDLE_CUDA_THREAD_SIZE, - 1); - dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); - const half* x2 = - reinterpret_cast(x->data()); - const half* y2 = - reinterpret_cast(y->data()); - half* z2 = reinterpret_cast(z->data()); - SameDimsElemwiseSubCUDAKernel<<< - grid_size, block_size, 0, - ctx.template device_context().stream()>>>( - x2, y2, z2, size); +template +class ElementwiseSubKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + std::vector ins; + std::vector outs; + const auto& cuda_ctx = + ctx.template device_context(); + + int axis = PackTensorsIntoVector(ctx, &ins, &outs); + LaunchElementwiseCudaKernel( + cuda_ctx, ins, &outs, axis, CudaSubFunctor()); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op.h b/paddle/fluid/operators/elementwise/elementwise_sub_op.h index 4171d2eb9e5e53ea2fff9a2ab7521f2e5c4ae438..426093413276092538c67676abb2c1e9b7f637ed 100644 --- a/paddle/fluid/operators/elementwise/elementwise_sub_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op.h @@ -11,8 +11,8 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - #pragma once + #include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h"