未验证 提交 f326e2c3 编写于 作者: J Jason 提交者: GitHub

Merge pull request #421 from SunAhong1993/paddle-2.0

add caffe2paddle
......@@ -13,6 +13,7 @@
# limitations under the License.
from six import text_type as _text_type
from x2paddle import program
import argparse
import sys
......@@ -88,11 +89,13 @@ def arg_parser():
default=False,
help="define whether merge the params")
parser.add_argument(
"--input_shapes",
"-is",
action='append',
default=None,
help="define the inputs' shape")
"--paddle_type",
"-pt",
type=_text_type,
default="dygraph",
help="define the paddle model type after converting(dygraph/static)"
)
return parser
......@@ -117,7 +120,7 @@ def tf2paddle(model_path,
"[ERROR] Tensorflow is not installed, use \"pip install tensorflow\"."
)
return
from x2paddle import program
from x2paddle.decoder.tf_decoder import TFDecoder
from x2paddle.op_mapper.tf_op_mapper import TFOpMapper
from x2paddle.optimizer.tensorflow.bias import BiasOpt
......@@ -126,6 +129,7 @@ def tf2paddle(model_path,
print("Now translating model from tensorflow to paddle.")
model = TFDecoder(model_path, define_input_shape=define_input_shape)
mapper = TFOpMapper(model)
program.build()
bias_opt = BiasOpt()
......@@ -137,10 +141,13 @@ def tf2paddle(model_path,
program.gen_model(save_dir)
def caffe2paddle(proto, weight, save_dir, caffe_proto, params_merge=False):
def caffe2paddle(proto, weight, save_dir, caffe_proto,
paddle_type, params_merge=False):
from x2paddle.decoder.caffe_decoder import CaffeDecoder
from x2paddle.op_mapper.caffe_op_mapper import CaffeOpMapper
from x2paddle.optimizer.caffe_optimizer import CaffeOptimizer
if paddle_type == "dygraph":
from x2paddle.op_mapper.dygraph.caffe2paddle.caffe_op_mapper import CaffeOpMapper
else:
from x2paddle.op_mapper.static.caffe2paddle.caffe_op_mapper import CaffeOpMapper
import google.protobuf as gpb
ver_part = gpb.__version__.split('.')
version_satisfy = False
......@@ -151,10 +158,13 @@ def caffe2paddle(proto, weight, save_dir, caffe_proto, params_merge=False):
print("Now translating model from caffe to paddle.")
model = CaffeDecoder(proto, weight, caffe_proto)
mapper = CaffeOpMapper(model)
optimizer = CaffeOptimizer(mapper)
optimizer.merge_bn_scale()
optimizer.merge_op_activation()
mapper.save_inference_model(save_dir, params_merge)
mapper.paddle_graph.build()
print("Model optimizing ...")
from x2paddle.optimizer.optimizer import GraphOptimizer
graph_opt = GraphOptimizer(source_frame="caffe", paddle_type=paddle_type)
graph_opt.optimize(mapper.paddle_graph)
print("Model optimized.")
mapper.paddle_graph.gen_model(save_dir)
def onnx2paddle(model_path, save_dir, params_merge=False):
......@@ -185,7 +195,7 @@ def onnx2paddle(model_path, save_dir, params_merge=False):
print("Paddle model and code generated.")
def pytorch2paddle(model_path, save_dir, input_shapes):
def pytorch2paddle(model_path, save_dir, jit_type, input_files):
# check pytorch installation and version
try:
import torch
......@@ -202,9 +212,12 @@ def pytorch2paddle(model_path, save_dir, input_shapes):
return
print("Now translating model from pytorch to paddle.")
from x2paddle.decoder.pytorch_decoder import PyTorchDecoder
from x2paddle.decoder.pytorch_decoder import ScriptDecoder, TraceDecoder
from x2paddle.op_mapper.pytorch2paddle import pytorch_op_mapper
model = PyTorchDecoder(model_path)
if jit_type == "trace":
model = TraceDecoder(model_path, input_files)
else:
model = ScriptDecoder(model_path)
mapper = pytorch_op_mapper.PyTorchOpMapper(model)
mapper.graph.build()
print("Model optimizing ...")
......@@ -212,16 +225,7 @@ def pytorch2paddle(model_path, save_dir, input_shapes):
graph_opt = GraphOptimizer()
graph_opt.optimize(mapper.graph)
print("Model optimized.")
if input_shapes is not None:
real_input_shapes = list()
for shape in input_shapes:
sp = shape[1:-1].split(",")
for i, s in enumerate(sp):
sp[i] = int(s)
real_input_shapes.append(sp)
else:
real_input_shapes = None
mapper.graph.gen_model(save_dir, real_input_shapes)
mapper.graph.gen_model(save_dir, jit_type, input_files)
def paddle2onnx(model_path, save_dir, opset_version=10):
......@@ -260,6 +264,7 @@ def main():
assert args.framework is not None, "--framework is not defined(support tensorflow/caffe/onnx)"
assert args.save_dir is not None, "--save_dir is not defined"
assert args.paddle_type in ["dygraph", "static"], "--paddle_type must be 'dygraph' or 'static'"
try:
import paddle
......@@ -267,8 +272,8 @@ def main():
print("paddle.__version__ = {}".format(paddle.__version__))
if v0 == '0' and v1 == '0' and v2 == '0':
print("[WARNING] You are use develop version of paddlepaddle")
elif int(v0) != 1 or int(v1) < 6:
print("[ERROR] paddlepaddle>=1.6.0 is required")
elif int(v0) != 2 or int(v1) < 0:
print("[ERROR] paddlepaddle>=2.0.0 is required")
return
except:
print(
......@@ -296,7 +301,7 @@ def main():
if args.params_merge:
params_merge = True
caffe2paddle(args.prototxt, args.weight, args.save_dir,
args.caffe_proto, params_merge)
args.caffe_proto, args.paddle_type, params_merge)
elif args.framework == "onnx":
assert args.model is not None, "--model should be defined while translating onnx model"
params_merge = False
......@@ -304,7 +309,10 @@ def main():
if args.params_merge:
params_merge = True
onnx2paddle(args.model, args.save_dir, params_merge)
elif args.framework == "pytorch":
assert args.model is not None, "--model should be defined while translating pytorch model"
pytorch2paddle(args.model, args.save_dir, args.jit_type, args.input_files)
elif args.framework == "paddle2onnx":
assert args.model is not None, "--model should be defined while translating paddle model to onnx"
paddle2onnx(args.model, args.save_dir, opset_version=args.onnx_opset)
......
# -*- coding:UTF-8 -*-
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
......@@ -21,15 +20,15 @@ import paddle
from paddle.fluid.proto import framework_pb2
from collections import OrderedDict
import numpy
import collections
import sys
import os
import six
import pickle
import numpy as np
class PaddleLayer(object):
def __init__(self, id, kernel, inputs, outputs, **kwargs):
def __init__(self, id, kernel, inputs, outputs, scope_name="", **kwargs):
assert isinstance(
inputs,
dict), "parameter 'inputs' for PaddleLayer should be type of dict"
......@@ -53,16 +52,18 @@ class PaddleLayer(object):
self.kernel = kernel
self.inputs = inputs
self.outputs = outputs
self.scope_name = scope_name
self.attrs = kwargs
self.id = id
self.blocks = list()
def add_block(self, block):
self.blocks.append(block)
class PaddleGraph(object):
def __init__(self, parent_layer=None, graph_type="static"):
def __init__(self, source_type=None, parent_layer=None, graph_type="static"):
self.layers = OrderedDict()
self.edges_out = dict()
self.edges_in = dict()
......@@ -71,12 +72,24 @@ class PaddleGraph(object):
self.parameters = dict()
self.parent_layer = parent_layer
self.graph_type = graph_type
self.source_type = source_type
self.custom_code = None
self.inputs_info = None
def set_name(self, name):
self.name = name
self.name = name.replace("-", "_").replace("/", "_")
def set_parameters(self, parameters):
self.parameters = parameters
def set_custom(self, custom_code):
self.custom_code = custom_code
def set_inputs_info(self, inputs_info):
self.inputs_info = inputs_info
def set_script(self, script):
self.script = script
def clear(self):
self.layers = OrderedDict()
......@@ -90,13 +103,13 @@ class PaddleGraph(object):
self.edges_out = dict()
self.edges_in = dict()
def add_layer(self, kernel, inputs, outputs, **kwargs):
def add_layer(self, kernel, inputs, outputs, scope_name="", **kwargs):
layer_id = str(len(self.layers))
if self.parent_layer is not None:
layer_id = "{}.{}.{}".format(self.parent_layer.id,
len(self.parent_layer.blocks),
layer_id)
layer = PaddleLayer(layer_id, kernel, inputs, outputs, **kwargs)
layer = PaddleLayer(layer_id, kernel, inputs, outputs, scope_name=scope_name, **kwargs)
self.layers[layer_id] = layer
return layer_id
......@@ -215,10 +228,67 @@ class PaddleGraph(object):
block_global_layers = update(block.layers)
global_layers.update(block_global_layers)
return global_layers
return update(self.layers)
def gen_code(self, code_dir):
def gen_model(self, save_dir, jit_type=None):
if not os.path.exists(save_dir):
os.makedirs(save_dir)
if self.graph_type == "static":
self.gen_static_model(save_dir)
else:
self.gen_dygraph_model(save_dir, jit_type)
def gen_static_model(self, save_dir):
code_dir = os.path.join(save_dir, 'model_with_code')
infer_dir = os.path.join(save_dir, 'inference_model')
self.gen_static_code(code_dir)
sys.path.append(code_dir)
import x2paddle_model
paddle.enable_static()
scope = paddle.static.Scope()
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
with paddle.static.scope_guard(scope):
with paddle.static.program_guard(main_program, startup_program):
inputs, outputs = x2paddle_model.x2paddle_net()
exe = fluid.Executor(fluid.CPUPlace())
exe.run(startup_program)
param_dir = os.path.join(code_dir, 'weights')
for k, v in self.parameters.items():
if scope.find_var(k):
self.dump_parameter(k, v, param_dir)
def if_exist(var):
b = os.path.exists(
os.path.join(os.path.join(param_dir, var.name)))
return b
fluid.io.load_vars(
exe, param_dir, main_program, predicate=if_exist)
fluid.io.save_inference_model(
dirname=infer_dir,
feeded_var_names=[i.name for i in inputs],
target_vars=outputs,
executor=exe)
def gen_dygraph_model(self, save_dir, jit_type=None):
if jit_type == "trace":
from x2paddle.optimizer.code_optimizer import HierarchicalTree
hierarchical_tree = HierarchicalTree(self)
for layer_id, layer in self.layers.items():
hierarchical_tree.insert(layer)
hierarchical_tree.save_source_files(save_dir)
self.dump_dygraph_parameter(save_dir)
else:
self.gen_dygraph_code(save_dir)
self.dump_dygraph_parameter(save_dir)
input_shapes = list()
input_types = list()
for input_name in self.inputs:
input_shapes.append(self.inputs_info[input_name][0])
input_types.append(self.inputs_info[input_name][1])
# 如果input_files非空,则导出推理模型;其值类似[[None, 3, 224, 224]]
self.dygraph2static(save_dir, input_shapes, input_types)
def gen_static_code(self, code_dir):
def write_code(f, code_list, indent=0):
indent_blank = " " * indent
for code_line in code_list:
......@@ -235,10 +305,24 @@ class PaddleGraph(object):
f, [
"from paddle.fluid.initializer import Constant",
"from paddle.fluid.param_attr import ParamAttr",
"import paddle.fluid as fluid", "import math", "",
"def x2paddle_net():"
"import paddle.fluid as fluid",
"import paddle", "import math", "",
],
indent=0)
if self.custom_code is not None:
write_code(
f,
list(self.custom_code.values()),
indent=0)
write_code(f,
["", "def x2paddle_net():"],
indent=0)
write_code(
f, [
"paddle.enable_static()"
],
indent=1)
for layer_id, layer in self.layers.items():
edges_in = self.edges_in.get(layer_id, [])
edges_out = self.edges_out.get(layer_id, [])
......@@ -253,8 +337,10 @@ class PaddleGraph(object):
for output in layer.outputs:
line += "{}, ".format(output)
line = line.strip(", ")
line += " = {}(".format(layer.kernel)
if layer.kernel.startswith("custom_layer"):
line += " = {}(".format(layer.kernel.split(":")[-1].lower() + "_layer")
else:
line += " = {}(".format(layer.kernel)
for k, v in layer.inputs.items():
if isinstance(v, list):
line += "{}=[{}], ".format(k, ", ".join(v))
......@@ -274,47 +360,6 @@ class PaddleGraph(object):
indent=1)
f.close()
def gen_model(self, save_dir, input_shapes=None):
if not os.path.exists(save_dir):
os.makedirs(save_dir)
if self.graph_type == "static":
code_dir = os.path.join(save_dir, 'model_with_code')
infer_dir = os.path.join(save_dir, 'inference_model')
self.gen_code(code_dir)
sys.path.append(code_dir)
import x2paddle_model
scope = fluid.Scope()
startup_program = fluid.Program()
main_program = fluid.Program()
with fluid.scope_guard(scope):
with fluid.program_guard(main_program, startup_program):
inputs, outputs = x2paddle_model.x2paddle_net()
exe = fluid.Executor(fluid.CPUPlace())
exe.run(startup_program)
param_dir = os.path.join(code_dir, 'weights')
for k, v in self.parameters.items():
if scope.find_var(k):
self.dump_parameter(k, v, param_dir)
def if_exist(var):
b = os.path.exists(
os.path.join(os.path.join(param_dir, var.name)))
return b
fluid.io.load_vars(
exe, param_dir, main_program, predicate=if_exist)
fluid.io.save_inference_model(
dirname=infer_dir,
feeded_var_names=[i.name for i in inputs],
target_vars=outputs,
executor=exe)
else:
self.gen_dygraph_code(save_dir)
self.dump_dygraph_parameter(save_dir)
if input_shapes is not None:
# 如果input_shapes非空,则导出推理模型;其值类似[[None, 3, 224, 224]]
self.dygraph2static(save_dir, input_shapes)
def dump_parameter(self, param_name, param, save_dir):
if not os.path.exists(save_dir):
......@@ -356,10 +401,10 @@ class PaddleGraph(object):
if self.edges_in.get(layer_id, 0) == 0 and self.edges_out.get(
layer_id, 0) == 0:
continue
if layer.kernel == "fluid.dygraph.base.to_variable":
value = layer.attrs["value"]
if not value.startswith("params["):
self.inputs.append(value)
if layer.kernel == "paddle.to_tensor":
data = layer.attrs["data"]
if not data.startswith("params["):
self.inputs.append(data)
if len(layer.blocks) > 0:
for block in layer.blocks:
block.get_dygraph_inputs()
......@@ -376,11 +421,15 @@ class PaddleGraph(object):
layer_id, 0) == 0:
continue
if self.edges_out.get(layer_id, 0) == 0:
for output_name in layer.outputs:
if not output_name.startswith("x"):
continue
self.outputs.append(output_name)
self.outputs = list(set(self.outputs))
for i, output_name in enumerate(layer.outputs):
if ("paddle.nn" in layer.kernel and "functional" not in layer.kernel) or \
(layer.kernel == "paddle.to_tensor" and layer.attrs["data"].startswith("params["))or \
"paddle.fluid.dygraph" in layer.kernel:
if i == 0:
continue
if output_name not in self.outputs:
self.outputs.append(output_name)
def gen_dygraph_code(self, code_dir=None, indent=2):
def gen_codes(code_list, indent=0):
......@@ -394,20 +443,24 @@ class PaddleGraph(object):
return codes
def gen_head():
if self.source_type == "caffe":
custom_import = "from x2paddle.op_mapper.dygraph.caffe2paddle " + \
"import caffe_custom_layer as x2paddle_nn"
self.head = gen_codes(
[
"from paddle.fluid.initializer import Constant",
"from paddle.fluid.param_attr import ParamAttr",
"import paddle",
"import paddle.fluid as fluid",
custom_import,
"",
"class {}(fluid.dygraph.Layer):".format(self.name),
"class {}(paddle.nn.Layer):".format(self.name),
],
indent=0)
input_data_name = ', '.join(self.inputs)
self.init_func.extend(
gen_codes(
["def __init__(self, params):"], indent=1))
["def __init__(self):"], indent=1))
self.init_func.extend(
gen_codes(
["super({}, self).__init__()".format(self.name)], indent=2))
......@@ -415,6 +468,31 @@ class PaddleGraph(object):
gen_codes(
["def forward(self, {}):".format(input_data_name)],
indent=1))
def gen_main_code(code_dir):
input_data_name = ', '.join(self.inputs)
self.run_func = gen_codes(
[
"",
"def main({}):".format(input_data_name),
],
indent=0)
comment_list = list()
comment_list.append("# 共{}个输入".format(len(self.inputs_info)))
for k, v in self.inputs_info.items():
comment_list.append("# {}: 形状为{},类型为{}。".format(k, v[0], v[1]))
self.run_func.extend(
gen_codes(
comment_list,
indent=1))
self.run_func.extend(
gen_codes(["paddle.disable_static()",
"params, _ = fluid.load_dygraph('{}/model')".format(code_dir),
"model = {}()".format(self.name),
"model.set_dict(params)",
"model.eval()",
"out = model({})".format(input_data_name),
"return out"], indent=1))
def write_code(code_dir):
f = open(os.path.join(code_dir, 'x2paddle_code.py'), 'w')
......@@ -431,6 +509,8 @@ class PaddleGraph(object):
self.forward_func.extend(gen_codes([return_code], indent=2))
for code_line in self.forward_func:
f.write(code_line)
for code_line in self.run_func:
f.write(code_line)
f.close()
self.init_func = []
......@@ -440,21 +520,25 @@ class PaddleGraph(object):
for layer_id, layer in self.layers.items():
if ("paddle.nn" in layer.kernel and "functional" not in layer.kernel
) or layer.kernel == "fluid.dygraph.base.to_variable" or \
"paddle.fluid.dygraph" in layer.kernel:
) or layer.kernel == "paddle.to_tensor" or \
"paddle.fluid.dygraph" in layer.kernel or \
layer.kernel.startswith("custom_layer"):
line = "{}".format(
layer.outputs[0]
) if layer.kernel == "fluid.dygraph.base.to_variable" and not layer.attrs[
"value"].startswith("params[") else "self.{}".format(
) if layer.kernel == "paddle.to_tensor" and not layer.attrs[
"data"].startswith("params[") else "self.{}".format(
layer.outputs[0])
line += " = {}(".format(layer.kernel)
if layer.kernel.startswith("custom_layer"):
line += "= x2paddle_nn.{}(".format(layer.kernel.split(":")[-1])
else:
line += " = {}(".format(layer.kernel)
for k, v in layer.attrs.items():
line += "{}={}, ".format(k, v)
line = line.strip(", ")
line += ")"
if layer.kernel == "fluid.dygraph.base.to_variable" and not layer.attrs[
"value"].startswith("params["):
if layer.kernel == "paddle.to_tensor" and not layer.attrs[
"data"].startswith("params["):
self.forward_func.extend(gen_codes([line], indent=indent))
continue
else:
......@@ -466,8 +550,8 @@ class PaddleGraph(object):
line = layer.outputs[1]
else:
line = ','.join(layer.outputs[1:])
if layer.kernel == "fluid.dygraph.base.to_variable" and layer.attrs[
"value"].startswith("params["):
if layer.kernel == "paddle.to_tensor" and layer.attrs[
"data"].startswith("params["):
line += " = self.{}".format(layer.outputs[0])
else:
line += " = self.{}(".format(layer.outputs[0])
......@@ -475,10 +559,10 @@ class PaddleGraph(object):
line += "{}, ".format(v)
line = line.strip(", ")
line += ")"
self.forward_func.extend(gen_codes([line], indent=indent))
self.forward_func.extend(gen_codes([line], indent=indent))
elif "prim" in layer.kernel:
func_name = layer.kernel.replace(".", "_")
from x2paddle.op_mapper.pytorch2paddle import prim2code
from x2paddle.op_mapper.dygraph import prim2code
if hasattr(prim2code, func_name):
func = getattr(prim2code, func_name)
func(
......@@ -502,8 +586,14 @@ class PaddleGraph(object):
line += "{}={}, ".format(k, v)
line = line.strip(", ")
line += ")"
self.forward_func.extend(gen_codes([line], indent=indent))
if layer.kernel == "self.create_parameter":
self.init_func.extend(gen_codes(["self." + line], indent=indent))
self.forward_func.extend(gen_codes(["{} = self.{}".format(layer.outputs[0],
layer.outputs[0])], indent=indent))
else:
self.forward_func.extend(gen_codes([line], indent=indent))
if indent == 2:
gen_main_code(code_dir)
write_code(code_dir)
else:
return self.init_func, self.forward_func
......@@ -513,23 +603,22 @@ class PaddleGraph(object):
pickle.dump(self.parameters, params_output)
params_output.close()
def dygraph2static(self, save_dir, input_shapes=[]):
def dygraph2static(self, save_dir, input_shapes=[], input_types=[]):
from paddle.fluid.dygraph.jit import declarative
sepc_list = list()
for i, name in enumerate(self.inputs):
input_shapes[i][0] = -1
sepc_list.append(
paddle.static.InputSpec(
shape=input_shapes[i], name=name))
shape=input_shapes[i], name=name, dtype=input_types[i]))
import sys
path = osp.abspath(save_dir)
sys.path.insert(0, save_dir)
import x2paddle_code
place = fluid.CPUPlace()
with fluid.dygraph.guard(place):
restore, _ = fluid.load_dygraph(osp.join(save_dir, "model"))
model = getattr(x2paddle_code, self.name)(restore)
model.set_dict(restore)
model.eval()
model.forward = declarative(model.forward, sepc_list)
fluid.dygraph.jit.save(
layer=model, model_path=osp.join(save_dir, "inference"))
paddle.disable_static()
restore, _ = fluid.load_dygraph(osp.join(save_dir, "model"))
model = getattr(x2paddle_code, self.name)()
model.set_dict(restore)
model.eval()
static_model = paddle.jit.to_static(model, input_spec=sepc_list)
paddle.jit.save(static_model, osp.join(save_dir, "inference_model/model"))
\ No newline at end of file
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
......@@ -18,3 +18,11 @@ import os
def string(param):
return "\'{}\'".format(param)
def name_generator(nn_name, nn_name2id):
if nn_name in nn_name2id:
nn_name2id[nn_name] += 1
else:
nn_name2id[nn_name] = 0
real_nn_name = nn_name + str(nn_name2id[nn_name])
return real_nn_name
\ No newline at end of file
......@@ -18,7 +18,6 @@ from google.protobuf import text_format
import numpy as np
from x2paddle.core.graph import GraphNode, Graph
from x2paddle.core.fluid_code import FluidCode
from x2paddle.op_mapper import caffe_shape
class CaffeResolver(object):
......@@ -50,10 +49,10 @@ class CaffeGraphNode(GraphNode):
def __init__(self, layer, type_str, layer_name=None):
if layer_name is None:
super(CaffeGraphNode, self).__init__(
layer, layer.name.replace('/', '_').replace('-', '_'))
layer, layer.name.replace('/', '_').replace('-', '_').lower())
else:
super(CaffeGraphNode, self).__init__(
layer, layer_name.replace('/', '_').replace('-', '_'))
layer, layer_name.replace('/', '_').replace('-', '_').lower())
self.layer_type = type_str
self.fluid_code = FluidCode()
self.data = None
......@@ -66,6 +65,13 @@ class CaffeGraph(Graph):
def __init__(self, model, params, caffe_pb):
self.params = params
self.caffe_pb = caffe_pb
if hasattr(model, "name"):
if model.name == "":
self.graph_name = "CaffeModel"
else:
self.graph_name = model.name
else:
self.graph_name = "CaffeModel"
super(CaffeGraph, self).__init__(model)
def filter_layers(self, layers):
......@@ -242,7 +248,7 @@ class CaffeDecoder(object):
with open(proto_path, 'rb') as proto_file:
proto_str = proto_file.read()
text_format.Merge(proto_str, self.net)
self.load_using_pb()
self.caffe_graph = CaffeGraph(self.net, self.params,
......
......@@ -12,15 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import torch
import numpy as np
class PyTorchDecoder(object):
def __init__(self, script_path):
self.script = torch.jit.load(script_path)
self.graph = self._optimize_graph(self.script.inlined_graph)
def _optimize_graph(self, graph):
class Decoder(object):
def _optimize_graph(self, graph):
torch._C._jit_pass_constant_propagation(graph)
torch._C._jit_pass_dce(graph)
torch._C._jit_pass_lint(graph)
......@@ -31,4 +30,37 @@ class PyTorchDecoder(object):
torch._C._jit_pass_canonicalize(graph)
torch._C._jit_pass_lint(graph)
torch._C._jit_pass_constant_propagation(graph)
return graph
return graph
class ScriptDecoder(Decoder):
""" 当script_path非None,直接load ScriptModule;
当model_path非None,load PyTorchModule后使用script方式转换为ScriptModule。
Args:
script_path (str): ScriptModule保存路径。
model_path (str): PyTorchModule保存路径。
"""
def __init__(self, script_path=None):
self.script = torch.jit.load(script_path)
self.graph = self._optimize_graph(self.script.inlined_graph)
class TraceDecoder(Decoder):
""" PyTorchModule后使用trace方式转换为ScriptModule。
Args:
model_path (str): PyTorchModule保存路径。
input_files (list): 输入网络的numpy,每个numpy保存成.npy文件,
文件路径存储在input_files中。
"""
def __init__(self, model_path, input_files=list()):
# TODO(syf): 传入pytorch的Module(即import),否则出错
model = torch.load(model_path)
model.eval()
input_list = list()
for npy_file in input_files:
input_list.append(torch.tensor(np.load(npy_file)))
self.script = torch.jit.trace(model, input_list, strict=False)
self.graph = self._optimize_graph(self.script.inlined_graph)
# print(self.graph)
# print(getattr(getattr(self.script.decoder.block, "5").layer, "2"))
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .detectionoutput import DetectionOutput
from .normalize import Normalize
from .priorbox import PriorBox
from .roipooling import ROIPooling
from .select import Select
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
class DetectionOutput(object):
def __init__(self, nms_threshold, nms_top_k, keep_top_k, nms_eta, score_threshold, background_label):
self.detection_output_layer_attrs = {
"background_label": background_label,
"nms_threshold": nms_threshold,
"nms_top_k": nms_top_k,
"keep_top_k": keep_top_k,
"score_threshold": score_threshold,
"nms_eta": nms_eta}
def __call__(self, x0, x1, x2):
priorbox_list = paddle.split(x2, num_or_sections=2, axis=1)
pb = priorbox_list[0]
pbv = priorbox_list[1]
pb = paddle.reshape(x=pb, shape=[-1, 4])
pbv = paddle.reshape(x=pbv, shape=[-1, 4])
pb_dim = fluid.layers.shape(pb)[0]
loc = paddle.reshape(x0, shape=[-1, pb_dim, 4])
conf_flatten = paddle.reshape(x1, shape=[0, pb_dim, -1])
out = fluid.layers.detection_output(loc=loc,
scores=conf_flatten,
prior_box=pb,
prior_box_var=pbv,
**self.detection_output_layer_attrs)
return out
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
class Normalize(object):
def __init__(self, axis, param_name, param_shape):
self.axis = axis
self.param_name = param_name
self.param_shape = param_shape
def __call__(self, x):
l2 = fluid.layers.prior_box(x=x, p=2, axis=1)
attr = fluid.ParamAttr(name=self.param_name, trainable=False)
param = paddle.nn.Layer.create_parameter(shape=self.param_shape,
attr=atr)
out = paddle.multiply(x=l2, y=param, axis=self.axis)
return out
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
class PriorBox(object):
def __init__(self, min_sizes, max_sizes,
aspect_ratios, variance, flip,
clip, steps, offset,
min_max_aspect_ratios_order):
self.priorbox_layer_attrs = {
"min_sizes": min_sizes,
"max_sizes": max_sizes,
"aspect_ratios": aspect_ratios,
"variance": variance,
"flip": flip,
"clip": clip,
"steps": steps,
"offset": offset,
"min_max_aspect_ratios_order": min_max_aspect_ratios_order}
def __call__(self, x0, x1):
box, var = fluid.layers.prior_box(input=x0,
image=x1,
**self.priorbox_layer_attrs)
box = paddle.reshape(x=box, shape=[1, 1, -1])
var = paddle.reshape(x=var, shape=[1, 1, -1])
out = paddle.concat(x=[box, var], axis=1)
return out
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
class ROIPooling(object):
def __init__(self, pooled_height, pooled_width, spatial_scale):
self.roipooling_layer_attrs = {
"pooled_height": pooled_height,
"pooled_width": pooled_width,
"spatial_scale": spatial_scale}
def __call__(self, x0, x1):
slice_x1 = paddle.slice(input=x1, axes=[1],
starts=[1], ends=[5])
out = fluid.layers.roi_pool(input=x0,
rois=slice_x1,
**self.roipooling_layer_attrs)
return out
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
class Select(object):
def __init__(self, input_shape, point, axis):
self.point = point
self.input_shape = input_shape
self.axis = axis
def __call__(self, x):
start = self.point[0]
if len(self.point) == 2:
end = self.point[1]
else:
end = self.input_shape[self.axis]
out = paddle.slice(x=x,
start=start,
end=end,
axes=[self.axis])
return out
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numbers
from functools import reduce
def get_kernel_parameters(params):
[k_h, k_w] = [1, 1]
if isinstance(params.kernel_size, numbers.Number):
[k_h, k_w] = [params.kernel_size] * 2
elif len(params.kernel_size) > 0:
k_h = params.kernel_h if params.kernel_h > 0 else params.kernel_size[0]
k_w = params.kernel_w if params.kernel_w > 0 else params.kernel_size[
len(params.kernel_size) - 1]
elif params.kernel_h > 0 or params.kernel_w > 0:
k_h = params.kernel_h
k_w = params.kernel_w
[s_h, s_w] = [1, 1]
if isinstance(params.stride, numbers.Number):
[s_h, s_w] = [params.stride] * 2
elif len(params.stride) > 0:
s_h = params.stride_h if params.stride_h > 0 else params.stride[0]
s_w = params.stride_w if params.stride_w > 0 else params.stride[len(
params.stride) - 1]
elif params.stride_h > 0 or params.stride_w > 0:
s_h = params.stride_h
s_w = params.stride_w
[p_h, p_w] = [0, 0]
if isinstance(params.pad, numbers.Number):
[p_h, p_w] = [params.pad] * 2
elif len(params.pad) > 0:
p_h = params.pad_h if params.pad_h > 0 else params.pad[0]
p_w = params.pad_w if params.pad_w > 0 else params.pad[len(params.pad) -
1]
elif params.pad_h > 0 or params.pad_w > 0:
p_h = params.pad_h
p_w = params.pad_w
dila_h = dila_w = 1
if hasattr(params, 'dilation'):
dila_len = len(params.dilation)
if dila_len == 2:
dila_h = params.dilation[0]
dila_w = params.dilation[1]
elif dila_len == 1:
dila_h = dila_w = params.dilation[0]
else:
assert dila_len == 0, "invalid length[%s] of dilation in convolution" % (
dila_len)
return dila_h, dila_w, p_h, p_w, k_h, k_w, s_h, s_w
def get_strided_kernel_output_shape(params, input_shape, round_func):
i_h = input_shape[2]
i_w = input_shape[3]
dila_h, dila_w, pad_h, pad_w, kernel_h, kernel_w, stride_h, stride_w = get_kernel_parameters(
params)
o_h = (i_h + 2 * pad_h - (dila_h *
(kernel_h - 1) + 1)) / float(stride_h) + 1
o_w = (i_w + 2 * pad_w - (dila_w *
(kernel_w - 1) + 1)) / float(stride_w) + 1
o_h = int(round_func(o_h))
o_w = int(round_func(o_w))
has_c_o = hasattr(params, 'num_output')
c = params.num_output if has_c_o else input_shape[1]
return [[input_shape[0], c, o_h, o_w]]
def shape_convolution(layer, input_shape):
params = layer.convolution_param
return get_strided_kernel_output_shape(params, input_shape[0], math.floor)
def shape_deconvolution(layer, input_shape):
h_i = input_shape[0][2]
w_i = input_shape[0][3]
params = layer.convolution_param
dila_h, dila_w, pad_h, pad_w, kernel_h, kernel_w, stride_h, stride_w = get_kernel_parameters(
params)
h_o = (h_i - 1) * stride_h - 2 * pad_h + dila_h * (kernel_h - 1) + 1
w_o = (w_i - 1) * stride_w - 2 * pad_w + dila_w * (kernel_w - 1) + 1
has_c_o = hasattr(params, 'num_output')
c = params.num_output if has_c_o else input_shape.channels
return [[input_shape[0][0], c, h_o, w_o]]
def shape_pooling(layer, input_shape):
params = layer.pooling_param
global_pool = getattr(params, 'global_pooling', False)
if global_pool:
return [[input_shape[0][0], input_shape[0][1], 1, 1]]
ceil_mode = getattr(params, 'ceil_mode', True)
if ceil_mode is True:
method = math.ceil
else:
method = math.floor
return get_strided_kernel_output_shape(params, input_shape[0], method)
def shape_convolutiondepthwise(layer, input_shape):
params = layer.convolution_param
return get_strided_kernel_output_shape(params, input_shape[0], math.floor)
def shape_innerproduct(layer, input_shape):
params = layer.inner_product_param
return [[input_shape[0][0], params.num_output]]
def shape_lrn(layer, input_shape):
return input_shape
def shape_relu(layer, input_shape):
return input_shape
def shape_softmax(layer, input_shape):
return input_shape
def shape_input(layer, input_shape):
return [list(layer.input_param.shape[0].dim)]
def shape_memorydata(layer, input_shape):
params = layer.memory_data_param
shape = []
shape.append(int(params.batch_size))
shape.append(int(params.channels))
shape.append(int(params.height))
shape.append(int(params.width))
return [shape]
def shape_concat(layer, input_shape):
params = layer.concat_param
axis = params.axis
output_shape = None
for shape in input_shape:
if output_shape is None:
output_shape = []
for i in range(len(shape)):
output_shape.append(shape[i])
else:
output_shape[axis] += shape[axis]
return [output_shape]
def shape_slice(layer, input_shape):
inshape = input_shape[0]
top_len = len(layer.top)
params = layer.slice_param
axis = params.axis
slice_dim = params.slice_dim
if slice_dim != 1 and axis == 1:
axis = slice_dim
points = list(params.slice_point)
count = inshape[axis]
if len(points) == 0:
assert count % top_len == 0, "the parameter of Slice is wrong"
part = count / top_len
t = part
while t < count:
points.append(int(t))
t += part
points = [0] + points + [count]
output_shape = []
for i in range(len(points)):
shape = []
for ii in range(len(inshape)):
shape.append(inshape[ii])
size = points[i + 1] - points[i]
shape[axis] = size
output_shape.append(shape)
if i == len(points) - 2:
break
return output_shape
def shape_prelu(layer, input_shape):
return input_shape
def shape_sigmoid(layer, input_shape):
return input_shape
def shape_absval(layer, input_shape):
return input_shape
def shape_accuracy(layer, input_shape):
return [[1]]
def shape_tanh(layer, input_shape):
return input_shape
def shape_eltwise(layer, input_shape):
return [input_shape[0]]
def shape_batchnorm(layer, input_shape):
return input_shape
def shape_scale(layer, input_shape):
return input_shape
def shape_reshape(layer, input_shape):
def count(num_list):
return reduce(lambda a, b: a * b, num_list)
inshape = input_shape[0]
params = layer.reshape_param
axis = params.axis if hasattr(params, 'axis') else 0
num_axes = params.num_axes if hasattr(params, 'num_axes') else -1
if inshape[0] == -1:
inshape[0] = 1
input_count = count(inshape)
input_num_axes = len(inshape)
input_start_axis = axis
start_axis = input_start_axis if input_start_axis >= 0 \
else input_num_axes + input_start_axis + 1
assert start_axis >= 0, "[Reshape]axis %d out of range" % (input_start_axis)
assert start_axis <= input_num_axes, "[Reshape]axis %d out of range for %d-D input data"\
% (input_start_axis, input_num_axes)
assert num_axes >= -1, "[Reshape]num_axes must be >= 0, or -1 for all"
end_axis = input_num_axes if num_axes == -1 else start_axis + num_axes
assert end_axis <= input_num_axes, "end_axis[%d] = axis[%d] + num_axes[%d] is out of range"\
% (end_axis, start_axis, num_axes)
num_axes_replaced = end_axis - start_axis
num_axes_retained = input_num_axes - num_axes_replaced
num_new_axes = len(list(params.shape.dim))
output_shape = []
for i in range(start_axis):
output_shape.append(inshape[i])
for i in range(num_new_axes):
output_shape.append(params.shape.dim[i])
for i in range(end_axis, input_num_axes):
output_shape.append(inshape[i])
assert len(output_shape) == num_axes_retained + num_new_axes,\
"[Reshape]invalid dims of output shape[%s]" % (str(output_shape))
inferred_axis = -1
copy_axes = []
constant_count = 1
for i in range(num_new_axes):
top_dim = params.shape.dim[i]
if top_dim == 0:
copy_axes.append(i)
copy_axis_index = start_axis + i
output_shape[copy_axis_index] = inshape[copy_axis_index]
elif top_dim == -1:
assert inferred_axis == -1, "[Reshape]new shape contains multiple -1 dims"
inferred_axis = i
else:
constant_count *= top_dim
if inferred_axis >= 0:
explicit_count = constant_count
l = inshape[0:start_axis]
if len(l) > 0:
explicit_count *= count(l)
l = inshape[end_axis:]
if len(l) > 0:
explicit_count *= count(l)
for i in range(len(copy_axes)):
explicit_count *= output_shape[start_axis + copy_axes[i]]
assert input_count % explicit_count == 0, "[Reshape]botom count[%d] "\
"must be divisible by product of the specified dimensions[%d] "\
% (input_count, explicit_count)
output_shape[start_axis + inferred_axis] = int(input_count / explicit_count)
output_count = count(output_shape)
assert output_count == input_count, "[Reshape]output count[%d] must match input count[%d]" % (
output_count, input_count)
output_shape[0] = -1
return [output_shape]
def shape_argmax(layer, input_shape):
inshape = input_shape[0]
params = layer.argmax_param
out_max_val = params.out_max_val if hasattr(params, out_max_val) else False
top_k = params.top_k if hasattr(params, top_k) else 1
axis = parmas.axis if hasattr(params, axis) else -1
if axis < 0:
axis += len(inshape)
assert (axis + 1 == len(inshape)
), 'only can be applied on the last dimension[axis:%d, %s] now,'\
'make sure you have set axis param in xxx.prototxt file' \
% (axis, str(inshape))
output_shape = inshape
output_shape[-1] = top_k
if out_max_val is True:
output_shape[-1] *= 2
return [output_shape]
def shape_crop(layer, input_shape):
assert len(input_shape) == 2, "the number of crop's inputs must be 2"
return [input_shape[1]]
def shape_flatten(layer, input_shape):
assert len(input_shape) == 1, "the number of flatten's inputs must be 1"
inshape = input_shape[0]
params = layer.flatten_param
start_axis = params.axis
end_axis = params.end_axis
if start_axis < 0:
start_axis += len(inshape)
if end_axis < 0:
end_axis += len(inshape) + 1
assert start_axis <= end_axis, 'invalid axis[%d] or end_axis[%d] params'\
% (start_axis, end_axis)
output_shape = inshape[0:start_axis]
if len(inshape[start_axis:end_axis]) != 0:
flat_sz = reduce(lambda a, b: a * b, inshape[start_axis:end_axis])
output_shape += [flat_sz]
output_shape += inshape[end_axis:len(inshape)]
output_shape[0] = -1
return [output_shape]
def shape_power(layer, input_shape):
return input_shape
def shape_reduction(layer, input_shape):
params = layer.reduction_param
axis = params.axis
if axis < 0:
axis += len(input_shape[0]) + 1
assert axis <= len(input_shape[0]), 'invalid axis[%d] error' % (axis)
return [input_shape[0:axis]]
def shape_axpy(layer, input_shape):
assert len(input_shapes) == 3, "not valid input shape for axpy layer"
assert len(input_shapes[0]) == len(input_shapes[1]), 'should have same dims'
output_shape = input_shapes[1]
assert (input_shapes[2] == output_shape),\
"shape not consistent for axpy[%s <--> %s]" \
% (str(output_shape), str(input_shapes[2]))
return [output_shape]
def shape_detectionoutput(layer, input_shape):
return [[-1, 6]]
def shape_normalize(layer, input_shape):
return input_shape
def shape_permute(layer, input_shape):
order = layer.permute_param.order
inshape = input_shape[0]
output_shape = []
order = list(order)
for ii in order:
assert ii < len(inshape), "invalid order for permute[%s]" % (name)
output_shape.append(inshape[ii])
return [output_shape]
def shape_priorbox(layer, input_shape):
max_size = layer.prior_box_param.max_size
aspect_ratio = layer.prior_box_param.aspect_ratio
fc_shape = input_shape[0]
N = 1
if not max_size == None:
N += 1
if not aspect_ratio == None:
N += 2 * len(aspect_ratio)
N_bbx = fc_shape[2] * fc_shape[3] * N
output_shape = [1, 2, 4 * N_bbx]
return [output_shape]
def shape_relu6(layer, input_shape):
return input_shape
def shape_roipooling(layer, input_shape):
pooled_w = layer.roi_pooling_param.pooled_w
pooled_h = layer.roi_pooling_param.pooled_h
base_fea_shape = input_shapes[0]
rois_shape = input_shapes[1]
output_shape = base_fea_shape
output_shape[0] = rois_shape[0]
output_shape[2] = pooled_h
output_shape[3] = pooled_w
return [output_shape]
def shape_shufflechannel(layer, input_shape):
return input_shape
def shape_upsample(layer, input_shape):
scale = layer.upsample_param.scale
assert len(input_shapes) == 1, "not valid input shape for upsample layer"
assert type(scale) is int
input_shape = input_shapes[0]
new_h = scale * input_shape[2]
new_w = scale * input_shape[3]
output_shape = [input_shape[0], input_shape[1], new_h, new_w]
return [output_shape]
def shape_select(layer, input_shape):
slice_point = layer.select_param.slice_point
axis = layer.select_param.axis
input_shape = input_shapes[0]
start = slice_point[0]
if len(slice_point) == 2:
end = slice_point[1]
else:
end = input_shape[axis]
assert end > start, "invalid slice_point with [start:%d, end:%d]"\
% (start, end)
output_shape = input_shape
output_shape[axis] = end - start
return [output_shape]
此差异已折叠。
......@@ -14,18 +14,6 @@ def detectionoutput_layer(inputs,
confidence_threshold=0.1,
input_shape=None,
name=None):
nms_param_str = nms_param
nms_param = {}
part = nms_param_str.split(',')
for s in part:
if s == '':
break
else:
name, obj = s.split(': ')
if name == 'top_k':
nms_param[name] = int(obj)
else:
nms_param[name] = float(obj)
if nms_param is None:
nms_param = {"nms_threshold": 0.3, "top_k": 10, "eta": 1.0}
mbox_conf_flatten = inputs[1]
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.decoder.caffe_decoder import CaffeGraph
from x2paddle.core.util import *
class CaffeOptimizer(object):
layers_with_act = ['Convolution', 'Deconvolution', 'InnerProduct']
activation_ops = ['ReLU', 'Sigmoid']
def __init__(self, mapper):
self.graph = mapper.graph
def merge_bn_scale(self):
for node_name in self.graph.topo_sort:
node = self.graph.get_node(node_name)
if node.layer_type == 'Scale':
parent_node = self.graph.get_bottom_node(node, idx=0)
if parent_node.layer_type == 'BatchNorm':
is_delete_node = True if len(
parent_node.outputs) == 1 else False
parent_fluid_layer = parent_node.fluid_code.layers[0]
input = parent_fluid_layer.inputs
parent_param_attr = parent_fluid_layer.param_attr
parent_param_attr['param_attr'] = string(node.layer_name +
'_scale')
parent_param_attr['bias_attr'] = string(node.layer_name +
'_offset')
if is_delete_node:
parent_node.fluid_code.clear()
node.fluid_code.clear()
node.fluid_code.add_layer(
"batch_norm",
inputs=input,
output=node,
param_attr=parent_param_attr)
def merge_op_activation(self):
for node_name in self.graph.topo_sort:
node = self.graph.get_node(node_name)
if node.layer_type in self.activation_ops:
parent_node = self.graph.get_bottom_node(node, idx=0)
if parent_node.layer_type in self.layers_with_act:
is_delete_node = True if len(
parent_node.outputs) == 1 else False
parent_fluid_layer = parent_node.fluid_code.layers[0]
input = parent_fluid_layer.inputs
parent_param_attr = parent_fluid_layer.param_attr
parent_param_attr['act'] = string(node.layer_type.lower())
op = parent_fluid_layer.op
if is_delete_node:
parent_node.fluid_code.clear()
node.fluid_code.clear()
node.fluid_code.add_layer(
op,
inputs=input,
output=node,
param_attr=parent_param_attr)
......@@ -12,17 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .adaptive_pool2d_fuser import AdaptivePool2dFuser
from .adaptive_pool2d_fuse_pass import AdaptivePool2dFusePass
from .batchnorm2d_fuser import BatchNorm2dFuser
from .batchnorm2d_fuse_pass import BatchNorm2dFusePass
from .constant_fuser import ConstantFuser
from .constant_fuse_pass import ConstantFusePass
from .dropout_fuser import DropoutFuser
from .dropout_fuse_pass import DropoutFusePass
from .fc_fuser import FcFuser
from .fc_fuse_pass import FcFusePass
from .interpolate_bilinear_fuser import InterpolateBilinearFuser
from .interpolate_bilinear_fuse_pass import InterpolateBilinearFusePass
from .reshape_fuser import ReshapeFuser
from .reshape_fuse_pass import ReshapeFusePass
from .adaptive_pool2d_fuser import Dygraph_AdaptivePool2dFuser
from .adaptive_pool2d_fuse_pass import Dygraph_AdaptivePool2dFusePass
from .batchnorm2d_fuser import Dygraph_BatchNorm2dFuser
from .batchnorm2d_fuse_pass import Dygraph_BatchNorm2dFusePass
from .bn_scale_fuser import Dygraph_BNScaleFuser
from .bn_scale_fuse_pass import Dygraph_BNScaleFusePass
from .constant_fuser import Dygraph_ConstantFuser
from .constant_fuse_pass import Dygraph_ConstantFusePass
from .dropout_fuser import Dygraph_DropoutFuser
from .dropout_fuse_pass import Dygraph_DropoutFusePass
from .fc_fuser import Dygraph_FcFuser
from .fc_fuse_pass import Dygraph_FcFusePass
from .interpolate_bilinear_fuser import Dygraph_InterpolateBilinearFuser
from .interpolate_bilinear_fuse_pass import Dygraph_InterpolateBilinearFusePass
from .reshape_fuser import Dygraph_ReshapeFuser
from .reshape_fuse_pass import Dygraph_ReshapeFusePass
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.optimizer.pass_ import Pass
from x2paddle.optimizer.fusion.dygraph import Dygraph_AdaptivePool2dFuser
from x2paddle.optimizer.pass_manager import pass_register
@pass_register
class Dygraph_AdaptivePool2dFusePass(Pass):
name = "dygraph_adaptive_pool2d_fuse_pass"
def __init__(self):
Pass.__init__(self)
def apply(self, graph):
fuser = Dygraph_AdaptivePool2dFuser()
fuser.operate(graph, match_kind="topo")
# 用于注册
adaptive_pool2d_fuse_pass = Dygraph_AdaptivePool2dFusePass()
......@@ -13,14 +13,14 @@
# limitations under the License.
import numpy as np
from x2paddle.optimizer.pytorch_optimizer.pattern_matcher import FuseBase
from x2paddle.optimizer.pattern_matcher import FuseBase
from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import *
class AdaptivePool2dFuser(FuseBase):
class Dygraph_AdaptivePool2dFuser(FuseBase):
def __init__(self):
super(AdaptivePool2dFuser, self).__init__(graph_type="dygraph")
super(Dygraph_AdaptivePool2dFuser, self).__init__(graph_type="dygraph")
def build_pattern(self):
""" 描述需要替换的adaptive pool2d图结构。
......
......@@ -12,22 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.optimizer.pytorch_optimizer.pass_ import Pass
from x2paddle.optimizer.pytorch_optimizer.fusion import BatchNorm2dFuser
from x2paddle.optimizer.pytorch_optimizer.pass_manager import pass_register
from x2paddle.optimizer.pass_ import Pass
from x2paddle.optimizer.fusion.dygraph import Dygraph_BatchNorm2dFuser
from x2paddle.optimizer.pass_manager import pass_register
@pass_register
class BatchNorm2dFusePass(Pass):
name = "batchnorm2d_fuse_pass"
class Dygraph_BatchNorm2dFusePass(Pass):
name = "dygraph_batchnorm2d_fuse_pass"
def __init__(self):
Pass.__init__(self)
def apply(self, graph):
fuser = BatchNorm2dFuser()
fuser = Dygraph_BatchNorm2dFuser()
fuser.operate(graph, match_kind="topo")
# 用于注册
batchnorm2d_fuse_pass = BatchNorm2dFusePass()
batchnorm2d_fuse_pass = Dygraph_BatchNorm2dFusePass()
......@@ -13,14 +13,14 @@
# limitations under the License.
import numpy as np
from x2paddle.optimizer.pytorch_optimizer.pattern_matcher import FuseBase
from x2paddle.optimizer.pattern_matcher import FuseBase
from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import *
class BatchNorm2dFuser(FuseBase):
class Dygraph_BatchNorm2dFuser(FuseBase):
def __init__(self):
super(BatchNorm2dFuser, self).__init__(graph_type="dygraph")
super(Dygraph_BatchNorm2dFuser, self).__init__(graph_type="dygraph")
def build_pattern(self):
""" 描述需要替换的batchnorm2d图结构。
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.optimizer.pass_ import Pass
from x2paddle.optimizer.fusion.dygraph import Dygraph_BNScaleFuser
from x2paddle.optimizer.pass_manager import pass_register
@pass_register
class Dygraph_BNScaleFusePass(Pass):
name = "dygraph_bn_scale_fuse_pass"
def __init__(self):
Pass.__init__(self)
def apply(self, graph):
fuser = Dygraph_BNScaleFuser()
fuser.operate(graph, match_kind="topo")
# 用于注册
bn_scale_fuse_pass = Dygraph_BNScaleFusePass()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from x2paddle.optimizer.pattern_matcher import FuseBase
from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import *
class Dygraph_BNScaleFuser(FuseBase):
def __init__(self):
super(Dygraph_BNScaleFuser, self).__init__(graph_type="dygraph")
def build_pattern(self):
""" 描述需要替换的batchnorm2d图结构。
batchnorm2d层模式python实现代码示例:
bn_conv1 = self.batchnorm0(conv1)
scale_conv1_cparam1 = self.scale_conv1_cparam1
scale_conv1_mul = paddle.multiply(x=bn_conv1, y=scale_conv1_cparam1, axis=1)
scale_conv1_cparam2 = self.scale_conv1_cparam2
scale_conv1 = fluid.layers.elementwise_add(x=scale_conv1_mul, y=scale_conv1_cparam2, axis=1)
"""
def gen_name(id):
return "x" + str(id)
self.pattern.add_layer(
"paddle.nn.BatchNorm2D",
inputs={"input": "bn-input-0"},
outputs=[gen_name(0)])
self.pattern.add_layer(
"self.create_parameter",
inputs={},
outputs=[gen_name(1)])
inputs_dict = {}
inputs_dict['x'] = gen_name(0)
inputs_dict['y'] = gen_name(1)
self.pattern.add_layer(
"paddle.multiply",
inputs=inputs_dict,
outputs=[gen_name(2)])
self.pattern.add_layer(
"self.create_parameter",
inputs={},
outputs=[gen_name(3)])
inputs_dict = {}
inputs_dict['x'] = gen_name(2)
inputs_dict['y'] = gen_name(3)
self.pattern.add_layer(
"fluid.layers.elementwise_add",
inputs=inputs_dict,
outputs=[gen_name(4)])
self.pattern.build(inputs={"input-0": "bn-input-0"})
def insert_new_layer(self, graph, parameters, matches):
new_layer = self.gen_new_layer(parameters, matches)
new_layer_id = list(matches.keys())[0]
graph.layers[new_layer_id] = new_layer
matches.pop(new_layer_id)
def gen_new_layer(self, parameters, matches):
layers_id = list(matches.keys())
layer = matches[layers_id[0]]
layer_inputs = layer.inputs
bn_name = layer.outputs[0]
layer_attrs = layer.attrs
layer_attrs.pop("weight_attr")
layer_attrs.pop("bias_attr")
layer = matches[layers_id[4]]
layer_outputs = [bn_name] + layer.outputs
layer = matches[layers_id[1]]
data0_name = layer.outputs[0]
data0_numpy = parameters.pop(data0_name)
parameters["{}.weight".format(layer_outputs[0])] = data0_numpy
layer = matches[layers_id[3]]
data1_name = layer.outputs[0]
data1_numpy = parameters.pop(data1_name)
parameters["{}.bias".format(layer_outputs[0])] = data1_numpy
new_layer = PaddleLayer(
layers_id[0],
"paddle.nn.BatchNorm2D",
inputs=layer_inputs,
outputs=layer_outputs,
**layer_attrs)
return new_layer
\ No newline at end of file
......@@ -12,22 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.optimizer.pytorch_optimizer.pass_ import Pass
from x2paddle.optimizer.pytorch_optimizer.fusion import ConstantFuser
from x2paddle.optimizer.pytorch_optimizer.pass_manager import pass_register
from x2paddle.optimizer.pass_ import Pass
from x2paddle.optimizer.fusion.dygraph import Dygraph_ConstantFuser
from x2paddle.optimizer.pass_manager import pass_register
@pass_register
class ConstantFusePass(Pass):
name = "constant_fuse_pass"
class Dygraph_ConstantFusePass(Pass):
name = "dygraph_constant_fuse_pass"
def __init__(self):
Pass.__init__(self)
def apply(self, graph):
fuser = ConstantFuser()
fuser = Dygraph_ConstantFuser()
fuser.operate(graph, match_kind="topo")
# 用于注册
constant_fuse_pass = ConstantFuser()
constant_fuse_pass = Dygraph_ConstantFuser()
......@@ -13,14 +13,14 @@
# limitations under the License.
import numpy as np
from x2paddle.optimizer.pytorch_optimizer.pattern_matcher import FuseBase
from x2paddle.optimizer.pattern_matcher import FuseBase
from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import *
class ConstantFuser(FuseBase):
class Dygraph_ConstantFuser(FuseBase):
def __init__(self):
super(ConstantFuser, self).__init__(graph_type="dygraph")
super(Dygraph_ConstantFuser, self).__init__(graph_type="dygraph")
def build_pattern(self):
""" 描述需要替换的constant图结构。
......
......@@ -12,22 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.optimizer.pytorch_optimizer.pass_ import Pass
from x2paddle.optimizer.pytorch_optimizer.fusion import DropoutFuser
from x2paddle.optimizer.pytorch_optimizer.pass_manager import pass_register
from x2paddle.optimizer.pass_ import Pass
from x2paddle.optimizer.fusion.dygraph import Dygraph_DropoutFuser
from x2paddle.optimizer.pass_manager import pass_register
@pass_register
class DropoutFusePass(Pass):
name = "dropout_fuse_pass"
class Dygraph_DropoutFusePass(Pass):
name = "dygraph_dropout_fuse_pass"
def __init__(self):
Pass.__init__(self)
def apply(self, graph):
fuser = DropoutFuser()
fuser = Dygraph_DropoutFuser()
fuser.operate(graph, match_kind="topo")
# 用于注册
dropout_fuse_pass = DropoutFuser()
dropout_fuse_pass = Dygraph_DropoutFuser()
......@@ -13,14 +13,14 @@
# limitations under the License.
import numpy as np
from x2paddle.optimizer.pytorch_optimizer.pattern_matcher import FuseBase
from x2paddle.optimizer.pattern_matcher import FuseBase
from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import *
class DropoutFuser(FuseBase):
class Dygraph_DropoutFuser(FuseBase):
def __init__(self):
super(DropoutFuser, self).__init__(graph_type="dygraph")
super(Dygraph_DropoutFuser, self).__init__(graph_type="dygraph")
def build_pattern(self):
""" 描述需要替换的constant图结构。
......
......@@ -12,22 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.optimizer.pytorch_optimizer.pass_ import Pass
from x2paddle.optimizer.pytorch_optimizer.fusion import FcFuser
from x2paddle.optimizer.pytorch_optimizer.pass_manager import pass_register
from x2paddle.optimizer.pass_ import Pass
from x2paddle.optimizer.fusion.dygraph import Dygraph_FcFuser
from x2paddle.optimizer.pass_manager import pass_register
@pass_register
class FcFusePass(Pass):
name = "fc_fuse_pass"
class Dygraph_FcFusePass(Pass):
name = "dygraph_fc_fuse_pass"
def __init__(self):
Pass.__init__(self)
def apply(self, graph):
fuser = FcFuser()
fuser = Dygraph_FcFuser()
fuser.operate(graph, match_kind="topo")
# 用于注册
fc_fuse_pass = FcFusePass()
fc_fuse_pass = Dygraph_FcFusePass()
......@@ -13,15 +13,15 @@
# limitations under the License.
import numpy as np
from x2paddle.optimizer.pytorch_optimizer.pattern_matcher import FuseBase
from x2paddle.optimizer.pattern_matcher import FuseBase
from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import *
class FcFuser(FuseBase):
class Dygraph_FcFuser(FuseBase):
def __init__(self):
self.linear_index = 0
super(FcFuser, self).__init__(graph_type="dygraph")
super(Dygraph_FcFuser, self).__init__(graph_type="dygraph")
def build_pattern(self):
""" 描述需要替换的fc图结构。
......
......@@ -12,22 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.optimizer.pytorch_optimizer.pass_ import Pass
from x2paddle.optimizer.pytorch_optimizer.fusion import AdaptivePool2dFuser
from x2paddle.optimizer.pytorch_optimizer.pass_manager import pass_register
from x2paddle.optimizer.pass_ import Pass
from x2paddle.optimizer.fusion.dygraph import Dygraph_InterpolateBilinearFuser
from x2paddle.optimizer.pass_manager import pass_register
@pass_register
class AdaptivePool2dFusePass(Pass):
name = "adaptive_pool2d_fuse_pass"
class Dygraph_InterpolateBilinearFusePass(Pass):
name = "dygraph_interpolate_bilinear_fuse_pass"
def __init__(self):
Pass.__init__(self)
def apply(self, graph):
fuser = AdaptivePool2dFuser()
fuser = Dygraph_InterpolateBilinearFuser()
fuser.operate(graph, match_kind="topo")
# 用于注册
adaptive_pool2d_fuse_pass = AdaptivePool2dFusePass()
interpolate_bilinear_fuse_pass = Dygraph_InterpolateBilinearFusePass()
......@@ -13,14 +13,14 @@
# limitations under the License.
import numpy as np
from x2paddle.optimizer.pytorch_optimizer.pattern_matcher import FuseBase
from x2paddle.optimizer.pattern_matcher import FuseBase
from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import *
class InterpolateBilinearFuser(FuseBase):
class Dygraph_InterpolateBilinearFuser(FuseBase):
def __init__(self):
super(InterpolateBilinearFuser, self).__init__(graph_type="dygraph")
super(Dygraph_InterpolateBilinearFuser, self).__init__(graph_type="dygraph")
import torch
torch_version = torch.__version__
torch_version_part = torch_version.split(".")
......
......@@ -12,22 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.optimizer.pytorch_optimizer.pass_ import Pass
from x2paddle.optimizer.pytorch_optimizer.fusion import ReshapeFuser
from x2paddle.optimizer.pytorch_optimizer.pass_manager import pass_register
from x2paddle.optimizer.pass_ import Pass
from x2paddle.optimizer.fusion.dygraph import Dygraph_ReshapeFuser
from x2paddle.optimizer.pass_manager import pass_register
@pass_register
class ReshapeFusePass(Pass):
name = "reshape_fuse_pass"
class Dygraph_ReshapeFusePass(Pass):
name = "dygraph_reshape_fuse_pass"
def __init__(self):
Pass.__init__(self)
def apply(self, graph):
fuser = ReshapeFuser()
fuser = Dygraph_ReshapeFuser()
fuser.operate(graph, match_kind="edge")
# 用于注册
reshape_fuse_pass = ReshapeFusePass()
reshape_fuse_pass = Dygraph_ReshapeFusePass()
......@@ -13,14 +13,14 @@
# limitations under the License.
import numpy as np
from x2paddle.optimizer.pytorch_optimizer.pattern_matcher import FuseBase
from x2paddle.optimizer.pattern_matcher import FuseBase
from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import *
class ReshapeFuser(FuseBase):
class Dygraph_ReshapeFuser(FuseBase):
def __init__(self):
super(ReshapeFuser, self).__init__(graph_type="dygraph")
super(Dygraph_ReshapeFuser, self).__init__(graph_type="dygraph")
def build_pattern(self):
""" 描述需要替换的reshape图结构。
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .bn_scale_fuser import Static_BNScaleFuser
from .bn_scale_fuse_pass import Static_BNScaleFusePass
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.optimizer.pass_ import Pass
from x2paddle.optimizer.fusion.static import Static_BNScaleFuser
from x2paddle.optimizer.pass_manager import pass_register
@pass_register
class Static_BNScaleFusePass(Pass):
name = "static_bn_scale_fuse_pass"
def __init__(self):
Pass.__init__(self)
def apply(self, graph):
fuser = Static_BNScaleFuser()
fuser.operate(graph, match_kind="topo")
# 用于注册
bn_scale_fuse_pass = Static_BNScaleFusePass()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from x2paddle.optimizer.pattern_matcher import FuseBase
from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import *
class Static_BNScaleFuser(FuseBase):
def __init__(self):
super(Static_BNScaleFuser, self).__init__(graph_type="dygraph")
def build_pattern(self):
""" 描述需要替换的batchnorm2d图结构。
batchnorm2d层模式python实现代码示例:
conv5_bn = fluid.layers.batch_norm(input=conv5, is_test=True, param_attr=None, bias_attr=None, moving_mean_name='conv5_bn_mean', moving_variance_name='conv5_bn_variance', epsilon=9.999999747378752e-06, name='conv5_bn')
conv5_scale_scale = fluid.ParamAttr(name='conv5_scale_scale')
conv5_scale_cparam1 = fluid.layers.create_parameter(attr=conv5_scale_scale, dtype=conv5_bn.dtype, shape=[256], name='conv5_scale_cparam1', is_bias=True, default_initializer=Constant(value=1.0))
conv5_scale_mul = fluid.layers.elementwise_mul(x=conv5_bn, y=conv5_scale_cparam1, axis=1)
conv5_scale_offset = fluid.ParamAttr(name='conv5_scale_offset')
conv5_scale_cparam2 = fluid.layers.create_parameter(attr=conv5_scale_offset, dtype=conv5_bn.dtype, shape=[256], name='conv5_scale_cparam2', is_bias=True, default_initializer=Constant(value=1.0))
conv5_scale = fluid.layers.elementwise_add(x=conv5_scale_mul, y=conv5_scale_cparam2, axis=1)
"""
def gen_name(id):
return "x" + str(id)
self.pattern.add_layer(
"fluid.layers.batch_norm",
inputs={"input": "bn-input-0"},
outputs=[gen_name(0)])
self.pattern.add_layer(
"fluid.ParamAttr",
inputs={},
outputs=[gen_name(1)])
self.pattern.add_layer(
"fluid.layers.create_parameter",
inputs={"attr": gen_name(1)},
outputs=[gen_name(2)])
inputs_dict = {}
inputs_dict['x'] = gen_name(0)
inputs_dict['y'] = gen_name(2)
self.pattern.add_layer(
"fluid.layers.elementwise_mul",
inputs=inputs_dict,
outputs=[gen_name(3)])
self.pattern.add_layer(
"fluid.ParamAttr",
inputs={},
outputs=[gen_name(4)])
self.pattern.add_layer(
"fluid.layers.create_parameter",
inputs={"attr": gen_name(4)},
outputs=[gen_name(5)])
inputs_dict = {}
inputs_dict['x'] = gen_name(3)
inputs_dict['y'] = gen_name(5)
self.pattern.add_layer(
"fluid.layers.elementwise_add",
inputs=inputs_dict,
outputs=[gen_name(6)])
self.pattern.build(inputs={"input-0": "bn-input-0"})
def insert_new_layer(self, graph, parameters, matches):
new_layer = self.gen_new_layer(parameters, matches)
new_layer_id = list(matches.keys())[0]
graph.layers[new_layer_id] = new_layer
matches.pop(new_layer_id)
def gen_new_layer(self, parameters, matches):
layers_id = list(matches.keys())
layer = matches[layers_id[0]]
layer_inputs = layer.inputs
layer_name = layer.outputs[0]
layer_attrs = layer.attrs
layer_attrs["param_attr"] = string("{}_scale".format(layer_name))
layer_attrs["bias_attr"] = string("{}_offset".format(layer_name))
layer = matches[layers_id[-1]]
layer_outputs = layer.outputs
layer = matches[layers_id[1]]
layer_name = layer.outputs[0]
scale_numpy = parameters.pop(layer_name)
parameters[layer_attrs["param_attr"][1: -1]] = scale_numpy
layer = matches[layers_id[4]]
layer_name = layer.outputs[0]
scale_numpy = parameters.pop(layer_name)
parameters[layer_attrs["bias_attr"][1: -1]] = scale_numpy
new_layer = PaddleLayer(
layers_id[0],
"fluid.layers.batch_norm",
inputs=layer_inputs,
outputs=layer_outputs,
**layer_attrs)
return new_layer
\ No newline at end of file
......@@ -12,22 +12,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.optimizer.pytorch_optimizer.fusion import *
from x2paddle.optimizer.pytorch_optimizer.pass_manager import PassManager
from x2paddle.optimizer.pass_manager import PassManager
from x2paddle.optimizer.fusion.dygraph import *
from x2paddle.optimizer.fusion.static import *
class GraphOptimizer(object):
def __init__(self):
self.passes = [
"constant_fuse_pass", "batchnorm2d_fuse_pass",
"interpolate_bilinear_fuse_pass", "fc_fuse_pass",
"adaptive_pool2d_fuse_pass", "reshape_fuse_pass",
"dropout_fuse_pass"
]
def __init__(self, source_frame, paddle_type="dygraph"):
if source_frame == "pytorch":
self.passes = [
"dygraph_constant_fuse_pass", "dygraph_batchnorm2d_fuse_pass",
"dygraph_interpolate_bilinear_fuse_pass", "dygraph_fc_fuse_pass",
"dygraph_adaptive_pool2d_fuse_pass", "dygraph_reshape_fuse_pass",
"dygraph_dropout_fuse_pass"
]
elif source_frame == "caffe":
if paddle_type == "dygraph":
self.passes = ["dygraph_bn_scale_fuse_pass"]
else:
self.passes = ["static_bn_scale_fuse_pass"]
else:
# TODO
pass
def optimize(self, graph):
for pass_name in self.passes:
pass_ = PassManager.lookup(pass_name)()
pass_.apply(graph)
while True:
before_len = len(graph.layers)
pass_.apply(graph)
after_len = len(graph.layers)
if before_len == after_len:
break
print("{} done!".format(pass_name))
return graph
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.optimizer.pytorch_optimizer.pass_ import Pass
from x2paddle.optimizer.pytorch_optimizer.fusion import InterpolateBilinearFuser
from x2paddle.optimizer.pytorch_optimizer.pass_manager import pass_register
@pass_register
class InterpolateBilinearFusePass(Pass):
name = "interpolate_bilinear_fuse_pass"
def __init__(self):
Pass.__init__(self)
def apply(self, graph):
fuser = InterpolateBilinearFuser()
fuser.operate(graph, match_kind="topo")
# 用于注册
interpolate_bilinear_fuse_pass = InterpolateBilinearFusePass()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册