未验证 提交 0efcae86 编写于 作者: Z Zhang Ting 提交者: GitHub

[part 3]change type of function args (#38887)

* code clean

* [part 3]change type of function args
上级 f1201482
......@@ -22,19 +22,19 @@ limitations under the License. */
namespace paddle {
namespace operators {
#define BITWISE_BINARY_FUNCTOR(func, expr, bool_expr) \
template <typename T> \
struct Bitwise##func##Functor { \
using ELEM_TYPE = T; \
HOSTDEVICE T operator()(const T& a, const T& b) const { return a expr b; } \
}; \
\
template <> \
struct Bitwise##func##Functor<bool> { \
using ELEM_TYPE = bool; \
HOSTDEVICE bool operator()(const bool& a, const bool& b) const { \
return a bool_expr b; \
} \
#define BITWISE_BINARY_FUNCTOR(func, expr, bool_expr) \
template <typename T> \
struct Bitwise##func##Functor { \
using ELEM_TYPE = T; \
HOSTDEVICE T operator()(const T a, const T b) const { return a expr b; } \
}; \
\
template <> \
struct Bitwise##func##Functor<bool> { \
using ELEM_TYPE = bool; \
HOSTDEVICE bool operator()(const bool a, const bool b) const { \
return a bool_expr b; \
} \
};
BITWISE_BINARY_FUNCTOR(And, &, &&)
......@@ -45,13 +45,13 @@ BITWISE_BINARY_FUNCTOR(Xor, ^, !=)
template <typename T>
struct BitwiseNotFunctor {
using ELEM_TYPE = T;
HOSTDEVICE T operator()(const T& a) const { return ~a; }
HOSTDEVICE T operator()(const T a) const { return ~a; }
};
template <>
struct BitwiseNotFunctor<bool> {
using ELEM_TYPE = bool;
HOSTDEVICE bool operator()(const bool& a) const { return !a; }
HOSTDEVICE bool operator()(const bool a) const { return !a; }
};
template <typename DeviceContext, typename Functor>
......
......@@ -28,7 +28,7 @@ namespace operators {
template <typename T>
struct EqualReduceFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const {
HOSTDEVICE bool operator()(const T a, const T b) const {
if (std::is_floating_point<T>::value) {
// This branch will be optimized while compiling if T is integer. It is
// safe to cast a and b to double.
......
......@@ -25,31 +25,31 @@ namespace operators {
template <typename T>
struct LessThanFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a < b; }
HOSTDEVICE bool operator()(const T a, const T b) const { return a < b; }
};
template <typename T>
struct LessEqualFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a <= b; }
HOSTDEVICE bool operator()(const T a, const T b) const { return a <= b; }
};
template <typename T>
struct GreaterThanFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a > b; }
HOSTDEVICE bool operator()(const T a, const T b) const { return a > b; }
};
template <typename T>
struct GreaterEqualFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a >= b; }
HOSTDEVICE bool operator()(const T a, const T b) const { return a >= b; }
};
template <typename T>
struct EqualFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const {
HOSTDEVICE bool operator()(const T a, const T b) const {
if (std::is_floating_point<T>::value) {
// This branch will be optimized while compiling if T is integer. It is
// safe to cast a and b to double.
......@@ -63,7 +63,7 @@ struct EqualFunctor {
template <typename T>
struct NotEqualFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const {
HOSTDEVICE bool operator()(const T a, const T b) const {
return !EqualFunctor<T>()(a, b);
}
};
......
......@@ -18,26 +18,6 @@ namespace plat = paddle::platform;
namespace paddle {
namespace operators {
#define LOGICAL_BINARY_FUNCTOR(func_name, op) \
template <typename T> \
struct func_name { \
using ELEMENT_TYPE = T; \
HOSTDEVICE bool operator()(const T* args) const { \
return static_cast<bool>(args[0]) op static_cast<bool>(args[1]); \
} \
};
LOGICAL_BINARY_FUNCTOR(CudaOrFunctor, ||)
LOGICAL_BINARY_FUNCTOR(CudaAndFunctor, &&)
LOGICAL_BINARY_FUNCTOR(CudaXorFunctor, ^)
#undef LOGICAL_BINARY_FUNCTOR
template <typename T>
struct CudaNotFunctor {
using ELEMENT_TYPE = T;
HOSTDEVICE bool operator()(const T* args) const { return !args[0]; }
};
template <typename Functor>
class BinaryLogicalOpKernel<platform::CUDADeviceContext, Functor>
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
......@@ -76,8 +56,8 @@ class BinaryLogicalOpKernel<platform::CUDADeviceContext, Functor>
ops::BinaryLogicalOpKernel<plat::CUDADeviceContext, ops::func<float>>, \
ops::BinaryLogicalOpKernel<plat::CUDADeviceContext, ops::func<double>>);
REGISTER_LOGICAL_CUDA_KERNEL(logical_or, CudaOrFunctor)
REGISTER_LOGICAL_CUDA_KERNEL(logical_and, CudaAndFunctor)
REGISTER_LOGICAL_CUDA_KERNEL(logical_xor, CudaXorFunctor)
REGISTER_LOGICAL_CUDA_KERNEL(logical_not, CudaNotFunctor)
REGISTER_LOGICAL_CUDA_KERNEL(logical_or, LogicalOrFunctor)
REGISTER_LOGICAL_CUDA_KERNEL(logical_and, LogicalAndFunctor)
REGISTER_LOGICAL_CUDA_KERNEL(logical_xor, LogicalXorFunctor)
REGISTER_LOGICAL_CUDA_KERNEL(logical_not, LogicalNotFunctor)
#undef REGISTER_LOGICAL_CUDA_KERNEL
......@@ -19,38 +19,32 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename T>
struct LogicalAndFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a && b; }
};
#define LOGICAL_BINARY_FUNCTOR(func_name, op) \
template <typename T> \
struct func_name { \
using ELEMENT_TYPE = T; \
HOSTDEVICE bool operator()(const T a, const T b) const { \
return static_cast<bool>(a) op static_cast<bool>(b); \
} \
};
template <typename T>
struct LogicalOrFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a || b; }
};
LOGICAL_BINARY_FUNCTOR(LogicalOrFunctor, ||)
LOGICAL_BINARY_FUNCTOR(LogicalAndFunctor, &&)
LOGICAL_BINARY_FUNCTOR(LogicalXorFunctor, ^)
#undef LOGICAL_BINARY_FUNCTOR
template <typename T>
struct LogicalNotFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a) const { return !a; }
};
template <typename T>
struct LogicalXorFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const {
return (a || b) && !(a && b);
}
using ELEMENT_TYPE = T;
HOSTDEVICE bool operator()(const T a) const { return !a; }
};
template <typename DeviceContext, typename Functor>
class BinaryLogicalOpKernel
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
void Compute(const framework::ExecutionContext& context) const override {
using T = typename Functor::ELEM_TYPE;
using T = typename Functor::ELEMENT_TYPE;
auto* x = context.Input<framework::Tensor>("X");
auto* y = context.Input<framework::Tensor>("Y");
auto* out = context.Output<framework::Tensor>("Out");
......@@ -62,10 +56,10 @@ class BinaryLogicalOpKernel
template <typename DeviceContext, typename Functor>
class UnaryLogicalOpKernel
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
void Compute(const framework::ExecutionContext& context) const override {
using T = typename Functor::ELEM_TYPE;
using T = typename Functor::ELEMENT_TYPE;
auto* x = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
Functor unary_func;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册