提交 169b989e 编写于 作者: M mamingjie-China

support for bert

上级 2c831a8d
......@@ -550,6 +550,12 @@ class TFOpMapperNHWC(OpMapper):
transpose_x=transpose_a,
transpose_y=transpose_b)
def BatchMatMul(self, node):
return self.MatMul(node)
def BatchMatMulV2(self, node):
return self.MatMul(node)
def DepthwiseConv2dNative(self, node):
input = self.graph.get_node(node.layer.input[0])
kernel = self.graph.get_node(node.layer.input[1])
......@@ -1043,3 +1049,93 @@ class TFOpMapperNHWC(OpMapper):
inputs = {"x": node.name, "y": node.name}
program.add_layer(
"fluid.layers.elementwise_mul", inputs=inputs, outputs=[node.name])
def OneHot(self, node):
input = self.graph.get_node(node.layer.input[0])
depth = self.graph.get_node(node.layer.input[1])
on_value = self.graph.get_node(node.layer.input[2])
off_value = self.graph.get_node(node.layer.input[3])
assert depth.layer_type == 'Const', 'Parameter depth should be Const in OneHot'
assert on_value.layer_type == 'Const', 'Parameter on_value should be Const in OneHot'
assert off_value.layer_type == 'Const', 'Parameter off_value should be Const in OneHot'
attr = {'depth': depth.value}
on_value = on_value.value
off_value = off_value.value
assert math.fabs(on_value -
1.0) < 1e-06, "on_value should be 1 in OneHot"
assert math.fabs(off_value -
0.0) < 1e-06, "off_value should be 0 in OneHot"
program.add_layer(
"fluid.one_hot",
inputs={"input": input.name},
outputs=[node.name],
depth=depth.value)
def Pow(self, node):
x = self.graph.get_node(node.layer.input[0])
factor = self.graph.get_node(node.layer.input[1])
inputs = {"x": x.name}
attr = dict()
if factor.layer_type == 'Const':
attr["factor"] = factor.value.tolist()
else:
inputs["factor"] = factor.name
program.add_layer(
"fluid.layers.pow", inputs=inputs, outputs=[node.name], **attr)
def All(self, node):
input = self.graph.get_node(node.layer.input[0])
reduce_idx = self.graph.get_node(node.layer.input[1])
assert reduce_idx.layer_type == "Const", "Only support Const parameter[reduce_idx]"
attr = dict()
attr["dim"] = reduce_idx.value.tolist()
attr["keep_dim"] = node.get_attr("keep_dims")
program.add_layer(
"fluid.layers.reduce_all",
inputs={"input": input.name},
outputs=[node.name],
**attr)
def GatherV2(self, node):
embeddings = self.graph.get_node(node.layer.input[0])
index = self.graph.get_node(node.layer.input[1])
axis = self.graph.get_node(node.layer.input[2])
assert axis.layer_type == 'Const', "Only support Const parameter[axis]"
axis = axis.value.tolist()
assert axis == 0, "Only support axis=0 in GatherV2 OP"
index_name = index.name
if len(index.out_shapes[0]) != 1:
reshape_name = gen_name("gather", "reshape")
index_name = reshape_name
program.add_layer(
"fluid.layers.reshape",
inputs={"x": index.name},
outputs=[reshape_name],
shape=[-1])
inputs = {'input': embeddings.name, 'index': index_name}
program.add_layer(
"fluid.layers.gather",
inputs=inputs,
outputs=[node.name],
overwrite=False)
def ExpandDims(self, node):
x = self.graph.get_node(node.layer.input[0], copy=True)
y = self.graph.get_node(node.layer.input[1], copy=True)
inputs = {"input": x.name}
attr = dict()
if y.layer_type == 'Const':
dim = y.value.tolist()
if not isinstance(dim, list):
dim = [dim]
attr['axes'] = dim
else:
inputs['axes'] = y.name
program.add_layer(
"fluid.layers.unsqueeze",
inputs=inputs,
outputs=[node.name],
**attr)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册