diff --git a/paddle/phi/api/yaml/generator/sparse_api_gen.py b/paddle/phi/api/yaml/generator/sparse_api_gen.py index 4236862e347b320d7f3fe3ff80f5720a16f014a5..f013fed1b3ce1c0fb7720791b29e8f2e80bfa3d6 100644 --- a/paddle/phi/api/yaml/generator/sparse_api_gen.py +++ b/paddle/phi/api/yaml/generator/sparse_api_gen.py @@ -238,7 +238,7 @@ class SparseAPI(ForwardAPI): kernel_name, inplace_flag) return f""" -PADDLE_API {self.get_return_type()} {api_func_name}({self.get_define_args()}) {{ +PADDLE_API {self.get_return_type(inplace_flag)} {api_func_name}({self.get_define_args(inplace_flag)}) {{ {kernel_dispatch_code} PADDLE_THROW(phi::errors::Unimplemented( "The kernel of ({self.api}) for input tensors is unimplemented, please check the type of input tensors.")); diff --git a/paddle/phi/api/yaml/sparse_backward.yaml b/paddle/phi/api/yaml/sparse_backward.yaml index c57a3cae8927f3e10d4874e0a77518b0a371ede5..ffb5406436faa2f37eb0c6108366594b778ddadc 100644 --- a/paddle/phi/api/yaml/sparse_backward.yaml +++ b/paddle/phi/api/yaml/sparse_backward.yaml @@ -368,7 +368,7 @@ subtract_csr_csr_grad{sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr} - backward_op : sync_batch_norm_grad - forward : sync_batch_norm(Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) -> Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space) + forward : sync_batch_norm_(Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) -> Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space) args : (Tensor x, Tensor scale, Tensor bias, Tensor saved_mean, Tensor saved_variance, Tensor reserve_space, Tensor out_grad, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) output : Tensor(x_grad), Tensor(scale_grad), Tensor(bias_grad) infer_meta : diff --git a/paddle/phi/api/yaml/sparse_ops.yaml b/paddle/phi/api/yaml/sparse_ops.yaml index 7921f423d859d6a2bd7f669dcd55d9e5ce371e00..a7111e5dee3c6b70b2127e9ee1330b6be95a957b 100644 --- a/paddle/phi/api/yaml/sparse_ops.yaml +++ b/paddle/phi/api/yaml/sparse_ops.yaml @@ -95,6 +95,7 @@ kernel : func : batch_norm_coo {sparse_coo, dense, dense, dense, dense -> sparse_coo, dense, dense, dense, dense, dense} data_type : x + view : (mean -> mean_out), (variance -> variance_out) backward : batch_norm_grad - op : cast @@ -480,7 +481,7 @@ layout : x backward : transpose_grad -- op : sync_batch_norm +- op : sync_batch_norm_ args : (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) output : Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space) infer_meta : @@ -489,6 +490,7 @@ func : sync_batch_norm_coo{sparse_coo, dense, dense, dense, dense -> sparse_coo, dense, dense, dense, dense, dense} data_type : x backward : sync_batch_norm_grad + inplace : (mean -> mean_out), (variance -> variance_out) - op : reshape args : (Tensor x, IntArray shape) diff --git a/python/paddle/incubate/sparse/nn/layer/norm.py b/python/paddle/incubate/sparse/nn/layer/norm.py index 440a08953951f2fe6ec1f8e1cc65a6f7eca7cf21..b84db710d732e11aae486942d4fb8e96f7d38f5a 100644 --- a/python/paddle/incubate/sparse/nn/layer/norm.py +++ b/python/paddle/incubate/sparse/nn/layer/norm.py @@ -297,7 +297,7 @@ class SyncBatchNorm(paddle.nn.SyncBatchNorm): def forward(self, x): self._check_data_format() - sync_batch_norm_out, _, _, _, _, _ = _C_ops.sparse_sync_batch_norm( + sync_batch_norm_out, _, _, _, _, _ = _C_ops.sparse_sync_batch_norm_( x, self.weight, self.bias, self._mean, self._variance, self._momentum, self._epsilon, self._data_format, not self.training, False, False, False)