提交 75810b21 编写于 作者: D daiwk

support bert finetune

上级 06e5521a
...@@ -52,6 +52,24 @@ def export_paddle_param(param, param_name, dir): ...@@ -52,6 +52,24 @@ def export_paddle_param(param, param_name, dir):
def run_net(param_dir="./"): def run_net(param_dir="./"):
import os import os
inputs, outputs = x2paddle_net() inputs, outputs = x2paddle_net()
ops = fluid.default_main_program().global_block().ops
used_vars = list()
for op in ops:
used_vars += op.input_arg_names
tmp = list()
for input in inputs:
if isinstance(input, list):
for ipt in input:
if ipt.name not in used_vars:
continue
tmp.append(ipt)
else:
if input.name not in used_vars:
continue
tmp.append(input)
inputs = tmp
for i, out in enumerate(outputs): for i, out in enumerate(outputs):
if isinstance(out, list): if isinstance(out, list):
for out_part in out: for out_part in out:
...@@ -121,12 +139,30 @@ class OpMapper(object): ...@@ -121,12 +139,30 @@ class OpMapper(object):
import model import model
try: try:
inputs, outputs = model.x2paddle_net() inputs, outputs = model.x2paddle_net()
ops = fluid.default_main_program().global_block().ops
used_vars = list()
for op in ops:
used_vars += op.input_arg_names
for i, out in enumerate(outputs): for i, out in enumerate(outputs):
if isinstance(out, list): if isinstance(out, list):
for out_part in out: for out_part in out:
outputs.append(out_part) outputs.append(out_part)
del outputs[i] del outputs[i]
input_names = [input.name for input in inputs]
input_names = list()
for input in inputs:
if isinstance(input, list):
for ipt in input:
if ipt.name not in used_vars:
continue
input_names.append(ipt.name)
else:
if input.name not in used_vars:
continue
input_names.append(input.name)
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
......
...@@ -48,7 +48,7 @@ class TFGraphNode(GraphNode): ...@@ -48,7 +48,7 @@ class TFGraphNode(GraphNode):
@property @property
def out_shapes(self): def out_shapes(self):
if self.layer_type == "OneShotIterator": if self.layer_type == "OneShotIterator" or self.layer_type == "IteratorV2":
values = self.layer.attr["output_shapes"].list.shape values = self.layer.attr["output_shapes"].list.shape
else: else:
values = self.layer.attr["_output_shapes"].list.shape values = self.layer.attr["_output_shapes"].list.shape
...@@ -115,7 +115,7 @@ class TFGraph(Graph): ...@@ -115,7 +115,7 @@ class TFGraph(Graph):
def __init__(self, model, data_format="NHWC"): def __init__(self, model, data_format="NHWC"):
super(TFGraph, self).__init__(model) super(TFGraph, self).__init__(model)
self.identity_map = dict() self.identity_map = dict()
self.multi_out_ops = ['Split', 'SplitV'] self.multi_out_ops = ['Split', 'SplitV', 'IteratorV2']
self.tf_data_format = data_format self.tf_data_format = data_format
def build(self): def build(self):
...@@ -335,7 +335,7 @@ class TFDecoder(object): ...@@ -335,7 +335,7 @@ class TFDecoder(object):
graph_def = cp.deepcopy(graph_def) graph_def = cp.deepcopy(graph_def)
input_map = dict() input_map = dict()
for layer in graph_def.node: for layer in graph_def.node:
if layer.op != "Placeholder" and layer.op != "OneShotIterator": if layer.op != "Placeholder" and layer.op != "OneShotIterator" and layer.op != "IteratorV2":
continue continue
graph_node = TFGraphNode(layer) graph_node = TFGraphNode(layer)
dtype = graph_node.layer.attr['dtype'].type dtype = graph_node.layer.attr['dtype'].type
......
...@@ -70,7 +70,7 @@ class TFOpMapperNHWC(OpMapper): ...@@ -70,7 +70,7 @@ class TFOpMapperNHWC(OpMapper):
not_placeholder = list() not_placeholder = list()
for name in self.graph.input_nodes: for name in self.graph.input_nodes:
if self.graph.get_node(name).layer_type != "Placeholder": if self.graph.get_node(name).layer_type != "Placeholder" and self.graph.get_node(name).layer_type != "OneShotIterator" and self.graph.get_node(name).layer_type != "IteratorV2":
not_placeholder.append(name) not_placeholder.append(name)
for name in not_placeholder: for name in not_placeholder:
idx = self.graph.input_nodes.index(name) idx = self.graph.input_nodes.index(name)
...@@ -79,7 +79,7 @@ class TFOpMapperNHWC(OpMapper): ...@@ -79,7 +79,7 @@ class TFOpMapperNHWC(OpMapper):
unsupported_ops = set() unsupported_ops = set()
sys.stderr.write("Total nodes: {}\n".format(len(self.graph.topo_sort))) sys.stderr.write("Total nodes: {}\n".format(len(self.graph.topo_sort)))
for i, node_name in enumerate(self.graph.topo_sort): for i, node_name in enumerate(self.graph.topo_sort):
sys.stderr.write("\rConverting node {} ... ".format(i + 1)) sys.stderr.write("\rConverting node {} name: {:50}... ".format(i + 1, node_name))
node = self.graph.get_node(node_name) node = self.graph.get_node(node_name)
op = node.layer_type op = node.layer_type
if op in self.directly_map_ops: if op in self.directly_map_ops:
...@@ -94,11 +94,12 @@ class TFOpMapperNHWC(OpMapper): ...@@ -94,11 +94,12 @@ class TFOpMapperNHWC(OpMapper):
if len(unsupported_ops) > 0: if len(unsupported_ops) > 0:
continue continue
func = getattr(self, op) func = getattr(self, op)
try: func(node)
func(node) # try:
except Exception as e: # func(node)
print(str(e)) # except Exception as e:
unsupported_ops.add(op) # print(str(e))
# unsupported_ops.add(op)
else: else:
unsupported_ops.add(op) unsupported_ops.add(op)
if len(unsupported_ops) > 0: if len(unsupported_ops) > 0:
...@@ -710,6 +711,10 @@ class TFOpMapperNHWC(OpMapper): ...@@ -710,6 +711,10 @@ class TFOpMapperNHWC(OpMapper):
def BatchMatMul(self, node): def BatchMatMul(self, node):
return self.MatMul(node) return self.MatMul(node)
def BatchMatMulV2(self, node):
return self.MatMul(node)
def ArgMax(self, node): def ArgMax(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
axis = self.graph.get_node(node.layer.input[1], copy=True) axis = self.graph.get_node(node.layer.input[1], copy=True)
...@@ -810,13 +815,14 @@ class TFOpMapperNHWC(OpMapper): ...@@ -810,13 +815,14 @@ class TFOpMapperNHWC(OpMapper):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
begin = self.graph.get_node(node.layer.input[1], copy=True) begin = self.graph.get_node(node.layer.input[1], copy=True)
size = self.graph.get_node(node.layer.input[2], copy=True) size = self.graph.get_node(node.layer.input[2], copy=True)
self.add_omit_nodes(begin.layer_name, node.layer_name) self.add_omit_nodes(begin.layer_name, node.layer_name)
self.add_omit_nodes(size.layer_name, node.layer_name) self.add_omit_nodes(size.layer_name, node.layer_name)
if begin.layer_type == "Const": if begin.layer_type == "Const":
begin = begin.value.tolist() begin = begin.value.tolist()
else: else:
begin = self.decoder.infer_tensor(begin).tolist() begin = self.decoder.infer_tensor(begin).tolist()
if size.layer_type == "const": if size.layer_type == "Const":
size = size.value.tolist() size = size.value.tolist()
else: else:
size = self.decoder.infer_tensor(size).tolist() size = self.decoder.infer_tensor(size).tolist()
...@@ -1173,3 +1179,52 @@ class TFOpMapperNHWC(OpMapper): ...@@ -1173,3 +1179,52 @@ class TFOpMapperNHWC(OpMapper):
inputs=inputs, inputs=inputs,
output=node, output=node,
param_attr=attr) param_attr=attr)
def OneShotIterator(self, node):
return self.Placeholder(node)
def IteratorV2(self, node):
dtype_map = {
1: "float32",
3: "int32",
4: "uint8",
9: "int64",
10: "bool"
}
shapes = node.out_shapes
dtypes = node.layer.attr['output_types'].list.type
node.fluid_code.add_note("{} = [0] * {}".format(node.layer_name, len(shapes)))
for i, shape in enumerate(shapes):
attr = {
'dtype': string(dtype_map[dtypes[i]]),
'shape': shape,
'name': string("{}_{}".format(node.layer_name, i)),
'append_batch_size': False
}
output = "{}[{}]".format(node.layer_name, i)
node.fluid_code.add_layer("data",
inputs=None,
output=output,
param_attr=attr)
def Fill(self, node):
dims = self.graph.get_node(node.layer.input[0], copy=True)
value = self.graph.get_node(node.layer.input[1], copy=True)
assert dims.layer_type == 'Const', "Only support Const parameter in Fill OP"
assert value.layer_type == 'Const', "Only support Const parameter in Fill OP"
self.add_omit_nodes(dims.layer_name, node.layer_name)
self.add_omit_nodes(value.layer_name, node.layer_name)
dims = dims.value.tolist()
value = value.value
initializer = "Constant({})".format(value)
attr = {
'dtype': string(node.dtype),
'shape': dims,
'name': string(node.layer_name),
'default_initializer': initializer
}
node.fluid_code.add_layer("create_parameter",
inputs=None,
output=node,
param_attr=attr)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册