From b9e6b94d01e082170ea5e5873ac3d42e612f4294 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 20 Oct 2022 09:48:39 +0800 Subject: [PATCH] fix sparse inplace (#47167) --- paddle/phi/api/yaml/generator/sparse_api_gen.py | 2 +- paddle/phi/api/yaml/sparse_backward.yaml | 2 +- paddle/phi/api/yaml/sparse_ops.yaml | 4 +++- python/paddle/incubate/sparse/nn/layer/norm.py | 2 +- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/paddle/phi/api/yaml/generator/sparse_api_gen.py b/paddle/phi/api/yaml/generator/sparse_api_gen.py index 4236862e347..f013fed1b3c 100644 --- a/paddle/phi/api/yaml/generator/sparse_api_gen.py +++ b/paddle/phi/api/yaml/generator/sparse_api_gen.py @@ -238,7 +238,7 @@ class SparseAPI(ForwardAPI): kernel_name, inplace_flag) return f""" -PADDLE_API {self.get_return_type()} {api_func_name}({self.get_define_args()}) {{ +PADDLE_API {self.get_return_type(inplace_flag)} {api_func_name}({self.get_define_args(inplace_flag)}) {{ {kernel_dispatch_code} PADDLE_THROW(phi::errors::Unimplemented( "The kernel of ({self.api}) for input tensors is unimplemented, please check the type of input tensors.")); diff --git a/paddle/phi/api/yaml/sparse_backward.yaml b/paddle/phi/api/yaml/sparse_backward.yaml index c57a3cae892..ffb5406436f 100644 --- a/paddle/phi/api/yaml/sparse_backward.yaml +++ b/paddle/phi/api/yaml/sparse_backward.yaml @@ -368,7 +368,7 @@ subtract_csr_csr_grad{sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr} - backward_op : sync_batch_norm_grad - forward : sync_batch_norm(Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) -> Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space) + forward : sync_batch_norm_(Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) -> Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space) args : (Tensor x, Tensor scale, Tensor bias, Tensor saved_mean, Tensor saved_variance, Tensor reserve_space, Tensor out_grad, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) output : Tensor(x_grad), Tensor(scale_grad), Tensor(bias_grad) infer_meta : diff --git a/paddle/phi/api/yaml/sparse_ops.yaml b/paddle/phi/api/yaml/sparse_ops.yaml index 7921f423d85..a7111e5dee3 100644 --- a/paddle/phi/api/yaml/sparse_ops.yaml +++ b/paddle/phi/api/yaml/sparse_ops.yaml @@ -95,6 +95,7 @@ kernel : func : batch_norm_coo {sparse_coo, dense, dense, dense, dense -> sparse_coo, dense, dense, dense, dense, dense} data_type : x + view : (mean -> mean_out), (variance -> variance_out) backward : batch_norm_grad - op : cast @@ -480,7 +481,7 @@ layout : x backward : transpose_grad -- op : sync_batch_norm +- op : sync_batch_norm_ args : (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) output : Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space) infer_meta : @@ -489,6 +490,7 @@ func : sync_batch_norm_coo{sparse_coo, dense, dense, dense, dense -> sparse_coo, dense, dense, dense, dense, dense} data_type : x backward : sync_batch_norm_grad + inplace : (mean -> mean_out), (variance -> variance_out) - op : reshape args : (Tensor x, IntArray shape) diff --git a/python/paddle/incubate/sparse/nn/layer/norm.py b/python/paddle/incubate/sparse/nn/layer/norm.py index 440a0895395..b84db710d73 100644 --- a/python/paddle/incubate/sparse/nn/layer/norm.py +++ b/python/paddle/incubate/sparse/nn/layer/norm.py @@ -297,7 +297,7 @@ class SyncBatchNorm(paddle.nn.SyncBatchNorm): def forward(self, x): self._check_data_format() - sync_batch_norm_out, _, _, _, _, _ = _C_ops.sparse_sync_batch_norm( + sync_batch_norm_out, _, _, _, _, _ = _C_ops.sparse_sync_batch_norm_( x, self.weight, self.bias, self._mean, self._variance, self._momentum, self._epsilon, self._data_format, not self.training, False, False, False) -- GitLab