提交 8d1f3025 编写于 作者: M Megvii Engine Team

fix(mge): fix batch norm dump

GitOrigin-RevId: eb739437ef48fc6e8ddf55a9ebc54e8979b55cbd
上级 40e778fb
......@@ -47,13 +47,17 @@ class XORNet(M.Module):
self.num_class = 2
super().__init__()
self.fc0 = M.Linear(self.num_class, self.mid_dim, bias=True)
self.bn0 = M.BatchNorm1d(self.mid_dim)
self.fc1 = M.Linear(self.mid_dim, self.mid_dim, bias=True)
self.bn1 = M.BatchNorm1d(self.mid_dim)
self.fc2 = M.Linear(self.mid_dim, self.num_class, bias=True)
def forward(self, x):
x = self.fc0(x)
x = self.bn0(x)
x = F.tanh(x)
x = self.fc1(x)
x = self.bn1(x)
x = F.tanh(x)
x = self.fc2(x)
return x
......
......@@ -44,7 +44,7 @@ BatchNormForward::BatchNormForward(VarNode *x,
m_force_inplace = false;
}
if (m_force_inplace) {
if (m_force_inplace && param.fwd_mode == Param::FwdMode::TRAINING) {
auto check_dest = [&](VarNode* dest) {
auto dest_opr = dest->owner_opr();
mgb_throw_if(!(dest_opr->same_type<SharedDeviceTensor>() ||
......@@ -62,7 +62,14 @@ BatchNormForward::BatchNormForward(VarNode *x,
add_input({x, scale, bias, mean, variance});
if (m_force_inplace) {
if (param.fwd_mode == Param::FwdMode::INFERENCE) {
auto mark_empty_var = [&](VarNode *var) {
var->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
.add_flag(VarNode::Flag::VOLATILE_CONTENT);
};
mark_empty_var(output(0));
mark_empty_var(output(1));
} else if (m_force_inplace) {
output(0)->
set_fwd_in2out_writable_force(input(3)).
add_flag(VarNode::Flag::NO_MEM_RECLAIM);
......@@ -129,7 +136,7 @@ SymbolVarArray BatchNormForward::make(SymbolVar x,
cg::OperatorNodeBase::NodeProp*
BatchNormForward::do_make_node_prop() const {
auto ret = Super::do_make_node_prop();
if (input().size() == 5) {
if (need_stats()) {
ret->add_flag(NodeProp::Flag::FORCE_UPDATE_INPUT_VAR);
}
return ret;
......@@ -140,7 +147,7 @@ void BatchNormForward::scn_do_execute() {
auto &&y = output(4)->dev_tensor();
mgb_assert(x.layout().is_contiguous() &&
y.layout().is_contiguous());
if (input().size() == 5) { // need running mean/variance
if (need_stats()) {
auto &&o0 = output(0)->dev_tensor(),
&&o1 = output(1)->dev_tensor(),
&&i0 = input(3)->dev_tensor(),
......@@ -164,8 +171,14 @@ void BatchNormForward::scn_do_execute() {
}
auto scale = input(1)->dev_tensor().as_megdnn();
auto bias = input(2)->dev_tensor().as_megdnn();
auto mean = output(0)->dev_tensor().as_megdnn();
auto variance = output(1)->dev_tensor().as_megdnn();
megdnn::TensorND mean, variance;
if (param().fwd_mode == Param::FwdMode::INFERENCE) {
mean = input(3)->dev_tensor().as_megdnn();
variance = input(4)->dev_tensor().as_megdnn();
} else {
mean = output(0)->dev_tensor().as_megdnn();
variance = output(1)->dev_tensor().as_megdnn();
}
auto save_mean = output(2)->dev_tensor().as_megdnn();
auto save_variance = output(3)->dev_tensor().as_megdnn();
auto workspace = intl::get_megdnn_workspace_from_var(output().back());
......@@ -180,12 +193,11 @@ void BatchNormForward::add_input_layout_constraint() {
void BatchNormForward::get_output_var_shape(
const TensorShapeArray &inp_shape,
TensorShapeArray &out_shape) const {
size_t nr_inp = input().size();
out_shape[4] = inp_shape[0];
for (size_t i = 0; i < 4; ++ i) {
out_shape[i] = inp_shape[1];
}
if (nr_inp == 3) {
if (!need_stats()) {
out_shape[0] = out_shape[1] = {0};
}
}
......@@ -221,7 +233,7 @@ void BatchNormForward::init_output_dtype() {
}
void BatchNormForward::mem_plan_fwd_in2out_writable() {
if (!m_force_inplace && input().size() == 5) {
if (need_stats() && !m_force_inplace) {
// TODO: testing
output(0)->set_fwd_in2out_writable(input(3));
output(1)->set_fwd_in2out_writable(input(4));
......
......@@ -79,6 +79,8 @@ MGB_DEFINE_OPR_CLASS(BatchNormForward,
// if set to True, running mean/variance will be updated inplace
bool m_force_inplace = true;
// need running mean/variance
bool need_stats() const {return input().size() == 5 && param().fwd_mode == Param::FwdMode::TRAINING;}
};
using BatchNorm = BatchNormForward;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册