未验证 提交 55d5eb24 编写于 作者: S SunAhong1993 提交者: GitHub

Merge pull request #16 from PaddlePaddle/develop

Merge
...@@ -35,15 +35,15 @@ pip install x2paddle==1.0.0rc0 --index https://pypi.Python.org/simple/ ...@@ -35,15 +35,15 @@ pip install x2paddle==1.0.0rc0 --index https://pypi.Python.org/simple/
## 使用方法 ## 使用方法
### TensorFlow ### TensorFlow
``` ```
x2paddle --framework=tensorflow --model=tf_model.pb --save_dir=pd_model x2paddle --framework=tensorflow --model=tf_model.pb --save_dir=pd_model --paddle_type dygraph
``` ```
### Caffe ### Caffe
``` ```
x2paddle --framework=caffe --prototxt=deploy.prototxt --weight=deploy.caffemodel --save_dir=pd_model x2paddle --framework=caffe --prototxt=deploy.prototxt --weight=deploy.caffemodel --save_dir=pd_model --paddle_type dygraph
``` ```
### ONNX ### ONNX
``` ```
x2paddle --framework=onnx --model=onnx_model.onnx --save_dir=pd_model x2paddle --framework=onnx --model=onnx_model.onnx --save_dir=pd_model --paddle_type dygraph
``` ```
### PyTorch ### PyTorch
...@@ -64,6 +64,7 @@ x2paddle --framework=onnx --model=onnx_model.onnx --save_dir=pd_model ...@@ -64,6 +64,7 @@ x2paddle --framework=onnx --model=onnx_model.onnx --save_dir=pd_model
|--caffe_proto | **[可选]** 由caffe.proto编译成caffe_pb2.py文件的存放路径,当存在自定义Layer时使用,默认为None | |--caffe_proto | **[可选]** 由caffe.proto编译成caffe_pb2.py文件的存放路径,当存在自定义Layer时使用,默认为None |
|--define_input_shape | **[可选]** For TensorFlow, 当指定该参数时,强制用户输入每个Placeholder的shape,见[文档Q2](./docs/user_guides/FAQ.md) | |--define_input_shape | **[可选]** For TensorFlow, 当指定该参数时,强制用户输入每个Placeholder的shape,见[文档Q2](./docs/user_guides/FAQ.md) |
|--params_merge | **[可选]** 当指定该参数时,转换完成后,inference_model中的所有模型参数将合并保存为一个文件__params__ | |--params_merge | **[可选]** 当指定该参数时,转换完成后,inference_model中的所有模型参数将合并保存为一个文件__params__ |
|--paddle_type | **[可选]** 该参数指定转换为动态图代码(dygraph)或者静态图代码(static),默认为dygraph|
......
...@@ -185,16 +185,8 @@ def onnx2paddle(model_path, save_dir, paddle_type, params_merge=False): ...@@ -185,16 +185,8 @@ def onnx2paddle(model_path, save_dir, paddle_type, params_merge=False):
from x2paddle.op_mapper.static.onnx2paddle.onnx_op_mapper import ONNXOpMapper from x2paddle.op_mapper.static.onnx2paddle.onnx_op_mapper import ONNXOpMapper
model = ONNXDecoder(model_path) model = ONNXDecoder(model_path)
mapper = ONNXOpMapper(model) mapper = ONNXOpMapper(model)
if paddle_type == "dygraph": mapper.paddle_graph.build()
mapper.paddle_graph.build() mapper.paddle_graph.gen_model(save_dir)
mapper.paddle_graph.gen_model(save_dir)
else:
from x2paddle.optimizer.onnx_optimizer import ONNXOptimizer
print("Model optimizing ...")
optimizer = ONNXOptimizer(mapper)
optimizer.delete_redundance_code()
print("Model optimized.")
mapper.save_inference_model(save_dir, params_merge)
def pytorch2paddle(module, save_dir, jit_type="trace", input_examples=None): def pytorch2paddle(module, save_dir, jit_type="trace", input_examples=None):
......
...@@ -117,7 +117,13 @@ class ONNXGraphDataNode(GraphNode): ...@@ -117,7 +117,13 @@ class ONNXGraphDataNode(GraphNode):
if isinstance(self.layer, ValueInfoProto): if isinstance(self.layer, ValueInfoProto):
values = self.layer.type.tensor_type.shape.dim values = self.layer.type.tensor_type.shape.dim
out_shapes = list() out_shapes = list()
out_shapes.append([dim.dim_value for dim in values]) shape = list()
for dim in values:
if dim.dim_value == 0:
shape.append(-1)
else:
shape.append(dim.dim_value)
out_shapes.append(shape)
return out_shapes return out_shapes
else: else:
values = self.layer.dims values = self.layer.dims
......
...@@ -130,7 +130,7 @@ class TFGraph(Graph): ...@@ -130,7 +130,7 @@ class TFGraph(Graph):
def __init__(self, model, data_format="NHWC"): def __init__(self, model, data_format="NHWC"):
super(TFGraph, self).__init__(model) super(TFGraph, self).__init__(model)
self.identity_map = dict() self.identity_map = dict()
self.multi_out_ops = ['Split', 'SplitV', 'IteratorV2'] self.multi_out_ops = ['Split', 'SplitV', 'IteratorV2', 'Unpack']
self.tf_data_format = data_format self.tf_data_format = data_format
self.graph_name = "TFModel" self.graph_name = "TFModel"
...@@ -173,6 +173,7 @@ class TFGraph(Graph): ...@@ -173,6 +173,7 @@ class TFGraph(Graph):
self._optimize_dialiation_conv() self._optimize_dialiation_conv()
self._remove_identity_node() self._remove_identity_node()
self._remove_cast_node() self._remove_cast_node()
def get_node(self, node_name, copy=False): def get_node(self, node_name, copy=False):
items = node_name.strip().split(':') items = node_name.strip().split(':')
...@@ -192,6 +193,8 @@ class TFGraph(Graph): ...@@ -192,6 +193,8 @@ class TFGraph(Graph):
def get_input_node(self, node, idx=0, copy=False): def get_input_node(self, node, idx=0, copy=False):
input_node_name = node.layer.input[idx] input_node_name = node.layer.input[idx]
if idx > 0:
copy = True
return self.get_node(input_node_name, copy) return self.get_node(input_node_name, copy)
def remove_node(self, node_name): def remove_node(self, node_name):
...@@ -402,7 +405,7 @@ class TFDecoder(object): ...@@ -402,7 +405,7 @@ class TFDecoder(object):
right_shape_been_input = False right_shape_been_input = False
while not right_shape_been_input: while not right_shape_been_input:
try: try:
shape = input( shape = raw_input(
"Shape of Input(e.g. None,224,224,3): ") "Shape of Input(e.g. None,224,224,3): ")
except: except:
shape = input("Shape of Input(e.g. None,224,224,3): ") shape = input("Shape of Input(e.g. None,224,224,3): ")
......
...@@ -534,7 +534,7 @@ class OpSet9(): ...@@ -534,7 +534,7 @@ class OpSet9():
'bias_attr': string(val_b.name) 'bias_attr': string(val_b.name)
} }
dim = len(val_x.out_shapes[0]) dim = len(val_x.out_shapes[0])
if dim == 2 or dim == 3: if dim == 3:
paddle_op = "paddle.nn.InstanceNorm1D" paddle_op = "paddle.nn.InstanceNorm1D"
elif dim == 4: elif dim == 4:
paddle_op = "paddle.nn.InstanceNorm2D" paddle_op = "paddle.nn.InstanceNorm2D"
...@@ -1539,7 +1539,6 @@ class OpSet9(): ...@@ -1539,7 +1539,6 @@ class OpSet9():
layer_outputs = [op_name, output_name] layer_outputs = [op_name, output_name]
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_w = self.graph.get_input_node(node, idx=1, copy=True) val_w = self.graph.get_input_node(node, idx=1, copy=True)
val_y = self.graph.get_node(node.layer.output[0], copy=True)
has_bias = len(node.layer.input) == 3 has_bias = len(node.layer.input) == 3
if has_bias: if has_bias:
val_b = self.graph.get_input_node(node, idx=2, copy=True) val_b = self.graph.get_input_node(node, idx=2, copy=True)
...@@ -1620,23 +1619,7 @@ class OpSet9(): ...@@ -1620,23 +1619,7 @@ class OpSet9():
output_size[1] = (val_x.out_shapes[0][3] - 1 output_size[1] = (val_x.out_shapes[0][3] - 1
) * strides[1] - 2 * paddings[1] + dilations[1] * ( ) * strides[1] - 2 * paddings[1] + dilations[1] * (
kernel_shape[1] - 1) + 1 + out_padding[1] kernel_shape[1] - 1) + 1 + out_padding[1]
# layer_attrs = { # Conv2DTranspose缺少output_size,只能在forward里头传进output_size
# 'in_channels': num_in_channels,
# 'out_channels': num_out_channels,
# 'output_size': output_size or None,
# 'kernel_size': kernel_shape,
# 'padding': paddings,
# 'stride': strides,
# 'dilation': dilations,
# 'groups': num_groups,
# 'weight_attr': string(val_w.name),
# 'bias_attr': None if val_b is None else string(val_b.name),
# }
# self.paddle_graph.add_layer(
# paddle_op,
# inputs={"x": val_x.name},
# outputs=layer_outputs,
# **layer_attrs)
inputs_dict = {'x': val_x if isinstance(val_x, str) else val_x.name, inputs_dict = {'x': val_x if isinstance(val_x, str) else val_x.name,
"weight": val_w.name} "weight": val_w.name}
layer_attrs = { layer_attrs = {
......
...@@ -73,15 +73,17 @@ class TFOpMapper(OpMapper): ...@@ -73,15 +73,17 @@ class TFOpMapper(OpMapper):
'Sub': 'fluid.layers.elementwise_sub', 'Sub': 'fluid.layers.elementwise_sub',
'Maximum': 'paddle.maximum', 'Maximum': 'paddle.maximum',
'Minimum': 'paddle.minimum', 'Minimum': 'paddle.minimum',
'Mul': 'paddle.multiply',
'FloorDiv': 'paddle.floor_divide',
'FloorMod': 'paddle.floor_mod',
'LogicalAnd': 'logical_and',
}
bool_ops = {
'LessEqual': 'paddle.less_equal', 'LessEqual': 'paddle.less_equal',
'GreaterEqual': 'paddle.greater_equal', 'GreaterEqual': 'paddle.greater_equal',
'Greater': 'paddle.greater_than', 'Greater': 'paddle.greater_than',
'NotEqual': 'paddle.not_equal', 'NotEqual': 'paddle.not_equal',
'Equal': 'paddle.equal', 'Equal': 'paddle.equal',
'Mul': 'paddle.multiply',
'FloorDiv': 'paddle.floor_divide',
'FloorMod': 'paddle.floor_mod',
'LogicalAnd': 'logical_and',
} }
def __init__(self, decoder): def __init__(self, decoder):
...@@ -123,6 +125,8 @@ class TFOpMapper(OpMapper): ...@@ -123,6 +125,8 @@ class TFOpMapper(OpMapper):
self.directly_map(node) self.directly_map(node)
elif op in self.elementwise_ops: elif op in self.elementwise_ops:
self.elementwise_map(node) self.elementwise_map(node)
elif op in self.bool_ops:
self.bool_map(node)
elif hasattr(self, op): elif hasattr(self, op):
func = getattr(self, op) func = getattr(self, op)
func(node) func(node)
...@@ -138,7 +142,8 @@ class TFOpMapper(OpMapper): ...@@ -138,7 +142,8 @@ class TFOpMapper(OpMapper):
op = node.layer_type op = node.layer_type
if not hasattr(self, op) and \ if not hasattr(self, op) and \
op not in self.directly_map_ops and \ op not in self.directly_map_ops and \
op not in self.elementwise_ops: op not in self.elementwise_ops and \
op not in self.bool_ops:
unsupported_ops.add(op) unsupported_ops.add(op)
if len(unsupported_ops) == 0: if len(unsupported_ops) == 0:
return True return True
...@@ -178,8 +183,10 @@ class TFOpMapper(OpMapper): ...@@ -178,8 +183,10 @@ class TFOpMapper(OpMapper):
outputs=[node.name], outputs=[node.name],
**layer_attrs) **layer_attrs)
def elementwise_map(self, node): def elementwise_map(self, node, op_type=None):
op_type = self.elementwise_ops[node.layer_type] if op_type is None:
assert node.layer_type in self.elementwise_ops
op_type = self.elementwise_ops[node.layer_type]
x = self.graph.get_input_node(node, 0) x = self.graph.get_input_node(node, 0)
y = self.graph.get_input_node(node, 1) y = self.graph.get_input_node(node, 1)
x_shape = x.out_shapes[0] x_shape = x.out_shapes[0]
...@@ -190,6 +197,11 @@ class TFOpMapper(OpMapper): ...@@ -190,6 +197,11 @@ class TFOpMapper(OpMapper):
"y": y.name}, "y": y.name},
outputs=[node.name]) outputs=[node.name])
self.paddle_graph.layers[layer_id].input_shapes = {"x": x_shape, "y": y_shape} self.paddle_graph.layers[layer_id].input_shapes = {"x": x_shape, "y": y_shape}
def bool_map(self, node):
op_type = self.bool_ops[node.layer_type]
self.elementwise_map(node, op_type)
node.set_dtype("bool")
def Placeholder(self, node): def Placeholder(self, node):
shape = node.out_shapes[0] shape = node.out_shapes[0]
......
...@@ -12,9 +12,11 @@ ...@@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from x2paddle.op_mapper.static.onnx2paddle.opset9 import OpSet9, custom_layers import sys
from x2paddle.op_mapper.static.onnx2paddle.opset9 import OpSet9
from x2paddle.core.op_mapper import OpMapper from x2paddle.core.op_mapper import OpMapper
from x2paddle.decoder.onnx_decoder import ONNXGraph, ONNXGraphNode, ONNXGraphDataNode from x2paddle.decoder.onnx_decoder import ONNXGraphNode
from x2paddle.core.program import PaddleGraph
class ONNXOpMapper(OpMapper): class ONNXOpMapper(OpMapper):
...@@ -23,33 +25,36 @@ class ONNXOpMapper(OpMapper): ...@@ -23,33 +25,36 @@ class ONNXOpMapper(OpMapper):
self.support_op_sets = [9, ] self.support_op_sets = [9, ]
self.default_op_set = 9 self.default_op_set = 9
self.graph = decoder.graph self.graph = decoder.graph
self.paddle_graph = PaddleGraph(parent_layer=None, graph_type="static", source_type="onnx")
self.paddle_graph.outputs = self.graph.output_nodes
self.opset = self.create_opset(decoder) self.opset = self.create_opset(decoder)
if not self.op_checker(): if not self.op_checker():
raise Exception("Model are not supported yet.") raise Exception("Model is not supported yet.")
#mapping op
print("Total nodes: {}".format( print("Total nodes: {}".format(
sum([ sum([
isinstance(node, ONNXGraphNode) isinstance(node, ONNXGraphNode)
for name, node in self.graph.node_map.items() for name, node in self.graph.node_map.items()
]))) ])))
print("Nodes converting ...") print("Nodes converting ...")
for node_name in self.graph.topo_sort: for i, node_name in enumerate(self.graph.topo_sort):
sys.stderr.write("\rConverting node {} ... ".format(i + 1))
node = self.graph.get_node(node_name) node = self.graph.get_node(node_name)
op = node.layer_type op = node.layer_type
if hasattr(self.opset, op): if hasattr(self.opset, op):
func = getattr(self.opset, op) func = getattr(self.opset, op)
func(node) func(node)
elif op in self.opset.default_op_mapping: elif op in self.opset.directly_map_ops:
self.opset.directly_map(node) self.opset.directly_map(node)
elif op in custom_layers:
self.opset.deal_custom_layer(node)
elif op in self.opset.elementwise_ops: elif op in self.opset.elementwise_ops:
self.opset.elementwise_map(node) self.opset.elementwise_map(node)
print("Nodes converted.") print("\nNodes converted.")
self.weights = self.opset.weights self.paddle_graph.set_name(self.graph.graph_name)
self.omit_nodes = self.opset.omit_nodes self.paddle_graph.set_parameters(self.opset.params)
self.used_custom_layers = self.opset.used_custom_layers self.paddle_graph.set_inputs_info(self.opset.inputs_info)
self.paddle_graph.inputs = self.graph.input_nodes
self.paddle_graph.outputs = self.graph.output_nodes
def op_checker(self): def op_checker(self):
unsupported_ops = set() unsupported_ops = set()
...@@ -57,17 +62,17 @@ class ONNXOpMapper(OpMapper): ...@@ -57,17 +62,17 @@ class ONNXOpMapper(OpMapper):
node = self.graph.get_node(node_name) node = self.graph.get_node(node_name)
op = node.layer_type op = node.layer_type
if not hasattr(self.opset, op) and \ if not hasattr(self.opset, op) and \
op not in self.opset.default_op_mapping and \ op not in self.opset.directly_map_ops and \
op not in custom_layers and \
op not in self.opset.elementwise_ops: op not in self.opset.elementwise_ops:
unsupported_ops.add(op) unsupported_ops.add(op)
if len(unsupported_ops) == 0: if len(unsupported_ops) == 0:
return True return True
else: else:
print("There are {} ops not supported yet, list as below".format( if len(unsupported_ops) > 0:
len(unsupported_ops))) print("\n========= {} OPs are not supported yet ===========".format(
len(unsupported_ops)))
for op in unsupported_ops: for op in unsupported_ops:
print(op) print("========== {} ============".format(op))
return False return False
def create_opset(self, decoder): def create_opset(self, decoder):
...@@ -88,4 +93,4 @@ class ONNXOpMapper(OpMapper): ...@@ -88,4 +93,4 @@ class ONNXOpMapper(OpMapper):
'Now, onnx2paddle support convert onnx model opset_verison {},' 'Now, onnx2paddle support convert onnx model opset_verison {},'
'opset_verison of your onnx model is {}, automatically treated as op_set: {}.' 'opset_verison of your onnx model is {}, automatically treated as op_set: {}.'
.format(self.support_op_sets, decoder.op_set, run_op_set)) .format(self.support_op_sets, decoder.op_set, run_op_set))
return eval(opset)(decoder) return eval(opset)(decoder, self.paddle_graph)
from .opset import OpSet9 from .opset import OpSet9
from .custom_layer import custom_layers
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .register import get_registered_layers
custom_layers = get_registered_layers()
def set_args(f, params):
""" set args for function 'f' using the parameters in node.layer.param
Args:
f (function): a python function object
params (object): a object contains attributes needed by f's arguments
Returns:
arg_names (list): a list of argument names
kwargs (dict): a dict contains needed arguments
"""
argc = f.__code__.co_argcount
arg_list = f.__code__.co_varnames[0:argc]
kwargs = {}
for arg_name in arg_list:
if hasattr(params, arg_name) and params is not None:
kwargs[arg_name] = getattr(params, arg_name)
return arg_list, kwargs
def has_layer(layer_type):
""" test whether this layer exists in custom layer
"""
return layer_type in custom_layers
def get_params(layer, layer_type):
import re
if layer_type.lower() == "deconvolution" or layer_type.lower(
) == "convolutiondepthwise":
param_name = '_'.join(('convolution', 'param'))
elif layer_type.lower() == "normalize":
param_name = '_'.join(('norm', 'param'))
elif len(layer_type) - len(re.sub("[A-Z]", "", layer_type)) >= 2:
s = ''
tmp_name = ''
for i, ch in enumerate(layer_type):
if i == 0:
s += ch.lower()
continue
elif ch.isupper() and layer_type[i - 1].islower():
tmp_name += (s + '_')
s = ''
s += ch.lower()
tmp_name += s
param_name = '_'.join((tmp_name, 'param'))
else:
param_name = '_'.join((layer_type.lower(), 'param'))
return getattr(layer, param_name, None)
def compute_output_shape(node):
""" compute the output shape of custom layer
"""
layer_type = node.layer_type
assert layer_type in custom_layers, "layer[%s] not exist in custom layers" % (
layer_type)
shape_func = custom_layers[layer_type]['shape']
layer = node.layer
params = get_params(layer, layer_type)
arg_names, kwargs = set_args(shape_func, params)
input_shape = node.input_shape
return shape_func(input_shape, **kwargs)
def make_custom_layer(node):
""" get the code which implement the custom layer function
"""
layer_type = node.layer_type
assert layer_type in custom_layers, "layer[%s] not exist in custom layers" % (
layer_type)
layer_func = custom_layers[layer_type]['layer']
import inspect
return inspect.getsource(layer_func), layer_func
def make_custom_child_func(node):
""" get the code which implement the custom layer function
"""
layer_type = node.layer_type
child_func = custom_layers[layer_type]['child_func']
if child_func is None:
return None, child_func
import inspect
return inspect.getsource(child_func), child_func
def deal_weights(node, data=None):
""" deal the weights of the custom layer
"""
layer_type = node.layer_type
weights_func = custom_layers[layer_type]['weights']
name = node.layer_name
return weights_func(name, data)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" this module provides 'register' for registering customized layers
"""
g_custom_layers = {}
def register(kind, shape, layer, child_func, weights):
""" register a custom layer or a list of custom layers
Args:
@kind (str or list): type name of the layer
@shape (function): a function to generate the shape of layer's output
@layer (function): a function to generate the paddle code of layer
@weights (function): a function to deal with weights data
Returns:
None
"""
assert type(shape).__name__ == 'function', 'shape should be a function'
assert type(layer).__name__ == 'function', 'layer should be a function'
if type(kind) is str:
kind = [kind]
else:
assert type(
kind) is list, 'invalid param "kind" for register, not a list or str'
for k in kind:
assert type(
k) is str, 'invalid param "kind" for register, not a list of str'
assert k not in g_custom_layers, 'this type[%s] has already been registered' % (
k)
g_custom_layers[k] = {
'shape': shape,
'layer': layer,
'child_func': child_func,
'weights': weights
}
def get_registered_layers():
return g_custom_layers
...@@ -17,7 +17,6 @@ from x2paddle.core.graph import GraphNode ...@@ -17,7 +17,6 @@ from x2paddle.core.graph import GraphNode
from x2paddle.core.fluid_code import Layer from x2paddle.core.fluid_code import Layer
from x2paddle.core.fluid_code import FluidCode from x2paddle.core.fluid_code import FluidCode
from x2paddle.core.util import string from x2paddle.core.util import string
from x2paddle.op_mapper.static.onnx2paddle.opset9.custom_layer import *
from functools import reduce from functools import reduce
import numpy as np import numpy as np
import onnx import onnx
...@@ -27,6 +26,8 @@ import logging as _logging ...@@ -27,6 +26,8 @@ import logging as _logging
from collections import OrderedDict from collections import OrderedDict
import math import math
import os import os
import copy
import sys
import shutil import shutil
_logger = _logging.getLogger(__name__) _logger = _logging.getLogger(__name__)
...@@ -85,182 +86,118 @@ def print_mapping_info(func): ...@@ -85,182 +86,118 @@ def print_mapping_info(func):
class OpSet9(): class OpSet9():
elementwise_ops = { elementwise_ops = {
'Add': 'elementwise_add', 'Add': 'paddle.add',
'Div': 'elementwise_div', 'Div': 'paddle.divide',
'Sub': 'elementwise_sub', 'Sub': 'fluid.layers.elementwise_sub',
'Mul': 'elementwise_mul', 'Mul': 'paddle.multiply',
'Pow': 'elementwise_pow', 'Pow': 'paddle.pow',
} }
default_op_mapping_field_values = OrderedDict() directly_map_ops = {
default_op_mapping_field_values['FLUID_OP'] = '' 'Ceil': ['paddle.ceil'],
default_op_mapping_field_values['FLUID_INPUT_ARGS'] = None # reduce function
default_op_mapping_field_values['FLUID_OUTPUT_ARGS'] = None 'ReduceMean': ['paddle.mean',
default_op_mapping_field_values['ATTR_MAPPING'] = dict() dict(axes='axis', keepdims='keepdim'),
default_op_mapping_field_values['DEFAULTS'] = dict() dict(keepdims=1)],
default_op_mapping_field_values['INPUT_PERM'] = None 'ReduceSum': ['paddle.sum',
default_op_mapping_field_values['OUTPUT_PERM'] = None dict(axes='axis', keepdims='keepdim'),
default_op_mapping_field_values['FILL_NAME_FIELD'] = True dict(keepdims=1)],
'ReduceMin': ['paddle.min',
default_op_mapping = { dict(axes='axis', keepdims='keepdim'),
'Shape': ['shape', ['X'], ['Out']], dict(keepdim=1)],
'Erf': ['erf', ['X'], ['Out']], 'ReduceMax': ['paddle.max',
'Ceil': ['ceil', ['X'], ['Out']], dict(axes='axis', keepdims='keepdim'),
'ReduceMean': [ dict(keepdim=1)],
'reduce_mean', ['X'], ['Out'], dict( # active function
axes='dim', keepdims='keep_dim'), dict(keep_dim=1) 'Relu': ['paddle.nn.functional.relu'],
], 'LeakyRelu': ['paddle.nn.functional.leaky_relu',
'ReduceSum': [ dict(alpha='negative_slope'),
'reduce_sum', ['X'], ['Out'], dict( dict(negative_slope=.01)],
axes='dim', keepdims='keep_dim'), dict(keep_dim=1) 'Elu': ['paddle.nn.functional.elu',
], dict(alpha='alpha'),
'ReduceMin': [ dict(alpha=1.)],
'reduce_min', ['X'], ['Out'], dict( 'ThresholdedRelu': ['paddle.nn.functional.thresholded_relu',
axes='dim', keepdims='keep_dim'), dict(keep_dim=1) dict(alpha='threshold'),
], dict(alpha=1.)],
'ReduceMax': [ 'Tanh': ['paddle.nn.functional.tanh'],
'reduce_max', ['X'], ['Out'], dict( 'Sigmoid': ['paddle.nn.functional.sigmoid'],
axes='dim', keepdims='keep_dim'), dict(keep_dim=1) 'Softsign': ['paddle.nn.functional.softsign'],
], 'Softplus': ['paddle.nn.functional.softplus',
#active function dict(threshold='threshold'),
'Relu': ['relu', ['X'], ['Out']], dict(threshold=float(sys.maxsize))],
'LeakyRelu': ['leaky_relu', ['X'], ['Out'], dict(), dict(alpha=.01)], 'Exp': ['paddle.exp'],
'Elu': ['elu', ['X'], ['Out'], dict(), dict(alpha=1.)], 'Softmax': ['paddle.nn.functional.softmax',
'ThresholdedRelu': [ dict(axis='axis'),
'thresholded_relu', ['X'], ['Out'], dict(alpha='threshold'), dict(axis=1)],
dict(alpha=1.) 'Sqrt': ['paddle.sqrt'],
], 'Floor': ['paddle.floor'],
'Tanh': ['tanh', ['X'], ['Out']], 'Abs': ['paddle.abs'],
'Sigmoid': ['sigmoid', ['X'], ['Out']], 'Erf': ['paddle.erf'],
'HardSigmoid': [
'hard_sigmoid', ['X'], ['Out'], dict(
alpha='slope', beta='offset'), dict(
slope=.2, offset=.5)
],
'Softsign': ['softsign', ['X'], ['Out']],
'Softplus': ['softplus', ['X'], ['Out']],
'Exp': ['exp', ['X'], ['Out']],
'Softmax': ['softmax', ['X'], ['Out'], dict(), dict(axis=1)],
'Sqrt': ['sqrt', ['X'], ['Out']],
'Floor': ['floor', ['X'], ['Out']],
'Abs': ['abs', ['X'], ['Out']],
} }
default_ioa_constraint = {} def __init__(self, decoder, paddle_graph):
def __init__(self, decoder):
super(OpSet9, self).__init__() super(OpSet9, self).__init__()
self.graph = decoder.graph self.graph = decoder.graph
self.input_shapes = [] self.paddle_graph = paddle_graph
self.weights = dict() self.input_index = 0
self.omit_nodes = list() self.inputs_info = dict()
self.used_custom_layers = dict() self.params = dict()
@print_mapping_info @print_mapping_info
def directly_map(self, node, name='', *args, **kwargs): def directly_map(self, node, *args, **kwargs):
inputs = node.layer.input inputs = node.layer.input
outputs = node.layer.output assert len(inputs) == 1, 'directly_map error with multi inputs'
op_type = node.layer_type input = self.graph.get_input_node(node, idx=0, copy=True)
attrs = node.attr_map onnx_attrs = node.attr_map
info = self.default_op_mapping[op_type] if '' in onnx_attrs:
info.extend( onnx_attrs.pop('')
list(self.default_op_mapping_field_values.values())[len(info):]) if '_' in onnx_attrs:
( onnx_attrs.pop('_')
fluid_op, op_info = self.directly_map_ops[node.layer_type]
fluid_input_args, paddle_op = op_info[0]
fluid_output_args, layer_attrs = dict()
attr_mapping, if len(op_info) > 1:
default_attrs, attrs_name_map_dict = op_info[1]
input_perm, for onnx_attr_name, pd_attr_name in attrs_name_map_dict.items():
output_perm, if onnx_attr_name in onnx_attrs:
fill_name_field, ) = info layer_attrs[pd_attr_name] = onnx_attrs[onnx_attr_name]
else:
if fluid_op in self.default_ioa_constraint: layer_attrs[pd_attr_name] = op_info[2][onnx_attr_name]
for predicate, message in self.default_ioa_constraint[fluid_op]: self.paddle_graph.add_layer(
assert predicate(inputs, outputs, attrs), message kernel=paddle_op,
inputs={"x": input.name},
mapped_attrs = { outputs=[node.name],
attr_mapping.get(key, key): value **layer_attrs)
for key, value in attrs.items()
}
if '' in mapped_attrs:
mapped_attrs.pop('')
if '_' in mapped_attrs:
mapped_attrs.pop('_')
fluid_attrs = default_attrs.copy()
fluid_attrs.update(mapped_attrs)
inputs = inputs if input_perm is None else list(
map(lambda i: inputs[i], input_perm))
val_inps = []
for idx, ipt in enumerate(inputs):
val_inps.append(self.graph.get_input_node(node, idx=idx, copy=True))
val_outs = outputs if output_perm is None else list(
map(lambda i: outputs[i], output_perm))
attr = fluid_attrs
assert len(val_inps) == 1, 'directly_map error with multi inputs'
if fluid_op not in ['shape', 'erf']:
attr['name'] = string(node.layer_name)
node.fluid_code.add_layer(
fluid_op, inputs=val_inps[0], output=val_outs[0], param_attr=attr)
if fluid_op in ['shape']:
node.fluid_code.add_layer(
'cast',
inputs=val_outs[0],
output=val_outs[0],
param_attr={'dtype': string('int64')})
@print_mapping_info
def deal_custom_layer(self, node):
op = node.layer_type
custom_code, func = make_custom_layer(node)
child_func_code, child_func = make_custom_child_func(node)
params = get_params(node.layer, node.layer_type)
arg_names, kwargs = set_args(func, params)
kwargs['name'] = string(node.layer_name)
node.fluid_code.add_layer(
func.__code__.co_name,
inputs=node.inputs,
output=node,
param_attr=kwargs,
is_custom_layer=True)
if op not in self.used_custom_layers:
self.used_custom_layers[op] = custom_code
if op + '_child_func' not in self.used_custom_layers:
if child_func_code is not None:
self.used_custom_layers[op +
'_child_func'] = child_func_code
@print_mapping_info @print_mapping_info
def elementwise_map(self, node): def elementwise_map(self, node):
assert node.layer_type in self.elementwise_ops
op_type = self.elementwise_ops[node.layer_type] op_type = self.elementwise_ops[node.layer_type]
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_input_node(node, idx=1, copy=True) val_y = self.graph.get_input_node(node, idx=1, copy=True)
inputs = {'x': val_x, 'y': val_y} inputs_dict = {'x': val_x.name,
node.fluid_code.add_layer( 'y': val_y.name}
op_type, inputs=inputs, output=node, param_attr=None) self.paddle_graph.add_layer(
op_type,
inputs=inputs_dict,
outputs=[node.name])
@print_mapping_info @print_mapping_info
def place_holder(self, node): def place_holder(self, node):
self.input_shapes.append(node.out_shapes[0])
shape = node.out_shapes[0] shape = node.out_shapes[0]
for i, dim_shape in enumerate(shape): for i, dim_shape in enumerate(shape):
if dim_shape == 0 and i == 0: if dim_shape == 0 and i == 0:
shape[i] = 1 shape[i] = 1
if dim_shape == 0 and i != 0: if dim_shape == 0 and i != 0:
assert 'shape of input is not assigned' assert 'shape of input is not assigned'
attr = { self.paddle_graph.add_layer(
"dtype": string(node.dtype), kernel="paddle.static.data",
"shape": shape, inputs={},
"name": string(node.layer_name), outputs=[node.name],
"append_batch_size": 'False' dtype=string(node.dtype),
} shape=shape,
name=string(node.name))
node.fluid_code.add_layer( self.inputs_info["x{}".format(self.input_index)] = [shape, node.dtype]
"data", inputs=None, output=node, param_attr=attr) self.input_index += 1
@print_mapping_info @print_mapping_info
def create_parameter(self, node, parameter=None): def create_parameter(self, node, parameter=None):
...@@ -269,30 +206,23 @@ class OpSet9(): ...@@ -269,30 +206,23 @@ class OpSet9():
dtype = node.dtype dtype = node.dtype
shape = node.out_shapes[0] shape = node.out_shapes[0]
if len(node.weight.shape) == 0: if len(node.weight.shape) == 0:
shape = [1] self.paddle_graph.add_layer(
self.weights[node.layer_name] = node.weight "paddle.full",
attr = { inputs={},
'dtype': string(dtype), outputs=[node.name],
'shape': shape, dtype=string(dtype),
'name': string(node.layer_name), shape=[1],
'default_initializer': 'Constant(0.0)' fill_value=node.weight)
}
if dtype == 'bool':
attr['dtype'] = string('int64')
node.fluid_code.add_layer(
"create_parameter", inputs=None, output=node, param_attr=attr)
node.fluid_code.add_layer(
"cast",
inputs=node,
output=node,
param_attr={'dtype': string('bool')})
elif dtype == 'uint8':
attr['dtype'] = string('float32')
node.fluid_code.add_layer(
"create_parameter", inputs=None, output=node, param_attr=attr)
else: else:
node.fluid_code.add_layer( self.params[node.name] = node.weight
"create_parameter", inputs=None, output=node, param_attr=attr) self.paddle_graph.add_layer(
kernel="paddle.static.create_parameter",
inputs={},
outputs=[node.name],
dtype=string(dtype),
shape=shape,
name=string(node.name),
default_initializer="paddle.nn.initializer.Constant(value=0.0)")
def _pad_if_asymmetric(self, node, pads, val_name): # pads: SSEE def _pad_if_asymmetric(self, node, pads, val_name): # pads: SSEE
assert len(pads) & 1 == 0 assert len(pads) & 1 == 0
...@@ -309,49 +239,89 @@ class OpSet9(): ...@@ -309,49 +239,89 @@ class OpSet9():
def _interpolate(self, node): def _interpolate(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
inputs = {'input': val_x} inputs = {'x': val_x.name}
if node.layer_type == 'Resize': if node.layer_type == 'Resize':
if len(node.layer.input) == 2: if len(node.layer.input) == 2:
# opset 10 # opset 10
val_scales = self.graph.get_input_node(node, idx=1, copy=True) val_scales = self.graph.get_input_node(node, idx=1, copy=True)
inputs['scale'] = val_scales inputs['scale_factor'] = val_scales.name
elif len(node.layer.input) == 3: elif len(node.layer.input) == 3:
# opset 11 # opset 11
val_scales = self.graph.get_input_node(node, idx=2, copy=True) val_scales = self.graph.get_input_node(node, idx=2, copy=True)
inputs['scale'] = val_scales inputs['scale_factor'] = val_scales.name
elif len(node.layer.input) == 4: elif len(node.layer.input) == 4:
# opset 11 # opset 11
val_sizes = self.graph.get_input_node(node, idx=3, copy=True) val_sizes = self.graph.get_input_node(node, idx=3, copy=True)
var_nc, var_hw = val_sizes.layer_name + '_nc', val_sizes.layer_name + '_hw' var_nc, var_hw = val_sizes.name + '_nc', val_sizes.name + '_hw'
node.fluid_code.add_layer( self.paddle_graph.add_layer(
'split', 'paddle.split',
inputs=val_sizes, inputs={"x": val_sizes.name},
output=var_nc + ',' + var_hw, outputs=[var_nc, var_hw],
param_attr={ num_or_sections=[2, 2],
'dim': 0, axis=0)
'num_or_sections': [2, 2], self.paddle_graph.add_layer(
}) "paddle.cast",
node.fluid_code.add_layer( inputs={"x": var_hw},
"cast", outputs=[var_hw],
inputs=var_hw, dtype=string('int32'))
output=var_hw, # inputs['size'] = var_hw
param_attr={'dtype': string('int32')})
# TODO(syf): all use
inputs['out_shape'] = var_hw inputs['out_shape'] = var_hw
ipt = inputs.pop("x")
inputs["input"] = ipt
mode = node.get_attr('mode', 'nearest')
attrs = {"align_corners": False}
self.paddle_graph.add_layer(
kernel="fluid.layers.resize_nearest",
inputs=inputs,
outputs=[node.name],
**attrs)
return
elif node.layer_type == 'Upsample': elif node.layer_type == 'Upsample':
val_scales = self.graph.get_input_node(node, idx=1, copy=True) val_scales = self.graph.get_input_node(node, idx=1, copy=True)
inputs['scale'] = val_scales inputs['scale'] = val_scales
attr = {'name': string(node.layer_name)}
mode = node.get_attr('mode', 'nearest') mode = node.get_attr('mode', 'nearest')
fluid_op = 'resize_{}'.format(mode) attrs = {"align_corners": False,
if 'linear' in mode: "mode": string(mode),
print( "align_mode": 1}
'Warnning: paddle not support op:resize wiht mode: linear, we use bilinear replace linear' self.paddle_graph.add_layer(
) kernel="paddle.nn.functional.interpolate",
fluid_op = 'resize_bilinear' inputs=inputs,
attr['align_corners'] = False outputs=[node.name],
node.fluid_code.add_layer( **attrs)
fluid_op, inputs=inputs, output=node, param_attr=attr)
@print_mapping_info
def HardSigmoid(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
alpha = node.get_attr('alpha', 0.2)
beta = node.get_attr('beta', 0.5)
self.paddle_graph.add_layer(
kernel="paddle.scale",
inputs={"x": val_x.name},
outputs=[node.name + "_val"],
scale=alpha,
bias=beta)
self.paddle_graph.add_layer(
kernel="paddle.clip",
inputs={"x": node.name + "_val"},
outputs=[node.name],
min=0.0,
max=1.0)
@print_mapping_info
def Shape(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
self.paddle_graph.add_layer(
kernel="paddle.shape",
inputs={"input": val_x.name},
outputs=[node.name])
self.paddle_graph.add_layer(
'paddle.cast',
inputs={"x": node.name},
outputs=[node.name],
dtype=string('int64'))
@print_mapping_info @print_mapping_info
def RoiAlign(self, node): def RoiAlign(self, node):
...@@ -362,18 +332,18 @@ class OpSet9(): ...@@ -362,18 +332,18 @@ class OpSet9():
pooled_width = node.get_attr('output_width') pooled_width = node.get_attr('output_width')
spatial_scale = node.get_attr('spatial_scale') spatial_scale = node.get_attr('spatial_scale')
sampling_ratio = node.get_attr('sampling_ratio') sampling_ratio = node.get_attr('sampling_ratio')
attr = { layer_attrs = {
'pooled_height': pooled_height, 'pooled_height': pooled_height,
'pooled_width': pooled_width, 'pooled_width': pooled_width,
'spatial_scale': spatial_scale, 'spatial_scale': spatial_scale,
'sampling_ratio': sampling_ratio, 'sampling_ratio': sampling_ratio,
} }
node.fluid_code.add_layer( self.paddle_graph.add_layer(
'roi_align', 'fluid.layers.roi_align',
inputs={'input': val_x, inputs={'input': val_x.name,
'rois': val_rois}, 'rois': val_rois.name},
output=node, outputs=[node.name],
param_attr=attr) **layer_attrs)
@print_mapping_info @print_mapping_info
def MaxRoiPool(self, node): def MaxRoiPool(self, node):
...@@ -382,17 +352,17 @@ class OpSet9(): ...@@ -382,17 +352,17 @@ class OpSet9():
spatial_scale = node.get_attr('spatial_scale') spatial_scale = node.get_attr('spatial_scale')
pooled_height, pooled_width = node.get_attr('pooled_shape') pooled_height, pooled_width = node.get_attr('pooled_shape')
attr = { layer_attrs = {
'pooled_height': pooled_height, 'pooled_height': pooled_height,
'pooled_width': pooled_width, 'pooled_width': pooled_width,
'spatial_scale': spatial_scale, 'spatial_scale': spatial_scale,
} }
node.fluid_code.add_layer( self.paddle_graph.add_layer(
'roi_pool', 'fluid.layers.roi_pool',
inputs={'input': val_x, inputs={'input': val_x.name,
'rois': val_rois}, 'rois': val_rois.name},
output=node, outputs=[node.name],
param_attr=attr) **layer_attrs)
@print_mapping_info @print_mapping_info
def Pad(self, node, op_independent=True): def Pad(self, node, op_independent=True):
...@@ -403,7 +373,8 @@ class OpSet9(): ...@@ -403,7 +373,8 @@ class OpSet9():
data_shape = val_x.out_shapes[0] data_shape = val_x.out_shapes[0]
output_shape = node.out_shapes[0] output_shape = node.out_shapes[0]
assume_pad2d = False assume_pad2d = False
attr = {} layer_attrs = {}
layer_attrs['mode'] = string(mode)
paddings = [] paddings = []
if len(pads) == 4: if len(pads) == 4:
assume_pad2d |= mode != 'constant' assume_pad2d |= mode != 'constant'
...@@ -412,12 +383,12 @@ class OpSet9(): ...@@ -412,12 +383,12 @@ class OpSet9():
if output_shape: if output_shape:
assume_pad2d |= output_shape and len(output_shape) == 4 # NCHW assume_pad2d |= output_shape and len(output_shape) == 4 # NCHW
if assume_pad2d: if assume_pad2d:
fluid_op = 'pad2d' paddle_op = 'paddle.nn.functional.pad'
attr['data_format'] = string('NCHW') layer_attrs['data_format'] = string('NCHW')
attr['mode'] = string(mode) layer_attrs['value'] = value
else: else:
attr = {'pad_value': value} paddle_op = 'fluid.layers.pad'
fluid_op = 'pad' layer_attrs["pad_value"] = value
if len(pads) == 4: if len(pads) == 4:
paddings = np.array(pads).reshape( paddings = np.array(pads).reshape(
(-1, 2)).transpose().flatten().tolist() # SSEE -> SESE (-1, 2)).transpose().flatten().tolist() # SSEE -> SESE
...@@ -425,51 +396,52 @@ class OpSet9(): ...@@ -425,51 +396,52 @@ class OpSet9():
paddings = np.array(pads).reshape( paddings = np.array(pads).reshape(
(-1, 4)).transpose().flatten().tolist() # SSEE -> SESE (-1, 4)).transpose().flatten().tolist() # SSEE -> SESE
if sum(paddings[:4]) == 0: if sum(paddings[:4]) == 0:
fluid_op = 'pad2d' paddle_op = 'paddle.nn.functional.pad'
paddings = paddings[4:] paddings = paddings[4:]
attr['mode'] = string(mode) layer_attrs['value'] = value
attr['paddings'] = paddings if 'pad_value' in layer_attrs:
layer_attrs.pop('pad_value')
tmp_paddings = copy.deepcopy(paddings)
paddings[0] = tmp_paddings[2]
paddings[1] = tmp_paddings[3]
paddings[2] = tmp_paddings[0]
paddings[3] = tmp_paddings[1]
if paddle_op == 'paddle.nn.functional.pad':
layer_attrs['pad'] = paddings
else:
layer_attrs['paddings'] = paddings
if op_independent: if op_independent:
attr['name'] = string(node.layer_name) self.paddle_graph.add_layer(
node.fluid_code.add_layer( paddle_op,
fluid_op, inputs=val_x, output=node, param_attr=attr) inputs={'x': val_x.name},
outputs=[node.name],
**layer_attrs)
else: else:
attr['name'] = string(node.layer_name + '_paded') self.paddle_graph.add_layer(
node.fluid_code.add_layer( paddle_op,
fluid_op, inputs={'x': val_x.name},
inputs=val_x, outputs=[node.name + '_paded'],
output=node.layer_name + '_paded', **layer_attrs)
param_attr=attr) return node.name + '_paded'
return node.layer_name + '_paded'
@print_mapping_info @print_mapping_info
def Unsqueeze(self, node): def Unsqueeze(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
axes = node.get_attr('axes') axes = node.get_attr('axes')
attr = {'axes': axes, 'name': string(node.layer_name)} layer_attrs = {'axis': axes}
if len(val_x.out_shapes[0]) == 0: if len(val_x.out_shapes[0]) == 0:
if node.layer_name: if node.name:
node.fluid_code.add_layer( self.paddle_graph.add_layer(
'reshape', 'paddle.reshape',
inputs=val_x, inputs={"x": val_x.name},
output=node, outputs=[node.name],
param_attr={'shape': [1]}) shape=[1])
else: else:
if str(val_x.dtype) == 'bool': self.paddle_graph.add_layer(
val_x_cast = val_x.layer_name + '_cast' 'paddle.unsqueeze',
node.fluid_code.add_layer( inputs={"x": val_x.name},
'cast', outputs=[node.name],
inputs=val_x, **layer_attrs)
output=val_x_cast,
param_attr={'dtype': string('int64')})
node.fluid_code.add_layer(
'unsqueeze',
inputs=val_x_cast,
output=node,
param_attr=attr)
else:
node.fluid_code.add_layer(
'unsqueeze', inputs=val_x, output=node, param_attr=attr)
@print_mapping_info @print_mapping_info
def Shrink(self, node): def Shrink(self, node):
...@@ -477,9 +449,11 @@ class OpSet9(): ...@@ -477,9 +449,11 @@ class OpSet9():
bias = node.get_attr('bias') bias = node.get_attr('bias')
lambd = node.get_attr('lambd') lambd = node.get_attr('lambd')
assert bias == 0.0, 'not support bias!=0' assert bias == 0.0, 'not support bias!=0'
attr = {'threshold': lambd, 'name': node.layer_name} self.paddle_graph.add_layer(
node.fluid_code.add_layer( 'paddle.nn.functional.hardshrink',
'hard_shrink', inputs=val_x, output=node, param_attr=attr) inputs={"x": val_x.name},
outputs=[node.name],
threshold=lambd)
@print_mapping_info @print_mapping_info
def Constant(self, node): def Constant(self, node):
...@@ -500,29 +474,28 @@ class OpSet9(): ...@@ -500,29 +474,28 @@ class OpSet9():
_logger.warning('in (Constant -> %s): ' _logger.warning('in (Constant -> %s): '
'attribute "shape" of %s not inferred, ' 'attribute "shape" of %s not inferred, '
'using value as 1-D tensor may lead to fails', 'using value as 1-D tensor may lead to fails',
val_output.layer_name, val_output.layer_name) val_output.name, val_output.name)
if len(value) == 1: if len(value) == 1:
value = value.tolist() value = value.tolist()
shape = [1]
value = value[0] value = value[0]
if dtype.name == 'int64': self.paddle_graph.add_layer(
dtype = 'int32' "paddle.full",
attr = {'shape': shape, 'dtype': string(dtype), 'value': value} inputs={},
node.fluid_code.add_layer( outputs=[node.name],
'fill_constant', inputs=None, output=node, param_attr=attr) dtype=string(dtype),
shape=[1],
fill_value=value)
else: else:
if dtype.name == 'uint8':
dtype = 'int64'
value = np.reshape(value, shape) value = np.reshape(value, shape)
self.weights[node.layer_name] = value self.params[node.name] = value
attr = { self.paddle_graph.add_layer(
'dtype': string(dtype), kernel="paddle.static.create_parameter",
'shape': shape, inputs={},
'name': string(node.layer_name), outputs=[node.name],
'default_initializer': 'Constant(0.0)' dtype=string(dtype),
} shape=shape,
node.fluid_code.add_layer( name=string(node.name),
"create_parameter", inputs=None, output=node, param_attr=attr) default_initializer="paddle.nn.initializer.Constant(value=0.0)")
@print_mapping_info @print_mapping_info
def Resize(self, node): def Resize(self, node):
...@@ -538,36 +511,50 @@ class OpSet9(): ...@@ -538,36 +511,50 @@ class OpSet9():
val_scale = self.graph.get_input_node(node, idx=1, copy=True) val_scale = self.graph.get_input_node(node, idx=1, copy=True)
val_b = self.graph.get_input_node(node, idx=2, copy=True) val_b = self.graph.get_input_node(node, idx=2, copy=True)
epsilon = node.get_attr('epsilon', 1e-5) epsilon = node.get_attr('epsilon', 1e-5)
attr = { layer_attrs = {
'epsilon': epsilon, 'eps': epsilon,
'param_attr': string(val_scale.layer_name),
'bias_attr': string(val_b.layer_name)
} }
node.fluid_code.add_layer( dim = len(val_x.out_shapes[0])
"instance_norm", inputs=val_x, output=node, param_attr=attr) if dim ==2 :
layer_attrs["data_format"] = string("NC")
elif dim == 3:
layer_attrs["data_format"] = string("NCL")
elif dim == 4:
layer_attrs["data_format"] = string("NCHW")
elif dim == 5:
layer_attrs["data_format"] = string("NCDHW")
else:
raise Exception("The paddle only support 2D, 3D, 4D or 5D input in InstanceNormalization.")
self.paddle_graph.add_layer(
"paddle.nn.functional.instance_norm",
inputs={"x": val_x.name,
"weight": val_scale.name,
"bias": val_b.name},
outputs=[node.name],
**layer_attrs)
@print_mapping_info @print_mapping_info
def Expand(self, node): def Expand(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_shape = self.graph.get_input_node(node, idx=1, copy=True) val_shape = self.graph.get_input_node(node, idx=1, copy=True)
val_x_dtype = val_x.dtype val_x_dtype = val_x.dtype
name_ones = node.layer_name + '_ones' name_ones = node.name + '_ones'
attr_ones = { attr_ones = {
'shape': val_shape.layer_name, 'shape': val_shape.name,
'dtype': string(val_x_dtype), 'dtype': string(val_x_dtype),
'value': 1 'fill_value': 1
} }
node.fluid_code.add_layer( self.paddle_graph.add_layer(
'fill_constant', 'paddle.full',
inputs=None, inputs={},
output=name_ones, outputs=[name_ones],
param_attr=attr_ones) **attr_ones)
inputs = {'x': name_ones, 'y': val_x} inputs_dict = {'x': name_ones,
node.fluid_code.add_layer( 'y': val_x.name}
'elementwise_mul', self.paddle_graph.add_layer(
inputs=inputs, 'paddle.multiply',
output=node.layer_name, inputs=inputs_dict,
param_attr=None) outputs=[node.name])
@print_mapping_info @print_mapping_info
def Gather(self, node): def Gather(self, node):
...@@ -579,147 +566,140 @@ class OpSet9(): ...@@ -579,147 +566,140 @@ class OpSet9():
# indices_shape) <= 2, "Gather op don't support dim of indice >2 " # indices_shape) <= 2, "Gather op don't support dim of indice >2 "
if axis == 0 and len(indices_shape) <= 1: if axis == 0 and len(indices_shape) <= 1:
if len(val_x.out_shapes[0]) <= 1: if len(val_x.out_shapes[0]) <= 1:
node.fluid_code.add_layer( self.paddle_graph.add_layer(
'gather', 'paddle.gather',
inputs={'input': val_x, inputs={'x': val_x.name,
'index': indices}, 'index': indices.name},
output=node, outputs=[node.name])
param_attr=None)
elif len(val_x.out_shapes[0]) > 1: elif len(val_x.out_shapes[0]) > 1:
if len(indices_shape) == 0: if len(indices_shape) == 0:
gather_ = node.layer_name + '_1' gather_ = node.name + '_1'
node.fluid_code.add_layer( self.paddle_graph.add_layer(
'gather', 'paddle.gather',
inputs={'input': val_x, inputs={'x': val_x.name,
'index': indices}, 'index': indices.name},
output=gather_, outputs=[gather_])
param_attr=None) self.paddle_graph.add_layer(
node.fluid_code.add_layer( 'paddle.squeeze',
'squeeze', inputs={'x': gather_},
inputs={'input': gather_, outputs=[node.name],
'axes': [0]}, axis=[0])
output=node,
param_attr=None)
else: else:
node.fluid_code.add_layer( self.paddle_graph.add_layer(
'gather', 'paddle.gather',
inputs={'input': val_x, inputs={'x': val_x.name,
'index': indices}, 'index': indices.name},
output=node, outputs=[node.name])
param_attr=None)
elif axis > 0 and len(indices_shape) <= 1: elif axis > 0 and len(indices_shape) <= 1:
perm = list(range(len(val_x.out_shapes[0]))) perm = list(range(len(val_x.out_shapes[0])))
perm = [axis] + perm[:axis] + perm[axis + 1:] perm = [axis] + perm[:axis] + perm[axis + 1:]
attr_trans = {'perm': perm} name_trans = val_x.name + '_trans'
name_trans = val_x.layer_name + '_trans' self.paddle_graph.add_layer(
node.fluid_code.add_layer( 'paddle.transpose',
'transpose', inputs={"x": val_x.name},
inputs=val_x, outputs=[name_trans],
output=name_trans, perm=perm)
param_attr=attr_trans) self.paddle_graph.add_layer(
node.fluid_code.add_layer( 'paddle.gather',
'gather', inputs={'x': name_trans,
inputs={'input': name_trans, 'index': indices.name},
'index': indices}, outputs=[node.name])
output=node, self.paddle_graph.add_layer(
param_attr=None) 'paddle.transpose',
node.fluid_code.add_layer( inputs={"x": node.name},
'transpose', inputs=node, output=node, param_attr=attr_trans) outputs=[node.name],
perm=perm)
if len(indices_shape) < 1: if len(indices_shape) < 1:
node.fluid_code.add_layer( self.paddle_graph.add_layer(
'squeeze', 'paddle.squeeze',
inputs={'input': node, inputs={'x': node.name},
'axes': [axis]}, outputs=[node.name],
output=node, axis=[axis])
param_attr=None)
elif axis == 0 and len(indices_shape) > 1: elif axis == 0 and len(indices_shape) > 1:
if val_x.out_shapes[0] is not None and isinstance( if val_x.out_shapes[0] is not None and isinstance(
val_x, ONNXGraphDataNode): val_x, ONNXGraphDataNode):
indices_cast = indices.layer_name + '_cast' indices_cast = indices.name + '_cast'
node.fluid_code.add_layer( self.paddle_graph.add_layer(
'cast', 'paddle.cast',
inputs=indices, inputs={"x": indices.name},
output=indices_cast, outputs=indices_cast,
param_attr={'dtype': string('int64')}) dtype=string('int64'))
node.fluid_code.add_layer( op_name = name_generator("embedding", self.nn_name2id)
'embedding', output_name = node.name
inputs=indices_cast, layer_outputs = [op_name, output_name]
output=node, self.paddle_graph.add_layer(
use_fluid=True, 'paddle.nn.Embedding',
param_attr={ inputs={"x": indices_cast},
'param_attr': string(val_x.layer_name), outputs=layer_outputs,
'size': val_x.out_shapes[0] param_attr=string(val_x.name),
}) size=val_x.out_shapes[0])
else: else:
from functools import reduce from functools import reduce
reshape_shape = reduce(lambda x, y: x * y, indices_shape) reshape_shape = reduce(lambda x, y: x * y, indices_shape)
indices_reshape = indices.layer_name + '_shape' indices_reshape = indices.name + '_shape'
node.fluid_code.add_layer( self.paddle_graph.add_layer(
'reshape', 'paddle.reshape',
inputs=indices, inputs={"x": indices.name},
output=indices_reshape, outputs=[indices_reshape],
param_attr={'shape': [reshape_shape, ]}) shape=[reshape_shape, ])
perm = list(range(len(val_x.out_shapes[0]))) perm = list(range(len(val_x.out_shapes[0])))
node.fluid_code.add_layer( self.paddle_graph.add_layer(
'gather', 'paddle.gather',
inputs={'input': val_x, inputs={'x': val_x.name,
'index': indices_reshape}, 'index': indices_reshape},
output=node, outputs=[node.name])
param_attr=None)
val_x_shape = val_x.out_shapes[0] val_x_shape = val_x.out_shapes[0]
reshaped_shape = [] reshaped_shape = []
for i in perm: for i in perm:
reshaped_shape.append(indices_shape[i]) reshaped_shape.append(indices_shape[i])
for i in val_x_shape[:axis] + val_x_shape[axis + 1:]: for i in val_x_shape[:axis] + val_x_shape[axis + 1:]:
reshaped_shape.append(i) reshaped_shape.append(i)
node.fluid_code.add_layer( self.paddle_graph.add_layer(
'reshape', 'paddle.reshape',
inputs=node, inputs={"x": node.name},
output=node, outputs=[node.name],
param_attr={'shape': reshaped_shape}) shape=reshaped_shape)
elif axis > 0 and len(indices_shape) > 1: elif axis > 0 and len(indices_shape) > 1:
from functools import reduce from functools import reduce
reshape_shape = reduce(lambda x, y: x * y, indices_shape) reshape_shape = reduce(lambda x, y: x * y, indices_shape)
indices_reshape = indices.layer_name + '_shape' indices_reshape = indices.name + '_shape'
node.fluid_code.add_layer( self.paddle_graph.add_layer(
'reshape', 'paddle.reshape',
inputs=indices, inputs={"x": indices.name},
output=indices_reshape, outputs=[indices_reshape],
param_attr={'shape': [reshape_shape, ]}) shape=[reshape_shape, ])
perm = list(range(len(val_x.out_shapes[0]))) perm = list(range(len(val_x.out_shapes[0])))
perm = [axis] + perm[:axis] + perm[axis + 1:] perm = [axis] + perm[:axis] + perm[axis + 1:]
attr_trans = {'perm': perm} name_trans = val_x.name + '_transpose'
name_trans = val_x.layer_name + '_transpose' self.paddle_graph.add_layer(
node.fluid_code.add_layer( 'paddle.transpose',
'transpose', inputs={"x": val_x.name},
inputs=val_x, outputs=[name_trans],
output=name_trans, perm=perm)
param_attr=attr_trans) self.paddle_graph.add_layer(
node.fluid_code.add_layer( 'paddle.gather',
'gather', inputs={'x': name_trans,
inputs={'input': name_trans,
'index': indices_reshape}, 'index': indices_reshape},
output=node, outputs=[node.name])
param_attr=None) input_transpose = node.name + '_transpose'
input_transpose = node.layer_name + '_transpose' self.paddle_graph.add_layer(
node.fluid_code.add_layer( 'paddle.transpose',
'transpose', inputs={"x": node.name},
inputs=node, outputs=[input_transpose],
output=input_transpose, perm=perm)
param_attr=attr_trans)
val_x_shape = val_x.out_shapes[0] val_x_shape = val_x.out_shapes[0]
reshaped_shape = [] reshaped_shape = []
for i in perm: for i in perm:
reshaped_shape.append(indices_shape[i]) reshaped_shape.append(indices_shape[i])
for i in val_x_shape[:axis] + val_x_shape[axis + 1:]: for i in val_x_shape[:axis] + val_x_shape[axis + 1:]:
reshaped_shape.append(i) reshaped_shape.append(i)
node.fluid_code.add_layer( self.paddle_graph.add_layer(
'reshape', 'paddle.reshape',
inputs=input_transpose, inputs={"x": input_transpose},
output=node, outputs=[node.name],
param_attr={'shape': reshaped_shape}) shape=reshaped_shape)
@print_mapping_info @print_mapping_info
def ScatterND(self, node): def ScatterND(self, node):
...@@ -727,85 +707,78 @@ class OpSet9(): ...@@ -727,85 +707,78 @@ class OpSet9():
indices = self.graph.get_input_node(node, idx=1, copy=True) indices = self.graph.get_input_node(node, idx=1, copy=True)
updates = self.graph.get_input_node(node, idx=2, copy=True) updates = self.graph.get_input_node(node, idx=2, copy=True)
if len(indices.out_shapes[0]) == 1: if len(indices.out_shapes[0]) == 1:
node.fluid_code.add_layer( self.paddle_graph.add_layer(
'scatter', 'paddle.scatter',
inputs={'input': val_x, inputs={'x': val_x.name,
'index': indices, 'index': indices.name,
'updates': updates}, 'updates': updates.name},
output=node, outputs=[node.name])
param_attr=None)
else: else:
input_inner_indices = node.layer_name + '_input_inner_indices' input_inner_indices = node.name + '_input_inner_indices'
shape = val_x.out_shapes[0] shape = val_x.out_shapes[0]
node.fluid_code.add_layer( self.paddle_graph.add_layer(
'reshape', 'paddle.reshape',
inputs=indices.layer_name, inputs={"x": indices.name},
output=indices.layer_name, outputs=[indices.name],
param_attr={'shape': indices.out_shapes[0]}) shape=indices.out_shapes[0])
zeros_like_val_x = val_x.layer_name + '_zeros' zeros_like_val_x = val_x.name + '_zeros'
node.fluid_code.add_layer( self.paddle_graph.add_layer(
'zeros_like', 'paddle.zeros_like',
inputs=val_x, inputs={"x": val_x.name},
output=zeros_like_val_x, outputs=[zeros_like_val_x])
param_attr=None) self.paddle_graph.add_layer(
node.fluid_code.add_layer( 'paddle.scatter_nd_add',
'scatter_nd_add',
inputs={ inputs={
'ref': zeros_like_val_x, 'x': zeros_like_val_x,
'index': indices, 'index': indices.name,
'updates': updates 'updates': updates.name
}, },
output=input_inner_indices, outputs=[input_inner_indices])
param_attr=None) indices_mask = node.name + '_indices_mask'
indices_mask = node.layer_name + '_indices_mask' constant_minus_one = node.name + '_constant_minus_one'
constant_minus_one = node.layer_name + '_constant_minus_one'
# full_like support create tensor shape like input tensor # full_like support create tensor shape like input tensor
node.fluid_code.add_layer( self.paddle_graph.add_layer(
'full_like', 'paddle.full_like',
inputs=updates, inputs={"x": updates.name},
output=constant_minus_one, outputs=[constant_minus_one],
param_attr={'dtype': string(updates.dtype), dtype=string(updates.dtype),
'fill_value': -1}) fill_value=-1)
node.fluid_code.add_layer( self.paddle_graph.add_layer(
'scatter_nd_add', 'paddle.scatter_nd_add',
inputs={ inputs={
'ref': zeros_like_val_x, 'x': zeros_like_val_x,
'index': indices, 'index': indices.name,
'updates': constant_minus_one 'updates': constant_minus_one
}, },
output=indices_mask, outputs=[indices_mask])
param_attr=None) constant_one = node.name + '_constant_1'
constant_one = node.layer_name + '_constant_1'
# full_like support create tensor shape like input tensor # full_like support create tensor shape like input tensor
node.fluid_code.add_layer( self.paddle_graph.add_layer(
'full_like', 'paddle.full_like',
inputs=val_x, inputs={"x": val_x.name},
output=constant_one, outputs=[constant_one],
param_attr={'dtype': string(val_x.dtype), dtype=string(val_x.dtype),
'fill_value': 1}) fill_value=1)
input_out_indices_mask = node.layer_name + '_input_out_indices_mask' input_out_indices_mask = node.name + '_input_out_indices_mask'
node.fluid_code.add_layer( self.paddle_graph.add_layer(
"elementwise_add", "paddle.add",
inputs={"x": indices_mask, inputs={"x": indices_mask,
"y": constant_one}, "y": constant_one},
output=input_out_indices_mask, outputs=[input_out_indices_mask])
param_attr=None)
input_out_indices = node.layer_name + '_input_out_indices' input_out_indices = node.name + '_input_out_indices'
node.fluid_code.add_layer( self.paddle_graph.add_layer(
"elementwise_mul", "paddle.multiply",
inputs={"x": val_x, inputs={"x": val_x.name,
"y": input_out_indices_mask}, "y": input_out_indices_mask},
output=input_out_indices, outputs=[input_out_indices])
param_attr=None)
node.fluid_code.add_layer( self.paddle_graph.add_layer(
"elementwise_add", "paddle.add",
inputs={"x": input_inner_indices, inputs={"x": input_inner_indices,
"y": input_out_indices}, "y": input_out_indices},
output=node, outputs=[node.name])
param_attr=None)
@print_mapping_info @print_mapping_info
def Range(self, node): def Range(self, node):
...@@ -813,18 +786,20 @@ class OpSet9(): ...@@ -813,18 +786,20 @@ class OpSet9():
val_limit = self.graph.get_input_node(node, idx=1, copy=True) val_limit = self.graph.get_input_node(node, idx=1, copy=True)
val_delta = self.graph.get_input_node(node, idx=2, copy=True) val_delta = self.graph.get_input_node(node, idx=2, copy=True)
dtype = val_start.dtype dtype = val_start.dtype
inputs = {'start': val_start, 'end': val_limit, 'step': val_delta} inputs = {'start': val_start.name,
node.fluid_code.add_layer( 'end': val_limit.name,
'range', 'step': val_delta.name}
self.paddle_graph.add_layer(
'paddle.arange',
inputs=inputs, inputs=inputs,
output=node, outputs=[node.name],
param_attr={'dtype': string(dtype)}) dtype=string(dtype))
@print_mapping_info @print_mapping_info
def Slice(self, node): def Slice(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
starts, ends, axes, steps = None, None, None, None starts, ends, axes, steps = None, None, None, None
attr = {} layer_attrs = {}
if len(node.inputs) > 1: if len(node.inputs) > 1:
starts = self.graph.get_input_node(node, idx=1, copy=True) starts = self.graph.get_input_node(node, idx=1, copy=True)
ends = self.graph.get_input_node(node, idx=2, copy=True) ends = self.graph.get_input_node(node, idx=2, copy=True)
...@@ -837,14 +812,12 @@ class OpSet9(): ...@@ -837,14 +812,12 @@ class OpSet9():
if len(node.inputs) > 4: if len(node.inputs) > 4:
steps = self.graph.get_input_node(node, idx=4, copy=True) steps = self.graph.get_input_node(node, idx=4, copy=True)
steps = _const_weight_or_none(steps) steps = _const_weight_or_none(steps)
attr = { layer_attrs = {
"axes": axes, "axes": axes,
"starts": starts.layer_name, "starts": starts.name,
"ends": ends.layer_name "ends": ends.name
} }
if starts_value is not None and ends_value is not None: if starts_value is not None and ends_value is not None:
self.omit_nodes.append(starts.layer_name)
self.omit_nodes.append(ends.layer_name)
starts_value = starts_value.copy() starts_value = starts_value.copy()
ends_value = ends_value.copy() ends_value = ends_value.copy()
#for idx in range(len(ends_value)): #for idx in range(len(ends_value)):
...@@ -858,28 +831,28 @@ class OpSet9(): ...@@ -858,28 +831,28 @@ class OpSet9():
starts_value[idx] = val_x.out_shapes[0][axes[idx]] - 1 starts_value[idx] = val_x.out_shapes[0][axes[idx]] - 1
elif ends_value[idx] > 2**31 - 1: elif ends_value[idx] > 2**31 - 1:
ends_value[idx] = 2**31 - 1 ends_value[idx] = 2**31 - 1
attr = { layer_attrs = {
"axes": axes, "axes": axes,
"starts": starts_value, "starts": starts_value,
"ends": ends_value "ends": ends_value
} }
else: else:
if starts.dtype != 'int32': if starts.dtype != 'int32':
starts_cast = starts.layer_name + '_cast' starts_cast = starts.name + '_cast'
node.fluid_code.add_layer( self.paddle_graph.add_layer(
'cast', 'paddle.cast',
inputs=starts, inputs={"x": starts.name},
output=starts_cast, outputs=[starts_cast],
param_attr={'dtype': string('int32')}) dtype=string('int32'))
attr['starts'] = starts_cast layer_attrs['starts'] = starts_cast
if ends.dtype != 'int32': if ends.dtype != 'int32':
ends_cast = ends.layer_name + '_cast' ends_cast = ends.name + '_cast'
node.fluid_code.add_layer( self.paddle_graph.add_layer(
'cast', 'paddle.cast',
inputs=ends, inputs={"x": ends.name},
output=ends_cast, outputs=[ends_cast],
param_attr={'dtype': string('int32')}) dtype=string('int32'))
attr['ends'] = ends_cast layer_attrs['ends'] = ends_cast
else: else:
starts = node.get_attr('starts') starts = node.get_attr('starts')
ends = node.get_attr('ends') ends = node.get_attr('ends')
...@@ -887,15 +860,21 @@ class OpSet9(): ...@@ -887,15 +860,21 @@ class OpSet9():
for idx in range(len(ends)): for idx in range(len(ends)):
if ends[idx] > 2**31 - 1: if ends[idx] > 2**31 - 1:
ends[idx] = 2**31 - 1 ends[idx] = 2**31 - 1
attr = {"axes": axes, "starts": starts, "ends": ends} layer_attrs = {"axes": axes, "starts": starts, "ends": ends}
if steps is not None: if steps is not None:
attr['strides'] = steps layer_attrs['strides'] = steps
node.fluid_code.add_layer( self.paddle_graph.add_layer(
'strided_slice', inputs=val_x, output=node, param_attr=attr) 'paddle.strided_slice',
inputs={"x": val_x.name},
outputs=[node.name],
**layer_attrs)
else: else:
node.fluid_code.add_layer( self.paddle_graph.add_layer(
'slice', inputs=val_x, output=node, param_attr=attr) 'paddle.slice',
inputs={"input": val_x.name},
outputs=[node.name],
**layer_attrs)
@print_mapping_info @print_mapping_info
def ConstantOfShape(self, node): def ConstantOfShape(self, node):
...@@ -909,13 +888,16 @@ class OpSet9(): ...@@ -909,13 +888,16 @@ class OpSet9():
'this is not supported') 'this is not supported')
if len(value) == 1: if len(value) == 1:
value = value[0] value = value[0]
attr = { layer_attrs = {
'shape': val_shape.layer_name, 'shape': val_shape.name,
'dtype': string(dtype), 'dtype': string(dtype),
'value': value 'fill_value': value
} }
node.fluid_code.add_layer( self.paddle_graph.add_layer(
'fill_constant', inputs=None, output=node, param_attr=attr) "paddle.full",
inputs={},
outputs=[node.name],
**layer_attrs)
@print_mapping_info @print_mapping_info
def Clip(self, node): def Clip(self, node):
...@@ -925,104 +907,90 @@ class OpSet9(): ...@@ -925,104 +907,90 @@ class OpSet9():
if len(node.inputs) == 1: if len(node.inputs) == 1:
max_value = node.get_attr('max') max_value = node.get_attr('max')
min_value = node.get_attr('min') min_value = node.get_attr('min')
attr = { layer_attrs = {
'max': max_value, 'max': max_value,
'min': min_value, 'min': min_value,
} }
node.fluid_code.add_layer( self.paddle_graph.add_layer(
'clip', inputs=val_x, output=node, param_attr=attr) 'paddle.clip',
inputs={"x": val_x.name},
outputs=[node.name],
**layer_attrs)
else: else:
max_ipt = self.graph.get_input_node(node, idx=1, copy=True) max_ipt = self.graph.get_input_node(node, idx=1, copy=True)
min_ipt = self.graph.get_input_node(node, idx=2, copy=True) min_ipt = self.graph.get_input_node(node, idx=2, copy=True)
max_value = _const_weight_or_none(max_ipt) max_value = _const_weight_or_none(max_ipt)
min_value = _const_weight_or_none(min_ipt) min_value = _const_weight_or_none(min_ipt)
self.omit_nodes.append(max_ipt.layer_name)
self.omit_nodes.append(min_ipt.layer_name)
if max_value.shape == (1, ): if max_value.shape == (1, ):
max_value = max_value[0] max_value = max_value[0]
if min_value.shape == (1, ): if min_value.shape == (1, ):
min_value = min_value[0] min_value = min_value[0]
if max_value is not None and min_value is not None: if max_value is not None and min_value is not None:
attr = {'max': max_value, 'min': min_value} layer_attrs = {'max': max_value, 'min': min_value}
node.fluid_code.add_layer( self.paddle_graph.add_layer(
'clip', inputs=val_x, output=node, param_attr=attr) 'paddle.clip',
inputs={"x": val_x.name},
outputs=[node.name],
**layer_attrs)
else: else:
raise raise
@print_mapping_info @print_mapping_info
def Split(self, node): def Split(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_node(node.layer.output[0], copy=True) paddle_op = 'split'
fluid_op = 'split'
split = node.get_attr('split') split = node.get_attr('split')
axis = node.get_attr('axis', 0) axis = node.get_attr('axis', 0)
attr = { layer_attrs = {
'num_or_sections': split, 'num_or_sections': split,
'dim': axis, 'axis': axis,
'name': string(node.layer_name)
} }
outputs_list = list()
node.fluid_code.add_layer( if isinstance(split, list) or isinstance(split, tuple):
'split', inputs=val_x, output=val_y, param_attr=attr) for i in range(len(split)):
outputs_list.append("{}_p{}".format(node.layer_name, i))
else:
outputs_list.append(node.name)
self.paddle_graph.add_layer(
'paddle.split',
inputs={"x": val_x.name},
outputs=outputs_list,
**layer_attrs)
@print_mapping_info @print_mapping_info
def Reshape(self, node): def Reshape(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_shape = self.graph.get_input_node(node, idx=1, copy=True) val_shape = self.graph.get_input_node(node, idx=1, copy=True)
val_reshaped = self.graph.get_node(node.layer.output[0], copy=True) val_reshaped = self.graph.get_node(node.layer.output[0], copy=True)
attr = {}
shape_value = _const_weight_or_none(val_shape) shape_value = _const_weight_or_none(val_shape)
shape_dims = len(val_shape.out_shapes[0]) shape_dims = len(val_shape.out_shapes[0])
if shape_value is not None: if shape_value is not None:
node.fluid_code.add_layer( self.paddle_graph.add_layer(
'reshape', 'paddle.reshape',
inputs={'x': val_x}, inputs={'x': val_x.name},
output=node, outputs=[node.name],
param_attr={'shape': shape_value.tolist()}) shape=shape_value.tolist())
elif len(node.out_shapes[0]) > 0 and _is_static_shape(node.out_shapes[ elif len(node.out_shapes[0]) > 0 and _is_static_shape(node.out_shapes[
0]): 0]):
node.fluid_code.add_layer( self.paddle_graph.add_layer(
'reshape', 'paddle.reshape',
inputs={'x': val_x, inputs={'x': val_x.name},
'shape': node.out_shapes[0]}, outputs=[node.name],
output=node, shape=node.out_shapes[0])
param_attr=attr)
elif val_shape.dtype == 'int64':
val_shape_cast = val_shape.layer_name + '_cast'
node.fluid_code.add_layer(
'cast',
inputs=val_shape,
output=val_shape_cast,
param_attr={'dtype': string('int32')})
# shape may be [], come form Gather by scalar indices
if len(val_shape.out_shapes[0]) > 0:
node.fluid_code.add_layer(
'reshape',
inputs=val_shape_cast,
output=val_shape_cast,
param_attr={'shape': val_shape.out_shapes[0]})
node.fluid_code.add_layer(
'reshape',
inputs={'x': val_x,
'shape': val_shape_cast},
output=node,
param_attr=attr)
else: else:
# shape may be [], come form Gather by scalar indices # shape may be [], come form Gather by scalar indices
if len(val_shape.out_shapes[0]) > 0: if len(val_shape.out_shapes[0]) > 0:
node.fluid_code.add_layer( self.paddle_graph.add_layer(
'reshape', 'paddle.reshape',
inputs=val_shape, inputs={'x': val_shape.name},
output=val_shape, outputs=[val_shape.name],
param_attr={'shape': val_shape.out_shapes[0]}) shape=val_shape.out_shapes[0])
node.fluid_code.add_layer( self.paddle_graph.add_layer(
'reshape', 'paddle.reshape',
inputs={'x': val_x, inputs={'x': val_x.name,
'shape': val_shape}, 'shape': val_shape.name},
output=node, outputs=node)
param_attr=attr)
@print_mapping_info @print_mapping_info
def Cast(self, node): def Cast(self, node):
...@@ -1036,14 +1004,18 @@ class OpSet9(): ...@@ -1036,14 +1004,18 @@ class OpSet9():
output_dtype = val_output.dtype output_dtype = val_output.dtype
if output_dtype: if output_dtype:
assert dtype == output_dtype, 'dtype of to unmatches output' assert dtype == output_dtype, 'dtype of to unmatches output'
attr = {'dtype': string(dtype)} self.paddle_graph.add_layer(
node.fluid_code.add_layer( 'paddle.cast',
'cast', inputs=val_input, output=node, param_attr=attr) inputs={'x': val_input.name},
outputs=[node.name],
dtype=string(dtype))
@print_mapping_info @print_mapping_info
def Not(self, node): def Not(self, node):
val_input = self.graph.get_input_node(node, idx=0, copy=True) val_input = self.graph.get_input_node(node, idx=0, copy=True)
node.fluid_code.add_layer('logical_not', inputs=val_input, output=node) self.paddle_graph.add_layer('paddle.logical_not',
inputs={'x': val_input.name},
outputs=[node.name])
@print_mapping_info @print_mapping_info
def AveragePool(self, node): def AveragePool(self, node):
...@@ -1056,8 +1028,6 @@ class OpSet9(): ...@@ -1056,8 +1028,6 @@ class OpSet9():
pad_mode = node.get_attr("pads") pad_mode = node.get_attr("pads")
ceil_mode = bool(node.get_attr('ceil_mode', 0)) ceil_mode = bool(node.get_attr('ceil_mode', 0))
pads = node.get_attr('pads', [0] * (poolnd * 2)) pads = node.get_attr('pads', [0] * (poolnd * 2))
fluid_op = 'pool{}d'.format(poolnd)
assert 2 <= poolnd <= 3, 'only pool2d and pool3d is supported'
paddings, val_x = self._pad_if_asymmetric(node, pads, val_x) paddings, val_x = self._pad_if_asymmetric(node, pads, val_x)
...@@ -1069,44 +1039,60 @@ class OpSet9(): ...@@ -1069,44 +1039,60 @@ class OpSet9():
strides[1]) strides[1])
paddings = pad_h + pad_w paddings = pad_h + pad_w
attr = { paddle_op = 'fluid.layers.pool{}d'.format(poolnd)
assert 2 <= poolnd <= 3, 'only pool2d and pool3d are supported'
layer_attrs = {
"pool_size": kernel_shape, "pool_size": kernel_shape,
"pool_type": string('avg'), "pool_type": string('avg'),
"pool_stride": strides, "pool_stride": strides,
"pool_padding": paddings, "pool_padding": paddings,
"ceil_mode": ceil_mode, "ceil_mode": ceil_mode,
"exclusive": 'True', "exclusive": 'True',
"name": string(node.layer_name) "name": string(node.name)
} }
self.paddle_graph.add_layer(
node.fluid_code.add_layer( paddle_op,
fluid_op, inputs=val_x, output=node, param_attr=attr) inputs={'input': val_x if isinstance(val_x, str) else val_x.name},
outputs=[node.name],
**layer_attrs)
# TODO(syf): op has diff
@print_mapping_info @print_mapping_info
def Concat(self, node): def Concat(self, node):
inputs = [] inputs_list = []
dtypes = set() dtypes = set()
for i in range(len(node.layer.input)): for i in range(len(node.layer.input)):
ipt = self.graph.get_input_node(node, idx=i, copy=True) ipt = self.graph.get_input_node(node, idx=i, copy=True)
if isinstance(ipt, str): inputs_list.append(ipt.name)
inputs.append(ipt) dtypes.add(ipt.dtype)
else:
inputs.append(ipt.layer_name)
dtypes.add(ipt.dtype)
if len(dtypes) > 1: if len(dtypes) > 1:
assert 'Unspported situation happened, please create issue on https://github.com/PaddlePaddle/X2Paddle/issues.' assert 'Unspported situation happened, please create issue on https://github.com/PaddlePaddle/X2Paddle/issues.'
axis = node.get_attr('axis') axis = node.get_attr('axis')
attr = {'axis': axis} self.paddle_graph.add_layer(
node.fluid_code.add_layer( 'paddle.concat',
'concat', inputs=inputs, output=node, param_attr=attr) inputs={"x": inputs_list},
outputs=[node.name],
axis=axis)
@print_mapping_info @print_mapping_info
def Flatten(self, node): def Flatten(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
output_shape = node.out_shapes[0]
axis = node.get_attr('axis', 1) axis = node.get_attr('axis', 1)
attr = {"axis": str(axis), "name": string(node.layer_name)} shape_list = [1, 1]
node.fluid_code.add_layer( if axis == 0:
'flatten', inputs=val_x, output=node, param_attr=attr) for s in output_shape:
shape_list[1] *= s
else:
for s in output_shape[:axis]:
shape_list[0] *= s
for s in output_shape[axis:]:
shape_list[1] *= s
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={"x": val_x.name},
outputs=[node.name],
shape=shape_list)
@print_mapping_info @print_mapping_info
def Gemm(self, node): def Gemm(self, node):
...@@ -1118,65 +1104,68 @@ class OpSet9(): ...@@ -1118,65 +1104,68 @@ class OpSet9():
beta = node.get_attr('beta', 1.) # optional beta = node.get_attr('beta', 1.) # optional
trans_a = bool(node.get_attr('transA', 0)) # optional trans_a = bool(node.get_attr('transA', 0)) # optional
trans_b = bool(node.get_attr('transB', 0)) # optional trans_b = bool(node.get_attr('transB', 0)) # optional
val_mm = node.layer_name + '_mm' val_mm = node.name + '_mm'
matmul_inputs = {"x": val_a, "y": val_b} matmul_inputs = {"x": val_a.name,
"y": val_b.name}
attr_matmul = { attr_matmul = {
"transpose_x": trans_a, "transpose_x": trans_a,
"transpose_y": trans_b, "transpose_y": trans_b,
"alpha": alpha,
"name": string(val_mm)
} }
node.fluid_code.add_layer( self.paddle_graph.add_layer(
'matmul', 'paddle.matmul',
inputs=matmul_inputs, inputs=matmul_inputs,
output=val_mm, outputs=[val_mm],
param_attr=attr_matmul) **attr_matmul)
self.paddle_graph.add_layer(
"paddle.scale",
inputs={"x": val_mm},
outputs=[val_mm],
scale=alpha)
if beta != 0: if beta != 0:
if beta == 1.: if beta == 1.:
add_inputs = {"x": val_mm, "y": val_c} add_inputs = {"x": val_mm,
attr = {"name": string(node.layer_name)} "y": val_c.name}
node.fluid_code.add_layer( self.paddle_graph.add_layer(
"elementwise_add", "paddle.add",
inputs=add_inputs, inputs=add_inputs,
output=node, outputs=[node.name])
param_attr=attr)
else: else:
var_beta = node.layer_name + '_beta' var_beta = node.name + '_beta'
matmul_beta_inputs = {"x": val_c, "y": var_beta} self.paddle_graph.add_layer(
node.fluid_code.add_layer( "paddle.scale",
"Constant", inputs={"x": val_c.name},
inputs=matmul_beta_inputs, outputs=[var_beta],
output=var_beta, scale=beta)
param_attr={'value': beta})
add_inputs = {"x": val_mm, "y": var_beta} add_inputs = {"x": val_mm, "y": var_beta}
attr = {"name": string(node.layer_name)} self.paddle_graph.add_layer(
node.fluid_code.add_layer( "paddle.add",
"elementwise_add",
inputs=add_inputs, inputs=add_inputs,
output=node, outputs=[node.name])
param_attr=attr)
@print_mapping_info @print_mapping_info
def Sum(self, node): def Sum(self, node):
val_inps = node.layer.input val_inps = node.layer.input
inputs = { inputs_dict = {
"x": self.graph.get_input_node( "x": self.graph.get_input_node(
node, idx=0, copy=True), node, idx=0, copy=True).name,
"y": self.graph.get_input_node( "y": self.graph.get_input_node(
node, idx=1, copy=True), node, idx=1, copy=True).name,
} }
node.fluid_code.add_layer("elementwise_add", inputs=inputs, output=node) self.paddle_graph.add_layer("paddle.add",
inputs=inputs_dict,
outputs=[node.name])
for idx, ipt in enumerate(val_inps[2:]): for idx, ipt in enumerate(val_inps[2:]):
y = self.graph.get_input_node(node, idx=idx, copy=True) y = self.graph.get_input_node(node, idx=idx, copy=True)
inputs = { inputs_dict = {
"x": node.layer_name, "x": node.name,
"y": y, "y": y.name,
} }
node.fluid_code.add_layer( self.paddle_graph.add_layer(
"elementwise_add", inputs=inputs, output=node) "paddle.add",
inputs=inputs_dict,
outputs=[node.name])
@print_mapping_info @print_mapping_info
def MatMul(self, node): def MatMul(self, node):
...@@ -1184,21 +1173,26 @@ class OpSet9(): ...@@ -1184,21 +1173,26 @@ class OpSet9():
val_y = self.graph.get_input_node(node, idx=1, copy=True) val_y = self.graph.get_input_node(node, idx=1, copy=True)
x_shape = val_x.out_shapes[0] x_shape = val_x.out_shapes[0]
y_shape = val_y.out_shapes[0] y_shape = val_y.out_shapes[0]
inputs = {"x": val_x, "y": val_y} inputs_dict = {"x": val_x.name,
"y": val_y.name}
if y_shape[0] == 1 and x_shape[-1] != 1 and x_shape[0] != 1: if y_shape[0] == 1 and x_shape[-1] != 1 and x_shape[0] != 1:
y_squeeze = val_y.layer_name + '_squeeze' y_squeeze = val_y.name + '_squeeze'
node.fluid_code.add_layer( self.paddle_graph.add_layer(
"squeeze", "paddle.squeeze",
inputs=val_y, inputs={"x": val_y.name},
output=y_squeeze, outputs=[y_squeeze],
param_attr={'axes': [0]}) axis=[0])
inputs['y'] = y_squeeze inputs_dict['y'] = y_squeeze
node.fluid_code.add_layer( self.paddle_graph.add_layer(
"matmul", inputs=inputs, output=node, param_attr=None) "paddle.matmul",
inputs=inputs_dict,
outputs=[node.name])
else: else:
node.fluid_code.add_layer( self.paddle_graph.add_layer(
"matmul", inputs=inputs, output=node, param_attr=None) "paddle.matmul",
inputs=inputs_dict,
outputs=[node.name])
@print_mapping_info @print_mapping_info
def BatchNormalization(self, node): def BatchNormalization(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
...@@ -1207,108 +1201,98 @@ class OpSet9(): ...@@ -1207,108 +1201,98 @@ class OpSet9():
val_mean = self.graph.get_input_node(node, idx=3, copy=True) val_mean = self.graph.get_input_node(node, idx=3, copy=True)
val_var = self.graph.get_input_node(node, idx=4, copy=True) val_var = self.graph.get_input_node(node, idx=4, copy=True)
self.omit_nodes.append(val_scale.layer_name)
self.omit_nodes.append(val_b.layer_name)
self.omit_nodes.append(val_mean.layer_name)
self.omit_nodes.append(val_var.layer_name)
momentum = node.get_attr('momentum', .9) momentum = node.get_attr('momentum', .9)
epsilon = node.get_attr('epsilon', 1e-5) epsilon = node.get_attr('epsilon', 1e-5)
# Attribute: spatial is used in BatchNormalization-1,6,7 # Attribute: spatial is used in BatchNormalization-1,6,7
spatial = bool(node.get_attr('spatial')) spatial = bool(node.get_attr('spatial'))
attr = { layer_attrs = {
"momentum": momentum, "momentum": momentum,
"epsilon": epsilon, "epsilon": epsilon,
"data_layout": string('NCHW'),
"is_test": True,
"param_attr": string(val_scale.layer_name),
"bias_attr": string(val_b.layer_name),
"moving_mean_name": string(val_mean.layer_name),
"moving_variance_name": string(val_var.layer_name),
"use_global_stats": spatial,
"name": string(node.layer_name)
} }
node.fluid_code.add_layer( self.paddle_graph.add_layer(
"batch_norm", inputs=val_x, output=node, param_attr=attr) "paddle.nn.functional.batch_norm",
inputs={"x": val_x.name,
"weight": val_scale.name,
"bias": val_b.name,
"running_mean": val_mean.name,
"running_var": val_var.name},
outputs=[node.name],
**layer_attrs)
@print_mapping_info @print_mapping_info
def Transpose(self, node): def Transpose(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
perm = node.get_attr('perm') perm = node.get_attr('perm')
attr = {'perm': perm, "name": string(node.layer_name)} self.paddle_graph.add_layer(
node.fluid_code.add_layer( "paddle.transpose",
"transpose", inputs=val_x, output=node, param_attr=attr) inputs={"x": val_x.name},
outputs=[node.name],
@print_mapping_info perm=perm)
def Relu(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
attr = {"name": string(node.layer_name)}
node.fluid_code.add_layer(
"relu", inputs=val_x, output=node, param_attr=attr)
@print_mapping_info @print_mapping_info
def PRelu(self, node): def PRelu(self, node):
op_name = name_generator("prelu", self.nn_name2id)
output_name = node.name
layer_outputs = [op_name, output_name]
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_slope = self.graph.get_input_node(node, idx=1, copy=True) val_slope = self.graph.get_input_node(node, idx=1, copy=True)
mode = 'channel' mode = 'channel'
shape_slope = val_slope.out_shapes[0] shape_slope = val_slope.out_shapes[0]
if shape_slope == [1]: if shape_slope == [1]:
mode = 'all' mode = 'all'
elif len(shape_slope) > 2: elif len(shape_slope) > 2:
mode = 'element' raise Exception("The 'element' mode is not supported yet!")
if mode == 'channel' and len(shape_slope) == 1: if mode == 'channel' and len(shape_slope) == 1:
# paddle params shape need be [1, channel] # paddle params shape need be [1, channel]
slope_data = _const_weight_or_none(val_slope) slope_data = _const_weight_or_none(val_slope)
slope_data = np.reshape(slope_data, [1] + shape_slope) slope_data = np.reshape(slope_data, [1] + shape_slope)
self.weights[val_slope.layer_name] = slope_data self.params[val_slope.name] = slope_data
self.omit_nodes.append(val_slope.layer_name) self.paddle_graph.add_layer(
attr = { "paddle.nn.functional.prelu",
"param_attr": string(val_slope.layer_name), inputs={"x": val_x.name,
'mode': string(mode) "weight": val_slope.name},
} outputs=[node.name])
node.fluid_code.add_layer(
"prelu", inputs=val_x, output=node, param_attr=attr)
@print_mapping_info @print_mapping_info
def Squeeze(self, node): def Squeeze(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
axes = node.get_attr('axes') axes = node.get_attr('axes')
attr = {'axes': axes, "name": string(node.layer_name)}
if len(val_x.out_shapes[0]) == 1: if len(val_x.out_shapes[0]) == 1:
node.fluid_code.add_layer( self.paddle_graph.add_layer(
"cast", "paddle.cast",
inputs=val_x, inputs={"x": val_x.name},
output=node, outputs=[node.name],
param_attr={'dtype': string(val_x.dtype)}) dtype=string(val_x.dtype))
else: else:
node.fluid_code.add_layer( self.paddle_graph.add_layer(
"squeeze", inputs=val_x, output=node, param_attr=attr) "paddle.squeeze",
inputs={"x": val_x.name},
outputs=[node.name],
axis=axes)
@print_mapping_info @print_mapping_info
def Equal(self, node): def Equal(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_input_node(node, idx=1, copy=True) val_y = self.graph.get_input_node(node, idx=1, copy=True)
node.fluid_code.add_layer( self.paddle_graph.add_layer(
"equal", "paddle.equal",
inputs={'x': val_x, inputs={'x': val_x.name,
'y': val_y}, 'y': val_y.name},
output=node, outputs=[node.name])
param_attr=None)
@print_mapping_info @print_mapping_info
def Greater(self, node): def Greater(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_input_node(node, idx=1, copy=True) val_y = self.graph.get_input_node(node, idx=1, copy=True)
node.fluid_code.add_layer( self.paddle_graph.add_layer(
"greater_than", "paddle.greater_than",
inputs={'x': val_x, inputs={'x': val_x.name,
'y': val_y}, 'y': val_y.name},
output=node, outputs=node,
param_attr=None) param_attr=None)
@print_mapping_info @print_mapping_info
...@@ -1317,72 +1301,80 @@ class OpSet9(): ...@@ -1317,72 +1301,80 @@ class OpSet9():
val_x = self.graph.get_input_node(node, idx=1, copy=True) val_x = self.graph.get_input_node(node, idx=1, copy=True)
val_y = self.graph.get_input_node(node, idx=2, copy=True) val_y = self.graph.get_input_node(node, idx=2, copy=True)
not_condition = condition.layer_name + '_not' not_condition = condition.name + '_not'
node.fluid_code.add_layer( self.paddle_graph.add_layer(
"logical_not", "paddle.logical_not",
inputs=condition, inputs={"x": condition.name},
output=not_condition, outputs=[not_condition])
param_attr=None)
cast_not_condition = not_condition + '_cast' cast_not_condition = not_condition + '_cast'
node.fluid_code.add_layer( self.paddle_graph.add_layer(
"cast", "paddle.cast",
inputs=not_condition, inputs={"x": not_condition},
output=cast_not_condition, outputs=[cast_not_condition],
param_attr={'dtype': string(val_x.dtype)}) dtype=string(val_x.dtype))
cast_condition = condition.layer_name + '_cast' cast_condition = condition.name + '_cast'
node.fluid_code.add_layer( self.paddle_graph.add_layer(
"cast", "paddle.cast",
inputs=condition, inputs={"x": condition.name},
output=cast_condition, outputs=[cast_condition],
param_attr={'dtype': string(val_x.dtype)}) dtype=string(val_x.dtype))
mul_val_x = val_x.layer_name + '_mul' mul_val_x = val_x.name + '_mul'
node.fluid_code.add_layer( self.paddle_graph.add_layer(
"elementwise_mul", "paddle.multiply",
inputs={'x': val_x, inputs={'x': val_x.name,
'y': cast_condition}, 'y': cast_condition},
output=mul_val_x, outputs=[mul_val_x])
param_attr=None) mul_val_y = val_y.name + '_mul'
mul_val_y = val_y.layer_name + '_mul' self.paddle_graph.add_layer(
node.fluid_code.add_layer( "paddle.multiply",
"elementwise_mul", inputs={'x': val_y.name,
inputs={'x': val_y,
'y': cast_not_condition}, 'y': cast_not_condition},
output=mul_val_y, outputs=[mul_val_y])
param_attr=None)
node.fluid_code.add_layer( self.paddle_graph.add_layer(
"elementwise_add", "paddle.add",
inputs={'x': mul_val_x, inputs={'x': mul_val_x,
'y': mul_val_y}, 'y': mul_val_y},
output=node, outputs=[node.name])
param_attr=None)
@print_mapping_info @print_mapping_info
def NonZero(self, node): def NonZero(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_x_dim = len(val_x.out_shapes[0]) val_x_dim = len(val_x.out_shapes[0])
if val_x_dim == 1: if val_x_dim == 1:
node.fluid_code.add_layer("nonzero", inputs=val_x, output=val_x) self.paddle_graph.add_layer(
node.fluid_code.add_layer( "paddle.nonzero",
"transpose", inputs={"x": val_x.name},
inputs=val_x, outputs=[val_x.name])
output=node, self.paddle_graph.add_layer(
param_attr={'perm': [1, 0]}) "paddle.transpose",
inputs={"x": val_x.name},
outputs=[node.layer_naem],
perm=[1, 0])
if val_x_dim > 1: if val_x_dim > 1:
node.fluid_code.add_layer("nonzero", inputs=val_x, output=val_x) self.paddle_graph.add_layer(
node.fluid_code.add_layer( "paddle.nonzero",
"split", inputs={"x": val_x.name},
inputs=val_x, outputs=[val_x.name])
output=val_x, self.paddle_graph.add_layer(
param_attr={'num_or_sections': 1, "paddle.split",
'dim': val_x_dim}) inputs={"x": val_x.name},
node.fluid_code.add_layer("concat", inputs=val_x, output=node) outputs=[val_x.name],
num_or_sections=1,
axis=val_x_dim)
self.paddle_graph.add_layer(
"paddle.concat",
inputs={"x": val_x.name},
outputs=[node.name])
@print_mapping_info @print_mapping_info
def Identity(self, node): def Identity(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
node.fluid_code.add_layer("assign", inputs=val_x, output=node) self.paddle_graph.add_layer(
"paddle.assign",
inputs={"x": val_x.name},
outputs=[node.name])
@print_mapping_info @print_mapping_info
def Tile(self, node): def Tile(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
...@@ -1390,14 +1382,13 @@ class OpSet9(): ...@@ -1390,14 +1382,13 @@ class OpSet9():
repeats = _const_weight_or_none(val_repeats) repeats = _const_weight_or_none(val_repeats)
if repeats is None: if repeats is None:
repeats = val_repeats.layer_name repeats = val_repeats.name
if val_repeats.dtype != 'int32': if val_repeats.dtype != 'int32':
attr = {"dtype": string("int32")} self.paddle_graph.add_layer(
node.fluid_code.add_layer( "paddle.cast",
"cast", inputs={"x": repeats},
inputs=repeats, outputs=["{}.tmp".format(repeats)],
output="{}.tmp".format(repeats), dtype=string("int32"))
param_attr=attr)
repeats = "{}.tmp".format(repeats) repeats = "{}.tmp".format(repeats)
elif isinstance(repeats, int): elif isinstance(repeats, int):
...@@ -1405,10 +1396,13 @@ class OpSet9(): ...@@ -1405,10 +1396,13 @@ class OpSet9():
attr = { attr = {
'expand_times': repeats, 'expand_times': repeats,
"name": string(node.layer_name), "name": string(node.name),
} }
node.fluid_code.add_layer( self.paddle_graph.add_layer(
"expand", inputs=val_x, output=node, param_attr=attr) "paddle.tile",
inputs={"x": val_x.name},
outputs=[node.name],
repeat_times=repeats)
@print_mapping_info @print_mapping_info
def MaxPool(self, node): def MaxPool(self, node):
...@@ -1423,8 +1417,8 @@ class OpSet9(): ...@@ -1423,8 +1417,8 @@ class OpSet9():
pad_mode = node.get_attr("pads") pad_mode = node.get_attr("pads")
ceil_mode = bool(node.get_attr('ceil_mode', 0)) # optional ceil_mode = bool(node.get_attr('ceil_mode', 0)) # optional
pads = node.get_attr('pads', [0] * (poolnd * 2)) # optional pads = node.get_attr('pads', [0] * (poolnd * 2)) # optional
fluid_op = 'pool{}d'.format(poolnd) paddle_op = 'paddle.nn.functional.max_pool{}d'.format(poolnd)
assert 2 <= poolnd <= 3, 'only pool2d and pool3d is supported' assert 1 <= poolnd <= 3, 'only max_pool1d, max_pool2d and max_pool3d are supported'
paddings, val_x = self._pad_if_asymmetric(node, pads, val_x) paddings, val_x = self._pad_if_asymmetric(node, pads, val_x)
...@@ -1435,64 +1429,72 @@ class OpSet9(): ...@@ -1435,64 +1429,72 @@ class OpSet9():
pad_w = _get_same_padding(input_shape[3], kernel_shape[1], pad_w = _get_same_padding(input_shape[3], kernel_shape[1],
strides[1]) strides[1])
paddings = pad_h + pad_w paddings = pad_h + pad_w
attr = { layer_attrs = {
"pool_size": kernel_shape, "kernel_size": kernel_shape,
"pool_type": string("max"), "stride": strides,
"pool_stride": strides, "padding": paddings,
"pool_padding": paddings,
"ceil_mode": ceil_mode, "ceil_mode": ceil_mode,
"name": string(node.layer_name),
"exclusive": False
} }
node.fluid_code.add_layer( self.paddle_graph.add_layer(
fluid_op, inputs=val_x, output=node, param_attr=attr) paddle_op,
inputs={'x': val_x if isinstance(val_x, str) else val_x.name},
def _global_pool(self, node): outputs=[node.name],
val_x = self.graph.get_input_node(node, idx=0, copy=True) **layer_attrs)
val_y = self.graph.get_node(node.layer.output[0], copy=True)
fluid_op = 'pool2d'
pool_type = None
if node.layer.op_type == 'GlobalMaxPool':
pool_type = 'max'
elif node.layer.op_type == 'GlobalAveragePool':
pool_type = 'avg'
attr = {
"pool_type": string(pool_type),
"global_pooling": True,
"name": string(node.layer_name)
}
node.fluid_code.add_layer(
fluid_op, inputs=val_x, output=node, param_attr=attr)
@print_mapping_info @print_mapping_info
def GlobalMaxPool(self, node): def GlobalMaxPool(self, node):
self._global_pool(node) val_x = self.graph.get_input_node(node, idx=0, copy=True)
input_shape = val_x.out_shapes[0]
if len(input_shape) == 4:
poolnd = 2
elif len(input_shape) == 5:
poolnd = 3
elif len(input_shape) == 3:
poolnd = 1
paddle_op = 'paddle.nn.functional.adaptive_max_pool{}d'.format(poolnd)
assert 1 <= poolnd <= 3, 'only adaptive_max_pool1d, adaptive_max_pool2d and adaptive_max_pool3d are supported'
output_shape = node.out_shapes[0]
self.paddle_graph.add_layer(
paddle_op,
inputs={'x': val_x.name},
outputs=[node.name],
output_size=output_shape[2:])
@print_mapping_info @print_mapping_info
def GlobalAveragePool(self, node): def GlobalAveragePool(self, node):
self._global_pool(node) val_x = self.graph.get_input_node(node, idx=0, copy=True)
input_shape = val_x.out_shapes[0]
if len(input_shape) == 4:
poolnd = 2
elif len(input_shape) == 5:
poolnd = 3
elif len(input_shape) == 3:
poolnd = 1
paddle_op = 'paddle.nn.functional.adaptive_avg_pool{}d'.format(poolnd)
assert 1 <= poolnd <= 3, 'only Pool1D, Pool2D and Pool3D are supported'
output_shape = node.out_shapes[0]
self.paddle_graph.add_layer(
paddle_op,
inputs={'x': val_x.name},
outputs=[node.name],
output_size=output_shape[2:])
@print_mapping_info @print_mapping_info
def Conv(self, node): def Conv(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_w = self.graph.get_input_node(node, idx=1, copy=True) val_w = self.graph.get_input_node(node, idx=1, copy=True)
val_y = self.graph.get_node(node.layer.output[0], copy=True)
self.omit_nodes.append(val_w.layer_name)
has_bias = len(node.layer.input) == 3 has_bias = len(node.layer.input) == 3
if has_bias: if has_bias:
val_b = self.graph.get_input_node(node, idx=2, copy=True) val_b = self.graph.get_input_node(node, idx=2, copy=True)
self.omit_nodes.append(val_b.layer_name)
auto_pad = node.get_attr('auto_pad', 'NOTSET') auto_pad = node.get_attr('auto_pad', 'NOTSET')
kernel_shape = node.get_attr('kernel_shape') kernel_shape = node.get_attr('kernel_shape')
convnd = len(kernel_shape) convnd = len(kernel_shape)
assert 2 <= convnd <= 3, 'only conv2d and conv3d is supported' assert 2 <= convnd <= 3, 'only conv2d and conv3d is supported'
num_out_channels = val_w.out_shapes[0][0] num_out_channels = val_w.out_shapes[0][0]
fluid_op = 'conv{}d'.format(convnd) num_in_channels = val_w.out_shapes[0][1]
paddle_op = 'paddle.nn.functional.conv{}d'.format(convnd)
num_groups = node.get_attr('group', 1) num_groups = node.get_attr('group', 1)
strides = node.get_attr('strides', [1] * convnd) strides = node.get_attr('strides', [1] * convnd)
...@@ -1509,22 +1511,23 @@ class OpSet9(): ...@@ -1509,22 +1511,23 @@ class OpSet9():
strides[1]) strides[1])
paddings = pad_h + pad_w paddings = pad_h + pad_w
attr = { layer_attrs = {
"num_filters": num_out_channels,
"filter_size": kernel_shape,
"stride": strides, "stride": strides,
"padding": paddings, "padding": paddings,
"dilation": dilations, "dilation": dilations,
"groups": num_groups, "groups": num_groups,
'param_attr': string(val_w.layer_name), }
"name": string(node.layer_name) layer_inputs = {
"x": val_x.name,
"weight": val_w.name
} }
if has_bias: if has_bias:
attr["bias_attr"] = string(val_b.layer_name) layer_inputs["bias"] = val_b.name
else: self.paddle_graph.add_layer(
attr["bias_attr"] = False paddle_op,
node.fluid_code.add_layer( inputs=layer_inputs,
fluid_op, inputs=val_x, output=node, param_attr=attr) outputs=[node.name],
**layer_attrs)
@print_mapping_info @print_mapping_info
def ConvTranspose(self, node): def ConvTranspose(self, node):
...@@ -1533,19 +1536,15 @@ class OpSet9(): ...@@ -1533,19 +1536,15 @@ class OpSet9():
val_b = None val_b = None
if len(node.layer.input) > 2: if len(node.layer.input) > 2:
val_b = self.graph.get_input_node(node, idx=2, copy=True) val_b = self.graph.get_input_node(node, idx=2, copy=True)
self.omit_nodes.append(val_b.layer_name)
self.omit_nodes.append(val_w.layer_name)
val_y = self.graph.get_node(node.layer.output[0], copy=True)
auto_pad = node.get_attr('auto_pad', 'NOTSET') auto_pad = node.get_attr('auto_pad', 'NOTSET')
out_padding = node.get_attr('output_padding', [0, 0]) out_padding = node.get_attr('output_padding', [0, 0])
kernel_shape = node.get_attr('kernel_shape') kernel_shape = node.get_attr('kernel_shape')
assert kernel_shape, 'kernel_shape not inferred' assert kernel_shape, 'kernel_shape not inferred'
convnd = len(kernel_shape) convnd = len(kernel_shape)
assert 2 <= convnd <= 3, 'only conv2d_transpose and conv3d_transpose supported' assert 2 <= convnd <= 3, 'only conv2d_transpose and conv3d_transpose supported'
num_in_channels = val_w.out_shapes[0][0]
num_out_channels = val_w.out_shapes[0][1] num_out_channels = val_w.out_shapes[0][1]
fluid_op = 'conv{}d_transpose'.format(convnd) paddle_op = 'paddle.nn.functional.conv{}d_transpose'.format(convnd)
num_groups = node.get_attr('group', 1) num_groups = node.get_attr('group', 1)
strides = node.get_attr('strides', [1] * convnd) strides = node.get_attr('strides', [1] * convnd)
...@@ -1563,17 +1562,18 @@ class OpSet9(): ...@@ -1563,17 +1562,18 @@ class OpSet9():
output_size[1] = (val_x.out_shapes[0][3] - 1 output_size[1] = (val_x.out_shapes[0][3] - 1
) * strides[1] - 2 * paddings[1] + dilations[1] * ( ) * strides[1] - 2 * paddings[1] + dilations[1] * (
kernel_shape[1] - 1) + 1 + out_padding[1] kernel_shape[1] - 1) + 1 + out_padding[1]
attr = { layer_inputs = {'x': val_x.name,
'num_filters': num_out_channels, "weight": val_w.name}
'output_size': output_size or None, layer_attrs = {
'filter_size': kernel_shape, "stride": strides,
'padding': paddings, "dilation": dilations,
'stride': strides, "padding": paddings,
'dilation': dilations, "groups": num_groups,
'groups': num_groups, "output_size": node.out_shapes[0][2:]}
'param_attr': string(val_w.layer_name), if val_b is not None:
'bias_attr': None if val_b is None else string(val_b.layer_name), layer_inputs["bias"] = val_b.name
'name': string(node.layer_name), self.paddle_graph.add_layer(
} kernel=paddle_op,
node.fluid_code.add_layer( inputs=layer_inputs,
fluid_op, inputs=val_x, output=node, param_attr=attr) outputs=[node.name],
**layer_attrs)
\ No newline at end of file
...@@ -75,15 +75,17 @@ class TFOpMapper(OpMapper): ...@@ -75,15 +75,17 @@ class TFOpMapper(OpMapper):
'Sub': 'fluid.layers.elementwise_sub', 'Sub': 'fluid.layers.elementwise_sub',
'Maximum': 'paddle.maximum', 'Maximum': 'paddle.maximum',
'Minimum': 'paddle.minimum', 'Minimum': 'paddle.minimum',
'Mul': 'paddle.multiply',
'FloorDiv': 'paddle.floor_divide',
'FloorMod': 'paddle.floor_mod',
'LogicalAnd': 'logical_and',
}
bool_ops = {
'LessEqual': 'paddle.less_equal', 'LessEqual': 'paddle.less_equal',
'GreaterEqual': 'paddle.greater_equal', 'GreaterEqual': 'paddle.greater_equal',
'Greater': 'paddle.greater_than', 'Greater': 'paddle.greater_than',
'NotEqual': 'paddle.not_equal', 'NotEqual': 'paddle.not_equal',
'Equal': 'paddle.equal', 'Equal': 'paddle.equal',
'Mul': 'paddle.multiply',
'FloorDiv': 'paddle.floor_divide',
'FloorMod': 'paddle.floor_mod',
'LogicalAnd': 'logical_and',
} }
def __init__(self, decoder): def __init__(self, decoder):
...@@ -94,6 +96,7 @@ class TFOpMapper(OpMapper): ...@@ -94,6 +96,7 @@ class TFOpMapper(OpMapper):
raise Exception("Model is not supported yet.") raise Exception("Model is not supported yet.")
self.params = dict() self.params = dict()
self.paddle_graph = PaddleGraph(parent_layer=None, graph_type="static", source_type="tf") self.paddle_graph = PaddleGraph(parent_layer=None, graph_type="static", source_type="tf")
self.params_output2id = dict()
not_placeholder = list() not_placeholder = list()
for name in self.graph.input_nodes: for name in self.graph.input_nodes:
...@@ -124,6 +127,8 @@ class TFOpMapper(OpMapper): ...@@ -124,6 +127,8 @@ class TFOpMapper(OpMapper):
self.directly_map(node) self.directly_map(node)
elif op in self.elementwise_ops: elif op in self.elementwise_ops:
self.elementwise_map(node) self.elementwise_map(node)
elif op in self.bool_ops:
self.bool_map(node)
elif hasattr(self, op): elif hasattr(self, op):
func = getattr(self, op) func = getattr(self, op)
func(node) func(node)
...@@ -138,7 +143,8 @@ class TFOpMapper(OpMapper): ...@@ -138,7 +143,8 @@ class TFOpMapper(OpMapper):
op = node.layer_type op = node.layer_type
if not hasattr(self, op) and \ if not hasattr(self, op) and \
op not in self.directly_map_ops and \ op not in self.directly_map_ops and \
op not in self.elementwise_ops: op not in self.elementwise_ops and \
op not in self.bool_ops:
unsupported_ops.add(op) unsupported_ops.add(op)
if len(unsupported_ops) == 0: if len(unsupported_ops) == 0:
return True return True
...@@ -167,9 +173,10 @@ class TFOpMapper(OpMapper): ...@@ -167,9 +173,10 @@ class TFOpMapper(OpMapper):
outputs=[node.name], outputs=[node.name],
**attr) **attr)
def elementwise_map(self, node): def elementwise_map(self, node, op_type=None):
assert node.layer_type in self.elementwise_ops if op_type is None:
op_type = self.elementwise_ops[node.layer_type] assert node.layer_type in self.elementwise_ops
op_type = self.elementwise_ops[node.layer_type]
x = self.graph.get_node(node.layer.input[0]) x = self.graph.get_node(node.layer.input[0])
y = self.graph.get_node(node.layer.input[1]) y = self.graph.get_node(node.layer.input[1])
x_shape = x.out_shapes[0] x_shape = x.out_shapes[0]
...@@ -180,6 +187,11 @@ class TFOpMapper(OpMapper): ...@@ -180,6 +187,11 @@ class TFOpMapper(OpMapper):
"y": y.name}, "y": y.name},
outputs=[node.name]) outputs=[node.name])
self.paddle_graph.layers[layer_id].input_shapes = {"x": x_shape, "y": y_shape} self.paddle_graph.layers[layer_id].input_shapes = {"x": x_shape, "y": y_shape}
def bool_map(self, node):
op_type = self.bool_ops[node.layer_type]
self.elementwise_map(node, op_type)
node.set_dtype("bool")
def Placeholder(self, node): def Placeholder(self, node):
shape = node.out_shapes[0] shape = node.out_shapes[0]
...@@ -213,7 +225,7 @@ class TFOpMapper(OpMapper): ...@@ -213,7 +225,7 @@ class TFOpMapper(OpMapper):
return return
self.params[node.name] = node.value self.params[node.name] = node.value
self.paddle_graph.add_layer( layer_id = self.paddle_graph.add_layer(
kernel="paddle.static.create_parameter", kernel="paddle.static.create_parameter",
inputs={}, inputs={},
outputs=[node.name], outputs=[node.name],
...@@ -221,6 +233,7 @@ class TFOpMapper(OpMapper): ...@@ -221,6 +233,7 @@ class TFOpMapper(OpMapper):
shape=shape, shape=shape,
name=string(node.name), name=string(node.name),
default_initializer="paddle.nn.initializer.Constant(value=0.0)") default_initializer="paddle.nn.initializer.Constant(value=0.0)")
self.params_output2id[node.name] = layer_id
def Transpose(self, node): def Transpose(self, node):
input = self.graph.get_node(node.layer.input[0]) input = self.graph.get_node(node.layer.input[0])
...@@ -763,11 +776,17 @@ class TFOpMapper(OpMapper): ...@@ -763,11 +776,17 @@ class TFOpMapper(OpMapper):
data_format = node.get_attr("data_format").decode() data_format = node.get_attr("data_format").decode()
pad_mode = node.get_attr("padding").decode() pad_mode = node.get_attr("padding").decode()
self.paddle_graph.add_layer( if len(kernel.outputs) == 1:
kernel="paddle.transpose", self.params[kernel.name] = numpy.transpose(self.params[kernel.name],
inputs={"x": kernel.name}, (2, 3, 0, 1))
outputs=[kernel.name], layer = self.paddle_graph.layers[self.params_output2id[kernel.name]]
perm=[2, 3, 0, 1]) layer.attrs["shape"] = self.params[kernel.name].shape
else:
self.paddle_graph.add_layer(
kernel="paddle.transpose",
inputs={"x": kernel.name},
outputs=[kernel.name],
perm=[2, 3, 0, 1])
input_name = input.name input_name = input.name
if data_format == "NHWC": if data_format == "NHWC":
......
...@@ -178,13 +178,13 @@ class DygraphTransposeElimination(FuseBase): ...@@ -178,13 +178,13 @@ class DygraphTransposeElimination(FuseBase):
if _graph.layers[ipt].outputs[ if _graph.layers[ipt].outputs[
output_index] == _graph.layers[current_id].inputs[ output_index] == _graph.layers[current_id].inputs[
'x']: 'x']:
if len(x_shape) <= 1: if list(x_shape)==[1] or len(x_shape) < 1:
elementwise_layers.append(current_id) elementwise_layers.append(current_id)
continue continue
elif _graph.layers[ipt].outputs[ elif _graph.layers[ipt].outputs[
output_index] == _graph.layers[current_id].inputs[ output_index] == _graph.layers[current_id].inputs[
'y']: 'y']:
if len(y_shape) <= 1: if list(y_shape)==[1] or len(y_shape) < 1:
elementwise_layers.append(current_id) elementwise_layers.append(current_id)
continue continue
else: else:
...@@ -279,11 +279,6 @@ class DygraphTransposeElimination(FuseBase): ...@@ -279,11 +279,6 @@ class DygraphTransposeElimination(FuseBase):
for layer_id in list(set(optimized_concat_layers)): for layer_id in list(set(optimized_concat_layers)):
axis = graph.layers[layer_id].attrs.get('axis', 0) axis = graph.layers[layer_id].attrs.get('axis', 0)
graph.layers[layer_id].attrs['axis'] = [0, 2, 3, 1][axis] graph.layers[layer_id].attrs['axis'] = [0, 2, 3, 1][axis]
for layer_id in list(set(optimized_elementwise_layers)):
axis = graph.layers[layer_id].attrs.get('axis', -1)
graph.layers[layer_id].attrs['axis'] = [0, 2, 3, 1][axis]
if graph.layers[layer_id].kernel == "paddle.add":
graph.layers[layer_id].kernel = "fluid.layers.elementwise_add"
current_transpose_num = self.get_transpose_num(graph) current_transpose_num = self.get_transpose_num(graph)
print( print(
......
...@@ -178,13 +178,13 @@ class StaticTransposeElimination(FuseBase): ...@@ -178,13 +178,13 @@ class StaticTransposeElimination(FuseBase):
if _graph.layers[ipt].outputs[ if _graph.layers[ipt].outputs[
output_index] == _graph.layers[current_id].inputs[ output_index] == _graph.layers[current_id].inputs[
'x']: 'x']:
if len(x_shape) <= 1: if list(x_shape)==[1] or len(x_shape) < 1:
elementwise_layers.append(current_id) elementwise_layers.append(current_id)
continue continue
elif _graph.layers[ipt].outputs[ elif _graph.layers[ipt].outputs[
output_index] == _graph.layers[current_id].inputs[ output_index] == _graph.layers[current_id].inputs[
'y']: 'y']:
if len(y_shape) <= 1: if list(y_shape)==[1] or len(y_shape) < 1:
elementwise_layers.append(current_id) elementwise_layers.append(current_id)
continue continue
else: else:
...@@ -279,11 +279,6 @@ class StaticTransposeElimination(FuseBase): ...@@ -279,11 +279,6 @@ class StaticTransposeElimination(FuseBase):
for layer_id in list(set(optimized_concat_layers)): for layer_id in list(set(optimized_concat_layers)):
axis = graph.layers[layer_id].attrs.get('axis', 0) axis = graph.layers[layer_id].attrs.get('axis', 0)
graph.layers[layer_id].attrs['axis'] = [0, 2, 3, 1][axis] graph.layers[layer_id].attrs['axis'] = [0, 2, 3, 1][axis]
for layer_id in list(set(optimized_elementwise_layers)):
axis = graph.layers[layer_id].attrs.get('axis', -1)
graph.layers[layer_id].attrs['axis'] = [0, 2, 3, 1][axis]
if graph.layers[layer_id].kernel == "paddle.add":
graph.layers[layer_id].kernel = "fluid.layers.elementwise_add"
current_transpose_num = self.get_transpose_num(graph) current_transpose_num = self.get_transpose_num(graph)
print( print(
......
...@@ -167,7 +167,7 @@ class DygraphAdaptivePool2dFuser(FuseBase): ...@@ -167,7 +167,7 @@ class DygraphAdaptivePool2dFuser(FuseBase):
new_layer = PaddleLayer( new_layer = PaddleLayer(
layers_id[0], layers_id[0],
"paddle.nn.functional.adaptive_avg_pool2d", "paddle.nn.functional.adaptive_avg_pool2d",
inputs={"input": input_name}, inputs={"x": input_name},
outputs=[output_name], outputs=[output_name],
**attrs) **attrs)
else: else:
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO useless node remove
class ONNXOptimizer(object):
def __init__(self, op_mapper):
self.op_mapper = op_mapper
self.graph = op_mapper.graph
def delete_redundance_code(self):
for node_name in self.graph.topo_sort:
if node_name in self.op_mapper.omit_nodes:
node = self.graph.get_node(node_name)
omit_freq = self.op_mapper.omit_nodes.count(node_name)
if len(node.outputs) <= omit_freq:
node.fluid_code.clear()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册