提交 97df7f87 编写于 作者: J jiangjiajun

fix for keras bert

上级 13d5bda8
......@@ -60,7 +60,7 @@ class TFGraphNode(GraphNode):
@property
def dtype(self):
keys = ['dtype', 'Tidx', 'T', 'DstT']
keys = ['dtype', 'T', 'DstT']
for k in keys:
dtype = self.layer.attr[k].type
if dtype > 0:
......
......@@ -744,7 +744,7 @@ class TFOpMapper(OpMapper):
input_names = [i.name for i in inputs]
for i, ipt in enumerate(inputs):
if node.dtype == 'bool':
if ipt.dtype == 'bool':
cast_name = gen_name('concat', 'cast')
program.add_layer(
kernel="fluid.layers.cast",
......@@ -1213,9 +1213,17 @@ class TFOpMapper(OpMapper):
attr["dim"] = reduce_idx.value.tolist()
attr["keep_dim"] = node.get_attr("keep_dims")
input_name = input.name
if input.dtype != "bool":
input_name = gen_name("all", "cast")
program.add_layer(
"fluid.layers.cast",
inputs={"x": input.name},
outputs=[input_name],
dtype=string("bool"))
program.add_layer(
"fluid.layers.reduce_all",
inputs={"input": input.name},
inputs={"input": input_name},
outputs=[node.name],
**attr)
......
......@@ -8,6 +8,7 @@ class BatchNormOpt:
pass
def run(self, graph):
print("Optimize: BatchNormOpt...")
layers = copy.deepcopy(graph.layers)
for layer_id, layer in layers.items():
if layer.kernel != "fluid.layers.elementwise_add":
......
......@@ -13,6 +13,7 @@ class BiasOpt:
]
def run(self, graph):
print("Optimize: BiasOpt...")
layers = copy.deepcopy(graph.layers)
for layer_id, layer in layers.items():
if layer.kernel in self.conv_layers or layer.kernel == "fluid.layers.transpose":
......
......@@ -36,6 +36,7 @@ class TransposeOpt:
return count
def run(self, graph):
print("Optimize: TransposeOpt...")
total_layer_num = len(graph.layers)
scanned_layers = set()
optimized_transpose_layers = list()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册