未验证 提交 5df464fe 编写于 作者: W wanghuancoder 提交者: GitHub

[Eager] sync_batch_norm_grad delete mean and variance (#45411)

* sync_batch_norm_grad delete mean and variance
上级 1cd7e68b
......@@ -2460,7 +2460,7 @@
- backward_api : sync_batch_norm_grad
forward : sync_batch_norm_ (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) -> Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space)
args : (Tensor x, Tensor scale, Tensor bias, Tensor mean_out, Tensor variance_out, Tensor saved_mean, Tensor saved_variance, Tensor reserve_space, Tensor out_grad, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu)
args : (Tensor x, Tensor scale, Tensor bias, Tensor saved_mean, Tensor saved_variance, Tensor reserve_space, Tensor out_grad, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu)
output : Tensor(x_grad), Tensor(scale_grad), Tensor(bias_grad)
infer_meta :
func : GeneralTernaryGradInferMeta
......@@ -2468,7 +2468,7 @@
kernel :
func : sync_batch_norm_grad
data_type : out_grad
optional : mean_out, variance_out, reserve_space
optional : reserve_space
- backward_api : take_along_axis_grad
forward : take_along_axis (Tensor x, Tensor index, int axis) -> Tensor(out)
......
......@@ -24,8 +24,6 @@ void SyncBatchNormGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& scale,
const DenseTensor& bias,
const paddle::optional<DenseTensor>& mean,
const paddle::optional<DenseTensor>& variance,
const DenseTensor& saved_mean,
const DenseTensor& saved_variance,
const paddle::optional<DenseTensor>& reserve_space,
......
......@@ -25,8 +25,6 @@ void SyncBatchNormGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& scale,
const DenseTensor& bias,
const paddle::optional<DenseTensor>& mean,
const paddle::optional<DenseTensor>& variance,
const DenseTensor& saved_mean,
const DenseTensor& saved_variance,
const paddle::optional<DenseTensor>& reserve_space,
......
......@@ -42,8 +42,6 @@ KernelSignature SyncBatchNormGradOpArgumentMapping(
"X",
"Scale",
"Bias",
"Mean",
"Variance",
"SavedMean",
"SavedVariance",
"ReserveSpace",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册