From 8557ed2062041ae77f63963d675e018b28197132 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=AF=85?= Date: Mon, 25 Sep 2017 17:02:07 +0800 Subject: [PATCH] Fix: if reshape in BN, skip it --- mace/python/tools/tf_converter_lib.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/mace/python/tools/tf_converter_lib.py b/mace/python/tools/tf_converter_lib.py index 5e9acbfd..671230e5 100644 --- a/mace/python/tools/tf_converter_lib.py +++ b/mace/python/tools/tf_converter_lib.py @@ -12,6 +12,11 @@ pooling_type_mode = { 'MaxPool': 2 } +def get_input_tensor(op, index): + input_tensor = op.inputs[index] + if input_tensor.op.type == 'Reshape': + input_tensor = get_input_tensor(input_tensor.op, 0) + return input_tensor def convert_ops(unresolved_ops, net_def): ops_count = len(unresolved_ops) @@ -19,7 +24,7 @@ def convert_ops(unresolved_ops, net_def): first_op = unresolved_ops[0] - if first_op.type == 'Placeholder': + if first_op.type == 'Placeholder' or first_op.type == 'Reshape': pass elif first_op.type == 'Const': tf_tensor = first_op.outputs[0].eval() @@ -76,12 +81,12 @@ def convert_ops(unresolved_ops, net_def): if mul_op.type != 'Mul' or mul_2_op.type != 'Mul' or mul_1_op.type != 'Mul' or sub_op.type != 'Sub' or add_1_op.type != 'Add': raise Exception('Invalid BatchNorm Op') - input_name = mul_1_op.inputs[0].name - gamma = mul_op.inputs[1].name - beta = sub_op.inputs[0].name - mean = mul_2_op.inputs[0].name - variance = add_op.inputs[0].name - epsilon = add_op.inputs[1].name + input_name = get_input_tensor(mul_1_op, 0).name + gamma = get_input_tensor(mul_op, 1).name + beta = get_input_tensor(sub_op, 0).name + mean = get_input_tensor(mul_2_op, 0).name + variance = get_input_tensor(add_op, 0).name + epsilon = get_input_tensor(add_op, 1).name op_def = net_def.op.add() op_def.name = first_op.name[:-4] # remove /add -- GitLab