diff --git a/paddle/phi/api/yaml/generator/generate_sparse_op.py b/paddle/phi/api/yaml/generator/generate_sparse_op.py index 48ba0d81eca3d621fb3e39c317a93da675544da1..22c6ef0866b346d51166be9cdf2c7c71bbb152e2 100644 --- a/paddle/phi/api/yaml/generator/generate_sparse_op.py +++ b/paddle/phi/api/yaml/generator/generate_sparse_op.py @@ -84,6 +84,8 @@ def main( backward_api_dict = to_named_dict(backward_apis) for api in apis: + if api['name'][-1] == '_': + api['name'] = api['name'][:-1] api['op_name'] = SPARSE_OP_PREFIX + api['name'] api['name'] = api['op_name'] if api["backward"] is not None: diff --git a/paddle/phi/api/yaml/sparse_backward.yaml b/paddle/phi/api/yaml/sparse_backward.yaml index 6503dbb46e85765c57945d2e02e9c8b707f3f233..d8250867398613c5da4593b1195a538428313c73 100644 --- a/paddle/phi/api/yaml/sparse_backward.yaml +++ b/paddle/phi/api/yaml/sparse_backward.yaml @@ -101,7 +101,7 @@ atanh_csr_grad {sparse_csr, sparse_csr -> sparse_csr} - backward_op : batch_norm_grad - forward : batch_norm (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) -> Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space) + forward : batch_norm_ (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) -> Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space) args : (Tensor x, Tensor scale, Tensor bias, Tensor mean_out, Tensor variance_out, Tensor saved_mean, Tensor saved_variance, Tensor reserve_space, Tensor out_grad, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) output : Tensor(x_grad), Tensor(scale_grad), Tensor(bias_grad) infer_meta : diff --git a/paddle/phi/api/yaml/sparse_ops.yaml b/paddle/phi/api/yaml/sparse_ops.yaml index 015e7aef0ff0b01f8655f059671eb9f29d19aea4..c3b39e5dc3d1fab8d11ea6c8d3fc78e36ad4e141 100644 --- a/paddle/phi/api/yaml/sparse_ops.yaml +++ b/paddle/phi/api/yaml/sparse_ops.yaml @@ -87,7 +87,7 @@ layout : x backward : atanh_grad -- op : batch_norm +- op : batch_norm_ args : (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) output : Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space) infer_meta : @@ -95,7 +95,7 @@ kernel : func : batch_norm_coo {sparse_coo, dense, dense, dense, dense -> sparse_coo, dense, dense, dense, dense, dense} data_type : x - view : (mean -> mean_out), (variance -> variance_out) + inplace : (mean -> mean_out), (variance -> variance_out) backward : batch_norm_grad - op : cast diff --git a/paddle/phi/kernels/sparse/batch_norm_kernel.h b/paddle/phi/kernels/sparse/batch_norm_kernel.h index 282a8de7b39d4cba940fc1dda140fffe5ce2b376..a94a1203c8cdc19870b2440e8772bc9853990d83 100644 --- a/paddle/phi/kernels/sparse/batch_norm_kernel.h +++ b/paddle/phi/kernels/sparse/batch_norm_kernel.h @@ -23,25 +23,25 @@ namespace phi { namespace sparse { template -void BatchNormKernel(const Context& dev_ctx, - const SparseCooTensor& x, - const DenseTensor& scale, - const DenseTensor& bias, - const DenseTensor& mean, - const DenseTensor& variance, - float momentum, - float epsilon, - const std::string& data_layout, - bool is_test, - bool use_global_stats, - bool trainable_statistics, - bool fuse_with_relu, - SparseCooTensor* y, - DenseTensor* mean_out, - DenseTensor* variance_out, - DenseTensor* saved_mean, - DenseTensor* saved_variance, - DenseTensor* reserve_space); +void BatchNormCooKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& scale, + const DenseTensor& bias, + const DenseTensor& mean, + const DenseTensor& variance, + float momentum, + float epsilon, + const std::string& data_layout, + bool is_test, + bool use_global_stats, + bool trainable_statistics, + bool fuse_with_relu, + SparseCooTensor* y, + DenseTensor* mean_out, + DenseTensor* variance_out, + DenseTensor* saved_mean, + DenseTensor* saved_variance, + DenseTensor* reserve_space); } // namespace sparse } // namespace phi diff --git a/python/paddle/sparse/nn/layer/norm.py b/python/paddle/sparse/nn/layer/norm.py index 34ed96f9e434cec316fc3a27cd2eae5eff686464..987e1835ecd0377d15501dab9a7dbe66b88b0dae 100644 --- a/python/paddle/sparse/nn/layer/norm.py +++ b/python/paddle/sparse/nn/layer/norm.py @@ -138,7 +138,7 @@ class BatchNorm(paddle.nn.BatchNorm1D): data_format = 'NCHW' if self._data_format[1] == 'C' else 'NHWC' if in_dynamic_mode(): - batch_norm_out, _, _, _, _, _ = _C_ops.sparse_batch_norm( + batch_norm_out, _, _, _, _, _ = _C_ops.sparse_batch_norm_( input, self.weight, self.bias,