diff --git a/paddle/fluid/operators/batch_norm_op.cu b/paddle/fluid/operators/batch_norm_op.cu index 95d7f23b2c0ac6e46cf85bef4340eb4180dc3dba..4b4fcd487ceda989491b1917bd54e9556a899b37 100644 --- a/paddle/fluid/operators/batch_norm_op.cu +++ b/paddle/fluid/operators/batch_norm_op.cu @@ -418,7 +418,7 @@ class BatchNormGradKernel } } else { if (d_x) { - BNBackwardData<<< + BNBackwardData<<< grid2, block, 0, dev_ctx.stream()>>>( d_y->data(), scale->data>(), saved_mean_data, x->data(), saved_var_data, C, N, H * W * D, diff --git a/paddle/fluid/operators/instance_norm_op.cu b/paddle/fluid/operators/instance_norm_op.cu index c0609509037786f9a90edbfe2634869d154aab31..4c04f6c315b1319b80b7db10b58a746f0491eeae 100644 --- a/paddle/fluid/operators/instance_norm_op.cu +++ b/paddle/fluid/operators/instance_norm_op.cu @@ -328,7 +328,7 @@ class InstanceNormGradKernel epsilon, saved_mean_data, saved_var_data)); } else { if (d_x) { - GradComputeDX<<>>( + GradComputeDX<<>>( d_y->data(), scale->data>(), saved_mean_data, x->data(), saved_var_data, C, H * W * D, d_x->data()); diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 860e7092f6958c477668695bf28aec03d876aa9f..3eca39c2b180439a3ea1647b616482198352f9d6 100644 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -1247,9 +1247,13 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None): target = targets[i] if grad is None: grad_name = _append_grad_suffix_(target.name) - target_shape = paddle.fluid.layers.shape(target) + target_shape = target.name + '_shape' + block.desc.append_op().copy_from( + _create_op_desc_("shape", {'Input': [target.name]}, + {"Out": [target_shape]}, {})) + input_grad_names_set.add(target_shape) op_desc = _create_op_desc_("fill_constant", - {"ShapeTensor": [target_shape.name]}, + {"ShapeTensor": [target_shape]}, {"Out": [grad_name]}, { "shape": target.shape, "value": 1.0, diff --git a/python/paddle/fluid/tests/unittests/test_instance_norm_op.py b/python/paddle/fluid/tests/unittests/test_instance_norm_op.py index ccdf12849c7ee61fad5916aea0403b760a0302db..c02e48bd715b06fcf3950881447d3189377edf05 100644 --- a/python/paddle/fluid/tests/unittests/test_instance_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_instance_norm_op.py @@ -79,7 +79,7 @@ class TestInstanceNormOpTraining(unittest.TestCase): self.init_test_case() def init_test_case(self): - self.use_global_stats = False + self.shape = [2, 3, 4, 5] self.no_grad_set = set() self.fetch_list = [ 'y', 'saved_mean', 'saved_variance', 'x@GRAD', 'scale@GRAD', @@ -181,12 +181,19 @@ class TestInstanceNormOpTraining(unittest.TestCase): "instance_norm"): places.append(core.CUDAPlace(0)) for place in places: - test_with_place(place, [2, 3, 4, 5]) + test_with_place(place, self.shape) class TestInstanceNormOpTrainingCase1(TestInstanceNormOpTraining): def init_test_case(self): - self.use_global_stats = False + self.shape = [2, 3, 4, 5] + self.no_grad_set = set(['scale@GRAD', 'bias@GRAD']) + self.fetch_list = ['y', 'saved_mean', 'saved_variance', 'x@GRAD'] + + +class TestInstanceNormOpTrainingCase2(TestInstanceNormOpTraining): + def init_test_case(self): + self.shape = [20, 50, 4, 5] self.no_grad_set = set(['scale@GRAD', 'bias@GRAD']) self.fetch_list = ['y', 'saved_mean', 'saved_variance', 'x@GRAD']