未验证 提交 def81b4f 编写于 作者: Z Zhang Ting 提交者: GitHub

unify compare functor (#39024)

上级 46823104
......@@ -22,49 +22,40 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename T>
struct LessThanFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T a, const T b) const { return a < b; }
};
template <typename T>
struct LessEqualFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T a, const T b) const { return a <= b; }
};
template <typename T>
struct GreaterThanFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T a, const T b) const { return a > b; }
};
template <typename T>
struct GreaterEqualFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T a, const T b) const { return a >= b; }
};
template <typename T>
#define COMPARE_FUNCTOR(func_name, op) \
template <typename InT, typename OutT = bool> \
struct func_name { \
using ELEM_TYPE = InT; \
HOSTDEVICE OutT operator()(const InT a, const InT b) const { \
return static_cast<OutT>(a op b); \
} \
};
COMPARE_FUNCTOR(LessThanFunctor, <)
COMPARE_FUNCTOR(LessEqualFunctor, <=)
COMPARE_FUNCTOR(GreaterThanFunctor, >)
COMPARE_FUNCTOR(GreaterEqualFunctor, >=)
#undef COMPARE_FUNCTOR
template <typename InT, typename OutT = bool>
struct EqualFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T a, const T b) const {
if (std::is_floating_point<T>::value) {
using ELEM_TYPE = InT;
HOSTDEVICE OutT operator()(const InT a, const InT b) const {
if (std::is_floating_point<InT>::value) {
// This branch will be optimized while compiling if T is integer. It is
// safe to cast a and b to double.
return fabs(static_cast<double>(a - b)) < 1e-8;
return static_cast<OutT>(fabs(static_cast<double>(a - b)) < 1e-8);
} else {
return (a == b);
return static_cast<OutT>(a == b);
}
}
};
template <typename T>
template <typename InT, typename OutT = bool>
struct NotEqualFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T a, const T b) const {
return !EqualFunctor<T>()(a, b);
using ELEM_TYPE = InT;
HOSTDEVICE bool operator()(const InT a, const InT b) const {
return !EqualFunctor<InT, OutT>()(a, b);
}
};
......
......@@ -219,18 +219,20 @@ class MatrixRankCPUKernel : public framework::OpKernel<T> {
tol_tensor.Resize(detail::NewAxisDim(tol_tensor.dims(), 1));
Tensor compare_result;
compare_result.mutable_data<int>(detail::NewAxisDim(dim_out, k),
context.GetPlace());
compare_result.mutable_data<int64_t>(detail::NewAxisDim(dim_out, k),
context.GetPlace());
int axis = -1;
if (eigenvalue_tensor.dims().size() >= tol_tensor.dims().size()) {
ElementwiseComputeEx<GreaterThanFunctor<T>, platform::CPUDeviceContext, T,
int>(context, &eigenvalue_tensor, &tol_tensor, axis,
GreaterThanFunctor<T>(), &compare_result);
ElementwiseComputeEx<GreaterThanFunctor<T, int64_t>,
platform::CPUDeviceContext, T, int>(
context, &eigenvalue_tensor, &tol_tensor, axis,
GreaterThanFunctor<T, int64_t>(), &compare_result);
} else {
ElementwiseComputeEx<LessThanFunctor<T>, platform::CPUDeviceContext, T,
int>(context, &eigenvalue_tensor, &tol_tensor, axis,
LessThanFunctor<T>(), &compare_result);
ElementwiseComputeEx<LessThanFunctor<T, int64_t>,
platform::CPUDeviceContext, T, int>(
context, &eigenvalue_tensor, &tol_tensor, axis,
LessThanFunctor<T, int64_t>(), &compare_result);
}
auto dito_int =
math::DeviceIndependenceTensorOperations<platform::CPUDeviceContext,
......
......@@ -129,10 +129,10 @@ class MatrixRankGPUKernel : public framework::OpKernel<T> {
compare_result.mutable_data<int64_t>(detail::NewAxisDim(dim_out, k),
context.GetPlace());
int axis = -1;
ElementwiseComputeEx<GreaterThanFunctor<T>, platform::CUDADeviceContext, T,
int64_t>(context, &eigenvalue_tensor, &tol_tensor,
axis, GreaterThanFunctor<T>(),
&compare_result);
ElementwiseComputeEx<GreaterThanFunctor<T, int64_t>,
platform::CUDADeviceContext, T, int64_t>(
context, &eigenvalue_tensor, &tol_tensor, axis,
GreaterThanFunctor<T, int64_t>(), &compare_result);
auto dito_int =
math::DeviceIndependenceTensorOperations<platform::CUDADeviceContext,
int64_t>(context);
......
......@@ -16,6 +16,7 @@
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/controlflow/compare_op.h"
namespace paddle {
namespace operators {
......@@ -46,16 +47,6 @@ static DDim RemoveLastDim(const DDim& dim) {
}
} // namespace detail
template <typename T>
struct GreaterThanFunctor {
HOSTDEVICE int operator()(const T a, const T b) const { return a > b; }
};
template <typename T>
struct LessThanFunctor {
HOSTDEVICE int operator()(const T a, const T b) const { return a < b; }
};
template <typename T>
struct GreaterElementFunctor {
HOSTDEVICE T operator()(const T a, const T b) const {
......
......@@ -72,7 +72,8 @@ struct BinaryOperation<platform::CUDADeviceContext, BinaryFunctor, T> {
}
};
template <template <typename T> typename CompareFunctor, typename T>
template <template <typename InT, typename OutT> typename CompareFunctor,
typename T>
struct GetMask<platform::CUDADeviceContext, CompareFunctor, T> {
void operator()(const framework::ExecutionContext& ctx, const Tensor& lhs,
const Tensor& rhs, Tensor* mask) {
......@@ -81,7 +82,7 @@ struct GetMask<platform::CUDADeviceContext, CompareFunctor, T> {
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
paddle::operators::LaunchSameDimsElementwiseCudaKernel<
ElementwiseType::kBinary, int64_t, T>(dev_ctx, ins, &outs,
CompareFunctor<int64_t>());
CompareFunctor<int64_t, T>());
}
};
......
......@@ -112,12 +112,13 @@ void SameDimsBinaryOP(const Tensor& lhs, const Tensor& rhs, Tensor* out) {
}
}
template <typename DeviceContext, template <typename T> typename CompareFunctor,
template <typename DeviceContext,
template <typename InT, typename OutT> typename CompareFunctor,
typename T>
struct GetMask {
void operator()(const framework::ExecutionContext& ctx, const Tensor& lhs,
const Tensor& rhs, Tensor* mask) {
SameDimsBinaryOP<int64_t, CompareFunctor<int64_t>, T>(lhs, rhs, mask);
SameDimsBinaryOP<int64_t, CompareFunctor<int64_t, T>, T>(lhs, rhs, mask);
}
};
......
......@@ -140,6 +140,19 @@ def create_paddle_case(op_type, callback):
self.assertEqual((out.numpy() == self.real_result).all(), True)
paddle.enable_static()
def test_not_equal(self):
if self.op_type == "not_equal":
paddle.disable_static()
x = paddle.to_tensor(
np.array([1.2e-8, 2, 2, 1]), dtype="float32")
y = paddle.to_tensor(
np.array([1.1e-8, 2, 2, 1]), dtype="float32")
op = eval("paddle.%s" % (self.op_type))
out = op(x, y)
self.real_result = np.array([0, 0, 0, 0]).astype(np.int64)
self.assertEqual((out.numpy() == self.real_result).all(), True)
paddle.enable_static()
def test_assert(self):
def test_dynamic_api_string(self):
if self.op_type == "equal":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册