“abc167338a99f1b644e1a5d4bb324a566d2fd87f”上不存在“develop/api_doc/fluid/executor.html”
未验证 提交 f9da3dcb 编写于 作者: W WJJ1995 提交者: GitHub

Add sign op and fixed some bugs in TF search model (#633)

* add TF sign op

* fix TF squeeze bug

* fix TF expandDims bug

* support squeeze_dims is empty

* delete TF stridedslice op input type is bool

* support TF infer_tensor func input type is int

* fixed stridedslice bug
上级 a20d95a7
...@@ -58,7 +58,7 @@ class TFGraphNode(GraphNode): ...@@ -58,7 +58,7 @@ class TFGraphNode(GraphNode):
@property @property
def dtype(self): def dtype(self):
keys = ['dtype', 'T', 'DstT'] keys = ['dtype', 'T', 'DstT', 'Tparams']
for k in keys: for k in keys:
dtype = self.layer.attr[k].type dtype = self.layer.attr[k].type
if dtype > 0: if dtype > 0:
...@@ -109,6 +109,8 @@ class TFGraphNode(GraphNode): ...@@ -109,6 +109,8 @@ class TFGraphNode(GraphNode):
attr = self.layer.attr[name] attr = self.layer.attr[name]
field = attr.WhichOneof('value') field = attr.WhichOneof('value')
value = getattr(attr, field) if field else None value = getattr(attr, field) if field else None
if name == "squeeze_dims" and not value.ListFields():
return None
if isinstance(value, attr_value_pb2.AttrValue.ListValue): if isinstance(value, attr_value_pb2.AttrValue.ListValue):
result = list(value.ListFields()[0][1]) result = list(value.ListFields()[0][1])
...@@ -466,6 +468,9 @@ class TFDecoder(object): ...@@ -466,6 +468,9 @@ class TFDecoder(object):
":0") ":0")
if shape.count(-1) > 0: if shape.count(-1) > 0:
shape[shape.index(-1)] = b shape[shape.index(-1)] = b
if dtype == 3:
feed[input_tensor] = numpy.random.randint(1, 10, size=shape)
else:
feed[input_tensor] = numpy.random.random_sample(shape) feed[input_tensor] = numpy.random.random_sample(shape)
output_tensor = self.sess.graph.get_tensor_by_name(tensor_name) output_tensor = self.sess.graph.get_tensor_by_name(tensor_name)
if use_diff_inputs: if use_diff_inputs:
......
...@@ -680,11 +680,14 @@ class TFOpMapper(): ...@@ -680,11 +680,14 @@ class TFOpMapper():
def Squeeze(self, node): def Squeeze(self, node):
input = self.graph.get_input_node(node, 0) input = self.graph.get_input_node(node, 0)
squeeze_dims = node.get_attr('squeeze_dims') squeeze_dims = node.get_attr('squeeze_dims')
axis = node.get_attr('axis')
if squeeze_dims != None and axis == None:
axis = squeeze_dims
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
kernel="paddle.squeeze", kernel="paddle.squeeze",
inputs={"x": input.name}, inputs={"x": input.name},
outputs=[node.name], outputs=[node.name],
axis=squeeze_dims) axis=axis)
def Shape(self, node): def Shape(self, node):
input = self.graph.get_input_node(node, 0) input = self.graph.get_input_node(node, 0)
...@@ -1519,6 +1522,16 @@ class TFOpMapper(): ...@@ -1519,6 +1522,16 @@ class TFOpMapper():
attr['axis'] = dim attr['axis'] = dim
else: else:
inputs['axis'] = y.name inputs['axis'] = y.name
if len(x.out_shapes[0]) == 0:
value = self.decoder.infer_tensor(x, use_diff_inputs=False).tolist()
self.paddle_graph.add_layer(
"paddle.full",
inputs={},
outputs=[node.name],
dtype=string(x.dtype),
shape=[1],
fill_value=value)
else:
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
"paddle.unsqueeze", inputs=inputs, outputs=[node.name], **attr) "paddle.unsqueeze", inputs=inputs, outputs=[node.name], **attr)
...@@ -1646,3 +1659,26 @@ class TFOpMapper(): ...@@ -1646,3 +1659,26 @@ class TFOpMapper():
inputs={"x": transpose_name}, inputs={"x": transpose_name},
outputs=[node.name], outputs=[node.name],
shape=shape) shape=shape)
def Sign(self, node):
x = self.graph.get_input_node(node, 0)
support_list = ["float16", "float32", "float64"]
if x.dtype not in support_list:
self.paddle_graph.add_layer(
"paddle.cast",
inputs={"x": x.name},
outputs=[node.name],
dtype=string("float32"))
self.paddle_graph.add_layer(
kernel="paddle.sign",
inputs={"x": node.name},
outputs=[node.name])
self.paddle_graph.add_layer(
"paddle.cast",
inputs={"x": node.name},
outputs=[node.name],
dtype=string(x.dtype))
else:
self.paddle_graph.add_layer(
kernel="paddle.sign", inputs={"x": x.name},
outputs=[node.name])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册