未验证 提交 2c01f3f7 编写于 作者: J Jason 提交者: GitHub

Merge pull request #442 from SunAhong1993/paddle-2.0-new

for same structure
...@@ -18,6 +18,7 @@ from google.protobuf import text_format ...@@ -18,6 +18,7 @@ from google.protobuf import text_format
import numpy as np import numpy as np
from x2paddle.core.graph import GraphNode, Graph from x2paddle.core.graph import GraphNode, Graph
from x2paddle.core.fluid_code import FluidCode from x2paddle.core.fluid_code import FluidCode
from x2paddle.decoder import caffe_shape_inference
class CaffeResolver(object): class CaffeResolver(object):
...@@ -59,6 +60,28 @@ class CaffeGraphNode(GraphNode): ...@@ -59,6 +60,28 @@ class CaffeGraphNode(GraphNode):
def set_params(self, params): def set_params(self, params):
self.data = params self.data = params
@property
def name(self):
if hasattr(self, 'index'):
return "{}_p{}".format(self.layer_name, self.index)
return self.layer_name
@property
def out_shapes(self):
return self._out_shapes
@out_shapes.setter
def out_shapes(self, value):
self._out_shapes = value
@property
def in_shapes(self):
return self._in_shapes
@in_shapes.setter
def in_shapes(self, value):
self._in_shapes = value
class CaffeGraph(Graph): class CaffeGraph(Graph):
...@@ -226,8 +249,11 @@ class CaffeGraph(Graph): ...@@ -226,8 +249,11 @@ class CaffeGraph(Graph):
layer_name) layer_name)
super(CaffeGraph, self).build() super(CaffeGraph, self).build()
for i, node_name in enumerate(self.topo_sort):
node = self.get_node(node_name)
self.set_node_shape(node)
def get_bottom_node(self, node, idx=0, copy=False): def get_input_node(self, node, idx=0, copy=False):
input_node_name = node.inputs[idx] input_node_name = node.inputs[idx]
assert input_node_name in self.node_map, 'The {} isn\'t a valid node'.format( assert input_node_name in self.node_map, 'The {} isn\'t a valid node'.format(
name) name)
...@@ -238,6 +264,19 @@ class CaffeGraph(Graph): ...@@ -238,6 +264,19 @@ class CaffeGraph(Graph):
else: else:
name = input_node_name name = input_node_name
return self.get_node(name, copy=copy) return self.get_node(name, copy=copy)
def set_node_shape(self, node):
inputs = node.inputs
input_shape = []
for i, nm in enumerate(inputs):
last_node = self.get_node(nm)
tmp = node.layer.bottom[i]
idx = list(last_node.layer.top).index(tmp)
input_shape.append(last_node.out_shapes[idx])
node.in_shapes = input_shape
func_name = 'shape_' + node.layer_type.lower()
node.out_shapes = getattr(caffe_shape_inference, func_name)(node.layer,
input_shape)
class CaffeDecoder(object): class CaffeDecoder(object):
......
...@@ -83,6 +83,10 @@ def shape_convolution(layer, input_shape): ...@@ -83,6 +83,10 @@ def shape_convolution(layer, input_shape):
return get_strided_kernel_output_shape(params, input_shape[0], math.floor) return get_strided_kernel_output_shape(params, input_shape[0], math.floor)
def shape_depthwiseconvolution(layer, input_shape):
return shape_convolution(layer, input_shape)
def shape_deconvolution(layer, input_shape): def shape_deconvolution(layer, input_shape):
h_i = input_shape[0][2] h_i = input_shape[0][2]
......
...@@ -64,6 +64,12 @@ class ONNXGraphNode(GraphNode): ...@@ -64,6 +64,12 @@ class ONNXGraphNode(GraphNode):
if 'value' not in self.attr_map: if 'value' not in self.attr_map:
return None return None
return self.attr_map['value'] return self.attr_map['value']
@property
def name(self):
if hasattr(self, 'index'):
return "{}_p{}".format(self.layer_name, self.index)
return self.layer_name
def get_attribute_value(self, attr): def get_attribute_value(self, attr):
""" """
...@@ -118,6 +124,10 @@ class ONNXGraphDataNode(GraphNode): ...@@ -118,6 +124,10 @@ class ONNXGraphDataNode(GraphNode):
out_shapes = list() out_shapes = list()
out_shapes.append(values) out_shapes.append(values)
return out_shapes return out_shapes
@property
def name(self):
return self.layer_name
@property @property
def dtype(self): def dtype(self):
...@@ -308,6 +318,7 @@ class ONNXGraph(Graph): ...@@ -308,6 +318,7 @@ class ONNXGraph(Graph):
if ipt_node.layer_name in node.which_child: if ipt_node.layer_name in node.which_child:
ipt_node.index = node.which_child[ipt_node.layer_name] ipt_node.index = node.which_child[ipt_node.layer_name]
return ipt_node return ipt_node
def graph_weights(self): def graph_weights(self):
""" """
......
...@@ -189,6 +189,10 @@ class TFGraph(Graph): ...@@ -189,6 +189,10 @@ class TFGraph(Graph):
if len(items) == 1 and node.layer_type in self.multi_out_ops: if len(items) == 1 and node.layer_type in self.multi_out_ops:
node.index = 0 node.index = 0
return node return node
def get_input_node(self, node, idx=0, copy=False):
input_node_name = node.layer.input[idx]
return self.get_node(input_node_name, copy)
def remove_node(self, node_name): def remove_node(self, node_name):
if node_name not in self.node_map: if node_name not in self.node_map:
...@@ -316,7 +320,7 @@ class TFDecoder(object): ...@@ -316,7 +320,7 @@ class TFDecoder(object):
self.sess = tf.compat.v1.Session() self.sess = tf.compat.v1.Session()
except: except:
self.sess = tf.Session() self.sess = tf.Session()
self.input_info = dict() self.inputs_info = dict()
self.define_input_shape = define_input_shape self.define_input_shape = define_input_shape
with open(pb_model, 'rb') as f: with open(pb_model, 'rb') as f:
try: try:
...@@ -426,50 +430,40 @@ class TFDecoder(object): ...@@ -426,50 +430,40 @@ class TFDecoder(object):
input_map["{}:0".format(layer.name)] = x2paddle_input input_map["{}:0".format(layer.name)] = x2paddle_input
if shape.count(None) > 0: if shape.count(None) > 0:
shape[shape.index(None)] = -1 shape[shape.index(None)] = -1
self.input_info["x2paddle_{}".format(layer.name)] = (shape, self.inputs_info["x2paddle_{}".format(layer.name)] = (shape,
dtype) dtype)
else: else:
value = graph_node.layer.attr["shape"].shape value = graph_node.layer.attr["shape"].shape
shape = [dim.size for dim in value.dim] shape = [dim.size for dim in value.dim]
self.input_info[layer.name] = (shape, dtype) self.inputs_info[layer.name] = (shape, dtype)
return input_map return input_map
# trick method # trick method
# should be removed after PaddlePaddle V1.6 been released # should be removed after PaddlePaddle V1.6 been released
def infer_tensor(self, graph_node): def infer_tensor(self, graph_node, out_shape=None, use_diff_inputs=True):
if hasattr(graph_node, "index"): if hasattr(graph_node, "index"):
tensor_name = graph_node.layer.name + ":{}".format(graph_node.index) tensor_name = graph_node.layer.name + ":{}".format(graph_node.index)
else: else:
tensor_name = graph_node.layer.name + ":0" tensor_name = graph_node.layer.name + ":0"
feed = dict() feed = dict()
for input_name, info in self.input_info.items(): if use_diff_inputs:
(shape, dtype) = cp.deepcopy(info) batch_size = [2, 3, 5]
input_tensor = self.sess.graph.get_tensor_by_name(input_name + ":0")
if shape.count(-1) > 0:
shape[shape.index(-1)] = 2
feed[input_tensor] = numpy.random.random_sample(shape)
output_tensor = self.sess.graph.get_tensor_by_name(tensor_name)
return self.sess.run([output_tensor], feed)[0]
def infer_shape_tensor(self, graph_node, out_shape=None):
if hasattr(graph_node, "index"):
tensor_name = graph_node.layer.name + ":{}".format(graph_node.index)
else: else:
tensor_name = graph_node.layer.name + ":0" batch_size = [2]
feed = dict()
batch_size = [2, 3, 5]
results = list() results = list()
for b in batch_size: for b in batch_size:
for input_name, info in self.input_info.items(): for input_name, info in self.inputs_info.items():
(shape, dtype) = cp.deepcopy(info) (shape, dtype) = cp.deepcopy(info)
input_tensor = self.sess.graph.get_tensor_by_name(input_name + input_tensor = self.sess.graph.get_tensor_by_name(input_name + ":0")
":0")
if shape.count(-1) > 0: if shape.count(-1) > 0:
shape[shape.index(-1)] = b shape[shape.index(-1)] = b
feed[input_tensor] = numpy.random.random_sample(shape) feed[input_tensor] = numpy.random.random_sample(shape)
output_tensor = self.sess.graph.get_tensor_by_name(tensor_name) output_tensor = self.sess.graph.get_tensor_by_name(tensor_name)
results.append(self.sess.run([output_tensor], feed)[0].flatten()) if use_diff_inputs:
results.append(self.sess.run([output_tensor], feed)[0].flatten())
else:
return self.sess.run([output_tensor], feed)[0]
compare01 = (results[0] == results[1]) compare01 = (results[0] == results[1])
compare12 = (results[1] == results[2]) compare12 = (results[1] == results[2])
...@@ -494,38 +488,4 @@ class TFDecoder(object): ...@@ -494,38 +488,4 @@ class TFDecoder(object):
return results[0].tolist() return results[0].tolist()
else: else:
raise Exception("Couldn't infer a stable shape shape tensor value") raise Exception("Couldn't infer a stable shape shape tensor value")
def infer_tensor_shape(self, graph_node): \ No newline at end of file
if hasattr(graph_node, "index"):
tensor_name = graph_node.layer.name + ":{}".format(graph_node.index)
else:
tensor_name = graph_node.layer.name + ":0"
feed = dict()
batch_size = [2, 3, 5]
shapes = list()
for b in batch_size:
for input_name, info in self.input_info.items():
(shape, dtype) = cp.deepcopy(info)
input_tensor = self.sess.graph.get_tensor_by_name(input_name +
":0")
if shape.count(-1) > 0:
shape[shape.index(-1)] = b
feed[input_tensor] = numpy.random.random_sample(shape)
output_tensor = self.sess.graph.get_tensor_by_name(tensor_name)
shape = self.sess.run([output_tensor], feed)[0].shape
shapes.append(numpy.array(shape))
compare01 = (shapes[0] == shapes[1])
compare12 = (shapes[1] == shapes[2])
if compare01.all() and compare12.all():
return shape[0].tolist()
if (compare01 == compare12).all():
index = numpy.argwhere(compare01 == False).flatten()
if index.shape[0] != 1:
raise Exception("There's not only one unstable dimension")
if index[0] != 0:
raise Exception("Batch size not in the first dimension")
shapes[0][0] = -1
return shapes[0].tolist()
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys
from x2paddle.op_mapper.dygraph.onnx2paddle.opset9 import OpSet9 from x2paddle.op_mapper.dygraph.onnx2paddle.opset9 import OpSet9
from x2paddle.core.op_mapper import OpMapper from x2paddle.core.op_mapper import OpMapper
from x2paddle.decoder.onnx_decoder import ONNXGraphNode from x2paddle.decoder.onnx_decoder import ONNXGraphNode
...@@ -25,34 +26,33 @@ class ONNXOpMapper(OpMapper): ...@@ -25,34 +26,33 @@ class ONNXOpMapper(OpMapper):
self.default_op_set = 9 self.default_op_set = 9
self.graph = decoder.graph self.graph = decoder.graph
self.paddle_graph = PaddleGraph(parent_layer=None, graph_type="dygraph", source_type="onnx") self.paddle_graph = PaddleGraph(parent_layer=None, graph_type="dygraph", source_type="onnx")
self.paddle_graph.outputs = self.graph.output_nodes
self.opset = self.create_opset(decoder) self.opset = self.create_opset(decoder)
if not self.op_checker(): if not self.op_checker():
raise Exception("Model are not supported yet.") raise Exception("Model is not supported yet.")
#mapping op
print("Total nodes: {}".format( print("Total nodes: {}".format(
sum([ sum([
isinstance(node, ONNXGraphNode) isinstance(node, ONNXGraphNode)
for name, node in self.graph.node_map.items() for name, node in self.graph.node_map.items()
]))) ])))
print("Nodes converting ...") print("Nodes converting ...")
for node_name in self.graph.topo_sort: for i, node_name in enumerate(self.graph.topo_sort):
sys.stderr.write("\rConverting node {} ... ".format(i + 1))
node = self.graph.get_node(node_name) node = self.graph.get_node(node_name)
op = node.layer_type op = node.layer_type
if hasattr(self.opset, op): if hasattr(self.opset, op):
func = getattr(self.opset, op) func = getattr(self.opset, op)
func(node) func(node)
elif op in self.opset.default_op_mapping: elif op in self.opset.directly_map_ops:
self.opset.directly_map(node) self.opset.directly_map(node)
elif op in self.opset.elementwise_ops: elif op in self.opset.elementwise_ops:
self.opset.elementwise_map(node) self.opset.elementwise_map(node)
print("Nodes converted.") print("\nNodes converted.")
self.weights = self.opset.weights
self.inputs_info = self.opset.inputs_info
self.paddle_graph.set_name(self.graph.graph_name) self.paddle_graph.set_name(self.graph.graph_name)
self.paddle_graph.set_parameters(self.weights) self.paddle_graph.set_parameters(self.opset.weights)
self.paddle_graph.set_inputs_info(self.inputs_info) self.paddle_graph.set_inputs_info(self.opset.inputs_info)
self.paddle_graph.outputs = self.graph.output_nodes
def op_checker(self): def op_checker(self):
unsupported_ops = set() unsupported_ops = set()
...@@ -60,16 +60,17 @@ class ONNXOpMapper(OpMapper): ...@@ -60,16 +60,17 @@ class ONNXOpMapper(OpMapper):
node = self.graph.get_node(node_name) node = self.graph.get_node(node_name)
op = node.layer_type op = node.layer_type
if not hasattr(self.opset, op) and \ if not hasattr(self.opset, op) and \
op not in self.opset.default_op_mapping and \ op not in self.opset.directly_map_ops and \
op not in self.opset.elementwise_ops: op not in self.opset.elementwise_ops:
unsupported_ops.add(op) unsupported_ops.add(op)
if len(unsupported_ops) == 0: if len(unsupported_ops) == 0:
return True return True
else: else:
print("There are {} ops not supported yet, list as below".format( if len(unsupported_ops) > 0:
len(unsupported_ops))) print("\n========= {} OPs are not supported yet ===========".format(
len(unsupported_ops)))
for op in unsupported_ops: for op in unsupported_ops:
print(op) print("========== {} ============".format(op))
return False return False
def create_opset(self, decoder): def create_opset(self, decoder):
......
...@@ -38,28 +38,34 @@ class PyTorchOpMapper(OpMapper): ...@@ -38,28 +38,34 @@ class PyTorchOpMapper(OpMapper):
self.scope_name2id = dict() self.scope_name2id = dict()
self.inputs_info = dict() self.inputs_info = dict()
# 转换 # 转换
self.check_op(decoder.graph) if not self.op_checker(decoder.graph):
raise Exception("Model is not supported yet.")
self.paddle_graph, _ = self.traverse(decoder.graph) self.paddle_graph, _ = self.traverse(decoder.graph)
self.paddle_graph.set_inputs_info(self.inputs_info) self.paddle_graph.set_inputs_info(self.inputs_info)
def check_op(self, script_graph): def op_checker(self, script_graph):
def _update_op_list(graph): def _update_op_list(graph):
for node in graph.nodes(): for node in graph.nodes():
op_list.append(node.kind()) op_list.append(node.kind())
for block in node.blocks(): for block in node.blocks():
_update_op_list(block) _update_op_list(block)
op_list = list() op_list = list()
_update_op_list(script_graph) _update_op_list(script_graph)
op_list = list(set(op_list)) op_list = list(set(op_list))
unsupported_op_list = [] unsupported_ops = []
for op in op_list: for op in op_list:
func_name = op.replace('::', '_') func_name = op.replace('::', '_')
if not (hasattr(prim, func_name) or hasattr(aten, func_name)): if not (hasattr(prim, func_name) or hasattr(aten, func_name)):
unsupported_op_list.append(op) unsupported_ops.append(op)
if len(unsupported_op_list) > 0: if len(unsupported_ops) == 0:
raise Exception("The kind {} in model is not supported yet.".format( return True
unsupported_op_list)) else:
if len(unsupported_ops) > 0:
print("\n========= {} OPs are not supported yet ===========".format(
len(unsupported_ops)))
for op in unsupported_ops:
print("========== {} ============".format(op))
return False
def traverse(self, script_graph, parent_layer=None): def traverse(self, script_graph, parent_layer=None):
# 用于获取graph的输入 # 用于获取graph的输入
......
...@@ -231,7 +231,7 @@ class CaffeOpMapper(OpMapper): ...@@ -231,7 +231,7 @@ class CaffeOpMapper(OpMapper):
self.weights[node.layer_name + '_bias'] = data[1] self.weights[node.layer_name + '_bias'] = data[1]
assert len(node.inputs assert len(node.inputs
) == 1, 'The count of Convolution node\'s input is not 1.' ) == 1, 'The count of Convolution node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_input_node(node, idx=0, copy=True)
layer_attrs = { layer_attrs = {
'filter_size': kernel, 'filter_size': kernel,
'num_filters': channel, 'num_filters': channel,
...@@ -273,7 +273,7 @@ class CaffeOpMapper(OpMapper): ...@@ -273,7 +273,7 @@ class CaffeOpMapper(OpMapper):
self.weights[node.layer_name + '_bias'] = data[1] self.weights[node.layer_name + '_bias'] = data[1]
assert len(node.inputs assert len(node.inputs
) == 1, 'The count of Deconvolution node\'s input is not 1.' ) == 1, 'The count of Deconvolution node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_input_node(node, idx=0, copy=True)
layer_attrs = { layer_attrs = {
'output_size': None, 'output_size': None,
'filter_size': kernel, 'filter_size': kernel,
...@@ -306,7 +306,7 @@ class CaffeOpMapper(OpMapper): ...@@ -306,7 +306,7 @@ class CaffeOpMapper(OpMapper):
pool_type = 'avg' pool_type = 'avg'
assert len( assert len(
node.inputs) == 1, 'The count of Pooling node\'s input is not 1.' node.inputs) == 1, 'The count of Pooling node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_input_node(node, idx=0, copy=True)
layer_attrs = { layer_attrs = {
'pool_size': kernel, 'pool_size': kernel,
'pool_stride': stride, 'pool_stride': stride,
...@@ -333,7 +333,7 @@ class CaffeOpMapper(OpMapper): ...@@ -333,7 +333,7 @@ class CaffeOpMapper(OpMapper):
# just scales by alpha (as does Krizhevsky's paper). # just scales by alpha (as does Krizhevsky's paper).
# We'll account for that here. # We'll account for that here.
alpha = params.alpha / float(params.local_size) alpha = params.alpha / float(params.local_size)
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_input_node(node, idx=0, copy=True)
layer_attrs = { layer_attrs = {
'n': params.local_size, 'n': params.local_size,
'k': params.k, 'k': params.k,
...@@ -381,7 +381,7 @@ class CaffeOpMapper(OpMapper): ...@@ -381,7 +381,7 @@ class CaffeOpMapper(OpMapper):
#params = node.layer.inner_product_param #params = node.layer.inner_product_param
assert params.axis == 1 assert params.axis == 1
assert params.bias_term == True assert params.bias_term == True
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_input_node(node, idx=0, copy=True)
layer_attrs = { layer_attrs = {
'size': params.num_output, 'size': params.num_output,
'name': string(node.layer_name), 'name': string(node.layer_name),
...@@ -399,7 +399,7 @@ class CaffeOpMapper(OpMapper): ...@@ -399,7 +399,7 @@ class CaffeOpMapper(OpMapper):
def Softmax(self, node): def Softmax(self, node):
assert len( assert len(
node.inputs) == 1, 'The count of Softmax node\'s input is not 1.' node.inputs) == 1, 'The count of Softmax node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_input_node(node, idx=0, copy=True)
params = node.layer.softmax_param params = node.layer.softmax_param
axis = params.axis axis = params.axis
shape = node.input_shape[0] shape = node.input_shape[0]
...@@ -415,7 +415,7 @@ class CaffeOpMapper(OpMapper): ...@@ -415,7 +415,7 @@ class CaffeOpMapper(OpMapper):
def Slice(self, node): def Slice(self, node):
assert len( assert len(
node.inputs) == 1, 'The count of Slice node\'s input is not 1.' node.inputs) == 1, 'The count of Slice node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_input_node(node, idx=0, copy=True)
top_len = len(node.layer.top) top_len = len(node.layer.top)
params = node.layer.slice_param params = node.layer.slice_param
axis = params.axis axis = params.axis
...@@ -445,7 +445,7 @@ class CaffeOpMapper(OpMapper): ...@@ -445,7 +445,7 @@ class CaffeOpMapper(OpMapper):
) >= 1, 'The count of Concat node\'s input is not more than 1.' ) >= 1, 'The count of Concat node\'s input is not more than 1.'
inputs_list = [] inputs_list = []
for i in range(len(node.inputs)): for i in range(len(node.inputs)):
input = self.graph.get_bottom_node(node, idx=i, copy=True) input = self.graph.get_input_node(node, idx=i, copy=True)
inputs_list.append(self.get_input_name(input)) inputs_list.append(self.get_input_name(input))
params = node.layer.concat_param params = node.layer.concat_param
axis = params.axis axis = params.axis
...@@ -464,7 +464,7 @@ class CaffeOpMapper(OpMapper): ...@@ -464,7 +464,7 @@ class CaffeOpMapper(OpMapper):
""" """
assert len( assert len(
node.inputs) == 1, 'The count of ReLU node\'s input is not 1.' node.inputs) == 1, 'The count of ReLU node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_input_node(node, idx=0, copy=True)
params = node.layer.relu_param params = node.layer.relu_param
if params.HasField('negative_slope') and params.negative_slope != 0: if params.HasField('negative_slope') and params.negative_slope != 0:
...@@ -483,7 +483,7 @@ class CaffeOpMapper(OpMapper): ...@@ -483,7 +483,7 @@ class CaffeOpMapper(OpMapper):
def PReLU(self, node): def PReLU(self, node):
assert len( assert len(
node.inputs) == 1, 'The count of PReLU node\'s input is not 1.' node.inputs) == 1, 'The count of PReLU node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_input_node(node, idx=0, copy=True)
params = node.layer.prelu_param params = node.layer.prelu_param
mode_bool = params.channel_shared mode_bool = params.channel_shared
if mode_bool: if mode_bool:
...@@ -511,10 +511,10 @@ class CaffeOpMapper(OpMapper): ...@@ -511,10 +511,10 @@ class CaffeOpMapper(OpMapper):
inputs_dict = dict() inputs_dict = dict()
for i, shape in enumerate(node.input_shape): for i, shape in enumerate(node.input_shape):
if shape[1] == 1: if shape[1] == 1:
input = self.graph.get_bottom_node(node, idx=i, copy=True) input = self.graph.get_input_node(node, idx=i, copy=True)
inputs_dict["label"] = self.get_input_name(input) inputs_dict["label"] = self.get_input_name(input)
else: else:
input = self.graph.get_bottom_node(node, idx=i, copy=True) input = self.graph.get_input_node(node, idx=i, copy=True)
inputs_dict["input"] = self.get_input_name(input) inputs_dict["input"] = self.get_input_name(input)
params = node.layer.accuracy_param params = node.layer.accuracy_param
top_k = params.top_k top_k = params.top_k
...@@ -534,9 +534,9 @@ class CaffeOpMapper(OpMapper): ...@@ -534,9 +534,9 @@ class CaffeOpMapper(OpMapper):
params = node.layer.eltwise_param params = node.layer.eltwise_param
mode = params.operation mode = params.operation
inputs = [] inputs = []
input0 = self.graph.get_bottom_node(node, idx=0, copy=True) input0 = self.graph.get_input_node(node, idx=0, copy=True)
inputs.append(input0) inputs.append(input0)
input1 = self.graph.get_bottom_node(node, idx=1, copy=True) input1 = self.graph.get_input_node(node, idx=1, copy=True)
inputs.append(input1) inputs.append(input1)
if mode == 0: if mode == 0:
inputs_dict = {} inputs_dict = {}
...@@ -606,7 +606,7 @@ class CaffeOpMapper(OpMapper): ...@@ -606,7 +606,7 @@ class CaffeOpMapper(OpMapper):
def BatchNorm(self, node): def BatchNorm(self, node):
assert len( assert len(
node.inputs) == 1, 'The count of BatchNorm node\'s input is not 1.' node.inputs) == 1, 'The count of BatchNorm node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_input_node(node, idx=0, copy=True)
params = node.layer.batch_norm_param params = node.layer.batch_norm_param
if hasattr(params, 'eps'): if hasattr(params, 'eps'):
eps = params.eps eps = params.eps
...@@ -670,8 +670,8 @@ class CaffeOpMapper(OpMapper): ...@@ -670,8 +670,8 @@ class CaffeOpMapper(OpMapper):
# for two tensor, here resets axis to 1. Maybe there is a bug for unkown case. # for two tensor, here resets axis to 1. Maybe there is a bug for unkown case.
axis = 1 axis = 1
bias_shape = node.input_shape[0][axis:axis + num_axes] bias_shape = node.input_shape[0][axis:axis + num_axes]
input0 = self.graph.get_bottom_node(node, idx=0, copy=True) input0 = self.graph.get_input_node(node, idx=0, copy=True)
input1 = self.graph.get_bottom_node(node, idx=1, copy=True) input1 = self.graph.get_input_node(node, idx=1, copy=True)
inputs_dict = {} inputs_dict = {}
inputs_dict['x'] = self.get_input_name(input0) inputs_dict['x'] = self.get_input_name(input0)
inputs_dict['y'] = self.get_input_name(input1) inputs_dict['y'] = self.get_input_name(input1)
...@@ -682,7 +682,7 @@ class CaffeOpMapper(OpMapper): ...@@ -682,7 +682,7 @@ class CaffeOpMapper(OpMapper):
axis=axis) axis=axis)
else: else:
bias_shape = node.input_shape[0][axis:axis + num_axes] bias_shape = node.input_shape[0][axis:axis + num_axes]
input0 = self.graph.get_bottom_node(node, idx=0, copy=True) input0 = self.graph.get_input_node(node, idx=0, copy=True)
input0_name = self.get_input_name(input0) input0_name = self.get_input_name(input0)
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
kernel="fluid.ParamAttr", kernel="fluid.ParamAttr",
...@@ -739,7 +739,7 @@ class CaffeOpMapper(OpMapper): ...@@ -739,7 +739,7 @@ class CaffeOpMapper(OpMapper):
def Reshape(self, node): def Reshape(self, node):
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_input_node(node, idx=0, copy=True)
top_count = len(input.layer.top) top_count = len(input.layer.top)
is_inplace = False if top_count == 1 else True is_inplace = False if top_count == 1 else True
output_shape = node.output_shape[0] output_shape = node.output_shape[0]
...@@ -759,7 +759,7 @@ class CaffeOpMapper(OpMapper): ...@@ -759,7 +759,7 @@ class CaffeOpMapper(OpMapper):
assert len(node.inputs) == 1 and len( assert len(node.inputs) == 1 and len(
node.outputs node.outputs
) == 1, 'The count of ArgMax node\'s input and output is not 1.' ) == 1, 'The count of ArgMax node\'s input and output is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_input_node(node, idx=0, copy=True)
input_shape = node.input_shape[0] input_shape = node.input_shape[0]
params = node.layer.argmax_param params = node.layer.argmax_param
out_max_val = params.out_max_val if hasattr(params, out_max_val = params.out_max_val if hasattr(params,
...@@ -796,8 +796,8 @@ class CaffeOpMapper(OpMapper): ...@@ -796,8 +796,8 @@ class CaffeOpMapper(OpMapper):
def Crop(self, node): def Crop(self, node):
assert len( assert len(
node.inputs) == 2, 'The count of Crop node\'s input is not 2.' node.inputs) == 2, 'The count of Crop node\'s input is not 2.'
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_input_node(node, idx=0, copy=True)
example = self.graph.get_bottom_node(node, idx=1, copy=True) example = self.graph.get_input_node(node, idx=1, copy=True)
params = node.layer.crop_param params = node.layer.crop_param
axis = params.axis axis = params.axis
input_shape = node.input_shape[0] input_shape = node.input_shape[0]
...@@ -822,7 +822,7 @@ class CaffeOpMapper(OpMapper): ...@@ -822,7 +822,7 @@ class CaffeOpMapper(OpMapper):
assert len( assert len(
node. node.
inputs) == 1, 'The count of DetectionOutput node\'s input is not 1.' inputs) == 1, 'The count of DetectionOutput node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_input_node(node, idx=0, copy=True)
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
kernel="fluid.layers.reshape", kernel="fluid.layers.reshape",
inputs={"x": self.get_input_name(input)}, inputs={"x": self.get_input_name(input)},
...@@ -832,7 +832,7 @@ class CaffeOpMapper(OpMapper): ...@@ -832,7 +832,7 @@ class CaffeOpMapper(OpMapper):
def Power(self, node): def Power(self, node):
assert len( assert len(
node.inputs) == 1, 'The count of Permute node\'s input is not 1.' node.inputs) == 1, 'The count of Permute node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_input_node(node, idx=0, copy=True)
params = node.layer.power_param params = node.layer.power_param
power = params.power power = params.power
scale = params.scale scale = params.scale
...@@ -857,7 +857,7 @@ class CaffeOpMapper(OpMapper): ...@@ -857,7 +857,7 @@ class CaffeOpMapper(OpMapper):
def Reduction(self, node): def Reduction(self, node):
assert len( assert len(
node.inputs) == 1, 'The count of Reduction node\'s input is not 1.' node.inputs) == 1, 'The count of Reduction node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_input_node(node, idx=0, copy=True)
params = node.layer.reduction_param params = node.layer.reduction_param
operation = params.operation operation = params.operation
axis = params.axis axis = params.axis
...@@ -942,15 +942,15 @@ class CaffeOpMapper(OpMapper): ...@@ -942,15 +942,15 @@ class CaffeOpMapper(OpMapper):
self.weights[weights_name[i]] = data[i] self.weights[weights_name[i]] = data[i]
inputs_list = [] inputs_list = []
for i in range(len(node.inputs)): for i in range(len(node.inputs)):
input = self.graph.get_bottom_node(node, idx=i, copy=True) input = self.graph.get_input_node(node, idx=i, copy=True)
if i == 1 and op == 'DetectionOutput': if i == 1 and op == 'DetectionOutput':
input = self.graph.get_bottom_node(node, idx=i, copy=True) input = self.graph.get_input_node(node, idx=i, copy=True)
while input is not None \ while input is not None \
and input.layer_type != 'Softmax' \ and input.layer_type != 'Softmax' \
and input.layer_type != 'Sigmoid': and input.layer_type != 'Sigmoid':
input = self.graph.get_bottom_node(input, idx=0, copy=True) input = self.graph.get_input_node(input, idx=0, copy=True)
assert input is not None, 'This kind of DetectionOutput is not supported!' assert input is not None, 'This kind of DetectionOutput is not supported!'
input = self.graph.get_bottom_node(input, idx=0, copy=True) input = self.graph.get_input_node(input, idx=0, copy=True)
inputs_list.append(self.get_input_name(input)) inputs_list.append(self.get_input_name(input))
kwargs_tmp = copy.deepcopy(kwargs) kwargs_tmp = copy.deepcopy(kwargs)
for k, v in kwargs_tmp.items(): for k, v in kwargs_tmp.items():
...@@ -970,7 +970,7 @@ class CaffeOpMapper(OpMapper): ...@@ -970,7 +970,7 @@ class CaffeOpMapper(OpMapper):
def directly_map(self, node): def directly_map(self, node):
assert node.layer_type in self.directly_map_ops assert node.layer_type in self.directly_map_ops
op_info = self.directly_map_ops[node.layer_type] op_info = self.directly_map_ops[node.layer_type]
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_input_node(node, idx=0, copy=True)
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
kernel=op_info, kernel=op_info,
inputs={"x": self.get_input_name(input)}, inputs={"x": self.get_input_name(input)},
......
...@@ -359,7 +359,7 @@ class TFOpMapper(OpMapper): ...@@ -359,7 +359,7 @@ class TFOpMapper(OpMapper):
kernel_value = kernel.value kernel_value = kernel.value
kernel_weight_name = kernel.name.replace('/', '_') kernel_weight_name = kernel.name.replace('/', '_')
else: else:
kernel_value = self.decoder.infer_tensor(kernel) kernel_value = self.decoder.infer_tensor(kernel, use_diff_inputs=False)
if kernel.layer_type == 'Split': if kernel.layer_type == 'Split':
kernel_weight_name = "{}_{}_kernel".format(node.name, kernel_weight_name = "{}_{}_kernel".format(node.name,
kernel.name) kernel.name)
...@@ -781,15 +781,15 @@ class TFOpMapper(OpMapper): ...@@ -781,15 +781,15 @@ class TFOpMapper(OpMapper):
if strides.layer_type == "Const": if strides.layer_type == "Const":
strides = strides.value.tolist() strides = strides.value.tolist()
else: else:
strides = self.decoder.infer_shape_tensor(strides) strides = self.decoder.infer_tensor(strides)
if begin.layer_type == "Const": if begin.layer_type == "Const":
begin = begin.value.tolist() begin = begin.value.tolist()
else: else:
begin = self.decoder.infer_shape_tensor(begin) begin = self.decoder.infer_tensor(begin)
if end.layer_type == "Const": if end.layer_type == "Const":
end = end.value.tolist() end = end.value.tolist()
else: else:
end = self.decoder.infer_shape_tensor(end) end = self.decoder.infer_tensor(end)
assert len(set(strides)) == 1 and strides[ assert len(set(strides)) == 1 and strides[
0] == 1, "Only support strides be 1 in StridedSlice OP" 0] == 1, "Only support strides be 1 in StridedSlice OP"
...@@ -897,7 +897,7 @@ class TFOpMapper(OpMapper): ...@@ -897,7 +897,7 @@ class TFOpMapper(OpMapper):
# outputs=[reshape_name], # outputs=[reshape_name],
# shape=shape) # shape=shape)
# inputs['offsets'] = reshape_name # inputs['offsets'] = reshape_name
begin = self.decoder.infer_tensor(begin).tolist() begin = self.decoder.infer_tensor(begin, use_diff_inputs=False).tolist()
attrs['offsets'] = begin attrs['offsets'] = begin
if size.layer_type == "Const": if size.layer_type == "Const":
size = size.value.tolist() size = size.value.tolist()
...@@ -1066,15 +1066,15 @@ class TFOpMapper(OpMapper): ...@@ -1066,15 +1066,15 @@ class TFOpMapper(OpMapper):
if out_shape.layer_type == "Const": if out_shape.layer_type == "Const":
out_shape = out_shape.value.tolist() out_shape = out_shape.value.tolist()
else: else:
out_shape = self.decoder.infer_shape_tensor(out_shape, out_shape = self.decoder.infer_tensor(out_shape,
node.out_shapes[0]) out_shape=node.out_shapes[0])
in_shape = input.out_shapes[0] in_shape = input.out_shapes[0]
if in_shape.count(-1) > 2: if in_shape.count(-1) > 2:
in_shape = self.decoder.infer_tensor(input).shape in_shape = self.decoder.infer_tensor(input, use_diff_inputs=False).shape
k_size = kernel.out_shapes[0] k_size = kernel.out_shapes[0]
if k_size.count(-1) > 2: if k_size.count(-1) > 2:
k_size = self.decoder.infer_tensor(kernel).shape k_size = self.decoder.infer_tensor(input, use_diff_inputs=False).shape
pad_mode = node.get_attr("padding").decode() pad_mode = node.get_attr("padding").decode()
strides = node.get_attr("strides") strides = node.get_attr("strides")
......
# -*- coding:UTF-8 -*-
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License"
......
# -*- coding:UTF-8 -*-
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License"
......
# -*- coding:UTF-8 -*-
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License"
......
# -*- coding:UTF-8 -*-
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License"
......
# -*- coding:UTF-8 -*-
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License"
......
# -*- coding:UTF-8 -*-
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册