Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
c54887dd
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
280
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
c54887dd
编写于
9月 21, 2020
作者:
K
kinghuin
提交者:
GitHub
9月 21, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix the windows bug in ernie_gen (#905)
上级
a6ceff1c
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
25 addition
and
21 deletion
+25
-21
hub_module/modules/text/text_generation/ernie_gen/README.md
hub_module/modules/text/text_generation/ernie_gen/README.md
+4
-0
hub_module/modules/text/text_generation/ernie_gen/module.py
hub_module/modules/text/text_generation/ernie_gen/module.py
+4
-3
hub_module/modules/text/text_generation/ernie_gen/propeller/paddle/data/example.proto
..._generation/ernie_gen/propeller/paddle/data/example.proto
+2
-2
hub_module/modules/text/text_generation/ernie_gen/propeller/paddle/data/feature.proto
..._generation/ernie_gen/propeller/paddle/data/feature.proto
+1
-1
hub_module/modules/text/text_generation/ernie_gen/propeller/paddle/data/feature_column.py
...eration/ernie_gen/propeller/paddle/data/feature_column.py
+1
-1
hub_module/modules/text/text_generation/ernie_gen/propeller/paddle/train/hooks.py
...text_generation/ernie_gen/propeller/paddle/train/hooks.py
+4
-4
hub_module/modules/text/text_generation/ernie_gen/propeller/paddle/train/trainer.py
...xt_generation/ernie_gen/propeller/paddle/train/trainer.py
+7
-7
hub_module/modules/text/text_generation/ernie_gen/propeller/service/server.py
...ext/text_generation/ernie_gen/propeller/service/server.py
+2
-2
hub_module/modules/text/text_generation/ernie_gen/propeller/tools/ckpt_inspector.py
...xt_generation/ernie_gen/propeller/tools/ckpt_inspector.py
+0
-1
hub_module/modules/text/text_generation/gpt2/__init__.py
hub_module/modules/text/text_generation/gpt2/__init__.py
+0
-0
未找到文件。
hub_module/modules/text/text_generation/ernie_gen/README.md
浏览文件 @
c54887dd
...
...
@@ -184,3 +184,7 @@ paddlehub >= 1.7.0
*
1.0.1
修复模型导出bug
*
1.0.2
修复windows运行中的bug
hub_module/modules/text/text_generation/ernie_gen/module.py
浏览文件 @
c54887dd
...
...
@@ -39,7 +39,7 @@ import ernie_gen.propeller.paddle as propeller
@
moduleinfo
(
name
=
"ernie_gen"
,
version
=
"1.0.
1
"
,
version
=
"1.0.
2
"
,
summary
=
"ERNIE-GEN is a multi-flow language generation framework for both pre-training and fine-tuning."
,
author
=
"baidu"
,
...
...
@@ -371,10 +371,11 @@ class ErnieGen(hub.Module):
src_ids
=
src_ids
[:
self
.
max_encode_len
]
tgt_ids
=
tgt_ids
[:
self
.
max_decode_len
]
src_ids
,
src_sids
=
self
.
tokenizer
.
build_for_ernie
(
src_ids
)
src_pids
=
np
.
arange
(
len
(
src_ids
))
src_pids
=
np
.
arange
(
len
(
src_ids
)
,
dtype
=
np
.
int64
)
tgt_ids
,
tgt_sids
=
self
.
tokenizer
.
build_for_ernie
(
tgt_ids
)
tgt_pids
=
np
.
arange
(
len
(
tgt_ids
))
+
len
(
src_ids
)
# continues position
tgt_pids
=
np
.
arange
(
len
(
tgt_ids
),
dtype
=
np
.
int64
)
+
len
(
src_ids
)
# continues position
tgt_sids
=
np
.
ones_like
(
tgt_sids
)
attn_ids
=
np
.
ones_like
(
tgt_ids
)
*
self
.
tokenizer
.
vocab
[
'[MASK]'
]
...
...
hub_module/modules/text/text_generation/ernie_gen/propeller/paddle/data/example.proto
浏览文件 @
c54887dd
...
...
@@ -16,8 +16,8 @@
// model training or inference.
syntax
=
"proto3"
;
import
"propeller/paddle/data/feature.proto"
;
package
propeller
;
import
"
ernie_gen.
propeller/paddle/data/feature.proto"
;
package
ernie_gen
.
propeller
;
message
Example
{
Features
features
=
1
;
...
...
hub_module/modules/text/text_generation/ernie_gen/propeller/paddle/data/feature.proto
浏览文件 @
c54887dd
...
...
@@ -13,7 +13,7 @@
// limitations under the License.
syntax
=
"proto3"
;
package
propeller
;
package
ernie_gen
.
propeller
;
message
BytesList
{
repeated
bytes
value
=
1
;
...
...
hub_module/modules/text/text_generation/ernie_gen/propeller/paddle/data/feature_column.py
浏览文件 @
c54887dd
...
...
@@ -125,7 +125,7 @@ class LabelColumn(Column):
ids
=
int
(
raw
)
else
:
ids
=
self
.
vocab
[
raw
]
return
ids
return
np
.
array
(
ids
,
dtype
=
np
.
int64
)
class
TextColumn
(
Column
):
...
...
hub_module/modules/text/text_generation/ernie_gen/propeller/paddle/train/hooks.py
浏览文件 @
c54887dd
...
...
@@ -73,7 +73,7 @@ class TqdmProgressBarHook(RunHook):
"""doc"""
self
.
tqdm
=
None
import
tqdm
from
propeller
import
log
as
main_log
from
ernie_gen.
propeller
import
log
as
main_log
hdl
=
main_log
.
handlers
[
0
]
class
_TqdmLogginHandler
(
logging
.
Handler
):
...
...
@@ -110,7 +110,7 @@ class TqdmNotebookProgressBarHook(RunHook):
"""doc"""
self
.
tqdm
=
None
import
tqdm
from
propeller
import
log
as
main_log
from
ernie_gen.
propeller
import
log
as
main_log
hdl
=
main_log
.
handlers
[
0
]
class
_TqdmLogginHandler
(
logging
.
Handler
):
...
...
@@ -144,7 +144,7 @@ class TqdmNotebookProgressBarHook(RunHook):
class
LoggingHook
(
RunHook
):
"""log tensor in to screan and
tensorboard
"""
"""log tensor in to screan and
VisualDL
"""
def
__init__
(
self
,
loss
,
...
...
@@ -205,7 +205,7 @@ class LoggingHook(RunHook):
speed
=
-
1.
self
.
last_state
=
state
# log to
tensorboard
# log to
VisualDL
if
self
.
writer
is
not
None
:
self
.
writer
.
add_scalar
(
'loss'
,
loss
,
state
.
gstep
)
for
name
,
t
in
zip
(
self
.
s_name
,
s_np
):
...
...
hub_module/modules/text/text_generation/ernie_gen/propeller/paddle/train/trainer.py
浏览文件 @
c54887dd
...
...
@@ -48,11 +48,11 @@ __all__ = ['train_and_eval', 'Learner']
def
_get_summary_writer
(
path
):
summary_writer
=
None
try
:
from
tensorboardX
import
Summary
Writer
from
visualdl
import
Log
Writer
if
distribution
.
status
.
is_master
:
summary_writer
=
Summary
Writer
(
os
.
path
.
join
(
path
))
summary_writer
=
Log
Writer
(
os
.
path
.
join
(
path
))
except
ImportError
:
log
.
warning
(
'
tensorboardX not installed, will not log to tensorboard
'
)
log
.
warning
(
'
VisualDL not installed, will not log to VisualDL
'
)
return
summary_writer
...
...
@@ -69,7 +69,7 @@ def _log_eval_result(name, eval_result, swriter, state):
printable
.
append
(
'{}
\t
{}'
.
format
(
n
,
val
))
if
swriter
is
not
None
:
swriter
.
add_scalar
(
n
,
val
,
state
.
gstep
)
log
.
debug
(
'write to
tensorboard
%s'
%
swriter
.
logdir
)
log
.
debug
(
'write to
VisualDL
%s'
%
swriter
.
logdir
)
if
len
(
printable
):
log
.
info
(
'*** eval res: %10s ***'
%
name
)
...
...
@@ -134,10 +134,10 @@ class Learner(object):
if
run_config
.
model_dir
is
None
:
raise
ValueError
(
'model_dir should specified in run_config'
)
if
issubclass
(
model_class_or_model_fn
,
Model
):
_model_fn
=
_build_model_fn
(
model_class_or_model_fn
)
elif
inspect
.
isfunction
(
model_class_or_model_fn
):
if
inspect
.
isfunction
(
model_class_or_model_fn
):
_model_fn
=
model_class_or_model_fn
elif
issubclass
(
model_class_or_model_fn
,
Model
):
_model_fn
=
_build_model_fn
(
model_class_or_model_fn
)
else
:
raise
ValueError
(
'unknown model %s'
%
model_class_or_model_fn
)
...
...
hub_module/modules/text/text_generation/ernie_gen/propeller/service/server.py
浏览文件 @
c54887dd
...
...
@@ -71,8 +71,8 @@ def run_worker(model_dir, device_idx, endpoint="ipc://worker.ipc"):
"CUDA_VISIBLE_DEVICES"
).
split
(
","
)[
device_idx
]
log
.
debug
(
'cuda_env %s'
%
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
])
import
paddle.fluid
as
F
from
propeller.service
import
interface_pb2
import
propeller.service.utils
as
serv_utils
from
ernie_gen.
propeller.service
import
interface_pb2
import
ernie_gen.
propeller.service.utils
as
serv_utils
context
=
zmq
.
Context
()
socket
=
context
.
socket
(
zmq
.
REP
)
socket
.
connect
(
endpoint
)
...
...
hub_module/modules/text/text_generation/ernie_gen/propeller/tools/ckpt_inspector.py
浏览文件 @
c54887dd
...
...
@@ -26,7 +26,6 @@ import collections
from
distutils
import
dir_util
import
pickle
#from utils import print_arguments
import
paddle.fluid
as
F
from
paddle.fluid.proto
import
framework_pb2
...
...
hub_module/modules/text/text_generation/gpt2/__init__.py
0 → 100644
浏览文件 @
c54887dd
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录