Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
曾经的那一瞬间
Models
提交
b29fe6b7
M
Models
项目概览
曾经的那一瞬间
/
Models
11 个月 前同步成功
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
Models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
b29fe6b7
编写于
5月 24, 2020
作者:
H
Hongkun Yu
提交者:
A. Unique TensorFlower
5月 24, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Clean up]: remove is_v2() check inside transformer.
PiperOrigin-RevId: 312988874
上级
09d3c74a
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
7 addition
and
38 deletion
+7
-38
official/nlp/transformer/beam_search.py
official/nlp/transformer/beam_search.py
+5
-11
official/nlp/transformer/data_pipeline.py
official/nlp/transformer/data_pipeline.py
+2
-7
official/nlp/transformer/misc.py
official/nlp/transformer/misc.py
+0
-20
未找到文件。
official/nlp/transformer/beam_search.py
浏览文件 @
b29fe6b7
...
...
@@ -17,7 +17,6 @@
import
tensorflow
as
tf
from
official.nlp.transformer
import
beam_search_v1
as
v1
from
official.nlp.transformer
import
misc
_StateKeys
=
v1
.
_StateKeys
# pylint: disable=protected-access
...
...
@@ -52,8 +51,8 @@ class SequenceBeamSearchV2(v1.SequenceBeamSearch):
# Account for corner case where there are no finished sequences for a
# particular batch item. In that case, return alive sequences for that batch
# item.
finished_seq
=
tf
.
compat
.
v2
.
where
(
seq_cond
,
finished_seq
,
alive_seq
)
finished_scores
=
tf
.
compat
.
v2
.
where
(
finished_seq
=
tf
.
where
(
seq_cond
,
finished_seq
,
alive_seq
)
finished_scores
=
tf
.
where
(
score_cond
,
finished_scores
,
alive_log_probs
)
return
finished_seq
,
finished_scores
...
...
@@ -102,14 +101,9 @@ def sequence_beam_search(symbols_to_logits_fn,
batch_size
=
(
initial_ids
.
shape
.
as_list
()[
0
]
if
padded_decode
else
tf
.
shape
(
initial_ids
)[
0
])
if
misc
.
is_v2
():
sbs
=
SequenceBeamSearchV2
(
symbols_to_logits_fn
,
vocab_size
,
batch_size
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
,
padded_decode
,
dtype
)
else
:
sbs
=
v1
.
SequenceBeamSearch
(
symbols_to_logits_fn
,
vocab_size
,
batch_size
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
,
padded_decode
,
dtype
)
return
sbs
.
search
(
initial_ids
,
initial_cache
)
...
...
official/nlp/transformer/data_pipeline.py
浏览文件 @
b29fe6b7
...
...
@@ -56,7 +56,6 @@ import os
from
absl
import
logging
import
tensorflow
as
tf
from
official.nlp.transformer
import
misc
from
official.utils.misc
import
model_helpers
# Buffer size for reading records from a TFRecord file. Each training file is
...
...
@@ -313,9 +312,5 @@ def eval_input_fn(params, ctx=None):
def
map_data_for_transformer_fn
(
x
,
y
):
"""Maps data for training, and handles weried behaviors for different vers."""
# Will transform input x and targets y into tuple(x, y) as new model inputs.
if
misc
.
is_v2
():
# For TF v2, the 2nd parameter is omitted to make Keras training work.
return
((
x
,
y
),)
else
:
# For TF v1, Keras requires a dummy placeholder as the 2nd parameter.
return
((
x
,
y
),
tf
.
constant
(
0.0
))
official/nlp/transformer/misc.py
浏览文件 @
b29fe6b7
...
...
@@ -22,10 +22,6 @@ from __future__ import print_function
from
absl
import
flags
import
tensorflow
as
tf
# TODO(tianlin) Import internal library. Remove this when some functions for
# different TF versions are fixed.
from
tensorflow.python
import
tf2
as
tf2_internal
from
official.nlp.transformer
import
model_params
from
official.utils.flags
import
core
as
flags_core
from
official.utils.misc
import
keras_utils
...
...
@@ -39,11 +35,6 @@ PARAMS_MAP = {
}
def
is_v2
():
"""Returns whether it is v2."""
return
tf2_internal
.
enabled
()
def
get_model_params
(
param_set
,
num_gpus
):
"""Gets predefined model params."""
if
num_gpus
>
1
:
...
...
@@ -78,17 +69,6 @@ def define_transformer_flags():
fp16_implementation
=
True
)
# Additional performance flags
# TODO(b/76028325): Remove when generic layout optimizer is ready.
flags
.
DEFINE_boolean
(
name
=
'enable_grappler_layout_optimizer'
,
default
=
True
,
help
=
'Enable Grappler layout optimizer. Currently Grappler can '
'de-optimize fp16 graphs by forcing NCHW layout for all '
'convolutions and batch normalizations, and this flag allows to '
'disable it.'
)
flags_core
.
define_benchmark
()
flags_core
.
define_device
(
tpu
=
True
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录