提交 4e29cc70 编写于 作者: W wjj19950828

rm useless code

上级 10969ba5
...@@ -1150,9 +1150,7 @@ class OpSet9(): ...@@ -1150,9 +1150,7 @@ class OpSet9():
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
starts, ends, axes, steps = None, None, None, None starts, ends, axes, steps = None, None, None, None
layer_attrs = {} layer_attrs = {}
if val_x.dtype not in [ if val_x.dtype == 'uint8':
"float16", "float32", "float64", "int32", "int64"
]:
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.cast', 'paddle.cast',
inputs={"x": val_x.name}, inputs={"x": val_x.name},
...@@ -1878,6 +1876,7 @@ class OpSet9(): ...@@ -1878,6 +1876,7 @@ class OpSet9():
if axes is None: if axes is None:
axes_node = self.graph.get_input_node(node, idx=1, copy=True) axes_node = self.graph.get_input_node(node, idx=1, copy=True)
axes = _const_weight_or_none(axes_node, necessary=True) axes = _const_weight_or_none(axes_node, necessary=True)
# deal with scalar(0D) tensor
if len(val_x.out_shapes[0]) <= 1 and len(axes) == 1 and axes[0] == 0: if len(val_x.out_shapes[0]) <= 1 and len(axes) == 1 and axes[0] == 0:
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
"paddle.cast", "paddle.cast",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册