提交 262229f5 编写于 作者: S SunAhong1993

modify onnx static

上级 9e19ff2b
......@@ -185,16 +185,8 @@ def onnx2paddle(model_path, save_dir, paddle_type, params_merge=False):
from x2paddle.op_mapper.static.onnx2paddle.onnx_op_mapper import ONNXOpMapper
model = ONNXDecoder(model_path)
mapper = ONNXOpMapper(model)
if paddle_type == "dygraph":
mapper.paddle_graph.build()
mapper.paddle_graph.gen_model(save_dir)
else:
from x2paddle.optimizer.onnx_optimizer import ONNXOptimizer
print("Model optimizing ...")
optimizer = ONNXOptimizer(mapper)
optimizer.delete_redundance_code()
print("Model optimized.")
mapper.save_inference_model(save_dir, params_merge)
mapper.paddle_graph.build()
mapper.paddle_graph.gen_model(save_dir)
def pytorch2paddle(module, save_dir, jit_type="trace", input_examples=None):
......
......@@ -534,7 +534,7 @@ class OpSet9():
'bias_attr': string(val_b.name)
}
dim = len(val_x.out_shapes[0])
if dim == 2 or dim == 3:
if dim == 3:
paddle_op = "paddle.nn.InstanceNorm1D"
elif dim == 4:
paddle_op = "paddle.nn.InstanceNorm2D"
......@@ -1539,7 +1539,6 @@ class OpSet9():
layer_outputs = [op_name, output_name]
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_w = self.graph.get_input_node(node, idx=1, copy=True)
val_y = self.graph.get_node(node.layer.output[0], copy=True)
has_bias = len(node.layer.input) == 3
if has_bias:
val_b = self.graph.get_input_node(node, idx=2, copy=True)
......@@ -1589,6 +1588,9 @@ class OpSet9():
@print_mapping_info
def ConvTranspose(self, node):
op_name = name_generator("conv", self.nn_name2id)
output_name = node.name
layer_outputs = [op_name, output_name]
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_w = self.graph.get_input_node(node, idx=1, copy=True)
val_b = None
......@@ -1602,7 +1604,7 @@ class OpSet9():
assert 2 <= convnd <= 3, 'only Conv2DTranspose and Conv3DTranspose supported'
num_in_channels = val_w.out_shapes[0][0]
num_out_channels = val_w.out_shapes[0][1]
paddle_op = 'paddle.nn.functional.conv{}d_transpose'.format(convnd)
paddle_op = 'paddle.nn.Conv{}DTranspose'.format(convnd)
num_groups = node.get_attr('group', 1)
strides = node.get_attr('strides', [1] * convnd)
......@@ -1620,37 +1622,37 @@ class OpSet9():
output_size[1] = (val_x.out_shapes[0][3] - 1
) * strides[1] - 2 * paddings[1] + dilations[1] * (
kernel_shape[1] - 1) + 1 + out_padding[1]
# layer_attrs = {
# 'in_channels': num_in_channels,
# 'out_channels': num_out_channels,
# 'output_size': output_size or None,
# 'kernel_size': kernel_shape,
# 'padding': paddings,
# 'stride': strides,
# 'dilation': dilations,
# 'groups': num_groups,
# 'weight_attr': string(val_w.name),
# 'bias_attr': None if val_b is None else string(val_b.name),
# }
# self.paddle_graph.add_layer(
# paddle_op,
# inputs={"x": val_x.name},
# outputs=layer_outputs,
# **layer_attrs)
inputs_dict = {'x': val_x if isinstance(val_x, str) else val_x.name,
"weight": val_w.name}
layer_attrs = {
"stride": strides,
"dilation": dilations,
"padding": paddings,
"groups": num_groups,
"output_size": node.out_shapes[0][2:]}
if val_b is not None:
inputs_dict["bias"] = val_b.name
else:
layer_attrs["bias"] = None
'in_channels': num_in_channels,
'out_channels': num_out_channels,
'output_size': output_size or None,
'kernel_size': kernel_shape,
'padding': paddings,
'stride': strides,
'dilation': dilations,
'groups': num_groups,
'weight_attr': string(val_w.name),
'bias_attr': None if val_b is None else string(val_b.name),
}
self.paddle_graph.add_layer(
kernel="paddle.nn.functional.conv2d_transpose",
inputs=inputs_dict,
outputs=[node.name],
paddle_op,
inputs={"x": val_x.name},
outputs=layer_outputs,
**layer_attrs)
# inputs_dict = {'x': val_x if isinstance(val_x, str) else val_x.name,
# "weight": val_w.name}
# layer_attrs = {
# "stride": strides,
# "dilation": dilations,
# "padding": paddings,
# "groups": num_groups,
# "output_size": node.out_shapes[0][2:]}
# if val_b is not None:
# inputs_dict["bias"] = val_b.name
# else:
# layer_attrs["bias"] = None
# self.paddle_graph.add_layer(
# kernel="paddle.nn.functional.conv2d_transpose",
# inputs=inputs_dict,
# outputs=[node.name],
# **layer_attrs)
......@@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.op_mapper.static.onnx2paddle.opset9 import OpSet9, custom_layers
import sys
from x2paddle.op_mapper.static.onnx2paddle.opset9 import OpSet9
from x2paddle.core.op_mapper import OpMapper
from x2paddle.decoder.onnx_decoder import ONNXGraph, ONNXGraphNode, ONNXGraphDataNode
from x2paddle.decoder.onnx_decoder import ONNXGraphNode
from x2paddle.core.program import PaddleGraph
class ONNXOpMapper(OpMapper):
......@@ -23,33 +25,36 @@ class ONNXOpMapper(OpMapper):
self.support_op_sets = [9, ]
self.default_op_set = 9
self.graph = decoder.graph
self.paddle_graph = PaddleGraph(parent_layer=None, graph_type="static", source_type="onnx")
self.paddle_graph.outputs = self.graph.output_nodes
self.opset = self.create_opset(decoder)
if not self.op_checker():
raise Exception("Model are not supported yet.")
#mapping op
raise Exception("Model is not supported yet.")
print("Total nodes: {}".format(
sum([
isinstance(node, ONNXGraphNode)
for name, node in self.graph.node_map.items()
])))
print("Nodes converting ...")
for node_name in self.graph.topo_sort:
for i, node_name in enumerate(self.graph.topo_sort):
sys.stderr.write("\rConverting node {} ... ".format(i + 1))
node = self.graph.get_node(node_name)
op = node.layer_type
if hasattr(self.opset, op):
func = getattr(self.opset, op)
func(node)
elif op in self.opset.default_op_mapping:
elif op in self.opset.directly_map_ops:
self.opset.directly_map(node)
elif op in custom_layers:
self.opset.deal_custom_layer(node)
elif op in self.opset.elementwise_ops:
self.opset.elementwise_map(node)
print("Nodes converted.")
self.weights = self.opset.weights
self.omit_nodes = self.opset.omit_nodes
self.used_custom_layers = self.opset.used_custom_layers
print("\nNodes converted.")
self.paddle_graph.set_name(self.graph.graph_name)
self.paddle_graph.set_parameters(self.opset.params)
self.paddle_graph.set_inputs_info(self.opset.inputs_info)
self.paddle_graph.inputs = self.graph.input_nodes
self.paddle_graph.outputs = self.graph.output_nodes
def op_checker(self):
unsupported_ops = set()
......@@ -57,17 +62,17 @@ class ONNXOpMapper(OpMapper):
node = self.graph.get_node(node_name)
op = node.layer_type
if not hasattr(self.opset, op) and \
op not in self.opset.default_op_mapping and \
op not in custom_layers and \
op not in self.opset.directly_map_ops and \
op not in self.opset.elementwise_ops:
unsupported_ops.add(op)
if len(unsupported_ops) == 0:
return True
else:
print("There are {} ops not supported yet, list as below".format(
len(unsupported_ops)))
if len(unsupported_ops) > 0:
print("\n========= {} OPs are not supported yet ===========".format(
len(unsupported_ops)))
for op in unsupported_ops:
print(op)
print("========== {} ============".format(op))
return False
def create_opset(self, decoder):
......@@ -88,4 +93,4 @@ class ONNXOpMapper(OpMapper):
'Now, onnx2paddle support convert onnx model opset_verison {},'
'opset_verison of your onnx model is {}, automatically treated as op_set: {}.'
.format(self.support_op_sets, decoder.op_set, run_op_set))
return eval(opset)(decoder)
return eval(opset)(decoder, self.paddle_graph)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .register import get_registered_layers
custom_layers = get_registered_layers()
def set_args(f, params):
""" set args for function 'f' using the parameters in node.layer.param
Args:
f (function): a python function object
params (object): a object contains attributes needed by f's arguments
Returns:
arg_names (list): a list of argument names
kwargs (dict): a dict contains needed arguments
"""
argc = f.__code__.co_argcount
arg_list = f.__code__.co_varnames[0:argc]
kwargs = {}
for arg_name in arg_list:
if hasattr(params, arg_name) and params is not None:
kwargs[arg_name] = getattr(params, arg_name)
return arg_list, kwargs
def has_layer(layer_type):
""" test whether this layer exists in custom layer
"""
return layer_type in custom_layers
def get_params(layer, layer_type):
import re
if layer_type.lower() == "deconvolution" or layer_type.lower(
) == "convolutiondepthwise":
param_name = '_'.join(('convolution', 'param'))
elif layer_type.lower() == "normalize":
param_name = '_'.join(('norm', 'param'))
elif len(layer_type) - len(re.sub("[A-Z]", "", layer_type)) >= 2:
s = ''
tmp_name = ''
for i, ch in enumerate(layer_type):
if i == 0:
s += ch.lower()
continue
elif ch.isupper() and layer_type[i - 1].islower():
tmp_name += (s + '_')
s = ''
s += ch.lower()
tmp_name += s
param_name = '_'.join((tmp_name, 'param'))
else:
param_name = '_'.join((layer_type.lower(), 'param'))
return getattr(layer, param_name, None)
def compute_output_shape(node):
""" compute the output shape of custom layer
"""
layer_type = node.layer_type
assert layer_type in custom_layers, "layer[%s] not exist in custom layers" % (
layer_type)
shape_func = custom_layers[layer_type]['shape']
layer = node.layer
params = get_params(layer, layer_type)
arg_names, kwargs = set_args(shape_func, params)
input_shape = node.input_shape
return shape_func(input_shape, **kwargs)
def make_custom_layer(node):
""" get the code which implement the custom layer function
"""
layer_type = node.layer_type
assert layer_type in custom_layers, "layer[%s] not exist in custom layers" % (
layer_type)
layer_func = custom_layers[layer_type]['layer']
import inspect
return inspect.getsource(layer_func), layer_func
def make_custom_child_func(node):
""" get the code which implement the custom layer function
"""
layer_type = node.layer_type
child_func = custom_layers[layer_type]['child_func']
if child_func is None:
return None, child_func
import inspect
return inspect.getsource(child_func), child_func
def deal_weights(node, data=None):
""" deal the weights of the custom layer
"""
layer_type = node.layer_type
weights_func = custom_layers[layer_type]['weights']
name = node.layer_name
return weights_func(name, data)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" this module provides 'register' for registering customized layers
"""
g_custom_layers = {}
def register(kind, shape, layer, child_func, weights):
""" register a custom layer or a list of custom layers
Args:
@kind (str or list): type name of the layer
@shape (function): a function to generate the shape of layer's output
@layer (function): a function to generate the paddle code of layer
@weights (function): a function to deal with weights data
Returns:
None
"""
assert type(shape).__name__ == 'function', 'shape should be a function'
assert type(layer).__name__ == 'function', 'layer should be a function'
if type(kind) is str:
kind = [kind]
else:
assert type(
kind) is list, 'invalid param "kind" for register, not a list or str'
for k in kind:
assert type(
k) is str, 'invalid param "kind" for register, not a list of str'
assert k not in g_custom_layers, 'this type[%s] has already been registered' % (
k)
g_custom_layers[k] = {
'shape': shape,
'layer': layer,
'child_func': child_func,
'weights': weights
}
def get_registered_layers():
return g_custom_layers
......@@ -27,6 +27,8 @@ import logging as _logging
from collections import OrderedDict
import math
import os
import copy
import sys
import shutil
_logger = _logging.getLogger(__name__)
......@@ -85,182 +87,118 @@ def print_mapping_info(func):
class OpSet9():
elementwise_ops = {
'Add': 'elementwise_add',
'Div': 'elementwise_div',
'Sub': 'elementwise_sub',
'Mul': 'elementwise_mul',
'Pow': 'elementwise_pow',
'Add': 'paddle.add',
'Div': 'paddle.divide',
'Sub': 'fluid.layers.elementwise_sub',
'Mul': 'paddle.multiply',
'Pow': 'paddle.pow',
}
default_op_mapping_field_values = OrderedDict()
default_op_mapping_field_values['FLUID_OP'] = ''
default_op_mapping_field_values['FLUID_INPUT_ARGS'] = None
default_op_mapping_field_values['FLUID_OUTPUT_ARGS'] = None
default_op_mapping_field_values['ATTR_MAPPING'] = dict()
default_op_mapping_field_values['DEFAULTS'] = dict()
default_op_mapping_field_values['INPUT_PERM'] = None
default_op_mapping_field_values['OUTPUT_PERM'] = None
default_op_mapping_field_values['FILL_NAME_FIELD'] = True
default_op_mapping = {
'Shape': ['shape', ['X'], ['Out']],
'Erf': ['erf', ['X'], ['Out']],
'Ceil': ['ceil', ['X'], ['Out']],
'ReduceMean': [
'reduce_mean', ['X'], ['Out'], dict(
axes='dim', keepdims='keep_dim'), dict(keep_dim=1)
],
'ReduceSum': [
'reduce_sum', ['X'], ['Out'], dict(
axes='dim', keepdims='keep_dim'), dict(keep_dim=1)
],
'ReduceMin': [
'reduce_min', ['X'], ['Out'], dict(
axes='dim', keepdims='keep_dim'), dict(keep_dim=1)
],
'ReduceMax': [
'reduce_max', ['X'], ['Out'], dict(
axes='dim', keepdims='keep_dim'), dict(keep_dim=1)
],
#active function
'Relu': ['relu', ['X'], ['Out']],
'LeakyRelu': ['leaky_relu', ['X'], ['Out'], dict(), dict(alpha=.01)],
'Elu': ['elu', ['X'], ['Out'], dict(), dict(alpha=1.)],
'ThresholdedRelu': [
'thresholded_relu', ['X'], ['Out'], dict(alpha='threshold'),
dict(alpha=1.)
],
'Tanh': ['tanh', ['X'], ['Out']],
'Sigmoid': ['sigmoid', ['X'], ['Out']],
'HardSigmoid': [
'hard_sigmoid', ['X'], ['Out'], dict(
alpha='slope', beta='offset'), dict(
slope=.2, offset=.5)
],
'Softsign': ['softsign', ['X'], ['Out']],
'Softplus': ['softplus', ['X'], ['Out']],
'Exp': ['exp', ['X'], ['Out']],
'Softmax': ['softmax', ['X'], ['Out'], dict(), dict(axis=1)],
'Sqrt': ['sqrt', ['X'], ['Out']],
'Floor': ['floor', ['X'], ['Out']],
'Abs': ['abs', ['X'], ['Out']],
directly_map_ops = {
'Ceil': ['paddle.ceil'],
# reduce function
'ReduceMean': ['paddle.mean',
dict(axes='axis', keepdims='keepdim'),
dict(keepdims=1)],
'ReduceSum': ['paddle.sum',
dict(axes='axis', keepdims='keepdim'),
dict(keepdims=1)],
'ReduceMin': ['paddle.min',
dict(axes='axis', keepdims='keepdim'),
dict(keepdim=1)],
'ReduceMax': ['paddle.max',
dict(axes='axis', keepdims='keepdim'),
dict(keepdim=1)],
# active function
'Relu': ['paddle.nn.functional.relu'],
'LeakyRelu': ['paddle.nn.functional.leaky_relu',
dict(alpha='negative_slope'),
dict(negative_slope=.01)],
'Elu': ['paddle.nn.functional.elu',
dict(alpha='alpha'),
dict(alpha=1.)],
'ThresholdedRelu': ['paddle.nn.functional.thresholded_relu',
dict(alpha='threshold'),
dict(alpha=1.)],
'Tanh': ['paddle.nn.functional.tanh'],
'Sigmoid': ['paddle.nn.functional.sigmoid'],
'Softsign': ['paddle.nn.functional.softsign'],
'Softplus': ['paddle.nn.functional.softplus',
dict(threshold='threshold'),
dict(threshold=float(sys.maxsize))],
'Exp': ['paddle.exp'],
'Softmax': ['paddle.nn.functional.softmax',
dict(axis='axis'),
dict(axis=1)],
'Sqrt': ['paddle.sqrt'],
'Floor': ['paddle.floor'],
'Abs': ['paddle.abs'],
'Erf': ['paddle.erf'],
}
default_ioa_constraint = {}
def __init__(self, decoder):
def __init__(self, decoder, paddle_graph):
super(OpSet9, self).__init__()
self.graph = decoder.graph
self.input_shapes = []
self.weights = dict()
self.omit_nodes = list()
self.used_custom_layers = dict()
self.paddle_graph = paddle_graph
self.input_index = 0
self.inputs_info = dict()
self.params = dict()
@print_mapping_info
def directly_map(self, node, name='', *args, **kwargs):
def directly_map(self, node, *args, **kwargs):
inputs = node.layer.input
outputs = node.layer.output
op_type = node.layer_type
attrs = node.attr_map
info = self.default_op_mapping[op_type]
info.extend(
list(self.default_op_mapping_field_values.values())[len(info):])
(
fluid_op,
fluid_input_args,
fluid_output_args,
attr_mapping,
default_attrs,
input_perm,
output_perm,
fill_name_field, ) = info
if fluid_op in self.default_ioa_constraint:
for predicate, message in self.default_ioa_constraint[fluid_op]:
assert predicate(inputs, outputs, attrs), message
mapped_attrs = {
attr_mapping.get(key, key): value
for key, value in attrs.items()
}
if '' in mapped_attrs:
mapped_attrs.pop('')
if '_' in mapped_attrs:
mapped_attrs.pop('_')
fluid_attrs = default_attrs.copy()
fluid_attrs.update(mapped_attrs)
inputs = inputs if input_perm is None else list(
map(lambda i: inputs[i], input_perm))
val_inps = []
for idx, ipt in enumerate(inputs):
val_inps.append(self.graph.get_input_node(node, idx=idx, copy=True))
val_outs = outputs if output_perm is None else list(
map(lambda i: outputs[i], output_perm))
attr = fluid_attrs
assert len(val_inps) == 1, 'directly_map error with multi inputs'
if fluid_op not in ['shape', 'erf']:
attr['name'] = string(node.layer_name)
node.fluid_code.add_layer(
fluid_op, inputs=val_inps[0], output=val_outs[0], param_attr=attr)
if fluid_op in ['shape']:
node.fluid_code.add_layer(
'cast',
inputs=val_outs[0],
output=val_outs[0],
param_attr={'dtype': string('int64')})
@print_mapping_info
def deal_custom_layer(self, node):
op = node.layer_type
custom_code, func = make_custom_layer(node)
child_func_code, child_func = make_custom_child_func(node)
params = get_params(node.layer, node.layer_type)
arg_names, kwargs = set_args(func, params)
kwargs['name'] = string(node.layer_name)
node.fluid_code.add_layer(
func.__code__.co_name,
inputs=node.inputs,
output=node,
param_attr=kwargs,
is_custom_layer=True)
if op not in self.used_custom_layers:
self.used_custom_layers[op] = custom_code
if op + '_child_func' not in self.used_custom_layers:
if child_func_code is not None:
self.used_custom_layers[op +
'_child_func'] = child_func_code
assert len(inputs) == 1, 'directly_map error with multi inputs'
input = self.graph.get_input_node(node, idx=0, copy=True)
onnx_attrs = node.attr_map
if '' in onnx_attrs:
onnx_attrs.pop('')
if '_' in onnx_attrs:
onnx_attrs.pop('_')
op_info = self.directly_map_ops[node.layer_type]
paddle_op = op_info[0]
layer_attrs = dict()
if len(op_info) > 1:
attrs_name_map_dict = op_info[1]
for onnx_attr_name, pd_attr_name in attrs_name_map_dict.items():
if onnx_attr_name in onnx_attrs:
layer_attrs[pd_attr_name] = onnx_attrs[onnx_attr_name]
else:
layer_attrs[pd_attr_name] = op_info[2][onnx_attr_name]
self.paddle_graph.add_layer(
kernel=paddle_op,
inputs={"x": input.name},
outputs=[node.name],
**layer_attrs)
@print_mapping_info
def elementwise_map(self, node):
assert node.layer_type in self.elementwise_ops
op_type = self.elementwise_ops[node.layer_type]
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_input_node(node, idx=1, copy=True)
inputs = {'x': val_x, 'y': val_y}
node.fluid_code.add_layer(
op_type, inputs=inputs, output=node, param_attr=None)
inputs_dict = {'x': val_x.name,
'y': val_y.name}
self.paddle_graph.add_layer(
op_type,
inputs=inputs_dict,
outputs=[node.name])
@print_mapping_info
def place_holder(self, node):
self.input_shapes.append(node.out_shapes[0])
shape = node.out_shapes[0]
for i, dim_shape in enumerate(shape):
if dim_shape == 0 and i == 0:
shape[i] = 1
if dim_shape == 0 and i != 0:
assert 'shape of input is not assigned'
attr = {
"dtype": string(node.dtype),
"shape": shape,
"name": string(node.layer_name),
"append_batch_size": 'False'
}
node.fluid_code.add_layer(
"data", inputs=None, output=node, param_attr=attr)
self.paddle_graph.add_layer(
kernel="paddle.static.data",
inputs={},
outputs=[node.name],
dtype=string(node.dtype),
shape=shape,
name=string(node.name))
self.inputs_info["x{}".format(self.input_index)] = [shape, node.dtype]
self.input_index += 1
@print_mapping_info
def create_parameter(self, node, parameter=None):
......@@ -269,30 +207,23 @@ class OpSet9():
dtype = node.dtype
shape = node.out_shapes[0]
if len(node.weight.shape) == 0:
shape = [1]
self.weights[node.layer_name] = node.weight
attr = {
'dtype': string(dtype),
'shape': shape,
'name': string(node.layer_name),
'default_initializer': 'Constant(0.0)'
}
if dtype == 'bool':
attr['dtype'] = string('int64')
node.fluid_code.add_layer(
"create_parameter", inputs=None, output=node, param_attr=attr)
node.fluid_code.add_layer(
"cast",
inputs=node,
output=node,
param_attr={'dtype': string('bool')})
elif dtype == 'uint8':
attr['dtype'] = string('float32')
node.fluid_code.add_layer(
"create_parameter", inputs=None, output=node, param_attr=attr)
self.paddle_graph.add_layer(
"paddle.full",
inputs={},
outputs=[node.name],
dtype=string(dtype),
shape=[1],
fill_value=node.weight)
else:
node.fluid_code.add_layer(
"create_parameter", inputs=None, output=node, param_attr=attr)
self.params[node.name] = node.weight
self.paddle_graph.add_layer(
kernel="paddle.static.create_parameter",
inputs={},
outputs=[node.name],
dtype=string(dtype),
shape=shape,
name=string(node.name),
default_initializer="paddle.nn.initializer.Constant(value=0.0)")
def _pad_if_asymmetric(self, node, pads, val_name): # pads: SSEE
assert len(pads) & 1 == 0
......@@ -309,49 +240,89 @@ class OpSet9():
def _interpolate(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
inputs = {'input': val_x}
inputs = {'x': val_x.name}
if node.layer_type == 'Resize':
if len(node.layer.input) == 2:
# opset 10
val_scales = self.graph.get_input_node(node, idx=1, copy=True)
inputs['scale'] = val_scales
inputs['scale_factor'] = val_scales.name
elif len(node.layer.input) == 3:
# opset 11
val_scales = self.graph.get_input_node(node, idx=2, copy=True)
inputs['scale'] = val_scales
inputs['scale_factor'] = val_scales.name
elif len(node.layer.input) == 4:
# opset 11
val_sizes = self.graph.get_input_node(node, idx=3, copy=True)
var_nc, var_hw = val_sizes.layer_name + '_nc', val_sizes.layer_name + '_hw'
node.fluid_code.add_layer(
'split',
inputs=val_sizes,
output=var_nc + ',' + var_hw,
param_attr={
'dim': 0,
'num_or_sections': [2, 2],
})
node.fluid_code.add_layer(
"cast",
inputs=var_hw,
output=var_hw,
param_attr={'dtype': string('int32')})
var_nc, var_hw = val_sizes.name + '_nc', val_sizes.name + '_hw'
self.paddle_graph.add_layer(
'paddle.split',
inputs={"x": val_sizes.name},
outputs=[var_nc, var_hw],
num_or_sections=[2, 2],
axis=0)
self.paddle_graph.add_layer(
"paddle.cast",
inputs={"x": var_hw},
outputs=[var_hw],
dtype=string('int32'))
# inputs['size'] = var_hw
# TODO(syf): all use
inputs['out_shape'] = var_hw
ipt = inputs.pop("x")
inputs["input"] = ipt
mode = node.get_attr('mode', 'nearest')
attrs = {"align_corners": False}
self.paddle_graph.add_layer(
kernel="fluid.layers.resize_nearest",
inputs=inputs,
outputs=[node.name],
**attrs)
return
elif node.layer_type == 'Upsample':
val_scales = self.graph.get_input_node(node, idx=1, copy=True)
inputs['scale'] = val_scales
attr = {'name': string(node.layer_name)}
mode = node.get_attr('mode', 'nearest')
fluid_op = 'resize_{}'.format(mode)
if 'linear' in mode:
print(
'Warnning: paddle not support op:resize wiht mode: linear, we use bilinear replace linear'
)
fluid_op = 'resize_bilinear'
attr['align_corners'] = False
node.fluid_code.add_layer(
fluid_op, inputs=inputs, output=node, param_attr=attr)
attrs = {"align_corners": False,
"mode": string(mode),
"align_mode": 1}
self.paddle_graph.add_layer(
kernel="paddle.nn.functional.interpolate",
inputs=inputs,
outputs=[node.name],
**attrs)
@print_mapping_info
def HardSigmoid(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
alpha = node.get_attr('alpha', 0.2)
beta = node.get_attr('beta', 0.5)
self.paddle_graph.add_layer(
kernel="paddle.scale",
inputs={"x": val_x.name},
outputs=[node.name + "_val"],
scale=alpha,
bias=beta)
self.paddle_graph.add_layer(
kernel="paddle.clip",
inputs={"x": node.name + "_val"},
outputs=[node.name],
min=0.0,
max=1.0)
@print_mapping_info
def Shape(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
self.paddle_graph.add_layer(
kernel="paddle.shape",
inputs={"input": val_x.name},
outputs=[node.name])
self.paddle_graph.add_layer(
'paddle.cast',
inputs={"x": node.name},
outputs=[node.name],
dtype=string('int64'))
@print_mapping_info
def RoiAlign(self, node):
......@@ -362,18 +333,18 @@ class OpSet9():
pooled_width = node.get_attr('output_width')
spatial_scale = node.get_attr('spatial_scale')
sampling_ratio = node.get_attr('sampling_ratio')
attr = {
layer_attrs = {
'pooled_height': pooled_height,
'pooled_width': pooled_width,
'spatial_scale': spatial_scale,
'sampling_ratio': sampling_ratio,
}
node.fluid_code.add_layer(
'roi_align',
inputs={'input': val_x,
'rois': val_rois},
output=node,
param_attr=attr)
self.paddle_graph.add_layer(
'fluid.layers.roi_align',
inputs={'input': val_x.name,
'rois': val_rois.name},
outputs=[node.name],
**layer_attrs)
@print_mapping_info
def MaxRoiPool(self, node):
......@@ -382,17 +353,17 @@ class OpSet9():
spatial_scale = node.get_attr('spatial_scale')
pooled_height, pooled_width = node.get_attr('pooled_shape')
attr = {
layer_attrs = {
'pooled_height': pooled_height,
'pooled_width': pooled_width,
'spatial_scale': spatial_scale,
}
node.fluid_code.add_layer(
'roi_pool',
inputs={'input': val_x,
'rois': val_rois},
output=node,
param_attr=attr)
self.paddle_graph.add_layer(
'fluid.layers.roi_pool',
inputs={'input': val_x.name,
'rois': val_rois.name},
outputs=[node.name],
**layer_attrs)
@print_mapping_info
def Pad(self, node, op_independent=True):
......@@ -403,7 +374,8 @@ class OpSet9():
data_shape = val_x.out_shapes[0]
output_shape = node.out_shapes[0]
assume_pad2d = False
attr = {}
layer_attrs = {}
layer_attrs['mode'] = string(mode)
paddings = []
if len(pads) == 4:
assume_pad2d |= mode != 'constant'
......@@ -412,12 +384,12 @@ class OpSet9():
if output_shape:
assume_pad2d |= output_shape and len(output_shape) == 4 # NCHW
if assume_pad2d:
fluid_op = 'pad2d'
attr['data_format'] = string('NCHW')
attr['mode'] = string(mode)
paddle_op = 'paddle.nn.functional.pad'
layer_attrs['data_format'] = string('NCHW')
layer_attrs['value'] = value
else:
attr = {'pad_value': value}
fluid_op = 'pad'
paddle_op = 'fluid.layers.pad'
layer_attrs["pad_value"] = value
if len(pads) == 4:
paddings = np.array(pads).reshape(
(-1, 2)).transpose().flatten().tolist() # SSEE -> SESE
......@@ -425,51 +397,52 @@ class OpSet9():
paddings = np.array(pads).reshape(
(-1, 4)).transpose().flatten().tolist() # SSEE -> SESE
if sum(paddings[:4]) == 0:
fluid_op = 'pad2d'
paddle_op = 'paddle.nn.functional.pad'
paddings = paddings[4:]
attr['mode'] = string(mode)
attr['paddings'] = paddings
layer_attrs['value'] = value
if 'pad_value' in layer_attrs:
layer_attrs.pop('pad_value')
tmp_paddings = copy.deepcopy(paddings)
paddings[0] = tmp_paddings[2]
paddings[1] = tmp_paddings[3]
paddings[2] = tmp_paddings[0]
paddings[3] = tmp_paddings[1]
if paddle_op == 'paddle.nn.functional.pad':
layer_attrs['pad'] = paddings
else:
layer_attrs['paddings'] = paddings
if op_independent:
attr['name'] = string(node.layer_name)
node.fluid_code.add_layer(
fluid_op, inputs=val_x, output=node, param_attr=attr)
self.paddle_graph.add_layer(
paddle_op,
inputs={'x': val_x.name},
outputs=[node.name],
**layer_attrs)
else:
attr['name'] = string(node.layer_name + '_paded')
node.fluid_code.add_layer(
fluid_op,
inputs=val_x,
output=node.layer_name + '_paded',
param_attr=attr)
return node.layer_name + '_paded'
self.paddle_graph.add_layer(
paddle_op,
inputs={'x': val_x.name},
outputs=[node.name + '_paded'],
**layer_attrs)
return node.name + '_paded'
@print_mapping_info
def Unsqueeze(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
axes = node.get_attr('axes')
attr = {'axes': axes, 'name': string(node.layer_name)}
layer_attrs = {'axis': axes}
if len(val_x.out_shapes[0]) == 0:
if node.layer_name:
node.fluid_code.add_layer(
'reshape',
inputs=val_x,
output=node,
param_attr={'shape': [1]})
if node.name:
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={"x": val_x.name},
outputs=[node.name],
shape=[1])
else:
if str(val_x.dtype) == 'bool':
val_x_cast = val_x.layer_name + '_cast'
node.fluid_code.add_layer(
'cast',
inputs=val_x,
output=val_x_cast,
param_attr={'dtype': string('int64')})
node.fluid_code.add_layer(
'unsqueeze',
inputs=val_x_cast,
output=node,
param_attr=attr)
else:
node.fluid_code.add_layer(
'unsqueeze', inputs=val_x, output=node, param_attr=attr)
self.paddle_graph.add_layer(
'paddle.unsqueeze',
inputs={"x": val_x.name},
outputs=[node.name],
**layer_attrs)
@print_mapping_info
def Shrink(self, node):
......@@ -477,9 +450,11 @@ class OpSet9():
bias = node.get_attr('bias')
lambd = node.get_attr('lambd')
assert bias == 0.0, 'not support bias!=0'
attr = {'threshold': lambd, 'name': node.layer_name}
node.fluid_code.add_layer(
'hard_shrink', inputs=val_x, output=node, param_attr=attr)
self.paddle_graph.add_layer(
'paddle.nn.functional.hardshrink',
inputs={"x": val_x.name},
outputs=[node.name],
threshold=lambd)
@print_mapping_info
def Constant(self, node):
......@@ -500,29 +475,28 @@ class OpSet9():
_logger.warning('in (Constant -> %s): '
'attribute "shape" of %s not inferred, '
'using value as 1-D tensor may lead to fails',
val_output.layer_name, val_output.layer_name)
val_output.name, val_output.name)
if len(value) == 1:
value = value.tolist()
shape = [1]
value = value[0]
if dtype.name == 'int64':
dtype = 'int32'
attr = {'shape': shape, 'dtype': string(dtype), 'value': value}
node.fluid_code.add_layer(
'fill_constant', inputs=None, output=node, param_attr=attr)
self.paddle_graph.add_layer(
"paddle.full",
inputs={},
outputs=[node.name],
dtype=string(dtype),
shape=[1],
fill_value=value)
else:
if dtype.name == 'uint8':
dtype = 'int64'
value = np.reshape(value, shape)
self.weights[node.layer_name] = value
attr = {
'dtype': string(dtype),
'shape': shape,
'name': string(node.layer_name),
'default_initializer': 'Constant(0.0)'
}
node.fluid_code.add_layer(
"create_parameter", inputs=None, output=node, param_attr=attr)
self.params[node.name] = value
self.paddle_graph.add_layer(
kernel="paddle.static.create_parameter",
inputs={},
outputs=[node.name],
dtype=string(dtype),
shape=shape,
name=string(node.name),
default_initializer="paddle.nn.initializer.Constant(value=0.0)")
@print_mapping_info
def Resize(self, node):
......@@ -534,40 +508,57 @@ class OpSet9():
@print_mapping_info
def InstanceNormalization(self, node):
op_name = name_generator("instanse_norm", self.nn_name2id)
output_name = node.name
layer_outputs = [op_name, output_name]
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_scale = self.graph.get_input_node(node, idx=1, copy=True)
val_b = self.graph.get_input_node(node, idx=2, copy=True)
epsilon = node.get_attr('epsilon', 1e-5)
attr = {
layer_attrs = {
'epsilon': epsilon,
'param_attr': string(val_scale.layer_name),
'bias_attr': string(val_b.layer_name)
}
node.fluid_code.add_layer(
"instance_norm", inputs=val_x, output=node, param_attr=attr)
dim = len(val_x.out_shapes[0])
if dim ==2 :
layer_attrs["data_format"] = "NC"
elif dim == 3:
layer_attrs["data_format"] = "NCL"
elif dim == 4:
layer_attrs["data_format"] = "NCHW"
elif dim == 5:
layer_attrs["data_format"] = "NCDHW"
else:
raise Exception("The paddle only support 2D, 3D, 4D or 5D input in InstanceNormalization.")
self.paddle_graph.add_layer(
paddle_op,
inputs={"x": val_x.name,
"weight": val_scale.name,
"bias": val_b.name},
outputs=layer_outputs,
**layer_attrs)
@print_mapping_info
def Expand(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_shape = self.graph.get_input_node(node, idx=1, copy=True)
val_x_dtype = val_x.dtype
name_ones = node.layer_name + '_ones'
name_ones = node.name + '_ones'
attr_ones = {
'shape': val_shape.layer_name,
'shape': val_shape.name,
'dtype': string(val_x_dtype),
'value': 1
'fill_value': 1
}
node.fluid_code.add_layer(
'fill_constant',
inputs=None,
output=name_ones,
param_attr=attr_ones)
inputs = {'x': name_ones, 'y': val_x}
node.fluid_code.add_layer(
'elementwise_mul',
inputs=inputs,
output=node.layer_name,
param_attr=None)
self.paddle_graph.add_layer(
'paddle.full',
inputs={},
outputs=[name_ones],
**attr_ones)
inputs_dict = {'x': name_ones,
'y': val_x.name}
self.paddle_graph.add_layer(
'paddle.multiply',
inputs=inputs_dict,
outputs=[node.name])
@print_mapping_info
def Gather(self, node):
......@@ -579,147 +570,140 @@ class OpSet9():
# indices_shape) <= 2, "Gather op don't support dim of indice >2 "
if axis == 0 and len(indices_shape) <= 1:
if len(val_x.out_shapes[0]) <= 1:
node.fluid_code.add_layer(
'gather',
inputs={'input': val_x,
'index': indices},
output=node,
param_attr=None)
self.paddle_graph.add_layer(
'paddle.gather',
inputs={'x': val_x.name,
'index': indices.name},
outputs=[node.name])
elif len(val_x.out_shapes[0]) > 1:
if len(indices_shape) == 0:
gather_ = node.layer_name + '_1'
node.fluid_code.add_layer(
'gather',
inputs={'input': val_x,
'index': indices},
output=gather_,
param_attr=None)
node.fluid_code.add_layer(
'squeeze',
inputs={'input': gather_,
'axes': [0]},
output=node,
param_attr=None)
gather_ = node.name + '_1'
self.paddle_graph.add_layer(
'paddle.gather',
inputs={'x': val_x.name,
'index': indices.name},
outputs=[gather_])
self.paddle_graph.add_layer(
'paddle.squeeze',
inputs={'x': gather_},
outputs=[node.name],
axis=[0])
else:
node.fluid_code.add_layer(
'gather',
inputs={'input': val_x,
'index': indices},
output=node,
param_attr=None)
self.paddle_graph.add_layer(
'paddle.gather',
inputs={'x': val_x.name,
'index': indices.name},
outputs=[node.name])
elif axis > 0 and len(indices_shape) <= 1:
perm = list(range(len(val_x.out_shapes[0])))
perm = [axis] + perm[:axis] + perm[axis + 1:]
attr_trans = {'perm': perm}
name_trans = val_x.layer_name + '_trans'
node.fluid_code.add_layer(
'transpose',
inputs=val_x,
output=name_trans,
param_attr=attr_trans)
node.fluid_code.add_layer(
'gather',
inputs={'input': name_trans,
'index': indices},
output=node,
param_attr=None)
node.fluid_code.add_layer(
'transpose', inputs=node, output=node, param_attr=attr_trans)
name_trans = val_x.name + '_trans'
self.paddle_graph.add_layer(
'paddle.transpose',
inputs={"x": val_x.name},
outputs=[name_trans],
perm=perm)
self.paddle_graph.add_layer(
'paddle.gather',
inputs={'x': name_trans,
'index': indices.name},
outputs=[node.name])
self.paddle_graph.add_layer(
'paddle.transpose',
inputs={"x": node.name},
outputs=[node.name],
perm=perm)
if len(indices_shape) < 1:
node.fluid_code.add_layer(
'squeeze',
inputs={'input': node,
'axes': [axis]},
output=node,
param_attr=None)
self.paddle_graph.add_layer(
'paddle.squeeze',
inputs={'x': node.name},
outputs=[node.name],
axis=[axis])
elif axis == 0 and len(indices_shape) > 1:
if val_x.out_shapes[0] is not None and isinstance(
val_x, ONNXGraphDataNode):
indices_cast = indices.layer_name + '_cast'
node.fluid_code.add_layer(
'cast',
inputs=indices,
output=indices_cast,
param_attr={'dtype': string('int64')})
node.fluid_code.add_layer(
'embedding',
inputs=indices_cast,
output=node,
use_fluid=True,
param_attr={
'param_attr': string(val_x.layer_name),
'size': val_x.out_shapes[0]
})
indices_cast = indices.name + '_cast'
self.paddle_graph.add_layer(
'paddle.cast',
inputs={"x": indices.name},
outputs=indices_cast,
dtype=string('int64'))
op_name = name_generator("embedding", self.nn_name2id)
output_name = node.name
layer_outputs = [op_name, output_name]
self.paddle_graph.add_layer(
'paddle.nn.Embedding',
inputs={"x": indices_cast},
outputs=layer_outputs,
param_attr=string(val_x.name),
size=val_x.out_shapes[0])
else:
from functools import reduce
reshape_shape = reduce(lambda x, y: x * y, indices_shape)
indices_reshape = indices.layer_name + '_shape'
node.fluid_code.add_layer(
'reshape',
inputs=indices,
output=indices_reshape,
param_attr={'shape': [reshape_shape, ]})
indices_reshape = indices.name + '_shape'
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={"x": indices.name},
outputs=[indices_reshape],
shape=[reshape_shape, ])
perm = list(range(len(val_x.out_shapes[0])))
node.fluid_code.add_layer(
'gather',
inputs={'input': val_x,
self.paddle_graph.add_layer(
'paddle.gather',
inputs={'x': val_x.name,
'index': indices_reshape},
output=node,
param_attr=None)
outputs=[node.name])
val_x_shape = val_x.out_shapes[0]
reshaped_shape = []
for i in perm:
reshaped_shape.append(indices_shape[i])
for i in val_x_shape[:axis] + val_x_shape[axis + 1:]:
reshaped_shape.append(i)
node.fluid_code.add_layer(
'reshape',
inputs=node,
output=node,
param_attr={'shape': reshaped_shape})
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={"x": node.name},
outputs=[node.name],
shape=reshaped_shape)
elif axis > 0 and len(indices_shape) > 1:
from functools import reduce
reshape_shape = reduce(lambda x, y: x * y, indices_shape)
indices_reshape = indices.layer_name + '_shape'
node.fluid_code.add_layer(
'reshape',
inputs=indices,
output=indices_reshape,
param_attr={'shape': [reshape_shape, ]})
indices_reshape = indices.name + '_shape'
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={"x": indices.name},
outputs=[indices_reshape],
shape=[reshape_shape, ])
perm = list(range(len(val_x.out_shapes[0])))
perm = [axis] + perm[:axis] + perm[axis + 1:]
attr_trans = {'perm': perm}
name_trans = val_x.layer_name + '_transpose'
node.fluid_code.add_layer(
'transpose',
inputs=val_x,
output=name_trans,
param_attr=attr_trans)
node.fluid_code.add_layer(
'gather',
inputs={'input': name_trans,
name_trans = val_x.name + '_transpose'
self.paddle_graph.add_layer(
'paddle.transpose',
inputs={"x": val_x.name},
outputs=[name_trans],
perm=perm)
self.paddle_graph.add_layer(
'paddle.gather',
inputs={'x': name_trans,
'index': indices_reshape},
output=node,
param_attr=None)
input_transpose = node.layer_name + '_transpose'
node.fluid_code.add_layer(
'transpose',
inputs=node,
output=input_transpose,
param_attr=attr_trans)
outputs=[node.name])
input_transpose = node.name + '_transpose'
self.paddle_graph.add_layer(
'paddle.transpose',
inputs={"x": node.name},
outputs=[input_transpose],
perm=perm)
val_x_shape = val_x.out_shapes[0]
reshaped_shape = []
for i in perm:
reshaped_shape.append(indices_shape[i])
for i in val_x_shape[:axis] + val_x_shape[axis + 1:]:
reshaped_shape.append(i)
node.fluid_code.add_layer(
'reshape',
inputs=input_transpose,
output=node,
param_attr={'shape': reshaped_shape})
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={"x": input_transpose},
outputs=[node.name],
shape=reshaped_shape)
@print_mapping_info
def ScatterND(self, node):
......@@ -727,85 +711,78 @@ class OpSet9():
indices = self.graph.get_input_node(node, idx=1, copy=True)
updates = self.graph.get_input_node(node, idx=2, copy=True)
if len(indices.out_shapes[0]) == 1:
node.fluid_code.add_layer(
'scatter',
inputs={'input': val_x,
'index': indices,
'updates': updates},
output=node,
param_attr=None)
self.paddle_graph.add_layer(
'paddle.scatter',
inputs={'x': val_x.name,
'index': indices.name,
'updates': updates.name},
outputs=[node.name])
else:
input_inner_indices = node.layer_name + '_input_inner_indices'
input_inner_indices = node.name + '_input_inner_indices'
shape = val_x.out_shapes[0]
node.fluid_code.add_layer(
'reshape',
inputs=indices.layer_name,
output=indices.layer_name,
param_attr={'shape': indices.out_shapes[0]})
zeros_like_val_x = val_x.layer_name + '_zeros'
node.fluid_code.add_layer(
'zeros_like',
inputs=val_x,
output=zeros_like_val_x,
param_attr=None)
node.fluid_code.add_layer(
'scatter_nd_add',
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={"x": indices.name},
outputs=[indices.name],
shape=indices.out_shapes[0])
zeros_like_val_x = val_x.name + '_zeros'
self.paddle_graph.add_layer(
'paddle.zeros_like',
inputs={"x": val_x.name},
outputs=[zeros_like_val_x])
self.paddle_graph.add_layer(
'paddle.scatter_nd_add',
inputs={
'ref': zeros_like_val_x,
'index': indices,
'updates': updates
'x': zeros_like_val_x,
'index': indices.name,
'updates': updates.name
},
output=input_inner_indices,
param_attr=None)
indices_mask = node.layer_name + '_indices_mask'
constant_minus_one = node.layer_name + '_constant_minus_one'
outputs=[input_inner_indices])
indices_mask = node.name + '_indices_mask'
constant_minus_one = node.name + '_constant_minus_one'
# full_like support create tensor shape like input tensor
node.fluid_code.add_layer(
'full_like',
inputs=updates,
output=constant_minus_one,
param_attr={'dtype': string(updates.dtype),
'fill_value': -1})
node.fluid_code.add_layer(
'scatter_nd_add',
self.paddle_graph.add_layer(
'paddle.full_like',
inputs={"x": updates.name},
outputs=[constant_minus_one],
dtype=string(updates.dtype),
fill_value=-1)
self.paddle_graph.add_layer(
'paddle.scatter_nd_add',
inputs={
'ref': zeros_like_val_x,
'index': indices,
'x': zeros_like_val_x,
'index': indices.name,
'updates': constant_minus_one
},
output=indices_mask,
param_attr=None)
constant_one = node.layer_name + '_constant_1'
outputs=[indices_mask])
constant_one = node.name + '_constant_1'
# full_like support create tensor shape like input tensor
node.fluid_code.add_layer(
'full_like',
inputs=val_x,
output=constant_one,
param_attr={'dtype': string(val_x.dtype),
'fill_value': 1})
input_out_indices_mask = node.layer_name + '_input_out_indices_mask'
node.fluid_code.add_layer(
"elementwise_add",
self.paddle_graph.add_layer(
'paddle.full_like',
inputs={"x": val_x.name},
outputs=[constant_one],
dtype=string(val_x.dtype),
fill_value=1)
input_out_indices_mask = node.name + '_input_out_indices_mask'
self.paddle_graph.add_layer(
"paddle.add",
inputs={"x": indices_mask,
"y": constant_one},
output=input_out_indices_mask,
param_attr=None)
outputs=[input_out_indices_mask])
input_out_indices = node.layer_name + '_input_out_indices'
node.fluid_code.add_layer(
"elementwise_mul",
inputs={"x": val_x,
input_out_indices = node.name + '_input_out_indices'
self.paddle_graph.add_layer(
"paddle.multiply",
inputs={"x": val_x.name,
"y": input_out_indices_mask},
output=input_out_indices,
param_attr=None)
outputs=[input_out_indices])
node.fluid_code.add_layer(
"elementwise_add",
self.paddle_graph.add_layer(
"paddle.add",
inputs={"x": input_inner_indices,
"y": input_out_indices},
output=node,
param_attr=None)
outputs=[node.name])
@print_mapping_info
def Range(self, node):
......@@ -813,18 +790,20 @@ class OpSet9():
val_limit = self.graph.get_input_node(node, idx=1, copy=True)
val_delta = self.graph.get_input_node(node, idx=2, copy=True)
dtype = val_start.dtype
inputs = {'start': val_start, 'end': val_limit, 'step': val_delta}
node.fluid_code.add_layer(
'range',
inputs = {'start': val_start.name,
'end': val_limit.name,
'step': val_delta.name}
self.paddle_graph.add_layer(
'paddle.arange',
inputs=inputs,
output=node,
param_attr={'dtype': string(dtype)})
outputs=[node.name],
dtype=string(dtype))
@print_mapping_info
def Slice(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
starts, ends, axes, steps = None, None, None, None
attr = {}
layer_attrs = {}
if len(node.inputs) > 1:
starts = self.graph.get_input_node(node, idx=1, copy=True)
ends = self.graph.get_input_node(node, idx=2, copy=True)
......@@ -837,14 +816,12 @@ class OpSet9():
if len(node.inputs) > 4:
steps = self.graph.get_input_node(node, idx=4, copy=True)
steps = _const_weight_or_none(steps)
attr = {
layer_attrs = {
"axes": axes,
"starts": starts.layer_name,
"ends": ends.layer_name
"starts": starts.name,
"ends": ends.name
}
if starts_value is not None and ends_value is not None:
self.omit_nodes.append(starts.layer_name)
self.omit_nodes.append(ends.layer_name)
starts_value = starts_value.copy()
ends_value = ends_value.copy()
#for idx in range(len(ends_value)):
......@@ -858,28 +835,28 @@ class OpSet9():
starts_value[idx] = val_x.out_shapes[0][axes[idx]] - 1
elif ends_value[idx] > 2**31 - 1:
ends_value[idx] = 2**31 - 1
attr = {
layer_attrs = {
"axes": axes,
"starts": starts_value,
"ends": ends_value
}
else:
if starts.dtype != 'int32':
starts_cast = starts.layer_name + '_cast'
node.fluid_code.add_layer(
'cast',
inputs=starts,
output=starts_cast,
param_attr={'dtype': string('int32')})
attr['starts'] = starts_cast
starts_cast = starts.name + '_cast'
self.paddle_graph.add_layer(
'paddle.cast',
inputs={"x": starts.name},
outputs=[starts_cast],
dtype=string('int32'))
layer_attrs['starts'] = starts_cast
if ends.dtype != 'int32':
ends_cast = ends.layer_name + '_cast'
node.fluid_code.add_layer(
'cast',
inputs=ends,
output=ends_cast,
param_attr={'dtype': string('int32')})
attr['ends'] = ends_cast
ends_cast = ends.name + '_cast'
self.paddle_graph.add_layer(
'paddle.cast',
inputs={"x": ends.name},
outputs=[ends_cast],
dtype=string('int32'))
layer_attrs['ends'] = ends_cast
else:
starts = node.get_attr('starts')
ends = node.get_attr('ends')
......@@ -887,15 +864,21 @@ class OpSet9():
for idx in range(len(ends)):
if ends[idx] > 2**31 - 1:
ends[idx] = 2**31 - 1
attr = {"axes": axes, "starts": starts, "ends": ends}
layer_attrs = {"axes": axes, "starts": starts, "ends": ends}
if steps is not None:
attr['strides'] = steps
node.fluid_code.add_layer(
'strided_slice', inputs=val_x, output=node, param_attr=attr)
layer_attrs['strides'] = steps
self.paddle_graph.add_layer(
'paddle.strided_slice',
inputs={"x": val_x.name},
outputs=[node.name],
**layer_attrs)
else:
node.fluid_code.add_layer(
'slice', inputs=val_x, output=node, param_attr=attr)
self.paddle_graph.add_layer(
'paddle.slice',
inputs={"input": val_x.name},
outputs=[node.name],
**layer_attrs)
@print_mapping_info
def ConstantOfShape(self, node):
......@@ -909,13 +892,16 @@ class OpSet9():
'this is not supported')
if len(value) == 1:
value = value[0]
attr = {
'shape': val_shape.layer_name,
layer_attrs = {
'shape': val_shape.name,
'dtype': string(dtype),
'value': value
'fill_value': value
}
node.fluid_code.add_layer(
'fill_constant', inputs=None, output=node, param_attr=attr)
self.paddle_graph.add_layer(
"paddle.full",
inputs={},
outputs=[node.name],
**layer_attrs)
@print_mapping_info
def Clip(self, node):
......@@ -925,104 +911,90 @@ class OpSet9():
if len(node.inputs) == 1:
max_value = node.get_attr('max')
min_value = node.get_attr('min')
attr = {
layer_attrs = {
'max': max_value,
'min': min_value,
}
node.fluid_code.add_layer(
'clip', inputs=val_x, output=node, param_attr=attr)
self.paddle_graph.add_layer(
'paddle.clip',
inputs={"x": val_x.name},
outputs=[node.name],
**layer_attrs)
else:
max_ipt = self.graph.get_input_node(node, idx=1, copy=True)
min_ipt = self.graph.get_input_node(node, idx=2, copy=True)
max_value = _const_weight_or_none(max_ipt)
min_value = _const_weight_or_none(min_ipt)
self.omit_nodes.append(max_ipt.layer_name)
self.omit_nodes.append(min_ipt.layer_name)
if max_value.shape == (1, ):
max_value = max_value[0]
if min_value.shape == (1, ):
min_value = min_value[0]
if max_value is not None and min_value is not None:
attr = {'max': max_value, 'min': min_value}
node.fluid_code.add_layer(
'clip', inputs=val_x, output=node, param_attr=attr)
layer_attrs = {'max': max_value, 'min': min_value}
self.paddle_graph.add_layer(
'paddle.clip',
inputs={"x": val_x.name},
outputs=[node.name],
**layer_attrs)
else:
raise
@print_mapping_info
def Split(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_node(node.layer.output[0], copy=True)
fluid_op = 'split'
paddle_op = 'split'
split = node.get_attr('split')
axis = node.get_attr('axis', 0)
attr = {
layer_attrs = {
'num_or_sections': split,
'dim': axis,
'name': string(node.layer_name)
'axis': axis,
}
node.fluid_code.add_layer(
'split', inputs=val_x, output=val_y, param_attr=attr)
outputs_list = list()
if isinstance(split, list) or isinstance(split, tuple):
for i in range(len(split)):
outputs_list.append("{}_p{}".format(node.layer_name, i))
else:
outputs_list.append(node.name)
self.paddle_graph.add_layer(
'paddle.split',
inputs={"x": val_x.name},
outputs=outputs_list,
**layer_attrs)
@print_mapping_info
def Reshape(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_shape = self.graph.get_input_node(node, idx=1, copy=True)
val_reshaped = self.graph.get_node(node.layer.output[0], copy=True)
attr = {}
shape_value = _const_weight_or_none(val_shape)
shape_dims = len(val_shape.out_shapes[0])
if shape_value is not None:
node.fluid_code.add_layer(
'reshape',
inputs={'x': val_x},
output=node,
param_attr={'shape': shape_value.tolist()})
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={'x': val_x.name},
outputs=[node.name],
shape=shape_value.tolist())
elif len(node.out_shapes[0]) > 0 and _is_static_shape(node.out_shapes[
0]):
node.fluid_code.add_layer(
'reshape',
inputs={'x': val_x,
'shape': node.out_shapes[0]},
output=node,
param_attr=attr)
elif val_shape.dtype == 'int64':
val_shape_cast = val_shape.layer_name + '_cast'
node.fluid_code.add_layer(
'cast',
inputs=val_shape,
output=val_shape_cast,
param_attr={'dtype': string('int32')})
# shape may be [], come form Gather by scalar indices
if len(val_shape.out_shapes[0]) > 0:
node.fluid_code.add_layer(
'reshape',
inputs=val_shape_cast,
output=val_shape_cast,
param_attr={'shape': val_shape.out_shapes[0]})
node.fluid_code.add_layer(
'reshape',
inputs={'x': val_x,
'shape': val_shape_cast},
output=node,
param_attr=attr)
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={'x': val_x.name},
outputs=[node.name],
shape=node.out_shapes[0])
else:
# shape may be [], come form Gather by scalar indices
if len(val_shape.out_shapes[0]) > 0:
node.fluid_code.add_layer(
'reshape',
inputs=val_shape,
output=val_shape,
param_attr={'shape': val_shape.out_shapes[0]})
node.fluid_code.add_layer(
'reshape',
inputs={'x': val_x,
'shape': val_shape},
output=node,
param_attr=attr)
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={'x': val_shape.name},
outputs=[val_shape.name],
shape=val_shape.out_shapes[0])
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={'x': val_x.name,
'shape': val_shape.name},
outputs=node)
@print_mapping_info
def Cast(self, node):
......@@ -1036,14 +1008,18 @@ class OpSet9():
output_dtype = val_output.dtype
if output_dtype:
assert dtype == output_dtype, 'dtype of to unmatches output'
attr = {'dtype': string(dtype)}
node.fluid_code.add_layer(
'cast', inputs=val_input, output=node, param_attr=attr)
self.paddle_graph.add_layer(
'paddle.cast',
inputs={'x': val_input.name},
outputs=[node.name],
dtype=string(dtype))
@print_mapping_info
def Not(self, node):
val_input = self.graph.get_input_node(node, idx=0, copy=True)
node.fluid_code.add_layer('logical_not', inputs=val_input, output=node)
self.paddle_graph.add_layer('paddle.logical_not',
inputs={'x': val_input.name},
outputs=[node.name])
@print_mapping_info
def AveragePool(self, node):
......@@ -1056,8 +1032,6 @@ class OpSet9():
pad_mode = node.get_attr("pads")
ceil_mode = bool(node.get_attr('ceil_mode', 0))
pads = node.get_attr('pads', [0] * (poolnd * 2))
fluid_op = 'pool{}d'.format(poolnd)
assert 2 <= poolnd <= 3, 'only pool2d and pool3d is supported'
paddings, val_x = self._pad_if_asymmetric(node, pads, val_x)
......@@ -1069,44 +1043,60 @@ class OpSet9():
strides[1])
paddings = pad_h + pad_w
attr = {
paddle_op = 'fluid.layers.pool{}d'.format(poolnd)
assert 2 <= poolnd <= 3, 'only pool2d and pool3d are supported'
layer_attrs = {
"pool_size": kernel_shape,
"pool_type": string('avg'),
"pool_stride": strides,
"pool_padding": paddings,
"ceil_mode": ceil_mode,
"exclusive": 'True',
"name": string(node.layer_name)
"name": string(node.name)
}
node.fluid_code.add_layer(
fluid_op, inputs=val_x, output=node, param_attr=attr)
self.paddle_graph.add_layer(
paddle_op,
inputs={'input': val_x if isinstance(val_x, str) else val_x.name},
outputs=[node.name],
**layer_attrs)
# TODO(syf): op has diff
@print_mapping_info
def Concat(self, node):
inputs = []
inputs_list = []
dtypes = set()
for i in range(len(node.layer.input)):
ipt = self.graph.get_input_node(node, idx=i, copy=True)
if isinstance(ipt, str):
inputs.append(ipt)
else:
inputs.append(ipt.layer_name)
dtypes.add(ipt.dtype)
inputs_list.append(ipt.name)
dtypes.add(ipt.dtype)
if len(dtypes) > 1:
assert 'Unspported situation happened, please create issue on https://github.com/PaddlePaddle/X2Paddle/issues.'
axis = node.get_attr('axis')
attr = {'axis': axis}
node.fluid_code.add_layer(
'concat', inputs=inputs, output=node, param_attr=attr)
self.paddle_graph.add_layer(
'paddle.concat',
inputs={"x": inputs_list},
outputs=[node.name],
axis=axis)
@print_mapping_info
def Flatten(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
output_shape = node.out_shapes[0]
axis = node.get_attr('axis', 1)
attr = {"axis": str(axis), "name": string(node.layer_name)}
node.fluid_code.add_layer(
'flatten', inputs=val_x, output=node, param_attr=attr)
shape_list = [1, 1]
if axis == 0:
for s in output_shape:
shape_list[1] *= s
else:
for s in output_shape[:axis]:
shape_list[0] *= s
for s in output_shape[axis:]:
shape_list[1] *= s
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={"x": val_x.name},
outputs=[node.name],
shape=shape_list)
@print_mapping_info
def Gemm(self, node):
......@@ -1118,65 +1108,68 @@ class OpSet9():
beta = node.get_attr('beta', 1.) # optional
trans_a = bool(node.get_attr('transA', 0)) # optional
trans_b = bool(node.get_attr('transB', 0)) # optional
val_mm = node.layer_name + '_mm'
matmul_inputs = {"x": val_a, "y": val_b}
val_mm = node.name + '_mm'
matmul_inputs = {"x": val_a.name,
"y": val_b.name}
attr_matmul = {
"transpose_x": trans_a,
"transpose_y": trans_b,
"alpha": alpha,
"name": string(val_mm)
}
node.fluid_code.add_layer(
'matmul',
self.paddle_graph.add_layer(
'paddle.matmul',
inputs=matmul_inputs,
output=val_mm,
param_attr=attr_matmul)
outputs=[val_mm],
**attr_matmul)
self.paddle_graph.add_layer(
"paddle.scale",
inputs={"x": val_mm},
outputs=[val_mm],
scale=alpha)
if beta != 0:
if beta == 1.:
add_inputs = {"x": val_mm, "y": val_c}
attr = {"name": string(node.layer_name)}
node.fluid_code.add_layer(
"elementwise_add",
add_inputs = {"x": val_mm,
"y": val_c.name}
self.paddle_graph.add_layer(
"paddle.add",
inputs=add_inputs,
output=node,
param_attr=attr)
outputs=[node.name])
else:
var_beta = node.layer_name + '_beta'
matmul_beta_inputs = {"x": val_c, "y": var_beta}
node.fluid_code.add_layer(
"Constant",
inputs=matmul_beta_inputs,
output=var_beta,
param_attr={'value': beta})
var_beta = node.name + '_beta'
self.paddle_graph.add_layer(
"paddle.scale",
inputs={"x": val_c.name},
outputs=[var_beta],
scale=beta)
add_inputs = {"x": val_mm, "y": var_beta}
attr = {"name": string(node.layer_name)}
node.fluid_code.add_layer(
"elementwise_add",
self.paddle_graph.add_layer(
"paddle.add",
inputs=add_inputs,
output=node,
param_attr=attr)
outputs=[node.name])
@print_mapping_info
def Sum(self, node):
val_inps = node.layer.input
inputs = {
inputs_dict = {
"x": self.graph.get_input_node(
node, idx=0, copy=True),
node, idx=0, copy=True).name,
"y": self.graph.get_input_node(
node, idx=1, copy=True),
node, idx=1, copy=True).name,
}
node.fluid_code.add_layer("elementwise_add", inputs=inputs, output=node)
self.paddle_graph.add_layer("paddle.add",
inputs=inputs_dict,
outputs=[node.name])
for idx, ipt in enumerate(val_inps[2:]):
y = self.graph.get_input_node(node, idx=idx, copy=True)
inputs = {
"x": node.layer_name,
"y": y,
inputs_dict = {
"x": node.name,
"y": y.name,
}
node.fluid_code.add_layer(
"elementwise_add", inputs=inputs, output=node)
self.paddle_graph.add_layer(
"paddle.add",
inputs=inputs_dict,
outputs=[node.name])
@print_mapping_info
def MatMul(self, node):
......@@ -1184,21 +1177,26 @@ class OpSet9():
val_y = self.graph.get_input_node(node, idx=1, copy=True)
x_shape = val_x.out_shapes[0]
y_shape = val_y.out_shapes[0]
inputs = {"x": val_x, "y": val_y}
inputs_dict = {"x": val_x.name,
"y": val_y.name}
if y_shape[0] == 1 and x_shape[-1] != 1 and x_shape[0] != 1:
y_squeeze = val_y.layer_name + '_squeeze'
node.fluid_code.add_layer(
"squeeze",
inputs=val_y,
output=y_squeeze,
param_attr={'axes': [0]})
inputs['y'] = y_squeeze
node.fluid_code.add_layer(
"matmul", inputs=inputs, output=node, param_attr=None)
y_squeeze = val_y.name + '_squeeze'
self.paddle_graph.add_layer(
"paddle.squeeze",
inputs={"x": val_y.name},
outputs=[y_squeeze],
axis=[0])
inputs_dict['y'] = y_squeeze
self.paddle_graph.add_layer(
"paddle.matmul",
inputs=inputs_dict,
outputs=[node.name])
else:
node.fluid_code.add_layer(
"matmul", inputs=inputs, output=node, param_attr=None)
self.paddle_graph.add_layer(
"paddle.matmul",
inputs=inputs_dict,
outputs=[node.name])
@print_mapping_info
def BatchNormalization(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
......@@ -1207,108 +1205,98 @@ class OpSet9():
val_mean = self.graph.get_input_node(node, idx=3, copy=True)
val_var = self.graph.get_input_node(node, idx=4, copy=True)
self.omit_nodes.append(val_scale.layer_name)
self.omit_nodes.append(val_b.layer_name)
self.omit_nodes.append(val_mean.layer_name)
self.omit_nodes.append(val_var.layer_name)
momentum = node.get_attr('momentum', .9)
epsilon = node.get_attr('epsilon', 1e-5)
# Attribute: spatial is used in BatchNormalization-1,6,7
spatial = bool(node.get_attr('spatial'))
attr = {
layer_attrs = {
"momentum": momentum,
"epsilon": epsilon,
"data_layout": string('NCHW'),
"is_test": True,
"param_attr": string(val_scale.layer_name),
"bias_attr": string(val_b.layer_name),
"moving_mean_name": string(val_mean.layer_name),
"moving_variance_name": string(val_var.layer_name),
"use_global_stats": spatial,
"name": string(node.layer_name)
}
node.fluid_code.add_layer(
"batch_norm", inputs=val_x, output=node, param_attr=attr)
self.paddle_graph.add_layer(
"paddle.nn.functional.batch_norm",
inputs={"x": val_x.name,
"weight": val_scale.name,
"bias": val_b.name,
"running_mean": val_mean.name,
"running_var": val_var.name},
outputs=[node.name],
**layer_attrs)
@print_mapping_info
def Transpose(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
perm = node.get_attr('perm')
attr = {'perm': perm, "name": string(node.layer_name)}
node.fluid_code.add_layer(
"transpose", inputs=val_x, output=node, param_attr=attr)
@print_mapping_info
def Relu(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
attr = {"name": string(node.layer_name)}
node.fluid_code.add_layer(
"relu", inputs=val_x, output=node, param_attr=attr)
self.paddle_graph.add_layer(
"paddle.transpose",
inputs={"x": val_x.name},
outputs=[node.name],
perm=perm)
@print_mapping_info
def PRelu(self, node):
op_name = name_generator("prelu", self.nn_name2id)
output_name = node.name
layer_outputs = [op_name, output_name]
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_slope = self.graph.get_input_node(node, idx=1, copy=True)
mode = 'channel'
shape_slope = val_slope.out_shapes[0]
if shape_slope == [1]:
mode = 'all'
elif len(shape_slope) > 2:
mode = 'element'
raise Exception("The 'element' mode is not supported yet!")
if mode == 'channel' and len(shape_slope) == 1:
# paddle params shape need be [1, channel]
slope_data = _const_weight_or_none(val_slope)
slope_data = np.reshape(slope_data, [1] + shape_slope)
self.weights[val_slope.layer_name] = slope_data
self.omit_nodes.append(val_slope.layer_name)
attr = {
"param_attr": string(val_slope.layer_name),
'mode': string(mode)
}
node.fluid_code.add_layer(
"prelu", inputs=val_x, output=node, param_attr=attr)
self.params[val_slope.name] = slope_data
self.paddle_graph.add_layer(
"paddle.nn.functional.prelu",
inputs={"x": val_x.name,
"weight": val_slope.name},
outputs=[node.name])
@print_mapping_info
def Squeeze(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
axes = node.get_attr('axes')
attr = {'axes': axes, "name": string(node.layer_name)}
if len(val_x.out_shapes[0]) == 1:
node.fluid_code.add_layer(
"cast",
inputs=val_x,
output=node,
param_attr={'dtype': string(val_x.dtype)})
self.paddle_graph.add_layer(
"paddle.cast",
inputs={"x": val_x.name},
outputs=[node.name],
dtype=string(val_x.dtype))
else:
node.fluid_code.add_layer(
"squeeze", inputs=val_x, output=node, param_attr=attr)
self.paddle_graph.add_layer(
"paddle.squeeze",
inputs={"x": val_x.name},
outputs=[node.name],
axis=axes)
@print_mapping_info
def Equal(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_input_node(node, idx=1, copy=True)
node.fluid_code.add_layer(
"equal",
inputs={'x': val_x,
'y': val_y},
output=node,
param_attr=None)
self.paddle_graph.add_layer(
"paddle.equal",
inputs={'x': val_x.name,
'y': val_y.name},
outputs=[node.name])
@print_mapping_info
def Greater(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_input_node(node, idx=1, copy=True)
node.fluid_code.add_layer(
"greater_than",
inputs={'x': val_x,
'y': val_y},
output=node,
self.paddle_graph.add_layer(
"paddle.greater_than",
inputs={'x': val_x.name,
'y': val_y.name},
outputs=node,
param_attr=None)
@print_mapping_info
......@@ -1317,72 +1305,80 @@ class OpSet9():
val_x = self.graph.get_input_node(node, idx=1, copy=True)
val_y = self.graph.get_input_node(node, idx=2, copy=True)
not_condition = condition.layer_name + '_not'
node.fluid_code.add_layer(
"logical_not",
inputs=condition,
output=not_condition,
param_attr=None)
not_condition = condition.name + '_not'
self.paddle_graph.add_layer(
"paddle.logical_not",
inputs={"x": condition.name},
outputs=[not_condition])
cast_not_condition = not_condition + '_cast'
node.fluid_code.add_layer(
"cast",
inputs=not_condition,
output=cast_not_condition,
param_attr={'dtype': string(val_x.dtype)})
cast_condition = condition.layer_name + '_cast'
node.fluid_code.add_layer(
"cast",
inputs=condition,
output=cast_condition,
param_attr={'dtype': string(val_x.dtype)})
mul_val_x = val_x.layer_name + '_mul'
node.fluid_code.add_layer(
"elementwise_mul",
inputs={'x': val_x,
self.paddle_graph.add_layer(
"paddle.cast",
inputs={"x": not_condition},
outputs=[cast_not_condition],
dtype=string(val_x.dtype))
cast_condition = condition.name + '_cast'
self.paddle_graph.add_layer(
"paddle.cast",
inputs={"x": condition.name},
outputs=[cast_condition],
dtype=string(val_x.dtype))
mul_val_x = val_x.name + '_mul'
self.paddle_graph.add_layer(
"paddle.multiply",
inputs={'x': val_x.name,
'y': cast_condition},
output=mul_val_x,
param_attr=None)
mul_val_y = val_y.layer_name + '_mul'
node.fluid_code.add_layer(
"elementwise_mul",
inputs={'x': val_y,
outputs=[mul_val_x])
mul_val_y = val_y.name + '_mul'
self.paddle_graph.add_layer(
"paddle.multiply",
inputs={'x': val_y.name,
'y': cast_not_condition},
output=mul_val_y,
param_attr=None)
outputs=[mul_val_y])
node.fluid_code.add_layer(
"elementwise_add",
self.paddle_graph.add_layer(
"paddle.add",
inputs={'x': mul_val_x,
'y': mul_val_y},
output=node,
param_attr=None)
outputs=[node.name])
@print_mapping_info
def NonZero(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_x_dim = len(val_x.out_shapes[0])
if val_x_dim == 1:
node.fluid_code.add_layer("nonzero", inputs=val_x, output=val_x)
node.fluid_code.add_layer(
"transpose",
inputs=val_x,
output=node,
param_attr={'perm': [1, 0]})
self.paddle_graph.add_layer(
"paddle.nonzero",
inputs={"x": val_x.name},
outputs=[val_x.name])
self.paddle_graph.add_layer(
"paddle.transpose",
inputs={"x": val_x.name},
outputs=[node.layer_naem],
perm=[1, 0])
if val_x_dim > 1:
node.fluid_code.add_layer("nonzero", inputs=val_x, output=val_x)
node.fluid_code.add_layer(
"split",
inputs=val_x,
output=val_x,
param_attr={'num_or_sections': 1,
'dim': val_x_dim})
node.fluid_code.add_layer("concat", inputs=val_x, output=node)
self.paddle_graph.add_layer(
"paddle.nonzero",
inputs={"x": val_x.name},
outputs=[val_x.name])
self.paddle_graph.add_layer(
"paddle.split",
inputs={"x": val_x.name},
outputs=[val_x.name],
num_or_sections=1,
axis=val_x_dim)
self.paddle_graph.add_layer(
"paddle.concat",
inputs={"x": val_x.name},
outputs=[node.name])
@print_mapping_info
def Identity(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
node.fluid_code.add_layer("assign", inputs=val_x, output=node)
self.paddle_graph.add_layer(
"paddle.assign",
inputs={"x": val_x.name},
outputs=[node.name])
@print_mapping_info
def Tile(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
......@@ -1390,14 +1386,13 @@ class OpSet9():
repeats = _const_weight_or_none(val_repeats)
if repeats is None:
repeats = val_repeats.layer_name
repeats = val_repeats.name
if val_repeats.dtype != 'int32':
attr = {"dtype": string("int32")}
node.fluid_code.add_layer(
"cast",
inputs=repeats,
output="{}.tmp".format(repeats),
param_attr=attr)
self.paddle_graph.add_layer(
"paddle.cast",
inputs={"x": repeats},
outputs=["{}.tmp".format(repeats)],
dtype=string("int32"))
repeats = "{}.tmp".format(repeats)
elif isinstance(repeats, int):
......@@ -1405,10 +1400,13 @@ class OpSet9():
attr = {
'expand_times': repeats,
"name": string(node.layer_name),
"name": string(node.name),
}
node.fluid_code.add_layer(
"expand", inputs=val_x, output=node, param_attr=attr)
self.paddle_graph.add_layer(
"paddle.tile",
inputs={"x": val_x.name},
outputs=[node.name],
repeat_times=repeats)
@print_mapping_info
def MaxPool(self, node):
......@@ -1423,8 +1421,8 @@ class OpSet9():
pad_mode = node.get_attr("pads")
ceil_mode = bool(node.get_attr('ceil_mode', 0)) # optional
pads = node.get_attr('pads', [0] * (poolnd * 2)) # optional
fluid_op = 'pool{}d'.format(poolnd)
assert 2 <= poolnd <= 3, 'only pool2d and pool3d is supported'
paddle_op = 'paddle.nn.functional.max_pool{}d'.format(poolnd)
assert 1 <= poolnd <= 3, 'only max_pool1d, max_pool2d and max_pool3d are supported'
paddings, val_x = self._pad_if_asymmetric(node, pads, val_x)
......@@ -1435,64 +1433,72 @@ class OpSet9():
pad_w = _get_same_padding(input_shape[3], kernel_shape[1],
strides[1])
paddings = pad_h + pad_w
attr = {
"pool_size": kernel_shape,
"pool_type": string("max"),
"pool_stride": strides,
"pool_padding": paddings,
layer_attrs = {
"kernel_size": kernel_shape,
"stride": strides,
"padding": paddings,
"ceil_mode": ceil_mode,
"name": string(node.layer_name),
"exclusive": False
}
node.fluid_code.add_layer(
fluid_op, inputs=val_x, output=node, param_attr=attr)
def _global_pool(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_node(node.layer.output[0], copy=True)
fluid_op = 'pool2d'
pool_type = None
if node.layer.op_type == 'GlobalMaxPool':
pool_type = 'max'
elif node.layer.op_type == 'GlobalAveragePool':
pool_type = 'avg'
attr = {
"pool_type": string(pool_type),
"global_pooling": True,
"name": string(node.layer_name)
}
node.fluid_code.add_layer(
fluid_op, inputs=val_x, output=node, param_attr=attr)
self.paddle_graph.add_layer(
paddle_op,
inputs={'x': val_x if isinstance(val_x, str) else val_x.name},
outputs=[node.name],
**layer_attrs)
@print_mapping_info
def GlobalMaxPool(self, node):
self._global_pool(node)
val_x = self.graph.get_input_node(node, idx=0, copy=True)
input_shape = val_x.out_shapes[0]
if len(input_shape) == 4:
poolnd = 2
elif len(input_shape) == 5:
poolnd = 3
elif len(input_shape) == 3:
poolnd = 1
paddle_op = 'paddle.nn.functional.adaptive_max_pool{}d'.format(poolnd)
assert 1 <= poolnd <= 3, 'only adaptive_max_pool1d, adaptive_max_pool2d and adaptive_max_pool3d are supported'
output_shape = node.out_shapes[0]
self.paddle_graph.add_layer(
paddle_op,
inputs={'x': val_x.name},
outputs=[node.name],
output_size=output_shape[2:])
@print_mapping_info
def GlobalAveragePool(self, node):
self._global_pool(node)
val_x = self.graph.get_input_node(node, idx=0, copy=True)
input_shape = val_x.out_shapes[0]
if len(input_shape) == 4:
poolnd = 2
elif len(input_shape) == 5:
poolnd = 3
elif len(input_shape) == 3:
poolnd = 1
paddle_op = 'paddle.nn.functional.adaptive_avg_pool{}d'.format(poolnd)
assert 1 <= poolnd <= 3, 'only Pool1D, Pool2D and Pool3D are supported'
output_shape = node.out_shapes[0]
self.paddle_graph.add_layer(
paddle_op,
inputs={'x': val_x.name},
outputs=[node.name],
output_size=output_shape[2:])
@print_mapping_info
def Conv(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_w = self.graph.get_input_node(node, idx=1, copy=True)
val_y = self.graph.get_node(node.layer.output[0], copy=True)
self.omit_nodes.append(val_w.layer_name)
has_bias = len(node.layer.input) == 3
if has_bias:
val_b = self.graph.get_input_node(node, idx=2, copy=True)
self.omit_nodes.append(val_b.layer_name)
auto_pad = node.get_attr('auto_pad', 'NOTSET')
kernel_shape = node.get_attr('kernel_shape')
convnd = len(kernel_shape)
assert 2 <= convnd <= 3, 'only conv2d and conv3d is supported'
num_out_channels = val_w.out_shapes[0][0]
fluid_op = 'conv{}d'.format(convnd)
num_in_channels = val_w.out_shapes[0][1]
paddle_op = 'paddle.nn.functional.conv{}d'.format(convnd)
num_groups = node.get_attr('group', 1)
strides = node.get_attr('strides', [1] * convnd)
......@@ -1509,22 +1515,23 @@ class OpSet9():
strides[1])
paddings = pad_h + pad_w
attr = {
"num_filters": num_out_channels,
"filter_size": kernel_shape,
layer_attrs = {
"stride": strides,
"padding": paddings,
"dilation": dilations,
"groups": num_groups,
'param_attr': string(val_w.layer_name),
"name": string(node.layer_name)
}
layer_inputs = {
"x": val_x.name,
"weight": val_w.name
}
if has_bias:
attr["bias_attr"] = string(val_b.layer_name)
else:
attr["bias_attr"] = False
node.fluid_code.add_layer(
fluid_op, inputs=val_x, output=node, param_attr=attr)
layer_inputs["bias"] = val_b.name
self.paddle_graph.add_layer(
paddle_op,
inputs=layer_inputs,
outputs=[node.name],
**layer_attrs)
@print_mapping_info
def ConvTranspose(self, node):
......@@ -1533,19 +1540,15 @@ class OpSet9():
val_b = None
if len(node.layer.input) > 2:
val_b = self.graph.get_input_node(node, idx=2, copy=True)
self.omit_nodes.append(val_b.layer_name)
self.omit_nodes.append(val_w.layer_name)
val_y = self.graph.get_node(node.layer.output[0], copy=True)
auto_pad = node.get_attr('auto_pad', 'NOTSET')
out_padding = node.get_attr('output_padding', [0, 0])
kernel_shape = node.get_attr('kernel_shape')
assert kernel_shape, 'kernel_shape not inferred'
convnd = len(kernel_shape)
assert 2 <= convnd <= 3, 'only conv2d_transpose and conv3d_transpose supported'
num_in_channels = val_w.out_shapes[0][0]
num_out_channels = val_w.out_shapes[0][1]
fluid_op = 'conv{}d_transpose'.format(convnd)
paddle_op = 'paddle.nn.functional.conv{}d_transpose'.format(convnd)
num_groups = node.get_attr('group', 1)
strides = node.get_attr('strides', [1] * convnd)
......@@ -1563,17 +1566,18 @@ class OpSet9():
output_size[1] = (val_x.out_shapes[0][3] - 1
) * strides[1] - 2 * paddings[1] + dilations[1] * (
kernel_shape[1] - 1) + 1 + out_padding[1]
attr = {
'num_filters': num_out_channels,
'output_size': output_size or None,
'filter_size': kernel_shape,
'padding': paddings,
'stride': strides,
'dilation': dilations,
'groups': num_groups,
'param_attr': string(val_w.layer_name),
'bias_attr': None if val_b is None else string(val_b.layer_name),
'name': string(node.layer_name),
}
node.fluid_code.add_layer(
fluid_op, inputs=val_x, output=node, param_attr=attr)
layer_inputs = {'x': val_x.name,
"weight": val_w.name}
layer_attrs = {
"stride": strides,
"dilation": dilations,
"padding": paddings,
"groups": num_groups,
"output_size": node.out_shapes[0][2:]}
if val_b is not None:
layer_inputs["bias"] = val_b.name
self.paddle_graph.add_layer(
kernel="paddle.nn.functional.conv2d_transpose",
inputs=layer_inputs,
outputs=[node.name],
**layer_attrs)
\ No newline at end of file
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO useless node remove
class ONNXOptimizer(object):
def __init__(self, op_mapper):
self.op_mapper = op_mapper
self.graph = op_mapper.graph
def delete_redundance_code(self):
for node_name in self.graph.topo_sort:
if node_name in self.op_mapper.omit_nodes:
node = self.graph.get_node(node_name)
omit_freq = self.op_mapper.omit_nodes.count(node_name)
if len(node.outputs) <= omit_freq:
node.fluid_code.clear()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册