提交 fb61a10b 编写于 作者: C Channingss

add lstm & mapping weight by rename to paddle's naming rule

......@@ -166,7 +166,6 @@ class PaddleGraph(object):
self.clear_edges()
outputs_from_nodes = dict()
for layer_id, layer in self.layers.items():
print(layer.kernel, layer.outputs ,layer.inputs)
for input_key, input_var in layer.inputs.items():
vs = input_var
if not isinstance(vs, (list, tuple)):
......@@ -211,9 +210,12 @@ class PaddleGraph(object):
if self.edges_in.get(layer_id, 0) == 0 and self.edges_out.get(
layer_id, 0) == 0 and layer.kernel != "prim.assert" \
and layer.kernel != "prim.exception" \
and layer.kernel != "prim.warnings":
if layer.kernel == "paddle.to_tensor":
and layer.kernel != "prim.warnings" \
and layer.outputs[0] not in self.outputs:
if layer.kernel == "paddle.to_tensor" and layer.outputs[0] in self.inputs_info:
self.inputs_info.pop(layer.outputs[0])
if layer.outputs[0] in self.inputs:
self.inputs.pop(self.inputs.index(layer.outputs[0]))
invalid_list.append(layer_id)
for layer_id in invalid_list:
self.layers.pop(layer_id)
......@@ -323,6 +325,9 @@ class PaddleGraph(object):
if self.source_type == "caffe":
custom_import = "from x2paddle.op_mapper.static.caffe2paddle " + \
"import caffe_custom_layer as x2paddle_nn"
elif self.source_type == "onnx":
custom_import = "from x2paddle.op_mapper.static.onnx2paddle " + \
"import onnx_custom_layer as x2paddle_nn"
else:
custom_import = ""
......@@ -352,7 +357,9 @@ class PaddleGraph(object):
remove_default_attrs(layer.kernel, layer.attrs)
edges_in = self.edges_in.get(layer_id, [])
edges_out = self.edges_out.get(layer_id, [])
if len(edges_in) == 0 and len(edges_out) == 0:
if len(edges_in) == 0 and len(edges_out) == 0 and layer.outputs[0] not in self.outputs:
if layer.outputs[0] in self.inputs:
self.inputs.pop(self.inputs.index(layer.outputs[0]))
continue
line = ""
......@@ -472,6 +479,9 @@ class PaddleGraph(object):
elif self.source_type == "pytorch":
custom_import = "from x2paddle.op_mapper.dygraph.pytorch2paddle " + \
"import pytorch_custom_layer as x2paddle_nn"
elif self.source_type == "onnx":
custom_import = "from x2paddle.op_mapper.dygraph.onnx2paddle " + \
"import onnx_custom_layer as x2paddle_nn"
else:
custom_import = ""
self.head = gen_codes(
......@@ -580,7 +590,7 @@ class PaddleGraph(object):
elif len(layer.outputs) == 2:
line = layer.outputs[1]
else:
if layer.kernel == "paddle.nn.LSTM":
if layer.kernel in ["paddle.nn.LSTM", 'custom_layer:LSTM']:
line = "{}, ({})".format(layer.outputs[1], ', '.join(layer.outputs[-2:]))
else:
line = ','.join(layer.outputs[1:])
......@@ -589,8 +599,13 @@ class PaddleGraph(object):
line += " = self.{}".format(layer.outputs[0])
else:
line += " = self.{}(".format(layer.outputs[0])
for k, v in layer.inputs.items():
line += "{}, ".format(v)
for v in layer.inputs.values():
if isinstance(v, list):
line += "[{}], ".format(", ".join(v))
elif isinstance(v, tuple):
line += "({}), ".format(", ".join(v))
else:
line += "{}, ".format(v)
line = line.strip(", ")
line += ")"
self.forward_func.extend(gen_codes([line], indent=indent))
......
......@@ -31,6 +31,7 @@ import numpy as np
from copy import deepcopy
import logging as _logging
import os
import copy
default_op_domain = 'ai.onnx'
_logger = _logging.getLogger(__name__)
......@@ -98,9 +99,7 @@ class ONNXGraphNode(GraphNode):
def output(self, index=0):
if index >0 and len(self.layer.output) <= index:
raise IndexError('Output numbers of Node:{} is {} <= index:{}'.format(self.layer_name, len(self.layer.output), index))
if index > 0:
return "{}_p{}".format(self.layer_name, index)
return self.layer_name
return self.layer.output[index]
class ONNXGraphDataNode(GraphNode):
......@@ -132,6 +131,17 @@ class ONNXGraphDataNode(GraphNode):
shape.append(dim.dim_value)
out_shapes.append(shape)
return out_shapes
elif isinstance(self.layer, TensorProto):
values = self.layer.dims
out_shapes = list()
shape = list()
for dim in values:
if dim == 0:
shape.append(-1)
else:
shape.append(dim)
out_shapes.append(shape)
return out_shapes
else:
values = self.layer.dims
out_shapes = list()
......@@ -241,11 +251,12 @@ class ONNXGraph(Graph):
"""
generate output_nodes node of ONNX model
"""
inner_nodes = self.get_inner_nodes()
output_nodes = [value.name for value in self.graph.output]
for opt_data in output_nodes:
if opt_data not in inner_nodes:
self.output_nodes.append(opt_data)
#n = super(ONNXGraph, self).get_node(opt_data)
#if n is None:
# self.topo_sort.append(self.node_map[opt_data])
self.output_nodes.append(opt_data)
def is_place_holder_nodes(self, layer):
"""
......@@ -293,7 +304,7 @@ class ONNXGraph(Graph):
#generate topo
super(ONNXGraph, self).build()
self.input_nodes = self.place_holder_nodes
self.input_nodes = copy.deepcopy(self.place_holder_nodes)
def build_connection(self, layer_name, node):
"""
......@@ -410,10 +421,8 @@ class ONNXDecoder(object):
check_model(onnx_model)
onnx_model = self.optimize_model_skip_op(onnx_model)
onnx_model = self.optimize_model_strip_initializer(onnx_model)
onnx_model = self.optimize_node_name(onnx_model)
self.graph = ONNXGraph(onnx_model)
#self.onnx_model = onnx_model
def build_value_refs(self, nodes):
"""
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .one_hot import OneHot
from .rnn import LSTM
from .pad_two_input import PadWithTwoInput
from .pad_all_dim2 import PadAllDim2
from .pad_all_dim4 import PadAllDim4
from .pad_all_dim4_one_input import PadAllDim4WithOneInput
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
class OneHot(object):
def __init__(self, axis):
self.axis = axis
def __call__(self, indices, depth, values):
indices_shape = indices.shape
rank = len(indices.shape)
real_axis = self.axis
if self.axis < 0:
real_axis = self.axis + rank + 1
depth_range = paddle.arange(end=depth)
ls = tuple(indices_shape[0: real_axis])
rs = tuple(indices_shape[real_axis: rank])
targets = paddle.reshape(depth_range, (1,) * (real_axis-0) + tuple(depth_range.shape) + (1,) * (rank-real_axis))
mod = paddle.mod(indices, depth)
v = paddle.reshape(mod, ls + (1,) + rs)
out = targets == v
out = paddle.cast(out, "float32")
on_value = paddle.slice(values, axes=[0], starts=[1], ends=[2])
off_value = paddle.slice(values, axes=[0], starts=[0], ends=[1])
out = out * (on_value - off_value) + off_value
return out
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from x2paddle.core.util import *
class PadAllDim2(object):
def __init__(self, value, mode):
self.layer_attrs = {}
self.layer_attrs['mode'] = mode
self.layer_attrs['data_format'] = 'NCHW'
self.layer_attrs['value'] = value
def __call__(self, x, pad):
pad = paddle.reshape(pad, shape=[2, -1])
pad = paddle.transpose(pad, perm=[1, 0])
pad = paddle.reverse(pad, axis=[0])
pad = paddle.flatten(pad)
pad = paddle.cast(pad, dtype="int32")
x = paddle.unsqueeze(x, axis=[0, 1])
out = paddle.nn.functional.pad(x=x, pad=pad, **self.layer_attrs)
out = paddle.squeeze(out, axis=[0, 1])
return out
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from x2paddle.core.util import *
class PadAllDim4(object):
def __init__(self, value, mode):
self.layer_attrs = {}
self.layer_attrs['mode'] = mode
self.layer_attrs['data_format'] = 'NCHW'
self.layer_attrs['value'] = value
def __call__(self, x, pad):
pad = paddle.reshape(pad, shape=[2, -1])
pad = paddle.transpose(pad, perm=[1, 0])
pad = paddle.reverse(pad, axis=[0])
pad = paddle.flatten(pad)
pad = paddle.cast(pad, dtype="int32")
pad1, pad2 = paddle.split(pad, num_or_sections=2, axis=0)
x = paddle.nn.functional.pad(x=x, pad=pad1, **self.layer_attrs)
x = paddle.transpose(x, perm=[2, 3, 0, 1])
x = paddle.nn.functional.pad(x=x, pad=pad2, **self.layer_attrs)
out = paddle.transpose(x, perm=[2, 3, 0, 1])
return out
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from x2paddle.core.util import *
class PadAllDim4WithOneInput(object):
def __init__(self, pad, value, mode):
self.layer_attrs = {}
self.layer_attrs['mode'] = mode
self.layer_attrs['data_format'] = 'NCHW'
self.layer_attrs['value'] = value
self.pad1 = pad[0: 4]
self.pad2 = pad[4: 9]
def __call__(self, x):
x = paddle.nn.functional.pad(x=x, pad=self.pad1, **self.layer_attrs)
x = paddle.transpose(x, perm=[2, 3, 0, 1])
x = paddle.nn.functional.pad(x=x, pad=self.pad2, **self.layer_attrs)
out = paddle.transpose(x, perm=[2, 3, 0, 1])
return out
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from x2paddle.core.util import *
class PadWithTwoInput(object):
def __init__(self, value, mode, data_format):
self.layer_attrs = {}
self.layer_attrs['mode'] = mode
self.layer_attrs['data_format'] = data_format
self.layer_attrs['value'] = value
def __call__(self, x, pad):
pad = paddle.reshape(pad, shape=[2, -1])
pad = paddle.transpose(pad, perm=[1, 0])
pad = paddle.reverse(pad, axis=[0])
pad = paddle.flatten(pad)
pad = paddle.cast(pad, dtype="int32")
out = paddle.nn.functional.pad(x=x, pad=pad, **self.layer_attrs)
return out
\ No newline at end of file
......@@ -13,8 +13,6 @@
# limitations under the License.
import paddle
from itertools import product
import numpy as np
class Gather(object):
def __init__(self, dim):
......
......@@ -642,27 +642,6 @@ class TFOpMapper(OpMapper):
assert paddings.layer_type == "Const", "Padding should be Const"
paddings = paddings.value.flatten().tolist()
if len(input.out_shapes[0]) == 4:
if paddings[0] + paddings[1] + paddings[6] + paddings[7] == 0:
new_padding = paddings[2:6]
transpose_name = gen_name("pad", "transpose")
self.paddle_graph.add_layer(
kernel="paddle.transpose",
inputs={"x": input.name},
outputs=[transpose_name],
perm=[0, 3, 1, 2])
self.paddle_graph.add_layer(
kernel="paddle.nn.functional.pad",
inputs={"x": transpose_name},
outputs=[node.name],
pad=new_padding)
self.paddle_graph.add_layer(
kernel="paddle.transpose",
inputs={"x": node.name},
outputs=[node.name],
perm=[0, 2, 3, 1])
return
self.paddle_graph.add_layer(
kernel="paddle.nn.functional.pad",
inputs={"x": input.name},
......@@ -670,31 +649,11 @@ class TFOpMapper(OpMapper):
pad=paddings)
def MirrorPad(self, node):
op_name = name_generator("pad", self.nn_name2id)
output_name = node.name
layer_outputs = [op_name, output_name]
input = self.graph.get_input_node(node, 0)
paddings = self.graph.get_input_node(node, 1)
assert paddings.layer_type == "Const", "Padding should be Const"
new_paddings = numpy.flip(paddings.value, 0).flatten().tolist()
dim = int(len(new_paddings) / 2)
transpose_name = gen_name("pad", "transpose")
self.paddle_graph.add_layer(
kernel="paddle.transpose",
inputs={"x": input.name},
outputs=[transpose_name],
perm=[0, 3, 1, 2])
self.paddle_graph.add_layer(
kernel="paddle.nn.Pad{}D".format(dim),
inputs={"x": transpose_name},
outputs=layer_outputs,
pad=new_paddings)
self.paddle_graph.add_layer(
kernel="paddle.transpose",
inputs={"x": node.name},
outputs=[node.name],
perm=[0, 2, 3, 1])
self.Pad(node)
def PadV2(self, node):
self.Pad(node)
def Squeeze(self, node):
input = self.graph.get_input_node(node, 0)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .one_hot import one_hot
from .pad_two_input import pad_with_two_input
from .pad_all_dim2 import pad_all_dim2
from .pad_all_dim4 import pad_all_dim4
from .pad_all_dim4_one_input import pad_all_dim4_one_input
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
def one_hot(indices, depth, values, axis):
indices_shape = indices.shape
rank = len(indices.shape)
real_axis = axis
if axis < 0:
real_axis = axis + rank + 1
depth_range = paddle.arange(end=depth)
ls = tuple(indices_shape[0: real_axis])
rs = tuple(indices_shape[real_axis: rank])
targets = paddle.reshape(depth_range, (1,) * (real_axis-0) + tuple(depth_range.shape) + (1,) * (rank-real_axis))
mod = paddle.mod(indices, depth)
v = paddle.reshape(mod, ls + (1,) + rs)
out = targets == v
out = paddle.cast(out, "float32")
on_value = paddle.slice(values, axes=[0], starts=[1], ends=[2])
off_value = paddle.slice(values, axes=[0], starts=[0], ends=[1])
out = out * (on_value - off_value) + off_value
return out
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
def pad_all_dim2(x, pad, value, mode):
pad = paddle.reshape(pad, shape=[2, -1])
pad = paddle.transpose(pad, perm=[1, 0])
pad = paddle.reverse(pad, axis=[0])
pad = paddle.flatten(pad)
pad = paddle.cast(pad, dtype="int32")
x = paddle.unsqueeze(x, axis=[0, 1])
out = paddle.nn.functional.pad(x=x,
pad=pad,
mode=mode,
data_format='NCHW',
value=value)
out = paddle.squeeze(out, axis=[0, 1])
return out
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
def pad_all_dim4(x, pad, value, mode):
pad = paddle.reshape(pad, shape=[2, -1])
pad = paddle.transpose(pad, perm=[1, 0])
pad = paddle.reverse(pad, axis=[0])
pad = paddle.flatten(pad)
pad = paddle.cast(pad, dtype="int32")
pad1, pad2 = paddle.split(pad, num_or_sections=2, axis=0)
x = paddle.nn.functional.pad(x=x,
pad=pad1,
mode=mode,
data_format='NCHW',
value=value)
x = paddle.transpose(x, perm=[2, 3, 0, 1])
x = paddle.nn.functional.pad(x=x,
pad=pad2,
mode=mode,
data_format='NCHW',
value=value)
out = paddle.transpose(x, perm=[2, 3, 0, 1])
return out
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
def pad_all_dim4_one_input(x, pad, value, mode):
x = paddle.nn.functional.pad(x=x,
pad=pad[0: 4],
mode=mode,
data_format='NCHW',
value=value)
x = paddle.transpose(x, perm=[2, 3, 0, 1])
x = paddle.nn.functional.pad(x=x,
pad=pad[4: 9],
mode=mode,
data_format='NCHW',
value=value)
out = paddle.transpose(x, perm=[2, 3, 0, 1])
return out
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
def pad_with_two_input(x, pad, value, mode, data_format):
pad = paddle.reshape(pad, shape=[2, -1])
pad = paddle.transpose(pad, perm=[1, 0])
pad = paddle.reverse(pad, axis=[0])
pad = paddle.flatten(pad)
pad = paddle.cast(pad, dtype="int32")
out = paddle.nn.functional.pad(x=x,
pad=pad,
value=value,
mode=mode,
data_format=data_format)
return out
\ No newline at end of file
......@@ -106,6 +106,9 @@ class OpSet9():
'ReduceMax': ['paddle.max',
dict(axes='axis', keepdims='keepdim'),
dict(keepdim=1)],
'ReduceProd': ['paddle.prod',
dict(axes='axis', keepdims='keepdim'),
dict(keepdim=1)],
# active function
'Relu': ['paddle.nn.functional.relu'],
'LeakyRelu': ['paddle.nn.functional.leaky_relu',
......@@ -203,7 +206,7 @@ class OpSet9():
node = parameter
dtype = node.dtype
shape = node.out_shapes[0]
if len(node.weight.shape) == 0:
if hasattr(node.weight, "shape") and len(node.weight.shape) == 0:
self.paddle_graph.add_layer(
"paddle.full",
inputs={},
......@@ -286,6 +289,10 @@ class OpSet9():
attrs.update({"align_corners": False,
"mode": string(mode),
"align_mode": 1})
val_x_shape = val_x.out_shapes[0]
if mode == "linear" and len(val_x_shape) == 4:
attrs["mode"] = string("bilinear")
attrs["align_corners"] = True
self.paddle_graph.add_layer(
kernel="paddle.nn.functional.interpolate",
inputs=inputs,
......@@ -368,61 +375,136 @@ class OpSet9():
def Pad(self, node, op_independent=True):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
pads = node.get_attr('pads')
is_pads_attr = True
if pads is None:
val_pad = self.graph.get_input_node(node, idx=1, copy=True)
pad_shape = val_pad.out_shapes[0]
is_pads_attr = False
pads = _const_weight_or_none(val_pad)
if pads is not None:
is_pads_attr = True
mode = node.get_attr('mode', 'constant')
value = node.get_attr('value', 0.)
data_shape = val_x.out_shapes[0]
output_shape = node.out_shapes[0]
assume_pad2d = False
assume_pad = False
layer_attrs = {}
layer_attrs['mode'] = string(mode)
paddings = []
if len(pads) == 4:
assume_pad2d |= mode != 'constant'
if data_shape:
assume_pad2d |= data_shape and len(data_shape) == 4 # NCHW
if output_shape:
assume_pad2d |= output_shape and len(output_shape) == 4 # NCHW
if assume_pad2d:
paddle_op = 'paddle.nn.functional.pad'
layer_attrs['data_format'] = string('NCHW')
layer_attrs['value'] = value
layer_attrs['value'] = value
if not op_independent:
output_name = node.name + '_paded'
else:
paddle_op = 'paddle.fluid.layers.pad'
layer_attrs["pad_value"] = value
if len(pads) == 4:
paddings = np.array(pads).reshape(
(-1, 2)).transpose().flatten().tolist() # SSEE -> SESE
elif len(pads) == 8:
paddings = np.array(pads).reshape(
(-1, 4)).transpose().flatten().tolist() # SSEE -> SESE
if sum(paddings[:4]) == 0:
paddle_op = 'paddle.nn.functional.pad'
paddings = paddings[4:]
layer_attrs['value'] = value
if 'pad_value' in layer_attrs:
layer_attrs.pop('pad_value')
tmp_paddings = copy.deepcopy(paddings)
paddings[0] = tmp_paddings[2]
paddings[1] = tmp_paddings[3]
paddings[2] = tmp_paddings[0]
paddings[3] = tmp_paddings[1]
if paddle_op == 'paddle.nn.functional.pad':
layer_attrs['pad'] = paddings
else:
layer_attrs['paddings'] = paddings
if op_independent:
output_name = node.name
layer_outputs = [output_name]
if is_pads_attr:
paddings = []
paddle_op = 'paddle.nn.functional.pad'
if len(pads) in [2, 4, 6]:
if data_shape:
assume_pad |= data_shape and 2 * (len(data_shape) - 2) == len(pads) # NCHW
if output_shape:
assume_pad |= output_shape and 2 * (len(output_shape) - 2) == len(pads) # NCHW
if assume_pad:
if len(pads) == 2:
data_format = "NCL"
elif len(pads) == 4:
data_format = "NCHW"
else:
data_format = "NCDHW"
paddings = np.array(pads).reshape(
(2, -1)).transpose().astype("int32")
paddings = np.flip(paddings, axis=0).flatten().tolist()
layer_attrs['pad'] = paddings
layer_attrs['data_format'] = data_format
else:
if data_shape:
assume_pad |= data_shape and 2 * len(data_shape) == len(pads) # NCHW
if output_shape:
assume_pad |= output_shape and 2 * len(output_shape) == len(pads) # NCHW
if assume_pad:
paddings = np.array(pads).reshape(
(2, -1)).transpose().astype("int32").flatten().tolist()
layer_attrs['pad'] = paddings
else:
raise Exception("The padding value {} is wrong!".format(pads))
elif len(pads) == 8:
if data_shape:
assume_pad |= data_shape and 2 * len(data_shape) == len(pads) # NCHW
if output_shape:
assume_pad |= output_shape and 2 * len(output_shape) == len(pads) # NCHW
if assume_pad:
paddings = np.array(pads).reshape(
(2, -1)).transpose().astype("int32")
paddings = np.flip(paddings, axis=0).flatten().tolist()
if sum(paddings[:4]) == 0:
paddings = paddings[4:]
layer_attrs['pad'] = paddings
else:
layer_attrs['pad'] = paddings
paddle_op = "custom_layer:pad_all_dim4_one_input"
else:
raise Exception("The padding value {} is wrong!".format(pads))
self.paddle_graph.add_layer(
paddle_op,
inputs={'x': val_x.name},
outputs=[node.name],
outputs=layer_outputs,
**layer_attrs)
if not op_independent:
return node.name + '_paded'
else:
self.paddle_graph.add_layer(
paddle_op,
inputs={'x': val_x.name},
outputs=[node.name + '_paded'],
**layer_attrs)
return node.name + '_paded'
pads_len = val_pad.out_shapes[0][0]
if pads_len in [2, 4, 6]:
if data_shape:
assume_pad |= data_shape and 2 * (len(data_shape) - 2) == pads_len # NCHW
if output_shape:
assume_pad |= output_shape and 2 * (len(output_shape) - 2) == pads_len # NCHW
if assume_pad:
if pads_len == 2:
data_format = "NCL"
elif pads_len == 4:
data_format = "NCHW"
else:
data_format = "NCDHW"
self.paddle_graph.add_layer(
"custom_layer:pad_with_two_input",
inputs={'x': val_x.name, 'pad': val_pad.name},
outputs=layer_outputs,
value=value,
mode=string(mode),
data_format=string(data_format))
else:
if data_shape:
assume_pad |= data_shape and 2 * len(data_shape) == pads_len # NCHW
if output_shape:
assume_pad |= output_shape and 2 * len(output_shape) == pads_len # NCHW
if assume_pad:
if pads_len == 4:
self.paddle_graph.add_layer(
"custom_layer:pad_all_dim2",
inputs={'x': val_x.name, 'pad': val_pad.name},
outputs=layer_outputs,
value=value,
mode=string(mode))
else:
raise Exception("The padding value is wrong!")
elif pads_len == 8:
if data_shape:
assume_pad |= data_shape and 2 * len(data_shape) == pads_len # NCHW
if output_shape:
assume_pad |= output_shape and 2 * len(output_shape) == pads_len # NCHW
if assume_pad:
self.paddle_graph.add_layer(
"custom_layer:pad_all_dim4",
inputs={'x': val_x.name, 'pad': val_pad.name},
outputs=layer_outputs,
value=value,
mode=string(mode))
else:
print(pads_len)
raise Exception("The padding value is wrong!")
if not op_independent:
return node.name + '_paded'
@print_mapping_info
def Unsqueeze(self, node):
......@@ -622,17 +704,13 @@ class OpSet9():
self.paddle_graph.add_layer(
'paddle.cast',
inputs={"x": indices.name},
outputs=indices_cast,
outputs=[indices_cast],
dtype=string('int64'))
op_name = name_generator("embedding", self.nn_name2id)
output_name = node.name
layer_outputs = [op_name, output_name]
self.paddle_graph.add_layer(
'paddle.nn.Embedding',
inputs={"x": indices_cast},
outputs=layer_outputs,
param_attr=string(val_x.name),
size=val_x.out_shapes[0])
'paddle.nn.functional.embedding',
inputs={"x": indices_cast,
"weight": val_x.name},
outputs=[node.name])
else:
from functools import reduce
reshape_shape = reduce(lambda x, y: x * y, indices_shape)
......@@ -804,20 +882,27 @@ class OpSet9():
starts = self.graph.get_input_node(node, idx=1, copy=True)
ends = self.graph.get_input_node(node, idx=2, copy=True)
starts_value = _const_weight_or_none(starts)
if starts_value is not None:
starts_value = starts_value.tolist()
ends_value = _const_weight_or_none(ends)
if ends_value is not None:
ends_value = ends_value.tolist()
if len(node.inputs) > 2:
s_len = len(val_x.out_shapes[0])
axes = list(range(s_len))
if len(node.inputs) > 3:
axes = self.graph.get_input_node(node, idx=3, copy=True)
axes = _const_weight_or_none(axes, necessary=True)
axes_node = self.graph.get_input_node(node, idx=3, copy=True)
axes = _const_weight_or_none(axes_node, necessary=True).tolist()
if len(node.inputs) > 4:
steps = self.graph.get_input_node(node, idx=4, copy=True)
steps = _const_weight_or_none(steps)
steps = _const_weight_or_none(steps).tolist()
layer_attrs = {
"axes": axes,
"starts": starts.name,
"ends": ends.name
}
if starts_value is not None and ends_value is not None:
if starts_value is not None and ends_value is not None and axes is not None:
starts_value = starts_value.copy()
ends_value = ends_value.copy()
#for idx in range(len(ends_value)):
......@@ -847,6 +932,8 @@ class OpSet9():
layer_attrs['starts'] = starts_cast
if ends.dtype != 'int32':
ends_cast = ends.name + '_cast'
else:
ends_cast = ends.name
self.paddle_graph.add_layer(
'paddle.cast',
inputs={"x": ends.name},
......@@ -862,6 +949,7 @@ class OpSet9():
ends[idx] = 2**31 - 1
layer_attrs = {"axes": axes, "starts": starts, "ends": ends}
if steps is not None:
layer_attrs['strides'] = steps
self.paddle_graph.add_layer(
......@@ -986,11 +1074,17 @@ class OpSet9():
inputs={'x': val_shape.name},
outputs=[val_shape.name],
shape=val_shape.out_shapes[0])
if val_shape.dtype != "int32":
self.paddle_graph.add_layer(
'paddle.cast',
inputs={'x': val_shape.name},
outputs=[val_shape.name],
dtype=string("int32"))
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={'x': val_x.name,
'shape': val_shape.name},
outputs=node)
outputs=[node.name])
@print_mapping_info
def Cast(self, node):
......@@ -1221,7 +1315,10 @@ class OpSet9():
@print_mapping_info
def Transpose(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
perm = node.get_attr('perm')
s_len = len(val_x.out_shapes[0])
perm_default = list(range(s_len))
perm_default.reverse()
perm = node.get_attr('perm', perm_default)
self.paddle_graph.add_layer(
"paddle.transpose",
inputs={"x": val_x.name},
......@@ -1230,9 +1327,6 @@ class OpSet9():
@print_mapping_info
def PRelu(self, node):
op_name = name_generator("prelu", self.nn_name2id)
output_name = node.name
layer_outputs = [op_name, output_name]
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_slope = self.graph.get_input_node(node, idx=1, copy=True)
......@@ -1240,20 +1334,27 @@ class OpSet9():
shape_slope = val_slope.out_shapes[0]
if shape_slope == [1]:
mode = 'all'
elif len(shape_slope) > 2:
raise Exception("The 'element' mode is not supported yet!")
if mode == 'channel' and len(shape_slope) == 1:
# paddle params shape need be [1, channel]
slope_data = _const_weight_or_none(val_slope)
slope_data = np.reshape(slope_data, [1] + shape_slope)
self.params[val_slope.name] = slope_data
self.paddle_graph.add_layer(
"paddle.nn.functional.prelu",
inputs={"x": val_x.name,
"weight": val_slope.name},
outputs=[node.name])
if mode == "element":
self.paddle_graph.add_layer(
"paddle.static.nn.prelu",
inputs={"x": val_x.name,
"param_attr": val_slope.name},
outputs=[node.name],
mode="element")
else:
if mode == 'channel':
if len(shape_slope) > 1:
self.paddle_graph.add_layer(
"paddle.reshape",
inputs={"x": val_slope.name},
outputs=[val_slope.name],
shape=[shape_slope[0]])
self.paddle_graph.add_layer(
"paddle.nn.functional.prelu",
inputs={"x": val_x.name,
"weight": val_slope.name},
outputs=[node.name])
@print_mapping_info
def Squeeze(self, node):
......@@ -1521,6 +1622,16 @@ class OpSet9():
}
if has_bias:
layer_inputs["bias"] = val_b.name
input_shape = val_x.out_shapes[0]
if reduce(lambda x,y:x*y, input_shape) in [1, -1] and 1 not in input_shape:
input_shape[1] = num_in_channels * num_groups
input_shape[0] = 0
input_shape[2] = 0
self.paddle_graph.add_layer(
"paddle.reshape",
inputs={"x": layer_inputs["x"]},
outputs=[layer_inputs["x"]],
shape=input_shape)
self.paddle_graph.add_layer(
paddle_op,
inputs=layer_inputs,
......@@ -1588,4 +1699,62 @@ class OpSet9():
inputs={"x": val_x.name},
outputs=[node.name],
**layer_attrs)
@print_mapping_info
def Size(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
self.paddle_graph.add_layer(
"paddle.shape",
inputs={"input": val_x.name},
outputs=[node.name])
self.paddle_graph.add_layer(
'paddle.cast',
inputs={"x": node.name},
outputs=[node.name],
dtype=string('int64'))
self.paddle_graph.add_layer(
"paddle.prod",
inputs={"x": node.name},
outputs=[node.name])
@print_mapping_info
def Sign(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
if node.dtype not in ["float16", "float32", "float64"]:
self.paddle_graph.add_layer(
"paddle.cast",
inputs={"x": val_x.name},
outputs=[val_x.name],
dtype=string("float32"))
self.paddle_graph.add_layer(
"paddle.sign",
inputs={"x": val_x.name},
outputs=[node.name])
if node.dtype not in ["float16", "float32", "float64"]:
self.paddle_graph.add_layer(
"paddle.cast",
inputs={"x": node.name},
outputs=[node.name],
dtype=string(node.dtype))
@print_mapping_info
def OneHot(self, node):
indices = self.graph.get_input_node(node, idx=0, copy=True)
depth = self.graph.get_input_node(node, idx=1, copy=True)
values = self.graph.get_input_node(node, idx=2, copy=True)
axis = node.get_attr('axis', -1)
self.paddle_graph.add_layer(
"custom_layer:one_hot",
inputs={"indices": indices.name,
"depth": depth.name,
"values": values.name},
outputs=[node.name],
axis=axis)
@print_mapping_info
def Reciprocal(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
self.paddle_graph.add_layer(
"paddle.reciprocal",
inputs={"x": val_x.name},
outputs=[node.name])
......@@ -625,32 +625,11 @@ class TFOpMapper(OpMapper):
shape=out_shape.tolist())
def Pad(self, node):
input = self.graph.get_node(node.layer.input[0])
paddings = self.graph.get_node(node.layer.input[1])
input = self.graph.get_input_node(node, 0)
paddings = self.graph.get_input_node(node, 1)
assert paddings.layer_type == "Const", "Padding should be Const"
paddings = paddings.value.flatten().tolist()
if len(input.out_shapes[0]) == 4:
if paddings[0] + paddings[1] + paddings[6] + paddings[7] == 0:
new_padding = paddings[2:6]
transpose_name = gen_name("pad", "transpose")
self.paddle_graph.add_layer(
kernel="paddle.transpose",
inputs={"x": input.name},
outputs=[transpose_name],
perm=[0, 3, 1, 2])
self.paddle_graph.add_layer(
kernel="paddle.nn.functional.pad",
inputs={"x": transpose_name},
outputs=[node.name],
pad=new_padding)
self.paddle_graph.add_layer(
kernel="paddle.transpose",
inputs={"x": node.name},
outputs=[node.name],
perm=[0, 2, 3, 1])
return
self.paddle_graph.add_layer(
kernel="paddle.nn.functional.pad",
inputs={"x": input.name},
......@@ -658,26 +637,11 @@ class TFOpMapper(OpMapper):
pad=paddings)
def MirrorPad(self, node):
input = self.graph.get_input_node(node, 0)
paddings = self.graph.get_input_node(node, 1)
assert paddings.layer_type == "Const", "Padding should be Const"
new_paddings = numpy.flip(paddings.value, 0).flatten().tolist()
transpose_name = gen_name("pad", "transpose")
self.paddle_graph.add_layer(
kernel="paddle.transpose",
inputs={"x": input.name},
outputs=[transpose_name],
perm=[0, 3, 1, 2])
self.paddle_graph.add_layer(
kernel="paddle.nn.functional.pad".format(dim),
inputs={"x": transpose_name},
outputs=[node.name],
pad=new_paddings)
self.paddle_graph.add_layer(
kernel="paddle.transpose",
inputs={"x": node.name},
outputs=[node.name],
perm=[0, 2, 3, 1])
self.Pad(node)
def PadV2(self, node):
self.Pad(node)
def Squeeze(self, node):
input = self.graph.get_input_node(node, 0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册