Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
44ba6942
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
44ba6942
编写于
9月 20, 2018
作者:
L
luotao1
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
put clone(for_test=True) before optimization phase
上级
abf019f6
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
8 addition
and
9 deletion
+8
-9
python/paddle/fluid/tests/unittests/dist_transformer.py
python/paddle/fluid/tests/unittests/dist_transformer.py
+8
-9
未找到文件。
python/paddle/fluid/tests/unittests/dist_transformer.py
浏览文件 @
44ba6942
...
@@ -437,11 +437,8 @@ def split_data(data, num_part):
...
@@ -437,11 +437,8 @@ def split_data(data, num_part):
]
]
def
test_context
(
t
rain_prog
m
,
avg_cost
,
train_exe
,
dev_count
,
data_input_names
,
def
test_context
(
t
est_progra
m
,
avg_cost
,
train_exe
,
dev_count
,
data_input_names
,
sum_cost
,
token_num
):
sum_cost
,
token_num
):
# Context to do validation.
test_program
=
train_progm
.
clone
(
for_test
=
True
)
val_data
=
DataReader
(
val_data
=
DataReader
(
src_vocab_fpath
=
TrainTaskConfig
.
src_vocab_fpath
,
src_vocab_fpath
=
TrainTaskConfig
.
src_vocab_fpath
,
trg_vocab_fpath
=
TrainTaskConfig
.
trg_vocab_fpath
,
trg_vocab_fpath
=
TrainTaskConfig
.
trg_vocab_fpath
,
...
@@ -503,7 +500,7 @@ def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names,
...
@@ -503,7 +500,7 @@ def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names,
def
train_loop
(
exe
,
train_progm
,
dev_count
,
sum_cost
,
avg_cost
,
lr_scheduler
,
def
train_loop
(
exe
,
train_progm
,
dev_count
,
sum_cost
,
avg_cost
,
lr_scheduler
,
token_num
,
predict
):
token_num
,
predict
,
test_program
):
# Initialize the parameters.
# Initialize the parameters.
if
TrainTaskConfig
.
ckpt_path
:
if
TrainTaskConfig
.
ckpt_path
:
lr_scheduler
.
current_steps
=
TrainTaskConfig
.
start_step
lr_scheduler
.
current_steps
=
TrainTaskConfig
.
start_step
...
@@ -552,7 +549,7 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
...
@@ -552,7 +549,7 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
-
1
]
+
label_data_input_fields
-
1
]
+
label_data_input_fields
if
TrainTaskConfig
.
val_file_pattern
is
not
None
:
if
TrainTaskConfig
.
val_file_pattern
is
not
None
:
test
=
test_context
(
t
rain_prog
m
,
avg_cost
,
train_exe
,
dev_count
,
test
=
test_context
(
t
est_progra
m
,
avg_cost
,
train_exe
,
dev_count
,
data_input_names
,
sum_cost
,
token_num
)
data_input_names
,
sum_cost
,
token_num
)
# the best cross-entropy value with label smoothing
# the best cross-entropy value with label smoothing
...
@@ -1645,6 +1642,8 @@ def get_model(is_dist, is_async):
...
@@ -1645,6 +1642,8 @@ def get_model(is_dist, is_async):
local_lr_scheduler
=
LearningRateScheduler
(
ModelHyperParams
.
d_model
,
local_lr_scheduler
=
LearningRateScheduler
(
ModelHyperParams
.
d_model
,
TrainTaskConfig
.
warmup_steps
,
TrainTaskConfig
.
warmup_steps
,
TrainTaskConfig
.
learning_rate
)
TrainTaskConfig
.
learning_rate
)
# Context to do validation.
test_program
=
fluid
.
default_main_program
().
clone
(
for_test
=
True
)
if
not
is_dist
:
if
not
is_dist
:
optimizer
=
fluid
.
optimizer
.
Adam
(
optimizer
=
fluid
.
optimizer
.
Adam
(
...
@@ -1669,7 +1668,7 @@ def get_model(is_dist, is_async):
...
@@ -1669,7 +1668,7 @@ def get_model(is_dist, is_async):
epsilon
=
TrainTaskConfig
.
eps
)
epsilon
=
TrainTaskConfig
.
eps
)
optimizer
.
minimize
(
sum_cost
)
optimizer
.
minimize
(
sum_cost
)
return
sum_cost
,
avg_cost
,
predict
,
token_num
,
local_lr_scheduler
return
sum_cost
,
avg_cost
,
predict
,
token_num
,
local_lr_scheduler
,
test_program
def
update_args
():
def
update_args
():
...
@@ -1703,7 +1702,7 @@ class DistTransformer2x2(TestDistRunnerBase):
...
@@ -1703,7 +1702,7 @@ class DistTransformer2x2(TestDistRunnerBase):
def
run_trainer
(
self
,
use_cuda
,
args
):
def
run_trainer
(
self
,
use_cuda
,
args
):
place
=
fluid
.
CUDAPlace
(
0
)
if
use_cuda
else
fluid
.
CPUPlace
()
place
=
fluid
.
CUDAPlace
(
0
)
if
use_cuda
else
fluid
.
CPUPlace
()
TrainTaskConfig
.
use_gpu
=
use_cuda
TrainTaskConfig
.
use_gpu
=
use_cuda
sum_cost
,
avg_cost
,
predict
,
token_num
,
local_lr_scheduler
=
get_model
(
sum_cost
,
avg_cost
,
predict
,
token_num
,
local_lr_scheduler
,
test_program
=
get_model
(
args
.
is_dist
,
not
args
.
sync_mode
)
args
.
is_dist
,
not
args
.
sync_mode
)
if
args
.
is_dist
:
if
args
.
is_dist
:
...
@@ -1724,7 +1723,7 @@ class DistTransformer2x2(TestDistRunnerBase):
...
@@ -1724,7 +1723,7 @@ class DistTransformer2x2(TestDistRunnerBase):
TrainTaskConfig
.
local
=
not
args
.
is_dist
TrainTaskConfig
.
local
=
not
args
.
is_dist
train_loop
(
startup_exe
,
trainer_prog
,
1
,
sum_cost
,
avg_cost
,
train_loop
(
startup_exe
,
trainer_prog
,
1
,
sum_cost
,
avg_cost
,
local_lr_scheduler
,
token_num
,
predict
)
local_lr_scheduler
,
token_num
,
predict
,
test_program
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录