提交 9663d20f 编写于 作者: W wjj19950828

fixed topk

上级 0bd1269e
...@@ -2577,27 +2577,44 @@ class OpSet9(): ...@@ -2577,27 +2577,44 @@ class OpSet9():
def TopK(self, node): def TopK(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_k = self.graph.get_input_node(node, idx=1, copy=True) val_k = self.graph.get_input_node(node, idx=1, copy=True)
if val_k.dtype != "int32":
self.paddle_graph.add_layer(
"paddle.cast",
inputs={"x": val_k.name},
outputs=[val_k.name],
dtype=string('int32'))
layer_attrs = dict() layer_attrs = dict()
layer_attrs["axis"] = node.get_attr('axis', -1) layer_attrs["axis"] = node.get_attr('axis', -1)
layer_attrs["largest"] = True if node.get_attr('largest', layer_attrs["largest"] = True if node.get_attr('largest',
1) == 1 else False 1) == 1 else False
layer_attrs["sorted"] = True if node.get_attr('sorted', layer_attrs["sorted"] = True if node.get_attr('sorted',
1) == 1 else False 1) == 1 else False
self.paddle_graph.add_layer( is_k_attr = False
"paddle.topk", k = _const_weight_or_none(val_k)
inputs={"x": val_x.name, if isinstance(k, (list, tuple, np.ndarray)):
"k": val_k.name}, k = k[0]
outputs=[ if k is not None:
"{}_p{}".format(node.layer_name, 0), is_k_attr = True
"{}_p{}".format(node.layer_name, 1) layer_attrs["k"] = k
], if is_k_attr:
**layer_attrs) self.paddle_graph.add_layer(
"paddle.topk",
inputs={"x": val_x.name},
outputs=[
"{}_p{}".format(node.layer_name, 0),
"{}_p{}".format(node.layer_name, 1)
],
**layer_attrs)
else:
if val_k.dtype != "int32":
self.paddle_graph.add_layer(
"paddle.cast",
inputs={"x": val_k.name},
outputs=[val_k.name],
dtype=string('int32'))
self.paddle_graph.add_layer(
"paddle.topk",
inputs={"x": val_x.name,
"k": val_k.name},
outputs=[
"{}_p{}".format(node.layer_name, 0),
"{}_p{}".format(node.layer_name, 1)
],
**layer_attrs)
@print_mapping_info @print_mapping_info
def LRN(self, node): def LRN(self, node):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册