未验证 提交 c9a9342a 编写于 作者: L leaves-zwx 提交者: GitHub

Merge pull request #198 from Oneflow-Inc/dev_convert_py_model_to_oneflow

convert pytorch model to oneflow
# GPT模型转换
### PyTorch模型转OneFlow模型
- `meta.proto`,是为生成模型目录下的`meta`文件,需要执行`protoc --python_out=. meta.proto`后生成`meta_pb2.py`,即可`import meta_pb2 as meta_pb`
```
syntax = "proto2";
package gpt;
message Shape {
repeated int32 dim = 1;
}
enum DataType {
kInvalidDataType = 0;
kChar = 1;
kFloat = 2;
kDouble = 3;
kInt8 = 4;
kInt32 = 5;
kInt64 = 6;
kUInt8 = 7;
kOFRecord = 8;
kFloat16 = 9;
kTensorBuffer = 10;
}
message Meta {
required Shape shape = 1;
required DataType data_type = 2 [default = kFloat16];
}
```
- 转换脚本`convert_pt_to_of_gpt.py`,执行`python3 convert_pt_to_of_gpt.py --py_model_dir /path/to/iter_0500000/mp_rank_00/model_optim_rng.pt`即可在当前目录下的`convert_pt_to_of_gpt`生成OneFlow模型
- `--py_model_dir`,pytorch模型地址
- `--of_dump_path`,保存转换后的模型路径
\ No newline at end of file
import argparse
import os
import numpy as np
import torch
import meta_pb2 as meta_pb
def get_args():
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument(
"--py_model_dir",
type=str,
default="/path/to/iter_0500000/mp_rank_00/model_optim_rng.pt",
help="Path the PyTorch checkpoint file path.",
)
parser.add_argument(
"--of_dump_path",
type=str,
default="./convert_pt_to_of_gpt_release",
help="Path to the output OneFlow model.",
)
return parser.parse_args()
def _SaveWeightBlob2File(blob, op_name, save_path, var="out", meta="meta"):
folder = os.path.join(save_path, op_name)
if not os.path.exists(folder):
os.makedirs(folder)
filename = os.path.join(folder, var)
f = open(filename, "wb")
f.write(blob.tobytes())
meta_info = meta_pb.Meta()
meta_info.shape.dim[:] = blob.shape
meta_info.data_type = meta_pb.kFloat
filename = os.path.join(folder, meta)
f = open(filename, "w")
f.write(str(meta_info))
f.close()
np.save(filename, blob)
def _SaveWeightBlob2FileExtend(blob, op_name, save_path, var="out", meta="meta"):
_SaveWeightBlob2File(blob.numpy(), op_name, save_path, var=var, meta=meta)
_SaveWeightBlob2File(
np.ones_like(blob), op_name + "-v", save_path, var=var, meta=meta
)
_SaveWeightBlob2File(
np.zeros_like(blob), op_name + "-m", save_path, var=var, meta=meta
)
def convert(args):
path = args.py_model_dir
state_dict = torch.load(path, map_location="cpu")
for model_key, model_value in state_dict["model"]["language_model"][
"transformer"
].items():
if len(model_value.shape) > 1:
model_value = torch.transpose(model_value, 0, 1)
model_value = model_value.float()
op_name_list = model_key.split(".")
if "layers." in model_key:
op_name = model_key.replace("layers.", "model-")
op_name = op_name.replace(
"-%s." % (op_name_list[1]), "-h%s-" % (op_name_list[1])
)
else:
op_name = model_key.replace("final_layernorm.", "model-layernorm_f-")
op_name = op_name.replace("input_layernorm.", "layernorm_1-")
op_name = op_name.replace("post_attention_layernorm.", "layernorm_2-")
op_name = op_name.replace("attention.", "attn-")
op_name = op_name.replace("query_key_value.", "c_attn-")
op_name = op_name.replace("dense.", "c_proj-")
op_name = op_name.replace("mlp.dense_h_to_4h.", "mlp-c_fc-")
op_name = op_name.replace("mlp.dense_4h_to_h.", "mlp-c_proj-")
if (
"layernorm_1" in op_name
or "layernorm_2" in op_name
or "layernorm_f" in op_name
):
op_name = op_name.replace("-weight", "-gamma")
op_name = op_name.replace("-bias", "-beta")
print(model_key, "-" * 8, op_name)
_SaveWeightBlob2FileExtend(model_value, op_name, args.of_dump_path)
_SaveWeightBlob2FileExtend(
state_dict["model"]["language_model"]["embedding"]["position_embeddings"][
"weight"
].float(),
"model-wpe",
args.of_dump_path,
)
_SaveWeightBlob2FileExtend(
state_dict["model"]["language_model"]["embedding"]["word_embeddings"][
"weight"
].float(),
"model-wte",
args.of_dump_path,
)
if __name__ == "__main__":
args = get_args()
convert(args)
syntax = "proto2";
message Shape {
repeated int32 dim = 1;
}
enum DataType {
kInvalidDataType = 0;
kChar = 1;
kFloat = 2;
kDouble = 3;
kInt8 = 4;
kInt32 = 5;
kInt64 = 6;
kUInt8 = 7;
kOFRecord = 8;
kFloat16 = 9;
kTensorBuffer = 10;
}
message Meta {
required Shape shape = 1;
required DataType data_type = 2 [default = kFloat16];
}
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: meta.proto
"""Generated protocol buffer code."""
from google.protobuf.internal import enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='meta.proto',
package='',
syntax='proto2',
serialized_options=None,
create_key=_descriptor._internal_create_key,
serialized_pb=b'\n\nmeta.proto\"\x14\n\x05Shape\x12\x0b\n\x03\x64im\x18\x01 \x03(\x05\"E\n\x04Meta\x12\x15\n\x05shape\x18\x01 \x02(\x0b\x32\x06.Shape\x12&\n\tdata_type\x18\x02 \x02(\x0e\x32\t.DataType:\x08kFloat16*\xa3\x01\n\x08\x44\x61taType\x12\x14\n\x10kInvalidDataType\x10\x00\x12\t\n\x05kChar\x10\x01\x12\n\n\x06kFloat\x10\x02\x12\x0b\n\x07kDouble\x10\x03\x12\t\n\x05kInt8\x10\x04\x12\n\n\x06kInt32\x10\x05\x12\n\n\x06kInt64\x10\x06\x12\n\n\x06kUInt8\x10\x07\x12\r\n\tkOFRecord\x10\x08\x12\x0c\n\x08kFloat16\x10\t\x12\x11\n\rkTensorBuffer\x10\n'
)
_DATATYPE = _descriptor.EnumDescriptor(
name='DataType',
full_name='DataType',
filename=None,
file=DESCRIPTOR,
create_key=_descriptor._internal_create_key,
values=[
_descriptor.EnumValueDescriptor(
name='kInvalidDataType', index=0, number=0,
serialized_options=None,
type=None,
create_key=_descriptor._internal_create_key),
_descriptor.EnumValueDescriptor(
name='kChar', index=1, number=1,
serialized_options=None,
type=None,
create_key=_descriptor._internal_create_key),
_descriptor.EnumValueDescriptor(
name='kFloat', index=2, number=2,
serialized_options=None,
type=None,
create_key=_descriptor._internal_create_key),
_descriptor.EnumValueDescriptor(
name='kDouble', index=3, number=3,
serialized_options=None,
type=None,
create_key=_descriptor._internal_create_key),
_descriptor.EnumValueDescriptor(
name='kInt8', index=4, number=4,
serialized_options=None,
type=None,
create_key=_descriptor._internal_create_key),
_descriptor.EnumValueDescriptor(
name='kInt32', index=5, number=5,
serialized_options=None,
type=None,
create_key=_descriptor._internal_create_key),
_descriptor.EnumValueDescriptor(
name='kInt64', index=6, number=6,
serialized_options=None,
type=None,
create_key=_descriptor._internal_create_key),
_descriptor.EnumValueDescriptor(
name='kUInt8', index=7, number=7,
serialized_options=None,
type=None,
create_key=_descriptor._internal_create_key),
_descriptor.EnumValueDescriptor(
name='kOFRecord', index=8, number=8,
serialized_options=None,
type=None,
create_key=_descriptor._internal_create_key),
_descriptor.EnumValueDescriptor(
name='kFloat16', index=9, number=9,
serialized_options=None,
type=None,
create_key=_descriptor._internal_create_key),
_descriptor.EnumValueDescriptor(
name='kTensorBuffer', index=10, number=10,
serialized_options=None,
type=None,
create_key=_descriptor._internal_create_key),
],
containing_type=None,
serialized_options=None,
serialized_start=108,
serialized_end=271,
)
_sym_db.RegisterEnumDescriptor(_DATATYPE)
DataType = enum_type_wrapper.EnumTypeWrapper(_DATATYPE)
kInvalidDataType = 0
kChar = 1
kFloat = 2
kDouble = 3
kInt8 = 4
kInt32 = 5
kInt64 = 6
kUInt8 = 7
kOFRecord = 8
kFloat16 = 9
kTensorBuffer = 10
_SHAPE = _descriptor.Descriptor(
name='Shape',
full_name='Shape',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='dim', full_name='Shape.dim', index=0,
number=1, type=5, cpp_type=1, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=14,
serialized_end=34,
)
_META = _descriptor.Descriptor(
name='Meta',
full_name='Meta',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='shape', full_name='Meta.shape', index=0,
number=1, type=11, cpp_type=10, label=2,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='data_type', full_name='Meta.data_type', index=1,
number=2, type=14, cpp_type=8, label=2,
has_default_value=True, default_value=9,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=36,
serialized_end=105,
)
_META.fields_by_name['shape'].message_type = _SHAPE
_META.fields_by_name['data_type'].enum_type = _DATATYPE
DESCRIPTOR.message_types_by_name['Shape'] = _SHAPE
DESCRIPTOR.message_types_by_name['Meta'] = _META
DESCRIPTOR.enum_types_by_name['DataType'] = _DATATYPE
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
Shape = _reflection.GeneratedProtocolMessageType('Shape', (_message.Message,), {
'DESCRIPTOR' : _SHAPE,
'__module__' : 'meta_pb2'
# @@protoc_insertion_point(class_scope:Shape)
})
_sym_db.RegisterMessage(Shape)
Meta = _reflection.GeneratedProtocolMessageType('Meta', (_message.Message,), {
'DESCRIPTOR' : _META,
'__module__' : 'meta_pb2'
# @@protoc_insertion_point(class_scope:Meta)
})
_sym_db.RegisterMessage(Meta)
# @@protoc_insertion_point(module_scope)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册