提交 8d1f3025 编写于 作者: M Megvii Engine Team

fix(mge): fix batch norm dump

GitOrigin-RevId: eb739437ef48fc6e8ddf55a9ebc54e8979b55cbd
上级 40e778fb
...@@ -47,13 +47,17 @@ class XORNet(M.Module): ...@@ -47,13 +47,17 @@ class XORNet(M.Module):
self.num_class = 2 self.num_class = 2
super().__init__() super().__init__()
self.fc0 = M.Linear(self.num_class, self.mid_dim, bias=True) self.fc0 = M.Linear(self.num_class, self.mid_dim, bias=True)
self.bn0 = M.BatchNorm1d(self.mid_dim)
self.fc1 = M.Linear(self.mid_dim, self.mid_dim, bias=True) self.fc1 = M.Linear(self.mid_dim, self.mid_dim, bias=True)
self.bn1 = M.BatchNorm1d(self.mid_dim)
self.fc2 = M.Linear(self.mid_dim, self.num_class, bias=True) self.fc2 = M.Linear(self.mid_dim, self.num_class, bias=True)
def forward(self, x): def forward(self, x):
x = self.fc0(x) x = self.fc0(x)
x = self.bn0(x)
x = F.tanh(x) x = F.tanh(x)
x = self.fc1(x) x = self.fc1(x)
x = self.bn1(x)
x = F.tanh(x) x = F.tanh(x)
x = self.fc2(x) x = self.fc2(x)
return x return x
......
...@@ -44,7 +44,7 @@ BatchNormForward::BatchNormForward(VarNode *x, ...@@ -44,7 +44,7 @@ BatchNormForward::BatchNormForward(VarNode *x,
m_force_inplace = false; m_force_inplace = false;
} }
if (m_force_inplace) { if (m_force_inplace && param.fwd_mode == Param::FwdMode::TRAINING) {
auto check_dest = [&](VarNode* dest) { auto check_dest = [&](VarNode* dest) {
auto dest_opr = dest->owner_opr(); auto dest_opr = dest->owner_opr();
mgb_throw_if(!(dest_opr->same_type<SharedDeviceTensor>() || mgb_throw_if(!(dest_opr->same_type<SharedDeviceTensor>() ||
...@@ -62,7 +62,14 @@ BatchNormForward::BatchNormForward(VarNode *x, ...@@ -62,7 +62,14 @@ BatchNormForward::BatchNormForward(VarNode *x,
add_input({x, scale, bias, mean, variance}); add_input({x, scale, bias, mean, variance});
if (m_force_inplace) { if (param.fwd_mode == Param::FwdMode::INFERENCE) {
auto mark_empty_var = [&](VarNode *var) {
var->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
.add_flag(VarNode::Flag::VOLATILE_CONTENT);
};
mark_empty_var(output(0));
mark_empty_var(output(1));
} else if (m_force_inplace) {
output(0)-> output(0)->
set_fwd_in2out_writable_force(input(3)). set_fwd_in2out_writable_force(input(3)).
add_flag(VarNode::Flag::NO_MEM_RECLAIM); add_flag(VarNode::Flag::NO_MEM_RECLAIM);
...@@ -129,7 +136,7 @@ SymbolVarArray BatchNormForward::make(SymbolVar x, ...@@ -129,7 +136,7 @@ SymbolVarArray BatchNormForward::make(SymbolVar x,
cg::OperatorNodeBase::NodeProp* cg::OperatorNodeBase::NodeProp*
BatchNormForward::do_make_node_prop() const { BatchNormForward::do_make_node_prop() const {
auto ret = Super::do_make_node_prop(); auto ret = Super::do_make_node_prop();
if (input().size() == 5) { if (need_stats()) {
ret->add_flag(NodeProp::Flag::FORCE_UPDATE_INPUT_VAR); ret->add_flag(NodeProp::Flag::FORCE_UPDATE_INPUT_VAR);
} }
return ret; return ret;
...@@ -140,7 +147,7 @@ void BatchNormForward::scn_do_execute() { ...@@ -140,7 +147,7 @@ void BatchNormForward::scn_do_execute() {
auto &&y = output(4)->dev_tensor(); auto &&y = output(4)->dev_tensor();
mgb_assert(x.layout().is_contiguous() && mgb_assert(x.layout().is_contiguous() &&
y.layout().is_contiguous()); y.layout().is_contiguous());
if (input().size() == 5) { // need running mean/variance if (need_stats()) {
auto &&o0 = output(0)->dev_tensor(), auto &&o0 = output(0)->dev_tensor(),
&&o1 = output(1)->dev_tensor(), &&o1 = output(1)->dev_tensor(),
&&i0 = input(3)->dev_tensor(), &&i0 = input(3)->dev_tensor(),
...@@ -164,8 +171,14 @@ void BatchNormForward::scn_do_execute() { ...@@ -164,8 +171,14 @@ void BatchNormForward::scn_do_execute() {
} }
auto scale = input(1)->dev_tensor().as_megdnn(); auto scale = input(1)->dev_tensor().as_megdnn();
auto bias = input(2)->dev_tensor().as_megdnn(); auto bias = input(2)->dev_tensor().as_megdnn();
auto mean = output(0)->dev_tensor().as_megdnn(); megdnn::TensorND mean, variance;
auto variance = output(1)->dev_tensor().as_megdnn(); if (param().fwd_mode == Param::FwdMode::INFERENCE) {
mean = input(3)->dev_tensor().as_megdnn();
variance = input(4)->dev_tensor().as_megdnn();
} else {
mean = output(0)->dev_tensor().as_megdnn();
variance = output(1)->dev_tensor().as_megdnn();
}
auto save_mean = output(2)->dev_tensor().as_megdnn(); auto save_mean = output(2)->dev_tensor().as_megdnn();
auto save_variance = output(3)->dev_tensor().as_megdnn(); auto save_variance = output(3)->dev_tensor().as_megdnn();
auto workspace = intl::get_megdnn_workspace_from_var(output().back()); auto workspace = intl::get_megdnn_workspace_from_var(output().back());
...@@ -180,12 +193,11 @@ void BatchNormForward::add_input_layout_constraint() { ...@@ -180,12 +193,11 @@ void BatchNormForward::add_input_layout_constraint() {
void BatchNormForward::get_output_var_shape( void BatchNormForward::get_output_var_shape(
const TensorShapeArray &inp_shape, const TensorShapeArray &inp_shape,
TensorShapeArray &out_shape) const { TensorShapeArray &out_shape) const {
size_t nr_inp = input().size();
out_shape[4] = inp_shape[0]; out_shape[4] = inp_shape[0];
for (size_t i = 0; i < 4; ++ i) { for (size_t i = 0; i < 4; ++ i) {
out_shape[i] = inp_shape[1]; out_shape[i] = inp_shape[1];
} }
if (nr_inp == 3) { if (!need_stats()) {
out_shape[0] = out_shape[1] = {0}; out_shape[0] = out_shape[1] = {0};
} }
} }
...@@ -221,7 +233,7 @@ void BatchNormForward::init_output_dtype() { ...@@ -221,7 +233,7 @@ void BatchNormForward::init_output_dtype() {
} }
void BatchNormForward::mem_plan_fwd_in2out_writable() { void BatchNormForward::mem_plan_fwd_in2out_writable() {
if (!m_force_inplace && input().size() == 5) { if (need_stats() && !m_force_inplace) {
// TODO: testing // TODO: testing
output(0)->set_fwd_in2out_writable(input(3)); output(0)->set_fwd_in2out_writable(input(3));
output(1)->set_fwd_in2out_writable(input(4)); output(1)->set_fwd_in2out_writable(input(4));
......
...@@ -79,6 +79,8 @@ MGB_DEFINE_OPR_CLASS(BatchNormForward, ...@@ -79,6 +79,8 @@ MGB_DEFINE_OPR_CLASS(BatchNormForward,
// if set to True, running mean/variance will be updated inplace // if set to True, running mean/variance will be updated inplace
bool m_force_inplace = true; bool m_force_inplace = true;
// need running mean/variance
bool need_stats() const {return input().size() == 5 && param().fwd_mode == Param::FwdMode::TRAINING;}
}; };
using BatchNorm = BatchNormForward; using BatchNorm = BatchNormForward;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册