diff --git a/x2paddle/op_mapper/dygraph/pytorch2paddle/prim.py b/x2paddle/op_mapper/dygraph/pytorch2paddle/prim.py index fd19cd6f06b43da6ab14c82d2428c62da8d9e369..cc9c61efb110c6ba68363180680fda1947c3998b 100644 --- a/x2paddle/op_mapper/dygraph/pytorch2paddle/prim.py +++ b/x2paddle/op_mapper/dygraph/pytorch2paddle/prim.py @@ -37,20 +37,24 @@ def prim_Constant(mapper, graph, node): tensor_value = value value = "{}".format(value) if "tensor" in value: - if isinstance(tensor_value, list) or isinstance(tensor_value, tuple): + if isinstance(tensor_value, list) or isinstance(tensor_value, + tuple): name_dict = dict() for i, tv in enumerate(tensor_value): - output_name_i = "{}_p{}".format(output_name,i) + output_name_i = "{}_p{}".format(output_name, i) key_i = "input{}".format(i) - mapper.paddle_params[output_name_i] = tv.cpu().detach().numpy() + mapper.paddle_params[output_name_i] = tv.cpu().detach( + ).numpy() graph.add_layer( "self.create_parameter", inputs={}, outputs=[output_name_i], scope_name=scope_name, - dtype=string(str(mapper.paddle_params[output_name_i].dtype)), - shape = mapper.paddle_params[output_name_i].shape, - default_initializer="paddle.nn.initializer.Constant(value=0.0)") + dtype=string( + str(mapper.paddle_params[output_name_i].dtype)), + shape=mapper.paddle_params[output_name_i].shape, + default_initializer="paddle.nn.initializer.Constant(value=0.0)" + ) name_dict[key_i] = output_name_i graph.add_layer( "prim.list", @@ -59,16 +63,18 @@ def prim_Constant(mapper, graph, node): scope_name=scope_name) return [], [output_name] else: -# mapper.pytorch_params[output_name] = tensor_value.cpu().detach().numpy() - mapper.paddle_params[output_name] = tensor_value.cpu().detach().numpy() + # mapper.pytorch_params[output_name] = tensor_value.cpu().detach().numpy() + mapper.paddle_params[output_name] = tensor_value.cpu().detach( + ).numpy() graph.add_layer( - "self.create_parameter", - inputs={}, - outputs=[output_name], - scope_name=scope_name, - dtype=string(str(mapper.paddle_params[output_name].dtype)), - shape = mapper.paddle_params[output_name].shape, - default_initializer="paddle.nn.initializer.Constant(value=0.0)") + "self.create_parameter", + inputs={}, + outputs=[output_name], + scope_name=scope_name, + dtype=string(str(mapper.paddle_params[output_name].dtype)), + shape=mapper.paddle_params[output_name].shape, + default_initializer="paddle.nn.initializer.Constant(value=0.0)" + ) return [], [output_name] if "inf" in str(value): t = str(type(value)).split("'")[1] @@ -81,7 +87,11 @@ def prim_Constant(mapper, graph, node): value = int(math.pow(2, 31) - 1) mapper.attrs[output_name] = value graph.add_layer( - "prim.constant", inputs={}, outputs=[output_name], scope_name=scope_name, value=value) + "prim.constant", + inputs={}, + outputs=[output_name], + scope_name=scope_name, + value=value) return [], [output_name] @@ -105,18 +115,23 @@ def prim_data(mapper, graph, node): # 获取当前节点输出的list current_outputs = [output_name] # 处理输入0,即%4336 - mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) + mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, + scope_name) layer_inputs["input"] = inputs_name[0] # 获取当前节点输入的list current_inputs = list(layer_inputs.values()) - graph.add_layer("prim.equal", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) + graph.add_layer( + "prim.equal", + inputs=layer_inputs, + outputs=layer_outputs, + scope_name=scope_name) return current_inputs, current_outputs def prim_DictConstruct(mapper, graph, node): """ 构建dict。 - + TorchScript示例: %32 : Dict(str, Tensor) = prim::DictConstruct(%30, %23, %31, %29) 参数含义: @@ -136,22 +151,22 @@ def prim_DictConstruct(mapper, graph, node): current_outputs = [output_name] # 处理每个输入 for i, input_name in enumerate(inputs_name): - if i%2 == 0: - layer_attrs["key{}".format(int(i/2))] = mapper.attrs[input_name] + if i % 2 == 0: + layer_attrs["key{}".format(int(i / 2))] = mapper.attrs[input_name] else: - layer_inputs["value{}".format(int(i/2))] = input_name + layer_inputs["value{}".format(int(i / 2))] = input_name # 获取当前节点输入的list current_inputs = list(layer_inputs.values()) - graph.add_layer("prim.dict_construct", - inputs=layer_inputs, - outputs=layer_outputs, - scope_name=scope_name, - **layer_attrs) + graph.add_layer( + "prim.dict_construct", + inputs=layer_inputs, + outputs=layer_outputs, + scope_name=scope_name, + **layer_attrs) return current_inputs, current_outputs - def prim_GetAttr(mapper, graph, node): """ 获取attribute信息。 @@ -212,8 +227,13 @@ def prim_If(mapper, graph, node): input_node = list(node.inputs())[0].node() script_input_unique_id = list(node.inputs())[0].unique() input_node_name = mapper.outputs_info[script_input_unique_id] - mapper._check_input(graph, input_node, input_node_name, current_outputs, scope_name) - graph.add_layer("prim.if", inputs={'input': input_node_name}, outputs=node_outputs, scope_name=scope_name) + mapper._check_input(graph, input_node, input_node_name, current_outputs, + scope_name) + graph.add_layer( + "prim.if", + inputs={'input': input_node_name}, + outputs=node_outputs, + scope_name=scope_name) current_layer = list(graph.layers.values())[-1] block0 = list(node.blocks())[0] block0_graph, graph_inputs0 = mapper.traverse(block0, current_layer) @@ -249,12 +269,17 @@ def prim_ListConstruct(mapper, graph, node): current_outputs = [output_name] # 处理每个输入 for i, input_name in enumerate(inputs_name): - mapper._check_input(graph, inputs_node[i], input_name, current_outputs, scope_name) + mapper._check_input(graph, inputs_node[i], input_name, current_outputs, + scope_name) layer_inputs["input{}".format(i)] = input_name # 获取当前节点输入的list current_inputs = list(layer_inputs.values()) - layer_id = graph.add_layer("prim.list", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) + layer_id = graph.add_layer( + "prim.list", + inputs=layer_inputs, + outputs=layer_outputs, + scope_name=scope_name) mapper.output2id[output_name] = layer_id return current_inputs, current_outputs @@ -277,13 +302,17 @@ def prim_ListUnpack(mapper, graph, node): # 获取当前节点输出的list current_outputs = layer_outputs.copy() # 处理输入0,即%4354 - mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) + mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, + scope_name) layer_inputs["input"] = inputs_name[0] # 获取当前节点输入的list current_inputs = list(layer_inputs.values()) graph.add_layer( - "prim.list_unpack", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) + "prim.list_unpack", + inputs=layer_inputs, + outputs=layer_outputs, + scope_name=scope_name) mapper.split_len[list(layer_inputs.values())[0]] = len(layer_outputs) return current_inputs, current_outputs @@ -342,7 +371,11 @@ def prim_Loop(mapper, graph, node): scope_name=scope_name) node_outputs.append(block_input_node_name) - graph.add_layer("prim.loop", inputs=loop_inputs, outputs=loop_outputs, scope_name=scope_name) + graph.add_layer( + "prim.loop", + inputs=loop_inputs, + outputs=loop_outputs, + scope_name=scope_name) current_layer = list(graph.layers.values())[-1] block_graph, graph_inputs = mapper.traverse(block, current_layer) for i, input_name in enumerate(graph_inputs): @@ -370,12 +403,17 @@ def prim_min(mapper, graph, node): # 获取当前节点输出的list current_outputs = [output_name] # 处理输入0,即%86 - mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) + mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, + scope_name) layer_inputs["input"] = inputs_name[0] # 获取当前节点输入的list current_inputs = list(layer_inputs.values()) - graph.add_layer("prim.min", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) + graph.add_layer( + "prim.min", + inputs=layer_inputs, + outputs=layer_outputs, + scope_name=scope_name) return current_inputs, current_outputs @@ -397,14 +435,19 @@ def prim_NumToTensor(mapper, graph, node): # 获取当前节点输出的list current_outputs = [output_name] # 处理输入0,即%86 - mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) - inputs_inputs_name, inputs_inputs_node = mapper._get_inputs_name(inputs_node[0]) + mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, + scope_name) + inputs_inputs_name, inputs_inputs_node = mapper._get_inputs_name( + inputs_node[0]) if inputs_node[0].kind() == "aten::size" and len(inputs_inputs_name) > 1: layer_inputs["input"] = inputs_name[0] # 获取当前节点输入的list current_inputs = list(layer_inputs.values()) graph.add_layer( - "prim_equal", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) + "prim_equal", + inputs=layer_inputs, + outputs=layer_outputs, + scope_name=scope_name) else: layer_inputs["fill_value"] = inputs_name[0] # 获取当前节点输入的list @@ -437,13 +480,17 @@ def prim_RaiseException(mapper, graph, node): # 获取当前节点输出的list current_outputs = [output_name] # 处理输入0,即%76 - mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) + mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, + scope_name) layer_inputs["input"] = inputs_name[0] # 获取当前节点输入的list current_inputs = list(layer_inputs.values()) graph.add_layer( - "prim.exception", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) + "prim.exception", + inputs=layer_inputs, + outputs=layer_outputs, + scope_name=scope_name) return current_inputs, current_outputs @@ -464,13 +511,17 @@ def prim_requires_grad(mapper, graph, node): # 获取当前节点输出的list current_outputs = [output_name] # 处理输入0,即%86 - mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) + mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, + scope_name) layer_inputs["input"] = inputs_name[0] # 获取当前节点输入的list current_inputs = list(layer_inputs.values()) graph.add_layer( - "prim.requires_grad", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) + "prim.requires_grad", + inputs=layer_inputs, + outputs=layer_outputs, + scope_name=scope_name) return current_inputs, current_outputs @@ -527,13 +578,17 @@ def prim_shape(mapper, graph, node): # 获取当前节点输出的list current_outputs = [output_name] # 处理输入0,即%input.8 - mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) + mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, + scope_name) layer_inputs["input"] = inputs_name[0] # 获取当前节点输入的list current_inputs = list(layer_inputs.values()) graph.add_layer( - "paddle.shape", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) + "paddle.shape", + inputs=layer_inputs, + outputs=layer_outputs, + scope_name=scope_name) return current_inputs, current_outputs @@ -560,7 +615,11 @@ def prim_TupleConstruct(mapper, graph, node): # 获取当前节点输入的list current_inputs = list(layer_inputs.values()) - graph.add_layer("prim.tuple", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) + graph.add_layer( + "prim.tuple", + inputs=layer_inputs, + outputs=layer_outputs, + scope_name=scope_name) return current_inputs, current_outputs @@ -590,7 +649,11 @@ def prim_TupleUnpack(mapper, graph, node): current_inputs = list(layer_inputs.values()) graph.add_layer( - "prim.tuple_unpack", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name, **layer_attrs) + "prim.tuple_unpack", + inputs=layer_inputs, + outputs=layer_outputs, + scope_name=scope_name, + **layer_attrs) return current_inputs, current_outputs @@ -614,12 +677,17 @@ def prim_unchecked_cast(mapper, graph, node): # 获取当前节点输出的list current_outputs = [output_name] # 处理输入0,即%size.63 - mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) + mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, + scope_name) layer_inputs["input"] = inputs_name[0] # 获取当前节点输入的list current_inputs = list(layer_inputs.values()) - graph.add_layer("prim.equal", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) + graph.add_layer( + "prim.equal", + inputs=layer_inputs, + outputs=layer_outputs, + scope_name=scope_name) return current_inputs, current_outputs @@ -636,5 +704,9 @@ def prim_Uninitialized(mapper, graph, node): output = list(node.outputs())[0] mapper.attrs[output_name] = None graph.add_layer( - "prim.constant", inputs={}, outputs=[output_name], scope_name=scope_name, value=None) + "prim.constant", + inputs={}, + outputs=[output_name], + scope_name=scope_name, + value=None) return [], [output_name] diff --git a/x2paddle/op_mapper/dygraph/pytorch2paddle/prim2code.py b/x2paddle/op_mapper/dygraph/pytorch2paddle/prim2code.py index 0ca02fc87e156e2bb55ab10516fea27d1164abe0..621ad26c72e5ecbe39905e249d632b1095b53b48 100644 --- a/x2paddle/op_mapper/dygraph/pytorch2paddle/prim2code.py +++ b/x2paddle/op_mapper/dygraph/pytorch2paddle/prim2code.py @@ -14,7 +14,8 @@ # limitations under the License. NO_OUTPUT_COUNT = 0 - + + def gen_codes(code_list, indent=0): indent_blank = " " * indent codes = [] @@ -53,13 +54,24 @@ def get_value(layer, key, layer_id=None, different_attrs=None): return str(layer.attrs[key]) -def prim_add(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): +def prim_add(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): line = "{} = {} + {}".format(layer.outputs[0], - get_value(layer, "x", different_attrs), get_value(layer, "y", different_attrs)) + get_value(layer, "x", different_attrs), + get_value(layer, "y", different_attrs)) forward_func.extend(gen_codes([line], indent=indent)) -def prim_add_(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): +def prim_add_(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): line = "{} = {} + {} * {}".format(layer.outputs[0], get_value(layer, "x", different_attrs), layer.attrs["alpha"], @@ -67,22 +79,39 @@ def prim_add_(layer, indent=1, init_func=[], forward_func=[], layer_id=None, dif forward_func.extend(gen_codes([line], indent=indent)) -def prim_and(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): +def prim_and(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None, + is_return_line=False): line = "{} = {} and {}".format(layer.outputs[0], - get_value(layer, "x", different_attrs), get_value(layer, "y", different_attrs)) + get_value(layer, "x", different_attrs), + get_value(layer, "y", different_attrs)) if is_return_line: return line.split(" = ")[1] forward_func.extend(gen_codes([line], indent=indent)) -def prim_append(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): +def prim_append(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): line = "{}.append({})".format( - get_value(layer, "list", layer_id, different_attrs), + get_value(layer, "list", layer_id, different_attrs), get_value(layer, "element", layer_id, different_attrs)) forward_func.extend(gen_codes([line], indent=indent)) -def prim_assert(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): +def prim_assert(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): if layer.attrs["type"] == "eq": values = get_value(layer, "key") if "value" in layer.attrs: @@ -93,15 +122,15 @@ def prim_assert(layer, indent=1, init_func=[], forward_func=[], layer_id=None, d s += "{} == {} or ".format(get_value(layer, "key"), v) if len(s) > 0: s = s[:-4] - lc=locals() + lc = locals() exec("assert_result = {}".format(s)) assert_result = lc['assert_result'] line = "assert {}, \'The {} must be {}!\'".format( s, get_value(layer, "key"), get_value(layer, "value")) else: - s = "{} == {}".format(get_value(layer, "key"), - get_value(layer, "value")) - lc=locals() + s = "{} == {}".format( + get_value(layer, "key"), get_value(layer, "value")) + lc = locals() exec("assert_result = {}".format(s)) assert_result = lc['assert_result'] line = "assert {}, \'The {} must be {}!\'".format( @@ -112,7 +141,12 @@ def prim_assert(layer, indent=1, init_func=[], forward_func=[], layer_id=None, d forward_func.extend(gen_codes([line], indent=indent)) -def prim_check_dim(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): +def prim_check_dim(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): lines = [] dim = get_value(layer, "dim", different_attrs) lines.append("if {} < 0:".format(dim)) @@ -123,12 +157,23 @@ def prim_check_dim(layer, indent=1, init_func=[], forward_func=[], layer_id=None forward_func.extend(gen_codes(lines, indent=indent)) -def prim_constant(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): +def prim_constant(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): line = "{} = {}".format(layer.outputs[0], layer.attrs["value"]) forward_func.extend(gen_codes([line], indent=indent)) -def prim_contain(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): +def prim_contain(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None, + is_return_line=False): line = "{} = {} in {}".format(layer.outputs[0], get_value(layer, "element", different_attrs), get_value(layer, "input", different_attrs)) @@ -137,108 +182,182 @@ def prim_contain(layer, indent=1, init_func=[], forward_func=[], layer_id=None, forward_func.extend(gen_codes([line], indent=indent)) -def prim_dict(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): +def prim_dict(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): line = "{} = dict()".format(layer.outputs[0]) forward_func.extend(gen_codes([line], indent=indent)) - - -def prim_dict_construct(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): + + +def prim_dict_construct(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): lines = list() line = "{} = dict()".format(layer.outputs[0]) lines.append(line) for i in range(len(layer.inputs)): - line = "{}[{}] = {}".format(layer.outputs[0], - get_value(layer, "key{}".format(i), different_attrs), - get_value(layer, "value{}".format(i), different_attrs)) + line = "{}[{}] = {}".format( + layer.outputs[0], + get_value(layer, "key{}".format(i), different_attrs), + get_value(layer, "value{}".format(i), different_attrs)) lines.append(line) forward_func.extend(gen_codes(lines, indent=indent)) - - -def prim_dict2values(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): - line = "{} = list({}.values())".format(layer.outputs[0], - get_value(layer, "x", different_attrs)) + + +def prim_dict2values(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): + line = "{} = list({}.values())".format( + layer.outputs[0], get_value(layer, "x", different_attrs)) forward_func.extend(gen_codes([line], indent=indent)) -def prim_div(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): +def prim_div(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): line = "{} = {} / {}".format(layer.outputs[0], - get_value(layer, "x", different_attrs), + get_value(layer, "x", different_attrs), get_value(layer, "y", different_attrs)) forward_func.extend(gen_codes([line], indent=indent)) -def prim_eq(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None,is_return_line=False): +def prim_eq(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None, + is_return_line=False): line = "{} = {} == {}".format(layer.outputs[0], - get_value(layer, "x", different_attrs), + get_value(layer, "x", different_attrs), get_value(layer, "y", different_attrs)) if is_return_line: return line.split(" = ")[1] forward_func.extend(gen_codes([line], indent=indent)) -def prim_equal(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): - line = "{} = {}".format(layer.outputs[0], get_value(layer, "input", different_attrs)) +def prim_equal(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): + line = "{} = {}".format(layer.outputs[0], + get_value(layer, "input", different_attrs)) forward_func.extend(gen_codes([line], indent=indent)) -def prim_exception(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): - line = "raise Exception({})".format(get_value(layer, "input", different_attrs)) +def prim_exception(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): + line = "raise Exception({})".format( + get_value(layer, "input", different_attrs)) forward_func.extend(gen_codes([line], indent=indent)) -def prim_float(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): - line = "{} = float({})".format(layer.outputs[0], get_value(layer, "input", different_attrs)) +def prim_float(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): + line = "{} = float({})".format(layer.outputs[0], + get_value(layer, "input", different_attrs)) forward_func.extend(gen_codes([line], indent=indent)) -def prim_floor(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): +def prim_floor(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): line = "{} = math.floor({})".format(layer.outputs[0], get_value(layer, "x", different_attrs)) forward_func.extend(gen_codes([line], indent=indent)) -def prim_floordiv(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): +def prim_floordiv(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): line = "{} = {} // {}".format(layer.outputs[0], - get_value(layer, "x", different_attrs), + get_value(layer, "x", different_attrs), get_value(layer, "y", different_attrs)) forward_func.extend(gen_codes([line], indent=indent)) -def prim_getitem(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): +def prim_getitem(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): line = "{} = {}[{}]".format(layer.outputs[0], get_value(layer, "list", different_attrs), get_value(layer, "index", different_attrs)) forward_func.extend(gen_codes([line], indent=indent)) -def prim_gt(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): +def prim_gt(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None, + is_return_line=False): line = "{} = {} > {}".format(layer.outputs[0], - get_value(layer, "x", different_attrs), + get_value(layer, "x", different_attrs), get_value(layer, "y", different_attrs)) if is_return_line: return line.split(" = ")[1] forward_func.extend(gen_codes([line], indent=indent)) -def prim_if(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): +def prim_if(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): try: exec_s = None for line in forward_func: s = line.replace(" ", "") - if s.startswith("{} = ".format(get_value(layer, "input", different_attrs))): + if s.startswith("{} = ".format( + get_value(layer, "input", different_attrs))): exec_s = s.split(" = ")[1] - lc=locals() + lc = locals() if exec_s is not None: exec("if_result = {}".format(exec_s)) else: - exec("if_result = {}".format(get_value(layer, "input", different_attrs))) + exec("if_result = {}".format( + get_value(layer, "input", different_attrs))) if_result = lc['if_result'] if if_result: block = layer.blocks[0] else: block = layer.blocks[1] if len(block.layers) > 0: - b_init_lines, b_forward_lines = block.gen_dygraph_code(indent=indent) + b_init_lines, b_forward_lines = block.gen_dygraph_code( + indent=indent) init_func.extend(b_init_lines) forward_func.extend(b_forward_lines) except: @@ -249,7 +368,8 @@ def prim_if(layer, indent=1, init_func=[], forward_func=[], layer_id=None, diffe line = "pass" forward_func.extend(gen_codes([line], indent=indent + 1)) else: - b_init_lines, b_forward_lines = block.gen_dygraph_code(indent=indent + 1) + b_init_lines, b_forward_lines = block.gen_dygraph_code( + indent=indent + 1) init_func.extend(b_init_lines) forward_func.extend(b_forward_lines) block = layer.blocks[1] @@ -263,30 +383,54 @@ def prim_if(layer, indent=1, init_func=[], forward_func=[], layer_id=None, diffe forward_func.extend(b_forward_lines) -def prim_int(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): - line = "{} = int({})".format(layer.outputs[0], get_value(layer, "input", different_attrs)) +def prim_int(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): + line = "{} = int({})".format(layer.outputs[0], + get_value(layer, "input", different_attrs)) forward_func.extend(gen_codes([line], indent=indent)) -def prim_is(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): +def prim_is(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None, + is_return_line=False): line = "{} = {} is {}".format(layer.outputs[0], - get_value(layer, "x", different_attrs), + get_value(layer, "x", different_attrs), get_value(layer, "y", different_attrs)) if is_return_line: return line.split(" = ")[1] forward_func.extend(gen_codes([line], indent=indent)) -def prim_isinstance(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): - line = "{} = isinstance({}, {})".format(layer.outputs[0], - get_value(layer, "input", different_attrs), - layer.attrs["cls"]) +def prim_isinstance(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None, + is_return_line=False): + line = "{} = isinstance({}, {})".format( + layer.outputs[0], + get_value(layer, "input", different_attrs), layer.attrs["cls"]) if is_return_line: return line.split(" = ")[1] forward_func.extend(gen_codes([line], indent=indent)) -def prim_isnot(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): +def prim_isnot(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None, + is_return_line=False): line = "{} = {} is not {}".format(layer.outputs[0], get_value(layer, "x", different_attrs), get_value(layer, "y", different_attrs)) @@ -295,53 +439,94 @@ def prim_isnot(layer, indent=1, init_func=[], forward_func=[], layer_id=None, di forward_func.extend(gen_codes([line], indent=indent)) -def prim_le(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): +def prim_le(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None, + is_return_line=False): line = "{} = {} <= {}".format(layer.outputs[0], - get_value(layer, "x", different_attrs), + get_value(layer, "x", different_attrs), get_value(layer, "y", different_attrs)) if is_return_line: return line.split(" = ")[1] forward_func.extend(gen_codes([line], indent=indent)) -def prim_len(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): - line = "{} = len({})".format(layer.outputs[0], get_value(layer, "input", different_attrs)) +def prim_len(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): + line = "{} = len({})".format(layer.outputs[0], + get_value(layer, "input", different_attrs)) forward_func.extend(gen_codes([line], indent=indent)) -def prim_len2list(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): +def prim_len2list(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): lines = [] lines.append("{} = []".format(layer.outputs[0])) - lines.append("for i in range({}):".format(get_value(layer, "len", different_attrs))) + lines.append("for i in range({}):".format( + get_value(layer, "len", different_attrs))) lines.append(" {}.append(i)".format(layer.outputs[0])) forward_func.extend(gen_codes(lines, indent=indent)) -def prim_lt(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): +def prim_lt(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None, + is_return_line=False): line = "{} = {} < {}".format(layer.outputs[0], - get_value(layer, "x", different_attrs), + get_value(layer, "x", different_attrs), get_value(layer, "y", different_attrs)) if is_return_line: return line.split(" = ")[1] forward_func.extend(gen_codes([line], indent=indent)) -def prim_list(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): +def prim_list(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): input_len = len(layer.inputs) + len(layer.attrs) inputs_list = list() for i in range(input_len): - inputs_list.append(get_value(layer, "input{}".format(i), different_attrs)) + inputs_list.append( + get_value(layer, "input{}".format(i), different_attrs)) inputs_str = ', '.join(inputs_list) line = "{} = [{}]".format(layer.outputs[0], inputs_str) forward_func.extend(gen_codes([line], indent=indent)) -def prim_list_unpack(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): - line = "{} = {}".format(", ".join(layer.outputs), get_value(layer, "input", different_attrs)) +def prim_list_unpack(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): + line = "{} = {}".format(", ".join(layer.outputs), + get_value(layer, "input", different_attrs)) forward_func.extend(gen_codes([line], indent=indent)) -def prim_loop(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): +def prim_loop(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): loop_range = get_value(layer, "input", different_attrs) line = "for {} in range({}):".format(layer.outputs[1], loop_range) forward_func.extend(gen_codes([line], indent=indent)) @@ -351,171 +536,303 @@ def prim_loop(layer, indent=1, init_func=[], forward_func=[], layer_id=None, dif forward_func.extend(b_forward_lines) -def prim_min(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): - line = "{} = min({})".format(layer.outputs[0], get_value(layer, "input", different_attrs)) +def prim_min(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): + line = "{} = min({})".format(layer.outputs[0], + get_value(layer, "input", different_attrs)) forward_func.extend(gen_codes([line], indent=indent)) -def prim_mul(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): +def prim_mul(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): line = "{} = {} * {}".format(layer.outputs[0], - get_value(layer, "x", different_attrs), + get_value(layer, "x", different_attrs), get_value(layer, "y", different_attrs)) forward_func.extend(gen_codes([line], indent=indent)) -def prim_ne(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): +def prim_ne(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None, + is_return_line=False): line = "{} = {} != {}".format(layer.outputs[0], - get_value(layer, "x", different_attrs), + get_value(layer, "x", different_attrs), get_value(layer, "y", different_attrs)) if is_return_line: return line.split(" = ")[1] forward_func.extend(gen_codes([line], indent=indent)) -def prim_neg(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): - line = "{} = -{}".format(layer.outputs[0], get_value(layer, "input", different_attrs)) +def prim_neg(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): + line = "{} = -{}".format(layer.outputs[0], + get_value(layer, "input", different_attrs)) forward_func.extend(gen_codes([line], indent=indent)) -def prim_not(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): - line = "{} = not {}".format(layer.outputs[0], get_value(layer, "input", different_attrs)) +def prim_not(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None, + is_return_line=False): + line = "{} = not {}".format(layer.outputs[0], + get_value(layer, "input", different_attrs)) if is_return_line: return line.split(" = ")[1] forward_func.extend(gen_codes([line], indent=indent)) -def prim_or(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): +def prim_or(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None, + is_return_line=False): line = "{} = {} or {}".format(layer.outputs[0], - get_value(layer, "x", different_attrs), + get_value(layer, "x", different_attrs), get_value(layer, "y", different_attrs)) if is_return_line: return line.split(" = ")[1] forward_func.extend(gen_codes([line], indent=indent)) -def prim_replaceitem(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): +def prim_replaceitem(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): line = "{}[{}] = {}".format( get_value(layer, "list", layer_id, different_attrs), - get_value(layer, "index", layer_id, different_attrs), + get_value(layer, "index", layer_id, different_attrs), get_value(layer, "item", layer_id, different_attrs)) forward_func.extend(gen_codes([line], indent=indent)) -def prim_requires_grad(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): - line = "{} = not {}.stop_gradient".format(layer.outputs[0], - get_value(layer, "input", different_attrs)) +def prim_requires_grad(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): + line = "{} = not {}.stop_gradient".format( + layer.outputs[0], get_value(layer, "input", different_attrs)) forward_func.extend(gen_codes([line], indent=indent)) -def prim_rsub(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): - line = "{} = {} - {} * {}".format(layer.outputs[0], - get_value(layer, "y", different_attrs), - get_value(layer, "x", different_attrs), - get_value(layer, "alpha", different_attrs)) +def prim_rsub(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): + line = "{} = {} - {} * {}".format( + layer.outputs[0], + get_value(layer, "y", different_attrs), + get_value(layer, "x", different_attrs), + get_value(layer, "alpha", different_attrs)) forward_func.extend(gen_codes([line], indent=indent)) -def prim_select(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): - line = "{} = {}[".format(layer.outputs[0], get_value(layer, "input", different_attrs)) +def prim_select(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): + line = "{} = {}[".format(layer.outputs[0], + get_value(layer, "input", different_attrs)) for dim in range(layer.attrs["dim"]): line += ":, " line += (get_value(layer, "index", different_attrs) + "]") forward_func.extend(gen_codes([line], indent=indent)) -def prim_set_attr(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): - line = "{} = {}".format(layer.outputs[0], get_value(layer, "input", different_attrs)) +def prim_set_attr(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): + line = "{} = {}".format(layer.outputs[0], + get_value(layer, "input", different_attrs)) forward_func.extend(gen_codes([line], indent=indent)) -def prim_set_item(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): +def prim_set_item(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): line = "{}[{}] = {}".format( get_value(layer, "dict", different_attrs), - get_value(layer, "key", different_attrs), get_value(layer, "value", different_attrs)) + get_value(layer, "key", different_attrs), + get_value(layer, "value", different_attrs)) forward_func.extend(gen_codes([line], indent=indent)) - - -def prim_shape(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): - line = "{} = {}.shape".format(layer.outputs[0], - get_value(layer, "input", different_attrs)) - forward_func.extend(gen_codes([line], indent=indent)) - -def prim_shape_dim(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): - line = "{} = {}.shape[{}]".format(layer.outputs[0], - get_value(layer, "input", different_attrs), - get_value(layer, "dim", different_attrs)) +def prim_shape(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): + line = "{} = {}.shape".format(layer.outputs[0], + get_value(layer, "input", different_attrs)) forward_func.extend(gen_codes([line], indent=indent)) -def prim_slice(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): - line = "{} = {}[{}: {}: {}]".format(layer.outputs[0], - get_value(layer, "input", different_attrs), - get_value(layer, "start", different_attrs), - get_value(layer, "end", different_attrs), - get_value(layer, "step", different_attrs)) - forward_func.extend(gen_codes([line], indent=indent)) - - -def prim_startswith(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): - line = "{} = {}.startswith({})".format(layer.outputs[0], - get_value(layer, "input", different_attrs), - get_value(layer, "start_str", different_attrs)) +def prim_shape_dim(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): + line = "{} = {}.shape[{}]".format( + layer.outputs[0], + get_value(layer, "input", different_attrs), + get_value(layer, "dim", different_attrs)) + forward_func.extend(gen_codes([line], indent=indent)) + + +def prim_slice(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): + line = "{} = {}[{}: {}: {}]".format( + layer.outputs[0], + get_value(layer, "input", different_attrs), + get_value(layer, "start", different_attrs), + get_value(layer, "end", different_attrs), + get_value(layer, "step", different_attrs)) + forward_func.extend(gen_codes([line], indent=indent)) + + +def prim_startswith(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None, + is_return_line=False): + line = "{} = {}.startswith({})".format( + layer.outputs[0], + get_value(layer, "input", different_attrs), + get_value(layer, "start_str", different_attrs)) if is_return_line: return line.split(" = ")[1] forward_func.extend(gen_codes([line], indent=indent)) -def prim_str(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): - line = "{} = str({})".format(layer.outputs[0], get_value(layer, "input", different_attrs)) +def prim_str(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): + line = "{} = str({})".format(layer.outputs[0], + get_value(layer, "input", different_attrs)) forward_func.extend(gen_codes([line], indent=indent)) -def prim_sub(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): +def prim_sub(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): if int(float(get_value(layer, "alpha", different_attrs))) == 1: line = "{} = {} - {}".format(layer.outputs[0], - get_value(layer, "x", different_attrs), + get_value(layer, "x", different_attrs), get_value(layer, "y", different_attrs)) else: - line = "{} = {} - {} * {}".format(layer.outputs[0], - get_value(layer, "x", different_attrs), - get_value(layer, "alpha", different_attrs), - get_value(layer, "y", different_attrs)) + line = "{} = {} - {} * {}".format( + layer.outputs[0], + get_value(layer, "x", different_attrs), + get_value(layer, "alpha", different_attrs), + get_value(layer, "y", different_attrs)) forward_func.extend(gen_codes([line], indent=indent)) -def prim_tuple(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): +def prim_tuple(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): input_len = len(layer.inputs) + len(layer.attrs) inputs_list = list() for i in range(input_len): - inputs_list.append(get_value(layer, "input{}".format(i), different_attrs)) + inputs_list.append( + get_value(layer, "input{}".format(i), different_attrs)) inputs_str = ', '.join(inputs_list) line = "{} = ({})".format(layer.outputs[0], inputs_str) forward_func.extend(gen_codes([line], indent=indent)) -def prim_tuple_unpack(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): +def prim_tuple_unpack(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): outputs_str = ', '.join(layer.outputs) - line = "{} = {}".format(outputs_str, get_value(layer, "input", different_attrs)) + line = "{} = {}".format(outputs_str, + get_value(layer, "input", different_attrs)) forward_func.extend(gen_codes([line], indent=indent)) -def prim_type(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): - line = "{} = {}.dtype".format(layer.outputs[0], get_value(layer, "input", different_attrs)) +def prim_type(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): + line = "{} = {}.dtype".format(layer.outputs[0], + get_value(layer, "input", different_attrs)) forward_func.extend(gen_codes([line], indent=indent)) -def prim_var2list(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): - line = "{} = {}.numpy().tolist()".format(layer.outputs[0], - get_value(layer, "input", different_attrs)) +def prim_var2list(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): + line = "{} = {}.numpy().tolist()".format( + layer.outputs[0], get_value(layer, "input", different_attrs)) forward_func.extend(gen_codes([line], indent=indent)) -def prim_warnings(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): +def prim_warnings(layer, + indent=1, + init_func=[], + forward_func=[], + layer_id=None, + different_attrs=None): lines = ["import warnings"] line = "warnings.warn({}, stacklevel={})".format( get_value(layer, "input", different_attrs), layer.attrs["stacklevel"]) lines.append(line) forward_func.extend(gen_codes(lines, indent=indent)) - diff --git a/x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_op_mapper.py b/x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_op_mapper.py index 04b9557059ddd75faa89d6f9b013d2f4efb59ef4..d33e84a8294f99866a0cfac2d2fe863718d50a6c 100644 --- a/x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_op_mapper.py +++ b/x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_op_mapper.py @@ -37,7 +37,7 @@ class PyTorchOpMapper(OpMapper): self.scope_name_list = list() self.scope_name2id = dict() self.inputs_info = dict() - self.output2id = dict() # output名字和layer_id的关系,用于lstm去除前面的node + self.output2id = dict() # output名字和layer_id的关系,用于lstm去除前面的node # 转换 if not self.op_checker(decoder.graph): raise Exception("Model is not supported yet.") @@ -50,6 +50,7 @@ class PyTorchOpMapper(OpMapper): op_list.append(node.kind()) for block in node.blocks(): _update_op_list(block) + op_list = list() _update_op_list(script_graph) op_list = list(set(op_list)) @@ -62,11 +63,11 @@ class PyTorchOpMapper(OpMapper): return True else: if len(unsupported_ops) > 0: - print("\n========= {} OPs are not supported yet ===========".format( - len(unsupported_ops))) + print("\n========= {} OPs are not supported yet ===========". + format(len(unsupported_ops))) for op in unsupported_ops: print("========== {} ============".format(op)) - return False + return False def traverse(self, script_graph, parent_layer=None): # 用于获取graph的输入 @@ -85,20 +86,24 @@ class PyTorchOpMapper(OpMapper): current_node_outputs.extend(outputs) # 初始化 - graph = PaddleGraph(source_type="pytorch", parent_layer=parent_layer, graph_type="dygraph") + graph = PaddleGraph( + source_type="pytorch", + parent_layer=parent_layer, + graph_type="dygraph") if "TopLevelTracedModule" in str(type(self.script)): graph.set_script(self.script) current_node_outputs = [] graph_inputs = [] # 转换输入节点 if isinstance(script_graph, torch._C.Graph): - input_ct = 0 + input_ct = 0 for i, ivalue in enumerate(script_graph.inputs()): node = ivalue.node() if str(ivalue.type()) not in ["Tensor", "Dict[str, Tensor]"]: graph.set_name(str(ivalue.type()).split(".")[-1]) continue - inputs, outputs = self.data(graph, node, ivalue.unique(), input_ct) + inputs, outputs = self.data(graph, node, + ivalue.unique(), input_ct) input_ct += 1 # 转换中间节点 for node in script_graph.nodes(): @@ -183,8 +188,9 @@ class PyTorchOpMapper(OpMapper): outputs=[output_name], scope_name=scope_name, dtype=string(str(param.dtype)), - shape = param.shape, - default_initializer="paddle.nn.initializer.Constant(value=0.0)") + shape=param.shape, + default_initializer="paddle.nn.initializer.Constant(value=0.0)" + ) self.output2id[output_name] = layer_id else: if isinstance(param, dict) and "Tensor" in param and \ @@ -211,8 +217,9 @@ class PyTorchOpMapper(OpMapper): outputs=[output_name], scope_name=scope_name, dtype=string(str(param.dtype)), - shape = param.shape, - default_initializer="paddle.nn.initializer.Constant(value=0.0)") + shape=param.shape, + default_initializer="paddle.nn.initializer.Constant(value=0.0)" + ) node_outputs.append(output_name) self.output2id[output_name] = layer_id return @@ -232,7 +239,8 @@ class PyTorchOpMapper(OpMapper): value=string(param) if isinstance(param, str) else param) node_outputs.append(output_name) - elif node.kind() == "prim::Constant" and output_name in self.pytorch_params: + elif node.kind( + ) == "prim::Constant" and output_name in self.pytorch_params: param = self.pytorch_params[output_name] self.paddle_params[output_name] = param layer_id = graph.add_layer( @@ -241,11 +249,10 @@ class PyTorchOpMapper(OpMapper): outputs=[output_name], scope_name=scope_name, dtype=string(str(param.dtype)), - shape = param.shape, - default_initializer="paddle.nn.initializer.Constant(value=0.0)") + shape=param.shape, + default_initializer="paddle.nn.initializer.Constant(value=0.0)") self.output2id[output_name] = layer_id - def _get_inputs_name(self, node): inputs_name = [] inputs_node = [] @@ -256,7 +263,6 @@ class PyTorchOpMapper(OpMapper): inputs_node.append(script_input_node) inputs_name.append(input_name) return inputs_name, inputs_node - def data(self, graph, node, uid, input_ct): scope_name = self.normalize_scope_name(node) @@ -276,7 +282,8 @@ class PyTorchOpMapper(OpMapper): data=output_name) if self.input_examples is not None: input_np = self.input_examples[input_ct].detach().numpy() - self.inputs_info[output_name] = [list(input_np.shape), str(input_np.dtype)] + self.inputs_info[ + output_name] = [list(input_np.shape), str(input_np.dtype)] return [], [output_name] def equal(self, graph, node, uid=None, parent_layer=None, index=None): @@ -289,7 +296,8 @@ class PyTorchOpMapper(OpMapper): control_output_id = index - 1 output_node_name = parent_layer.outputs[control_output_id] current_outputs = [output_node_name] - self._check_input(graph, node, input_node_name, current_outputs, scope_name) + self._check_input(graph, node, input_node_name, current_outputs, + scope_name) graph.add_layer( "prim.equal", inputs={'input': input_node_name}, @@ -321,20 +329,20 @@ class PyTorchOpMapper(OpMapper): self.scope_name2id[i][ns] = 0 real_scope_name = "/".join(name_segments[1:]) real_father_scope_name = "/".join(name_segments[1:-1]) - + for i, ns in enumerate(name_segments): if i == 0: continue if self.scope_name2id[i][ns] != 0: name_segments[i] = name_segments[i] + \ "__{}".format(self.scope_name2id[i][ns]) - prefix_scope_name = "/".join(name_segments[1 :i + 1]) + prefix_scope_name = "/".join(name_segments[1:i + 1]) is_found = False for j in range(len(self.scope_name_list)): - last_scope_name = self.scope_name_list[-1-j] + last_scope_name = self.scope_name_list[-1 - j] if last_scope_name.startswith(prefix_scope_name + "/") \ or last_scope_name == prefix_scope_name: - if j != 0: # and i != len(name_segments) - 1: + if j != 0: # and i != len(name_segments) - 1: is_found = True origin_name_segment_i = name_segments[i].split("__")[0] self.scope_name2id[i][origin_name_segment_i] += 1 @@ -346,4 +354,3 @@ class PyTorchOpMapper(OpMapper): real_scope_name = "/".join(name_segments[1:]) self.scope_name_list.append(real_scope_name) return real_scope_name - \ No newline at end of file