未验证 提交 d33d6db0 编写于 作者: Z zhangkaihuo 提交者: GitHub

[Sparse]BatchNorm use inplace (#48254)

上级 41ba2722
...@@ -82,6 +82,8 @@ def main(op_yaml_path, backward_yaml_path, output_op_path, output_arg_map_path): ...@@ -82,6 +82,8 @@ def main(op_yaml_path, backward_yaml_path, output_op_path, output_arg_map_path):
backward_op_dict = to_named_dict(backward_ops) backward_op_dict = to_named_dict(backward_ops)
for op in ops: for op in ops:
if op['name'][-1] == '_':
op['name'] = op['name'][:-1]
op['op_name'] = SPARSE_OP_PREFIX + op['name'] op['op_name'] = SPARSE_OP_PREFIX + op['name']
op['name'] = op['op_name'] op['name'] = op['op_name']
if op["backward"] is not None: if op["backward"] is not None:
......
...@@ -101,7 +101,7 @@ ...@@ -101,7 +101,7 @@
atanh_csr_grad {sparse_csr, sparse_csr -> sparse_csr} atanh_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
- backward_op : batch_norm_grad - backward_op : batch_norm_grad
forward : batch_norm (Tensor x, Tensor mean, Tensor variance, Tensor scale, Tensor bias, bool is_test, float momentum, float epsilon, str data_layout, bool use_global_stats, bool trainable_statistics) -> Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space) forward : batch_norm_ (Tensor x, Tensor mean, Tensor variance, Tensor scale, Tensor bias, bool is_test, float momentum, float epsilon, str data_layout, bool use_global_stats, bool trainable_statistics) -> Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space)
args : (Tensor x, Tensor scale, Tensor bias, Tensor mean_out, Tensor variance_out, Tensor saved_mean, Tensor saved_variance, Tensor reserve_space, Tensor out_grad, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics) args : (Tensor x, Tensor scale, Tensor bias, Tensor mean_out, Tensor variance_out, Tensor saved_mean, Tensor saved_variance, Tensor reserve_space, Tensor out_grad, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics)
output : Tensor(x_grad), Tensor(scale_grad), Tensor(bias_grad) output : Tensor(x_grad), Tensor(scale_grad), Tensor(bias_grad)
infer_meta : infer_meta :
......
...@@ -87,7 +87,7 @@ ...@@ -87,7 +87,7 @@
layout : x layout : x
backward : atanh_grad backward : atanh_grad
- op : batch_norm - op : batch_norm_
args : (Tensor x, Tensor mean, Tensor variance, Tensor scale, Tensor bias, bool is_test, float momentum, float epsilon, str data_layout, bool use_global_stats, bool trainable_statistics) args : (Tensor x, Tensor mean, Tensor variance, Tensor scale, Tensor bias, bool is_test, float momentum, float epsilon, str data_layout, bool use_global_stats, bool trainable_statistics)
output : Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space) output : Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space)
infer_meta : infer_meta :
...@@ -95,7 +95,7 @@ ...@@ -95,7 +95,7 @@
kernel : kernel :
func : batch_norm_coo {sparse_coo, dense, dense, dense, dense -> sparse_coo, dense, dense, dense, dense, dense} func : batch_norm_coo {sparse_coo, dense, dense, dense, dense -> sparse_coo, dense, dense, dense, dense, dense}
data_type : x data_type : x
view : (mean -> mean_out), (variance -> variance_out) inplace : (mean -> mean_out), (variance -> variance_out)
backward : batch_norm_grad backward : batch_norm_grad
- op : cast - op : cast
......
...@@ -23,24 +23,24 @@ namespace phi { ...@@ -23,24 +23,24 @@ namespace phi {
namespace sparse { namespace sparse {
template <typename T, typename Context> template <typename T, typename Context>
void BatchNormKernel(const Context& dev_ctx, void BatchNormCooKernel(const Context& dev_ctx,
const SparseCooTensor& x, const SparseCooTensor& x,
const DenseTensor& scale, const DenseTensor& mean,
const DenseTensor& bias, const DenseTensor& variance,
const DenseTensor& mean, const DenseTensor& scale,
const DenseTensor& variance, const DenseTensor& bias,
float momentum, bool is_test,
float epsilon, float momentum,
const std::string& data_layout, float epsilon,
bool is_test, const std::string& data_layout,
bool use_global_stats, bool use_global_stats,
bool trainable_statistics, bool trainable_statistics,
SparseCooTensor* y, SparseCooTensor* y,
DenseTensor* mean_out, DenseTensor* mean_out,
DenseTensor* variance_out, DenseTensor* variance_out,
DenseTensor* saved_mean, DenseTensor* saved_mean,
DenseTensor* saved_variance, DenseTensor* saved_variance,
DenseTensor* reserve_space); DenseTensor* reserve_space);
} // namespace sparse } // namespace sparse
} // namespace phi } // namespace phi
...@@ -138,7 +138,7 @@ class BatchNorm(paddle.nn.BatchNorm1D): ...@@ -138,7 +138,7 @@ class BatchNorm(paddle.nn.BatchNorm1D):
data_format = 'NCHW' if self._data_format[1] == 'C' else 'NHWC' data_format = 'NCHW' if self._data_format[1] == 'C' else 'NHWC'
if in_dynamic_mode(): if in_dynamic_mode():
batch_norm_out, _, _, _, _, _ = _C_ops.sparse_batch_norm( batch_norm_out, _, _, _, _, _ = _C_ops.sparse_batch_norm_(
input, input,
self._mean, self._mean,
self._variance, self._variance,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册