提交 3adae34a 编写于 作者: B BBuf

fix convert bug

上级 9e266cab
...@@ -36,7 +36,7 @@ class AlexNet(nn.Module): ...@@ -36,7 +36,7 @@ class AlexNet(nn.Module):
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2), nn.MaxPool2d(kernel_size=3, stride=2),
) )
self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) self.avgpool = nn.AvgPool2d((1, 1))
self.classifier = nn.Sequential( self.classifier = nn.Sequential(
nn.Dropout(), nn.Dropout(),
nn.Linear(256 * 6 * 6, 4096), nn.Linear(256 * 6 * 6, 4096),
...@@ -48,10 +48,10 @@ class AlexNet(nn.Module): ...@@ -48,10 +48,10 @@ class AlexNet(nn.Module):
) )
def forward(self, x: flow.Tensor) -> flow.Tensor: def forward(self, x: flow.Tensor) -> flow.Tensor:
x = self.features(x) # x = self.features(x)
x = self.avgpool(x) x = self.avgpool(x)
x = flow.flatten(x, 1) # x = flow.flatten(x, 1)
x = self.classifier(x) # x = self.classifier(x)
return x return x
alexnet = AlexNet() alexnet = AlexNet()
......
...@@ -188,6 +188,8 @@ def FlowOnnxMapping(g, ops_mapping): ...@@ -188,6 +188,8 @@ def FlowOnnxMapping(g, ops_mapping):
continue continue
op = node.op_type op = node.op_type
if op == "output":
continue
map_info = ops_mapping.get(op) map_info = ops_mapping.get(op)
if map_info is None: if map_info is None:
unmapped_op[op] += 1 unmapped_op[op] += 1
......
...@@ -122,12 +122,11 @@ def convert_to_onnx_and_check( ...@@ -122,12 +122,11 @@ def convert_to_onnx_and_check(
ipt_dict, onnx_res = run_onnx( ipt_dict, onnx_res = run_onnx(
onnx_model_path, ["CPUExecutionProvider"], ort_optimize=ort_optimize onnx_model_path, ["CPUExecutionProvider"], ort_optimize=ort_optimize
) )
oneflow_res = graph(*ipt_dict.values()) oneflow_res = graph(flow.tensor(*ipt_dict.values(), dtype=flow.float32))
if not isinstance(oneflow_res, np.ndarray): if not isinstance(oneflow_res, np.ndarray):
oneflow_res = oneflow_res.get().numpy() oneflow_res = oneflow_res.numpy()
compare_result(oneflow_res, onnx_res, print_outlier=print_outlier) compare_result(oneflow_res, onnx_res, print_outlier=print_outlier)
flow.clear_default_session()
# cleanup() # cleanup()
...@@ -182,7 +182,7 @@ class Node(object): ...@@ -182,7 +182,7 @@ class Node(object):
# return self.op_type in ["Const", "ConstV2"] # return self.op_type in ["Const", "ConstV2"]
def is_graph_output(self): def is_graph_output(self):
return self.op_type in ["return"] return self.op_type in ["output"]
def is_graph_input(self): def is_graph_input(self):
return self.op_type in ["input"] return self.op_type in ["input"]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册