提交 f1f98f81 编写于 作者: S SunAhong1993

fix the create_param

上级 a2bac287
...@@ -225,20 +225,14 @@ class TFOpMapper(OpMapper): ...@@ -225,20 +225,14 @@ class TFOpMapper(OpMapper):
self.params[node.name] = node.value self.params[node.name] = node.value
if 0 not in shape: if 0 not in shape:
if dtype != "float32":
self.params[node.name] = node.value.astype("float32")
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
"self.create_parameter", "self.create_parameter",
inputs={}, inputs={},
outputs=[node.name], outputs=[node.name],
shape=shape, shape=shape,
attr=string(node.name)) attr=string(node.name),
if dtype != "float32": dtype=string(dtype),
self.paddle_graph.add_layer( default_initializer="paddle.nn.initializer.Constant(value=0.0)")
kernel="paddle.cast",
inputs={"x": node.name},
outputs=[node.name],
dtype=string(dtype))
def Transpose(self, node): def Transpose(self, node):
input = self.graph.get_node(node.layer.input[0]) input = self.graph.get_node(node.layer.input[0])
......
...@@ -33,9 +33,9 @@ class GraphOptimizer(object): ...@@ -33,9 +33,9 @@ class GraphOptimizer(object):
self.passes = ["static_bn_scale_fuse_pass"] self.passes = ["static_bn_scale_fuse_pass"]
elif source_frame == "tf": elif source_frame == "tf":
self.passes = [ self.passes = [
"dygraph_conv2d_add_fuse_pass", # "dygraph_conv2d_add_fuse_pass",
"dygraph_tf_batchnorm_fuse_pass", # "dygraph_tf_batchnorm_fuse_pass",
"transpose_eliminate_pass" # "transpose_eliminate_pass"
] ]
else: else:
# TODO # TODO
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册