未验证 提交 fd3c33a8 编写于 作者: S SunAhong1993 提交者: GitHub

Merge pull request #12 from PaddlePaddle/develop

new
......@@ -10,12 +10,13 @@ X2Paddle在多个主流的CV模型上,测试过TensorFlow/Caffe/ONNX模型的
## 环境依赖
python == 2.7 | python >= 3.5
paddlepaddle >= 1.8.0
paddlepaddle >= 2.0.0
**按需安装以下依赖**
tensorflow : tensorflow == 1.14.0
caffe : 无
onnx : onnx >= 1.6.0
pytorch:torch >=1.5.0 (script方式暂不支持1.7.0)
## 安装
### 安装方式一(推荐)
......@@ -45,10 +46,13 @@ x2paddle --framework=caffe --prototxt=deploy.prototxt --weight=deploy.caffemodel
x2paddle --framework=onnx --model=onnx_model.onnx --save_dir=pd_model
```
### PyTorch
> PyTorch不支持命令行使用方式,详见[PyTorch2Paddle](pytorch2paddle.md)
### Paddle2ONNX
```
Paddle2ONNX功能已迁移至新的github: https://github.com/PaddlePaddle/paddle2onnx, 欢迎大家去新的代码仓库查看详细介绍以及新功能。
```
> Paddle2ONNX功能已迁移至新的github: https://github.com/PaddlePaddle/paddle2onnx, 欢迎大家去新的代码仓库查看详细介绍以及新功能。
### 参数选项
| 参数 | |
|----------|--------------|
......
# PyTorch2Paddle
PyTorch2Paddle支持trace和script两种方式的转换,均是PyTorch动态图到Paddle动态图的转换,转换后的Paddle动态图运用动转静可转换为静态图模型。trace方式生成的代码可读性较强,较为接近原版PyTorch代码的组织结构;script方式不需要知道输入数据的类型和大小即可转换,使用上较为方便,但目前PyTorch支持的script代码方式有所限制,所以支持转换的代码也有所限制。用户可根据自身需求,选择转换方式。
## 环境依赖
python == 2.7 | python >= 3.5
paddlepaddle >= 1.8.0
pytorch:torch >=1.5.0 (script方式暂不支持1.7.0)
**使用trace方式需安装以下依赖**
pandas
treelib
## 使用方式
``` python
from x2paddle.convert import pytorch2paddle
pytorch2paddle(module=torch_module,
save_dir="./pd_model",
jit_type="trace",
input_examples=[torch_input])
# module (torch.nn.Module): PyTorch的Module。
# save_dir (str): 转换后模型的保存路径。
# jit_type (str): 转换方式。默认为"trace"。
# input_examples (list[torch.tensor]): torch.nn.Module的输入示例,list的长度必须与输入的长度一致。默认为None。
```
**注意:** 当jit_type为"trace"时,input_examples不可为None,转换后自动进行动转静;
当jit_type为"script"时",input_examples不为None时,才可以进行动转静。
## 使用示例
``` python
import torch
import numpy as np
from torchvision.models import AlexNet
from torchvision.models.utils import load_state_dict_from_url
# 构建输入
input_data = np.random.rand(1, 3, 224, 224).astype("float32")
# 获取PyTorch Module
torch_module = AlexNet()
torch_state_dict = load_state_dict_from_url('https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth')
torch_module.load_state_dict(torch_state_dict)
# 设置为eval模式
torch_module.eval()
# 进行转换
from x2paddle.convert import pytorch2paddle
pytorch2paddle(torch_model,
save_dir="pd_model_trace",
jit_type="trace",
input_examples=[torch.tensor(input_data)])
```
......@@ -4,6 +4,7 @@
```
python tools/check_for_lite.py paddle_model/inference_model/__model__
```
> 附:check_for_lite工具并不能完全判断模型是否被支持,PaddleLite详细支持的算子请参考[PaddleLite支持算子集](https://github.com/PaddlePaddle/Paddle-Lite/blob/develop/docs/introduction/support_operation_list.md)
### 二、模型参数合并
X2Paddle转换后产出的路径下包括两个目录,
......
__version__ = "0.8.8"
__version__ = "1.0.0rc0"
from .core.program import PaddleGraph
......
......@@ -13,6 +13,7 @@
# limitations under the License.
from six import text_type as _text_type
from x2paddle import program
import argparse
import sys
......@@ -63,12 +64,6 @@ def arg_parser():
action="store_true",
default=False,
help="get version of x2paddle")
parser.add_argument(
"--without_data_format_optimization",
"-wo",
type=_text_type,
default="True",
help="tf model conversion without data format optimization")
parser.add_argument(
"--define_input_shape",
"-d",
......@@ -82,11 +77,19 @@ def arg_parser():
default=False,
help="define whether merge the params")
parser.add_argument(
"--input_shapes",
"-is",
action='append',
default=None,
help="define the inputs' shape")
"--paddle_type",
"-pt",
type=_text_type,
default="dygraph",
help="define the paddle model type after converting(dygraph/static)"
)
parser.add_argument(
"--without_data_format_optimization",
"-wo",
type=_text_type,
default="True",
help="tf model conversion without data format optimization")
return parser
......@@ -94,6 +97,7 @@ def tf2paddle(model_path,
save_dir,
without_data_format_optimization=False,
define_input_shape=False,
paddle_type="dygraph",
params_merge=False):
# check tensorflow installation and version
try:
......@@ -111,33 +115,46 @@ def tf2paddle(model_path,
"[ERROR] Tensorflow is not installed, use \"pip install tensorflow\"."
)
return
from x2paddle import program
from x2paddle.decoder.tf_decoder import TFDecoder
from x2paddle.op_mapper.tf_op_mapper import TFOpMapper
from x2paddle.optimizer.tensorflow.bias import BiasOpt
from x2paddle.optimizer.tensorflow.transpose import TransposeOpt
from x2paddle.optimizer.tensorflow.batch_norm import BatchNormOpt
from x2paddle.optimizer.tensorflow.prelu import PReLUOpt
if paddle_type == "dygraph":
from x2paddle.op_mapper.dygraph.tf2paddle.tf_op_mapper import TFOpMapper
else:
from x2paddle.op_mapper.static.tf2paddle.tf_op_mapper import TFOpMapper
print("Now translating model from tensorflow to paddle.")
model = TFDecoder(model_path, define_input_shape=define_input_shape)
mapper = TFOpMapper(model)
program.build()
bias_opt = BiasOpt()
transpose_opt = TransposeOpt()
batch_norm_opt = BatchNormOpt()
prelu_opt = PReLUOpt()
bias_opt.run(program)
batch_norm_opt.run(program)
prelu_opt.run(program)
transpose_opt.run(program)
program.gen_model(save_dir)
mapper.paddle_graph.build()
if paddle_type == "dygraph":
from x2paddle.optimizer.optimizer import GraphOptimizer
graph_opt = GraphOptimizer(source_frame="tf", paddle_type=paddle_type)
graph_opt.optimize(mapper.paddle_graph)
else:
from x2paddle.optimizer.tensorflow.bias import BiasOpt
from x2paddle.optimizer.tensorflow.transpose import TransposeOpt
from x2paddle.optimizer.tensorflow.batch_norm import BatchNormOpt
from x2paddle.optimizer.tensorflow.prelu import PReLUOpt
bias_opt = BiasOpt()
transpose_opt = TransposeOpt()
batch_norm_opt = BatchNormOpt()
prelu_opt = PReLUOpt()
bias_opt.run(mapper.paddle_graph)
batch_norm_opt.run(mapper.paddle_graph)
prelu_opt.run(mapper.paddle_graph)
transpose_opt.run(mapper.paddle_graph)
mapper.paddle_graph.gen_model(save_dir)
def caffe2paddle(proto, weight, save_dir, caffe_proto, params_merge=False):
def caffe2paddle(proto, weight, save_dir, caffe_proto,
paddle_type, params_merge=False):
from x2paddle.decoder.caffe_decoder import CaffeDecoder
from x2paddle.op_mapper.caffe_op_mapper import CaffeOpMapper
from x2paddle.optimizer.caffe_optimizer import CaffeOptimizer
if paddle_type == "dygraph":
from x2paddle.op_mapper.dygraph.caffe2paddle.caffe_op_mapper import CaffeOpMapper
else:
from x2paddle.op_mapper.static.caffe2paddle.caffe_op_mapper import CaffeOpMapper
import google.protobuf as gpb
ver_part = gpb.__version__.split('.')
version_satisfy = False
......@@ -148,13 +165,16 @@ def caffe2paddle(proto, weight, save_dir, caffe_proto, params_merge=False):
print("Now translating model from caffe to paddle.")
model = CaffeDecoder(proto, weight, caffe_proto)
mapper = CaffeOpMapper(model)
optimizer = CaffeOptimizer(mapper)
optimizer.merge_bn_scale()
optimizer.merge_op_activation()
mapper.save_inference_model(save_dir, params_merge)
mapper.paddle_graph.build()
print("Model optimizing ...")
from x2paddle.optimizer.optimizer import GraphOptimizer
graph_opt = GraphOptimizer(source_frame="caffe", paddle_type=paddle_type)
graph_opt.optimize(mapper.paddle_graph)
print("Model optimized.")
mapper.paddle_graph.gen_model(save_dir)
def onnx2paddle(model_path, save_dir, params_merge=False):
def onnx2paddle(model_path, save_dir, paddle_type, params_merge=False):
# check onnx installation and version
try:
import onnx
......@@ -167,22 +187,26 @@ def onnx2paddle(model_path, save_dir, params_merge=False):
return
print("Now translating model from onnx to paddle.")
from x2paddle.op_mapper.onnx2paddle.onnx_op_mapper import ONNXOpMapper
from x2paddle.decoder.onnx_decoder import ONNXDecoder
from x2paddle.optimizer.onnx_optimizer import ONNXOptimizer
if paddle_type == "dygraph":
from x2paddle.op_mapper.dygraph.onnx2paddle.onnx_op_mapper import ONNXOpMapper
else:
from x2paddle.op_mapper.static.onnx2paddle.onnx_op_mapper import ONNXOpMapper
model = ONNXDecoder(model_path)
mapper = ONNXOpMapper(model)
print("Model optimizing ...")
optimizer = ONNXOptimizer(mapper)
optimizer.delete_redundance_code()
print("Model optimized.")
print("Paddle model and code generating ...")
mapper.save_inference_model(save_dir, params_merge)
print("Paddle model and code generated.")
if paddle_type == "dygraph":
mapper.paddle_graph.build()
mapper.paddle_graph.gen_model(save_dir)
else:
from x2paddle.optimizer.onnx_optimizer import ONNXOptimizer
print("Model optimizing ...")
optimizer = ONNXOptimizer(mapper)
optimizer.delete_redundance_code()
print("Model optimized.")
mapper.save_inference_model(save_dir, params_merge)
def pytorch2paddle(model_path, save_dir, input_shapes):
def pytorch2paddle(module, save_dir, jit_type="trace", input_examples=None):
# check pytorch installation and version
try:
import torch
......@@ -198,27 +222,22 @@ def pytorch2paddle(model_path, save_dir, input_shapes):
)
return
print("Now translating model from pytorch to paddle.")
from x2paddle.decoder.pytorch_decoder import ScriptDecoder, TraceDecoder
from x2paddle.op_mapper.dygraph.pytorch2paddle.pytorch_op_mapper import PyTorchOpMapper
from x2paddle.decoder.pytorch_decoder import PyTorchDecoder
from x2paddle.op_mapper.pytorch2paddle import pytorch_op_mapper
model = PyTorchDecoder(model_path)
mapper = pytorch_op_mapper.PyTorchOpMapper(model)
mapper.graph.build()
if jit_type == "trace":
model = TraceDecoder(module, input_examples)
else:
model = ScriptDecoder(module, input_examples)
mapper = PyTorchOpMapper(model)
mapper.paddle_graph.build()
print("Model optimizing ...")
from x2paddle.optimizer.pytorch_optimizer.optimizer import GraphOptimizer
graph_opt = GraphOptimizer()
graph_opt.optimize(mapper.graph)
from x2paddle.optimizer.optimizer import GraphOptimizer
graph_opt = GraphOptimizer(source_frame="pytorch", paddle_type="dygraph", jit_type=jit_type)
graph_opt.optimize(mapper.paddle_graph)
print("Model optimized.")
if input_shapes is not None:
real_input_shapes = list()
for shape in input_shapes:
sp = shape[1:-1].split(",")
for i, s in enumerate(sp):
sp[i] = int(s)
real_input_shapes.append(sp)
else:
real_input_shapes = None
mapper.graph.gen_model(save_dir, real_input_shapes)
mapper.paddle_graph.gen_model(save_dir, jit_type=jit_type)
def main():
......@@ -239,6 +258,7 @@ def main():
assert args.framework is not None, "--framework is not defined(support tensorflow/caffe/onnx)"
assert args.save_dir is not None, "--save_dir is not defined"
assert args.paddle_type in ["dygraph", "static"], "--paddle_type must be 'dygraph' or 'static'"
try:
import paddle
......@@ -246,8 +266,8 @@ def main():
print("paddle.__version__ = {}".format(paddle.__version__))
if v0 == '0' and v1 == '0' and v2 == '0':
print("[WARNING] You are use develop version of paddlepaddle")
elif int(v0) != 1 or int(v1) < 6:
print("[ERROR] paddlepaddle>=1.6.0 is required")
elif int(v0) != 2 or int(v1) < 0:
print("[ERROR] paddlepaddle>=2.0.0 is required")
return
except:
print(
......@@ -267,7 +287,7 @@ def main():
if args.params_merge:
params_merge = True
tf2paddle(args.model, args.save_dir, without_data_format_optimization,
define_input_shape, params_merge)
define_input_shape, args.paddle_type, params_merge)
elif args.framework == "caffe":
assert args.prototxt is not None and args.weight is not None, "--prototxt and --weight should be defined while translating caffe model"
......@@ -275,21 +295,20 @@ def main():
if args.params_merge:
params_merge = True
caffe2paddle(args.prototxt, args.weight, args.save_dir,
args.caffe_proto, params_merge)
args.caffe_proto, args.paddle_type, params_merge)
elif args.framework == "onnx":
assert args.model is not None, "--model should be defined while translating onnx model"
params_merge = False
if args.params_merge:
params_merge = True
onnx2paddle(args.model, args.save_dir, params_merge)
onnx2paddle(args.model, args.save_dir, args.paddle_type, params_merge)
elif args.framework == "paddle2onnx":
print("Paddle to ONNX tool has been migrated to the new github: https://github.com/PaddlePaddle/paddle2onnx")
else:
raise Exception(
"--framework only support tensorflow/caffe/onnx/ now")
"--framework only support tensorflow/caffe/onnx now")
if __name__ == "__main__":
......
......@@ -41,6 +41,8 @@ class Layer(object):
layer_code = layer_code
elif self.use_fluid:
layer_code = layer_code + "fluid." + self.op + "("
elif self.op == "full_like":
layer_code = layer_code + "paddle." + self.op + "("
else:
layer_code = layer_code + "fluid.layers." + self.op + "("
......
......@@ -128,6 +128,7 @@ class OpMapper(object):
self.add_codes("from paddle.fluid.initializer import Constant")
self.add_codes("from paddle.fluid.param_attr import ParamAttr")
self.add_codes("import paddle.fluid as fluid")
self.add_codes("import paddle")
self.add_codes("")
def save_inference_model(self, save_dir, params_merge):
......@@ -214,6 +215,7 @@ class OpMapper(object):
self.add_codes("", 0)
self.add_codes("\ndef x2paddle_net():", 0)
self.add_codes("paddle.enable_static()", 1)
for i in range(len(self.graph.topo_sort)):
node_name = self.graph.topo_sort[i]
node = self.graph.get_node(node_name)
......
此差异已折叠。
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
......@@ -18,3 +18,11 @@ import os
def string(param):
return "\'{}\'".format(param)
def name_generator(nn_name, nn_name2id):
if nn_name in nn_name2id:
nn_name2id[nn_name] += 1
else:
nn_name2id[nn_name] = 0
real_nn_name = nn_name + str(nn_name2id[nn_name])
return real_nn_name
\ No newline at end of file
......@@ -18,7 +18,7 @@ from google.protobuf import text_format
import numpy as np
from x2paddle.core.graph import GraphNode, Graph
from x2paddle.core.fluid_code import FluidCode
from x2paddle.op_mapper import caffe_shape
from x2paddle.decoder import caffe_shape_inference
class CaffeResolver(object):
......@@ -50,22 +50,51 @@ class CaffeGraphNode(GraphNode):
def __init__(self, layer, type_str, layer_name=None):
if layer_name is None:
super(CaffeGraphNode, self).__init__(
layer, layer.name.replace('/', '_').replace('-', '_'))
layer, layer.name.replace('/', '_').replace('-', '_').lower())
else:
super(CaffeGraphNode, self).__init__(
layer, layer_name.replace('/', '_').replace('-', '_'))
layer, layer_name.replace('/', '_').replace('-', '_').lower())
self.layer_type = type_str
self.fluid_code = FluidCode()
self.data = None
def set_params(self, params):
self.data = params
@property
def name(self):
if hasattr(self, 'index'):
return "{}_p{}".format(self.layer_name, self.index)
return self.layer_name
@property
def out_shapes(self):
return self._out_shapes
@out_shapes.setter
def out_shapes(self, value):
self._out_shapes = value
@property
def in_shapes(self):
return self._in_shapes
@in_shapes.setter
def in_shapes(self, value):
self._in_shapes = value
class CaffeGraph(Graph):
def __init__(self, model, params, caffe_pb):
self.params = params
self.caffe_pb = caffe_pb
if hasattr(model, "name"):
if model.name == "":
self.graph_name = "CaffeModel"
else:
self.graph_name = model.name
else:
self.graph_name = "CaffeModel"
super(CaffeGraph, self).__init__(model)
def filter_layers(self, layers):
......@@ -220,8 +249,11 @@ class CaffeGraph(Graph):
layer_name)
super(CaffeGraph, self).build()
for i, node_name in enumerate(self.topo_sort):
node = self.get_node(node_name)
self.set_node_shape(node)
def get_bottom_node(self, node, idx=0, copy=False):
def get_input_node(self, node, idx=0, copy=False):
input_node_name = node.inputs[idx]
assert input_node_name in self.node_map, 'The {} isn\'t a valid node'.format(
name)
......@@ -232,6 +264,19 @@ class CaffeGraph(Graph):
else:
name = input_node_name
return self.get_node(name, copy=copy)
def set_node_shape(self, node):
inputs = node.inputs
input_shape = []
for i, nm in enumerate(inputs):
last_node = self.get_node(nm)
tmp = node.layer.bottom[i]
idx = list(last_node.layer.top).index(tmp)
input_shape.append(last_node.out_shapes[idx])
node.in_shapes = input_shape
func_name = 'shape_' + node.layer_type.lower()
node.out_shapes = getattr(caffe_shape_inference, func_name)(node.layer,
input_shape)
class CaffeDecoder(object):
......@@ -244,7 +289,7 @@ class CaffeDecoder(object):
with open(proto_path, 'rb') as proto_file:
proto_str = proto_file.read()
text_format.Merge(proto_str, self.net)
self.load_using_pb()
self.caffe_graph = CaffeGraph(self.net, self.params,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numbers
from functools import reduce
def get_kernel_parameters(params):
[k_h, k_w] = [1, 1]
if isinstance(params.kernel_size, numbers.Number):
[k_h, k_w] = [params.kernel_size] * 2
elif len(params.kernel_size) > 0:
k_h = params.kernel_h if params.kernel_h > 0 else params.kernel_size[0]
k_w = params.kernel_w if params.kernel_w > 0 else params.kernel_size[
len(params.kernel_size) - 1]
elif params.kernel_h > 0 or params.kernel_w > 0:
k_h = params.kernel_h
k_w = params.kernel_w
[s_h, s_w] = [1, 1]
if isinstance(params.stride, numbers.Number):
[s_h, s_w] = [params.stride] * 2
elif len(params.stride) > 0:
s_h = params.stride_h if params.stride_h > 0 else params.stride[0]
s_w = params.stride_w if params.stride_w > 0 else params.stride[len(
params.stride) - 1]
elif params.stride_h > 0 or params.stride_w > 0:
s_h = params.stride_h
s_w = params.stride_w
[p_h, p_w] = [0, 0]
if isinstance(params.pad, numbers.Number):
[p_h, p_w] = [params.pad] * 2
elif len(params.pad) > 0:
p_h = params.pad_h if params.pad_h > 0 else params.pad[0]
p_w = params.pad_w if params.pad_w > 0 else params.pad[len(params.pad) -
1]
elif params.pad_h > 0 or params.pad_w > 0:
p_h = params.pad_h
p_w = params.pad_w
dila_h = dila_w = 1
if hasattr(params, 'dilation'):
dila_len = len(params.dilation)
if dila_len == 2:
dila_h = params.dilation[0]
dila_w = params.dilation[1]
elif dila_len == 1:
dila_h = dila_w = params.dilation[0]
else:
assert dila_len == 0, "invalid length[%s] of dilation in convolution" % (
dila_len)
return dila_h, dila_w, p_h, p_w, k_h, k_w, s_h, s_w
def get_strided_kernel_output_shape(params, input_shape, round_func):
i_h = input_shape[2]
i_w = input_shape[3]
dila_h, dila_w, pad_h, pad_w, kernel_h, kernel_w, stride_h, stride_w = get_kernel_parameters(
params)
o_h = (i_h + 2 * pad_h - (dila_h *
(kernel_h - 1) + 1)) / float(stride_h) + 1
o_w = (i_w + 2 * pad_w - (dila_w *
(kernel_w - 1) + 1)) / float(stride_w) + 1
o_h = int(round_func(o_h))
o_w = int(round_func(o_w))
has_c_o = hasattr(params, 'num_output')
c = params.num_output if has_c_o else input_shape[1]
return [[input_shape[0], c, o_h, o_w]]
def shape_convolution(layer, input_shape):
params = layer.convolution_param
return get_strided_kernel_output_shape(params, input_shape[0], math.floor)
def shape_depthwiseconvolution(layer, input_shape):
return shape_convolution(layer, input_shape)
def shape_deconvolution(layer, input_shape):
h_i = input_shape[0][2]
w_i = input_shape[0][3]
params = layer.convolution_param
dila_h, dila_w, pad_h, pad_w, kernel_h, kernel_w, stride_h, stride_w = get_kernel_parameters(
params)
h_o = (h_i - 1) * stride_h - 2 * pad_h + dila_h * (kernel_h - 1) + 1
w_o = (w_i - 1) * stride_w - 2 * pad_w + dila_w * (kernel_w - 1) + 1
has_c_o = hasattr(params, 'num_output')
c = params.num_output if has_c_o else input_shape.channels
return [[input_shape[0][0], c, h_o, w_o]]
def shape_pooling(layer, input_shape):
params = layer.pooling_param
global_pool = getattr(params, 'global_pooling', False)
if global_pool:
return [[input_shape[0][0], input_shape[0][1], 1, 1]]
ceil_mode = getattr(params, 'ceil_mode', True)
if ceil_mode is True:
method = math.ceil
else:
method = math.floor
return get_strided_kernel_output_shape(params, input_shape[0], method)
def shape_convolutiondepthwise(layer, input_shape):
params = layer.convolution_param
return get_strided_kernel_output_shape(params, input_shape[0], math.floor)
def shape_innerproduct(layer, input_shape):
params = layer.inner_product_param
return [[input_shape[0][0], params.num_output]]
def shape_lrn(layer, input_shape):
return input_shape
def shape_relu(layer, input_shape):
return input_shape
def shape_softmax(layer, input_shape):
return input_shape
def shape_input(layer, input_shape):
return [list(layer.input_param.shape[0].dim)]
def shape_memorydata(layer, input_shape):
params = layer.memory_data_param
shape = []
shape.append(int(params.batch_size))
shape.append(int(params.channels))
shape.append(int(params.height))
shape.append(int(params.width))
return [shape]
def shape_concat(layer, input_shape):
params = layer.concat_param
axis = params.axis
output_shape = None
for shape in input_shape:
if output_shape is None:
output_shape = []
for i in range(len(shape)):
output_shape.append(shape[i])
else:
output_shape[axis] += shape[axis]
return [output_shape]
def shape_slice(layer, input_shape):
inshape = input_shape[0]
top_len = len(layer.top)
params = layer.slice_param
axis = params.axis
slice_dim = params.slice_dim
if slice_dim != 1 and axis == 1:
axis = slice_dim
points = list(params.slice_point)
count = inshape[axis]
if len(points) == 0:
assert count % top_len == 0, "the parameter of Slice is wrong"
part = count / top_len
t = part
while t < count:
points.append(int(t))
t += part
points = [0] + points + [count]
output_shape = []
for i in range(len(points)):
shape = []
for ii in range(len(inshape)):
shape.append(inshape[ii])
size = points[i + 1] - points[i]
shape[axis] = size
output_shape.append(shape)
if i == len(points) - 2:
break
return output_shape
def shape_prelu(layer, input_shape):
return input_shape
def shape_sigmoid(layer, input_shape):
return input_shape
def shape_absval(layer, input_shape):
return input_shape
def shape_accuracy(layer, input_shape):
return [[1]]
def shape_tanh(layer, input_shape):
return input_shape
def shape_eltwise(layer, input_shape):
return [input_shape[0]]
def shape_batchnorm(layer, input_shape):
return input_shape
def shape_scale(layer, input_shape):
return input_shape
def shape_reshape(layer, input_shape):
def count(num_list):
return reduce(lambda a, b: a * b, num_list)
inshape = input_shape[0]
params = layer.reshape_param
axis = params.axis if hasattr(params, 'axis') else 0
num_axes = params.num_axes if hasattr(params, 'num_axes') else -1
if inshape[0] == -1:
inshape[0] = 1
input_count = count(inshape)
input_num_axes = len(inshape)
input_start_axis = axis
start_axis = input_start_axis if input_start_axis >= 0 \
else input_num_axes + input_start_axis + 1
assert start_axis >= 0, "[Reshape]axis %d out of range" % (input_start_axis)
assert start_axis <= input_num_axes, "[Reshape]axis %d out of range for %d-D input data"\
% (input_start_axis, input_num_axes)
assert num_axes >= -1, "[Reshape]num_axes must be >= 0, or -1 for all"
end_axis = input_num_axes if num_axes == -1 else start_axis + num_axes
assert end_axis <= input_num_axes, "end_axis[%d] = axis[%d] + num_axes[%d] is out of range"\
% (end_axis, start_axis, num_axes)
num_axes_replaced = end_axis - start_axis
num_axes_retained = input_num_axes - num_axes_replaced
num_new_axes = len(list(params.shape.dim))
output_shape = []
for i in range(start_axis):
output_shape.append(inshape[i])
for i in range(num_new_axes):
output_shape.append(params.shape.dim[i])
for i in range(end_axis, input_num_axes):
output_shape.append(inshape[i])
assert len(output_shape) == num_axes_retained + num_new_axes,\
"[Reshape]invalid dims of output shape[%s]" % (str(output_shape))
inferred_axis = -1
copy_axes = []
constant_count = 1
for i in range(num_new_axes):
top_dim = params.shape.dim[i]
if top_dim == 0:
copy_axes.append(i)
copy_axis_index = start_axis + i
output_shape[copy_axis_index] = inshape[copy_axis_index]
elif top_dim == -1:
assert inferred_axis == -1, "[Reshape]new shape contains multiple -1 dims"
inferred_axis = i
else:
constant_count *= top_dim
if inferred_axis >= 0:
explicit_count = constant_count
l = inshape[0:start_axis]
if len(l) > 0:
explicit_count *= count(l)
l = inshape[end_axis:]
if len(l) > 0:
explicit_count *= count(l)
for i in range(len(copy_axes)):
explicit_count *= output_shape[start_axis + copy_axes[i]]
assert input_count % explicit_count == 0, "[Reshape]botom count[%d] "\
"must be divisible by product of the specified dimensions[%d] "\
% (input_count, explicit_count)
output_shape[start_axis + inferred_axis] = int(input_count / explicit_count)
output_count = count(output_shape)
assert output_count == input_count, "[Reshape]output count[%d] must match input count[%d]" % (
output_count, input_count)
output_shape[0] = -1
return [output_shape]
def shape_argmax(layer, input_shape):
inshape = input_shape[0]
params = layer.argmax_param
out_max_val = params.out_max_val if hasattr(params, out_max_val) else False
top_k = params.top_k if hasattr(params, top_k) else 1
axis = parmas.axis if hasattr(params, axis) else -1
if axis < 0:
axis += len(inshape)
assert (axis + 1 == len(inshape)
), 'only can be applied on the last dimension[axis:%d, %s] now,'\
'make sure you have set axis param in xxx.prototxt file' \
% (axis, str(inshape))
output_shape = inshape
output_shape[-1] = top_k
if out_max_val is True:
output_shape[-1] *= 2
return [output_shape]
def shape_crop(layer, input_shape):
assert len(input_shape) == 2, "the number of crop's inputs must be 2"
return [input_shape[1]]
def shape_flatten(layer, input_shape):
assert len(input_shape) == 1, "the number of flatten's inputs must be 1"
inshape = input_shape[0]
params = layer.flatten_param
start_axis = params.axis
end_axis = params.end_axis
if start_axis < 0:
start_axis += len(inshape)
if end_axis < 0:
end_axis += len(inshape) + 1
assert start_axis <= end_axis, 'invalid axis[%d] or end_axis[%d] params'\
% (start_axis, end_axis)
output_shape = inshape[0:start_axis]
if len(inshape[start_axis:end_axis]) != 0:
flat_sz = reduce(lambda a, b: a * b, inshape[start_axis:end_axis])
output_shape += [flat_sz]
output_shape += inshape[end_axis:len(inshape)]
output_shape[0] = -1
return [output_shape]
def shape_power(layer, input_shape):
return input_shape
def shape_reduction(layer, input_shape):
params = layer.reduction_param
axis = params.axis
if axis < 0:
axis += len(input_shape[0]) + 1
assert axis <= len(input_shape[0]), 'invalid axis[%d] error' % (axis)
return [input_shape[0:axis]]
def shape_axpy(layer, input_shape):
assert len(input_shapes) == 3, "not valid input shape for axpy layer"
assert len(input_shapes[0]) == len(input_shapes[1]), 'should have same dims'
output_shape = input_shapes[1]
assert (input_shapes[2] == output_shape),\
"shape not consistent for axpy[%s <--> %s]" \
% (str(output_shape), str(input_shapes[2]))
return [output_shape]
def shape_detectionoutput(layer, input_shape):
return [[-1, 6]]
def shape_normalize(layer, input_shape):
return input_shape
def shape_permute(layer, input_shape):
order = layer.permute_param.order
inshape = input_shape[0]
output_shape = []
order = list(order)
for ii in order:
assert ii < len(inshape), "invalid order for permute[%s]" % (name)
output_shape.append(inshape[ii])
return [output_shape]
def shape_priorbox(layer, input_shape):
max_size = layer.prior_box_param.max_size
aspect_ratio = layer.prior_box_param.aspect_ratio
fc_shape = input_shape[0]
N = 1
if not max_size == None:
N += 1
if not aspect_ratio == None:
N += 2 * len(aspect_ratio)
N_bbx = fc_shape[2] * fc_shape[3] * N
output_shape = [1, 2, 4 * N_bbx]
return [output_shape]
def shape_relu6(layer, input_shape):
return input_shape
def shape_roipooling(layer, input_shape):
pooled_w = layer.roi_pooling_param.pooled_w
pooled_h = layer.roi_pooling_param.pooled_h
base_fea_shape = input_shapes[0]
rois_shape = input_shapes[1]
output_shape = base_fea_shape
output_shape[0] = rois_shape[0]
output_shape[2] = pooled_h
output_shape[3] = pooled_w
return [output_shape]
def shape_shufflechannel(layer, input_shape):
return input_shape
def shape_upsample(layer, input_shape):
scale = layer.upsample_param.scale
assert len(input_shapes) == 1, "not valid input shape for upsample layer"
assert type(scale) is int
input_shape = input_shapes[0]
new_h = scale * input_shape[2]
new_w = scale * input_shape[3]
output_shape = [input_shape[0], input_shape[1], new_h, new_w]
return [output_shape]
def shape_select(layer, input_shape):
slice_point = layer.select_param.slice_point
axis = layer.select_param.axis
input_shape = input_shapes[0]
start = slice_point[0]
if len(slice_point) == 2:
end = slice_point[1]
else:
end = input_shape[axis]
assert end > start, "invalid slice_point with [start:%d, end:%d]"\
% (start, end)
output_shape = input_shape
output_shape[axis] = end - start
return [output_shape]
......@@ -64,6 +64,12 @@ class ONNXGraphNode(GraphNode):
if 'value' not in self.attr_map:
return None
return self.attr_map['value']
@property
def name(self):
if hasattr(self, 'index'):
return "{}_p{}".format(self.layer_name, self.index)
return self.layer_name
def get_attribute_value(self, attr):
"""
......@@ -118,6 +124,10 @@ class ONNXGraphDataNode(GraphNode):
out_shapes = list()
out_shapes.append(values)
return out_shapes
@property
def name(self):
return self.layer_name
@property
def dtype(self):
......@@ -144,11 +154,12 @@ class ONNXGraph(Graph):
if self.graph is None:
print('[WARNING] Shape inference by ONNX offical interface.')
onnx_model = shape_inference.infer_shapes(onnx_model)
self.graph = onnx_model.graph
self.graph = onnx_model.graph
print("shape inferenced.")
self.build()
self.collect_value_infos()
self.allocate_shapes()
self.graph_name = "ONNXModel"
def get_inner_nodes(self):
"""
......@@ -307,6 +318,7 @@ class ONNXGraph(Graph):
if ipt_node.layer_name in node.which_child:
ipt_node.index = node.which_child[ipt_node.layer_name]
return ipt_node
def graph_weights(self):
"""
......@@ -538,4 +550,4 @@ class ONNXDecoder(object):
node.input[i] = self.make_variable_name(node.input[i])
for i in range(len(node.output)):
node.output[i] = self.make_variable_name(node.output[i])
return model
return model
\ No newline at end of file
......@@ -1601,11 +1601,11 @@ class SymbolicShapeInference:
in_mp)
symbolic_shape_inference._update_output_from_vi()
if not all_shapes_inferred:
print('!' * 10)
symbolic_shape_inference.out_mp_ = shape_inference.infer_shapes(
symbolic_shape_inference.out_mp_)
print('[INFO] Complete symbolic shape inference.')
except:
print('[WARNING] Incomplete symbolic shape inference.')
print('[WARNING] Incomplete symbolic shape inference')
symbolic_shape_inference.out_mp_ = shape_inference.infer_shapes(
symbolic_shape_inference.out_mp_)
return symbolic_shape_inference.out_mp_.graph
return symbolic_shape_inference.out_mp_.graph
\ No newline at end of file
......@@ -12,15 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import torch
import numpy as np
class PyTorchDecoder(object):
def __init__(self, script_path):
self.script = torch.jit.load(script_path)
self.graph = self._optimize_graph(self.script.inlined_graph)
def _optimize_graph(self, graph):
class Decoder(object):
def _optimize_graph(self, graph):
torch._C._jit_pass_constant_propagation(graph)
torch._C._jit_pass_dce(graph)
torch._C._jit_pass_lint(graph)
......@@ -31,4 +30,40 @@ class PyTorchDecoder(object):
torch._C._jit_pass_canonicalize(graph)
torch._C._jit_pass_lint(graph)
torch._C._jit_pass_constant_propagation(graph)
return graph
return graph
class ScriptDecoder(Decoder):
""" 当script_path非None,直接load ScriptModule;
当model_path非None,load PyTorchModule后使用script方式转换为ScriptModule。
Args:
script_path (str): ScriptModule保存路径。
model_path (str): PyTorchModule保存路径。
"""
def __init__(self, module, input_examples=None):
self.script = torch.jit.script(module)
self.graph = self._optimize_graph(self.script.inlined_graph)
self.input_examples = input_examples
class TraceDecoder(Decoder):
""" PyTorchModule后使用trace方式转换为ScriptModule。
Args:
model_path (str): PyTorchModule保存路径。
input_files (list): 输入网络的numpy,每个numpy保存成.npy文件,
文件路径存储在input_files中。
"""
def __init__(self, module, input_examples):
try:
self.script = torch.jit.trace(module, input_examples)
except RuntimeError as e:
if "strict" in str(e):
self.script = torch.jit.trace(module, input_examples, strict=False)
else:
print(e)
exit(0)
self.graph = self._optimize_graph(self.script.inlined_graph)
self.input_examples = input_examples
......@@ -132,6 +132,7 @@ class TFGraph(Graph):
self.identity_map = dict()
self.multi_out_ops = ['Split', 'SplitV', 'IteratorV2']
self.tf_data_format = data_format
self.graph_name = "TFModel"
def build(self):
for layer in self.model.node:
......@@ -188,6 +189,10 @@ class TFGraph(Graph):
if len(items) == 1 and node.layer_type in self.multi_out_ops:
node.index = 0
return node
def get_input_node(self, node, idx=0, copy=False):
input_node_name = node.layer.input[idx]
return self.get_node(input_node_name, copy)
def remove_node(self, node_name):
if node_name not in self.node_map:
......@@ -315,7 +320,7 @@ class TFDecoder(object):
self.sess = tf.compat.v1.Session()
except:
self.sess = tf.Session()
self.input_info = dict()
self.inputs_info = dict()
self.define_input_shape = define_input_shape
with open(pb_model, 'rb') as f:
try:
......@@ -397,7 +402,7 @@ class TFDecoder(object):
right_shape_been_input = False
while not right_shape_been_input:
try:
shape = raw_input(
shape = input(
"Shape of Input(e.g. None,224,224,3): ")
except:
shape = input("Shape of Input(e.g. None,224,224,3): ")
......@@ -425,50 +430,40 @@ class TFDecoder(object):
input_map["{}:0".format(layer.name)] = x2paddle_input
if shape.count(None) > 0:
shape[shape.index(None)] = -1
self.input_info["x2paddle_{}".format(layer.name)] = (shape,
self.inputs_info["x2paddle_{}".format(layer.name)] = (shape,
dtype)
else:
value = graph_node.layer.attr["shape"].shape
shape = [dim.size for dim in value.dim]
self.input_info[layer.name] = (shape, dtype)
self.inputs_info[layer.name] = (shape, dtype)
return input_map
# trick method
# should be removed after PaddlePaddle V1.6 been released
def infer_tensor(self, graph_node):
def infer_tensor(self, graph_node, out_shape=None, use_diff_inputs=True):
if hasattr(graph_node, "index"):
tensor_name = graph_node.layer.name + ":{}".format(graph_node.index)
else:
tensor_name = graph_node.layer.name + ":0"
feed = dict()
for input_name, info in self.input_info.items():
(shape, dtype) = cp.deepcopy(info)
input_tensor = self.sess.graph.get_tensor_by_name(input_name + ":0")
if shape.count(-1) > 0:
shape[shape.index(-1)] = 2
feed[input_tensor] = numpy.random.random_sample(shape)
output_tensor = self.sess.graph.get_tensor_by_name(tensor_name)
return self.sess.run([output_tensor], feed)[0]
def infer_shape_tensor(self, graph_node, out_shape=None):
if hasattr(graph_node, "index"):
tensor_name = graph_node.layer.name + ":{}".format(graph_node.index)
if use_diff_inputs:
batch_size = [2, 3, 5]
else:
tensor_name = graph_node.layer.name + ":0"
feed = dict()
batch_size = [2, 3, 5]
batch_size = [2]
results = list()
for b in batch_size:
for input_name, info in self.input_info.items():
for input_name, info in self.inputs_info.items():
(shape, dtype) = cp.deepcopy(info)
input_tensor = self.sess.graph.get_tensor_by_name(input_name +
":0")
input_tensor = self.sess.graph.get_tensor_by_name(input_name + ":0")
if shape.count(-1) > 0:
shape[shape.index(-1)] = b
feed[input_tensor] = numpy.random.random_sample(shape)
output_tensor = self.sess.graph.get_tensor_by_name(tensor_name)
results.append(self.sess.run([output_tensor], feed)[0].flatten())
if use_diff_inputs:
results.append(self.sess.run([output_tensor], feed)[0].flatten())
else:
return self.sess.run([output_tensor], feed)[0]
compare01 = (results[0] == results[1])
compare12 = (results[1] == results[2])
......@@ -493,38 +488,4 @@ class TFDecoder(object):
return results[0].tolist()
else:
raise Exception("Couldn't infer a stable shape shape tensor value")
def infer_tensor_shape(self, graph_node):
if hasattr(graph_node, "index"):
tensor_name = graph_node.layer.name + ":{}".format(graph_node.index)
else:
tensor_name = graph_node.layer.name + ":0"
feed = dict()
batch_size = [2, 3, 5]
shapes = list()
for b in batch_size:
for input_name, info in self.input_info.items():
(shape, dtype) = cp.deepcopy(info)
input_tensor = self.sess.graph.get_tensor_by_name(input_name +
":0")
if shape.count(-1) > 0:
shape[shape.index(-1)] = b
feed[input_tensor] = numpy.random.random_sample(shape)
output_tensor = self.sess.graph.get_tensor_by_name(tensor_name)
shape = self.sess.run([output_tensor], feed)[0].shape
shapes.append(numpy.array(shape))
compare01 = (shapes[0] == shapes[1])
compare12 = (shapes[1] == shapes[2])
if compare01.all() and compare12.all():
return shape[0].tolist()
if (compare01 == compare12).all():
index = numpy.argwhere(compare01 == False).flatten()
if index.shape[0] != 1:
raise Exception("There's not only one unstable dimension")
if index[0] != 0:
raise Exception("Batch size not in the first dimension")
shapes[0][0] = -1
return shapes[0].tolist()
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .detectionoutput import DetectionOutput
from .normalize import Normalize
from .priorbox import PriorBox
from .roipooling import ROIPooling
from .select import Select
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
class DetectionOutput(object):
def __init__(self, nms_threshold, nms_top_k, keep_top_k, nms_eta, score_threshold, background_label):
self.detection_output_layer_attrs = {
"background_label": background_label,
"nms_threshold": nms_threshold,
"nms_top_k": nms_top_k,
"keep_top_k": keep_top_k,
"score_threshold": score_threshold,
"nms_eta": nms_eta}
def __call__(self, x0, x1, x2):
priorbox_list = paddle.split(x2, num_or_sections=2, axis=1)
pb = priorbox_list[0]
pbv = priorbox_list[1]
pb = paddle.reshape(x=pb, shape=[-1, 4])
pbv = paddle.reshape(x=pbv, shape=[-1, 4])
pb_dim = fluid.layers.shape(pb)[0]
loc = paddle.reshape(x0, shape=[-1, pb_dim, 4])
conf_flatten = paddle.reshape(x1, shape=[0, pb_dim, -1])
out = fluid.layers.detection_output(loc=loc,
scores=conf_flatten,
prior_box=pb,
prior_box_var=pbv,
**self.detection_output_layer_attrs)
return out
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
class Normalize(object):
def __init__(self, axis, param_name, param_shape):
self.axis = axis
self.param_name = param_name
self.param_shape = param_shape
def __call__(self, x):
l2 = fluid.layers.prior_box(x=x, p=2, axis=1)
attr = fluid.ParamAttr(name=self.param_name, trainable=False)
param = paddle.nn.Layer.create_parameter(shape=self.param_shape,
attr=atr)
out = paddle.multiply(x=l2, y=param, axis=self.axis)
return out
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
class PriorBox(object):
def __init__(self, min_sizes, max_sizes,
aspect_ratios, variance, flip,
clip, steps, offset,
min_max_aspect_ratios_order):
self.priorbox_layer_attrs = {
"min_sizes": min_sizes,
"max_sizes": max_sizes,
"aspect_ratios": aspect_ratios,
"variance": variance,
"flip": flip,
"clip": clip,
"steps": steps,
"offset": offset,
"min_max_aspect_ratios_order": min_max_aspect_ratios_order}
def __call__(self, x0, x1):
box, var = fluid.layers.prior_box(input=x0,
image=x1,
**self.priorbox_layer_attrs)
box = paddle.reshape(x=box, shape=[1, 1, -1])
var = paddle.reshape(x=var, shape=[1, 1, -1])
out = paddle.concat(x=[box, var], axis=1)
return out
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
class ROIPooling(object):
def __init__(self, pooled_height, pooled_width, spatial_scale):
self.roipooling_layer_attrs = {
"pooled_height": pooled_height,
"pooled_width": pooled_width,
"spatial_scale": spatial_scale}
def __call__(self, x0, x1):
slice_x1 = paddle.slice(input=x1, axes=[1],
starts=[1], ends=[5])
out = fluid.layers.roi_pool(input=x0,
rois=slice_x1,
**self.roipooling_layer_attrs)
return out
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
class Select(object):
def __init__(self, input_shape, point, axis):
self.point = point
self.input_shape = input_shape
self.axis = axis
def __call__(self, x):
start = self.point[0]
if len(self.point) == 2:
end = self.point[1]
else:
end = self.input_shape[self.axis]
out = paddle.slice(x=x,
start=start,
end=end,
axes=[self.axis])
return out
\ No newline at end of file
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from x2paddle.op_mapper.dygraph.onnx2paddle.opset9 import OpSet9
from x2paddle.core.op_mapper import OpMapper
from x2paddle.decoder.onnx_decoder import ONNXGraphNode
from x2paddle.core.program import PaddleGraph
class ONNXOpMapper(OpMapper):
def __init__(self, decoder):
super(ONNXOpMapper, self).__init__()
self.support_op_sets = [9, ]
self.default_op_set = 9
self.graph = decoder.graph
self.paddle_graph = PaddleGraph(parent_layer=None, graph_type="dygraph", source_type="onnx")
self.paddle_graph.outputs = self.graph.output_nodes
self.opset = self.create_opset(decoder)
if not self.op_checker():
raise Exception("Model is not supported yet.")
print("Total nodes: {}".format(
sum([
isinstance(node, ONNXGraphNode)
for name, node in self.graph.node_map.items()
])))
print("Nodes converting ...")
for i, node_name in enumerate(self.graph.topo_sort):
sys.stderr.write("\rConverting node {} ... ".format(i + 1))
node = self.graph.get_node(node_name)
op = node.layer_type
if hasattr(self.opset, op):
func = getattr(self.opset, op)
func(node)
elif op in self.opset.directly_map_ops:
self.opset.directly_map(node)
elif op in self.opset.elementwise_ops:
self.opset.elementwise_map(node)
print("\nNodes converted.")
self.paddle_graph.set_name(self.graph.graph_name)
self.paddle_graph.set_parameters(self.opset.weights)
self.paddle_graph.set_inputs_info(self.opset.inputs_info)
def op_checker(self):
unsupported_ops = set()
for node_name in self.graph.topo_sort:
node = self.graph.get_node(node_name)
op = node.layer_type
if not hasattr(self.opset, op) and \
op not in self.opset.directly_map_ops and \
op not in self.opset.elementwise_ops:
unsupported_ops.add(op)
if len(unsupported_ops) == 0:
return True
else:
if len(unsupported_ops) > 0:
print("\n========= {} OPs are not supported yet ===========".format(
len(unsupported_ops)))
for op in unsupported_ops:
print("========== {} ============".format(op))
return False
def create_opset(self, decoder):
run_op_set = self.default_op_set
opset = ''
if decoder.op_set in self.support_op_sets:
opset = 'OpSet' + str(decoder.op_set)
elif decoder.op_set < self.default_op_set:
opset = 'OpSet' + str(self.default_op_set)
else:
for op_set in self.support_op_sets:
if decoder.op_set > op_set:
run_op_set = op_set
else:
break
opset = 'OpSet' + str(run_op_set)
print(
'Now, onnx2paddle support convert onnx model opset_verison {},'
'opset_verison of your onnx model is {}, automatically treated as op_set: {}.'
.format(self.support_op_sets, decoder.op_set, run_op_set))
return eval(opset)(decoder, self.paddle_graph)
此差异已折叠。
此差异已折叠。
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .gather import Gather
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
from itertools import product
import numpy as np
class Gather(object):
def __init__(self, dim):
self.dim = dim
self.dtype_mapping = {"VarType.INT32": "int32",
"VarType.INT64": "int64"}
def __call__(self, x, index):
if self.dim < 0:
self.dim += len(x.shape)
x_range = list(range(len(x.shape)))
x_range[0] = self.dim
x_range[self.dim] = 0
x_swaped = paddle.transpose(x, perm=x_range)
index_range = list(range(len(index.shape)))
index_range[0] = self.dim
index_range[self.dim] = 0
index_swaped = paddle.transpose(index, perm=index_range)
dtype = self.dtype_mapping[str(index.dtype)]
x_shape = paddle.shape(x_swaped)
index_shape = paddle.shape(index_swaped)
prod = paddle.prod(x_shape, dtype=dtype) / x_shape[0]
x_swaped_flattend = paddle.flatten(x_swaped)
index_swaped_flattend = paddle.flatten(index_swaped)
index_swaped_flattend *= prod
bias = paddle.arange(start=0, end=prod, dtype=dtype)
bias = paddle.reshape(bias, x_shape[1:])
bias = paddle.crop(bias, index_shape[1:])
bias = paddle.flatten(bias)
bias = paddle.tile(bias, [index_shape[0]])
index_swaped_flattend += bias
gathered = paddle.index_select(x_swaped_flattend, index_swaped_flattend)
gathered = paddle.reshape(gathered, index_swaped.shape)
out = paddle.transpose(gathered, perm=x_range)
return out
此差异已折叠。
......@@ -14,18 +14,6 @@ def detectionoutput_layer(inputs,
confidence_threshold=0.1,
input_shape=None,
name=None):
nms_param_str = nms_param
nms_param = {}
part = nms_param_str.split(',')
for s in part:
if s == '':
break
else:
name, obj = s.split(': ')
if name == 'top_k':
nms_param[name] = int(obj)
else:
nms_param[name] = float(obj)
if nms_param is None:
nms_param = {"nms_threshold": 0.3, "top_k": 10, "eta": 1.0}
mbox_conf_flatten = inputs[1]
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.op_mapper.onnx2paddle.opset9 import OpSet9, custom_layers
from x2paddle.op_mapper.static.onnx2paddle.opset9 import OpSet9, custom_layers
from x2paddle.core.op_mapper import OpMapper
from x2paddle.decoder.onnx_decoder import ONNXGraph, ONNXGraphNode, ONNXGraphDataNode
......
......@@ -17,7 +17,7 @@ from x2paddle.core.graph import GraphNode
from x2paddle.core.fluid_code import Layer
from x2paddle.core.fluid_code import FluidCode
from x2paddle.core.util import string
from x2paddle.op_mapper.onnx2paddle.opset9.custom_layer import *
from x2paddle.op_mapper.static.onnx2paddle.opset9.custom_layer import *
from functools import reduce
import numpy as np
import onnx
......
此差异已折叠。
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.optimizer.code_optimizer.hierachical_tree import HierarchicalTree
\ No newline at end of file
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册