diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py index 57681be58ae47fca118acc4ca4232f27260585ea..87b2ff986dc922de1a5c663b5dda2fc3f509e8fe 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py @@ -31,7 +31,8 @@ ops_to_fill_zero_for_empty_grads = set([ "leaky_relu_double_grad", "sqrt_double_grad", "rsqrt_double_grad", "square_double_grad", "celu_double_grad", "pad_double_grad", "pad3d_double_grad", "squeeze_double_grad", "unsqueeze_double_grad", - "conv3d_double_grad", "depthwise_conv2d_grad_grad" + "instance_norm_double_grad", "conv3d_double_grad", + "depthwise_conv2d_grad_grad" ]) # For API dispatch used at python-level diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index d8b909c3bacc1bf4a1e54ea7d8b5e1bf39dbe1e0..d23d71b07626d035d6b8ef686d9436783a98f214 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -1404,7 +1404,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): const auto& out_metas = OutputMeta(); paddle::small_vector, egr::kSlotSmallVectorSize> returns({slot_num_bwd_outputs}); for (int i = 0; i < {slot_num_bwd_outputs}; ++i) {{ - returns[i].resize(out_metas[i].size()); + out_metas[i].size() == 0 ? returns[i].resize(1) : returns[i].resize(out_metas[i].size()); }} """ diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 521eb03fd770fc6cf7d11b9c03921811c084dc28..f59ea5549bd71409743176c471d7f633f62b7ca7 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -313,10 +313,10 @@ void GumbelSoftmaxGradInferMeta(const MetaTensor& out, } void InstanceNormGradInferMeta(const MetaTensor& x, - const MetaTensor& y_grad, const MetaTensor& scale, const MetaTensor& saved_mean, const MetaTensor& saved_variance, + const MetaTensor& y_grad, float epsilon, MetaTensor* x_grad, MetaTensor* scale_grad, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 93e2d4c43bc3fcc9479f3de4c35888420389676a..0e7ed640d8ffb55b622084bbb44cfbbe0879ba51 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -145,10 +145,10 @@ void GumbelSoftmaxGradInferMeta(const MetaTensor& out, MetaTensor* dx); void InstanceNormGradInferMeta(const MetaTensor& x, - const MetaTensor& y_grad, const MetaTensor& scale, const MetaTensor& saved_mean, const MetaTensor& saved_variance, + const MetaTensor& y_grad, float epsilon, MetaTensor* x_grad, MetaTensor* scale_grad, diff --git a/paddle/phi/kernels/cpu/instance_norm_grad_kernel.cc b/paddle/phi/kernels/cpu/instance_norm_grad_kernel.cc index 340d2907a790999344ae88f13eddec5ce53c0274..867d43fd833de5caf97295d7ee99f4c6de2c6474 100644 --- a/paddle/phi/kernels/cpu/instance_norm_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/instance_norm_grad_kernel.cc @@ -42,10 +42,10 @@ using EigenVectorArrayMap = Eigen::Map>; template void InstanceNormGradKernel(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& d_y, const paddle::optional& scale, const DenseTensor& saved_mean, const DenseTensor& saved_variance, + const DenseTensor& d_y, float epsilon, DenseTensor* d_x, DenseTensor* d_scale, diff --git a/paddle/phi/kernels/gpu/instance_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/instance_norm_grad_kernel.cu index b72acc7073383d8926e7a5f6111a2e0a6eace1a8..b2c2df2d3f0559e1017058ce60e94332f79ac7d7 100644 --- a/paddle/phi/kernels/gpu/instance_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/instance_norm_grad_kernel.cu @@ -290,10 +290,10 @@ __global__ void DoubleGradComputeDScale(const T *x, template void InstanceNormGradKernel(const Context &dev_ctx, const DenseTensor &x, - const DenseTensor &d_y, const paddle::optional &scale, const DenseTensor &saved_mean, const DenseTensor &saved_variance, + const DenseTensor &d_y, float epsilon_f, DenseTensor *d_x, DenseTensor *d_scale, diff --git a/paddle/phi/kernels/instance_norm_grad_kernel.h b/paddle/phi/kernels/instance_norm_grad_kernel.h index be7e4ce3e34880ce3a0237b8913ba98ad0e370e6..2a661a3fd3853befe12fdb64b15c8164271d59da 100644 --- a/paddle/phi/kernels/instance_norm_grad_kernel.h +++ b/paddle/phi/kernels/instance_norm_grad_kernel.h @@ -21,10 +21,10 @@ namespace phi { template void InstanceNormGradKernel(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& y_grad, const paddle::optional& scale, const DenseTensor& saved_mean, const DenseTensor& saved_variance, + const DenseTensor& y_grad, float epsilon, DenseTensor* x_grad, DenseTensor* scale_grad, diff --git a/paddle/phi/ops/compat/instance_norm_sig.cc b/paddle/phi/ops/compat/instance_norm_sig.cc index 2b490078512b1ef6c3574b08f15e55594c578435..6ccf12097988791d0c50dcc62afcc263377311ca 100644 --- a/paddle/phi/ops/compat/instance_norm_sig.cc +++ b/paddle/phi/ops/compat/instance_norm_sig.cc @@ -27,7 +27,7 @@ KernelSignature InstanceNormOpArgumentMapping( KernelSignature InstanceNormGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("instance_norm_grad", - {"X", "Y@GRAD", "Scale", "SavedMean", "SavedVariance"}, + {"X", "Scale", "SavedMean", "SavedVariance", "Y@GRAD"}, {"epsilon"}, {"X@GRAD", "Scale@GRAD", "Bias@GRAD"}); } diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index 4d985097088f852d064344f49d0d6decb22f81f4..72114a275156d63981bff350c73895d12f3c174b 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -1137,7 +1137,11 @@ class InstanceNorm(layers.Layer): self.bias = None def forward(self, input): - if _non_static_mode(): + if in_dygraph_mode(): + out, _, _, = _C_ops.final_state_instance_norm( + input, self.scale, self.bias, self._epsilon) + return out + if _in_legacy_dygraph(): out, _, _ = _C_ops.instance_norm(input, self.scale, self.bias, 'epsilon', self._epsilon) return out diff --git a/python/paddle/fluid/tests/unittests/test_instance_norm_op.py b/python/paddle/fluid/tests/unittests/test_instance_norm_op.py index aa184dd42e6fcd8d767a97b63a5d2de7deb47c9c..23c514334765da222b275c230e80156afe977a7a 100644 --- a/python/paddle/fluid/tests/unittests/test_instance_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_instance_norm_op.py @@ -22,6 +22,7 @@ from paddle.fluid.op import Operator from op_test import OpTest from paddle.fluid import Program, program_guard from paddle.fluid.dygraph import to_variable +from paddle.fluid.framework import _test_eager_guard def _reference_instance_norm_naive(x, scale, bias, epsilon, mean, var): @@ -253,6 +254,10 @@ class TestElasticNormOp(unittest.TestCase): outputs = instance_norm(to_variable(inputs)) self.assertTrue(np.allclose(outputs.numpy(), out_np, atol=1e-6)) + def test_eager_api(self): + with _test_eager_guard(): + self.test_norm() + class TestElasticNormOpCase2(unittest.TestCase): def init_test_case(self): @@ -282,6 +287,10 @@ class TestElasticNormOpCase2(unittest.TestCase): outputs = instance_norm(to_variable(inputs)) self.assertTrue(np.allclose(outputs.numpy(), out_np, atol=1e-6)) + def test_eager_api(self): + with _test_eager_guard(): + self.test_norm() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_instance_norm_op_v2.py b/python/paddle/fluid/tests/unittests/test_instance_norm_op_v2.py index 102e08e36a9e571db1f4dc0b77c5b2d90c6bb792..1656bc11869fd1f8da4008f35249856bda2f4b7a 100644 --- a/python/paddle/fluid/tests/unittests/test_instance_norm_op_v2.py +++ b/python/paddle/fluid/tests/unittests/test_instance_norm_op_v2.py @@ -22,6 +22,7 @@ from op_test import OpTest, _set_use_system_allocator from paddle.fluid.framework import grad_var_name import paddle.fluid as fluid from paddle.fluid import Program, program_guard +from paddle.fluid.framework import _test_eager_guard import paddle @@ -116,6 +117,11 @@ class TestInstanceNorm(unittest.TestCase): y2 = compute_v2(x) self.assertTrue(np.allclose(y1, y2)) + def test_eager_api(self): + with _test_eager_guard(): + self.test_dygraph() + self.test_error() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_norm_nn_grad.py b/python/paddle/fluid/tests/unittests/test_norm_nn_grad.py index 1452b869d4f8b8569d211c083d1c3afc736d0b29..13c2edbf37cf7d7d216e811eef8ac1b76966b207 100644 --- a/python/paddle/fluid/tests/unittests/test_norm_nn_grad.py +++ b/python/paddle/fluid/tests/unittests/test_norm_nn_grad.py @@ -70,6 +70,72 @@ class TestInstanceNormDoubleGradCheckWithoutParamBias( [x], z, x_init=x_arr, atol=atol, place=place, eps=eps) +class TestInstanceNormDoubleGradEagerCheck(unittest.TestCase): + def instance_norm_wrapper(self, x): + return paddle.nn.functional.instance_norm(x[0]) + + @prog_scope() + def func(self, place): + prog = fluid.Program() + with fluid.program_guard(prog): + np.random.seed() + shape = [2, 3, 4, 5] + dtype = "float32" + eps = 0.005 + atol = 1e-4 + x = layers.create_parameter(dtype=dtype, shape=shape, name='x') + z = paddle.nn.functional.instance_norm(x) + x_arr = np.random.uniform(-1, 1, shape).astype(dtype) + # check for static mode + gradient_checker.double_grad_check( + [x], z, x_init=x_arr, atol=atol, place=place, eps=eps) + # check for eager mode + gradient_checker.double_grad_check_for_dygraph( + self.instance_norm_wrapper, [x], + z, + x_init=x_arr, + atol=atol, + place=place) + + def test_grad(self): + paddle.enable_static() + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +class TestInstanceNormDoubleGradEagerCheckWithParams( + TestInstanceNormDoubleGradEagerCheck): + def instance_norm_wrapper(self, x): + instance_norm = paddle.nn.InstanceNorm2D(3) + return instance_norm(x[0]) + + @prog_scope() + def func(self, place): + prog = fluid.Program() + with fluid.program_guard(prog): + np.random.seed() + shape = [2, 3, 4, 5] + dtype = "float32" + eps = 0.005 + atol = 1e-4 + x = layers.create_parameter(dtype=dtype, shape=shape, name='x') + z = paddle.nn.InstanceNorm2D(3)(x) + x_arr = np.random.uniform(-1, 1, shape).astype(dtype) + # check for static mode + gradient_checker.double_grad_check( + [x], z, x_init=x_arr, atol=atol, place=place, eps=eps) + # check for eager mode + gradient_checker.double_grad_check_for_dygraph( + self.instance_norm_wrapper, [x], + z, + x_init=x_arr, + atol=atol, + place=place) + + class TestBatchNormDoubleGradCheck(unittest.TestCase): def setUp(self): self.init_test() diff --git a/python/paddle/nn/functional/norm.py b/python/paddle/nn/functional/norm.py index e719099b4b39da51406ba9a5305d2f8f522404d7..f64e731342ed2ba587a76af0cb9145e091ac20f9 100644 --- a/python/paddle/nn/functional/norm.py +++ b/python/paddle/nn/functional/norm.py @@ -407,8 +407,10 @@ def instance_norm(x, print(instance_norm_out) """ - - if in_dynamic_mode(): + if in_dygraph_mode(): + out, _, _, = _C_ops.final_state_instance_norm(x, weight, bias, eps) + return out + if _in_legacy_dygraph(): out, _, _ = _C_ops.instance_norm(x, weight, bias, "epsilon", eps, "momentum", momentum, "data_format", data_format) diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 44865940adb44df8ae5a5681e66f89029fb779eb..8ed4832a8f751701383044cec86e79280e7dcfc2 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -1030,6 +1030,17 @@ data_type : x backward : index_select_grad +- api : instance_norm + args : (Tensor x, Tensor scale, Tensor bias, float epsilon) + output : Tensor(y), Tensor(saved_mean), Tensor(saved_variance) + infer_meta : + func : InstanceNormInferMeta + kernel : + func : instance_norm + data_type : x + optional : scale, bias + backward : instance_norm_grad + # is_empty - api : is_empty args : (Tensor x) diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index d6c148e6ca925815adf8efaf69666d424820f65d..6a555fd24a0665fb26f59d789ce00b64a05b9ac6 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -927,6 +927,29 @@ data_type : x no_need_buffer : x +- backward_api : instance_norm_double_grad + forward : instance_norm_grad(Tensor x, Tensor fwd_scale, Tensor saved_mean, Tensor saved_variance, Tensor grad_y, float epsilon) -> Tensor(grad_x), Tensor(grad_scale), Tensor(grad_bias) + args : (Tensor x, Tensor fwd_scale, Tensor saved_mean, Tensor saved_variance, Tensor grad_y, Tensor grad_x_grad, Tensor grad_scale_grad, Tensor grad_bias_grad, float epsilon) + output : Tensor(x_grad), Tensor(fwd_scale_grad), Tensor(grad_y_grad) + infer_meta : + func : InstanceNormDoubleGradInferMeta + kernel : + func : instance_norm_double_grad + data_type : x + optional : fwd_scale, grad_x_grad, grad_scale_grad, grad_bias_grad + +- backward_api : instance_norm_grad + forward : instance_norm(Tensor x, Tensor scale, Tensor bias, float epsilon) -> Tensor(y), Tensor(saved_mean), Tensor(saved_variance) + args : (Tensor x, Tensor scale, Tensor saved_mean, Tensor saved_variance, Tensor y_grad, float epsilon) + output : Tensor(x_grad), Tensor(scale_grad), Tensor(bias_grad) + infer_meta : + func : InstanceNormGradInferMeta + kernel : + func : instance_norm_grad + data_type : x + optional : scale + backward : instance_norm_double_grad + - backward_api : kldiv_loss_grad forward : kldiv_loss(Tensor x, Tensor label, str reduction) -> Tensor(out) args : (Tensor x, Tensor label, Tensor out_grad, str reduction) diff --git a/tools/infrt/skipped_phi_api.json b/tools/infrt/skipped_phi_api.json index 2502e248c5c48310dbf979d0339c0f8c97b321b3..75533311513e5a429a9d4bbee1243f62e6a87296 100644 --- a/tools/infrt/skipped_phi_api.json +++ b/tools/infrt/skipped_phi_api.json @@ -1,4 +1,4 @@ { -"phi_apis":["conj", "deformable_conv", "dropout", "expand_as", "nll_loss", "psroi_pool", "roi_align", "roi_pool", "label_smooth", "layer_norm"], +"phi_apis":["conj", "deformable_conv", "dropout", "expand_as", "nll_loss", "psroi_pool", "roi_align", "roi_pool", "label_smooth", "layer_norm", "instance_norm"], "phi_kernels":["equal_all"] }