Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
曾经的那一瞬间
Models
提交
68104ce3
M
Models
项目概览
曾经的那一瞬间
/
Models
11 个月 前同步成功
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
Models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
68104ce3
编写于
12月 10, 2019
作者:
A
A. Unique TensorFlower
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Internal change
PiperOrigin-RevId: 284805626
上级
f7fd59b8
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
33 addition
and
18 deletion
+33
-18
official/transformer/v2/transformer_main.py
official/transformer/v2/transformer_main.py
+33
-18
未找到文件。
official/transformer/v2/transformer_main.py
浏览文件 @
68104ce3
...
...
@@ -44,7 +44,6 @@ from official.utils.logs import logger
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
distribution_utils
INF
=
int
(
1e9
)
BLEU_DIR
=
"bleu"
_SINGLE_SAMPLE
=
1
...
...
@@ -158,6 +157,7 @@ class TransformerTask(object):
params
[
"batch_size"
]
=
flags_obj
.
batch_size
or
params
[
"default_batch_size"
]
params
[
"repeat_dataset"
]
=
None
params
[
"dtype"
]
=
flags_core
.
get_tf_dtype
(
flags_obj
)
params
[
"enable_tensorboard"
]
=
flags_obj
.
enable_tensorboard
params
[
"enable_metrics_in_training"
]
=
flags_obj
.
enable_metrics_in_training
params
[
"steps_between_evals"
]
=
flags_obj
.
steps_between_evals
...
...
@@ -183,8 +183,8 @@ class TransformerTask(object):
# like this. What if multiple instances of TransformerTask are created?
# We should have a better way in the tf.keras.mixed_precision API of doing
# this.
loss_scale
=
flags_core
.
get_loss_scale
(
flags_obj
,
default_for_fp16
=
"dynamic"
)
loss_scale
=
flags_core
.
get_loss_scale
(
flags_obj
,
default_for_fp16
=
"dynamic"
)
policy
=
tf
.
compat
.
v2
.
keras
.
mixed_precision
.
experimental
.
Policy
(
"mixed_float16"
,
loss_scale
=
loss_scale
)
tf
.
compat
.
v2
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
policy
)
...
...
@@ -206,8 +206,7 @@ class TransformerTask(object):
params
=
self
.
params
flags_obj
=
self
.
flags_obj
# Sets config options.
keras_utils
.
set_session_config
(
enable_xla
=
flags_obj
.
enable_xla
)
keras_utils
.
set_session_config
(
enable_xla
=
flags_obj
.
enable_xla
)
_ensure_dir
(
flags_obj
.
model_dir
)
with
distribution_utils
.
get_strategy_scope
(
self
.
distribution_strategy
):
...
...
@@ -225,6 +224,14 @@ class TransformerTask(object):
if
params
[
"use_ctl"
]:
train_loss_metric
=
tf
.
keras
.
metrics
.
Mean
(
"training_loss"
,
dtype
=
tf
.
float32
)
if
params
[
"enable_tensorboard"
]:
summary_writer
=
tf
.
compat
.
v2
.
summary
.
create_file_writer
(
flags_obj
.
model_dir
)
else
:
summary_writer
=
tf
.
compat
.
v2
.
summary
.
create_noop_writer
()
train_metrics
=
[
train_loss_metric
]
if
params
[
"enable_metrics_in_training"
]:
train_metrics
=
train_metrics
+
model
.
metrics
else
:
model
.
compile
(
opt
)
...
...
@@ -303,17 +310,23 @@ class TransformerTask(object):
raise
NotImplementedError
(
"Custom training loop on GPUs is not implemented."
)
# Runs training steps.
train_steps
(
train_ds_iterator
,
tf
.
convert_to_tensor
(
train_steps_per_eval
,
dtype
=
tf
.
int32
))
current_step
+=
train_steps_per_eval
train_loss
=
train_loss_metric
.
result
().
numpy
().
astype
(
float
)
logging
.
info
(
"Train Step: %d/%d / loss = %s"
,
current_step
,
flags_obj
.
train_steps
,
train_loss
)
with
summary_writer
.
as_default
():
train_steps
(
train_ds_iterator
,
tf
.
convert_to_tensor
(
train_steps_per_eval
,
dtype
=
tf
.
int32
))
current_step
+=
train_steps_per_eval
train_loss
=
train_loss_metric
.
result
().
numpy
().
astype
(
float
)
logging
.
info
(
"Train Step: %d/%d / loss = %s"
,
current_step
,
flags_obj
.
train_steps
,
train_loss
)
if
params
[
"enable_tensorboard"
]:
for
metric_obj
in
train_metrics
:
tf
.
compat
.
v2
.
summary
.
scalar
(
metric_obj
.
name
,
metric_obj
.
result
(),
current_step
)
checkpoint_name
=
checkpoint
.
save
(
os
.
path
.
join
(
flags_obj
.
model_dir
,
"ctl_step_{}.ckpt"
.
format
(
current_step
)))
os
.
path
.
join
(
flags_obj
.
model_dir
,
"ctl_step_{}.ckpt"
.
format
(
current_step
)))
logging
.
info
(
"Saved checkpoint to %s"
,
checkpoint_name
)
else
:
if
self
.
use_tpu
:
...
...
@@ -391,8 +404,9 @@ class TransformerTask(object):
callbacks
=
misc
.
get_callbacks
(
params
[
"steps_between_evals"
])
callbacks
.
append
(
scheduler_callback
)
ckpt_full_path
=
os
.
path
.
join
(
cur_log_dir
,
"cp-{epoch:04d}.ckpt"
)
callbacks
.
append
(
tf
.
keras
.
callbacks
.
ModelCheckpoint
(
ckpt_full_path
,
save_weights_only
=
True
))
callbacks
.
append
(
tf
.
keras
.
callbacks
.
ModelCheckpoint
(
ckpt_full_path
,
save_weights_only
=
True
))
return
callbacks
def
_load_weights_if_possible
(
self
,
model
,
init_weight_path
=
None
):
...
...
@@ -426,8 +440,9 @@ class TransformerTask(object):
if
params
[
"dtype"
]
==
tf
.
float16
:
opt
=
tf
.
keras
.
mixed_precision
.
experimental
.
LossScaleOptimizer
(
opt
,
loss_scale
=
flags_core
.
get_loss_scale
(
self
.
flags_obj
,
default_for_fp16
=
"dynamic"
))
opt
,
loss_scale
=
flags_core
.
get_loss_scale
(
self
.
flags_obj
,
default_for_fp16
=
"dynamic"
))
if
self
.
flags_obj
.
fp16_implementation
==
"graph_rewrite"
:
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录