未验证 提交 b9e6b94d 编写于 作者: Z zhangkaihuo 提交者: GitHub

fix sparse inplace (#47167)

上级 b9c8c1b1
...@@ -238,7 +238,7 @@ class SparseAPI(ForwardAPI): ...@@ -238,7 +238,7 @@ class SparseAPI(ForwardAPI):
kernel_name, inplace_flag) kernel_name, inplace_flag)
return f""" return f"""
PADDLE_API {self.get_return_type()} {api_func_name}({self.get_define_args()}) {{ PADDLE_API {self.get_return_type(inplace_flag)} {api_func_name}({self.get_define_args(inplace_flag)}) {{
{kernel_dispatch_code} {kernel_dispatch_code}
PADDLE_THROW(phi::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"The kernel of ({self.api}) for input tensors is unimplemented, please check the type of input tensors.")); "The kernel of ({self.api}) for input tensors is unimplemented, please check the type of input tensors."));
......
...@@ -368,7 +368,7 @@ ...@@ -368,7 +368,7 @@
subtract_csr_csr_grad{sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr} subtract_csr_csr_grad{sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr}
- backward_op : sync_batch_norm_grad - backward_op : sync_batch_norm_grad
forward : sync_batch_norm(Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) -> Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space) forward : sync_batch_norm_(Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) -> Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space)
args : (Tensor x, Tensor scale, Tensor bias, Tensor saved_mean, Tensor saved_variance, Tensor reserve_space, Tensor out_grad, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) args : (Tensor x, Tensor scale, Tensor bias, Tensor saved_mean, Tensor saved_variance, Tensor reserve_space, Tensor out_grad, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu)
output : Tensor(x_grad), Tensor(scale_grad), Tensor(bias_grad) output : Tensor(x_grad), Tensor(scale_grad), Tensor(bias_grad)
infer_meta : infer_meta :
......
...@@ -95,6 +95,7 @@ ...@@ -95,6 +95,7 @@
kernel : kernel :
func : batch_norm_coo {sparse_coo, dense, dense, dense, dense -> sparse_coo, dense, dense, dense, dense, dense} func : batch_norm_coo {sparse_coo, dense, dense, dense, dense -> sparse_coo, dense, dense, dense, dense, dense}
data_type : x data_type : x
view : (mean -> mean_out), (variance -> variance_out)
backward : batch_norm_grad backward : batch_norm_grad
- op : cast - op : cast
...@@ -480,7 +481,7 @@ ...@@ -480,7 +481,7 @@
layout : x layout : x
backward : transpose_grad backward : transpose_grad
- op : sync_batch_norm - op : sync_batch_norm_
args : (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) args : (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu)
output : Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space) output : Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space)
infer_meta : infer_meta :
...@@ -489,6 +490,7 @@ ...@@ -489,6 +490,7 @@
func : sync_batch_norm_coo{sparse_coo, dense, dense, dense, dense -> sparse_coo, dense, dense, dense, dense, dense} func : sync_batch_norm_coo{sparse_coo, dense, dense, dense, dense -> sparse_coo, dense, dense, dense, dense, dense}
data_type : x data_type : x
backward : sync_batch_norm_grad backward : sync_batch_norm_grad
inplace : (mean -> mean_out), (variance -> variance_out)
- op : reshape - op : reshape
args : (Tensor x, IntArray shape) args : (Tensor x, IntArray shape)
......
...@@ -297,7 +297,7 @@ class SyncBatchNorm(paddle.nn.SyncBatchNorm): ...@@ -297,7 +297,7 @@ class SyncBatchNorm(paddle.nn.SyncBatchNorm):
def forward(self, x): def forward(self, x):
self._check_data_format() self._check_data_format()
sync_batch_norm_out, _, _, _, _, _ = _C_ops.sparse_sync_batch_norm( sync_batch_norm_out, _, _, _, _, _ = _C_ops.sparse_sync_batch_norm_(
x, self.weight, self.bias, self._mean, self._variance, x, self.weight, self.bias, self._mean, self._variance,
self._momentum, self._epsilon, self._data_format, not self.training, self._momentum, self._epsilon, self._data_format, not self.training,
False, False, False) False, False, False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册