未验证 提交 b1365d25 编写于 作者: Y Yiqun Liu 提交者: GitHub

Unify the functor of elementwise and logical ops. (#35767)

上级 dfa242e4
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -12,9 +15,6 @@ limitations under the License. */ ...@@ -12,9 +15,6 @@ limitations under the License. */
#include "paddle/fluid/operators/controlflow/logical_op.h" #include "paddle/fluid/operators/controlflow/logical_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -22,9 +22,10 @@ template <typename Functor> ...@@ -22,9 +22,10 @@ template <typename Functor>
class BinaryLogicalOpKernel<platform::CUDADeviceContext, Functor> class BinaryLogicalOpKernel<platform::CUDADeviceContext, Functor>
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> { : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override {
using InT = typename Functor::ELEMENT_TYPE; using InT = typename Functor::ELEMENT_TYPE;
using OutT = bool; using OutT = bool;
void Compute(const framework::ExecutionContext& ctx) const override {
auto functor = Functor(); auto functor = Functor();
std::vector<const framework::Tensor*> ins; std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs; std::vector<framework::Tensor*> outs;
...@@ -45,6 +46,9 @@ class BinaryLogicalOpKernel<platform::CUDADeviceContext, Functor> ...@@ -45,6 +46,9 @@ class BinaryLogicalOpKernel<platform::CUDADeviceContext, Functor>
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
#define REGISTER_LOGICAL_CUDA_KERNEL(op_name, func) \ #define REGISTER_LOGICAL_CUDA_KERNEL(op_name, func) \
REGISTER_OP_CUDA_KERNEL( \ REGISTER_OP_CUDA_KERNEL( \
op_name, \ op_name, \
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......
...@@ -22,6 +22,7 @@ namespace paddle { ...@@ -22,6 +22,7 @@ namespace paddle {
namespace operators { namespace operators {
// Define the binary functors used in elementwise ops. // Define the binary functors used in elementwise ops.
// Note: InverseXxxFunctor is needed when calling ElementwiseComputeEx on CPU.
// Add // Add
template <typename T> template <typename T>
......
...@@ -66,8 +66,8 @@ REGISTER_OP_CPU_KERNEL( ...@@ -66,8 +66,8 @@ REGISTER_OP_CPU_KERNEL(
elementwise_mod, elementwise_mod,
ops::ElementwiseModKernel<paddle::platform::CPUDeviceContext, int>, ops::ElementwiseModKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseModKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::ElementwiseModKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseModFPKernel<paddle::platform::CPUDeviceContext, float>, ops::ElementwiseModKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseModFPKernel<paddle::platform::CPUDeviceContext, double>); ops::ElementwiseModKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_VERSION(elementwise_mod) REGISTER_OP_VERSION(elementwise_mod)
.AddCheckpoint( .AddCheckpoint(
......
...@@ -14,9 +14,6 @@ limitations under the License. */ ...@@ -14,9 +14,6 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_mod_op.h" #include "paddle/fluid/operators/elementwise/elementwise_mod_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -38,6 +35,9 @@ class ElementwiseModKernel<platform::CUDADeviceContext, T> ...@@ -38,6 +35,9 @@ class ElementwiseModKernel<platform::CUDADeviceContext, T>
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
elementwise_mod, ops::ElementwiseModKernel<plat::CUDADeviceContext, int>, elementwise_mod, ops::ElementwiseModKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseModKernel<plat::CUDADeviceContext, int64_t>, ops::ElementwiseModKernel<plat::CUDADeviceContext, int64_t>,
......
...@@ -44,9 +44,9 @@ struct ModFunctor<T, ...@@ -44,9 +44,9 @@ struct ModFunctor<T,
} }
}; };
template <typename T> template <typename T, typename Enable = void>
struct InverseModFunctor { struct InverseModFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { inline HOSTDEVICE T operator()(const T a, const T b) const {
T res = b % a; T res = b % a;
if ((res != 0) && ((res < 0) != (a < 0))) res += a; if ((res != 0) && ((res < 0) != (a < 0))) res += a;
return res; return res;
...@@ -54,8 +54,9 @@ struct InverseModFunctor { ...@@ -54,8 +54,9 @@ struct InverseModFunctor {
}; };
template <typename T> template <typename T>
struct InverseModFunctorFP { struct InverseModFunctor<
inline HOSTDEVICE T operator()(T a, T b) const { T, typename std::enable_if_t<std::is_floating_point<T>::value>> {
inline HOSTDEVICE T operator()(const T a, const T b) const {
T res = fmod(b, a); T res = fmod(b, a);
if ((res != 0) && ((a < 0) != (res < 0))) res += a; if ((res != 0) && ((a < 0) != (res < 0))) res += a;
return res; return res;
...@@ -78,22 +79,6 @@ void elementwise_mod(const framework::ExecutionContext &ctx, ...@@ -78,22 +79,6 @@ void elementwise_mod(const framework::ExecutionContext &ctx,
} }
} }
template <typename DeviceContext, typename T>
void elementwise_mod_fp(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *y,
framework::Tensor *z) {
int axis = ctx.Attr<int>("axis");
auto x_dims = x->dims();
auto y_dims = y->dims();
if (x_dims.size() >= y_dims.size()) {
ElementwiseComputeEx<ModFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
ModFunctor<T>(), z);
} else {
ElementwiseComputeEx<InverseModFunctorFP<T>, DeviceContext, T>(
ctx, x, y, axis, InverseModFunctorFP<T>(), z);
}
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ElementwiseModKernel : public framework::OpKernel<T> { class ElementwiseModKernel : public framework::OpKernel<T> {
public: public:
...@@ -109,20 +94,5 @@ class ElementwiseModKernel : public framework::OpKernel<T> { ...@@ -109,20 +94,5 @@ class ElementwiseModKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T>
class ElementwiseModFPKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *x = ctx.Input<framework::LoDTensor>("X");
auto *y = ctx.Input<framework::LoDTensor>("Y");
auto *z = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
// dtype of x and y is float or double
elementwise_mod_fp<DeviceContext, T>(ctx, x, y, z);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -199,10 +199,6 @@ void ElementwiseComputeEx(const framework::ExecutionContext &ctx, ...@@ -199,10 +199,6 @@ void ElementwiseComputeEx(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *x,
const framework::Tensor *y, int axis, Functor func, const framework::Tensor *y, int axis, Functor func,
framework::Tensor *z) { framework::Tensor *z) {
z->mutable_data<OutType>(ctx.GetPlace());
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
auto pt_z = paddle::experimental::MakePtenDenseTensor(*z);
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
#if defined(__NVCC__) || defined(__HIPCC__) #if defined(__NVCC__) || defined(__HIPCC__)
std::vector<const framework::Tensor *> ins = {x, y}; std::vector<const framework::Tensor *> ins = {x, y};
...@@ -217,6 +213,11 @@ void ElementwiseComputeEx(const framework::ExecutionContext &ctx, ...@@ -217,6 +213,11 @@ void ElementwiseComputeEx(const framework::ExecutionContext &ctx,
return; return;
} }
z->mutable_data<OutType>(ctx.GetPlace());
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
auto pt_z = paddle::experimental::MakePtenDenseTensor(*z);
const auto &dev_ctx = const auto &dev_ctx =
ctx.template device_context<platform::CPUDeviceContext>(); ctx.template device_context<platform::CPUDeviceContext>();
pten::ElementwiseCompute<Functor, T, OutType>( pten::ElementwiseCompute<Functor, T, OutType>(
......
...@@ -16,9 +16,6 @@ limitations under the License. */ ...@@ -16,9 +16,6 @@ limitations under the License. */
#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/function_traits.h"
// only can include the headers in paddle/top/api dirs // only can include the headers in paddle/top/api dirs
#include "paddle/pten/api/lib/utils/tensor_utils.h" #include "paddle/pten/api/lib/utils/tensor_utils.h"
...@@ -27,8 +24,6 @@ limitations under the License. */ ...@@ -27,8 +24,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace kps = paddle::operators::kernel_primitives;
using ElementwiseType = pten::ElementwiseType; using ElementwiseType = pten::ElementwiseType;
template <ElementwiseType ET, typename InT, typename OutT, typename Functor, template <ElementwiseType ET, typename InT, typename OutT, typename Functor,
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -21,7 +24,7 @@ namespace operators { ...@@ -21,7 +24,7 @@ namespace operators {
template <typename T> template <typename T>
struct PowFunctor { struct PowFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { inline HOSTDEVICE T operator()(const T a, const T b) const {
// TODO(wujionghao): A potential speed improvement is supporting different // TODO(wujionghao): A potential speed improvement is supporting different
// types in C++. // types in C++.
#if defined(__CUDA_ARCH__) || defined(__HIPCC__) #if defined(__CUDA_ARCH__) || defined(__HIPCC__)
......
...@@ -26,7 +26,6 @@ ...@@ -26,7 +26,6 @@
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/complex_functors.h" #include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/math/functors.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
......
...@@ -22,6 +22,7 @@ namespace pten { ...@@ -22,6 +22,7 @@ namespace pten {
namespace funcs { namespace funcs {
// Define the binary functors used in elementwise ops. // Define the binary functors used in elementwise ops.
// Note: InverseXxxFunctor is needed when calling ElementwiseComputeEx on CPU.
// Add // Add
template <typename T> template <typename T>
...@@ -48,10 +49,22 @@ template <typename T> ...@@ -48,10 +49,22 @@ template <typename T>
struct MultiplyFunctor { struct MultiplyFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const { return a * b; } inline HOSTDEVICE T operator()(const T a, const T b) const { return a * b; }
}; };
template <>
struct MultiplyFunctor<bool> {
inline HOSTDEVICE bool operator()(const bool a, const bool b) const {
return a && b;
}
};
template <typename T> template <typename T>
struct InverseMultiplyFunctor { struct InverseMultiplyFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const { return b * a; } inline HOSTDEVICE T operator()(const T a, const T b) const { return b * a; }
}; };
template <>
struct InverseMultiplyFunctor<bool> {
inline HOSTDEVICE bool operator()(const bool a, const bool b) const {
return b && a;
}
};
// Divide // Divide
#define DIV_ERROR_INFO \ #define DIV_ERROR_INFO \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册