未验证 提交 0efcae86 编写于 作者: Z Zhang Ting 提交者: GitHub

[part 3]change type of function args (#38887)

* code clean

* [part 3]change type of function args
上级 f1201482
...@@ -22,19 +22,19 @@ limitations under the License. */ ...@@ -22,19 +22,19 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
#define BITWISE_BINARY_FUNCTOR(func, expr, bool_expr) \ #define BITWISE_BINARY_FUNCTOR(func, expr, bool_expr) \
template <typename T> \ template <typename T> \
struct Bitwise##func##Functor { \ struct Bitwise##func##Functor { \
using ELEM_TYPE = T; \ using ELEM_TYPE = T; \
HOSTDEVICE T operator()(const T& a, const T& b) const { return a expr b; } \ HOSTDEVICE T operator()(const T a, const T b) const { return a expr b; } \
}; \ }; \
\ \
template <> \ template <> \
struct Bitwise##func##Functor<bool> { \ struct Bitwise##func##Functor<bool> { \
using ELEM_TYPE = bool; \ using ELEM_TYPE = bool; \
HOSTDEVICE bool operator()(const bool& a, const bool& b) const { \ HOSTDEVICE bool operator()(const bool a, const bool b) const { \
return a bool_expr b; \ return a bool_expr b; \
} \ } \
}; };
BITWISE_BINARY_FUNCTOR(And, &, &&) BITWISE_BINARY_FUNCTOR(And, &, &&)
...@@ -45,13 +45,13 @@ BITWISE_BINARY_FUNCTOR(Xor, ^, !=) ...@@ -45,13 +45,13 @@ BITWISE_BINARY_FUNCTOR(Xor, ^, !=)
template <typename T> template <typename T>
struct BitwiseNotFunctor { struct BitwiseNotFunctor {
using ELEM_TYPE = T; using ELEM_TYPE = T;
HOSTDEVICE T operator()(const T& a) const { return ~a; } HOSTDEVICE T operator()(const T a) const { return ~a; }
}; };
template <> template <>
struct BitwiseNotFunctor<bool> { struct BitwiseNotFunctor<bool> {
using ELEM_TYPE = bool; using ELEM_TYPE = bool;
HOSTDEVICE bool operator()(const bool& a) const { return !a; } HOSTDEVICE bool operator()(const bool a) const { return !a; }
}; };
template <typename DeviceContext, typename Functor> template <typename DeviceContext, typename Functor>
......
...@@ -28,7 +28,7 @@ namespace operators { ...@@ -28,7 +28,7 @@ namespace operators {
template <typename T> template <typename T>
struct EqualReduceFunctor { struct EqualReduceFunctor {
using ELEM_TYPE = T; using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { HOSTDEVICE bool operator()(const T a, const T b) const {
if (std::is_floating_point<T>::value) { if (std::is_floating_point<T>::value) {
// This branch will be optimized while compiling if T is integer. It is // This branch will be optimized while compiling if T is integer. It is
// safe to cast a and b to double. // safe to cast a and b to double.
......
...@@ -25,31 +25,31 @@ namespace operators { ...@@ -25,31 +25,31 @@ namespace operators {
template <typename T> template <typename T>
struct LessThanFunctor { struct LessThanFunctor {
using ELEM_TYPE = T; using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a < b; } HOSTDEVICE bool operator()(const T a, const T b) const { return a < b; }
}; };
template <typename T> template <typename T>
struct LessEqualFunctor { struct LessEqualFunctor {
using ELEM_TYPE = T; using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a <= b; } HOSTDEVICE bool operator()(const T a, const T b) const { return a <= b; }
}; };
template <typename T> template <typename T>
struct GreaterThanFunctor { struct GreaterThanFunctor {
using ELEM_TYPE = T; using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a > b; } HOSTDEVICE bool operator()(const T a, const T b) const { return a > b; }
}; };
template <typename T> template <typename T>
struct GreaterEqualFunctor { struct GreaterEqualFunctor {
using ELEM_TYPE = T; using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a >= b; } HOSTDEVICE bool operator()(const T a, const T b) const { return a >= b; }
}; };
template <typename T> template <typename T>
struct EqualFunctor { struct EqualFunctor {
using ELEM_TYPE = T; using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { HOSTDEVICE bool operator()(const T a, const T b) const {
if (std::is_floating_point<T>::value) { if (std::is_floating_point<T>::value) {
// This branch will be optimized while compiling if T is integer. It is // This branch will be optimized while compiling if T is integer. It is
// safe to cast a and b to double. // safe to cast a and b to double.
...@@ -63,7 +63,7 @@ struct EqualFunctor { ...@@ -63,7 +63,7 @@ struct EqualFunctor {
template <typename T> template <typename T>
struct NotEqualFunctor { struct NotEqualFunctor {
using ELEM_TYPE = T; using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { HOSTDEVICE bool operator()(const T a, const T b) const {
return !EqualFunctor<T>()(a, b); return !EqualFunctor<T>()(a, b);
} }
}; };
......
...@@ -18,26 +18,6 @@ namespace plat = paddle::platform; ...@@ -18,26 +18,6 @@ namespace plat = paddle::platform;
namespace paddle { namespace paddle {
namespace operators { namespace operators {
#define LOGICAL_BINARY_FUNCTOR(func_name, op) \
template <typename T> \
struct func_name { \
using ELEMENT_TYPE = T; \
HOSTDEVICE bool operator()(const T* args) const { \
return static_cast<bool>(args[0]) op static_cast<bool>(args[1]); \
} \
};
LOGICAL_BINARY_FUNCTOR(CudaOrFunctor, ||)
LOGICAL_BINARY_FUNCTOR(CudaAndFunctor, &&)
LOGICAL_BINARY_FUNCTOR(CudaXorFunctor, ^)
#undef LOGICAL_BINARY_FUNCTOR
template <typename T>
struct CudaNotFunctor {
using ELEMENT_TYPE = T;
HOSTDEVICE bool operator()(const T* args) const { return !args[0]; }
};
template <typename Functor> template <typename Functor>
class BinaryLogicalOpKernel<platform::CUDADeviceContext, Functor> class BinaryLogicalOpKernel<platform::CUDADeviceContext, Functor>
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> { : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
...@@ -76,8 +56,8 @@ class BinaryLogicalOpKernel<platform::CUDADeviceContext, Functor> ...@@ -76,8 +56,8 @@ class BinaryLogicalOpKernel<platform::CUDADeviceContext, Functor>
ops::BinaryLogicalOpKernel<plat::CUDADeviceContext, ops::func<float>>, \ ops::BinaryLogicalOpKernel<plat::CUDADeviceContext, ops::func<float>>, \
ops::BinaryLogicalOpKernel<plat::CUDADeviceContext, ops::func<double>>); ops::BinaryLogicalOpKernel<plat::CUDADeviceContext, ops::func<double>>);
REGISTER_LOGICAL_CUDA_KERNEL(logical_or, CudaOrFunctor) REGISTER_LOGICAL_CUDA_KERNEL(logical_or, LogicalOrFunctor)
REGISTER_LOGICAL_CUDA_KERNEL(logical_and, CudaAndFunctor) REGISTER_LOGICAL_CUDA_KERNEL(logical_and, LogicalAndFunctor)
REGISTER_LOGICAL_CUDA_KERNEL(logical_xor, CudaXorFunctor) REGISTER_LOGICAL_CUDA_KERNEL(logical_xor, LogicalXorFunctor)
REGISTER_LOGICAL_CUDA_KERNEL(logical_not, CudaNotFunctor) REGISTER_LOGICAL_CUDA_KERNEL(logical_not, LogicalNotFunctor)
#undef REGISTER_LOGICAL_CUDA_KERNEL #undef REGISTER_LOGICAL_CUDA_KERNEL
...@@ -19,38 +19,32 @@ limitations under the License. */ ...@@ -19,38 +19,32 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> #define LOGICAL_BINARY_FUNCTOR(func_name, op) \
struct LogicalAndFunctor { template <typename T> \
using ELEM_TYPE = T; struct func_name { \
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a && b; } using ELEMENT_TYPE = T; \
}; HOSTDEVICE bool operator()(const T a, const T b) const { \
return static_cast<bool>(a) op static_cast<bool>(b); \
} \
};
template <typename T> LOGICAL_BINARY_FUNCTOR(LogicalOrFunctor, ||)
struct LogicalOrFunctor { LOGICAL_BINARY_FUNCTOR(LogicalAndFunctor, &&)
using ELEM_TYPE = T; LOGICAL_BINARY_FUNCTOR(LogicalXorFunctor, ^)
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a || b; } #undef LOGICAL_BINARY_FUNCTOR
};
template <typename T> template <typename T>
struct LogicalNotFunctor { struct LogicalNotFunctor {
using ELEM_TYPE = T; using ELEMENT_TYPE = T;
HOSTDEVICE bool operator()(const T& a) const { return !a; } HOSTDEVICE bool operator()(const T a) const { return !a; }
};
template <typename T>
struct LogicalXorFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const {
return (a || b) && !(a && b);
}
}; };
template <typename DeviceContext, typename Functor> template <typename DeviceContext, typename Functor>
class BinaryLogicalOpKernel class BinaryLogicalOpKernel
: public framework::OpKernel<typename Functor::ELEM_TYPE> { : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
using T = typename Functor::ELEM_TYPE; using T = typename Functor::ELEMENT_TYPE;
auto* x = context.Input<framework::Tensor>("X"); auto* x = context.Input<framework::Tensor>("X");
auto* y = context.Input<framework::Tensor>("Y"); auto* y = context.Input<framework::Tensor>("Y");
auto* out = context.Output<framework::Tensor>("Out"); auto* out = context.Output<framework::Tensor>("Out");
...@@ -62,10 +56,10 @@ class BinaryLogicalOpKernel ...@@ -62,10 +56,10 @@ class BinaryLogicalOpKernel
template <typename DeviceContext, typename Functor> template <typename DeviceContext, typename Functor>
class UnaryLogicalOpKernel class UnaryLogicalOpKernel
: public framework::OpKernel<typename Functor::ELEM_TYPE> { : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
using T = typename Functor::ELEM_TYPE; using T = typename Functor::ELEMENT_TYPE;
auto* x = context.Input<framework::Tensor>("X"); auto* x = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out"); auto* out = context.Output<framework::Tensor>("Out");
Functor unary_func; Functor unary_func;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册