提交 06e5521a 编写于 作者: J jiangjiajun

add support for bert

上级 8ec74138
......@@ -25,6 +25,7 @@ class Layer(object):
self.inputs = dict()
self.output = None
self.is_custom_layer = False
self.use_fluid = False
def get_code(self):
layer_code = ""
......@@ -38,6 +39,8 @@ class Layer(object):
layer_code = layer_code + self.op + "("
elif self.op == "=":
layer_code = layer_code
elif self.use_fluid:
layer_code = layer_code + "fluid." + self.op + "("
else:
layer_code = layer_code + "fluid.layers." + self.op + "("
......@@ -105,10 +108,12 @@ class FluidCode(object):
inputs,
output,
param_attr=None,
is_custom_layer=False):
is_custom_layer=False,
use_fluid=False):
layer = Layer()
layer.op = op
layer.is_custom_layer = is_custom_layer
layer.use_fluid = use_fluid
if inputs is not None:
layer.inputs = inputs
layer.output = output
......
......@@ -68,7 +68,8 @@ class TFGraphNode(GraphNode):
if dtype == 0:
dtype = self.layer.attr['output_types'].list.type[0]
if dtype not in self.dtype_map:
raise Exception("Dtype[{}] not in dtype_map".format(dtype))
raise Exception("Dtype[{}] of node({}) not in dtype_map".format(
dtype, self.layer.name))
return self.dtype_map[dtype]
@property
......@@ -119,10 +120,14 @@ class TFGraph(Graph):
def build(self):
for layer in self.model.node:
if layer.op == 'Assert':
continue
self.node_map[layer.name.replace('/', '_').replace(
'-', '_')] = TFGraphNode(layer, data_format=self.tf_data_format)
for layer_name, node in self.node_map.items():
if node.layer_type == 'Const':
continue
for in_node in node.layer.input:
in_node = in_node.replace('/',
'_').replace('-',
......@@ -139,6 +144,14 @@ class TFGraph(Graph):
super(TFGraph, self).build()
for layer in self.model.node:
if layer.op == 'Assert':
for ipt in layer.input:
ipt_name = ipt.replace('-', '_').replace('/', '_')
if ipt_name in self.output_nodes:
idx = self.output_nodes.index(ipt_name)
del self.output_nodes[idx]
# tensorflow graph optimize
self._remove_isolated_node()
self._optimize_dialiation_conv()
......
......@@ -15,6 +15,7 @@
from x2paddle.decoder.tf_decoder import TFGraph
from x2paddle.core.op_mapper import OpMapper
from x2paddle.core.util import *
import math
import inspect
import numpy
import sys
......@@ -53,6 +54,7 @@ class TFOpMapperNHWC(OpMapper):
'RealDiv': 'elementwise_div',
'Sub': 'elementwise_sub',
'Maximum': 'elementwise_max',
'LessEqual': 'less_equal',
'Mul': 'elementwise_mul',
'FloorDiv': 'elementwise_floordiv'
}
......@@ -94,7 +96,8 @@ class TFOpMapperNHWC(OpMapper):
func = getattr(self, op)
try:
func(node)
except:
except Exception as e:
print(str(e))
unsupported_ops.add(op)
else:
unsupported_ops.add(op)
......@@ -154,96 +157,9 @@ class TFOpMapperNHWC(OpMapper):
op_type = self.elementwise_ops[node.layer_type]
x = self.graph.get_node(node.layer.input[0], copy=True)
y = self.graph.get_node(node.layer.input[1], copy=True)
x_shape = x.out_shapes[0]
y_shape = y.out_shapes[0]
if len(x_shape) == 0:
x_shape = [1]
if len(y_shape) == 0:
y_shape = [1]
# incomplement broadcasting support for paddle
x_input = x
y_input = y
if len(x_shape) < len(y_shape):
unrevertable_ops = [
"elementwise_sub", "elementwise_div", "elementwise_floordiv",
"elementwise_mod", "elementwise_pow"
]
if op_type not in unrevertable_ops:
x_input = y
y_input = x
x_shape = y.out_shapes[0]
if len(x_shape) == 0:
x_shape = [1]
y_shape = x.out_shapes[0]
if len(y_shape) == 0:
y_shape = [1]
else:
raise Exception("Unexpected situation happend")
if len(x_shape) == 4 and len(y_shape) == 1:
inputs = {"x": x_input, "y": y_input}
node.fluid_code.add_layer(op_type, inputs=inputs, output=node)
return
is_sub_seq = True
for i in range(len(y_shape)):
index = -1 * i - 1
if y_shape[index] != x_shape[index]:
is_sub_seq = False
if not is_sub_seq:
x_expand_times = [1] * len(x_shape)
y_expand_times = [1] * len(y_shape)
x_need_expand = False
y_need_expand = False
for i in range(len(y_shape)):
index = -1 * i - 1
if y_shape[index] != x_shape[index]:
if y_shape[index] == 1:
y_expand_times[index] = x_shape[index]
y_need_expand = True
elif x_shape[index] == 1:
x_expand_times[index] = y_shape[index]
x_need_expand = True
else:
raise Exception("Unexpected situation happend")
if x_need_expand:
attr = {"expand_times": x_expand_times}
node.fluid_code.add_layer("expand",
inputs=x_input,
output="x_tmp",
param_attr=attr)
x_input = "x_tmp"
if y_need_expand:
attr = {"expand_times": y_expand_times}
node.fluid_code.add_layer("expand",
inputs=y_input,
output="y_tmp",
param_attr=attr)
y_input = "y_tmp"
if len(x_shape) == 4 and len(y_shape) == 4:
node.fluid_code.add_layer("transpose",
inputs=x_input,
output=x_input,
param_attr={'perm': [0, 3, 1, 2]})
node.fluid_code.add_layer("transpose",
inputs=y_input,
output=y_input,
param_attr={'perm': [0, 3, 1, 2]})
inputs = {"x": x_input, "y": y_input}
node.fluid_code.add_layer(op_type,
inputs=inputs,
output=node,
param_attr=None)
node.fluid_code.add_layer("transpose",
inputs=node,
output=node,
param_attr={'perm': [0, 2, 3, 1]})
else:
inputs = {"x": x_input, "y": y_input}
node.fluid_code.add_layer(op_type,
inputs=inputs,
output=node,
param_attr=None)
inputs = {"x": x, "y": y}
node.fluid_code.add_layer(op_type, inputs=inputs, output=node)
def Placeholder(self, node):
shape = node.out_shapes[0]
......@@ -781,12 +697,19 @@ class TFOpMapperNHWC(OpMapper):
inputs=x,
output=x,
param_attr=attr)
if transpose_a is None:
transpose_a = node.get_attr('adj_x')
if transpose_b is None:
transpose_b = node.get_attr('adj_y')
attr = {"transpose_x": transpose_a, "transpose_y": transpose_b}
node.fluid_code.add_layer("matmul",
inputs=inputs,
output=node,
param_attr=attr)
def BatchMatMul(self, node):
return self.MatMul(node)
def ArgMax(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
axis = self.graph.get_node(node.layer.input[1], copy=True)
......@@ -1156,7 +1079,9 @@ class TFOpMapperNHWC(OpMapper):
else:
dim = self.decoder.infer_tensor(y)
self.add_omit_nodes(y.layer_name, node.layer_name)
attr = {'axes': [dim]}
if not isinstance(dim, list):
dim = [dim]
attr = {'axes': dim}
node.fluid_code.add_layer("unsqueeze",
inputs=x,
output=node,
......@@ -1183,3 +1108,68 @@ class TFOpMapperNHWC(OpMapper):
param_attr=None)
else:
raise Exception("SpaceToBatchND is not supported")
def OneHot(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
depth = self.graph.get_node(node.layer.input[1], copy=True)
on_value = self.graph.get_node(node.layer.input[2], copy=True)
off_value = self.graph.get_node(node.layer.input[3], copy=True)
assert depth.layer_type == 'Const', 'Parameter depth should be Const in OneHot'
assert on_value.layer_type == 'Const', 'Parameter on_value should be Const in OneHot'
assert off_value.layer_type == 'Const', 'Parameter off_value should be Const in OneHot'
self.add_omit_nodes(depth.layer_name, node.layer_name)
self.add_omit_nodes(on_value.layer_name, node.layer_name)
self.add_omit_nodes(off_value.layer_name, node.layer_name)
depth = depth.value
on_value = on_value.value
off_value = off_value.value
assert math.fabs(on_value -
1.0) < 1e-06, "on_value should be 1 in OneHot"
assert math.fabs(off_value -
0.0) < 1e-06, "off_value should be 0 in OneHot"
attr = {'depth': depth}
node.fluid_code.add_layer("one_hot",
inputs=input,
output=node,
param_attr=attr,
use_fluid=True)
def Pow(self, node):
x = self.graph.get_node(node.layer.input[0], copy=True)
factor = self.graph.get_node(node.layer.input[1], copy=True)
self.add_omit_nodes(factor.layer_name, node.layer_name)
if factor.layer_type == 'Const':
factor = factor.value.tolist()
else:
factor = self.decoder.infer_tensor(factor)
attr = {'factor': factor}
node.fluid_code.add_layer("pow", inputs=x, output=node, param_attr=attr)
def All(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
reduce_idx = self.graph.get_node(node.layer.input[1], copy=True)
self.add_omit_nodes(reduce_idx.layer_name, node.layer_name)
assert reduce_idx.layer_type == "Const", "Only support Const parameter[reduce_idx]"
dims = reduce_idx.value.tolist()
keep_dims = node.get_attr("keep_dims")
attr = {"dim": dims, "keep_dim": keep_dims}
node.fluid_code.add_layer("reduce_all",
inputs=input,
output=node,
param_attr=attr)
def GatherV2(self, node):
embeddings = self.graph.get_node(node.layer.input[0], copy=True)
index = self.graph.get_node(node.layer.input[1], copy=True)
axis = self.graph.get_node(node.layer.input[2], copy=True)
self.add_omit_nodes(axis.layer_name, node.layer_name)
assert axis.layer_type == 'Const', "Only support Const parameter[axis]"
axis = axis.value.tolist()
assert axis == 0, "Only support axis=0 in GatherV2 OP"
attr = {'overwrite': False}
inputs = {'input': embeddings, 'index': index}
node.fluid_code.add_layer("gather",
inputs=inputs,
output=node,
param_attr=attr)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册