From dddc5d9d10317ff90f8f6c3ab48f6ee7a3a1a919 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Tue, 31 Jan 2023 15:02:36 +0800 Subject: [PATCH] [cherry-pick]BatchNorm use inplace (#49529) att, cherry-pick#48254, and resolve conflict --- .../api/yaml/generator/generate_sparse_op.py | 2 + paddle/phi/api/yaml/sparse_backward.yaml | 2 +- paddle/phi/api/yaml/sparse_ops.yaml | 4 +- paddle/phi/kernels/sparse/batch_norm_kernel.h | 38 +++++++++---------- python/paddle/sparse/nn/layer/norm.py | 2 +- 5 files changed, 25 insertions(+), 23 deletions(-) diff --git a/paddle/phi/api/yaml/generator/generate_sparse_op.py b/paddle/phi/api/yaml/generator/generate_sparse_op.py index 48ba0d81eca..22c6ef0866b 100644 --- a/paddle/phi/api/yaml/generator/generate_sparse_op.py +++ b/paddle/phi/api/yaml/generator/generate_sparse_op.py @@ -84,6 +84,8 @@ def main( backward_api_dict = to_named_dict(backward_apis) for api in apis: + if api['name'][-1] == '_': + api['name'] = api['name'][:-1] api['op_name'] = SPARSE_OP_PREFIX + api['name'] api['name'] = api['op_name'] if api["backward"] is not None: diff --git a/paddle/phi/api/yaml/sparse_backward.yaml b/paddle/phi/api/yaml/sparse_backward.yaml index 6503dbb46e8..d8250867398 100644 --- a/paddle/phi/api/yaml/sparse_backward.yaml +++ b/paddle/phi/api/yaml/sparse_backward.yaml @@ -101,7 +101,7 @@ atanh_csr_grad {sparse_csr, sparse_csr -> sparse_csr} - backward_op : batch_norm_grad - forward : batch_norm (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) -> Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space) + forward : batch_norm_ (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) -> Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space) args : (Tensor x, Tensor scale, Tensor bias, Tensor mean_out, Tensor variance_out, Tensor saved_mean, Tensor saved_variance, Tensor reserve_space, Tensor out_grad, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) output : Tensor(x_grad), Tensor(scale_grad), Tensor(bias_grad) infer_meta : diff --git a/paddle/phi/api/yaml/sparse_ops.yaml b/paddle/phi/api/yaml/sparse_ops.yaml index 015e7aef0ff..c3b39e5dc3d 100644 --- a/paddle/phi/api/yaml/sparse_ops.yaml +++ b/paddle/phi/api/yaml/sparse_ops.yaml @@ -87,7 +87,7 @@ layout : x backward : atanh_grad -- op : batch_norm +- op : batch_norm_ args : (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) output : Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space) infer_meta : @@ -95,7 +95,7 @@ kernel : func : batch_norm_coo {sparse_coo, dense, dense, dense, dense -> sparse_coo, dense, dense, dense, dense, dense} data_type : x - view : (mean -> mean_out), (variance -> variance_out) + inplace : (mean -> mean_out), (variance -> variance_out) backward : batch_norm_grad - op : cast diff --git a/paddle/phi/kernels/sparse/batch_norm_kernel.h b/paddle/phi/kernels/sparse/batch_norm_kernel.h index 282a8de7b39..a94a1203c8c 100644 --- a/paddle/phi/kernels/sparse/batch_norm_kernel.h +++ b/paddle/phi/kernels/sparse/batch_norm_kernel.h @@ -23,25 +23,25 @@ namespace phi { namespace sparse { template -void BatchNormKernel(const Context& dev_ctx, - const SparseCooTensor& x, - const DenseTensor& scale, - const DenseTensor& bias, - const DenseTensor& mean, - const DenseTensor& variance, - float momentum, - float epsilon, - const std::string& data_layout, - bool is_test, - bool use_global_stats, - bool trainable_statistics, - bool fuse_with_relu, - SparseCooTensor* y, - DenseTensor* mean_out, - DenseTensor* variance_out, - DenseTensor* saved_mean, - DenseTensor* saved_variance, - DenseTensor* reserve_space); +void BatchNormCooKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& scale, + const DenseTensor& bias, + const DenseTensor& mean, + const DenseTensor& variance, + float momentum, + float epsilon, + const std::string& data_layout, + bool is_test, + bool use_global_stats, + bool trainable_statistics, + bool fuse_with_relu, + SparseCooTensor* y, + DenseTensor* mean_out, + DenseTensor* variance_out, + DenseTensor* saved_mean, + DenseTensor* saved_variance, + DenseTensor* reserve_space); } // namespace sparse } // namespace phi diff --git a/python/paddle/sparse/nn/layer/norm.py b/python/paddle/sparse/nn/layer/norm.py index 34ed96f9e43..987e1835ecd 100644 --- a/python/paddle/sparse/nn/layer/norm.py +++ b/python/paddle/sparse/nn/layer/norm.py @@ -138,7 +138,7 @@ class BatchNorm(paddle.nn.BatchNorm1D): data_format = 'NCHW' if self._data_format[1] == 'C' else 'NHWC' if in_dynamic_mode(): - batch_norm_out, _, _, _, _, _ = _C_ops.sparse_batch_norm( + batch_norm_out, _, _, _, _, _ = _C_ops.sparse_batch_norm_( input, self.weight, self.bias, -- GitLab