diff --git a/mace/python/tools/tf_converter_lib.py b/mace/python/tools/tf_converter_lib.py index 5e9acbfd639384bf55e473c2f7c8f7099fbe2916..671230e51dfb01fa11d31e96e677817e2ac0ed88 100644 --- a/mace/python/tools/tf_converter_lib.py +++ b/mace/python/tools/tf_converter_lib.py @@ -12,6 +12,11 @@ pooling_type_mode = { 'MaxPool': 2 } +def get_input_tensor(op, index): + input_tensor = op.inputs[index] + if input_tensor.op.type == 'Reshape': + input_tensor = get_input_tensor(input_tensor.op, 0) + return input_tensor def convert_ops(unresolved_ops, net_def): ops_count = len(unresolved_ops) @@ -19,7 +24,7 @@ def convert_ops(unresolved_ops, net_def): first_op = unresolved_ops[0] - if first_op.type == 'Placeholder': + if first_op.type == 'Placeholder' or first_op.type == 'Reshape': pass elif first_op.type == 'Const': tf_tensor = first_op.outputs[0].eval() @@ -76,12 +81,12 @@ def convert_ops(unresolved_ops, net_def): if mul_op.type != 'Mul' or mul_2_op.type != 'Mul' or mul_1_op.type != 'Mul' or sub_op.type != 'Sub' or add_1_op.type != 'Add': raise Exception('Invalid BatchNorm Op') - input_name = mul_1_op.inputs[0].name - gamma = mul_op.inputs[1].name - beta = sub_op.inputs[0].name - mean = mul_2_op.inputs[0].name - variance = add_op.inputs[0].name - epsilon = add_op.inputs[1].name + input_name = get_input_tensor(mul_1_op, 0).name + gamma = get_input_tensor(mul_op, 1).name + beta = get_input_tensor(sub_op, 0).name + mean = get_input_tensor(mul_2_op, 0).name + variance = get_input_tensor(add_op, 0).name + epsilon = get_input_tensor(add_op, 1).name op_def = net_def.op.add() op_def.name = first_op.name[:-4] # remove /add