提交 89316571 编写于 作者: M mmglove 提交者: whs

add ce for slim (#4119)

上级 68d17711
......@@ -34,12 +34,20 @@ add_arg('pretrained_model', str, None, "Whether to use pretraine
add_arg('teacher_model', str, None, "Set the teacher network to use.")
add_arg('teacher_pretrained_model', str, None, "Whether to use pretrained model.")
add_arg('compress_config', str, None, "The config file for compression with yaml format.")
add_arg('enable_ce', bool, False, "If set, run the task with continuous evaluation logs.")
# yapf: enable
model_list = [m for m in dir(models) if "__" not in m]
def compress(args):
# add ce
if args.enable_ce:
SEED = 1
fluid.default_main_program().random_seed = SEED
fluid.default_startup_program().random_seed = SEED
image_shape = [int(m) for m in args.image_shape.split(",")]
assert args.model in model_list, "{} is not in lists: {}".format(args.model,
......
......@@ -33,6 +33,7 @@ add_arg('num_epochs', int, 120, "The number of total epochs
add_arg('total_images', int, 1281167, "The number of total training images.")
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
add_arg('config_file', str, None, "The config file for compression with yaml format.")
add_arg('enable_ce', bool, False, "If set, run the task with continuous evaluation logs.")
# yapf: enable
......@@ -68,6 +69,12 @@ def create_optimizer(args):
return cosine_decay(args)
def compress(args):
# add ce
if args.enable_ce:
SEED = 1
fluid.default_main_program().random_seed = SEED
fluid.default_startup_program().random_seed = SEED
class_dim=1000
image_shape="3,224,224"
image_shape = [int(m) for m in image_shape.split(",")]
......
......@@ -33,12 +33,19 @@ add_arg('teacher_model', str, None, "Set the teacher network to use
add_arg('teacher_pretrained_model', str, None, "Whether to use pretrained model.")
add_arg('compress_config', str, None, "The config file for compression with yaml format.")
add_arg('quant_only', bool, False, "Only do quantization-aware training.")
add_arg('enable_ce', bool, False, "If set, run the task with continuous evaluation logs.")
# yapf: enable
model_list = [m for m in dir(models) if "__" not in m]
def compress(args):
# add ce
if args.enable_ce:
SEED = 1
fluid.default_main_program().random_seed = SEED
fluid.default_startup_program().random_seed = SEED
image_shape = [int(m) for m in args.image_shape.split(",")]
assert args.model in model_list, "{} is not in lists: {}".format(args.model,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册