提交 f1f98f81 编写于 作者: S SunAhong1993

fix the create_param

上级 a2bac287
无相关合并请求
......@@ -225,20 +225,14 @@ class TFOpMapper(OpMapper):
self.params[node.name] = node.value
if 0 not in shape:
if dtype != "float32":
self.params[node.name] = node.value.astype("float32")
self.paddle_graph.add_layer(
"self.create_parameter",
inputs={},
outputs=[node.name],
shape=shape,
attr=string(node.name))
if dtype != "float32":
self.paddle_graph.add_layer(
kernel="paddle.cast",
inputs={"x": node.name},
outputs=[node.name],
dtype=string(dtype))
attr=string(node.name),
dtype=string(dtype),
default_initializer="paddle.nn.initializer.Constant(value=0.0)")
def Transpose(self, node):
input = self.graph.get_node(node.layer.input[0])
......
......@@ -33,9 +33,9 @@ class GraphOptimizer(object):
self.passes = ["static_bn_scale_fuse_pass"]
elif source_frame == "tf":
self.passes = [
"dygraph_conv2d_add_fuse_pass",
"dygraph_tf_batchnorm_fuse_pass",
"transpose_eliminate_pass"
# "dygraph_conv2d_add_fuse_pass",
# "dygraph_tf_batchnorm_fuse_pass",
# "transpose_eliminate_pass"
]
else:
# TODO
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部