提交 97df7f87 编写于 作者: J jiangjiajun

fix for keras bert

上级 13d5bda8
...@@ -60,7 +60,7 @@ class TFGraphNode(GraphNode): ...@@ -60,7 +60,7 @@ class TFGraphNode(GraphNode):
@property @property
def dtype(self): def dtype(self):
keys = ['dtype', 'Tidx', 'T', 'DstT'] keys = ['dtype', 'T', 'DstT']
for k in keys: for k in keys:
dtype = self.layer.attr[k].type dtype = self.layer.attr[k].type
if dtype > 0: if dtype > 0:
......
...@@ -744,7 +744,7 @@ class TFOpMapper(OpMapper): ...@@ -744,7 +744,7 @@ class TFOpMapper(OpMapper):
input_names = [i.name for i in inputs] input_names = [i.name for i in inputs]
for i, ipt in enumerate(inputs): for i, ipt in enumerate(inputs):
if node.dtype == 'bool': if ipt.dtype == 'bool':
cast_name = gen_name('concat', 'cast') cast_name = gen_name('concat', 'cast')
program.add_layer( program.add_layer(
kernel="fluid.layers.cast", kernel="fluid.layers.cast",
...@@ -1213,9 +1213,17 @@ class TFOpMapper(OpMapper): ...@@ -1213,9 +1213,17 @@ class TFOpMapper(OpMapper):
attr["dim"] = reduce_idx.value.tolist() attr["dim"] = reduce_idx.value.tolist()
attr["keep_dim"] = node.get_attr("keep_dims") attr["keep_dim"] = node.get_attr("keep_dims")
input_name = input.name
if input.dtype != "bool":
input_name = gen_name("all", "cast")
program.add_layer(
"fluid.layers.cast",
inputs={"x": input.name},
outputs=[input_name],
dtype=string("bool"))
program.add_layer( program.add_layer(
"fluid.layers.reduce_all", "fluid.layers.reduce_all",
inputs={"input": input.name}, inputs={"input": input_name},
outputs=[node.name], outputs=[node.name],
**attr) **attr)
......
...@@ -8,6 +8,7 @@ class BatchNormOpt: ...@@ -8,6 +8,7 @@ class BatchNormOpt:
pass pass
def run(self, graph): def run(self, graph):
print("Optimize: BatchNormOpt...")
layers = copy.deepcopy(graph.layers) layers = copy.deepcopy(graph.layers)
for layer_id, layer in layers.items(): for layer_id, layer in layers.items():
if layer.kernel != "fluid.layers.elementwise_add": if layer.kernel != "fluid.layers.elementwise_add":
......
...@@ -13,6 +13,7 @@ class BiasOpt: ...@@ -13,6 +13,7 @@ class BiasOpt:
] ]
def run(self, graph): def run(self, graph):
print("Optimize: BiasOpt...")
layers = copy.deepcopy(graph.layers) layers = copy.deepcopy(graph.layers)
for layer_id, layer in layers.items(): for layer_id, layer in layers.items():
if layer.kernel in self.conv_layers or layer.kernel == "fluid.layers.transpose": if layer.kernel in self.conv_layers or layer.kernel == "fluid.layers.transpose":
......
...@@ -36,6 +36,7 @@ class TransposeOpt: ...@@ -36,6 +36,7 @@ class TransposeOpt:
return count return count
def run(self, graph): def run(self, graph):
print("Optimize: TransposeOpt...")
total_layer_num = len(graph.layers) total_layer_num = len(graph.layers)
scanned_layers = set() scanned_layers = set()
optimized_transpose_layers = list() optimized_transpose_layers = list()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册