diff --git a/paddle/fluid/operators/controlflow/compare_op_npu.cc b/paddle/fluid/operators/controlflow/compare_op_npu.cc index b1d4d1e7022a325c8f2c4498aeb4d0b2de79b2d2..235d44b92f91954b3c7a117bd9bf562d1b79a420 100644 --- a/paddle/fluid/operators/controlflow/compare_op_npu.cc +++ b/paddle/fluid/operators/controlflow/compare_op_npu.cc @@ -11,21 +11,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include -#include -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/operators/controlflow/compare_op.h" +#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/npu_op_runner.h" -#ifdef PADDLE_WITH_ASCEND_CL namespace paddle { namespace operators { -template +template class EqualNPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -42,16 +36,33 @@ class EqualNPUKernel : public framework::OpKernel { } }; +template +class NotEqualNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + + const auto& runner = NpuOpRunner("NotEqual", {*x, *y}, {*out}, {}); + auto stream = + ctx.template device_context() + .stream(); + runner.Run(stream); + } +}; + template class LessThanNPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* x = ctx.Input("X"); auto* y = ctx.Input("Y"); - auto* z = ctx.Output("Out"); - // int axis = context.Attr("axis"); - z->mutable_data(ctx.GetPlace()); // allocate - const auto& runner = NpuOpRunner("Less", {*x, *y}, {*z}); + auto* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + + const auto& runner = NpuOpRunner("Less", {*x, *y}, {*out}, {}); auto stream = ctx.template device_context() .stream(); @@ -65,9 +76,10 @@ class LessEqualNPUKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { auto* x = ctx.Input("X"); auto* y = ctx.Input("Y"); - auto* z = ctx.Output("Out"); - z->mutable_data(ctx.GetPlace()); - const auto& runner = NpuOpRunner("LessEqual", {*x, *y}, {*z}); + auto* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + + const auto& runner = NpuOpRunner("LessEqual", {*x, *y}, {*out}, {}); auto stream = ctx.template device_context() .stream(); @@ -81,10 +93,10 @@ class GreaterThanNPUKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { auto* x = ctx.Input("X"); auto* y = ctx.Input("Y"); - auto* z = ctx.Output("Out"); + auto* out = ctx.Output("Out"); - z->mutable_data(ctx.GetPlace()); - const auto& runner = NpuOpRunner("Greater", {*x, *y}, {*z}); + out->mutable_data(ctx.GetPlace()); + const auto& runner = NpuOpRunner("Greater", {*x, *y}, {*out}, {}); auto stream = ctx.template device_context() .stream(); @@ -98,10 +110,10 @@ class GreaterEqualNPUKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { auto* x = ctx.Input("X"); auto* y = ctx.Input("Y"); - auto* z = ctx.Output("Out"); + auto* out = ctx.Output("Out"); - z->mutable_data(ctx.GetPlace()); - const auto& runner = NpuOpRunner("GreaterEqual", {*x, *y}, {*z}); + out->mutable_data(ctx.GetPlace()); + const auto& runner = NpuOpRunner("GreaterEqual", {*x, *y}, {*out}, {}); auto stream = ctx.template device_context() .stream(); @@ -115,32 +127,64 @@ class GreaterEqualNPUKernel : public framework::OpKernel { namespace ops = paddle::operators; namespace plat = paddle::platform; -REGISTER_OP_NPU_KERNEL(equal, ops::EqualNPUKernel, - ops::EqualNPUKernel, - ops::EqualNPUKernel); +REGISTER_OP_NPU_KERNEL( + equal, ops::EqualNPUKernel, + ops::EqualNPUKernel, + ops::EqualNPUKernel, + ops::EqualNPUKernel, + ops::EqualNPUKernel, + ops::EqualNPUKernel, + ops::EqualNPUKernel, + ops::EqualNPUKernel, + ops::EqualNPUKernel); + +REGISTER_OP_NPU_KERNEL( + not_equal, ops::NotEqualNPUKernel, + ops::NotEqualNPUKernel, + ops::NotEqualNPUKernel, + ops::NotEqualNPUKernel, + ops::NotEqualNPUKernel, + ops::NotEqualNPUKernel, + ops::NotEqualNPUKernel, + ops::NotEqualNPUKernel); REGISTER_OP_NPU_KERNEL( - less_than, - ops::LessThanNPUKernel, - ops::LessThanNPUKernel); + less_than, ops::LessThanNPUKernel, + ops::LessThanNPUKernel, + ops::LessThanNPUKernel, + ops::LessThanNPUKernel, + ops::LessThanNPUKernel, + ops::LessThanNPUKernel, + ops::LessThanNPUKernel, + ops::LessThanNPUKernel); REGISTER_OP_NPU_KERNEL( - less_equal, - ops::LessEqualNPUKernel, - ops::LessEqualNPUKernel); + less_equal, ops::LessEqualNPUKernel, + ops::LessEqualNPUKernel, + ops::LessEqualNPUKernel, + ops::LessEqualNPUKernel, + ops::LessEqualNPUKernel, + ops::LessEqualNPUKernel, + ops::LessEqualNPUKernel, + ops::LessEqualNPUKernel); REGISTER_OP_NPU_KERNEL( greater_than, - ops::GreaterThanNPUKernel, - ops::GreaterThanNPUKernel); + ops::GreaterThanNPUKernel, + ops::GreaterThanNPUKernel, + ops::GreaterThanNPUKernel, + ops::GreaterThanNPUKernel, + ops::GreaterThanNPUKernel, + ops::GreaterThanNPUKernel, + ops::GreaterThanNPUKernel, + ops::GreaterThanNPUKernel); REGISTER_OP_NPU_KERNEL( greater_equal, - ops::GreaterEqualNPUKernel, - ops::GreaterEqualNPUKernel); - -#endif + ops::GreaterEqualNPUKernel, + ops::GreaterEqualNPUKernel, + ops::GreaterEqualNPUKernel, + ops::GreaterEqualNPUKernel, + ops::GreaterEqualNPUKernel, + ops::GreaterEqualNPUKernel, + ops::GreaterEqualNPUKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/test_compare_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_compare_op_npu.py index d8c22e2da090770b6cdbc7429a3c9a5aa2f2463d..66ce81756fc9d8dde53a93f060ca4c6483942108 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_compare_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_compare_op_npu.py @@ -142,11 +142,12 @@ def create_test_class(op_type, typename, callback): globals()[cls_name] = Cls -for _type_name in {'float16', 'float32', 'int32'}: - if _type_name == 'int32': +for _type_name in {'float16', 'float32', 'int32', 'int64', 'bool'}: + if _type_name == 'int32' or _type_name == 'bool': create_test_class('equal', _type_name, lambda _a, _b: _a == _b) continue create_test_class('equal', _type_name, lambda _a, _b: _a == _b) + create_test_class('not_equal', _type_name, lambda _a, _b: _a != _b) create_test_class('less_than', _type_name, lambda _a, _b: _a < _b) create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b) create_test_class('greater_than', _type_name, lambda _a, _b: _a > _b)