Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
OneFlow-Benchmark
提交
c9a9342a
O
OneFlow-Benchmark
项目概览
Oneflow-Inc
/
OneFlow-Benchmark
上一次同步 2 年多
通知
1
Star
92
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
OneFlow-Benchmark
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
未验证
提交
c9a9342a
编写于
6月 04, 2021
作者:
L
leaves-zwx
提交者:
GitHub
6月 04, 2021
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #198 from Oneflow-Inc/dev_convert_py_model_to_oneflow
convert pytorch model to oneflow
上级
9d15f79d
0f973def
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
373 addition
and
0 deletion
+373
-0
LanguageModeling/GPT/tools/README.md
LanguageModeling/GPT/tools/README.md
+36
-0
LanguageModeling/GPT/tools/convert_py_model_to_of.py
LanguageModeling/GPT/tools/convert_py_model_to_of.py
+110
-0
LanguageModeling/GPT/tools/meta.proto
LanguageModeling/GPT/tools/meta.proto
+24
-0
LanguageModeling/GPT/tools/meta_pb2.py
LanguageModeling/GPT/tools/meta_pb2.py
+203
-0
未找到文件。
LanguageModeling/GPT/tools/README.md
0 → 100644
浏览文件 @
c9a9342a
# GPT模型转换
### PyTorch模型转OneFlow模型
-
`meta.proto`
,是为生成模型目录下的
`meta`
文件,需要执行
`protoc --python_out=. meta.proto`
后生成
`meta_pb2.py`
,即可
`import meta_pb2 as meta_pb`
```
syntax
=
"proto2"
;
package
gpt
;
message
Shape
{
repeated
int32
dim
=
1
;
}
enum
DataType
{
kInvalidDataType
=
0
;
kChar
=
1
;
kFloat
=
2
;
kDouble
=
3
;
kInt8
=
4
;
kInt32
=
5
;
kInt64
=
6
;
kUInt8
=
7
;
kOFRecord
=
8
;
kFloat16
=
9
;
kTensorBuffer
=
10
;
}
message
Meta
{
required
Shape
shape
=
1
;
required
DataType
data_type
=
2
[
default
=
kFloat16
];
}
```
-
转换脚本
`convert_pt_to_of_gpt.py`
,执行
`python3 convert_pt_to_of_gpt.py --py_model_dir /path/to/iter_0500000/mp_rank_00/model_optim_rng.pt`
即可在当前目录下的
`convert_pt_to_of_gpt`
生成OneFlow模型
-
`--py_model_dir`
,pytorch模型地址
-
`--of_dump_path`
,保存转换后的模型路径
\ No newline at end of file
LanguageModeling/GPT/tools/convert_py_model_to_of.py
0 → 100644
浏览文件 @
c9a9342a
import
argparse
import
os
import
numpy
as
np
import
torch
import
meta_pb2
as
meta_pb
def
get_args
():
parser
=
argparse
.
ArgumentParser
()
## Required parameters
parser
.
add_argument
(
"--py_model_dir"
,
type
=
str
,
default
=
"/path/to/iter_0500000/mp_rank_00/model_optim_rng.pt"
,
help
=
"Path the PyTorch checkpoint file path."
,
)
parser
.
add_argument
(
"--of_dump_path"
,
type
=
str
,
default
=
"./convert_pt_to_of_gpt_release"
,
help
=
"Path to the output OneFlow model."
,
)
return
parser
.
parse_args
()
def
_SaveWeightBlob2File
(
blob
,
op_name
,
save_path
,
var
=
"out"
,
meta
=
"meta"
):
folder
=
os
.
path
.
join
(
save_path
,
op_name
)
if
not
os
.
path
.
exists
(
folder
):
os
.
makedirs
(
folder
)
filename
=
os
.
path
.
join
(
folder
,
var
)
f
=
open
(
filename
,
"wb"
)
f
.
write
(
blob
.
tobytes
())
meta_info
=
meta_pb
.
Meta
()
meta_info
.
shape
.
dim
[:]
=
blob
.
shape
meta_info
.
data_type
=
meta_pb
.
kFloat
filename
=
os
.
path
.
join
(
folder
,
meta
)
f
=
open
(
filename
,
"w"
)
f
.
write
(
str
(
meta_info
))
f
.
close
()
np
.
save
(
filename
,
blob
)
def
_SaveWeightBlob2FileExtend
(
blob
,
op_name
,
save_path
,
var
=
"out"
,
meta
=
"meta"
):
_SaveWeightBlob2File
(
blob
.
numpy
(),
op_name
,
save_path
,
var
=
var
,
meta
=
meta
)
_SaveWeightBlob2File
(
np
.
ones_like
(
blob
),
op_name
+
"-v"
,
save_path
,
var
=
var
,
meta
=
meta
)
_SaveWeightBlob2File
(
np
.
zeros_like
(
blob
),
op_name
+
"-m"
,
save_path
,
var
=
var
,
meta
=
meta
)
def
convert
(
args
):
path
=
args
.
py_model_dir
state_dict
=
torch
.
load
(
path
,
map_location
=
"cpu"
)
for
model_key
,
model_value
in
state_dict
[
"model"
][
"language_model"
][
"transformer"
].
items
():
if
len
(
model_value
.
shape
)
>
1
:
model_value
=
torch
.
transpose
(
model_value
,
0
,
1
)
model_value
=
model_value
.
float
()
op_name_list
=
model_key
.
split
(
"."
)
if
"layers."
in
model_key
:
op_name
=
model_key
.
replace
(
"layers."
,
"model-"
)
op_name
=
op_name
.
replace
(
"-%s."
%
(
op_name_list
[
1
]),
"-h%s-"
%
(
op_name_list
[
1
])
)
else
:
op_name
=
model_key
.
replace
(
"final_layernorm."
,
"model-layernorm_f-"
)
op_name
=
op_name
.
replace
(
"input_layernorm."
,
"layernorm_1-"
)
op_name
=
op_name
.
replace
(
"post_attention_layernorm."
,
"layernorm_2-"
)
op_name
=
op_name
.
replace
(
"attention."
,
"attn-"
)
op_name
=
op_name
.
replace
(
"query_key_value."
,
"c_attn-"
)
op_name
=
op_name
.
replace
(
"dense."
,
"c_proj-"
)
op_name
=
op_name
.
replace
(
"mlp.dense_h_to_4h."
,
"mlp-c_fc-"
)
op_name
=
op_name
.
replace
(
"mlp.dense_4h_to_h."
,
"mlp-c_proj-"
)
if
(
"layernorm_1"
in
op_name
or
"layernorm_2"
in
op_name
or
"layernorm_f"
in
op_name
):
op_name
=
op_name
.
replace
(
"-weight"
,
"-gamma"
)
op_name
=
op_name
.
replace
(
"-bias"
,
"-beta"
)
print
(
model_key
,
"-"
*
8
,
op_name
)
_SaveWeightBlob2FileExtend
(
model_value
,
op_name
,
args
.
of_dump_path
)
_SaveWeightBlob2FileExtend
(
state_dict
[
"model"
][
"language_model"
][
"embedding"
][
"position_embeddings"
][
"weight"
].
float
(),
"model-wpe"
,
args
.
of_dump_path
,
)
_SaveWeightBlob2FileExtend
(
state_dict
[
"model"
][
"language_model"
][
"embedding"
][
"word_embeddings"
][
"weight"
].
float
(),
"model-wte"
,
args
.
of_dump_path
,
)
if
__name__
==
"__main__"
:
args
=
get_args
()
convert
(
args
)
LanguageModeling/GPT/tools/meta.proto
0 → 100644
浏览文件 @
c9a9342a
syntax
=
"proto2"
;
message
Shape
{
repeated
int32
dim
=
1
;
}
enum
DataType
{
kInvalidDataType
=
0
;
kChar
=
1
;
kFloat
=
2
;
kDouble
=
3
;
kInt8
=
4
;
kInt32
=
5
;
kInt64
=
6
;
kUInt8
=
7
;
kOFRecord
=
8
;
kFloat16
=
9
;
kTensorBuffer
=
10
;
}
message
Meta
{
required
Shape
shape
=
1
;
required
DataType
data_type
=
2
[
default
=
kFloat16
];
}
LanguageModeling/GPT/tools/meta_pb2.py
0 → 100644
浏览文件 @
c9a9342a
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: meta.proto
"""Generated protocol buffer code."""
from
google.protobuf.internal
import
enum_type_wrapper
from
google.protobuf
import
descriptor
as
_descriptor
from
google.protobuf
import
message
as
_message
from
google.protobuf
import
reflection
as
_reflection
from
google.protobuf
import
symbol_database
as
_symbol_database
# @@protoc_insertion_point(imports)
_sym_db
=
_symbol_database
.
Default
()
DESCRIPTOR
=
_descriptor
.
FileDescriptor
(
name
=
'meta.proto'
,
package
=
''
,
syntax
=
'proto2'
,
serialized_options
=
None
,
create_key
=
_descriptor
.
_internal_create_key
,
serialized_pb
=
b
'
\n\n
meta.proto
\"\x14\n\x05
Shape
\x12\x0b\n\x03\x64
im
\x18\x01
\x03
(
\x05\"
E
\n\x04
Meta
\x12\x15\n\x05
shape
\x18\x01
\x02
(
\x0b\x32\x06
.Shape
\x12
&
\n\t
data_type
\x18\x02
\x02
(
\x0e\x32\t
.DataType:
\x08
kFloat16*
\xa3\x01\n\x08\x44\x61
taType
\x12\x14\n\x10
kInvalidDataType
\x10\x00\x12\t\n\x05
kChar
\x10\x01\x12\n\n\x06
kFloat
\x10\x02\x12\x0b\n\x07
kDouble
\x10\x03\x12\t\n\x05
kInt8
\x10\x04\x12\n\n\x06
kInt32
\x10\x05\x12\n\n\x06
kInt64
\x10\x06\x12\n\n\x06
kUInt8
\x10\x07\x12\r\n\t
kOFRecord
\x10\x08\x12\x0c\n\x08
kFloat16
\x10\t\x12\x11\n\r
kTensorBuffer
\x10\n
'
)
_DATATYPE
=
_descriptor
.
EnumDescriptor
(
name
=
'DataType'
,
full_name
=
'DataType'
,
filename
=
None
,
file
=
DESCRIPTOR
,
create_key
=
_descriptor
.
_internal_create_key
,
values
=
[
_descriptor
.
EnumValueDescriptor
(
name
=
'kInvalidDataType'
,
index
=
0
,
number
=
0
,
serialized_options
=
None
,
type
=
None
,
create_key
=
_descriptor
.
_internal_create_key
),
_descriptor
.
EnumValueDescriptor
(
name
=
'kChar'
,
index
=
1
,
number
=
1
,
serialized_options
=
None
,
type
=
None
,
create_key
=
_descriptor
.
_internal_create_key
),
_descriptor
.
EnumValueDescriptor
(
name
=
'kFloat'
,
index
=
2
,
number
=
2
,
serialized_options
=
None
,
type
=
None
,
create_key
=
_descriptor
.
_internal_create_key
),
_descriptor
.
EnumValueDescriptor
(
name
=
'kDouble'
,
index
=
3
,
number
=
3
,
serialized_options
=
None
,
type
=
None
,
create_key
=
_descriptor
.
_internal_create_key
),
_descriptor
.
EnumValueDescriptor
(
name
=
'kInt8'
,
index
=
4
,
number
=
4
,
serialized_options
=
None
,
type
=
None
,
create_key
=
_descriptor
.
_internal_create_key
),
_descriptor
.
EnumValueDescriptor
(
name
=
'kInt32'
,
index
=
5
,
number
=
5
,
serialized_options
=
None
,
type
=
None
,
create_key
=
_descriptor
.
_internal_create_key
),
_descriptor
.
EnumValueDescriptor
(
name
=
'kInt64'
,
index
=
6
,
number
=
6
,
serialized_options
=
None
,
type
=
None
,
create_key
=
_descriptor
.
_internal_create_key
),
_descriptor
.
EnumValueDescriptor
(
name
=
'kUInt8'
,
index
=
7
,
number
=
7
,
serialized_options
=
None
,
type
=
None
,
create_key
=
_descriptor
.
_internal_create_key
),
_descriptor
.
EnumValueDescriptor
(
name
=
'kOFRecord'
,
index
=
8
,
number
=
8
,
serialized_options
=
None
,
type
=
None
,
create_key
=
_descriptor
.
_internal_create_key
),
_descriptor
.
EnumValueDescriptor
(
name
=
'kFloat16'
,
index
=
9
,
number
=
9
,
serialized_options
=
None
,
type
=
None
,
create_key
=
_descriptor
.
_internal_create_key
),
_descriptor
.
EnumValueDescriptor
(
name
=
'kTensorBuffer'
,
index
=
10
,
number
=
10
,
serialized_options
=
None
,
type
=
None
,
create_key
=
_descriptor
.
_internal_create_key
),
],
containing_type
=
None
,
serialized_options
=
None
,
serialized_start
=
108
,
serialized_end
=
271
,
)
_sym_db
.
RegisterEnumDescriptor
(
_DATATYPE
)
DataType
=
enum_type_wrapper
.
EnumTypeWrapper
(
_DATATYPE
)
kInvalidDataType
=
0
kChar
=
1
kFloat
=
2
kDouble
=
3
kInt8
=
4
kInt32
=
5
kInt64
=
6
kUInt8
=
7
kOFRecord
=
8
kFloat16
=
9
kTensorBuffer
=
10
_SHAPE
=
_descriptor
.
Descriptor
(
name
=
'Shape'
,
full_name
=
'Shape'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
create_key
=
_descriptor
.
_internal_create_key
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'dim'
,
full_name
=
'Shape.dim'
,
index
=
0
,
number
=
1
,
type
=
5
,
cpp_type
=
1
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
serialized_options
=
None
,
file
=
DESCRIPTOR
,
create_key
=
_descriptor
.
_internal_create_key
),
],
extensions
=
[
],
nested_types
=
[],
enum_types
=
[
],
serialized_options
=
None
,
is_extendable
=
False
,
syntax
=
'proto2'
,
extension_ranges
=
[],
oneofs
=
[
],
serialized_start
=
14
,
serialized_end
=
34
,
)
_META
=
_descriptor
.
Descriptor
(
name
=
'Meta'
,
full_name
=
'Meta'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
create_key
=
_descriptor
.
_internal_create_key
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'shape'
,
full_name
=
'Meta.shape'
,
index
=
0
,
number
=
1
,
type
=
11
,
cpp_type
=
10
,
label
=
2
,
has_default_value
=
False
,
default_value
=
None
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
serialized_options
=
None
,
file
=
DESCRIPTOR
,
create_key
=
_descriptor
.
_internal_create_key
),
_descriptor
.
FieldDescriptor
(
name
=
'data_type'
,
full_name
=
'Meta.data_type'
,
index
=
1
,
number
=
2
,
type
=
14
,
cpp_type
=
8
,
label
=
2
,
has_default_value
=
True
,
default_value
=
9
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
serialized_options
=
None
,
file
=
DESCRIPTOR
,
create_key
=
_descriptor
.
_internal_create_key
),
],
extensions
=
[
],
nested_types
=
[],
enum_types
=
[
],
serialized_options
=
None
,
is_extendable
=
False
,
syntax
=
'proto2'
,
extension_ranges
=
[],
oneofs
=
[
],
serialized_start
=
36
,
serialized_end
=
105
,
)
_META
.
fields_by_name
[
'shape'
].
message_type
=
_SHAPE
_META
.
fields_by_name
[
'data_type'
].
enum_type
=
_DATATYPE
DESCRIPTOR
.
message_types_by_name
[
'Shape'
]
=
_SHAPE
DESCRIPTOR
.
message_types_by_name
[
'Meta'
]
=
_META
DESCRIPTOR
.
enum_types_by_name
[
'DataType'
]
=
_DATATYPE
_sym_db
.
RegisterFileDescriptor
(
DESCRIPTOR
)
Shape
=
_reflection
.
GeneratedProtocolMessageType
(
'Shape'
,
(
_message
.
Message
,),
{
'DESCRIPTOR'
:
_SHAPE
,
'__module__'
:
'meta_pb2'
# @@protoc_insertion_point(class_scope:Shape)
})
_sym_db
.
RegisterMessage
(
Shape
)
Meta
=
_reflection
.
GeneratedProtocolMessageType
(
'Meta'
,
(
_message
.
Message
,),
{
'DESCRIPTOR'
:
_META
,
'__module__'
:
'meta_pb2'
# @@protoc_insertion_point(class_scope:Meta)
})
_sym_db
.
RegisterMessage
(
Meta
)
# @@protoc_insertion_point(module_scope)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录