diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 6736c379dfcdd107c8a836f947f2767eaca6be88..ddb2863cf48a0f59deb7bbfc5cdd0425070063a5 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -115,7 +115,7 @@ func : GeneralTernaryGradInferMeta param : [x, scale, x] kernel : - func : batch_norm_grad_grad + func : batch_norm_double_grad data_type : x optional : out_mean, out_variance, grad_x_grad, grad_scale_grad, grad_bias_grad inplace : (grad_out -> grad_out_grad) diff --git a/paddle/phi/kernels/cpu/batch_norm_grad_kernel.cc b/paddle/phi/kernels/cpu/batch_norm_grad_kernel.cc index c2da486e9f7521b21f2953977738d52c2e5ddf87..9eec65e92a38f36b866a56438576004f43955528 100644 --- a/paddle/phi/kernels/cpu/batch_norm_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/batch_norm_grad_kernel.cc @@ -662,7 +662,7 @@ PD_REGISTER_KERNEL(batch_norm_grad_raw, float, double) {} -PD_REGISTER_KERNEL(batch_norm_grad_grad, +PD_REGISTER_KERNEL(batch_norm_double_grad, CPU, ALL_LAYOUT, phi::BatchNormDoubleGradKernel, diff --git a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu index 3b09890e2242282fc20d0c5a5e9916ba4b65bd81..ede2458744902695974f552c87c86336b2034ef2 100644 --- a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu @@ -1345,14 +1345,14 @@ PD_REGISTER_KERNEL(batch_norm_grad_raw, #endif #ifdef PADDLE_WITH_HIP -PD_REGISTER_KERNEL(batch_norm_grad_grad, +PD_REGISTER_KERNEL(batch_norm_double_grad, GPU, ALL_LAYOUT, phi::BatchNormDoubleGradKernel, float, double) {} #else -PD_REGISTER_KERNEL(batch_norm_grad_grad, +PD_REGISTER_KERNEL(batch_norm_double_grad, GPU, ALL_LAYOUT, phi::BatchNormDoubleGradKernel, diff --git a/paddle/phi/ops/compat/batch_norm_sig.cc b/paddle/phi/ops/compat/batch_norm_sig.cc index ff7a582142513a3bc56227ed22f60538fb0843b1..5f6efcd9ce769ff9332866d4fcb7a672743c50fd 100644 --- a/paddle/phi/ops/compat/batch_norm_sig.cc +++ b/paddle/phi/ops/compat/batch_norm_sig.cc @@ -79,7 +79,7 @@ KernelSignature BatchNormGradOpArgumentMapping( KernelSignature BatchNormGradGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature("batch_norm_grad_grad", + return KernelSignature("batch_norm_double_grad", {"X", "Scale", "Mean",