未验证 提交 7e707ce8 编写于 作者: B baoachun 提交者: GitHub

add not_equal NPU op (#34560)

* add not_equal NPU op

* add not_equal NPU op

* add not_equal NPU op

* add not_equal NPU op
上级 8144a730
......@@ -11,21 +11,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/controlflow/compare_op.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/npu_op_runner.h"
#ifdef PADDLE_WITH_ASCEND_CL
namespace paddle {
namespace operators {
template <typename T>
template <typename DeviceContext, typename T>
class EqualNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -42,16 +36,33 @@ class EqualNPUKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
class NotEqualNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* out = ctx.Output<framework::LoDTensor>("Out");
out->mutable_data<bool>(ctx.GetPlace());
const auto& runner = NpuOpRunner("NotEqual", {*x, *y}, {*out}, {});
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
runner.Run(stream);
}
};
template <typename DeviceContext, typename T>
class LessThanNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* z = ctx.Output<framework::LoDTensor>("Out");
// int axis = context.Attr<int>("axis");
z->mutable_data<bool>(ctx.GetPlace()); // allocate
const auto& runner = NpuOpRunner("Less", {*x, *y}, {*z});
auto* out = ctx.Output<framework::LoDTensor>("Out");
out->mutable_data<bool>(ctx.GetPlace());
const auto& runner = NpuOpRunner("Less", {*x, *y}, {*out}, {});
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
......@@ -65,9 +76,10 @@ class LessEqualNPUKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* z = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<bool>(ctx.GetPlace());
const auto& runner = NpuOpRunner("LessEqual", {*x, *y}, {*z});
auto* out = ctx.Output<framework::LoDTensor>("Out");
out->mutable_data<bool>(ctx.GetPlace());
const auto& runner = NpuOpRunner("LessEqual", {*x, *y}, {*out}, {});
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
......@@ -81,10 +93,10 @@ class GreaterThanNPUKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* z = ctx.Output<framework::LoDTensor>("Out");
auto* out = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<bool>(ctx.GetPlace());
const auto& runner = NpuOpRunner("Greater", {*x, *y}, {*z});
out->mutable_data<bool>(ctx.GetPlace());
const auto& runner = NpuOpRunner("Greater", {*x, *y}, {*out}, {});
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
......@@ -98,10 +110,10 @@ class GreaterEqualNPUKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* z = ctx.Output<framework::LoDTensor>("Out");
auto* out = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<bool>(ctx.GetPlace());
const auto& runner = NpuOpRunner("GreaterEqual", {*x, *y}, {*z});
out->mutable_data<bool>(ctx.GetPlace());
const auto& runner = NpuOpRunner("GreaterEqual", {*x, *y}, {*out}, {});
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
......@@ -115,32 +127,64 @@ class GreaterEqualNPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(equal, ops::EqualNPUKernel<float>,
ops::EqualNPUKernel<plat::float16>,
ops::EqualNPUKernel<int>);
REGISTER_OP_NPU_KERNEL(
equal, ops::EqualNPUKernel<plat::NPUDeviceContext, plat::float16>,
ops::EqualNPUKernel<plat::NPUDeviceContext, float>,
ops::EqualNPUKernel<plat::NPUDeviceContext, double>,
ops::EqualNPUKernel<plat::NPUDeviceContext, int8_t>,
ops::EqualNPUKernel<plat::NPUDeviceContext, uint8_t>,
ops::EqualNPUKernel<plat::NPUDeviceContext, int16_t>,
ops::EqualNPUKernel<plat::NPUDeviceContext, int>,
ops::EqualNPUKernel<plat::NPUDeviceContext, int64_t>,
ops::EqualNPUKernel<plat::NPUDeviceContext, bool>);
REGISTER_OP_NPU_KERNEL(
not_equal, ops::NotEqualNPUKernel<plat::NPUDeviceContext, plat::float16>,
ops::NotEqualNPUKernel<plat::NPUDeviceContext, float>,
ops::NotEqualNPUKernel<plat::NPUDeviceContext, double>,
ops::NotEqualNPUKernel<plat::NPUDeviceContext, int8_t>,
ops::NotEqualNPUKernel<plat::NPUDeviceContext, uint8_t>,
ops::NotEqualNPUKernel<plat::NPUDeviceContext, int16_t>,
ops::NotEqualNPUKernel<plat::NPUDeviceContext, int>,
ops::NotEqualNPUKernel<plat::NPUDeviceContext, int64_t>);
REGISTER_OP_NPU_KERNEL(
less_than,
ops::LessThanNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::LessThanNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
less_than, ops::LessThanNPUKernel<plat::NPUDeviceContext, plat::float16>,
ops::LessThanNPUKernel<plat::NPUDeviceContext, float>,
ops::LessThanNPUKernel<plat::NPUDeviceContext, double>,
ops::LessThanNPUKernel<plat::NPUDeviceContext, int8_t>,
ops::LessThanNPUKernel<plat::NPUDeviceContext, uint8_t>,
ops::LessThanNPUKernel<plat::NPUDeviceContext, int16_t>,
ops::LessThanNPUKernel<plat::NPUDeviceContext, int>,
ops::LessThanNPUKernel<plat::NPUDeviceContext, int64_t>);
REGISTER_OP_NPU_KERNEL(
less_equal,
ops::LessEqualNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::LessEqualNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
less_equal, ops::LessEqualNPUKernel<plat::NPUDeviceContext, plat::float16>,
ops::LessEqualNPUKernel<plat::NPUDeviceContext, float>,
ops::LessEqualNPUKernel<plat::NPUDeviceContext, double>,
ops::LessEqualNPUKernel<plat::NPUDeviceContext, int8_t>,
ops::LessEqualNPUKernel<plat::NPUDeviceContext, uint8_t>,
ops::LessEqualNPUKernel<plat::NPUDeviceContext, int16_t>,
ops::LessEqualNPUKernel<plat::NPUDeviceContext, int>,
ops::LessEqualNPUKernel<plat::NPUDeviceContext, int64_t>);
REGISTER_OP_NPU_KERNEL(
greater_than,
ops::GreaterThanNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::GreaterThanNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
ops::GreaterThanNPUKernel<plat::NPUDeviceContext, plat::float16>,
ops::GreaterThanNPUKernel<plat::NPUDeviceContext, float>,
ops::GreaterThanNPUKernel<plat::NPUDeviceContext, double>,
ops::GreaterThanNPUKernel<plat::NPUDeviceContext, int8_t>,
ops::GreaterThanNPUKernel<plat::NPUDeviceContext, uint8_t>,
ops::GreaterThanNPUKernel<plat::NPUDeviceContext, int16_t>,
ops::GreaterThanNPUKernel<plat::NPUDeviceContext, int>,
ops::GreaterThanNPUKernel<plat::NPUDeviceContext, int64_t>);
REGISTER_OP_NPU_KERNEL(
greater_equal,
ops::GreaterEqualNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::GreaterEqualNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
#endif
ops::GreaterEqualNPUKernel<plat::NPUDeviceContext, plat::float16>,
ops::GreaterEqualNPUKernel<plat::NPUDeviceContext, float>,
ops::GreaterEqualNPUKernel<plat::NPUDeviceContext, double>,
ops::GreaterEqualNPUKernel<plat::NPUDeviceContext, int8_t>,
ops::GreaterEqualNPUKernel<plat::NPUDeviceContext, uint8_t>,
ops::GreaterEqualNPUKernel<plat::NPUDeviceContext, int>,
ops::GreaterEqualNPUKernel<plat::NPUDeviceContext, int64_t>);
......@@ -142,11 +142,12 @@ def create_test_class(op_type, typename, callback):
globals()[cls_name] = Cls
for _type_name in {'float16', 'float32', 'int32'}:
if _type_name == 'int32':
for _type_name in {'float16', 'float32', 'int32', 'int64', 'bool'}:
if _type_name == 'int32' or _type_name == 'bool':
create_test_class('equal', _type_name, lambda _a, _b: _a == _b)
continue
create_test_class('equal', _type_name, lambda _a, _b: _a == _b)
create_test_class('not_equal', _type_name, lambda _a, _b: _a != _b)
create_test_class('less_than', _type_name, lambda _a, _b: _a < _b)
create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b)
create_test_class('greater_than', _type_name, lambda _a, _b: _a > _b)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册