Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
曾经的那一瞬间
Models
提交
55d41fd6
M
Models
项目概览
曾经的那一瞬间
/
Models
大约 1 年 前同步成功
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
Models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
55d41fd6
编写于
9月 10, 2020
作者:
H
Hongkun Yu
提交者:
A. Unique TensorFlower
9月 10, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Use distribution utils in XLNET
PiperOrigin-RevId: 331015243
上级
a8518117
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
15 addition
and
46 deletion
+15
-46
official/nlp/xlnet/run_classifier.py
official/nlp/xlnet/run_classifier.py
+4
-14
official/nlp/xlnet/run_pretrain.py
official/nlp/xlnet/run_pretrain.py
+7
-18
official/nlp/xlnet/run_squad.py
official/nlp/xlnet/run_squad.py
+4
-14
未找到文件。
official/nlp/xlnet/run_classifier.py
浏览文件 @
55d41fd6
...
...
@@ -14,11 +14,6 @@
# ==============================================================================
"""XLNet classification finetuning runner in tf2.0."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
functools
# Import libraries
from
absl
import
app
...
...
@@ -34,7 +29,7 @@ from official.nlp.xlnet import optimization
from
official.nlp.xlnet
import
training_utils
from
official.nlp.xlnet
import
xlnet_config
from
official.nlp.xlnet
import
xlnet_modeling
as
modeling
from
official.utils.misc
import
tpu_lib
from
official.utils.misc
import
distribution_utils
flags
.
DEFINE_integer
(
"n_class"
,
default
=
2
,
help
=
"Number of classes."
)
flags
.
DEFINE_string
(
...
...
@@ -135,14 +130,9 @@ def get_metric_fn():
def
main
(
unused_argv
):
del
unused_argv
if
FLAGS
.
strategy_type
==
"mirror"
:
strategy
=
tf
.
distribute
.
MirroredStrategy
()
elif
FLAGS
.
strategy_type
==
"tpu"
:
cluster_resolver
=
tpu_lib
.
tpu_initialize
(
FLAGS
.
tpu
)
strategy
=
tf
.
distribute
.
experimental
.
TPUStrategy
(
cluster_resolver
)
else
:
raise
ValueError
(
"The distribution strategy type is not supported: %s"
%
FLAGS
.
strategy_type
)
strategy
=
distribution_utils
.
get_distribution_strategy
(
distribution_strategy
=
FLAGS
.
strategy_type
,
tpu_address
=
FLAGS
.
tpu
)
if
strategy
:
logging
.
info
(
"***** Number of cores used : %d"
,
strategy
.
num_replicas_in_sync
)
...
...
official/nlp/xlnet/run_pretrain.py
浏览文件 @
55d41fd6
...
...
@@ -12,12 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""XLNet classification finetuning runner in tf2.0."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
"""XLNet pretraining runner in tf2.0."""
import
functools
import
os
...
...
@@ -34,7 +29,7 @@ from official.nlp.xlnet import optimization
from
official.nlp.xlnet
import
training_utils
from
official.nlp.xlnet
import
xlnet_config
from
official.nlp.xlnet
import
xlnet_modeling
as
modeling
from
official.utils.misc
import
tpu_lib
from
official.utils.misc
import
distribution_utils
flags
.
DEFINE_integer
(
"num_predict"
,
...
...
@@ -77,17 +72,11 @@ def get_pretrainxlnet_model(model_config, run_config):
def
main
(
unused_argv
):
del
unused_argv
num_hosts
=
1
if
FLAGS
.
strategy_type
==
"mirror"
:
strategy
=
tf
.
distribute
.
MirroredStrategy
()
elif
FLAGS
.
strategy_type
==
"tpu"
:
cluster_resolver
=
tpu_lib
.
tpu_initialize
(
FLAGS
.
tpu
)
strategy
=
tf
.
distribute
.
experimental
.
TPUStrategy
(
cluster_resolver
)
topology
=
FLAGS
.
tpu_topology
.
split
(
"x"
)
total_num_core
=
2
*
int
(
topology
[
0
])
*
int
(
topology
[
1
])
num_hosts
=
total_num_core
//
FLAGS
.
num_core_per_host
else
:
raise
ValueError
(
"The distribution strategy type is not supported: %s"
%
FLAGS
.
strategy_type
)
strategy
=
distribution_utils
.
get_distribution_strategy
(
distribution_strategy
=
FLAGS
.
strategy_type
,
tpu_address
=
FLAGS
.
tpu
)
if
FLAGS
.
strategy_type
==
"tpu"
:
num_hosts
=
strategy
.
extended
.
num_hosts
if
strategy
:
logging
.
info
(
"***** Number of cores used : %d"
,
strategy
.
num_replicas_in_sync
)
...
...
official/nlp/xlnet/run_squad.py
浏览文件 @
55d41fd6
...
...
@@ -14,11 +14,6 @@
# ==============================================================================
"""XLNet SQUAD finetuning runner in tf2.0."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
functools
import
json
import
os
...
...
@@ -39,7 +34,7 @@ from official.nlp.xlnet import squad_utils
from
official.nlp.xlnet
import
training_utils
from
official.nlp.xlnet
import
xlnet_config
from
official.nlp.xlnet
import
xlnet_modeling
as
modeling
from
official.utils.misc
import
tpu_lib
from
official.utils.misc
import
distribution_utils
flags
.
DEFINE_string
(
"test_feature_path"
,
default
=
None
,
help
=
"Path to feature of test set."
)
...
...
@@ -217,14 +212,9 @@ def get_qaxlnet_model(model_config, run_config, start_n_top, end_n_top):
def
main
(
unused_argv
):
del
unused_argv
if
FLAGS
.
strategy_type
==
"mirror"
:
strategy
=
tf
.
distribute
.
MirroredStrategy
()
elif
FLAGS
.
strategy_type
==
"tpu"
:
cluster_resolver
=
tpu_lib
.
tpu_initialize
(
FLAGS
.
tpu
)
strategy
=
tf
.
distribute
.
experimental
.
TPUStrategy
(
cluster_resolver
)
else
:
raise
ValueError
(
"The distribution strategy type is not supported: %s"
%
FLAGS
.
strategy_type
)
strategy
=
distribution_utils
.
get_distribution_strategy
(
distribution_strategy
=
FLAGS
.
strategy_type
,
tpu_address
=
FLAGS
.
tpu
)
if
strategy
:
logging
.
info
(
"***** Number of cores used : %d"
,
strategy
.
num_replicas_in_sync
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录