提交 8557ed20 编写于 作者: 李寅

Fix: if reshape in BN, skip it

上级 a407488f
......@@ -12,6 +12,11 @@ pooling_type_mode = {
'MaxPool': 2
}
def get_input_tensor(op, index):
input_tensor = op.inputs[index]
if input_tensor.op.type == 'Reshape':
input_tensor = get_input_tensor(input_tensor.op, 0)
return input_tensor
def convert_ops(unresolved_ops, net_def):
ops_count = len(unresolved_ops)
......@@ -19,7 +24,7 @@ def convert_ops(unresolved_ops, net_def):
first_op = unresolved_ops[0]
if first_op.type == 'Placeholder':
if first_op.type == 'Placeholder' or first_op.type == 'Reshape':
pass
elif first_op.type == 'Const':
tf_tensor = first_op.outputs[0].eval()
......@@ -76,12 +81,12 @@ def convert_ops(unresolved_ops, net_def):
if mul_op.type != 'Mul' or mul_2_op.type != 'Mul' or mul_1_op.type != 'Mul' or sub_op.type != 'Sub' or add_1_op.type != 'Add':
raise Exception('Invalid BatchNorm Op')
input_name = mul_1_op.inputs[0].name
gamma = mul_op.inputs[1].name
beta = sub_op.inputs[0].name
mean = mul_2_op.inputs[0].name
variance = add_op.inputs[0].name
epsilon = add_op.inputs[1].name
input_name = get_input_tensor(mul_1_op, 0).name
gamma = get_input_tensor(mul_op, 1).name
beta = get_input_tensor(sub_op, 0).name
mean = get_input_tensor(mul_2_op, 0).name
variance = get_input_tensor(add_op, 0).name
epsilon = get_input_tensor(add_op, 1).name
op_def = net_def.op.add()
op_def.name = first_op.name[:-4] # remove /add
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册