未验证 提交 e3e15792 编写于 作者: M Meiyim 提交者: GitHub

[NPU] support npu kernel for `less_than` (#31327)

* [npu] support npu kernel for `less than`

* remove int* kernel

* cleanup
上级 a3cc4a4a
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
......@@ -21,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/operators/controlflow/compare_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/npu_op_runner.h"
#ifdef PADDLE_WITH_ASCEND_CL
namespace paddle {
namespace operators {
......@@ -42,6 +42,23 @@ class EqualNPUKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
class LessThanNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* z = ctx.Output<framework::LoDTensor>("Out");
// int axis = context.Attr<int>("axis");
z->mutable_data<bool>(ctx.GetPlace()); // allocate
auto runner = NpuOpRunner("Less", {*x, *y}, {*z});
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
runner.Run(stream);
}
};
} // namespace operators
} // namespace paddle
......@@ -51,3 +68,11 @@ namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(equal, ops::EqualNPUKernel<float>,
ops::EqualNPUKernel<plat::float16>,
ops::EqualNPUKernel<int>);
REGISTER_OP_NPU_KERNEL(
less_than,
ops::LessThanNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::LessThanNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
#endif
......@@ -56,6 +56,36 @@ class TestEqual(OpTest):
self.check_output_with_place(self.place, check_dygraph=False)
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestLessthan(OpTest):
def setUp(self):
self.set_npu()
self.op_type = "less_than"
self.place = paddle.NPUPlace(0)
self.init_dtype()
np.random.seed(SEED)
x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype)
y = np.random.uniform(1, 2, [11, 17]).astype(self.dtype)
out = x < y
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(x),
'Y': OpTest.np_dtype_to_fluid_dtype(y)
}
self.outputs = {'Out': out}
def set_npu(self):
self.__class__.use_npu = True
def init_dtype(self):
self.dtype = np.float32
def test_check_output(self):
self.check_output_with_place(self.place, check_dygraph=False)
class TestEqual2(TestEqual):
def setUp(self):
self.set_npu()
......@@ -76,6 +106,26 @@ class TestEqual2(TestEqual):
self.outputs = {'Out': out}
class TestLessthan2(TestLessthan):
def setUp(self):
self.set_npu()
self.op_type = "less_than"
self.place = paddle.NPUPlace(0)
self.init_dtype()
np.random.seed(SEED)
x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype)
y = x.copy()
y[0][1] = 1
out = x < y # all elements are equal, except position [0][1]
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(x),
'Y': OpTest.np_dtype_to_fluid_dtype(y)
}
self.outputs = {'Out': out}
class TestEqual2FP16(TestEqual2):
def init_dtype(self):
self.dtype = np.float16
......@@ -86,5 +136,10 @@ class TestEqual2Int(TestEqual2):
self.dtype = np.int32
class TestLessthan2FP16(TestLessthan2):
def init_dtype(self):
self.dtype = np.float16
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册