未验证 提交 78eff521 编写于 作者: W wawltor 提交者: GitHub

[BUG FIX] when x.dim < y.dim, the result of compare_op is inverse (#32470)

* fix bug: when x.dim < y.dim, the result of compare_op is inverse to expected result

* support the cuda for fix the compare broadcast bug
上级 f272e59a
......@@ -23,29 +23,6 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename Functor>
class CompareOpKernel<platform::CPUDeviceContext, Functor>
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
public:
void Compute(const framework::ExecutionContext& context) const override {
using T = typename Functor::ELEM_TYPE;
using Tensor = framework::Tensor;
auto* x = context.Input<Tensor>("X");
auto* y = context.Input<Tensor>("Y");
auto* z = context.Output<Tensor>("Out");
int axis = context.Attr<int>("axis");
if (x->numel() == 1 && y->numel() == 1) {
bool* z_data = z->mutable_data<bool>(context.GetPlace());
z_data[0] = Functor()(x->data<T>()[0], y->data<T>()[0]);
} else {
ElementwiseComputeEx<Functor, platform::CPUDeviceContext, T, bool>(
context, x, y, axis, Functor(), z);
}
}
};
template <typename OpComment>
class CompareOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
......@@ -153,16 +130,22 @@ class CompareOp : public framework::OperatorWithKernel {
REGISTER_COMPARE_OP_VERSION(op_type);
REGISTER_COMPARE_OP(less_than, "Out = X < Y");
REGISTER_COMPARE_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor);
REGISTER_COMPARE_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor,
paddle::operators::GreaterEqualFunctor);
REGISTER_COMPARE_OP(less_equal, "Out = X <= Y");
REGISTER_COMPARE_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor);
REGISTER_COMPARE_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor,
paddle::operators::GreaterThanFunctor);
REGISTER_COMPARE_OP(greater_than, "Out = X > Y");
REGISTER_COMPARE_KERNEL(greater_than, CPU,
paddle::operators::GreaterThanFunctor);
paddle::operators::GreaterThanFunctor,
paddle::operators::LessEqualFunctor);
REGISTER_COMPARE_OP(greater_equal, "Out = X >= Y");
REGISTER_COMPARE_KERNEL(greater_equal, CPU,
paddle::operators::GreaterEqualFunctor);
paddle::operators::GreaterEqualFunctor,
paddle::operators::LessThanFunctor);
REGISTER_COMPARE_OP(equal, "Out = X == Y");
REGISTER_COMPARE_KERNEL(equal, CPU, paddle::operators::EqualFunctor);
REGISTER_COMPARE_KERNEL(equal, CPU, paddle::operators::EqualFunctor,
paddle::operators::EqualFunctor);
REGISTER_COMPARE_OP(not_equal, "Out = X != Y");
REGISTER_COMPARE_KERNEL(not_equal, CPU, paddle::operators::NotEqualFunctor);
REGISTER_COMPARE_KERNEL(not_equal, CPU, paddle::operators::NotEqualFunctor,
paddle::operators::NotEqualFunctor);
......@@ -14,11 +14,17 @@ limitations under the License. */
#include "paddle/fluid/operators/controlflow/compare_op.h"
REGISTER_COMPARE_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor);
REGISTER_COMPARE_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor);
REGISTER_COMPARE_KERNEL(greater_than, CUDA,
REGISTER_COMPARE_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor,
paddle::operators::GreaterEqualFunctor);
REGISTER_COMPARE_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor,
paddle::operators::GreaterThanFunctor);
REGISTER_COMPARE_KERNEL(greater_than, CUDA,
paddle::operators::GreaterThanFunctor,
paddle::operators::LessEqualFunctor);
REGISTER_COMPARE_KERNEL(greater_equal, CUDA,
paddle::operators::GreaterEqualFunctor);
REGISTER_COMPARE_KERNEL(equal, CUDA, paddle::operators::EqualFunctor);
REGISTER_COMPARE_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor);
paddle::operators::GreaterEqualFunctor,
paddle::operators::LessThanFunctor);
REGISTER_COMPARE_KERNEL(equal, CUDA, paddle::operators::EqualFunctor,
paddle::operators::EqualFunctor);
REGISTER_COMPARE_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor,
paddle::operators::NotEqualFunctor);
......@@ -68,7 +68,7 @@ struct NotEqualFunctor {
}
};
template <typename DeviceContext, typename Functor>
template <typename DeviceContext, typename Functor, typename InverseFunctor>
class CompareOpKernel
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
public:
......@@ -80,21 +80,33 @@ class CompareOpKernel
auto* y = context.Input<Tensor>("Y");
auto* z = context.Output<Tensor>("Out");
int axis = context.Attr<int>("axis");
auto x_dims = x->dims();
auto y_dims = y->dims();
if (x_dims.size() >= y_dims.size()) {
ElementwiseComputeEx<Functor, DeviceContext, T, bool>(context, x, y, axis,
Functor(), z);
} else {
ElementwiseComputeEx<InverseFunctor, DeviceContext, T, bool>(
context, x, y, axis, InverseFunctor(), z);
}
}
};
} // namespace operators
} // namespace paddle
#define REGISTER_COMPARE_KERNEL(op_type, dev, functor) \
REGISTER_OP_##dev##_KERNEL( \
op_type, ::paddle::operators::CompareOpKernel< \
::paddle::platform::dev##DeviceContext, functor<int>>, \
#define REGISTER_COMPARE_KERNEL(op_type, dev, functor, inverse_functor) \
REGISTER_OP_##dev##_KERNEL(op_type, \
::paddle::operators::CompareOpKernel< \
::paddle::platform::dev##DeviceContext, \
functor<int>, inverse_functor<int>>, \
::paddle::operators::CompareOpKernel< \
::paddle::platform::dev##DeviceContext, functor<int64_t>>, \
::paddle::platform::dev##DeviceContext, \
functor<int64_t>, inverse_functor<int64_t>>, \
::paddle::operators::CompareOpKernel< \
::paddle::platform::dev##DeviceContext, functor<float>>, \
::paddle::platform::dev##DeviceContext, \
functor<float>, inverse_functor<float>>, \
::paddle::operators::CompareOpKernel< \
::paddle::platform::dev##DeviceContext, functor<double>>);
::paddle::platform::dev##DeviceContext, \
functor<double>, inverse_functor<double>>);
......@@ -122,6 +122,23 @@ def create_paddle_case(op_type, callback):
fetch_list=[out])
self.assertEqual((res == real_result).all(), True)
def test_broadcast_api_2(self):
paddle.enable_static()
with program_guard(Program(), Program()):
x = paddle.static.data(name='x', shape=[1, 2, 3], dtype='int32')
y = paddle.static.data(
name='y', shape=[1, 2, 1, 3], dtype='int32')
op = eval("paddle.%s" % (self.op_type))
out = op(x, y)
exe = paddle.static.Executor(self.place)
input_x = np.arange(0, 6).reshape((1, 2, 3)).astype(np.int32)
input_y = np.arange(1, 7).reshape((1, 2, 1, 3)).astype(np.int32)
real_result = callback(input_x, input_y)
res, = exe.run(feed={"x": input_x,
"y": input_y},
fetch_list=[out])
self.assertEqual((res == real_result).all(), True)
def test_attr_name(self):
paddle.enable_static()
with program_guard(Program(), Program()):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册