未验证 提交 f326e2c3 编写于 作者: J Jason 提交者: GitHub

Merge pull request #421 from SunAhong1993/paddle-2.0

add caffe2paddle
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from six import text_type as _text_type from six import text_type as _text_type
from x2paddle import program
import argparse import argparse
import sys import sys
...@@ -88,11 +89,13 @@ def arg_parser(): ...@@ -88,11 +89,13 @@ def arg_parser():
default=False, default=False,
help="define whether merge the params") help="define whether merge the params")
parser.add_argument( parser.add_argument(
"--input_shapes", "--paddle_type",
"-is", "-pt",
action='append', type=_text_type,
default=None, default="dygraph",
help="define the inputs' shape") help="define the paddle model type after converting(dygraph/static)"
)
return parser return parser
...@@ -117,7 +120,7 @@ def tf2paddle(model_path, ...@@ -117,7 +120,7 @@ def tf2paddle(model_path,
"[ERROR] Tensorflow is not installed, use \"pip install tensorflow\"." "[ERROR] Tensorflow is not installed, use \"pip install tensorflow\"."
) )
return return
from x2paddle import program
from x2paddle.decoder.tf_decoder import TFDecoder from x2paddle.decoder.tf_decoder import TFDecoder
from x2paddle.op_mapper.tf_op_mapper import TFOpMapper from x2paddle.op_mapper.tf_op_mapper import TFOpMapper
from x2paddle.optimizer.tensorflow.bias import BiasOpt from x2paddle.optimizer.tensorflow.bias import BiasOpt
...@@ -126,6 +129,7 @@ def tf2paddle(model_path, ...@@ -126,6 +129,7 @@ def tf2paddle(model_path,
print("Now translating model from tensorflow to paddle.") print("Now translating model from tensorflow to paddle.")
model = TFDecoder(model_path, define_input_shape=define_input_shape) model = TFDecoder(model_path, define_input_shape=define_input_shape)
mapper = TFOpMapper(model) mapper = TFOpMapper(model)
program.build() program.build()
bias_opt = BiasOpt() bias_opt = BiasOpt()
...@@ -137,10 +141,13 @@ def tf2paddle(model_path, ...@@ -137,10 +141,13 @@ def tf2paddle(model_path,
program.gen_model(save_dir) program.gen_model(save_dir)
def caffe2paddle(proto, weight, save_dir, caffe_proto, params_merge=False): def caffe2paddle(proto, weight, save_dir, caffe_proto,
paddle_type, params_merge=False):
from x2paddle.decoder.caffe_decoder import CaffeDecoder from x2paddle.decoder.caffe_decoder import CaffeDecoder
from x2paddle.op_mapper.caffe_op_mapper import CaffeOpMapper if paddle_type == "dygraph":
from x2paddle.optimizer.caffe_optimizer import CaffeOptimizer from x2paddle.op_mapper.dygraph.caffe2paddle.caffe_op_mapper import CaffeOpMapper
else:
from x2paddle.op_mapper.static.caffe2paddle.caffe_op_mapper import CaffeOpMapper
import google.protobuf as gpb import google.protobuf as gpb
ver_part = gpb.__version__.split('.') ver_part = gpb.__version__.split('.')
version_satisfy = False version_satisfy = False
...@@ -151,10 +158,13 @@ def caffe2paddle(proto, weight, save_dir, caffe_proto, params_merge=False): ...@@ -151,10 +158,13 @@ def caffe2paddle(proto, weight, save_dir, caffe_proto, params_merge=False):
print("Now translating model from caffe to paddle.") print("Now translating model from caffe to paddle.")
model = CaffeDecoder(proto, weight, caffe_proto) model = CaffeDecoder(proto, weight, caffe_proto)
mapper = CaffeOpMapper(model) mapper = CaffeOpMapper(model)
optimizer = CaffeOptimizer(mapper) mapper.paddle_graph.build()
optimizer.merge_bn_scale() print("Model optimizing ...")
optimizer.merge_op_activation() from x2paddle.optimizer.optimizer import GraphOptimizer
mapper.save_inference_model(save_dir, params_merge) graph_opt = GraphOptimizer(source_frame="caffe", paddle_type=paddle_type)
graph_opt.optimize(mapper.paddle_graph)
print("Model optimized.")
mapper.paddle_graph.gen_model(save_dir)
def onnx2paddle(model_path, save_dir, params_merge=False): def onnx2paddle(model_path, save_dir, params_merge=False):
...@@ -185,7 +195,7 @@ def onnx2paddle(model_path, save_dir, params_merge=False): ...@@ -185,7 +195,7 @@ def onnx2paddle(model_path, save_dir, params_merge=False):
print("Paddle model and code generated.") print("Paddle model and code generated.")
def pytorch2paddle(model_path, save_dir, input_shapes): def pytorch2paddle(model_path, save_dir, jit_type, input_files):
# check pytorch installation and version # check pytorch installation and version
try: try:
import torch import torch
...@@ -202,9 +212,12 @@ def pytorch2paddle(model_path, save_dir, input_shapes): ...@@ -202,9 +212,12 @@ def pytorch2paddle(model_path, save_dir, input_shapes):
return return
print("Now translating model from pytorch to paddle.") print("Now translating model from pytorch to paddle.")
from x2paddle.decoder.pytorch_decoder import PyTorchDecoder from x2paddle.decoder.pytorch_decoder import ScriptDecoder, TraceDecoder
from x2paddle.op_mapper.pytorch2paddle import pytorch_op_mapper from x2paddle.op_mapper.pytorch2paddle import pytorch_op_mapper
model = PyTorchDecoder(model_path) if jit_type == "trace":
model = TraceDecoder(model_path, input_files)
else:
model = ScriptDecoder(model_path)
mapper = pytorch_op_mapper.PyTorchOpMapper(model) mapper = pytorch_op_mapper.PyTorchOpMapper(model)
mapper.graph.build() mapper.graph.build()
print("Model optimizing ...") print("Model optimizing ...")
...@@ -212,16 +225,7 @@ def pytorch2paddle(model_path, save_dir, input_shapes): ...@@ -212,16 +225,7 @@ def pytorch2paddle(model_path, save_dir, input_shapes):
graph_opt = GraphOptimizer() graph_opt = GraphOptimizer()
graph_opt.optimize(mapper.graph) graph_opt.optimize(mapper.graph)
print("Model optimized.") print("Model optimized.")
if input_shapes is not None: mapper.graph.gen_model(save_dir, jit_type, input_files)
real_input_shapes = list()
for shape in input_shapes:
sp = shape[1:-1].split(",")
for i, s in enumerate(sp):
sp[i] = int(s)
real_input_shapes.append(sp)
else:
real_input_shapes = None
mapper.graph.gen_model(save_dir, real_input_shapes)
def paddle2onnx(model_path, save_dir, opset_version=10): def paddle2onnx(model_path, save_dir, opset_version=10):
...@@ -260,6 +264,7 @@ def main(): ...@@ -260,6 +264,7 @@ def main():
assert args.framework is not None, "--framework is not defined(support tensorflow/caffe/onnx)" assert args.framework is not None, "--framework is not defined(support tensorflow/caffe/onnx)"
assert args.save_dir is not None, "--save_dir is not defined" assert args.save_dir is not None, "--save_dir is not defined"
assert args.paddle_type in ["dygraph", "static"], "--paddle_type must be 'dygraph' or 'static'"
try: try:
import paddle import paddle
...@@ -267,8 +272,8 @@ def main(): ...@@ -267,8 +272,8 @@ def main():
print("paddle.__version__ = {}".format(paddle.__version__)) print("paddle.__version__ = {}".format(paddle.__version__))
if v0 == '0' and v1 == '0' and v2 == '0': if v0 == '0' and v1 == '0' and v2 == '0':
print("[WARNING] You are use develop version of paddlepaddle") print("[WARNING] You are use develop version of paddlepaddle")
elif int(v0) != 1 or int(v1) < 6: elif int(v0) != 2 or int(v1) < 0:
print("[ERROR] paddlepaddle>=1.6.0 is required") print("[ERROR] paddlepaddle>=2.0.0 is required")
return return
except: except:
print( print(
...@@ -296,7 +301,7 @@ def main(): ...@@ -296,7 +301,7 @@ def main():
if args.params_merge: if args.params_merge:
params_merge = True params_merge = True
caffe2paddle(args.prototxt, args.weight, args.save_dir, caffe2paddle(args.prototxt, args.weight, args.save_dir,
args.caffe_proto, params_merge) args.caffe_proto, args.paddle_type, params_merge)
elif args.framework == "onnx": elif args.framework == "onnx":
assert args.model is not None, "--model should be defined while translating onnx model" assert args.model is not None, "--model should be defined while translating onnx model"
params_merge = False params_merge = False
...@@ -304,7 +309,10 @@ def main(): ...@@ -304,7 +309,10 @@ def main():
if args.params_merge: if args.params_merge:
params_merge = True params_merge = True
onnx2paddle(args.model, args.save_dir, params_merge) onnx2paddle(args.model, args.save_dir, params_merge)
elif args.framework == "pytorch":
assert args.model is not None, "--model should be defined while translating pytorch model"
pytorch2paddle(args.model, args.save_dir, args.jit_type, args.input_files)
elif args.framework == "paddle2onnx": elif args.framework == "paddle2onnx":
assert args.model is not None, "--model should be defined while translating paddle model to onnx" assert args.model is not None, "--model should be defined while translating paddle model to onnx"
paddle2onnx(args.model, args.save_dir, opset_version=args.onnx_opset) paddle2onnx(args.model, args.save_dir, opset_version=args.onnx_opset)
......
# -*- coding:UTF-8 -*-
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License"
...@@ -21,15 +20,15 @@ import paddle ...@@ -21,15 +20,15 @@ import paddle
from paddle.fluid.proto import framework_pb2 from paddle.fluid.proto import framework_pb2
from collections import OrderedDict from collections import OrderedDict
import numpy import numpy
import collections
import sys import sys
import os import os
import six import six
import pickle import pickle
import numpy as np
class PaddleLayer(object): class PaddleLayer(object):
def __init__(self, id, kernel, inputs, outputs, **kwargs): def __init__(self, id, kernel, inputs, outputs, scope_name="", **kwargs):
assert isinstance( assert isinstance(
inputs, inputs,
dict), "parameter 'inputs' for PaddleLayer should be type of dict" dict), "parameter 'inputs' for PaddleLayer should be type of dict"
...@@ -53,16 +52,18 @@ class PaddleLayer(object): ...@@ -53,16 +52,18 @@ class PaddleLayer(object):
self.kernel = kernel self.kernel = kernel
self.inputs = inputs self.inputs = inputs
self.outputs = outputs self.outputs = outputs
self.scope_name = scope_name
self.attrs = kwargs self.attrs = kwargs
self.id = id self.id = id
self.blocks = list() self.blocks = list()
def add_block(self, block): def add_block(self, block):
self.blocks.append(block) self.blocks.append(block)
class PaddleGraph(object): class PaddleGraph(object):
def __init__(self, parent_layer=None, graph_type="static"): def __init__(self, source_type=None, parent_layer=None, graph_type="static"):
self.layers = OrderedDict() self.layers = OrderedDict()
self.edges_out = dict() self.edges_out = dict()
self.edges_in = dict() self.edges_in = dict()
...@@ -71,12 +72,24 @@ class PaddleGraph(object): ...@@ -71,12 +72,24 @@ class PaddleGraph(object):
self.parameters = dict() self.parameters = dict()
self.parent_layer = parent_layer self.parent_layer = parent_layer
self.graph_type = graph_type self.graph_type = graph_type
self.source_type = source_type
self.custom_code = None
self.inputs_info = None
def set_name(self, name): def set_name(self, name):
self.name = name self.name = name.replace("-", "_").replace("/", "_")
def set_parameters(self, parameters): def set_parameters(self, parameters):
self.parameters = parameters self.parameters = parameters
def set_custom(self, custom_code):
self.custom_code = custom_code
def set_inputs_info(self, inputs_info):
self.inputs_info = inputs_info
def set_script(self, script):
self.script = script
def clear(self): def clear(self):
self.layers = OrderedDict() self.layers = OrderedDict()
...@@ -90,13 +103,13 @@ class PaddleGraph(object): ...@@ -90,13 +103,13 @@ class PaddleGraph(object):
self.edges_out = dict() self.edges_out = dict()
self.edges_in = dict() self.edges_in = dict()
def add_layer(self, kernel, inputs, outputs, **kwargs): def add_layer(self, kernel, inputs, outputs, scope_name="", **kwargs):
layer_id = str(len(self.layers)) layer_id = str(len(self.layers))
if self.parent_layer is not None: if self.parent_layer is not None:
layer_id = "{}.{}.{}".format(self.parent_layer.id, layer_id = "{}.{}.{}".format(self.parent_layer.id,
len(self.parent_layer.blocks), len(self.parent_layer.blocks),
layer_id) layer_id)
layer = PaddleLayer(layer_id, kernel, inputs, outputs, **kwargs) layer = PaddleLayer(layer_id, kernel, inputs, outputs, scope_name=scope_name, **kwargs)
self.layers[layer_id] = layer self.layers[layer_id] = layer
return layer_id return layer_id
...@@ -215,10 +228,67 @@ class PaddleGraph(object): ...@@ -215,10 +228,67 @@ class PaddleGraph(object):
block_global_layers = update(block.layers) block_global_layers = update(block.layers)
global_layers.update(block_global_layers) global_layers.update(block_global_layers)
return global_layers return global_layers
return update(self.layers) return update(self.layers)
def gen_code(self, code_dir): def gen_model(self, save_dir, jit_type=None):
if not os.path.exists(save_dir):
os.makedirs(save_dir)
if self.graph_type == "static":
self.gen_static_model(save_dir)
else:
self.gen_dygraph_model(save_dir, jit_type)
def gen_static_model(self, save_dir):
code_dir = os.path.join(save_dir, 'model_with_code')
infer_dir = os.path.join(save_dir, 'inference_model')
self.gen_static_code(code_dir)
sys.path.append(code_dir)
import x2paddle_model
paddle.enable_static()
scope = paddle.static.Scope()
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
with paddle.static.scope_guard(scope):
with paddle.static.program_guard(main_program, startup_program):
inputs, outputs = x2paddle_model.x2paddle_net()
exe = fluid.Executor(fluid.CPUPlace())
exe.run(startup_program)
param_dir = os.path.join(code_dir, 'weights')
for k, v in self.parameters.items():
if scope.find_var(k):
self.dump_parameter(k, v, param_dir)
def if_exist(var):
b = os.path.exists(
os.path.join(os.path.join(param_dir, var.name)))
return b
fluid.io.load_vars(
exe, param_dir, main_program, predicate=if_exist)
fluid.io.save_inference_model(
dirname=infer_dir,
feeded_var_names=[i.name for i in inputs],
target_vars=outputs,
executor=exe)
def gen_dygraph_model(self, save_dir, jit_type=None):
if jit_type == "trace":
from x2paddle.optimizer.code_optimizer import HierarchicalTree
hierarchical_tree = HierarchicalTree(self)
for layer_id, layer in self.layers.items():
hierarchical_tree.insert(layer)
hierarchical_tree.save_source_files(save_dir)
self.dump_dygraph_parameter(save_dir)
else:
self.gen_dygraph_code(save_dir)
self.dump_dygraph_parameter(save_dir)
input_shapes = list()
input_types = list()
for input_name in self.inputs:
input_shapes.append(self.inputs_info[input_name][0])
input_types.append(self.inputs_info[input_name][1])
# 如果input_files非空,则导出推理模型;其值类似[[None, 3, 224, 224]]
self.dygraph2static(save_dir, input_shapes, input_types)
def gen_static_code(self, code_dir):
def write_code(f, code_list, indent=0): def write_code(f, code_list, indent=0):
indent_blank = " " * indent indent_blank = " " * indent
for code_line in code_list: for code_line in code_list:
...@@ -235,10 +305,24 @@ class PaddleGraph(object): ...@@ -235,10 +305,24 @@ class PaddleGraph(object):
f, [ f, [
"from paddle.fluid.initializer import Constant", "from paddle.fluid.initializer import Constant",
"from paddle.fluid.param_attr import ParamAttr", "from paddle.fluid.param_attr import ParamAttr",
"import paddle.fluid as fluid", "import math", "", "import paddle.fluid as fluid",
"def x2paddle_net():" "import paddle", "import math", "",
], ],
indent=0) indent=0)
if self.custom_code is not None:
write_code(
f,
list(self.custom_code.values()),
indent=0)
write_code(f,
["", "def x2paddle_net():"],
indent=0)
write_code(
f, [
"paddle.enable_static()"
],
indent=1)
for layer_id, layer in self.layers.items(): for layer_id, layer in self.layers.items():
edges_in = self.edges_in.get(layer_id, []) edges_in = self.edges_in.get(layer_id, [])
edges_out = self.edges_out.get(layer_id, []) edges_out = self.edges_out.get(layer_id, [])
...@@ -253,8 +337,10 @@ class PaddleGraph(object): ...@@ -253,8 +337,10 @@ class PaddleGraph(object):
for output in layer.outputs: for output in layer.outputs:
line += "{}, ".format(output) line += "{}, ".format(output)
line = line.strip(", ") line = line.strip(", ")
if layer.kernel.startswith("custom_layer"):
line += " = {}(".format(layer.kernel) line += " = {}(".format(layer.kernel.split(":")[-1].lower() + "_layer")
else:
line += " = {}(".format(layer.kernel)
for k, v in layer.inputs.items(): for k, v in layer.inputs.items():
if isinstance(v, list): if isinstance(v, list):
line += "{}=[{}], ".format(k, ", ".join(v)) line += "{}=[{}], ".format(k, ", ".join(v))
...@@ -274,47 +360,6 @@ class PaddleGraph(object): ...@@ -274,47 +360,6 @@ class PaddleGraph(object):
indent=1) indent=1)
f.close() f.close()
def gen_model(self, save_dir, input_shapes=None):
if not os.path.exists(save_dir):
os.makedirs(save_dir)
if self.graph_type == "static":
code_dir = os.path.join(save_dir, 'model_with_code')
infer_dir = os.path.join(save_dir, 'inference_model')
self.gen_code(code_dir)
sys.path.append(code_dir)
import x2paddle_model
scope = fluid.Scope()
startup_program = fluid.Program()
main_program = fluid.Program()
with fluid.scope_guard(scope):
with fluid.program_guard(main_program, startup_program):
inputs, outputs = x2paddle_model.x2paddle_net()
exe = fluid.Executor(fluid.CPUPlace())
exe.run(startup_program)
param_dir = os.path.join(code_dir, 'weights')
for k, v in self.parameters.items():
if scope.find_var(k):
self.dump_parameter(k, v, param_dir)
def if_exist(var):
b = os.path.exists(
os.path.join(os.path.join(param_dir, var.name)))
return b
fluid.io.load_vars(
exe, param_dir, main_program, predicate=if_exist)
fluid.io.save_inference_model(
dirname=infer_dir,
feeded_var_names=[i.name for i in inputs],
target_vars=outputs,
executor=exe)
else:
self.gen_dygraph_code(save_dir)
self.dump_dygraph_parameter(save_dir)
if input_shapes is not None:
# 如果input_shapes非空,则导出推理模型;其值类似[[None, 3, 224, 224]]
self.dygraph2static(save_dir, input_shapes)
def dump_parameter(self, param_name, param, save_dir): def dump_parameter(self, param_name, param, save_dir):
if not os.path.exists(save_dir): if not os.path.exists(save_dir):
...@@ -356,10 +401,10 @@ class PaddleGraph(object): ...@@ -356,10 +401,10 @@ class PaddleGraph(object):
if self.edges_in.get(layer_id, 0) == 0 and self.edges_out.get( if self.edges_in.get(layer_id, 0) == 0 and self.edges_out.get(
layer_id, 0) == 0: layer_id, 0) == 0:
continue continue
if layer.kernel == "fluid.dygraph.base.to_variable": if layer.kernel == "paddle.to_tensor":
value = layer.attrs["value"] data = layer.attrs["data"]
if not value.startswith("params["): if not data.startswith("params["):
self.inputs.append(value) self.inputs.append(data)
if len(layer.blocks) > 0: if len(layer.blocks) > 0:
for block in layer.blocks: for block in layer.blocks:
block.get_dygraph_inputs() block.get_dygraph_inputs()
...@@ -376,11 +421,15 @@ class PaddleGraph(object): ...@@ -376,11 +421,15 @@ class PaddleGraph(object):
layer_id, 0) == 0: layer_id, 0) == 0:
continue continue
if self.edges_out.get(layer_id, 0) == 0: if self.edges_out.get(layer_id, 0) == 0:
for output_name in layer.outputs:
if not output_name.startswith("x"): for i, output_name in enumerate(layer.outputs):
continue if ("paddle.nn" in layer.kernel and "functional" not in layer.kernel) or \
self.outputs.append(output_name) (layer.kernel == "paddle.to_tensor" and layer.attrs["data"].startswith("params["))or \
self.outputs = list(set(self.outputs)) "paddle.fluid.dygraph" in layer.kernel:
if i == 0:
continue
if output_name not in self.outputs:
self.outputs.append(output_name)
def gen_dygraph_code(self, code_dir=None, indent=2): def gen_dygraph_code(self, code_dir=None, indent=2):
def gen_codes(code_list, indent=0): def gen_codes(code_list, indent=0):
...@@ -394,20 +443,24 @@ class PaddleGraph(object): ...@@ -394,20 +443,24 @@ class PaddleGraph(object):
return codes return codes
def gen_head(): def gen_head():
if self.source_type == "caffe":
custom_import = "from x2paddle.op_mapper.dygraph.caffe2paddle " + \
"import caffe_custom_layer as x2paddle_nn"
self.head = gen_codes( self.head = gen_codes(
[ [
"from paddle.fluid.initializer import Constant", "from paddle.fluid.initializer import Constant",
"from paddle.fluid.param_attr import ParamAttr", "from paddle.fluid.param_attr import ParamAttr",
"import paddle", "import paddle",
"import paddle.fluid as fluid", "import paddle.fluid as fluid",
custom_import,
"", "",
"class {}(fluid.dygraph.Layer):".format(self.name), "class {}(paddle.nn.Layer):".format(self.name),
], ],
indent=0) indent=0)
input_data_name = ', '.join(self.inputs) input_data_name = ', '.join(self.inputs)
self.init_func.extend( self.init_func.extend(
gen_codes( gen_codes(
["def __init__(self, params):"], indent=1)) ["def __init__(self):"], indent=1))
self.init_func.extend( self.init_func.extend(
gen_codes( gen_codes(
["super({}, self).__init__()".format(self.name)], indent=2)) ["super({}, self).__init__()".format(self.name)], indent=2))
...@@ -415,6 +468,31 @@ class PaddleGraph(object): ...@@ -415,6 +468,31 @@ class PaddleGraph(object):
gen_codes( gen_codes(
["def forward(self, {}):".format(input_data_name)], ["def forward(self, {}):".format(input_data_name)],
indent=1)) indent=1))
def gen_main_code(code_dir):
input_data_name = ', '.join(self.inputs)
self.run_func = gen_codes(
[
"",
"def main({}):".format(input_data_name),
],
indent=0)
comment_list = list()
comment_list.append("# 共{}个输入".format(len(self.inputs_info)))
for k, v in self.inputs_info.items():
comment_list.append("# {}: 形状为{},类型为{}。".format(k, v[0], v[1]))
self.run_func.extend(
gen_codes(
comment_list,
indent=1))
self.run_func.extend(
gen_codes(["paddle.disable_static()",
"params, _ = fluid.load_dygraph('{}/model')".format(code_dir),
"model = {}()".format(self.name),
"model.set_dict(params)",
"model.eval()",
"out = model({})".format(input_data_name),
"return out"], indent=1))
def write_code(code_dir): def write_code(code_dir):
f = open(os.path.join(code_dir, 'x2paddle_code.py'), 'w') f = open(os.path.join(code_dir, 'x2paddle_code.py'), 'w')
...@@ -431,6 +509,8 @@ class PaddleGraph(object): ...@@ -431,6 +509,8 @@ class PaddleGraph(object):
self.forward_func.extend(gen_codes([return_code], indent=2)) self.forward_func.extend(gen_codes([return_code], indent=2))
for code_line in self.forward_func: for code_line in self.forward_func:
f.write(code_line) f.write(code_line)
for code_line in self.run_func:
f.write(code_line)
f.close() f.close()
self.init_func = [] self.init_func = []
...@@ -440,21 +520,25 @@ class PaddleGraph(object): ...@@ -440,21 +520,25 @@ class PaddleGraph(object):
for layer_id, layer in self.layers.items(): for layer_id, layer in self.layers.items():
if ("paddle.nn" in layer.kernel and "functional" not in layer.kernel if ("paddle.nn" in layer.kernel and "functional" not in layer.kernel
) or layer.kernel == "fluid.dygraph.base.to_variable" or \ ) or layer.kernel == "paddle.to_tensor" or \
"paddle.fluid.dygraph" in layer.kernel: "paddle.fluid.dygraph" in layer.kernel or \
layer.kernel.startswith("custom_layer"):
line = "{}".format( line = "{}".format(
layer.outputs[0] layer.outputs[0]
) if layer.kernel == "fluid.dygraph.base.to_variable" and not layer.attrs[ ) if layer.kernel == "paddle.to_tensor" and not layer.attrs[
"value"].startswith("params[") else "self.{}".format( "data"].startswith("params[") else "self.{}".format(
layer.outputs[0]) layer.outputs[0])
line += " = {}(".format(layer.kernel) if layer.kernel.startswith("custom_layer"):
line += "= x2paddle_nn.{}(".format(layer.kernel.split(":")[-1])
else:
line += " = {}(".format(layer.kernel)
for k, v in layer.attrs.items(): for k, v in layer.attrs.items():
line += "{}={}, ".format(k, v) line += "{}={}, ".format(k, v)
line = line.strip(", ") line = line.strip(", ")
line += ")" line += ")"
if layer.kernel == "fluid.dygraph.base.to_variable" and not layer.attrs[ if layer.kernel == "paddle.to_tensor" and not layer.attrs[
"value"].startswith("params["): "data"].startswith("params["):
self.forward_func.extend(gen_codes([line], indent=indent)) self.forward_func.extend(gen_codes([line], indent=indent))
continue continue
else: else:
...@@ -466,8 +550,8 @@ class PaddleGraph(object): ...@@ -466,8 +550,8 @@ class PaddleGraph(object):
line = layer.outputs[1] line = layer.outputs[1]
else: else:
line = ','.join(layer.outputs[1:]) line = ','.join(layer.outputs[1:])
if layer.kernel == "fluid.dygraph.base.to_variable" and layer.attrs[ if layer.kernel == "paddle.to_tensor" and layer.attrs[
"value"].startswith("params["): "data"].startswith("params["):
line += " = self.{}".format(layer.outputs[0]) line += " = self.{}".format(layer.outputs[0])
else: else:
line += " = self.{}(".format(layer.outputs[0]) line += " = self.{}(".format(layer.outputs[0])
...@@ -475,10 +559,10 @@ class PaddleGraph(object): ...@@ -475,10 +559,10 @@ class PaddleGraph(object):
line += "{}, ".format(v) line += "{}, ".format(v)
line = line.strip(", ") line = line.strip(", ")
line += ")" line += ")"
self.forward_func.extend(gen_codes([line], indent=indent)) self.forward_func.extend(gen_codes([line], indent=indent))
elif "prim" in layer.kernel: elif "prim" in layer.kernel:
func_name = layer.kernel.replace(".", "_") func_name = layer.kernel.replace(".", "_")
from x2paddle.op_mapper.pytorch2paddle import prim2code from x2paddle.op_mapper.dygraph import prim2code
if hasattr(prim2code, func_name): if hasattr(prim2code, func_name):
func = getattr(prim2code, func_name) func = getattr(prim2code, func_name)
func( func(
...@@ -502,8 +586,14 @@ class PaddleGraph(object): ...@@ -502,8 +586,14 @@ class PaddleGraph(object):
line += "{}={}, ".format(k, v) line += "{}={}, ".format(k, v)
line = line.strip(", ") line = line.strip(", ")
line += ")" line += ")"
self.forward_func.extend(gen_codes([line], indent=indent)) if layer.kernel == "self.create_parameter":
self.init_func.extend(gen_codes(["self." + line], indent=indent))
self.forward_func.extend(gen_codes(["{} = self.{}".format(layer.outputs[0],
layer.outputs[0])], indent=indent))
else:
self.forward_func.extend(gen_codes([line], indent=indent))
if indent == 2: if indent == 2:
gen_main_code(code_dir)
write_code(code_dir) write_code(code_dir)
else: else:
return self.init_func, self.forward_func return self.init_func, self.forward_func
...@@ -513,23 +603,22 @@ class PaddleGraph(object): ...@@ -513,23 +603,22 @@ class PaddleGraph(object):
pickle.dump(self.parameters, params_output) pickle.dump(self.parameters, params_output)
params_output.close() params_output.close()
def dygraph2static(self, save_dir, input_shapes=[]): def dygraph2static(self, save_dir, input_shapes=[], input_types=[]):
from paddle.fluid.dygraph.jit import declarative from paddle.fluid.dygraph.jit import declarative
sepc_list = list() sepc_list = list()
for i, name in enumerate(self.inputs): for i, name in enumerate(self.inputs):
input_shapes[i][0] = -1
sepc_list.append( sepc_list.append(
paddle.static.InputSpec( paddle.static.InputSpec(
shape=input_shapes[i], name=name)) shape=input_shapes[i], name=name, dtype=input_types[i]))
import sys import sys
path = osp.abspath(save_dir) path = osp.abspath(save_dir)
sys.path.insert(0, save_dir) sys.path.insert(0, save_dir)
import x2paddle_code import x2paddle_code
place = fluid.CPUPlace() paddle.disable_static()
with fluid.dygraph.guard(place): restore, _ = fluid.load_dygraph(osp.join(save_dir, "model"))
restore, _ = fluid.load_dygraph(osp.join(save_dir, "model")) model = getattr(x2paddle_code, self.name)()
model = getattr(x2paddle_code, self.name)(restore) model.set_dict(restore)
model.set_dict(restore) model.eval()
model.eval() static_model = paddle.jit.to_static(model, input_spec=sepc_list)
model.forward = declarative(model.forward, sepc_list) paddle.jit.save(static_model, osp.join(save_dir, "inference_model/model"))
fluid.dygraph.jit.save( \ No newline at end of file
layer=model, model_path=osp.join(save_dir, "inference"))
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -18,3 +18,11 @@ import os ...@@ -18,3 +18,11 @@ import os
def string(param): def string(param):
return "\'{}\'".format(param) return "\'{}\'".format(param)
def name_generator(nn_name, nn_name2id):
if nn_name in nn_name2id:
nn_name2id[nn_name] += 1
else:
nn_name2id[nn_name] = 0
real_nn_name = nn_name + str(nn_name2id[nn_name])
return real_nn_name
\ No newline at end of file
...@@ -18,7 +18,6 @@ from google.protobuf import text_format ...@@ -18,7 +18,6 @@ from google.protobuf import text_format
import numpy as np import numpy as np
from x2paddle.core.graph import GraphNode, Graph from x2paddle.core.graph import GraphNode, Graph
from x2paddle.core.fluid_code import FluidCode from x2paddle.core.fluid_code import FluidCode
from x2paddle.op_mapper import caffe_shape
class CaffeResolver(object): class CaffeResolver(object):
...@@ -50,10 +49,10 @@ class CaffeGraphNode(GraphNode): ...@@ -50,10 +49,10 @@ class CaffeGraphNode(GraphNode):
def __init__(self, layer, type_str, layer_name=None): def __init__(self, layer, type_str, layer_name=None):
if layer_name is None: if layer_name is None:
super(CaffeGraphNode, self).__init__( super(CaffeGraphNode, self).__init__(
layer, layer.name.replace('/', '_').replace('-', '_')) layer, layer.name.replace('/', '_').replace('-', '_').lower())
else: else:
super(CaffeGraphNode, self).__init__( super(CaffeGraphNode, self).__init__(
layer, layer_name.replace('/', '_').replace('-', '_')) layer, layer_name.replace('/', '_').replace('-', '_').lower())
self.layer_type = type_str self.layer_type = type_str
self.fluid_code = FluidCode() self.fluid_code = FluidCode()
self.data = None self.data = None
...@@ -66,6 +65,13 @@ class CaffeGraph(Graph): ...@@ -66,6 +65,13 @@ class CaffeGraph(Graph):
def __init__(self, model, params, caffe_pb): def __init__(self, model, params, caffe_pb):
self.params = params self.params = params
self.caffe_pb = caffe_pb self.caffe_pb = caffe_pb
if hasattr(model, "name"):
if model.name == "":
self.graph_name = "CaffeModel"
else:
self.graph_name = model.name
else:
self.graph_name = "CaffeModel"
super(CaffeGraph, self).__init__(model) super(CaffeGraph, self).__init__(model)
def filter_layers(self, layers): def filter_layers(self, layers):
...@@ -242,7 +248,7 @@ class CaffeDecoder(object): ...@@ -242,7 +248,7 @@ class CaffeDecoder(object):
with open(proto_path, 'rb') as proto_file: with open(proto_path, 'rb') as proto_file:
proto_str = proto_file.read() proto_str = proto_file.read()
text_format.Merge(proto_str, self.net) text_format.Merge(proto_str, self.net)
self.load_using_pb() self.load_using_pb()
self.caffe_graph = CaffeGraph(self.net, self.params, self.caffe_graph = CaffeGraph(self.net, self.params,
......
...@@ -12,15 +12,14 @@ ...@@ -12,15 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import sys
import torch import torch
import numpy as np
class PyTorchDecoder(object): class Decoder(object):
def __init__(self, script_path): def _optimize_graph(self, graph):
self.script = torch.jit.load(script_path)
self.graph = self._optimize_graph(self.script.inlined_graph)
def _optimize_graph(self, graph):
torch._C._jit_pass_constant_propagation(graph) torch._C._jit_pass_constant_propagation(graph)
torch._C._jit_pass_dce(graph) torch._C._jit_pass_dce(graph)
torch._C._jit_pass_lint(graph) torch._C._jit_pass_lint(graph)
...@@ -31,4 +30,37 @@ class PyTorchDecoder(object): ...@@ -31,4 +30,37 @@ class PyTorchDecoder(object):
torch._C._jit_pass_canonicalize(graph) torch._C._jit_pass_canonicalize(graph)
torch._C._jit_pass_lint(graph) torch._C._jit_pass_lint(graph)
torch._C._jit_pass_constant_propagation(graph) torch._C._jit_pass_constant_propagation(graph)
return graph return graph
class ScriptDecoder(Decoder):
""" 当script_path非None,直接load ScriptModule;
当model_path非None,load PyTorchModule后使用script方式转换为ScriptModule。
Args:
script_path (str): ScriptModule保存路径。
model_path (str): PyTorchModule保存路径。
"""
def __init__(self, script_path=None):
self.script = torch.jit.load(script_path)
self.graph = self._optimize_graph(self.script.inlined_graph)
class TraceDecoder(Decoder):
""" PyTorchModule后使用trace方式转换为ScriptModule。
Args:
model_path (str): PyTorchModule保存路径。
input_files (list): 输入网络的numpy,每个numpy保存成.npy文件,
文件路径存储在input_files中。
"""
def __init__(self, model_path, input_files=list()):
# TODO(syf): 传入pytorch的Module(即import),否则出错
model = torch.load(model_path)
model.eval()
input_list = list()
for npy_file in input_files:
input_list.append(torch.tensor(np.load(npy_file)))
self.script = torch.jit.trace(model, input_list, strict=False)
self.graph = self._optimize_graph(self.script.inlined_graph)
# print(self.graph)
# print(getattr(getattr(self.script.decoder.block, "5").layer, "2"))
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .detectionoutput import DetectionOutput
from .normalize import Normalize
from .priorbox import PriorBox
from .roipooling import ROIPooling
from .select import Select
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
class DetectionOutput(object):
def __init__(self, nms_threshold, nms_top_k, keep_top_k, nms_eta, score_threshold, background_label):
self.detection_output_layer_attrs = {
"background_label": background_label,
"nms_threshold": nms_threshold,
"nms_top_k": nms_top_k,
"keep_top_k": keep_top_k,
"score_threshold": score_threshold,
"nms_eta": nms_eta}
def __call__(self, x0, x1, x2):
priorbox_list = paddle.split(x2, num_or_sections=2, axis=1)
pb = priorbox_list[0]
pbv = priorbox_list[1]
pb = paddle.reshape(x=pb, shape=[-1, 4])
pbv = paddle.reshape(x=pbv, shape=[-1, 4])
pb_dim = fluid.layers.shape(pb)[0]
loc = paddle.reshape(x0, shape=[-1, pb_dim, 4])
conf_flatten = paddle.reshape(x1, shape=[0, pb_dim, -1])
out = fluid.layers.detection_output(loc=loc,
scores=conf_flatten,
prior_box=pb,
prior_box_var=pbv,
**self.detection_output_layer_attrs)
return out
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
class Normalize(object):
def __init__(self, axis, param_name, param_shape):
self.axis = axis
self.param_name = param_name
self.param_shape = param_shape
def __call__(self, x):
l2 = fluid.layers.prior_box(x=x, p=2, axis=1)
attr = fluid.ParamAttr(name=self.param_name, trainable=False)
param = paddle.nn.Layer.create_parameter(shape=self.param_shape,
attr=atr)
out = paddle.multiply(x=l2, y=param, axis=self.axis)
return out
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
class PriorBox(object):
def __init__(self, min_sizes, max_sizes,
aspect_ratios, variance, flip,
clip, steps, offset,
min_max_aspect_ratios_order):
self.priorbox_layer_attrs = {
"min_sizes": min_sizes,
"max_sizes": max_sizes,
"aspect_ratios": aspect_ratios,
"variance": variance,
"flip": flip,
"clip": clip,
"steps": steps,
"offset": offset,
"min_max_aspect_ratios_order": min_max_aspect_ratios_order}
def __call__(self, x0, x1):
box, var = fluid.layers.prior_box(input=x0,
image=x1,
**self.priorbox_layer_attrs)
box = paddle.reshape(x=box, shape=[1, 1, -1])
var = paddle.reshape(x=var, shape=[1, 1, -1])
out = paddle.concat(x=[box, var], axis=1)
return out
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
class ROIPooling(object):
def __init__(self, pooled_height, pooled_width, spatial_scale):
self.roipooling_layer_attrs = {
"pooled_height": pooled_height,
"pooled_width": pooled_width,
"spatial_scale": spatial_scale}
def __call__(self, x0, x1):
slice_x1 = paddle.slice(input=x1, axes=[1],
starts=[1], ends=[5])
out = fluid.layers.roi_pool(input=x0,
rois=slice_x1,
**self.roipooling_layer_attrs)
return out
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
class Select(object):
def __init__(self, input_shape, point, axis):
self.point = point
self.input_shape = input_shape
self.axis = axis
def __call__(self, x):
start = self.point[0]
if len(self.point) == 2:
end = self.point[1]
else:
end = self.input_shape[self.axis]
out = paddle.slice(x=x,
start=start,
end=end,
axes=[self.axis])
return out
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numbers
from functools import reduce
def get_kernel_parameters(params):
[k_h, k_w] = [1, 1]
if isinstance(params.kernel_size, numbers.Number):
[k_h, k_w] = [params.kernel_size] * 2
elif len(params.kernel_size) > 0:
k_h = params.kernel_h if params.kernel_h > 0 else params.kernel_size[0]
k_w = params.kernel_w if params.kernel_w > 0 else params.kernel_size[
len(params.kernel_size) - 1]
elif params.kernel_h > 0 or params.kernel_w > 0:
k_h = params.kernel_h
k_w = params.kernel_w
[s_h, s_w] = [1, 1]
if isinstance(params.stride, numbers.Number):
[s_h, s_w] = [params.stride] * 2
elif len(params.stride) > 0:
s_h = params.stride_h if params.stride_h > 0 else params.stride[0]
s_w = params.stride_w if params.stride_w > 0 else params.stride[len(
params.stride) - 1]
elif params.stride_h > 0 or params.stride_w > 0:
s_h = params.stride_h
s_w = params.stride_w
[p_h, p_w] = [0, 0]
if isinstance(params.pad, numbers.Number):
[p_h, p_w] = [params.pad] * 2
elif len(params.pad) > 0:
p_h = params.pad_h if params.pad_h > 0 else params.pad[0]
p_w = params.pad_w if params.pad_w > 0 else params.pad[len(params.pad) -
1]
elif params.pad_h > 0 or params.pad_w > 0:
p_h = params.pad_h
p_w = params.pad_w
dila_h = dila_w = 1
if hasattr(params, 'dilation'):
dila_len = len(params.dilation)
if dila_len == 2:
dila_h = params.dilation[0]
dila_w = params.dilation[1]
elif dila_len == 1:
dila_h = dila_w = params.dilation[0]
else:
assert dila_len == 0, "invalid length[%s] of dilation in convolution" % (
dila_len)
return dila_h, dila_w, p_h, p_w, k_h, k_w, s_h, s_w
def get_strided_kernel_output_shape(params, input_shape, round_func):
i_h = input_shape[2]
i_w = input_shape[3]
dila_h, dila_w, pad_h, pad_w, kernel_h, kernel_w, stride_h, stride_w = get_kernel_parameters(
params)
o_h = (i_h + 2 * pad_h - (dila_h *
(kernel_h - 1) + 1)) / float(stride_h) + 1
o_w = (i_w + 2 * pad_w - (dila_w *
(kernel_w - 1) + 1)) / float(stride_w) + 1
o_h = int(round_func(o_h))
o_w = int(round_func(o_w))
has_c_o = hasattr(params, 'num_output')
c = params.num_output if has_c_o else input_shape[1]
return [[input_shape[0], c, o_h, o_w]]
def shape_convolution(layer, input_shape):
params = layer.convolution_param
return get_strided_kernel_output_shape(params, input_shape[0], math.floor)
def shape_deconvolution(layer, input_shape):
h_i = input_shape[0][2]
w_i = input_shape[0][3]
params = layer.convolution_param
dila_h, dila_w, pad_h, pad_w, kernel_h, kernel_w, stride_h, stride_w = get_kernel_parameters(
params)
h_o = (h_i - 1) * stride_h - 2 * pad_h + dila_h * (kernel_h - 1) + 1
w_o = (w_i - 1) * stride_w - 2 * pad_w + dila_w * (kernel_w - 1) + 1
has_c_o = hasattr(params, 'num_output')
c = params.num_output if has_c_o else input_shape.channels
return [[input_shape[0][0], c, h_o, w_o]]
def shape_pooling(layer, input_shape):
params = layer.pooling_param
global_pool = getattr(params, 'global_pooling', False)
if global_pool:
return [[input_shape[0][0], input_shape[0][1], 1, 1]]
ceil_mode = getattr(params, 'ceil_mode', True)
if ceil_mode is True:
method = math.ceil
else:
method = math.floor
return get_strided_kernel_output_shape(params, input_shape[0], method)
def shape_convolutiondepthwise(layer, input_shape):
params = layer.convolution_param
return get_strided_kernel_output_shape(params, input_shape[0], math.floor)
def shape_innerproduct(layer, input_shape):
params = layer.inner_product_param
return [[input_shape[0][0], params.num_output]]
def shape_lrn(layer, input_shape):
return input_shape
def shape_relu(layer, input_shape):
return input_shape
def shape_softmax(layer, input_shape):
return input_shape
def shape_input(layer, input_shape):
return [list(layer.input_param.shape[0].dim)]
def shape_memorydata(layer, input_shape):
params = layer.memory_data_param
shape = []
shape.append(int(params.batch_size))
shape.append(int(params.channels))
shape.append(int(params.height))
shape.append(int(params.width))
return [shape]
def shape_concat(layer, input_shape):
params = layer.concat_param
axis = params.axis
output_shape = None
for shape in input_shape:
if output_shape is None:
output_shape = []
for i in range(len(shape)):
output_shape.append(shape[i])
else:
output_shape[axis] += shape[axis]
return [output_shape]
def shape_slice(layer, input_shape):
inshape = input_shape[0]
top_len = len(layer.top)
params = layer.slice_param
axis = params.axis
slice_dim = params.slice_dim
if slice_dim != 1 and axis == 1:
axis = slice_dim
points = list(params.slice_point)
count = inshape[axis]
if len(points) == 0:
assert count % top_len == 0, "the parameter of Slice is wrong"
part = count / top_len
t = part
while t < count:
points.append(int(t))
t += part
points = [0] + points + [count]
output_shape = []
for i in range(len(points)):
shape = []
for ii in range(len(inshape)):
shape.append(inshape[ii])
size = points[i + 1] - points[i]
shape[axis] = size
output_shape.append(shape)
if i == len(points) - 2:
break
return output_shape
def shape_prelu(layer, input_shape):
return input_shape
def shape_sigmoid(layer, input_shape):
return input_shape
def shape_absval(layer, input_shape):
return input_shape
def shape_accuracy(layer, input_shape):
return [[1]]
def shape_tanh(layer, input_shape):
return input_shape
def shape_eltwise(layer, input_shape):
return [input_shape[0]]
def shape_batchnorm(layer, input_shape):
return input_shape
def shape_scale(layer, input_shape):
return input_shape
def shape_reshape(layer, input_shape):
def count(num_list):
return reduce(lambda a, b: a * b, num_list)
inshape = input_shape[0]
params = layer.reshape_param
axis = params.axis if hasattr(params, 'axis') else 0
num_axes = params.num_axes if hasattr(params, 'num_axes') else -1
if inshape[0] == -1:
inshape[0] = 1
input_count = count(inshape)
input_num_axes = len(inshape)
input_start_axis = axis
start_axis = input_start_axis if input_start_axis >= 0 \
else input_num_axes + input_start_axis + 1
assert start_axis >= 0, "[Reshape]axis %d out of range" % (input_start_axis)
assert start_axis <= input_num_axes, "[Reshape]axis %d out of range for %d-D input data"\
% (input_start_axis, input_num_axes)
assert num_axes >= -1, "[Reshape]num_axes must be >= 0, or -1 for all"
end_axis = input_num_axes if num_axes == -1 else start_axis + num_axes
assert end_axis <= input_num_axes, "end_axis[%d] = axis[%d] + num_axes[%d] is out of range"\
% (end_axis, start_axis, num_axes)
num_axes_replaced = end_axis - start_axis
num_axes_retained = input_num_axes - num_axes_replaced
num_new_axes = len(list(params.shape.dim))
output_shape = []
for i in range(start_axis):
output_shape.append(inshape[i])
for i in range(num_new_axes):
output_shape.append(params.shape.dim[i])
for i in range(end_axis, input_num_axes):
output_shape.append(inshape[i])
assert len(output_shape) == num_axes_retained + num_new_axes,\
"[Reshape]invalid dims of output shape[%s]" % (str(output_shape))
inferred_axis = -1
copy_axes = []
constant_count = 1
for i in range(num_new_axes):
top_dim = params.shape.dim[i]
if top_dim == 0:
copy_axes.append(i)
copy_axis_index = start_axis + i
output_shape[copy_axis_index] = inshape[copy_axis_index]
elif top_dim == -1:
assert inferred_axis == -1, "[Reshape]new shape contains multiple -1 dims"
inferred_axis = i
else:
constant_count *= top_dim
if inferred_axis >= 0:
explicit_count = constant_count
l = inshape[0:start_axis]
if len(l) > 0:
explicit_count *= count(l)
l = inshape[end_axis:]
if len(l) > 0:
explicit_count *= count(l)
for i in range(len(copy_axes)):
explicit_count *= output_shape[start_axis + copy_axes[i]]
assert input_count % explicit_count == 0, "[Reshape]botom count[%d] "\
"must be divisible by product of the specified dimensions[%d] "\
% (input_count, explicit_count)
output_shape[start_axis + inferred_axis] = int(input_count / explicit_count)
output_count = count(output_shape)
assert output_count == input_count, "[Reshape]output count[%d] must match input count[%d]" % (
output_count, input_count)
output_shape[0] = -1
return [output_shape]
def shape_argmax(layer, input_shape):
inshape = input_shape[0]
params = layer.argmax_param
out_max_val = params.out_max_val if hasattr(params, out_max_val) else False
top_k = params.top_k if hasattr(params, top_k) else 1
axis = parmas.axis if hasattr(params, axis) else -1
if axis < 0:
axis += len(inshape)
assert (axis + 1 == len(inshape)
), 'only can be applied on the last dimension[axis:%d, %s] now,'\
'make sure you have set axis param in xxx.prototxt file' \
% (axis, str(inshape))
output_shape = inshape
output_shape[-1] = top_k
if out_max_val is True:
output_shape[-1] *= 2
return [output_shape]
def shape_crop(layer, input_shape):
assert len(input_shape) == 2, "the number of crop's inputs must be 2"
return [input_shape[1]]
def shape_flatten(layer, input_shape):
assert len(input_shape) == 1, "the number of flatten's inputs must be 1"
inshape = input_shape[0]
params = layer.flatten_param
start_axis = params.axis
end_axis = params.end_axis
if start_axis < 0:
start_axis += len(inshape)
if end_axis < 0:
end_axis += len(inshape) + 1
assert start_axis <= end_axis, 'invalid axis[%d] or end_axis[%d] params'\
% (start_axis, end_axis)
output_shape = inshape[0:start_axis]
if len(inshape[start_axis:end_axis]) != 0:
flat_sz = reduce(lambda a, b: a * b, inshape[start_axis:end_axis])
output_shape += [flat_sz]
output_shape += inshape[end_axis:len(inshape)]
output_shape[0] = -1
return [output_shape]
def shape_power(layer, input_shape):
return input_shape
def shape_reduction(layer, input_shape):
params = layer.reduction_param
axis = params.axis
if axis < 0:
axis += len(input_shape[0]) + 1
assert axis <= len(input_shape[0]), 'invalid axis[%d] error' % (axis)
return [input_shape[0:axis]]
def shape_axpy(layer, input_shape):
assert len(input_shapes) == 3, "not valid input shape for axpy layer"
assert len(input_shapes[0]) == len(input_shapes[1]), 'should have same dims'
output_shape = input_shapes[1]
assert (input_shapes[2] == output_shape),\
"shape not consistent for axpy[%s <--> %s]" \
% (str(output_shape), str(input_shapes[2]))
return [output_shape]
def shape_detectionoutput(layer, input_shape):
return [[-1, 6]]
def shape_normalize(layer, input_shape):
return input_shape
def shape_permute(layer, input_shape):
order = layer.permute_param.order
inshape = input_shape[0]
output_shape = []
order = list(order)
for ii in order:
assert ii < len(inshape), "invalid order for permute[%s]" % (name)
output_shape.append(inshape[ii])
return [output_shape]
def shape_priorbox(layer, input_shape):
max_size = layer.prior_box_param.max_size
aspect_ratio = layer.prior_box_param.aspect_ratio
fc_shape = input_shape[0]
N = 1
if not max_size == None:
N += 1
if not aspect_ratio == None:
N += 2 * len(aspect_ratio)
N_bbx = fc_shape[2] * fc_shape[3] * N
output_shape = [1, 2, 4 * N_bbx]
return [output_shape]
def shape_relu6(layer, input_shape):
return input_shape
def shape_roipooling(layer, input_shape):
pooled_w = layer.roi_pooling_param.pooled_w
pooled_h = layer.roi_pooling_param.pooled_h
base_fea_shape = input_shapes[0]
rois_shape = input_shapes[1]
output_shape = base_fea_shape
output_shape[0] = rois_shape[0]
output_shape[2] = pooled_h
output_shape[3] = pooled_w
return [output_shape]
def shape_shufflechannel(layer, input_shape):
return input_shape
def shape_upsample(layer, input_shape):
scale = layer.upsample_param.scale
assert len(input_shapes) == 1, "not valid input shape for upsample layer"
assert type(scale) is int
input_shape = input_shapes[0]
new_h = scale * input_shape[2]
new_w = scale * input_shape[3]
output_shape = [input_shape[0], input_shape[1], new_h, new_w]
return [output_shape]
def shape_select(layer, input_shape):
slice_point = layer.select_param.slice_point
axis = layer.select_param.axis
input_shape = input_shapes[0]
start = slice_point[0]
if len(slice_point) == 2:
end = slice_point[1]
else:
end = input_shape[axis]
assert end > start, "invalid slice_point with [start:%d, end:%d]"\
% (start, end)
output_shape = input_shape
output_shape[axis] = end - start
return [output_shape]
此差异已折叠。
...@@ -14,18 +14,6 @@ def detectionoutput_layer(inputs, ...@@ -14,18 +14,6 @@ def detectionoutput_layer(inputs,
confidence_threshold=0.1, confidence_threshold=0.1,
input_shape=None, input_shape=None,
name=None): name=None):
nms_param_str = nms_param
nms_param = {}
part = nms_param_str.split(',')
for s in part:
if s == '':
break
else:
name, obj = s.split(': ')
if name == 'top_k':
nms_param[name] = int(obj)
else:
nms_param[name] = float(obj)
if nms_param is None: if nms_param is None:
nms_param = {"nms_threshold": 0.3, "top_k": 10, "eta": 1.0} nms_param = {"nms_threshold": 0.3, "top_k": 10, "eta": 1.0}
mbox_conf_flatten = inputs[1] mbox_conf_flatten = inputs[1]
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.decoder.caffe_decoder import CaffeGraph
from x2paddle.core.util import *
class CaffeOptimizer(object):
layers_with_act = ['Convolution', 'Deconvolution', 'InnerProduct']
activation_ops = ['ReLU', 'Sigmoid']
def __init__(self, mapper):
self.graph = mapper.graph
def merge_bn_scale(self):
for node_name in self.graph.topo_sort:
node = self.graph.get_node(node_name)
if node.layer_type == 'Scale':
parent_node = self.graph.get_bottom_node(node, idx=0)
if parent_node.layer_type == 'BatchNorm':
is_delete_node = True if len(
parent_node.outputs) == 1 else False
parent_fluid_layer = parent_node.fluid_code.layers[0]
input = parent_fluid_layer.inputs
parent_param_attr = parent_fluid_layer.param_attr
parent_param_attr['param_attr'] = string(node.layer_name +
'_scale')
parent_param_attr['bias_attr'] = string(node.layer_name +
'_offset')
if is_delete_node:
parent_node.fluid_code.clear()
node.fluid_code.clear()
node.fluid_code.add_layer(
"batch_norm",
inputs=input,
output=node,
param_attr=parent_param_attr)
def merge_op_activation(self):
for node_name in self.graph.topo_sort:
node = self.graph.get_node(node_name)
if node.layer_type in self.activation_ops:
parent_node = self.graph.get_bottom_node(node, idx=0)
if parent_node.layer_type in self.layers_with_act:
is_delete_node = True if len(
parent_node.outputs) == 1 else False
parent_fluid_layer = parent_node.fluid_code.layers[0]
input = parent_fluid_layer.inputs
parent_param_attr = parent_fluid_layer.param_attr
parent_param_attr['act'] = string(node.layer_type.lower())
op = parent_fluid_layer.op
if is_delete_node:
parent_node.fluid_code.clear()
node.fluid_code.clear()
node.fluid_code.add_layer(
op,
inputs=input,
output=node,
param_attr=parent_param_attr)
...@@ -12,17 +12,19 @@ ...@@ -12,17 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .adaptive_pool2d_fuser import AdaptivePool2dFuser from .adaptive_pool2d_fuser import Dygraph_AdaptivePool2dFuser
from .adaptive_pool2d_fuse_pass import AdaptivePool2dFusePass from .adaptive_pool2d_fuse_pass import Dygraph_AdaptivePool2dFusePass
from .batchnorm2d_fuser import BatchNorm2dFuser from .batchnorm2d_fuser import Dygraph_BatchNorm2dFuser
from .batchnorm2d_fuse_pass import BatchNorm2dFusePass from .batchnorm2d_fuse_pass import Dygraph_BatchNorm2dFusePass
from .constant_fuser import ConstantFuser from .bn_scale_fuser import Dygraph_BNScaleFuser
from .constant_fuse_pass import ConstantFusePass from .bn_scale_fuse_pass import Dygraph_BNScaleFusePass
from .dropout_fuser import DropoutFuser from .constant_fuser import Dygraph_ConstantFuser
from .dropout_fuse_pass import DropoutFusePass from .constant_fuse_pass import Dygraph_ConstantFusePass
from .fc_fuser import FcFuser from .dropout_fuser import Dygraph_DropoutFuser
from .fc_fuse_pass import FcFusePass from .dropout_fuse_pass import Dygraph_DropoutFusePass
from .interpolate_bilinear_fuser import InterpolateBilinearFuser from .fc_fuser import Dygraph_FcFuser
from .interpolate_bilinear_fuse_pass import InterpolateBilinearFusePass from .fc_fuse_pass import Dygraph_FcFusePass
from .reshape_fuser import ReshapeFuser from .interpolate_bilinear_fuser import Dygraph_InterpolateBilinearFuser
from .reshape_fuse_pass import ReshapeFusePass from .interpolate_bilinear_fuse_pass import Dygraph_InterpolateBilinearFusePass
from .reshape_fuser import Dygraph_ReshapeFuser
from .reshape_fuse_pass import Dygraph_ReshapeFusePass
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.optimizer.pass_ import Pass
from x2paddle.optimizer.fusion.dygraph import Dygraph_AdaptivePool2dFuser
from x2paddle.optimizer.pass_manager import pass_register
@pass_register
class Dygraph_AdaptivePool2dFusePass(Pass):
name = "dygraph_adaptive_pool2d_fuse_pass"
def __init__(self):
Pass.__init__(self)
def apply(self, graph):
fuser = Dygraph_AdaptivePool2dFuser()
fuser.operate(graph, match_kind="topo")
# 用于注册
adaptive_pool2d_fuse_pass = Dygraph_AdaptivePool2dFusePass()
...@@ -13,14 +13,14 @@ ...@@ -13,14 +13,14 @@
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
from x2paddle.optimizer.pytorch_optimizer.pattern_matcher import FuseBase from x2paddle.optimizer.pattern_matcher import FuseBase
from x2paddle.core.program import PaddleGraph, PaddleLayer from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import * from x2paddle.core.util import *
class AdaptivePool2dFuser(FuseBase): class Dygraph_AdaptivePool2dFuser(FuseBase):
def __init__(self): def __init__(self):
super(AdaptivePool2dFuser, self).__init__(graph_type="dygraph") super(Dygraph_AdaptivePool2dFuser, self).__init__(graph_type="dygraph")
def build_pattern(self): def build_pattern(self):
""" 描述需要替换的adaptive pool2d图结构。 """ 描述需要替换的adaptive pool2d图结构。
......
...@@ -12,22 +12,22 @@ ...@@ -12,22 +12,22 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from x2paddle.optimizer.pytorch_optimizer.pass_ import Pass from x2paddle.optimizer.pass_ import Pass
from x2paddle.optimizer.pytorch_optimizer.fusion import BatchNorm2dFuser from x2paddle.optimizer.fusion.dygraph import Dygraph_BatchNorm2dFuser
from x2paddle.optimizer.pytorch_optimizer.pass_manager import pass_register from x2paddle.optimizer.pass_manager import pass_register
@pass_register @pass_register
class BatchNorm2dFusePass(Pass): class Dygraph_BatchNorm2dFusePass(Pass):
name = "batchnorm2d_fuse_pass" name = "dygraph_batchnorm2d_fuse_pass"
def __init__(self): def __init__(self):
Pass.__init__(self) Pass.__init__(self)
def apply(self, graph): def apply(self, graph):
fuser = BatchNorm2dFuser() fuser = Dygraph_BatchNorm2dFuser()
fuser.operate(graph, match_kind="topo") fuser.operate(graph, match_kind="topo")
# 用于注册 # 用于注册
batchnorm2d_fuse_pass = BatchNorm2dFusePass() batchnorm2d_fuse_pass = Dygraph_BatchNorm2dFusePass()
...@@ -13,14 +13,14 @@ ...@@ -13,14 +13,14 @@
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
from x2paddle.optimizer.pytorch_optimizer.pattern_matcher import FuseBase from x2paddle.optimizer.pattern_matcher import FuseBase
from x2paddle.core.program import PaddleGraph, PaddleLayer from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import * from x2paddle.core.util import *
class BatchNorm2dFuser(FuseBase): class Dygraph_BatchNorm2dFuser(FuseBase):
def __init__(self): def __init__(self):
super(BatchNorm2dFuser, self).__init__(graph_type="dygraph") super(Dygraph_BatchNorm2dFuser, self).__init__(graph_type="dygraph")
def build_pattern(self): def build_pattern(self):
""" 描述需要替换的batchnorm2d图结构。 """ 描述需要替换的batchnorm2d图结构。
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.optimizer.pass_ import Pass
from x2paddle.optimizer.fusion.dygraph import Dygraph_BNScaleFuser
from x2paddle.optimizer.pass_manager import pass_register
@pass_register
class Dygraph_BNScaleFusePass(Pass):
name = "dygraph_bn_scale_fuse_pass"
def __init__(self):
Pass.__init__(self)
def apply(self, graph):
fuser = Dygraph_BNScaleFuser()
fuser.operate(graph, match_kind="topo")
# 用于注册
bn_scale_fuse_pass = Dygraph_BNScaleFusePass()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from x2paddle.optimizer.pattern_matcher import FuseBase
from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import *
class Dygraph_BNScaleFuser(FuseBase):
def __init__(self):
super(Dygraph_BNScaleFuser, self).__init__(graph_type="dygraph")
def build_pattern(self):
""" 描述需要替换的batchnorm2d图结构。
batchnorm2d层模式python实现代码示例:
bn_conv1 = self.batchnorm0(conv1)
scale_conv1_cparam1 = self.scale_conv1_cparam1
scale_conv1_mul = paddle.multiply(x=bn_conv1, y=scale_conv1_cparam1, axis=1)
scale_conv1_cparam2 = self.scale_conv1_cparam2
scale_conv1 = fluid.layers.elementwise_add(x=scale_conv1_mul, y=scale_conv1_cparam2, axis=1)
"""
def gen_name(id):
return "x" + str(id)
self.pattern.add_layer(
"paddle.nn.BatchNorm2D",
inputs={"input": "bn-input-0"},
outputs=[gen_name(0)])
self.pattern.add_layer(
"self.create_parameter",
inputs={},
outputs=[gen_name(1)])
inputs_dict = {}
inputs_dict['x'] = gen_name(0)
inputs_dict['y'] = gen_name(1)
self.pattern.add_layer(
"paddle.multiply",
inputs=inputs_dict,
outputs=[gen_name(2)])
self.pattern.add_layer(
"self.create_parameter",
inputs={},
outputs=[gen_name(3)])
inputs_dict = {}
inputs_dict['x'] = gen_name(2)
inputs_dict['y'] = gen_name(3)
self.pattern.add_layer(
"fluid.layers.elementwise_add",
inputs=inputs_dict,
outputs=[gen_name(4)])
self.pattern.build(inputs={"input-0": "bn-input-0"})
def insert_new_layer(self, graph, parameters, matches):
new_layer = self.gen_new_layer(parameters, matches)
new_layer_id = list(matches.keys())[0]
graph.layers[new_layer_id] = new_layer
matches.pop(new_layer_id)
def gen_new_layer(self, parameters, matches):
layers_id = list(matches.keys())
layer = matches[layers_id[0]]
layer_inputs = layer.inputs
bn_name = layer.outputs[0]
layer_attrs = layer.attrs
layer_attrs.pop("weight_attr")
layer_attrs.pop("bias_attr")
layer = matches[layers_id[4]]
layer_outputs = [bn_name] + layer.outputs
layer = matches[layers_id[1]]
data0_name = layer.outputs[0]
data0_numpy = parameters.pop(data0_name)
parameters["{}.weight".format(layer_outputs[0])] = data0_numpy
layer = matches[layers_id[3]]
data1_name = layer.outputs[0]
data1_numpy = parameters.pop(data1_name)
parameters["{}.bias".format(layer_outputs[0])] = data1_numpy
new_layer = PaddleLayer(
layers_id[0],
"paddle.nn.BatchNorm2D",
inputs=layer_inputs,
outputs=layer_outputs,
**layer_attrs)
return new_layer
\ No newline at end of file
...@@ -12,22 +12,22 @@ ...@@ -12,22 +12,22 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from x2paddle.optimizer.pytorch_optimizer.pass_ import Pass from x2paddle.optimizer.pass_ import Pass
from x2paddle.optimizer.pytorch_optimizer.fusion import ConstantFuser from x2paddle.optimizer.fusion.dygraph import Dygraph_ConstantFuser
from x2paddle.optimizer.pytorch_optimizer.pass_manager import pass_register from x2paddle.optimizer.pass_manager import pass_register
@pass_register @pass_register
class ConstantFusePass(Pass): class Dygraph_ConstantFusePass(Pass):
name = "constant_fuse_pass" name = "dygraph_constant_fuse_pass"
def __init__(self): def __init__(self):
Pass.__init__(self) Pass.__init__(self)
def apply(self, graph): def apply(self, graph):
fuser = ConstantFuser() fuser = Dygraph_ConstantFuser()
fuser.operate(graph, match_kind="topo") fuser.operate(graph, match_kind="topo")
# 用于注册 # 用于注册
constant_fuse_pass = ConstantFuser() constant_fuse_pass = Dygraph_ConstantFuser()
...@@ -13,14 +13,14 @@ ...@@ -13,14 +13,14 @@
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
from x2paddle.optimizer.pytorch_optimizer.pattern_matcher import FuseBase from x2paddle.optimizer.pattern_matcher import FuseBase
from x2paddle.core.program import PaddleGraph, PaddleLayer from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import * from x2paddle.core.util import *
class ConstantFuser(FuseBase): class Dygraph_ConstantFuser(FuseBase):
def __init__(self): def __init__(self):
super(ConstantFuser, self).__init__(graph_type="dygraph") super(Dygraph_ConstantFuser, self).__init__(graph_type="dygraph")
def build_pattern(self): def build_pattern(self):
""" 描述需要替换的constant图结构。 """ 描述需要替换的constant图结构。
......
...@@ -12,22 +12,22 @@ ...@@ -12,22 +12,22 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from x2paddle.optimizer.pytorch_optimizer.pass_ import Pass from x2paddle.optimizer.pass_ import Pass
from x2paddle.optimizer.pytorch_optimizer.fusion import DropoutFuser from x2paddle.optimizer.fusion.dygraph import Dygraph_DropoutFuser
from x2paddle.optimizer.pytorch_optimizer.pass_manager import pass_register from x2paddle.optimizer.pass_manager import pass_register
@pass_register @pass_register
class DropoutFusePass(Pass): class Dygraph_DropoutFusePass(Pass):
name = "dropout_fuse_pass" name = "dygraph_dropout_fuse_pass"
def __init__(self): def __init__(self):
Pass.__init__(self) Pass.__init__(self)
def apply(self, graph): def apply(self, graph):
fuser = DropoutFuser() fuser = Dygraph_DropoutFuser()
fuser.operate(graph, match_kind="topo") fuser.operate(graph, match_kind="topo")
# 用于注册 # 用于注册
dropout_fuse_pass = DropoutFuser() dropout_fuse_pass = Dygraph_DropoutFuser()
...@@ -13,14 +13,14 @@ ...@@ -13,14 +13,14 @@
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
from x2paddle.optimizer.pytorch_optimizer.pattern_matcher import FuseBase from x2paddle.optimizer.pattern_matcher import FuseBase
from x2paddle.core.program import PaddleGraph, PaddleLayer from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import * from x2paddle.core.util import *
class DropoutFuser(FuseBase): class Dygraph_DropoutFuser(FuseBase):
def __init__(self): def __init__(self):
super(DropoutFuser, self).__init__(graph_type="dygraph") super(Dygraph_DropoutFuser, self).__init__(graph_type="dygraph")
def build_pattern(self): def build_pattern(self):
""" 描述需要替换的constant图结构。 """ 描述需要替换的constant图结构。
......
...@@ -12,22 +12,22 @@ ...@@ -12,22 +12,22 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from x2paddle.optimizer.pytorch_optimizer.pass_ import Pass from x2paddle.optimizer.pass_ import Pass
from x2paddle.optimizer.pytorch_optimizer.fusion import FcFuser from x2paddle.optimizer.fusion.dygraph import Dygraph_FcFuser
from x2paddle.optimizer.pytorch_optimizer.pass_manager import pass_register from x2paddle.optimizer.pass_manager import pass_register
@pass_register @pass_register
class FcFusePass(Pass): class Dygraph_FcFusePass(Pass):
name = "fc_fuse_pass" name = "dygraph_fc_fuse_pass"
def __init__(self): def __init__(self):
Pass.__init__(self) Pass.__init__(self)
def apply(self, graph): def apply(self, graph):
fuser = FcFuser() fuser = Dygraph_FcFuser()
fuser.operate(graph, match_kind="topo") fuser.operate(graph, match_kind="topo")
# 用于注册 # 用于注册
fc_fuse_pass = FcFusePass() fc_fuse_pass = Dygraph_FcFusePass()
...@@ -13,15 +13,15 @@ ...@@ -13,15 +13,15 @@
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
from x2paddle.optimizer.pytorch_optimizer.pattern_matcher import FuseBase from x2paddle.optimizer.pattern_matcher import FuseBase
from x2paddle.core.program import PaddleGraph, PaddleLayer from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import * from x2paddle.core.util import *
class FcFuser(FuseBase): class Dygraph_FcFuser(FuseBase):
def __init__(self): def __init__(self):
self.linear_index = 0 self.linear_index = 0
super(FcFuser, self).__init__(graph_type="dygraph") super(Dygraph_FcFuser, self).__init__(graph_type="dygraph")
def build_pattern(self): def build_pattern(self):
""" 描述需要替换的fc图结构。 """ 描述需要替换的fc图结构。
......
...@@ -12,22 +12,22 @@ ...@@ -12,22 +12,22 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from x2paddle.optimizer.pytorch_optimizer.pass_ import Pass from x2paddle.optimizer.pass_ import Pass
from x2paddle.optimizer.pytorch_optimizer.fusion import AdaptivePool2dFuser from x2paddle.optimizer.fusion.dygraph import Dygraph_InterpolateBilinearFuser
from x2paddle.optimizer.pytorch_optimizer.pass_manager import pass_register from x2paddle.optimizer.pass_manager import pass_register
@pass_register @pass_register
class AdaptivePool2dFusePass(Pass): class Dygraph_InterpolateBilinearFusePass(Pass):
name = "adaptive_pool2d_fuse_pass" name = "dygraph_interpolate_bilinear_fuse_pass"
def __init__(self): def __init__(self):
Pass.__init__(self) Pass.__init__(self)
def apply(self, graph): def apply(self, graph):
fuser = AdaptivePool2dFuser() fuser = Dygraph_InterpolateBilinearFuser()
fuser.operate(graph, match_kind="topo") fuser.operate(graph, match_kind="topo")
# 用于注册 # 用于注册
adaptive_pool2d_fuse_pass = AdaptivePool2dFusePass() interpolate_bilinear_fuse_pass = Dygraph_InterpolateBilinearFusePass()
...@@ -13,14 +13,14 @@ ...@@ -13,14 +13,14 @@
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
from x2paddle.optimizer.pytorch_optimizer.pattern_matcher import FuseBase from x2paddle.optimizer.pattern_matcher import FuseBase
from x2paddle.core.program import PaddleGraph, PaddleLayer from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import * from x2paddle.core.util import *
class InterpolateBilinearFuser(FuseBase): class Dygraph_InterpolateBilinearFuser(FuseBase):
def __init__(self): def __init__(self):
super(InterpolateBilinearFuser, self).__init__(graph_type="dygraph") super(Dygraph_InterpolateBilinearFuser, self).__init__(graph_type="dygraph")
import torch import torch
torch_version = torch.__version__ torch_version = torch.__version__
torch_version_part = torch_version.split(".") torch_version_part = torch_version.split(".")
......
...@@ -12,22 +12,22 @@ ...@@ -12,22 +12,22 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from x2paddle.optimizer.pytorch_optimizer.pass_ import Pass from x2paddle.optimizer.pass_ import Pass
from x2paddle.optimizer.pytorch_optimizer.fusion import ReshapeFuser from x2paddle.optimizer.fusion.dygraph import Dygraph_ReshapeFuser
from x2paddle.optimizer.pytorch_optimizer.pass_manager import pass_register from x2paddle.optimizer.pass_manager import pass_register
@pass_register @pass_register
class ReshapeFusePass(Pass): class Dygraph_ReshapeFusePass(Pass):
name = "reshape_fuse_pass" name = "dygraph_reshape_fuse_pass"
def __init__(self): def __init__(self):
Pass.__init__(self) Pass.__init__(self)
def apply(self, graph): def apply(self, graph):
fuser = ReshapeFuser() fuser = Dygraph_ReshapeFuser()
fuser.operate(graph, match_kind="edge") fuser.operate(graph, match_kind="edge")
# 用于注册 # 用于注册
reshape_fuse_pass = ReshapeFusePass() reshape_fuse_pass = Dygraph_ReshapeFusePass()
...@@ -13,14 +13,14 @@ ...@@ -13,14 +13,14 @@
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
from x2paddle.optimizer.pytorch_optimizer.pattern_matcher import FuseBase from x2paddle.optimizer.pattern_matcher import FuseBase
from x2paddle.core.program import PaddleGraph, PaddleLayer from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import * from x2paddle.core.util import *
class ReshapeFuser(FuseBase): class Dygraph_ReshapeFuser(FuseBase):
def __init__(self): def __init__(self):
super(ReshapeFuser, self).__init__(graph_type="dygraph") super(Dygraph_ReshapeFuser, self).__init__(graph_type="dygraph")
def build_pattern(self): def build_pattern(self):
""" 描述需要替换的reshape图结构。 """ 描述需要替换的reshape图结构。
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .bn_scale_fuser import Static_BNScaleFuser
from .bn_scale_fuse_pass import Static_BNScaleFusePass
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.optimizer.pass_ import Pass
from x2paddle.optimizer.fusion.static import Static_BNScaleFuser
from x2paddle.optimizer.pass_manager import pass_register
@pass_register
class Static_BNScaleFusePass(Pass):
name = "static_bn_scale_fuse_pass"
def __init__(self):
Pass.__init__(self)
def apply(self, graph):
fuser = Static_BNScaleFuser()
fuser.operate(graph, match_kind="topo")
# 用于注册
bn_scale_fuse_pass = Static_BNScaleFusePass()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from x2paddle.optimizer.pattern_matcher import FuseBase
from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import *
class Static_BNScaleFuser(FuseBase):
def __init__(self):
super(Static_BNScaleFuser, self).__init__(graph_type="dygraph")
def build_pattern(self):
""" 描述需要替换的batchnorm2d图结构。
batchnorm2d层模式python实现代码示例:
conv5_bn = fluid.layers.batch_norm(input=conv5, is_test=True, param_attr=None, bias_attr=None, moving_mean_name='conv5_bn_mean', moving_variance_name='conv5_bn_variance', epsilon=9.999999747378752e-06, name='conv5_bn')
conv5_scale_scale = fluid.ParamAttr(name='conv5_scale_scale')
conv5_scale_cparam1 = fluid.layers.create_parameter(attr=conv5_scale_scale, dtype=conv5_bn.dtype, shape=[256], name='conv5_scale_cparam1', is_bias=True, default_initializer=Constant(value=1.0))
conv5_scale_mul = fluid.layers.elementwise_mul(x=conv5_bn, y=conv5_scale_cparam1, axis=1)
conv5_scale_offset = fluid.ParamAttr(name='conv5_scale_offset')
conv5_scale_cparam2 = fluid.layers.create_parameter(attr=conv5_scale_offset, dtype=conv5_bn.dtype, shape=[256], name='conv5_scale_cparam2', is_bias=True, default_initializer=Constant(value=1.0))
conv5_scale = fluid.layers.elementwise_add(x=conv5_scale_mul, y=conv5_scale_cparam2, axis=1)
"""
def gen_name(id):
return "x" + str(id)
self.pattern.add_layer(
"fluid.layers.batch_norm",
inputs={"input": "bn-input-0"},
outputs=[gen_name(0)])
self.pattern.add_layer(
"fluid.ParamAttr",
inputs={},
outputs=[gen_name(1)])
self.pattern.add_layer(
"fluid.layers.create_parameter",
inputs={"attr": gen_name(1)},
outputs=[gen_name(2)])
inputs_dict = {}
inputs_dict['x'] = gen_name(0)
inputs_dict['y'] = gen_name(2)
self.pattern.add_layer(
"fluid.layers.elementwise_mul",
inputs=inputs_dict,
outputs=[gen_name(3)])
self.pattern.add_layer(
"fluid.ParamAttr",
inputs={},
outputs=[gen_name(4)])
self.pattern.add_layer(
"fluid.layers.create_parameter",
inputs={"attr": gen_name(4)},
outputs=[gen_name(5)])
inputs_dict = {}
inputs_dict['x'] = gen_name(3)
inputs_dict['y'] = gen_name(5)
self.pattern.add_layer(
"fluid.layers.elementwise_add",
inputs=inputs_dict,
outputs=[gen_name(6)])
self.pattern.build(inputs={"input-0": "bn-input-0"})
def insert_new_layer(self, graph, parameters, matches):
new_layer = self.gen_new_layer(parameters, matches)
new_layer_id = list(matches.keys())[0]
graph.layers[new_layer_id] = new_layer
matches.pop(new_layer_id)
def gen_new_layer(self, parameters, matches):
layers_id = list(matches.keys())
layer = matches[layers_id[0]]
layer_inputs = layer.inputs
layer_name = layer.outputs[0]
layer_attrs = layer.attrs
layer_attrs["param_attr"] = string("{}_scale".format(layer_name))
layer_attrs["bias_attr"] = string("{}_offset".format(layer_name))
layer = matches[layers_id[-1]]
layer_outputs = layer.outputs
layer = matches[layers_id[1]]
layer_name = layer.outputs[0]
scale_numpy = parameters.pop(layer_name)
parameters[layer_attrs["param_attr"][1: -1]] = scale_numpy
layer = matches[layers_id[4]]
layer_name = layer.outputs[0]
scale_numpy = parameters.pop(layer_name)
parameters[layer_attrs["bias_attr"][1: -1]] = scale_numpy
new_layer = PaddleLayer(
layers_id[0],
"fluid.layers.batch_norm",
inputs=layer_inputs,
outputs=layer_outputs,
**layer_attrs)
return new_layer
\ No newline at end of file
...@@ -12,22 +12,36 @@ ...@@ -12,22 +12,36 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from x2paddle.optimizer.pytorch_optimizer.fusion import * from x2paddle.optimizer.pass_manager import PassManager
from x2paddle.optimizer.pytorch_optimizer.pass_manager import PassManager from x2paddle.optimizer.fusion.dygraph import *
from x2paddle.optimizer.fusion.static import *
class GraphOptimizer(object): class GraphOptimizer(object):
def __init__(self): def __init__(self, source_frame, paddle_type="dygraph"):
self.passes = [ if source_frame == "pytorch":
"constant_fuse_pass", "batchnorm2d_fuse_pass", self.passes = [
"interpolate_bilinear_fuse_pass", "fc_fuse_pass", "dygraph_constant_fuse_pass", "dygraph_batchnorm2d_fuse_pass",
"adaptive_pool2d_fuse_pass", "reshape_fuse_pass", "dygraph_interpolate_bilinear_fuse_pass", "dygraph_fc_fuse_pass",
"dropout_fuse_pass" "dygraph_adaptive_pool2d_fuse_pass", "dygraph_reshape_fuse_pass",
] "dygraph_dropout_fuse_pass"
]
elif source_frame == "caffe":
if paddle_type == "dygraph":
self.passes = ["dygraph_bn_scale_fuse_pass"]
else:
self.passes = ["static_bn_scale_fuse_pass"]
else:
# TODO
pass
def optimize(self, graph): def optimize(self, graph):
for pass_name in self.passes: for pass_name in self.passes:
pass_ = PassManager.lookup(pass_name)() pass_ = PassManager.lookup(pass_name)()
pass_.apply(graph) while True:
before_len = len(graph.layers)
pass_.apply(graph)
after_len = len(graph.layers)
if before_len == after_len:
break
print("{} done!".format(pass_name)) print("{} done!".format(pass_name))
return graph return graph
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.optimizer.pytorch_optimizer.pass_ import Pass
from x2paddle.optimizer.pytorch_optimizer.fusion import InterpolateBilinearFuser
from x2paddle.optimizer.pytorch_optimizer.pass_manager import pass_register
@pass_register
class InterpolateBilinearFusePass(Pass):
name = "interpolate_bilinear_fuse_pass"
def __init__(self):
Pass.__init__(self)
def apply(self, graph):
fuser = InterpolateBilinearFuser()
fuser.operate(graph, match_kind="topo")
# 用于注册
interpolate_bilinear_fuse_pass = InterpolateBilinearFusePass()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册