提交 6a78f20a 编写于 作者: S SunAhong1993

fix the program

上级 55bfc88f
......@@ -16,7 +16,6 @@
from __future__ import print_function
from __future__ import division
import paddle.fluid as fluid
import os.path as osp
import paddle
from paddle.fluid.proto import framework_pb2
from collections import OrderedDict
......@@ -26,6 +25,7 @@ import os
import six
import pickle
import numpy as np
from os import path as osp
class PaddleLayer(object):
......@@ -232,7 +232,7 @@ class PaddleGraph(object):
return update(self.layers)
def gen_model(self, save_dir, jit_type=None):
if not os.path.exists(save_dir):
if not osp.exists(save_dir):
os.makedirs(save_dir)
if self.graph_type == "static":
self.gen_static_model(save_dir)
......@@ -240,8 +240,8 @@ class PaddleGraph(object):
self.gen_dygraph_model(save_dir, jit_type)
def gen_static_model(self, save_dir):
code_dir = os.path.join(save_dir, 'model_with_code')
infer_dir = os.path.join(save_dir, 'inference_model')
code_dir = osp.join(save_dir, 'model_with_code')
infer_dir = osp.join(save_dir, 'inference_model')
self.gen_static_code(code_dir)
sys.path.append(code_dir)
import x2paddle_model
......@@ -254,13 +254,13 @@ class PaddleGraph(object):
inputs, outputs = x2paddle_model.x2paddle_net()
exe = fluid.Executor(fluid.CPUPlace())
exe.run(startup_program)
param_dir = os.path.join(code_dir, 'weights')
param_dir = osp.join(code_dir, 'weights')
for k, v in self.parameters.items():
if scope.find_var(k):
self.dump_parameter(k, v, param_dir)
def if_exist(var):
b = os.path.exists(
os.path.join(os.path.join(param_dir, var.name)))
b = osp.exists(
osp.join(osp.join(param_dir, var.name)))
return b
fluid.io.load_vars(
exe, param_dir, main_program, predicate=if_exist)
......@@ -282,6 +282,8 @@ class PaddleGraph(object):
self.gen_dygraph_code(save_dir)
self.dump_dygraph_parameter(save_dir)
# 动转静
code_path = osp.join(osp.abspath(save_dir), "x2paddle_code.py")
print("Exporting inference model from python code ('{}')... \n".format(code_path))
if len(self.inputs_info) > 0:
input_shapes = list()
input_types = list()
......@@ -290,13 +292,10 @@ class PaddleGraph(object):
input_types.append(self.inputs_info[input_name][1])
try:
self.dygraph2static(save_dir, input_shapes, input_types)
except Error as e:
print("The Dygraph2Static is failed! The possible reason are:\n" +
"1. The convertor of dygraph2static of current model is not supported yet.\n" +
"2. The convertor of pytorch2paddle is wrong. You can run the code of x2paddle_model.py in your save_dir to check whether the convertor of pytorch2paddle is wrong.\n" +
"The Error is: \n" +
e)
exit(0)
except Exception as e:
print("Fail to generate inference model! Problem happend while export inference model from python code '{}';\n".format(coda_path))
print("===================Error Information===============")
raise e
def gen_static_code(self, code_dir):
def write_code(f, code_list, indent=0):
......@@ -307,9 +306,9 @@ class PaddleGraph(object):
else:
f.write(indent_blank + code_line + '\n')
if not os.path.exists(code_dir):
if not osp.exists(code_dir):
os.makedirs(code_dir)
f = open(os.path.join(code_dir, 'x2paddle_model.py'), 'w')
f = open(osp.join(code_dir, 'x2paddle_model.py'), 'w')
write_code(
f, [
......@@ -372,7 +371,7 @@ class PaddleGraph(object):
def dump_parameter(self, param_name, param, save_dir):
if not os.path.exists(save_dir):
if not osp.exists(save_dir):
os.makedirs(save_dir)
dtype_map = {
"int16": [framework_pb2.VarType.INT16, 'h'],
......@@ -392,7 +391,7 @@ class PaddleGraph(object):
assert str(
param.dtype) in dtype_map, "Unknown dtype {} of params: {}.".format(
str(param.dtype), param_name)
fp = open(os.path.join(save_dir, param_name), 'wb')
fp = open(osp.join(save_dir, param_name), 'wb')
numpy.array([0], dtype='int32').tofile(fp)
numpy.array([0], dtype='int64').tofile(fp)
numpy.array([0], dtype='int32').tofile(fp)
......@@ -502,7 +501,7 @@ class PaddleGraph(object):
use_structured_name = False if self.source_type in ["tf", "onnx"] else True
self.run_func.extend(
gen_codes(["paddle.disable_static()",
"params = paddle.load('{}/model.pdparams')".format(os.path.abspath(code_dir)),
"params = paddle.load('{}/model.pdparams')".format(osp.abspath(code_dir)),
"model = {}()".format(self.name),
"model.set_dict(params, use_structured_name={})".format(use_structured_name),
"model.eval()",
......@@ -510,7 +509,7 @@ class PaddleGraph(object):
"return out"], indent=1))
def write_code(code_dir):
f = open(os.path.join(code_dir, 'x2paddle_code.py'), 'w')
f = open(osp.join(code_dir, 'x2paddle_code.py'), 'w')
for code_line in self.head:
f.write(code_line)
init_writen_codes = []
......@@ -622,7 +621,7 @@ class PaddleGraph(object):
return self.init_func, self.forward_func
def dump_dygraph_parameter(self, code_dir):
save_path = os.path.join(code_dir, 'model.pdparams')
save_path = osp.join(code_dir, 'model.pdparams')
paddle.save(self.parameters, save_path)
def dygraph2static(self, save_dir, input_shapes=[], input_types=[]):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册