From 369154744d3e1136d7ac96b754dcf7642f28b4a9 Mon Sep 17 00:00:00 2001 From: qipengh Date: Sat, 10 Sep 2022 16:48:34 +0800 Subject: [PATCH] [MLU] fix compute error of dropout op (#45923) --- paddle/fluid/operators/dropout_op_mlu.cc | 71 +++--- paddle/fluid/operators/pool_op_mlu.cc | 7 +- .../unittests/mlu/test_dropout_op_mlu.py | 208 ++++++++++-------- 3 files changed, 151 insertions(+), 135 deletions(-) diff --git a/paddle/fluid/operators/dropout_op_mlu.cc b/paddle/fluid/operators/dropout_op_mlu.cc index 923e6cc5ed9..142e047e6c2 100644 --- a/paddle/fluid/operators/dropout_op_mlu.cc +++ b/paddle/fluid/operators/dropout_op_mlu.cc @@ -39,8 +39,17 @@ class DropoutMLUKernel : public framework::OpKernel { MLUCnnlTensorDesc x_desc(*x); MLUCnnlTensorDesc out_desc(*out); - if (!is_test) { - // exec dropout op for training only. + if (is_test && is_upscale) { + // dropout op for inference: out = input. + framework::TensorCopy( + *x, + ctx.GetPlace(), + ctx.template device_context(), + out); + return; + } else if (!is_test) { + // dropout op for training: out = input * mask / ( 1.0 - dropout_prob ) or + // out = input * mask. int seed_data = 0; if (seed_tensor) { if (platform::is_mlu_place(seed_tensor->place())) { @@ -79,50 +88,44 @@ class DropoutMLUKernel : public framework::OpKernel { const int device_id = ctx.GetPlace().GetDeviceId(); auto mlu_gen_random = GetMLURandomGenerator(ctx, device_id, seed_data); - const float prob = is_upscale ? dropout_prob : 0.0f; + // compute out = input * mask / ( 1.0 - dropout_prob ) MLUCnnl::FusedDropout(ctx, mlu_gen_random->get(), x_desc.get(), GetBasePtr(x), - prob, + dropout_prob, GetBasePtr(&(mlu_gen_random->get_state())), mask_desc.get(), GetBasePtr(mask), out_desc.get(), GetBasePtr(out)); - } else { - // exec dropout op for inference only. + if (is_upscale) { - framework::TensorCopy( - *x, - ctx.GetPlace(), - ctx.template device_context(), - out); - } else { - auto scale = static_cast(1.0f - dropout_prob); - Tensor scale_tensor(x->dtype()); - scale_tensor.mutable_data({1}, ctx.GetPlace()); - MLUCnnlTensorDesc scale_desc(scale_tensor); - MLUCnnl::Fill(ctx, - CNNL_POINTER_MODE_HOST, - &scale, - scale_desc.get(), - GetBasePtr(&scale_tensor)); - - auto data_type = ToCnnlDataType(); - MLUCnnlOpTensorDesc op_tensor_desc( - CNNL_OP_TENSOR_MUL, data_type, CNNL_NOT_PROPAGATE_NAN); - MLUCnnl::OpTensor(ctx, - op_tensor_desc.get(), - x_desc.get(), - GetBasePtr(x), - scale_desc.get(), - GetBasePtr(&scale_tensor), - out_desc.get(), - GetBasePtr(out), - data_type); + return; } } + + // In downgrade_in_infer mode, need to multiply (1.0f - dropout_prob). + Tensor scale_tensor(x->dtype()); + Tensor bias_tensor(x->dtype()); + scale_tensor.mutable_data({1}, ctx.GetPlace()); + bias_tensor.mutable_data({1}, ctx.GetPlace()); + MLUCnnlTensorDesc scale_desc(scale_tensor); + MLUCnnlTensorDesc bias_desc(bias_tensor); + FillMLUTensorWithHostValue( + ctx, static_cast(1.0f - dropout_prob), &scale_tensor); + FillMLUTensorWithHostValue(ctx, static_cast(0.0f), &bias_tensor); + + MLUCnnl::Scale(ctx, + 0, + is_test ? x_desc.get() : out_desc.get(), + is_test ? GetBasePtr(x) : GetBasePtr(out), + scale_desc.get(), + GetBasePtr(&scale_tensor), + bias_desc.get(), + GetBasePtr(&bias_tensor), + out_desc.get(), + GetBasePtr(out)); } }; diff --git a/paddle/fluid/operators/pool_op_mlu.cc b/paddle/fluid/operators/pool_op_mlu.cc index 5eaf8dbff88..988eb182a16 100644 --- a/paddle/fluid/operators/pool_op_mlu.cc +++ b/paddle/fluid/operators/pool_op_mlu.cc @@ -141,10 +141,9 @@ class MLUPoolOpKernel : public framework::OpKernel { handle, pool_mode, out_w, out_h, &extra_input_size); if (extra_input_size > 0) { - phi::CPUContext cpu_ctx; - framework::Tensor extra_host_tensor = - ctx.AllocateTmpTensor( - {static_cast(extra_input_size)}, cpu_ctx); + framework::Tensor extra_host_tensor; + extra_host_tensor.mutable_data( + {static_cast(extra_input_size)}, platform::CPUPlace()); cnnlInitPoolingExtraInput(handle, pool_desc.get(), trans_in_x_desc.get(), diff --git a/python/paddle/fluid/tests/unittests/mlu/test_dropout_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_dropout_op_mlu.py index 8497853561d..8f3740eb994 100644 --- a/python/paddle/fluid/tests/unittests/mlu/test_dropout_op_mlu.py +++ b/python/paddle/fluid/tests/unittests/mlu/test_dropout_op_mlu.py @@ -31,24 +31,43 @@ SEED = 2022 class TestDropoutOp(OpTest): def setUp(self): - self.op_type = "dropout" self.set_mlu() self.init_dtype() - self.inputs = {'X': np.random.random((32, 64)).astype(self.dtype)} + self.init_inputs_shape() + self.init_attrs() + self.op_type = 'dropout' + self.inputs = {'X': np.random.random(self.shape).astype(self.dtype)} self.attrs = { - 'dropout_prob': 0.0, - 'fix_seed': True, - 'is_test': False, - 'dropout_implementation': 'upscale_in_train' - } - self.outputs = { - 'Out': self.inputs['X'], - 'Mask': np.ones((32, 64)).astype('uint8') + 'dropout_prob': self.dropout_prob, + 'fix_seed': self.fix_seed, + 'is_test': self.is_test, + 'dropout_implementation': self.dropout_implementation } + out = self.inputs['X'] * (1.0 - self.dropout_prob) + if self.is_test == False: + mask = None + if self.dropout_prob == 0.0: + mask = np.ones(self.shape).astype('uint8') + elif self.dropout_prob == 1.0: + mask = np.zeros(self.shape).astype('uint8') + self.outputs = {'Out': out, 'Mask': mask} + else: + self.outputs = {'Out': out} + def init_dtype(self): self.dtype = np.float32 + def init_inputs_shape(self): + self.shape = [32, 64] + + def init_attrs(self): + self.__class__.no_need_check_grad = False + self.dropout_prob = 0.0 + self.fix_seed = True + self.is_test = False + self.dropout_implementation = "upscale_in_train" + def set_mlu(self): self.__class__.use_mlu = True self.place = paddle.device.MLUPlace(0) @@ -57,84 +76,111 @@ class TestDropoutOp(OpTest): self.check_output_with_place(self.place) def test_check_grad_normal(self): + if hasattr(self.__class__, "no_need_check_grad" + ) and self.__class__.no_need_check_grad == True: + return + self.check_grad_with_place(self.place, ['X'], 'Out') class TestDropoutOpInput1d(TestDropoutOp): - # change input shape - def setUp(self): - self.op_type = "dropout" - self.set_mlu() - self.init_dtype() - self.inputs = {'X': np.random.random((3, 62)).astype(self.dtype)} - self.attrs = { - 'dropout_prob': 0.0, - 'fix_seed': True, - 'is_test': False, - 'dropout_implementation': 'upscale_in_train' - } - self.outputs = { - 'Out': self.inputs['X'], - 'Mask': np.ones((3, 62)).astype('uint8') - } - -class TestDropoutOpInput1d_1(TestDropoutOp): - # the input is 1-D - def setUp(self): - self.op_type = "dropout" - self.set_mlu() - self.init_dtype() - self.inputs = {'X': np.random.random((2000)).astype(self.dtype)} - self.attrs = { - 'dropout_prob': 0.0, - 'fix_seed': True, - 'is_test': False, - 'dropout_implementation': 'upscale_in_train' - } - self.outputs = { - 'Out': self.inputs['X'], - 'Mask': np.ones((2000)).astype('uint8') - } + def init_inputs_shape(self): + self.shape = [2000] class TestDropoutOp2(TestDropoutOp): - # the dropout_prob is 1.0 - def setUp(self): - self.op_type = "dropout" - self.set_mlu() - self.init_dtype() - self.inputs = {'X': np.random.random((32, 64)).astype(self.dtype)} - self.attrs = { - 'dropout_prob': 1.0, - 'fix_seed': True, - 'is_test': False, - 'dropout_implementation': 'upscale_in_train' - } - self.outputs = { - 'Out': np.zeros((32, 64)).astype('float32'), - 'Mask': np.zeros((32, 64)).astype('uint8') - } + + def init_inputs_shape(self): + self.shape = [32, 64] + + def init_attrs(self): + self.dropout_prob = 1.0 + self.fix_seed = True + self.is_test = False + self.dropout_implementation = "upscale_in_train" class TestDropoutOp3(TestDropoutOp): - # the input dim is 3 + + def init_inputs_shape(self): + self.shape = [32, 64, 2] + + +class TestDropoutOp4(TestDropoutOp): + + def init_attrs(self): + self.__class__.no_need_check_grad = True + self.dropout_prob = 0.35 + self.fix_seed = True + self.is_test = True + self.dropout_implementation = "downgrade_in_infer" + + +class TestDropoutOp5(TestDropoutOp): + + def init_inputs_shape(self): + self.shape = [32, 64, 3] + + def init_attrs(self): + self.__class__.no_need_check_grad = True + self.dropout_prob = 0.75 + self.fix_seed = True + self.is_test = True + self.dropout_implementation = "downgrade_in_infer" + + +class TestDropoutOp6(TestDropoutOp): + + def init_attrs(self): + self.__class__.no_need_check_grad = True + self.dropout_prob = 0.0 + self.fix_seed = True + self.is_test = False + self.dropout_implementation = "downgrade_in_infer" + + +class TestDropoutOpWithSeed(TestDropoutOp): + # the seed is a Tensor def setUp(self): self.op_type = "dropout" self.set_mlu() - self.init_dtype() - self.inputs = {'X': np.random.random((32, 64, 2)).astype(self.dtype)} + self.dtype = np.float32 + self.inputs = { + "X": np.random.random((32, 64)).astype(self.dtype), + "Seed": np.asarray([125], dtype="int32") + } self.attrs = { 'dropout_prob': 0.0, - 'fix_seed': True, 'is_test': False, 'dropout_implementation': 'upscale_in_train' } self.outputs = { 'Out': self.inputs['X'], - 'Mask': np.ones((32, 64, 2)).astype('uint8') + 'Mask': np.ones((32, 64)).astype('uint8') } + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.device.MLUPlace(0) + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad_normal(self): + self.check_grad_with_place(self.place, ['X'], 'Out') + + +class TestDropoutOpFp16(TestDropoutOp): + # float16 + def init_dtype(self): + self.dtype = np.float16 + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.device.MLUPlace(0) + self.__class__.no_need_check_grad = True + @skip_check_grad_ci(reason="For inference, check_grad is not required.") class TestDropoutOpInference(OpTest): @@ -179,38 +225,6 @@ class TestDropoutOpInference2(TestDropoutOpInference): self.outputs = {'Out': self.inputs['X']} -class TestDropoutOpWithSeed(TestDropoutOp): - # the seed is a Tensor - def setUp(self): - self.op_type = "dropout" - self.set_mlu() - self.init_dtype() - self.inputs = { - "X": np.random.random((32, 64)).astype(self.dtype), - "Seed": np.asarray([125], dtype="int32") - } - self.attrs = { - 'dropout_prob': 0.0, - 'is_test': False, - 'dropout_implementation': 'upscale_in_train' - } - self.outputs = { - 'Out': self.inputs['X'], - 'Mask': np.ones((32, 64)).astype('uint8') - } - - -class TestDropoutOpFp16(TestDropoutOp): - # float16 - def init_dtype(self): - self.dtype = np.float16 - - def set_mlu(self): - self.__class__.use_mlu = True - self.place = paddle.device.MLUPlace(0) - self.__class__.no_need_check_grad = True - - class TestDropoutAPI(unittest.TestCase): def setUp(self): -- GitLab