未验证 提交 def81b4f 编写于 作者: Z Zhang Ting 提交者: GitHub

unify compare functor (#39024)

上级 46823104
...@@ -22,49 +22,40 @@ limitations under the License. */ ...@@ -22,49 +22,40 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> #define COMPARE_FUNCTOR(func_name, op) \
struct LessThanFunctor { template <typename InT, typename OutT = bool> \
using ELEM_TYPE = T; struct func_name { \
HOSTDEVICE bool operator()(const T a, const T b) const { return a < b; } using ELEM_TYPE = InT; \
}; HOSTDEVICE OutT operator()(const InT a, const InT b) const { \
return static_cast<OutT>(a op b); \
template <typename T> } \
struct LessEqualFunctor { };
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T a, const T b) const { return a <= b; } COMPARE_FUNCTOR(LessThanFunctor, <)
}; COMPARE_FUNCTOR(LessEqualFunctor, <=)
COMPARE_FUNCTOR(GreaterThanFunctor, >)
template <typename T> COMPARE_FUNCTOR(GreaterEqualFunctor, >=)
struct GreaterThanFunctor { #undef COMPARE_FUNCTOR
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T a, const T b) const { return a > b; } template <typename InT, typename OutT = bool>
};
template <typename T>
struct GreaterEqualFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T a, const T b) const { return a >= b; }
};
template <typename T>
struct EqualFunctor { struct EqualFunctor {
using ELEM_TYPE = T; using ELEM_TYPE = InT;
HOSTDEVICE bool operator()(const T a, const T b) const { HOSTDEVICE OutT operator()(const InT a, const InT b) const {
if (std::is_floating_point<T>::value) { if (std::is_floating_point<InT>::value) {
// This branch will be optimized while compiling if T is integer. It is // This branch will be optimized while compiling if T is integer. It is
// safe to cast a and b to double. // safe to cast a and b to double.
return fabs(static_cast<double>(a - b)) < 1e-8; return static_cast<OutT>(fabs(static_cast<double>(a - b)) < 1e-8);
} else { } else {
return (a == b); return static_cast<OutT>(a == b);
} }
} }
}; };
template <typename T> template <typename InT, typename OutT = bool>
struct NotEqualFunctor { struct NotEqualFunctor {
using ELEM_TYPE = T; using ELEM_TYPE = InT;
HOSTDEVICE bool operator()(const T a, const T b) const { HOSTDEVICE bool operator()(const InT a, const InT b) const {
return !EqualFunctor<T>()(a, b); return !EqualFunctor<InT, OutT>()(a, b);
} }
}; };
......
...@@ -219,18 +219,20 @@ class MatrixRankCPUKernel : public framework::OpKernel<T> { ...@@ -219,18 +219,20 @@ class MatrixRankCPUKernel : public framework::OpKernel<T> {
tol_tensor.Resize(detail::NewAxisDim(tol_tensor.dims(), 1)); tol_tensor.Resize(detail::NewAxisDim(tol_tensor.dims(), 1));
Tensor compare_result; Tensor compare_result;
compare_result.mutable_data<int>(detail::NewAxisDim(dim_out, k), compare_result.mutable_data<int64_t>(detail::NewAxisDim(dim_out, k),
context.GetPlace()); context.GetPlace());
int axis = -1; int axis = -1;
if (eigenvalue_tensor.dims().size() >= tol_tensor.dims().size()) { if (eigenvalue_tensor.dims().size() >= tol_tensor.dims().size()) {
ElementwiseComputeEx<GreaterThanFunctor<T>, platform::CPUDeviceContext, T, ElementwiseComputeEx<GreaterThanFunctor<T, int64_t>,
int>(context, &eigenvalue_tensor, &tol_tensor, axis, platform::CPUDeviceContext, T, int>(
GreaterThanFunctor<T>(), &compare_result); context, &eigenvalue_tensor, &tol_tensor, axis,
GreaterThanFunctor<T, int64_t>(), &compare_result);
} else { } else {
ElementwiseComputeEx<LessThanFunctor<T>, platform::CPUDeviceContext, T, ElementwiseComputeEx<LessThanFunctor<T, int64_t>,
int>(context, &eigenvalue_tensor, &tol_tensor, axis, platform::CPUDeviceContext, T, int>(
LessThanFunctor<T>(), &compare_result); context, &eigenvalue_tensor, &tol_tensor, axis,
LessThanFunctor<T, int64_t>(), &compare_result);
} }
auto dito_int = auto dito_int =
math::DeviceIndependenceTensorOperations<platform::CPUDeviceContext, math::DeviceIndependenceTensorOperations<platform::CPUDeviceContext,
......
...@@ -129,10 +129,10 @@ class MatrixRankGPUKernel : public framework::OpKernel<T> { ...@@ -129,10 +129,10 @@ class MatrixRankGPUKernel : public framework::OpKernel<T> {
compare_result.mutable_data<int64_t>(detail::NewAxisDim(dim_out, k), compare_result.mutable_data<int64_t>(detail::NewAxisDim(dim_out, k),
context.GetPlace()); context.GetPlace());
int axis = -1; int axis = -1;
ElementwiseComputeEx<GreaterThanFunctor<T>, platform::CUDADeviceContext, T, ElementwiseComputeEx<GreaterThanFunctor<T, int64_t>,
int64_t>(context, &eigenvalue_tensor, &tol_tensor, platform::CUDADeviceContext, T, int64_t>(
axis, GreaterThanFunctor<T>(), context, &eigenvalue_tensor, &tol_tensor, axis,
&compare_result); GreaterThanFunctor<T, int64_t>(), &compare_result);
auto dito_int = auto dito_int =
math::DeviceIndependenceTensorOperations<platform::CUDADeviceContext, math::DeviceIndependenceTensorOperations<platform::CUDADeviceContext,
int64_t>(context); int64_t>(context);
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/controlflow/compare_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -46,16 +47,6 @@ static DDim RemoveLastDim(const DDim& dim) { ...@@ -46,16 +47,6 @@ static DDim RemoveLastDim(const DDim& dim) {
} }
} // namespace detail } // namespace detail
template <typename T>
struct GreaterThanFunctor {
HOSTDEVICE int operator()(const T a, const T b) const { return a > b; }
};
template <typename T>
struct LessThanFunctor {
HOSTDEVICE int operator()(const T a, const T b) const { return a < b; }
};
template <typename T> template <typename T>
struct GreaterElementFunctor { struct GreaterElementFunctor {
HOSTDEVICE T operator()(const T a, const T b) const { HOSTDEVICE T operator()(const T a, const T b) const {
......
...@@ -72,7 +72,8 @@ struct BinaryOperation<platform::CUDADeviceContext, BinaryFunctor, T> { ...@@ -72,7 +72,8 @@ struct BinaryOperation<platform::CUDADeviceContext, BinaryFunctor, T> {
} }
}; };
template <template <typename T> typename CompareFunctor, typename T> template <template <typename InT, typename OutT> typename CompareFunctor,
typename T>
struct GetMask<platform::CUDADeviceContext, CompareFunctor, T> { struct GetMask<platform::CUDADeviceContext, CompareFunctor, T> {
void operator()(const framework::ExecutionContext& ctx, const Tensor& lhs, void operator()(const framework::ExecutionContext& ctx, const Tensor& lhs,
const Tensor& rhs, Tensor* mask) { const Tensor& rhs, Tensor* mask) {
...@@ -81,7 +82,7 @@ struct GetMask<platform::CUDADeviceContext, CompareFunctor, T> { ...@@ -81,7 +82,7 @@ struct GetMask<platform::CUDADeviceContext, CompareFunctor, T> {
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
paddle::operators::LaunchSameDimsElementwiseCudaKernel< paddle::operators::LaunchSameDimsElementwiseCudaKernel<
ElementwiseType::kBinary, int64_t, T>(dev_ctx, ins, &outs, ElementwiseType::kBinary, int64_t, T>(dev_ctx, ins, &outs,
CompareFunctor<int64_t>()); CompareFunctor<int64_t, T>());
} }
}; };
......
...@@ -112,12 +112,13 @@ void SameDimsBinaryOP(const Tensor& lhs, const Tensor& rhs, Tensor* out) { ...@@ -112,12 +112,13 @@ void SameDimsBinaryOP(const Tensor& lhs, const Tensor& rhs, Tensor* out) {
} }
} }
template <typename DeviceContext, template <typename T> typename CompareFunctor, template <typename DeviceContext,
template <typename InT, typename OutT> typename CompareFunctor,
typename T> typename T>
struct GetMask { struct GetMask {
void operator()(const framework::ExecutionContext& ctx, const Tensor& lhs, void operator()(const framework::ExecutionContext& ctx, const Tensor& lhs,
const Tensor& rhs, Tensor* mask) { const Tensor& rhs, Tensor* mask) {
SameDimsBinaryOP<int64_t, CompareFunctor<int64_t>, T>(lhs, rhs, mask); SameDimsBinaryOP<int64_t, CompareFunctor<int64_t, T>, T>(lhs, rhs, mask);
} }
}; };
......
...@@ -140,6 +140,19 @@ def create_paddle_case(op_type, callback): ...@@ -140,6 +140,19 @@ def create_paddle_case(op_type, callback):
self.assertEqual((out.numpy() == self.real_result).all(), True) self.assertEqual((out.numpy() == self.real_result).all(), True)
paddle.enable_static() paddle.enable_static()
def test_not_equal(self):
if self.op_type == "not_equal":
paddle.disable_static()
x = paddle.to_tensor(
np.array([1.2e-8, 2, 2, 1]), dtype="float32")
y = paddle.to_tensor(
np.array([1.1e-8, 2, 2, 1]), dtype="float32")
op = eval("paddle.%s" % (self.op_type))
out = op(x, y)
self.real_result = np.array([0, 0, 0, 0]).astype(np.int64)
self.assertEqual((out.numpy() == self.real_result).all(), True)
paddle.enable_static()
def test_assert(self): def test_assert(self):
def test_dynamic_api_string(self): def test_dynamic_api_string(self):
if self.op_type == "equal": if self.op_type == "equal":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册