提交 75810b21 编写于 作者: D daiwk

support bert finetune

上级 06e5521a
......@@ -52,6 +52,24 @@ def export_paddle_param(param, param_name, dir):
def run_net(param_dir="./"):
import os
inputs, outputs = x2paddle_net()
ops = fluid.default_main_program().global_block().ops
used_vars = list()
for op in ops:
used_vars += op.input_arg_names
tmp = list()
for input in inputs:
if isinstance(input, list):
for ipt in input:
if ipt.name not in used_vars:
continue
tmp.append(ipt)
else:
if input.name not in used_vars:
continue
tmp.append(input)
inputs = tmp
for i, out in enumerate(outputs):
if isinstance(out, list):
for out_part in out:
......@@ -121,12 +139,30 @@ class OpMapper(object):
import model
try:
inputs, outputs = model.x2paddle_net()
ops = fluid.default_main_program().global_block().ops
used_vars = list()
for op in ops:
used_vars += op.input_arg_names
for i, out in enumerate(outputs):
if isinstance(out, list):
for out_part in out:
outputs.append(out_part)
del outputs[i]
input_names = [input.name for input in inputs]
input_names = list()
for input in inputs:
if isinstance(input, list):
for ipt in input:
if ipt.name not in used_vars:
continue
input_names.append(ipt.name)
else:
if input.name not in used_vars:
continue
input_names.append(input.name)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
......
......@@ -48,7 +48,7 @@ class TFGraphNode(GraphNode):
@property
def out_shapes(self):
if self.layer_type == "OneShotIterator":
if self.layer_type == "OneShotIterator" or self.layer_type == "IteratorV2":
values = self.layer.attr["output_shapes"].list.shape
else:
values = self.layer.attr["_output_shapes"].list.shape
......@@ -115,7 +115,7 @@ class TFGraph(Graph):
def __init__(self, model, data_format="NHWC"):
super(TFGraph, self).__init__(model)
self.identity_map = dict()
self.multi_out_ops = ['Split', 'SplitV']
self.multi_out_ops = ['Split', 'SplitV', 'IteratorV2']
self.tf_data_format = data_format
def build(self):
......@@ -335,7 +335,7 @@ class TFDecoder(object):
graph_def = cp.deepcopy(graph_def)
input_map = dict()
for layer in graph_def.node:
if layer.op != "Placeholder" and layer.op != "OneShotIterator":
if layer.op != "Placeholder" and layer.op != "OneShotIterator" and layer.op != "IteratorV2":
continue
graph_node = TFGraphNode(layer)
dtype = graph_node.layer.attr['dtype'].type
......
......@@ -70,7 +70,7 @@ class TFOpMapperNHWC(OpMapper):
not_placeholder = list()
for name in self.graph.input_nodes:
if self.graph.get_node(name).layer_type != "Placeholder":
if self.graph.get_node(name).layer_type != "Placeholder" and self.graph.get_node(name).layer_type != "OneShotIterator" and self.graph.get_node(name).layer_type != "IteratorV2":
not_placeholder.append(name)
for name in not_placeholder:
idx = self.graph.input_nodes.index(name)
......@@ -79,7 +79,7 @@ class TFOpMapperNHWC(OpMapper):
unsupported_ops = set()
sys.stderr.write("Total nodes: {}\n".format(len(self.graph.topo_sort)))
for i, node_name in enumerate(self.graph.topo_sort):
sys.stderr.write("\rConverting node {} ... ".format(i + 1))
sys.stderr.write("\rConverting node {} name: {:50}... ".format(i + 1, node_name))
node = self.graph.get_node(node_name)
op = node.layer_type
if op in self.directly_map_ops:
......@@ -94,11 +94,12 @@ class TFOpMapperNHWC(OpMapper):
if len(unsupported_ops) > 0:
continue
func = getattr(self, op)
try:
func(node)
except Exception as e:
print(str(e))
unsupported_ops.add(op)
# try:
# func(node)
# except Exception as e:
# print(str(e))
# unsupported_ops.add(op)
else:
unsupported_ops.add(op)
if len(unsupported_ops) > 0:
......@@ -710,6 +711,10 @@ class TFOpMapperNHWC(OpMapper):
def BatchMatMul(self, node):
return self.MatMul(node)
def BatchMatMulV2(self, node):
return self.MatMul(node)
def ArgMax(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
axis = self.graph.get_node(node.layer.input[1], copy=True)
......@@ -810,13 +815,14 @@ class TFOpMapperNHWC(OpMapper):
input = self.graph.get_node(node.layer.input[0], copy=True)
begin = self.graph.get_node(node.layer.input[1], copy=True)
size = self.graph.get_node(node.layer.input[2], copy=True)
self.add_omit_nodes(begin.layer_name, node.layer_name)
self.add_omit_nodes(size.layer_name, node.layer_name)
if begin.layer_type == "Const":
begin = begin.value.tolist()
else:
begin = self.decoder.infer_tensor(begin).tolist()
if size.layer_type == "const":
if size.layer_type == "Const":
size = size.value.tolist()
else:
size = self.decoder.infer_tensor(size).tolist()
......@@ -1173,3 +1179,52 @@ class TFOpMapperNHWC(OpMapper):
inputs=inputs,
output=node,
param_attr=attr)
def OneShotIterator(self, node):
return self.Placeholder(node)
def IteratorV2(self, node):
dtype_map = {
1: "float32",
3: "int32",
4: "uint8",
9: "int64",
10: "bool"
}
shapes = node.out_shapes
dtypes = node.layer.attr['output_types'].list.type
node.fluid_code.add_note("{} = [0] * {}".format(node.layer_name, len(shapes)))
for i, shape in enumerate(shapes):
attr = {
'dtype': string(dtype_map[dtypes[i]]),
'shape': shape,
'name': string("{}_{}".format(node.layer_name, i)),
'append_batch_size': False
}
output = "{}[{}]".format(node.layer_name, i)
node.fluid_code.add_layer("data",
inputs=None,
output=output,
param_attr=attr)
def Fill(self, node):
dims = self.graph.get_node(node.layer.input[0], copy=True)
value = self.graph.get_node(node.layer.input[1], copy=True)
assert dims.layer_type == 'Const', "Only support Const parameter in Fill OP"
assert value.layer_type == 'Const', "Only support Const parameter in Fill OP"
self.add_omit_nodes(dims.layer_name, node.layer_name)
self.add_omit_nodes(value.layer_name, node.layer_name)
dims = dims.value.tolist()
value = value.value
initializer = "Constant({})".format(value)
attr = {
'dtype': string(node.dtype),
'shape': dims,
'name': string(node.layer_name),
'default_initializer': initializer
}
node.fluid_code.add_layer("create_parameter",
inputs=None,
output=node,
param_attr=attr)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册