未验证 提交 bfa46c38 编写于 作者: L Leo Chen 提交者: GitHub

bn supports reverse_space, test=develop (#24988)

上级 613303db
...@@ -52,6 +52,9 @@ std::map<std::string, std::set<std::string>> op_ins_map = { ...@@ -52,6 +52,9 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
std::map<std::string, std::set<std::string>> op_outs_map = { std::map<std::string, std::set<std::string>> op_outs_map = {
{"fake_quantize_dequantize_moving_average_abs_max", {"fake_quantize_dequantize_moving_average_abs_max",
{"Out", "OutScale", "OutAccum", "OutState"}}, {"Out", "OutScale", "OutAccum", "OutState"}},
{"batch_norm",
{"Y", "MeanOut", "VarianceOut", "SavedMean", "SavedVariance",
"ReserveSpace"}},
}; };
// NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are // NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are
......
...@@ -1310,9 +1310,10 @@ class BatchNorm(layers.Layer): ...@@ -1310,9 +1310,10 @@ class BatchNorm(layers.Layer):
self._fuse_with_relu, "use_global_stats", self._fuse_with_relu, "use_global_stats",
self._use_global_stats, 'trainable_statistics', self._use_global_stats, 'trainable_statistics',
self._trainable_statistics) self._trainable_statistics)
batch_norm_out, _, _, _, _ = core.ops.batch_norm( batch_norm_out, _, _, _, _, _ = core.ops.batch_norm(
input, self.weight, self.bias, self._mean, self._variance, input, self.weight, self.bias, self._mean, self._variance,
mean_out, variance_out, *attrs) mean_out, variance_out, *attrs)
return dygraph_utils._append_activation_in_dygraph( return dygraph_utils._append_activation_in_dygraph(
batch_norm_out, act=self._act) batch_norm_out, act=self._act)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册