From 169b989e65820bbfcff66c8745f584d8d9f854b2 Mon Sep 17 00:00:00 2001 From: mamingjie-China Date: Tue, 4 Aug 2020 20:48:21 +0800 Subject: [PATCH] support for bert --- x2paddle/op_mapper/tf_op_mapper_nhwc.py | 96 +++++++++++++++++++++++++ 1 file changed, 96 insertions(+) diff --git a/x2paddle/op_mapper/tf_op_mapper_nhwc.py b/x2paddle/op_mapper/tf_op_mapper_nhwc.py index c747611..6b00c5c 100644 --- a/x2paddle/op_mapper/tf_op_mapper_nhwc.py +++ b/x2paddle/op_mapper/tf_op_mapper_nhwc.py @@ -550,6 +550,12 @@ class TFOpMapperNHWC(OpMapper): transpose_x=transpose_a, transpose_y=transpose_b) + def BatchMatMul(self, node): + return self.MatMul(node) + + def BatchMatMulV2(self, node): + return self.MatMul(node) + def DepthwiseConv2dNative(self, node): input = self.graph.get_node(node.layer.input[0]) kernel = self.graph.get_node(node.layer.input[1]) @@ -1043,3 +1049,93 @@ class TFOpMapperNHWC(OpMapper): inputs = {"x": node.name, "y": node.name} program.add_layer( "fluid.layers.elementwise_mul", inputs=inputs, outputs=[node.name]) + + def OneHot(self, node): + input = self.graph.get_node(node.layer.input[0]) + depth = self.graph.get_node(node.layer.input[1]) + on_value = self.graph.get_node(node.layer.input[2]) + off_value = self.graph.get_node(node.layer.input[3]) + assert depth.layer_type == 'Const', 'Parameter depth should be Const in OneHot' + assert on_value.layer_type == 'Const', 'Parameter on_value should be Const in OneHot' + assert off_value.layer_type == 'Const', 'Parameter off_value should be Const in OneHot' + + attr = {'depth': depth.value} + on_value = on_value.value + off_value = off_value.value + assert math.fabs(on_value - + 1.0) < 1e-06, "on_value should be 1 in OneHot" + assert math.fabs(off_value - + 0.0) < 1e-06, "off_value should be 0 in OneHot" + + program.add_layer( + "fluid.one_hot", + inputs={"input": input.name}, + outputs=[node.name], + depth=depth.value) + + def Pow(self, node): + x = self.graph.get_node(node.layer.input[0]) + factor = self.graph.get_node(node.layer.input[1]) + inputs = {"x": x.name} + attr = dict() + if factor.layer_type == 'Const': + attr["factor"] = factor.value.tolist() + else: + inputs["factor"] = factor.name + program.add_layer( + "fluid.layers.pow", inputs=inputs, outputs=[node.name], **attr) + + def All(self, node): + input = self.graph.get_node(node.layer.input[0]) + reduce_idx = self.graph.get_node(node.layer.input[1]) + assert reduce_idx.layer_type == "Const", "Only support Const parameter[reduce_idx]" + attr = dict() + attr["dim"] = reduce_idx.value.tolist() + attr["keep_dim"] = node.get_attr("keep_dims") + + program.add_layer( + "fluid.layers.reduce_all", + inputs={"input": input.name}, + outputs=[node.name], + **attr) + + def GatherV2(self, node): + embeddings = self.graph.get_node(node.layer.input[0]) + index = self.graph.get_node(node.layer.input[1]) + axis = self.graph.get_node(node.layer.input[2]) + assert axis.layer_type == 'Const', "Only support Const parameter[axis]" + axis = axis.value.tolist() + assert axis == 0, "Only support axis=0 in GatherV2 OP" + index_name = index.name + if len(index.out_shapes[0]) != 1: + reshape_name = gen_name("gather", "reshape") + index_name = reshape_name + program.add_layer( + "fluid.layers.reshape", + inputs={"x": index.name}, + outputs=[reshape_name], + shape=[-1]) + inputs = {'input': embeddings.name, 'index': index_name} + program.add_layer( + "fluid.layers.gather", + inputs=inputs, + outputs=[node.name], + overwrite=False) + + def ExpandDims(self, node): + x = self.graph.get_node(node.layer.input[0], copy=True) + y = self.graph.get_node(node.layer.input[1], copy=True) + inputs = {"input": x.name} + attr = dict() + if y.layer_type == 'Const': + dim = y.value.tolist() + if not isinstance(dim, list): + dim = [dim] + attr['axes'] = dim + else: + inputs['axes'] = y.name + program.add_layer( + "fluid.layers.unsqueeze", + inputs=inputs, + outputs=[node.name], + **attr) -- GitLab