未验证 提交 68b4a2c3 编写于 作者: A Aganlengzi 提交者: GitHub

[NPU] add NPU ops of compare, test=develop (#34365)

* [NPU] add NPU ops&uts of compare, test=develop

* testing

* try style-format

* [NPU] update compare_op_npu uts

* [NPU] fix code sytle of test_compare_op_npu.py
上级 5e27d16d
...@@ -59,6 +59,56 @@ class LessThanNPUKernel : public framework::OpKernel<T> { ...@@ -59,6 +59,56 @@ class LessThanNPUKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T>
class LessEqualNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* z = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<bool>(ctx.GetPlace());
const auto& runner = NpuOpRunner("LessEqual", {*x, *y}, {*z});
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
runner.Run(stream);
}
};
template <typename DeviceContext, typename T>
class GreaterThanNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* z = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<bool>(ctx.GetPlace());
const auto& runner = NpuOpRunner("Greater", {*x, *y}, {*z});
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
runner.Run(stream);
}
};
template <typename DeviceContext, typename T>
class GreaterEqualNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* z = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<bool>(ctx.GetPlace());
const auto& runner = NpuOpRunner("GreaterEqual", {*x, *y}, {*z});
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
runner.Run(stream);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -75,4 +125,22 @@ REGISTER_OP_NPU_KERNEL( ...@@ -75,4 +125,22 @@ REGISTER_OP_NPU_KERNEL(
ops::LessThanNPUKernel<paddle::platform::NPUDeviceContext, ops::LessThanNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>); paddle::platform::float16>);
REGISTER_OP_NPU_KERNEL(
less_equal,
ops::LessEqualNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::LessEqualNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_NPU_KERNEL(
greater_than,
ops::GreaterThanNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::GreaterThanNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_NPU_KERNEL(
greater_equal,
ops::GreaterEqualNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::GreaterEqualNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
#endif #endif
...@@ -21,121 +21,136 @@ sys.path.append("..") ...@@ -21,121 +21,136 @@ sys.path.append("..")
from op_test import OpTest from op_test import OpTest
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
paddle.enable_static()
SEED = 2021
def create_test_class(op_type, typename, callback):
class TestEqual(OpTest): class Cls(OpTest):
def setUp(self): def setUp(self):
self.set_npu() self.set_npu()
self.op_type = "equal"
self.place = paddle.NPUPlace(0) self.place = paddle.NPUPlace(0)
x = np.random.random(size=(10, 7)).astype(typename)
self.init_dtype() y = np.random.random(size=(10, 7)).astype(typename)
np.random.seed(SEED) out = callback(x, y)
x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) self.inputs = {'X': x, 'Y': y}
y = np.random.uniform(1, 2, [11, 17]).astype(self.dtype)
out = x == y # all elements are not equal
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(x),
'Y': OpTest.np_dtype_to_fluid_dtype(y)
}
self.outputs = {'Out': out} self.outputs = {'Out': out}
self.op_type = op_type
def set_npu(self): def set_npu(self):
self.__class__.use_npu = True self.__class__.use_npu = True
def init_dtype(self): def test_output(self):
self.dtype = np.float32 self.check_output_with_place(place=self.place)
def test_check_output(self): def test_errors(self):
self.check_output_with_place(self.place) paddle.enable_static()
with program_guard(Program(), Program()):
a = fluid.layers.data(name='a', shape=[2], dtype='float32')
class TestLessthan(OpTest): b = fluid.layers.data(name='b', shape=[2], dtype='float32')
def setUp(self): c = fluid.layers.data(name='c', shape=[2], dtype='int16')
self.set_npu() d = fluid.create_lod_tensor(np.array([[-1]]), [[1]], self.place)
self.op_type = "less_than"
self.place = paddle.NPUPlace(0) op = eval("fluid.layers.%s" % self.op_type)
self.assertRaises(TypeError, op, x=a, y=b, axis=True)
self.init_dtype() self.assertRaises(TypeError, op, x=a, y=b, force_cpu=1)
np.random.seed(SEED) self.assertRaises(TypeError, op, x=a, y=b, cond=1)
x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) self.assertRaises(TypeError, op, x=a, y=c)
y = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) self.assertRaises(TypeError, op, x=c, y=a)
out = x < y self.assertRaises(TypeError, op, x=a, y=d)
self.assertRaises(TypeError, op, x=d, y=a)
self.inputs = { self.assertRaises(TypeError, op, x=c, y=d)
'X': OpTest.np_dtype_to_fluid_dtype(x),
'Y': OpTest.np_dtype_to_fluid_dtype(y) def test_dynamic_api(self):
} paddle.disable_static()
self.outputs = {'Out': out} paddle.set_device('npu:0')
x = np.random.random(size=(10, 7)).astype(typename)
def set_npu(self): y = np.random.random(size=(10, 7)).astype(typename)
self.__class__.use_npu = True real_result = callback(x, y)
x = paddle.to_tensor(x, dtype=typename)
def init_dtype(self): y = paddle.to_tensor(y, dtype=typename)
self.dtype = np.float32 op = eval("paddle.%s" % (self.op_type))
out = op(x, y)
def test_check_output(self): self.assertEqual((out.numpy() == real_result).all(), True)
self.check_output_with_place(self.place)
@unittest.skipIf(typename == 'float16', "float16 is not supported now")
def test_broadcast_api_1(self):
class TestEqual2(TestEqual): paddle.enable_static()
def setUp(self): with program_guard(Program(), Program()):
self.set_npu() x = paddle.static.data(
self.op_type = "equal" name='x', shape=[1, 2, 1, 3], dtype=typename)
self.place = paddle.NPUPlace(0) y = paddle.static.data(
name='y', shape=[1, 2, 3], dtype=typename)
self.init_dtype() op = eval("paddle.%s" % (self.op_type))
np.random.seed(SEED) out = op(x, y)
x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) exe = paddle.static.Executor(self.place)
y = x.copy() input_x = np.arange(1, 7).reshape((1, 2, 1, 3)).astype(typename)
y[0][1] = 1 input_y = np.arange(0, 6).reshape((1, 2, 3)).astype(typename)
out = x == y # all elements are equal, except position [0][1] real_result = callback(input_x, input_y)
res, = exe.run(feed={"x": input_x,
self.inputs = { "y": input_y},
'X': OpTest.np_dtype_to_fluid_dtype(x), fetch_list=[out])
'Y': OpTest.np_dtype_to_fluid_dtype(y) self.assertEqual((res == real_result).all(), True)
}
self.outputs = {'Out': out} @unittest.skipIf(typename == 'float16', "float16 is not supported now")
def test_broadcast_api_2(self):
paddle.enable_static()
class TestLessthan2(TestLessthan): with program_guard(Program(), Program()):
def setUp(self): x = paddle.static.data(
self.set_npu() name='x', shape=[1, 2, 3], dtype=typename)
self.op_type = "less_than" y = paddle.static.data(
self.place = paddle.NPUPlace(0) name='y', shape=[1, 2, 1, 3], dtype=typename)
op = eval("paddle.%s" % (self.op_type))
self.init_dtype() out = op(x, y)
np.random.seed(SEED) exe = paddle.static.Executor(self.place)
x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) input_x = np.arange(0, 6).reshape((1, 2, 3)).astype(typename)
y = x.copy() input_y = np.arange(1, 7).reshape((1, 2, 1, 3)).astype(typename)
y[0][1] = 1 real_result = callback(input_x, input_y)
out = x < y # all elements are equal, except position [0][1] res, = exe.run(feed={"x": input_x,
"y": input_y},
self.inputs = { fetch_list=[out])
'X': OpTest.np_dtype_to_fluid_dtype(x), self.assertEqual((res == real_result).all(), True)
'Y': OpTest.np_dtype_to_fluid_dtype(y)
} @unittest.skipIf(typename == 'float16', "float16 is not supported now")
self.outputs = {'Out': out} def test_broadcast_api_3(self):
paddle.enable_static()
with program_guard(Program(), Program()):
class TestEqual2FP16(TestEqual2): x = paddle.static.data(name='x', shape=[5], dtype=typename)
def init_dtype(self): y = paddle.static.data(name='y', shape=[3, 1], dtype=typename)
self.dtype = np.float16 op = eval("paddle.%s" % (self.op_type))
out = op(x, y)
exe = paddle.static.Executor(self.place)
class TestEqual2Int(TestEqual2): input_x = np.arange(0, 5).reshape((5)).astype(typename)
def init_dtype(self): input_y = np.array([5, 3, 2]).reshape((3, 1)).astype(typename)
self.dtype = np.int32 real_result = callback(input_x, input_y)
res, = exe.run(feed={"x": input_x,
"y": input_y},
class TestLessthan2FP16(TestLessthan2): fetch_list=[out])
def init_dtype(self): self.assertEqual((res == real_result).all(), True)
self.dtype = np.float16
@unittest.skipIf(typename == 'float16', "float16 is not supported now")
def test_attr_name(self):
paddle.enable_static()
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[4], dtype=typename)
y = fluid.layers.data(name='y', shape=[4], dtype=typename)
op = eval("paddle.%s" % (self.op_type))
out = op(x=x, y=y, name="name_%s" % (self.op_type))
self.assertEqual("name_%s" % (self.op_type) in out.name, True)
cls_name = "{0}_{1}".format(op_type, typename)
Cls.__name__ = cls_name
globals()[cls_name] = Cls
for _type_name in {'float16', 'float32', 'int32'}:
if _type_name == 'int32':
create_test_class('equal', _type_name, lambda _a, _b: _a == _b)
continue
create_test_class('equal', _type_name, lambda _a, _b: _a == _b)
create_test_class('less_than', _type_name, lambda _a, _b: _a < _b)
create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b)
create_test_class('greater_than', _type_name, lambda _a, _b: _a > _b)
create_test_class('greater_equal', _type_name, lambda _a, _b: _a >= _b)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册