未验证 提交 7e707ce8 编写于 作者: B baoachun 提交者: GitHub

add not_equal NPU op (#34560)

* add not_equal NPU op

* add not_equal NPU op

* add not_equal NPU op

* add not_equal NPU op
上级 8144a730
...@@ -11,21 +11,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -11,21 +11,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/controlflow/compare_op.h" #include "paddle/fluid/operators/controlflow/compare_op.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/npu_op_runner.h" #include "paddle/fluid/operators/npu_op_runner.h"
#ifdef PADDLE_WITH_ASCEND_CL
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename DeviceContext, typename T>
class EqualNPUKernel : public framework::OpKernel<T> { class EqualNPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -42,16 +36,33 @@ class EqualNPUKernel : public framework::OpKernel<T> { ...@@ -42,16 +36,33 @@ class EqualNPUKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T>
class NotEqualNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* out = ctx.Output<framework::LoDTensor>("Out");
out->mutable_data<bool>(ctx.GetPlace());
const auto& runner = NpuOpRunner("NotEqual", {*x, *y}, {*out}, {});
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
runner.Run(stream);
}
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class LessThanNPUKernel : public framework::OpKernel<T> { class LessThanNPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::LoDTensor>("X"); auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y"); auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* z = ctx.Output<framework::LoDTensor>("Out"); auto* out = ctx.Output<framework::LoDTensor>("Out");
// int axis = context.Attr<int>("axis"); out->mutable_data<bool>(ctx.GetPlace());
z->mutable_data<bool>(ctx.GetPlace()); // allocate
const auto& runner = NpuOpRunner("Less", {*x, *y}, {*z}); const auto& runner = NpuOpRunner("Less", {*x, *y}, {*out}, {});
auto stream = auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>() ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream(); .stream();
...@@ -65,9 +76,10 @@ class LessEqualNPUKernel : public framework::OpKernel<T> { ...@@ -65,9 +76,10 @@ class LessEqualNPUKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::LoDTensor>("X"); auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y"); auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* z = ctx.Output<framework::LoDTensor>("Out"); auto* out = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<bool>(ctx.GetPlace()); out->mutable_data<bool>(ctx.GetPlace());
const auto& runner = NpuOpRunner("LessEqual", {*x, *y}, {*z});
const auto& runner = NpuOpRunner("LessEqual", {*x, *y}, {*out}, {});
auto stream = auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>() ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream(); .stream();
...@@ -81,10 +93,10 @@ class GreaterThanNPUKernel : public framework::OpKernel<T> { ...@@ -81,10 +93,10 @@ class GreaterThanNPUKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::LoDTensor>("X"); auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y"); auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* z = ctx.Output<framework::LoDTensor>("Out"); auto* out = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<bool>(ctx.GetPlace()); out->mutable_data<bool>(ctx.GetPlace());
const auto& runner = NpuOpRunner("Greater", {*x, *y}, {*z}); const auto& runner = NpuOpRunner("Greater", {*x, *y}, {*out}, {});
auto stream = auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>() ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream(); .stream();
...@@ -98,10 +110,10 @@ class GreaterEqualNPUKernel : public framework::OpKernel<T> { ...@@ -98,10 +110,10 @@ class GreaterEqualNPUKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::LoDTensor>("X"); auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y"); auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* z = ctx.Output<framework::LoDTensor>("Out"); auto* out = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<bool>(ctx.GetPlace()); out->mutable_data<bool>(ctx.GetPlace());
const auto& runner = NpuOpRunner("GreaterEqual", {*x, *y}, {*z}); const auto& runner = NpuOpRunner("GreaterEqual", {*x, *y}, {*out}, {});
auto stream = auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>() ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream(); .stream();
...@@ -115,32 +127,64 @@ class GreaterEqualNPUKernel : public framework::OpKernel<T> { ...@@ -115,32 +127,64 @@ class GreaterEqualNPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(equal, ops::EqualNPUKernel<float>, REGISTER_OP_NPU_KERNEL(
ops::EqualNPUKernel<plat::float16>, equal, ops::EqualNPUKernel<plat::NPUDeviceContext, plat::float16>,
ops::EqualNPUKernel<int>); ops::EqualNPUKernel<plat::NPUDeviceContext, float>,
ops::EqualNPUKernel<plat::NPUDeviceContext, double>,
ops::EqualNPUKernel<plat::NPUDeviceContext, int8_t>,
ops::EqualNPUKernel<plat::NPUDeviceContext, uint8_t>,
ops::EqualNPUKernel<plat::NPUDeviceContext, int16_t>,
ops::EqualNPUKernel<plat::NPUDeviceContext, int>,
ops::EqualNPUKernel<plat::NPUDeviceContext, int64_t>,
ops::EqualNPUKernel<plat::NPUDeviceContext, bool>);
REGISTER_OP_NPU_KERNEL(
not_equal, ops::NotEqualNPUKernel<plat::NPUDeviceContext, plat::float16>,
ops::NotEqualNPUKernel<plat::NPUDeviceContext, float>,
ops::NotEqualNPUKernel<plat::NPUDeviceContext, double>,
ops::NotEqualNPUKernel<plat::NPUDeviceContext, int8_t>,
ops::NotEqualNPUKernel<plat::NPUDeviceContext, uint8_t>,
ops::NotEqualNPUKernel<plat::NPUDeviceContext, int16_t>,
ops::NotEqualNPUKernel<plat::NPUDeviceContext, int>,
ops::NotEqualNPUKernel<plat::NPUDeviceContext, int64_t>);
REGISTER_OP_NPU_KERNEL( REGISTER_OP_NPU_KERNEL(
less_than, less_than, ops::LessThanNPUKernel<plat::NPUDeviceContext, plat::float16>,
ops::LessThanNPUKernel<paddle::platform::NPUDeviceContext, float>, ops::LessThanNPUKernel<plat::NPUDeviceContext, float>,
ops::LessThanNPUKernel<paddle::platform::NPUDeviceContext, ops::LessThanNPUKernel<plat::NPUDeviceContext, double>,
paddle::platform::float16>); ops::LessThanNPUKernel<plat::NPUDeviceContext, int8_t>,
ops::LessThanNPUKernel<plat::NPUDeviceContext, uint8_t>,
ops::LessThanNPUKernel<plat::NPUDeviceContext, int16_t>,
ops::LessThanNPUKernel<plat::NPUDeviceContext, int>,
ops::LessThanNPUKernel<plat::NPUDeviceContext, int64_t>);
REGISTER_OP_NPU_KERNEL( REGISTER_OP_NPU_KERNEL(
less_equal, less_equal, ops::LessEqualNPUKernel<plat::NPUDeviceContext, plat::float16>,
ops::LessEqualNPUKernel<paddle::platform::NPUDeviceContext, float>, ops::LessEqualNPUKernel<plat::NPUDeviceContext, float>,
ops::LessEqualNPUKernel<paddle::platform::NPUDeviceContext, ops::LessEqualNPUKernel<plat::NPUDeviceContext, double>,
paddle::platform::float16>); ops::LessEqualNPUKernel<plat::NPUDeviceContext, int8_t>,
ops::LessEqualNPUKernel<plat::NPUDeviceContext, uint8_t>,
ops::LessEqualNPUKernel<plat::NPUDeviceContext, int16_t>,
ops::LessEqualNPUKernel<plat::NPUDeviceContext, int>,
ops::LessEqualNPUKernel<plat::NPUDeviceContext, int64_t>);
REGISTER_OP_NPU_KERNEL( REGISTER_OP_NPU_KERNEL(
greater_than, greater_than,
ops::GreaterThanNPUKernel<paddle::platform::NPUDeviceContext, float>, ops::GreaterThanNPUKernel<plat::NPUDeviceContext, plat::float16>,
ops::GreaterThanNPUKernel<paddle::platform::NPUDeviceContext, ops::GreaterThanNPUKernel<plat::NPUDeviceContext, float>,
paddle::platform::float16>); ops::GreaterThanNPUKernel<plat::NPUDeviceContext, double>,
ops::GreaterThanNPUKernel<plat::NPUDeviceContext, int8_t>,
ops::GreaterThanNPUKernel<plat::NPUDeviceContext, uint8_t>,
ops::GreaterThanNPUKernel<plat::NPUDeviceContext, int16_t>,
ops::GreaterThanNPUKernel<plat::NPUDeviceContext, int>,
ops::GreaterThanNPUKernel<plat::NPUDeviceContext, int64_t>);
REGISTER_OP_NPU_KERNEL( REGISTER_OP_NPU_KERNEL(
greater_equal, greater_equal,
ops::GreaterEqualNPUKernel<paddle::platform::NPUDeviceContext, float>, ops::GreaterEqualNPUKernel<plat::NPUDeviceContext, plat::float16>,
ops::GreaterEqualNPUKernel<paddle::platform::NPUDeviceContext, ops::GreaterEqualNPUKernel<plat::NPUDeviceContext, float>,
paddle::platform::float16>); ops::GreaterEqualNPUKernel<plat::NPUDeviceContext, double>,
ops::GreaterEqualNPUKernel<plat::NPUDeviceContext, int8_t>,
#endif ops::GreaterEqualNPUKernel<plat::NPUDeviceContext, uint8_t>,
ops::GreaterEqualNPUKernel<plat::NPUDeviceContext, int>,
ops::GreaterEqualNPUKernel<plat::NPUDeviceContext, int64_t>);
...@@ -142,11 +142,12 @@ def create_test_class(op_type, typename, callback): ...@@ -142,11 +142,12 @@ def create_test_class(op_type, typename, callback):
globals()[cls_name] = Cls globals()[cls_name] = Cls
for _type_name in {'float16', 'float32', 'int32'}: for _type_name in {'float16', 'float32', 'int32', 'int64', 'bool'}:
if _type_name == 'int32': if _type_name == 'int32' or _type_name == 'bool':
create_test_class('equal', _type_name, lambda _a, _b: _a == _b) create_test_class('equal', _type_name, lambda _a, _b: _a == _b)
continue continue
create_test_class('equal', _type_name, lambda _a, _b: _a == _b) create_test_class('equal', _type_name, lambda _a, _b: _a == _b)
create_test_class('not_equal', _type_name, lambda _a, _b: _a != _b)
create_test_class('less_than', _type_name, lambda _a, _b: _a < _b) create_test_class('less_than', _type_name, lambda _a, _b: _a < _b)
create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b) create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b)
create_test_class('greater_than', _type_name, lambda _a, _b: _a > _b) create_test_class('greater_than', _type_name, lambda _a, _b: _a > _b)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册