提交 eb48eac9 编写于 作者: S SunAhong1993

get inference model

上级 e2a1d944
......@@ -88,6 +88,12 @@ def arg_parser():
action="store_true",
default=False,
help="define whether merge the params")
parser.add_argument(
"--input_shapes",
"-is",
action='append',
default=[],
help="define the inputs' shape")
return parser
......@@ -174,7 +180,7 @@ def onnx2paddle(model_path, save_dir, params_merge=False):
print("Paddle model and code generated.")
def pytorch2paddle(model_path, save_dir):
def pytorch2paddle(model_path, save_dir, input_shapes):
# check pytorch installation and version
try:
import torch
......@@ -201,7 +207,13 @@ def pytorch2paddle(model_path, save_dir):
graph_opt = GraphOptimizer()
graph_opt.optimize(mapper.graph)
print("Model optimized.")
mapper.graph.gen_model(save_dir)
real_input_shapes = list()
for shape in input_shapes:
sp = shape[1:-1].split(",")
for i, s in enumerate(sp):
sp[i] = int(s)
real_input_shapes.append(sp)
mapper.graph.gen_model(save_dir, real_input_shapes)
def paddle2onnx(model_path, save_dir, opset_version=10):
......@@ -275,7 +287,7 @@ def main():
onnx2paddle(args.model, args.save_dir, params_merge)
elif args.framework == "pytorch":
assert args.model is not None, "--model should be defined while translating pytorch model"
pytorch2paddle(args.model, args.save_dir)
pytorch2paddle(args.model, args.save_dir, args.input_shapes)
elif args.framework == "paddle2onnx":
assert args.model is not None, "--model should be defined while translating paddle model to onnx"
......
......@@ -15,6 +15,8 @@
from __future__ import print_function
from __future__ import division
import paddle.fluid as fluid
import os.path as osp
import paddle
from paddle.fluid.proto import framework_pb2
from collections import OrderedDict
import numpy
......@@ -204,7 +206,7 @@ class PaddleGraph(object):
indent=1)
f.close()
def gen_model(self, save_dir):
def gen_model(self, save_dir, input_shapes):
if not os.path.exists(save_dir):
os.makedirs(save_dir)
if self.graph_type == "static":
......@@ -242,6 +244,7 @@ class PaddleGraph(object):
else:
self.gen_dygraph_code(save_dir)
self.dump_dygraph_parameter(save_dir)
self.dygraph2static(save_dir, input_shapes) #[[None, 3, 224, 224]]
def dump_parameter(self, param_name, param, save_dir):
if not os.path.exists(save_dir):
......@@ -342,7 +345,7 @@ class PaddleGraph(object):
indent=1))
def write_code(code_dir):
f = open(os.path.join(code_dir, 'code.py'), 'w')
f = open(os.path.join(code_dir, 'x2paddle_code.py'), 'w')
for code_line in self.head:
f.write(code_line)
init_writen_codes = []
......@@ -441,3 +444,24 @@ class PaddleGraph(object):
params_output = open(os.path.join(code_dir, 'model.pdparams'), 'wb')
pickle.dump(self.parameters, params_output)
params_output.close()
def dygraph2static(self, save_dir, input_shapes=[]):
from paddle.fluid.dygraph.jit import declarative
sepc_list = list()
for i, name in enumerate(self.inputs):
sepc_list.append(
paddle.static.InputSpec(
shape=input_shapes[i], name=name))
import sys
path = osp.abspath(save_dir)
sys.path.insert(0, save_dir)
import x2paddle_code
place = fluid.CPUPlace()
with fluid.dygraph.guard(place):
restore, _ = fluid.load_dygraph(osp.join(save_dir, "model"))
model = getattr(x2paddle_code, self.name)(restore)
model.set_dict(restore)
model.eval()
model.forward = declarative(model.forward, sepc_list)
fluid.dygraph.jit.save(
layer=model, model_path=osp.join(save_dir, "inference"))
......@@ -646,8 +646,8 @@ def aten_conv2d(mapper, graph, node):
# 处理输入1,即%25
weights = mapper.pytorch_params[inputs_name[1]]
mapper.paddle_params[conv2d_name + ".weight"] = weights
layer_attrs["num_filters"] = weights.shape[0]
layer_attrs["filter_size"] = weights.shape[2:]
layer_attrs["out_channels"] = weights.shape[0]
layer_attrs["kernel_size"] = weights.shape[2:]
# 处理输入2,即%27
if inputs_name[2] in mapper.pytorch_params:
bias = mapper.pytorch_params[inputs_name[2]]
......@@ -665,11 +665,10 @@ def aten_conv2d(mapper, graph, node):
layer_attrs["dilation"] = mapper.attrs[inputs_name[5]]
# 处理输入6,即%26
layer_attrs["groups"] = mapper.attrs[inputs_name[6]]
layer_attrs['num_channels'] = weights.shape[1] * mapper.attrs[inputs_name[
6]]
layer_attrs['in_channels'] = weights.shape[1] * mapper.attrs[inputs_name[6]]
graph.add_layer(
"paddle.nn.Conv2D",
"paddle.nn.Conv2d",
inputs=layer_inputs,
outputs=layer_outputs,
**layer_attrs)
......
......@@ -18,6 +18,8 @@ from .batchnorm2d_fuser import BatchNorm2dFuser
from .batchnorm2d_fuse_pass import BatchNorm2dFusePass
from .constant_fuser import ConstantFuser
from .constant_fuse_pass import ConstantFusePass
from .dropout_fuser import DropoutFuser
from .dropout_fuse_pass import DropoutFusePass
from .fc_fuser import FcFuser
from .fc_fuse_pass import FcFusePass
from .interpolate_bilinear_fuser import InterpolateBilinearFuser
......
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.optimizer.pass_ import Pass
from x2paddle.optimizer.fusion import DropoutFuser
from x2paddle.optimizer.pass_manager import pass_register
@pass_register
class DropoutFusePass(Pass):
name = "dropout_fuse_pass"
def __init__(self):
Pass.__init__(self)
def apply(self, graph):
fuser = DropoutFuser()
fuser.operate(graph, match_kind="topo")
# 用于注册
dropout_fuse_pass = DropoutFuser()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from x2paddle.optimizer.pattern_matcher import FuseBase
from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import *
class DropoutFuser(FuseBase):
def __init__(self):
super(DropoutFuser, self).__init__(graph_type="dygraph")
def build_pattern(self):
""" 描述需要替换的constant图结构。
constant层模式python实现代码示例:
x3 = 10
for _x70 in range(x3):
...
"""
self.pattern.add_layer(
"paddle.nn.Dropout",
inputs={"input": "dropout-input-0"},
outputs=["dropout0", "x1"])
self.pattern.build(inputs={"input-0": "dropout-input-0"})
self.pattern.outputs = ["dropout0", "x1"]
def insert_new_layer(self, graph, parameters, matches):
def replace_value(layer_connect, match_name, match_input):
for k, v in layer_connect.inputs.items():
if v == match_name:
layer_connect.inputs[k] = match_input
break
if layer_connect.kernel == "prim.loop" or \
layer_connect.kernel == "prim.if":
for block in layer_connect.blocks:
for b_layer_id, b_layer in block.layers.items():
if block.edges_in.get(b_layer_id, 0) != 0 and \
-1 in block.edges_in[b_layer_id]:
replace_value(b_layer, match_name, match_input)
layer_id = list(matches.keys())[0]
layer = list(matches.values())[0]
layer_output_name = layer.outputs[1]
layer_input = layer.inputs["input"]
if graph.edges_out.get(layer_id, 0) != 0:
for layer_id_out in graph.edges_out[layer_id]:
layer_connect = graph.layers[layer_id_out]
replace_value(layer_connect, layer_output_name, layer_input)
......@@ -143,8 +143,8 @@ class FcFuser(FuseBase):
layer = matches[layers_id[6]]
bias_name = layer.attrs["value"][8:-2]
attrs = dict()
attrs["input_dim"] = parameters[weight_name].shape[1]
attrs["output_dim"] = parameters[weight_name].shape[0]
attrs["in_features"] = parameters[weight_name].shape[1]
attrs["out_features"] = parameters[weight_name].shape[0]
linear_name = "linear{}".format(self.linear_index)
self.linear_index += 1
parameters["{}.weight".format(linear_name)] = parameters[
......
......@@ -21,7 +21,7 @@ class GraphOptimizer(object):
self.passes = [
"interpolate_bilinear_fuse_pass", "fc_fuse_pass",
"adaptive_pool2d_fuse_pass", "batchnorm2d_fuse_pass",
"constant_fuse_pass", "reshape_fuse_pass"
"constant_fuse_pass", "reshape_fuse_pass", "dropout_fuse_pass"
]
def optimize(self, graph):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册