未验证 提交 78eff521 编写于 作者: W wawltor 提交者: GitHub

[BUG FIX] when x.dim < y.dim, the result of compare_op is inverse (#32470)

* fix bug: when x.dim < y.dim, the result of compare_op is inverse to expected result

* support the cuda for fix the compare broadcast bug
上级 f272e59a
...@@ -23,29 +23,6 @@ limitations under the License. */ ...@@ -23,29 +23,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Functor>
class CompareOpKernel<platform::CPUDeviceContext, Functor>
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
public:
void Compute(const framework::ExecutionContext& context) const override {
using T = typename Functor::ELEM_TYPE;
using Tensor = framework::Tensor;
auto* x = context.Input<Tensor>("X");
auto* y = context.Input<Tensor>("Y");
auto* z = context.Output<Tensor>("Out");
int axis = context.Attr<int>("axis");
if (x->numel() == 1 && y->numel() == 1) {
bool* z_data = z->mutable_data<bool>(context.GetPlace());
z_data[0] = Functor()(x->data<T>()[0], y->data<T>()[0]);
} else {
ElementwiseComputeEx<Functor, platform::CPUDeviceContext, T, bool>(
context, x, y, axis, Functor(), z);
}
}
};
template <typename OpComment> template <typename OpComment>
class CompareOpProtoMaker : public framework::OpProtoAndCheckerMaker { class CompareOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
...@@ -153,16 +130,22 @@ class CompareOp : public framework::OperatorWithKernel { ...@@ -153,16 +130,22 @@ class CompareOp : public framework::OperatorWithKernel {
REGISTER_COMPARE_OP_VERSION(op_type); REGISTER_COMPARE_OP_VERSION(op_type);
REGISTER_COMPARE_OP(less_than, "Out = X < Y"); REGISTER_COMPARE_OP(less_than, "Out = X < Y");
REGISTER_COMPARE_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor); REGISTER_COMPARE_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor,
paddle::operators::GreaterEqualFunctor);
REGISTER_COMPARE_OP(less_equal, "Out = X <= Y"); REGISTER_COMPARE_OP(less_equal, "Out = X <= Y");
REGISTER_COMPARE_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor); REGISTER_COMPARE_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor,
paddle::operators::GreaterThanFunctor);
REGISTER_COMPARE_OP(greater_than, "Out = X > Y"); REGISTER_COMPARE_OP(greater_than, "Out = X > Y");
REGISTER_COMPARE_KERNEL(greater_than, CPU, REGISTER_COMPARE_KERNEL(greater_than, CPU,
paddle::operators::GreaterThanFunctor); paddle::operators::GreaterThanFunctor,
paddle::operators::LessEqualFunctor);
REGISTER_COMPARE_OP(greater_equal, "Out = X >= Y"); REGISTER_COMPARE_OP(greater_equal, "Out = X >= Y");
REGISTER_COMPARE_KERNEL(greater_equal, CPU, REGISTER_COMPARE_KERNEL(greater_equal, CPU,
paddle::operators::GreaterEqualFunctor); paddle::operators::GreaterEqualFunctor,
paddle::operators::LessThanFunctor);
REGISTER_COMPARE_OP(equal, "Out = X == Y"); REGISTER_COMPARE_OP(equal, "Out = X == Y");
REGISTER_COMPARE_KERNEL(equal, CPU, paddle::operators::EqualFunctor); REGISTER_COMPARE_KERNEL(equal, CPU, paddle::operators::EqualFunctor,
paddle::operators::EqualFunctor);
REGISTER_COMPARE_OP(not_equal, "Out = X != Y"); REGISTER_COMPARE_OP(not_equal, "Out = X != Y");
REGISTER_COMPARE_KERNEL(not_equal, CPU, paddle::operators::NotEqualFunctor); REGISTER_COMPARE_KERNEL(not_equal, CPU, paddle::operators::NotEqualFunctor,
paddle::operators::NotEqualFunctor);
...@@ -14,11 +14,17 @@ limitations under the License. */ ...@@ -14,11 +14,17 @@ limitations under the License. */
#include "paddle/fluid/operators/controlflow/compare_op.h" #include "paddle/fluid/operators/controlflow/compare_op.h"
REGISTER_COMPARE_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor); REGISTER_COMPARE_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor,
REGISTER_COMPARE_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor); paddle::operators::GreaterEqualFunctor);
REGISTER_COMPARE_KERNEL(greater_than, CUDA, REGISTER_COMPARE_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor,
paddle::operators::GreaterThanFunctor); paddle::operators::GreaterThanFunctor);
REGISTER_COMPARE_KERNEL(greater_than, CUDA,
paddle::operators::GreaterThanFunctor,
paddle::operators::LessEqualFunctor);
REGISTER_COMPARE_KERNEL(greater_equal, CUDA, REGISTER_COMPARE_KERNEL(greater_equal, CUDA,
paddle::operators::GreaterEqualFunctor); paddle::operators::GreaterEqualFunctor,
REGISTER_COMPARE_KERNEL(equal, CUDA, paddle::operators::EqualFunctor); paddle::operators::LessThanFunctor);
REGISTER_COMPARE_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor); REGISTER_COMPARE_KERNEL(equal, CUDA, paddle::operators::EqualFunctor,
paddle::operators::EqualFunctor);
REGISTER_COMPARE_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor,
paddle::operators::NotEqualFunctor);
...@@ -68,7 +68,7 @@ struct NotEqualFunctor { ...@@ -68,7 +68,7 @@ struct NotEqualFunctor {
} }
}; };
template <typename DeviceContext, typename Functor> template <typename DeviceContext, typename Functor, typename InverseFunctor>
class CompareOpKernel class CompareOpKernel
: public framework::OpKernel<typename Functor::ELEM_TYPE> { : public framework::OpKernel<typename Functor::ELEM_TYPE> {
public: public:
...@@ -80,21 +80,33 @@ class CompareOpKernel ...@@ -80,21 +80,33 @@ class CompareOpKernel
auto* y = context.Input<Tensor>("Y"); auto* y = context.Input<Tensor>("Y");
auto* z = context.Output<Tensor>("Out"); auto* z = context.Output<Tensor>("Out");
int axis = context.Attr<int>("axis"); int axis = context.Attr<int>("axis");
ElementwiseComputeEx<Functor, DeviceContext, T, bool>(context, x, y, axis,
Functor(), z); auto x_dims = x->dims();
auto y_dims = y->dims();
if (x_dims.size() >= y_dims.size()) {
ElementwiseComputeEx<Functor, DeviceContext, T, bool>(context, x, y, axis,
Functor(), z);
} else {
ElementwiseComputeEx<InverseFunctor, DeviceContext, T, bool>(
context, x, y, axis, InverseFunctor(), z);
}
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#define REGISTER_COMPARE_KERNEL(op_type, dev, functor) \ #define REGISTER_COMPARE_KERNEL(op_type, dev, functor, inverse_functor) \
REGISTER_OP_##dev##_KERNEL( \ REGISTER_OP_##dev##_KERNEL(op_type, \
op_type, ::paddle::operators::CompareOpKernel< \ ::paddle::operators::CompareOpKernel< \
::paddle::platform::dev##DeviceContext, functor<int>>, \ ::paddle::platform::dev##DeviceContext, \
::paddle::operators::CompareOpKernel< \ functor<int>, inverse_functor<int>>, \
::paddle::platform::dev##DeviceContext, functor<int64_t>>, \ ::paddle::operators::CompareOpKernel< \
::paddle::operators::CompareOpKernel< \ ::paddle::platform::dev##DeviceContext, \
::paddle::platform::dev##DeviceContext, functor<float>>, \ functor<int64_t>, inverse_functor<int64_t>>, \
::paddle::operators::CompareOpKernel< \ ::paddle::operators::CompareOpKernel< \
::paddle::platform::dev##DeviceContext, functor<double>>); ::paddle::platform::dev##DeviceContext, \
functor<float>, inverse_functor<float>>, \
::paddle::operators::CompareOpKernel< \
::paddle::platform::dev##DeviceContext, \
functor<double>, inverse_functor<double>>);
...@@ -122,6 +122,23 @@ def create_paddle_case(op_type, callback): ...@@ -122,6 +122,23 @@ def create_paddle_case(op_type, callback):
fetch_list=[out]) fetch_list=[out])
self.assertEqual((res == real_result).all(), True) self.assertEqual((res == real_result).all(), True)
def test_broadcast_api_2(self):
paddle.enable_static()
with program_guard(Program(), Program()):
x = paddle.static.data(name='x', shape=[1, 2, 3], dtype='int32')
y = paddle.static.data(
name='y', shape=[1, 2, 1, 3], dtype='int32')
op = eval("paddle.%s" % (self.op_type))
out = op(x, y)
exe = paddle.static.Executor(self.place)
input_x = np.arange(0, 6).reshape((1, 2, 3)).astype(np.int32)
input_y = np.arange(1, 7).reshape((1, 2, 1, 3)).astype(np.int32)
real_result = callback(input_x, input_y)
res, = exe.run(feed={"x": input_x,
"y": input_y},
fetch_list=[out])
self.assertEqual((res == real_result).all(), True)
def test_attr_name(self): def test_attr_name(self):
paddle.enable_static() paddle.enable_static()
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册