Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
曾经的那一瞬间
Models
提交
c115444f
M
Models
项目概览
曾经的那一瞬间
/
Models
11 个月 前同步成功
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
Models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
c115444f
编写于
12月 05, 2019
作者:
J
Jaehong Kim
提交者:
A. Unique TensorFlower
12月 05, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Internal change
PiperOrigin-RevId: 283962490
上级
ef8aed79
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
197 addition
and
44 deletion
+197
-44
official/vision/image_classification/common.py
official/vision/image_classification/common.py
+43
-1
official/vision/image_classification/imagenet_preprocessing.py
...ial/vision/image_classification/imagenet_preprocessing.py
+18
-0
official/vision/image_classification/resnet_imagenet_main.py
official/vision/image_classification/resnet_imagenet_main.py
+72
-19
official/vision/image_classification/resnet_imagenet_test.py
official/vision/image_classification/resnet_imagenet_test.py
+64
-24
未找到文件。
official/vision/image_classification/common.py
浏览文件 @
c115444f
...
...
@@ -24,6 +24,7 @@ import numpy as np
import
tensorflow
as
tf
from
tensorflow.python.keras.optimizer_v2
import
gradient_descent
as
gradient_descent_v2
import
tensorflow_model_optimization
as
tfmot
from
official.utils.flags
import
core
as
flags_core
from
official.utils.misc
import
keras_utils
...
...
@@ -180,7 +181,12 @@ def get_optimizer(learning_rate=0.1):
# TODO(hongkuny,haoyuzhang): make cifar model use_tensor_lr to clean up code.
def
get_callbacks
(
steps_per_epoch
,
learning_rate_schedule_fn
=
None
):
def
get_callbacks
(
steps_per_epoch
,
learning_rate_schedule_fn
=
None
,
pruning_method
=
''
,
enable_checkpoint_and_export
=
False
,
model_dir
=
''
):
"""Returns common callbacks."""
time_callback
=
keras_utils
.
TimeHistory
(
FLAGS
.
batch_size
,
FLAGS
.
log_steps
)
callbacks
=
[
time_callback
]
...
...
@@ -205,6 +211,17 @@ def get_callbacks(steps_per_epoch, learning_rate_schedule_fn=None):
steps_per_epoch
)
callbacks
.
append
(
profiler_callback
)
if
model_dir
:
if
pruning_method
==
'polynomial_decay'
:
callbacks
.
append
(
tfmot
.
sparsity
.
keras
.
PruningSummaries
(
log_dir
=
model_dir
,
profile_batch
=
0
))
callbacks
.
append
(
tfmot
.
sparsity
.
keras
.
UpdatePruningStep
())
if
enable_checkpoint_and_export
:
ckpt_full_path
=
os
.
path
.
join
(
model_dir
,
'model.ckpt-{epoch:04d}'
)
callbacks
.
append
(
tf
.
keras
.
callbacks
.
ModelCheckpoint
(
ckpt_full_path
,
save_weights_only
=
True
))
return
callbacks
...
...
@@ -358,6 +375,31 @@ def get_synth_data(height, width, num_channels, num_classes, dtype):
return
inputs
,
labels
def
define_pruning_flags
():
"""Define flags for pruning methods."""
flags
.
DEFINE_string
(
'pruning_method'
,
''
,
'Pruning method.'
'Empty string (no pruning) or polynomial_decay.'
)
flags
.
DEFINE_float
(
'pruning_initial_sparsity'
,
0.0
,
'Initial sparsity for pruning.'
)
flags
.
DEFINE_float
(
'pruning_final_sparsity'
,
0.5
,
'Final sparsity for pruning.'
)
flags
.
DEFINE_integer
(
'pruning_begin_step'
,
0
,
'Begin step for pruning.'
)
flags
.
DEFINE_integer
(
'pruning_end_step'
,
100000
,
'End step for pruning.'
)
flags
.
DEFINE_integer
(
'pruning_frequency'
,
100
,
'Frequency for pruning.'
)
flags
.
DEFINE_string
(
'model'
,
'resnet50_v1.5'
,
'Name of model preset. (mobilenet, resnet50_v1.5)'
)
flags
.
DEFINE_string
(
'optimizer'
,
'resnet50_default'
,
'Name of optimizer preset. '
'(mobilenet_default, resnet50_default)'
)
flags
.
DEFINE_string
(
'pretrained_filepath'
,
''
,
'Pretrained file path.'
)
def
get_synth_input_fn
(
height
,
width
,
num_channels
,
num_classes
,
dtype
=
tf
.
float32
,
drop_remainder
=
True
):
"""Returns an input function that returns a dataset with random data.
...
...
official/vision/image_classification/imagenet_preprocessing.py
浏览文件 @
c115444f
...
...
@@ -246,6 +246,24 @@ def parse_record(raw_record, is_training, dtype):
return
image
,
label
def
get_parse_record_fn
(
use_keras_image_data_format
=
False
):
"""Get function to use for parsing the records.
Args:
use_keras_image_data_format: A boolean denoting whether data format is keras
backend image data format.
Returns:
Function to use for parsing the records.
"""
def
parse_record_fn
(
raw_record
,
is_training
,
dtype
):
image
,
label
=
parse_record
(
raw_record
,
is_training
,
dtype
)
if
use_keras_image_data_format
:
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_first'
:
image
=
tf
.
transpose
(
image
,
perm
=
[
2
,
0
,
1
])
return
image
,
label
return
parse_record_fn
def
input_fn
(
is_training
,
data_dir
,
batch_size
,
...
...
official/vision/image_classification/resnet_imagenet_main.py
浏览文件 @
c115444f
...
...
@@ -25,6 +25,8 @@ from absl import flags
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow_model_optimization
as
tfmot
from
official.benchmark.models
import
trivial_model
from
official.utils.flags
import
core
as
flags_core
from
official.utils.logs
import
logger
...
...
@@ -44,6 +46,7 @@ def run(flags_obj):
Raises:
ValueError: If fp16 is passed as it is not currently supported.
NotImplementedError: If some features are not currently supported.
Returns:
Dictionary of training and eval stats.
...
...
@@ -120,12 +123,20 @@ def run(flags_obj):
# in the dataset, as XLA-GPU doesn't support dynamic shapes.
drop_remainder
=
flags_obj
.
enable_xla
# Current resnet_model.resnet50 input format is always channel-last.
# We use keras_application mobilenet model which input format is depends on
# the keras beckend image data format.
# This use_keras_image_data_format flags indicates whether image preprocessor
# output format should be same as the keras backend image data format or just
# channel-last format.
use_keras_image_data_format
=
(
flags_obj
.
model
==
'mobilenet'
)
train_input_dataset
=
input_fn
(
is_training
=
True
,
data_dir
=
flags_obj
.
data_dir
,
batch_size
=
flags_obj
.
batch_size
,
num_epochs
=
flags_obj
.
train_epochs
,
parse_record_fn
=
imagenet_preprocessing
.
parse_record
,
parse_record_fn
=
imagenet_preprocessing
.
get_parse_record_fn
(
use_keras_image_data_format
=
use_keras_image_data_format
),
datasets_num_private_threads
=
flags_obj
.
datasets_num_private_threads
,
dtype
=
dtype
,
drop_remainder
=
drop_remainder
,
...
...
@@ -140,7 +151,8 @@ def run(flags_obj):
data_dir
=
flags_obj
.
data_dir
,
batch_size
=
flags_obj
.
batch_size
,
num_epochs
=
flags_obj
.
train_epochs
,
parse_record_fn
=
imagenet_preprocessing
.
parse_record
,
parse_record_fn
=
imagenet_preprocessing
.
get_parse_record_fn
(
use_keras_image_data_format
=
use_keras_image_data_format
),
dtype
=
dtype
,
drop_remainder
=
drop_remainder
)
...
...
@@ -153,9 +165,27 @@ def run(flags_obj):
boundaries
=
list
(
p
[
1
]
for
p
in
common
.
LR_SCHEDULE
[
1
:]),
multipliers
=
list
(
p
[
0
]
for
p
in
common
.
LR_SCHEDULE
),
compute_lr_on_cpu
=
True
)
steps_per_epoch
=
(
imagenet_preprocessing
.
NUM_IMAGES
[
'train'
]
//
flags_obj
.
batch_size
)
learning_rate_schedule_fn
=
None
with
strategy_scope
:
if
flags_obj
.
optimizer
==
'resnet50_default'
:
optimizer
=
common
.
get_optimizer
(
lr_schedule
)
learning_rate_schedule_fn
=
common
.
learning_rate_schedule
elif
flags_obj
.
optimizer
==
'mobilenet_default'
:
lr_decay_factor
=
0.94
num_epochs_per_decay
=
2.5
initial_learning_rate_per_sample
=
0.000007
initial_learning_rate
=
\
initial_learning_rate_per_sample
*
flags_obj
.
batch_size
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
learning_rate
=
tf
.
keras
.
optimizers
.
schedules
.
ExponentialDecay
(
initial_learning_rate
,
decay_steps
=
steps_per_epoch
*
num_epochs_per_decay
,
decay_rate
=
lr_decay_factor
,
staircase
=
True
),
momentum
=
0.9
)
if
flags_obj
.
fp16_implementation
==
'graph_rewrite'
:
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
...
...
@@ -169,9 +199,30 @@ def run(flags_obj):
if
flags_obj
.
use_trivial_model
:
model
=
trivial_model
.
trivial_model
(
imagenet_preprocessing
.
NUM_CLASSES
)
el
se
:
el
if
flags_obj
.
model
==
'resnet50_v1.5'
:
model
=
resnet_model
.
resnet50
(
num_classes
=
imagenet_preprocessing
.
NUM_CLASSES
)
elif
flags_obj
.
model
==
'mobilenet'
:
model
=
tf
.
keras
.
applications
.
mobilenet
.
MobileNet
(
weights
=
None
,
classes
=
imagenet_preprocessing
.
NUM_CLASSES
)
if
flags_obj
.
pretrained_filepath
:
model
.
load_weights
(
flags_obj
.
pretrained_filepath
)
if
flags_obj
.
pruning_method
==
'polynomial_decay'
:
if
dtype
!=
tf
.
float32
:
raise
NotImplementedError
(
'Pruning is currently only supported on dtype=tf.float32.'
)
pruning_params
=
{
'pruning_schedule'
:
tfmot
.
sparsity
.
keras
.
PolynomialDecay
(
initial_sparsity
=
flags_obj
.
pruning_initial_sparsity
,
final_sparsity
=
flags_obj
.
pruning_final_sparsity
,
begin_step
=
flags_obj
.
pruning_begin_step
,
end_step
=
flags_obj
.
pruning_end_step
,
frequency
=
flags_obj
.
pruning_frequency
),
}
model
=
tfmot
.
sparsity
.
keras
.
prune_low_magnitude
(
model
,
**
pruning_params
)
# TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
# a valid arg for this model. Also remove as a valid flag.
...
...
@@ -191,16 +242,14 @@ def run(flags_obj):
if
flags_obj
.
report_accuracy_metrics
else
None
),
run_eagerly
=
flags_obj
.
run_eagerly
)
steps_per_epoch
=
(
imagenet_preprocessing
.
NUM_IMAGES
[
'train'
]
//
flags_obj
.
batch_size
)
train_epochs
=
flags_obj
.
train_epochs
callbacks
=
common
.
get_callbacks
(
steps_per_epoch
,
common
.
learning_rate_schedule
)
if
flags_obj
.
enable_checkpoint_and_export
:
ckpt_full_path
=
os
.
path
.
join
(
flags_obj
.
model_dir
,
'model.ckpt-{epoch:04d}'
)
callbacks
.
append
(
tf
.
keras
.
callbacks
.
ModelCheckpoint
(
ckpt_full_path
,
save_weights_only
=
True
)
)
callbacks
=
common
.
get_callbacks
(
steps_per_epoch
=
steps_per_epoch
,
learning_rate_schedule_fn
=
learning_rate_schedule_fn
,
pruning_method
=
flags_obj
.
pruning_method
,
enable_checkpoint_and_export
=
flags_obj
.
enable_checkpoint_and_export
,
model_dir
=
flags_obj
.
model_dir
)
# if mutliple epochs, ignore the train_steps flag.
if
train_epochs
<=
1
and
flags_obj
.
train_steps
:
...
...
@@ -236,13 +285,6 @@ def run(flags_obj):
validation_data
=
validation_data
,
validation_freq
=
flags_obj
.
epochs_between_evals
,
verbose
=
2
)
if
flags_obj
.
enable_checkpoint_and_export
:
if
dtype
==
tf
.
bfloat16
:
logging
.
warning
(
"Keras model.save does not support bfloat16 dtype."
)
else
:
# Keras model.save assumes a float32 input designature.
export_path
=
os
.
path
.
join
(
flags_obj
.
model_dir
,
'saved_model'
)
model
.
save
(
export_path
,
include_optimizer
=
False
)
eval_output
=
None
if
not
flags_obj
.
skip_eval
:
...
...
@@ -250,6 +292,16 @@ def run(flags_obj):
steps
=
num_eval_steps
,
verbose
=
2
)
if
flags_obj
.
pruning_method
==
'polynomial_decay'
:
model
=
tfmot
.
sparsity
.
keras
.
strip_pruning
(
model
)
if
flags_obj
.
enable_checkpoint_and_export
:
if
dtype
==
tf
.
bfloat16
:
logging
.
warning
(
'Keras model.save does not support bfloat16 dtype.'
)
else
:
# Keras model.save assumes a float32 input designature.
export_path
=
os
.
path
.
join
(
flags_obj
.
model_dir
,
'saved_model'
)
model
.
save
(
export_path
,
include_optimizer
=
False
)
if
not
strategy
and
flags_obj
.
explicit_gpu_placement
:
no_dist_strat_device
.
__exit__
()
...
...
@@ -259,6 +311,7 @@ def run(flags_obj):
def
define_imagenet_keras_flags
():
common
.
define_keras_flags
()
common
.
define_pruning_flags
()
flags_core
.
set_defaults
()
flags
.
adopt_module_key_flags
(
common
)
...
...
official/vision/image_classification/resnet_imagenet_test.py
浏览文件 @
c115444f
...
...
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
tensorflow.python.eager
import
context
...
...
@@ -27,14 +28,45 @@ from official.vision.image_classification import imagenet_preprocessing
from
official.vision.image_classification
import
resnet_imagenet_main
@
parameterized
.
parameters
(
"resnet"
,
"resnet_polynomial_decay"
,
"mobilenet"
,
"mobilenet_polynomial_decay"
)
class
KerasImagenetTest
(
tf
.
test
.
TestCase
):
"""Unit tests for Keras
ResNet
with ImageNet."""
_extra_flags
=
[
"""Unit tests for Keras
Models
with ImageNet."""
_extra_flags_dict
=
{
"resnet"
:
[
"-batch_size"
,
"4"
,
"-train_steps"
,
"1"
,
"-use_synthetic_data"
,
"true"
]
"-model"
,
"resnet50_v1.5"
,
"-optimizer"
,
"resnet50_default"
,
],
"resnet_polynomial_decay"
:
[
"-batch_size"
,
"4"
,
"-train_steps"
,
"1"
,
"-use_synthetic_data"
,
"true"
,
"-model"
,
"resnet50_v1.5"
,
"-optimizer"
,
"resnet50_default"
,
"-pruning_method"
,
"polynomial_decay"
,
],
"mobilenet"
:
[
"-batch_size"
,
"4"
,
"-train_steps"
,
"1"
,
"-use_synthetic_data"
,
"true"
"-model"
,
"mobilenet"
,
"-optimizer"
,
"mobilenet_default"
,
],
"mobilenet_polynomial_decay"
:
[
"-batch_size"
,
"4"
,
"-train_steps"
,
"1"
,
"-use_synthetic_data"
,
"true"
,
"-model"
,
"mobilenet"
,
"-optimizer"
,
"mobilenet_default"
,
"-pruning_method"
,
"polynomial_decay"
,
],
}
_tempdir
=
None
@
classmethod
...
...
@@ -50,7 +82,7 @@ class KerasImagenetTest(tf.test.TestCase):
super
(
KerasImagenetTest
,
self
).
tearDown
()
tf
.
io
.
gfile
.
rmtree
(
self
.
get_temp_dir
())
def
test_end_to_end_no_dist_strat
(
self
):
def
test_end_to_end_no_dist_strat
(
self
,
flags_key
):
"""Test Keras model with 1 GPU, no distribution strategy."""
config
=
keras_utils
.
get_config_proto_v1
()
tf
.
compat
.
v1
.
enable_eager_execution
(
config
=
config
)
...
...
@@ -59,7 +91,7 @@ class KerasImagenetTest(tf.test.TestCase):
"-distribution_strategy"
,
"off"
,
"-data_format"
,
"channels_last"
,
]
extra_flags
=
extra_flags
+
self
.
_extra_flags
extra_flags
=
extra_flags
+
self
.
_extra_flags
_dict
[
flags_key
]
integration
.
run_synthetic
(
main
=
resnet_imagenet_main
.
run
,
...
...
@@ -67,14 +99,14 @@ class KerasImagenetTest(tf.test.TestCase):
extra_flags
=
extra_flags
)
def
test_end_to_end_graph_no_dist_strat
(
self
):
def
test_end_to_end_graph_no_dist_strat
(
self
,
flags_key
):
"""Test Keras model in legacy graph mode with 1 GPU, no dist strat."""
extra_flags
=
[
"-enable_eager"
,
"false"
,
"-distribution_strategy"
,
"off"
,
"-data_format"
,
"channels_last"
,
]
extra_flags
=
extra_flags
+
self
.
_extra_flags
extra_flags
=
extra_flags
+
self
.
_extra_flags
_dict
[
flags_key
]
integration
.
run_synthetic
(
main
=
resnet_imagenet_main
.
run
,
...
...
@@ -82,7 +114,7 @@ class KerasImagenetTest(tf.test.TestCase):
extra_flags
=
extra_flags
)
def
test_end_to_end_1_gpu
(
self
):
def
test_end_to_end_1_gpu
(
self
,
flags_key
):
"""Test Keras model with 1 GPU."""
config
=
keras_utils
.
get_config_proto_v1
()
tf
.
compat
.
v1
.
enable_eager_execution
(
config
=
config
)
...
...
@@ -98,7 +130,7 @@ class KerasImagenetTest(tf.test.TestCase):
"-data_format"
,
"channels_last"
,
"-enable_checkpoint_and_export"
,
"1"
,
]
extra_flags
=
extra_flags
+
self
.
_extra_flags
extra_flags
=
extra_flags
+
self
.
_extra_flags
_dict
[
flags_key
]
integration
.
run_synthetic
(
main
=
resnet_imagenet_main
.
run
,
...
...
@@ -106,7 +138,7 @@ class KerasImagenetTest(tf.test.TestCase):
extra_flags
=
extra_flags
)
def
test_end_to_end_1_gpu_fp16
(
self
):
def
test_end_to_end_1_gpu_fp16
(
self
,
flags_key
):
"""Test Keras model with 1 GPU and fp16."""
config
=
keras_utils
.
get_config_proto_v1
()
tf
.
compat
.
v1
.
enable_eager_execution
(
config
=
config
)
...
...
@@ -122,7 +154,10 @@ class KerasImagenetTest(tf.test.TestCase):
"-distribution_strategy"
,
"mirrored"
,
"-data_format"
,
"channels_last"
,
]
extra_flags
=
extra_flags
+
self
.
_extra_flags
extra_flags
=
extra_flags
+
self
.
_extra_flags_dict
[
flags_key
]
if
"polynomial_decay"
in
extra_flags
:
self
.
skipTest
(
"Pruning with fp16 is not currently supported."
)
integration
.
run_synthetic
(
main
=
resnet_imagenet_main
.
run
,
...
...
@@ -130,8 +165,7 @@ class KerasImagenetTest(tf.test.TestCase):
extra_flags
=
extra_flags
)
def
test_end_to_end_2_gpu
(
self
):
def
test_end_to_end_2_gpu
(
self
,
flags_key
):
"""Test Keras model with 2 GPUs."""
config
=
keras_utils
.
get_config_proto_v1
()
tf
.
compat
.
v1
.
enable_eager_execution
(
config
=
config
)
...
...
@@ -145,7 +179,7 @@ class KerasImagenetTest(tf.test.TestCase):
"-num_gpus"
,
"2"
,
"-distribution_strategy"
,
"mirrored"
,
]
extra_flags
=
extra_flags
+
self
.
_extra_flags
extra_flags
=
extra_flags
+
self
.
_extra_flags
_dict
[
flags_key
]
integration
.
run_synthetic
(
main
=
resnet_imagenet_main
.
run
,
...
...
@@ -153,7 +187,7 @@ class KerasImagenetTest(tf.test.TestCase):
extra_flags
=
extra_flags
)
def
test_end_to_end_xla_2_gpu
(
self
):
def
test_end_to_end_xla_2_gpu
(
self
,
flags_key
):
"""Test Keras model with XLA and 2 GPUs."""
config
=
keras_utils
.
get_config_proto_v1
()
tf
.
compat
.
v1
.
enable_eager_execution
(
config
=
config
)
...
...
@@ -168,7 +202,7 @@ class KerasImagenetTest(tf.test.TestCase):
"-enable_xla"
,
"true"
,
"-distribution_strategy"
,
"mirrored"
,
]
extra_flags
=
extra_flags
+
self
.
_extra_flags
extra_flags
=
extra_flags
+
self
.
_extra_flags
_dict
[
flags_key
]
integration
.
run_synthetic
(
main
=
resnet_imagenet_main
.
run
,
...
...
@@ -176,7 +210,7 @@ class KerasImagenetTest(tf.test.TestCase):
extra_flags
=
extra_flags
)
def
test_end_to_end_2_gpu_fp16
(
self
):
def
test_end_to_end_2_gpu_fp16
(
self
,
flags_key
):
"""Test Keras model with 2 GPUs and fp16."""
config
=
keras_utils
.
get_config_proto_v1
()
tf
.
compat
.
v1
.
enable_eager_execution
(
config
=
config
)
...
...
@@ -191,7 +225,10 @@ class KerasImagenetTest(tf.test.TestCase):
"-dtype"
,
"fp16"
,
"-distribution_strategy"
,
"mirrored"
,
]
extra_flags
=
extra_flags
+
self
.
_extra_flags
extra_flags
=
extra_flags
+
self
.
_extra_flags_dict
[
flags_key
]
if
"polynomial_decay"
in
extra_flags
:
self
.
skipTest
(
"Pruning with fp16 is not currently supported."
)
integration
.
run_synthetic
(
main
=
resnet_imagenet_main
.
run
,
...
...
@@ -199,7 +236,7 @@ class KerasImagenetTest(tf.test.TestCase):
extra_flags
=
extra_flags
)
def
test_end_to_end_xla_2_gpu_fp16
(
self
):
def
test_end_to_end_xla_2_gpu_fp16
(
self
,
flags_key
):
"""Test Keras model with XLA, 2 GPUs and fp16."""
config
=
keras_utils
.
get_config_proto_v1
()
tf
.
compat
.
v1
.
enable_eager_execution
(
config
=
config
)
...
...
@@ -215,7 +252,10 @@ class KerasImagenetTest(tf.test.TestCase):
"-enable_xla"
,
"true"
,
"-distribution_strategy"
,
"mirrored"
,
]
extra_flags
=
extra_flags
+
self
.
_extra_flags
extra_flags
=
extra_flags
+
self
.
_extra_flags_dict
[
flags_key
]
if
"polynomial_decay"
in
extra_flags
:
self
.
skipTest
(
"Pruning with fp16 is not currently supported."
)
integration
.
run_synthetic
(
main
=
resnet_imagenet_main
.
run
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录