From acec26a1f3e6b85c78f293e0418857ddd34df0c8 Mon Sep 17 00:00:00 2001 From: taixiurong Date: Sat, 2 Apr 2022 10:52:36 +0800 Subject: [PATCH] xpu add dropout&cast unitest (#41120) --- paddle/fluid/operators/dropout_op_xpu.cc | 8 +- .../fluid/tests/unittests/op_test_xpu.py | 49 +++- .../tests/unittests/xpu/test_cast_op_xpu.py | 38 ++- .../unittests/xpu/test_dropout_op_xpu.py | 274 ++++++++++++------ 4 files changed, 259 insertions(+), 110 deletions(-) diff --git a/paddle/fluid/operators/dropout_op_xpu.cc b/paddle/fluid/operators/dropout_op_xpu.cc index 7d8660f238a..851f26ee0e7 100644 --- a/paddle/fluid/operators/dropout_op_xpu.cc +++ b/paddle/fluid/operators/dropout_op_xpu.cc @@ -42,7 +42,13 @@ class DropoutXPUKernel : public framework::OpKernel { if (!context.Attr("is_test")) { int seed_data = 0; if (seed) { - seed_data = *(seed->data()); + if (platform::is_xpu_place(seed->place())) { + memory::Copy(platform::CPUPlace(), &seed_data, seed->place(), + seed->data(), sizeof(int)); + } else { + seed_data = *(seed->data()); + } + } else { seed_data = context.Attr("fix_seed") ? context.Attr("seed") : 0; diff --git a/python/paddle/fluid/tests/unittests/op_test_xpu.py b/python/paddle/fluid/tests/unittests/op_test_xpu.py index 107f340d3a8..4a67af02bcf 100644 --- a/python/paddle/fluid/tests/unittests/op_test_xpu.py +++ b/python/paddle/fluid/tests/unittests/op_test_xpu.py @@ -54,13 +54,11 @@ class XPUOpTest(OpTest): """Restore random seeds""" def is_empty_grad_op(op_type): - all_op_kernels = core._get_all_register_op_kernels() grad_op = op_type + '_grad' - if grad_op in all_op_kernels.keys(): - grad_op_kernels = all_op_kernels[grad_op] - for grad_op_kernel in grad_op_kernels: - if 'XPU' in grad_op_kernel: - return False + xpu_version = core.get_xpu_device_version(0) + xpu_op_list = core.get_xpu_device_op_list(xpu_version) + if grad_op in xpu_op_list.keys(): + return False return True if cls.dtype == np.float16: @@ -70,9 +68,20 @@ class XPUOpTest(OpTest): super().tearDownClass() def _get_places(self): - places = [fluid.XPUPlace(0)] + places = [paddle.XPUPlace(0)] return places + def check_output(self, + atol=0.001, + no_check_set=None, + equal_nan=False, + check_dygraph=True, + inplace_atol=None, + check_eager=False): + place = paddle.XPUPlace(0) + self.check_output_with_place(place, atol, no_check_set, equal_nan, + check_dygraph, inplace_atol, check_eager) + def check_output_with_place(self, place, atol=0.001, @@ -82,20 +91,37 @@ class XPUOpTest(OpTest): inplace_atol=None, check_eager=False): self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs) - #xpu not support float64 if self.dtype == np.float64: return - if place == None: - place = paddle.XPUPlace(0) if self.dtype == np.float16: if core.is_float16_supported(place) == False: return + if self.dtype == np.float16: atol = 0.1 return super().check_output_with_place( place, atol, no_check_set, equal_nan, check_dygraph, inplace_atol) + def check_grad(self, + inputs_to_check, + output_names, + no_grad_set=None, + numeric_grad_delta=0.005, + in_place=False, + max_relative_error=0.005, + user_defined_grads=None, + user_defined_grad_outputs=None, + check_dygraph=True, + numeric_place=None, + check_eager=False): + place = paddle.XPUPlace(0) + self.check_grad_with_place(place, inputs_to_check, output_names, + no_grad_set, numeric_grad_delta, in_place, + max_relative_error, user_defined_grads, + user_defined_grad_outputs, check_dygraph, + numeric_place, check_eager) + def check_grad_with_place(self, place, inputs_to_check, @@ -116,9 +142,6 @@ class XPUOpTest(OpTest): self._check_grad_helper() return - if place == None: - place = paddle.XPUPlace(0) - if self.dtype == np.float64: return diff --git a/python/paddle/fluid/tests/unittests/xpu/test_cast_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_cast_op_xpu.py index 08d4810a653..201e758c0ac 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_cast_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_cast_op_xpu.py @@ -23,6 +23,9 @@ import paddle import paddle.fluid.core as core import paddle.fluid as fluid from paddle.fluid import compiler, Program, program_guard +from op_test_xpu import XPUOpTest + +from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper typeid_dict = { 'int32': int(core.VarDesc.VarType.INT32), @@ -33,10 +36,27 @@ typeid_dict = { } -def create_test_class(in_typename, out_typename): - class Cls(op_test.OpTest): +class XPUTestCastOp(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'cast' + self.use_dynamic_create_class = True + + def dynamic_create_class(self): + base_class = self.TestCastOp + classes = [] + for out_type in {'float16', 'float32', 'int32', 'int64'}: + class_name = 'XPUTestCastOp_outtype_' + out_type + attr_dict = {'out_typename': out_type} + classes.append([class_name, attr_dict]) + return base_class, classes + + class TestCastOp(XPUOpTest): def setUp(self): ipt = np.random.random(size=[10, 10]) + in_typename = self.in_type_str + out_typename = 'float32' if not hasattr( + self, 'out_typename') else self.out_typename + self.inputs = {'X': ipt.astype(in_typename)} self.outputs = {'Out': ipt.astype(in_typename).astype(out_typename)} self.attrs = { @@ -47,18 +67,12 @@ def create_test_class(in_typename, out_typename): self.__class__.no_need_check_grad = True def test_check_output(self): - if paddle.is_compiled_with_xpu(): - place = paddle.XPUPlace(0) - self.check_output_with_place(place) - - cls_name = "cast_{0}_{1}".format(in_typename, out_typename) - Cls.__name__ = cls_name - globals()[cls_name] = Cls + self.check_output() -for in_type in {'float16', 'float32', 'int32', 'int64', 'bool'}: - for out_type in {'float16', 'float32', 'int32', 'int64'}: - create_test_class(in_type, out_type) +support_types = get_xpu_op_support_types('cast') +for stype in support_types: + create_test_class(globals(), XPUTestCastOp, stype) class TestCastOpError(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/xpu/test_dropout_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_dropout_op_xpu.py index ca3b3a418ab..2baa837b23a 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_dropout_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_dropout_op_xpu.py @@ -25,90 +25,196 @@ from paddle.fluid import Program, program_guard from op_test_xpu import XPUOpTest paddle.enable_static() - -class TestDropoutOp(XPUOpTest): - def setUp(self): - self.op_type = "dropout" - self.inputs = {'X': np.random.random((32, 64)).astype("float32")} - self.attrs = {'dropout_prob': 0.0, 'fix_seed': True, 'is_test': False} - self.outputs = { - 'Out': self.inputs['X'], - 'Mask': np.ones((32, 64)).astype('uint8') - } - - def test_check_output(self): - if paddle.is_compiled_with_xpu(): - paddle.enable_static() - place = paddle.XPUPlace(0) - self.check_output_with_place(place) - - def test_check_grad_normal(self): - if paddle.is_compiled_with_xpu(): - paddle.enable_static() - place = paddle.XPUPlace(0) - self.check_grad_with_place(place, ['X'], 'Out') - - -class TestDropoutOpInput1d(XPUOpTest): - def setUp(self): - self.op_type = "dropout" - self.inputs = {'X': np.random.random((2000, )).astype("float32")} - self.attrs = {'dropout_prob': 0.0, 'fix_seed': True, 'is_test': False} - self.outputs = { - 'Out': self.inputs['X'], - 'Mask': np.ones((2000)).astype('uint8') - } - - def test_check_output(self): - if paddle.is_compiled_with_xpu(): - paddle.enable_static() - place = paddle.XPUPlace(0) - self.check_output_with_place(place) - - def test_check_grad_normal(self): - if paddle.is_compiled_with_xpu(): - paddle.enable_static() - place = paddle.XPUPlace(0) - self.check_grad_with_place(place, ['X'], 'Out') - - -class TestDropoutOp2(TestDropoutOp): - def setUp(self): - self.op_type = "dropout" - self.inputs = {'X': np.random.random((32, 64)).astype("float32")} - self.attrs = {'dropout_prob': 1.0, 'fix_seed': True, 'is_test': False} - self.outputs = { - 'Out': np.zeros((32, 64)).astype('float32'), - 'Mask': np.zeros((32, 64)).astype('uint8') - } - - -class TestDropoutOp3(TestDropoutOp): - def setUp(self): - self.op_type = "dropout" - self.inputs = {'X': np.random.random((32, 64, 2)).astype("float32")} - self.attrs = {'dropout_prob': 0.0, 'fix_seed': True, 'is_test': False} - self.outputs = { - 'Out': self.inputs['X'], - 'Mask': np.ones((32, 64, 2)).astype('uint8') - } - - -class TestDropoutOp6(TestDropoutOp): - def setUp(self): - self.op_type = "dropout" - self.inputs = {'X': np.random.random((32, 64, 2)).astype("float32")} - self.attrs = { - 'dropout_prob': 0.0, - 'fix_seed': True, - 'is_test': False, - 'dropout_implementation': 'upscale_in_train' - } - self.outputs = { - 'Out': self.inputs['X'], - 'Mask': np.ones((32, 64, 2)).astype('uint8') - } - +from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper + + +class XPUTestDropoutOp(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'dropout' + self.use_dynamic_create_class = False + + class TestDropoutOp(XPUOpTest): + def setUp(self): + self.init_inputs_shape() + self.init_attrs() + self.dtype = self.in_type + self.op_type = 'dropout' + self.inputs = {'X': np.random.random(self.shape).astype(self.dtype)} + self.attrs = { + 'dropout_prob': self.dropout_prob, + 'fix_seed': self.fix_seed, + 'is_test': self.is_test, + 'dropout_implementation': self.dropout_implementation + } + + out = self.inputs['X'] * (1.0 - self.dropout_prob) + if self.is_test == False: + mask = None + if self.dropout_prob == 0.0: + mask = np.ones(self.shape).astype(self.dtype) + elif self.dropout_prob == 1.0: + mask = np.zeros(self.shape).astype(self.dtype) + self.outputs = {'Out': out, 'Mask': mask} + else: + self.outputs = {'Out': out} + + def init_inputs_shape(self): + self.shape = [32, 64] + + def init_attrs(self): + self.__class__.no_need_check_grad = False + self.dropout_prob = 0.0 + self.fix_seed = True + self.is_test = False + self.dropout_implementation = "upscale_in_train" + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + if hasattr(self.__class__, "no_need_check_grad" + ) and self.__class__.no_need_check_grad == True: + return + + self.check_grad(['X'], 'Out') + + class TestDropoutOpInput1d(TestDropoutOp): + def init_inputs_shape(self): + self.shape = [2000] + + class TestDropoutOp2(TestDropoutOp): + def init_inputs_shape(self): + self.shape = [32, 64] + + def init_attrs(self): + self.dropout_prob = 1.0 + self.fix_seed = True + self.is_test = False + self.dropout_implementation = "upscale_in_train" + + class TestDropoutOp3(TestDropoutOp): + def init_inputs_shape(self): + self.shape = [32, 64, 2] + + class TestDropoutOp4(TestDropoutOp): + def init_attrs(self): + self.__class__.no_need_check_grad = True + self.dropout_prob = 0.35 + self.fix_seed = True + self.is_test = True + self.dropout_implementation = "downgrade_in_infer" + + class TestDropoutOp5(TestDropoutOp): + def init_inputs_shape(self): + self.shape = [32, 64, 3] + + def init_attrs(self): + self.__class__.no_need_check_grad = True + self.dropout_prob = 0.75 + self.fix_seed = True + self.is_test = True + self.dropout_implementation = "downgrade_in_infer" + + class TestDropoutOpError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + + def test_Variable(): + # the input of dropout must be Variable. + x1 = fluid.create_lod_tensor( + np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], + fluid.CPUPlace()) + fluid.layers.dropout(x1, dropout_prob=0.5) + + self.assertRaises(TypeError, test_Variable) + + def test_dtype(): + # the input dtype of dropout must be float16 or float32 or float64 + # float16 only can be set on GPU place + x2 = fluid.layers.data( + name='x2', shape=[3, 4, 5, 6], dtype="int32") + fluid.layers.dropout(x2, dropout_prob=0.5) + + self.assertRaises(TypeError, test_dtype) + + class TestDropoutCAPI(unittest.TestCase): + def setUp(self): + np.random.seed(123) + self.places = [fluid.CPUPlace()] + self.places.append(fluid.XPUPlace(0)) + + def test_dygraph(self): + for place in self.places: + with fluid.dygraph.guard(place): + input_np = np.random.random([40, 40]).astype(self.in_type) + result_np = input_np + input = fluid.dygraph.to_variable(input_np) + m = paddle.nn.Dropout(p=0.) + m.eval() + result = m(input) + self.assertTrue(np.allclose(result.numpy(), result_np)) + + class TestDropoutBackward(unittest.TestCase): + def setUp(self): + np.random.seed(123) + self.places = [fluid.CPUPlace()] + self.places.append(fluid.XPUPlace(0)) + + def cal_grad_upscale_train(self, mask, prob): + return mask.astype(self.in_type) / (1 - prob) + + def cal_grad_downscale_in_infer(self, mask): + return mask.astype(self.in_type) + + def test_backward_downscale_in_infer(self): + for place in self.places: + with fluid.dygraph.guard(place): + + input = paddle.uniform([40, 40], dtype=self.in_type) + input.stop_gradient = False + out, mask = core.ops.dropout(input, 'dropout_prob', 0.5) + out.backward() + + self.assertTrue( + np.array_equal(input.gradient( + ), self.cal_grad_downscale_in_infer(mask.numpy()))) + + def test_backward_upscale_train(self): + for place in self.places: + with fluid.dygraph.guard(place): + + prob = 0.5 + input = paddle.uniform([40, 40], dtype=self.in_type) + input.stop_gradient = False + out, mask = core.ops.dropout(input, 'dropout_prob', prob, + "dropout_implementation", + "upscale_in_train") + out.backward() + + self.assertTrue( + np.allclose(input.gradient( + ), self.cal_grad_upscale_train(mask.numpy(), prob))) + + def test_backward_upscale_train_2(self): + for place in self.places: + with fluid.dygraph.guard(place): + + prob = 0.3 + input = paddle.uniform([40, 40], dtype=self.in_type) + input.stop_gradient = False + out, mask = core.ops.dropout(input, 'dropout_prob', prob, + "dropout_implementation", + "upscale_in_train") + out.backward() + + self.assertTrue( + np.allclose(input.gradient( + ), self.cal_grad_upscale_train(mask.numpy(), prob))) + + +support_types = get_xpu_op_support_types('dropout') +for stype in support_types: + create_test_class(globals(), XPUTestDropoutOp, stype) if __name__ == '__main__': unittest.main() -- GitLab