diff --git a/x2paddle/convert.py b/x2paddle/convert.py index 9d82425cb1c29dac14511d969d4065b097b31b43..b4566d128666a92296b4c243b7f551e80a658f36 100644 --- a/x2paddle/convert.py +++ b/x2paddle/convert.py @@ -151,7 +151,7 @@ def onnx2paddle(model_path, save_dir): print("Now translating model from onnx to paddle.") from x2paddle.decoder.onnx_decoder import ONNXDecoder - model = ONNXDecoder(model_path, save_dir) + model = ONNXDecoder(model_path) from x2paddle.op_mapper.onnx_op_mapper import ONNXOpMapper mapper = ONNXOpMapper(model, save_dir) diff --git a/x2paddle/decoder/onnx_decoder.py b/x2paddle/decoder/onnx_decoder.py index 73cdd4832d2509f1fdd441cb4c293e4a474e2785..8955f1e0453904c42ebc5f6eb449aab201a18cbf 100644 --- a/x2paddle/decoder/onnx_decoder.py +++ b/x2paddle/decoder/onnx_decoder.py @@ -132,16 +132,14 @@ class ONNXGraphDataNode(GraphNode): class ONNXGraph(Graph): - def __init__(self, onnx_model, save_dir): + def __init__(self, onnx_model): super(ONNXGraph, self).__init__(onnx_model.graph) self.onnx_model = onnx_model self.initializer = {} self.place_holder_nodes = list() self.get_place_holder_nodes() - self.tmp_data_dir = os.path.join(save_dir, 'tmp_data') self.value_infos = self.inferred_model_value_info(self.model) self.results_of_inference = dict() - self.is_inference = False def get_inner_nodes(self): """ @@ -295,7 +293,7 @@ class ONNXGraph(Graph): class ONNXDecoder(object): - def __init__(self, onnx_model, save_dir): + def __init__(self, onnx_model): model = onnx.load(onnx_model) print('model ir_version: {}, op version: {}'.format( model.ir_version, model.opset_import[0].version)) @@ -314,7 +312,7 @@ class ONNXDecoder(object): self.model = model graph = model.graph - self.onnx_graph = ONNXGraph(model, save_dir) + self.onnx_graph = ONNXGraph(model) self.onnx_graph.build() def build_value_refs(self, nodes):