未验证 提交 0d9b45e8 编写于 作者: S sunnycase 提交者: GitHub

Allow cut postprocess but preserve input (#1275)

* Allow cut postprocess but preserve input

* apply code-format changes
Co-authored-by: Nsunnycase <sunnycase@users.noreply.github.com>
上级 66b2c22a
......@@ -185,9 +185,9 @@ int main(int argc, char* argv[])
// create VeriSilicon TIM-VX backend
context_t timvx_context = create_context("timvx", 1);
if ( set_context_device(timvx_context, "TIMVX", nullptr, 0) < 0 )
if (set_context_device(timvx_context, "TIMVX", nullptr, 0) < 0)
{
fprintf(stderr, "add_context_device failed.\n" );
fprintf(stderr, "add_context_device failed.\n");
return 1;
}
......
......@@ -60,7 +60,7 @@ def parse_args():
args = parse_args()
def cut_focus_output(input_node, in_name, out_name):
def cut_focus_output(input_node, in_name, out_name, cut_focus):
"""
cut the focus and postprocess nodes
Args:
......@@ -98,8 +98,11 @@ def cut_focus_output(input_node, in_name, out_name):
del input_node[i]
# cut input node
for n in in_name:
new_nodes = input_node[(node_dict[n] + 1):]
if cut_focus:
for n in in_name:
new_nodes = input_node[(node_dict[n] + 1):]
else:
new_nodes = input_node[:]
return new_nodes
......@@ -181,9 +184,8 @@ def main():
new_nodes = old_node[:]
# cut the focus and postprocess nodes
if args.cut_focus:
print("[Quant Tools Info]: Step 1, Remove the focus and postprocess nodes.")
new_nodes = cut_focus_output(old_node, in_tensor, out_tensor)
print("[Quant Tools Info]: Step 1, Remove the focus and postprocess nodes.")
new_nodes = cut_focus_output(old_node, in_tensor, out_tensor, args.cut_focus)
# op fusion, using HardSwish replace the Sigmoid and Mul
print("[Quant Tools Info]: Step 2, Using hardswish replace the sigmoid and mul.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册