提交 eb2eb0b8 编写于 作者: W wjj19950828

fixed for CI

上级 f81f905b
...@@ -81,7 +81,7 @@ ...@@ -81,7 +81,7 @@
"source": [ "source": [
"## 模型迁移\n", "## 模型迁移\n",
"### 1. 获取MobileNetV1的FrozenModel\n", "### 1. 获取MobileNetV1的FrozenModel\n",
"由于X2Paddle只支持TensorFlow中FrozenModel的转换,如果为纯checkpoint模型,需要参考参考X2Paddle官方[文档](https://github.com/PaddlePaddle/X2Paddle/blob/develop/docs/user_guides/export_tf_model.md),将其转换为FrozenModel,本示例中提供的模型为FrozenModel,所以无需转换。" "由于X2Paddle只支持TensorFlow中FrozenModel的转换,如果为纯checkpoint模型,需要参考参考X2Paddle官方[文档](https://github.com/PaddlePaddle/X2Paddle/blob/release-1.1/docs/user_guides/export_tf_model.md),将其转换为FrozenModel,本示例中提供的模型为FrozenModel,所以无需转换。"
] ]
}, },
{ {
......
...@@ -532,9 +532,9 @@ class PaddleGraph(object): ...@@ -532,9 +532,9 @@ class PaddleGraph(object):
paddle.save(self.parameters, save_path) paddle.save(self.parameters, save_path)
def dygraph2static(self, save_dir, input_shapes=[], input_types=[]): def dygraph2static(self, save_dir, input_shapes=[], input_types=[]):
sepc_list = list() spec_list = list()
for i, name in enumerate(self.inputs): for i, name in enumerate(self.inputs):
sepc_list.append( spec_list.append(
paddle.static.InputSpec( paddle.static.InputSpec(
shape=input_shapes[i], name=name, dtype=input_types[i])) shape=input_shapes[i], name=name, dtype=input_types[i]))
path = osp.abspath(save_dir) path = osp.abspath(save_dir)
...@@ -548,7 +548,7 @@ class PaddleGraph(object): ...@@ -548,7 +548,7 @@ class PaddleGraph(object):
else: else:
model.set_dict(restore) model.set_dict(restore)
model.eval() model.eval()
static_model = paddle.jit.to_static(model, input_spec=sepc_list) static_model = paddle.jit.to_static(model, input_spec=spec_list)
try: try:
paddle.jit.save(static_model, paddle.jit.save(static_model,
osp.join(save_dir, "inference_model/model")) osp.join(save_dir, "inference_model/model"))
......
...@@ -583,7 +583,8 @@ class ONNXDecoder(object): ...@@ -583,7 +583,8 @@ class ONNXDecoder(object):
item.name = self.make_variable_name(item.name) item.name = self.make_variable_name(item.name)
for node in graph.node: for node in graph.node:
node.name = node.output[0] node.name = node.output[0]
if ":" in node.name and len(node.output) > 1: if ":" in node.name and len(
node.output) > 1 and node.op_type != "LSTM":
node.name = node.name.split(':')[0] node.name = node.name.split(':')[0]
node.name = self.make_variable_name(node.name) node.name = self.make_variable_name(node.name)
for i in range(len(node.input)): for i in range(len(node.input)):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册