提交 c32b8b15 编写于 作者: S shippingwang

refine code

上级 5ae03cfb
mode: 'valid'
ARCHITECTURE:
name: ""
name: "ResNet50_vd"
pretrained_model: ""
pretrained_model: "./pretrained_model/ResNet50_vd_pretrained"
classes_num: 1000
total_images: 1281167
topk: 5
......@@ -11,8 +11,8 @@ image_shape: [3, 224, 224]
VALID:
batch_size: 16
num_workers: 4
file_list: "../dataset/ILSVRC2012/val_list.txt"
data_dir: "../dataset/ILSVRC2012/"
file_list: "./dataset/ILSVRC2012/val_list.txt"
data_dir: "./dataset/ILSVRC2012/"
shuffle_seed: 0
transforms:
- DecodeImage:
......
mode: 'train'
architecture: 'ResNet50_vd'
ARCHITECTURE:
name: 'ResNet50_vd'
pretrained_model:
model_save_dir: "./output/"
classes_num: 102
......@@ -29,8 +30,8 @@ OPTIMIZER:
TRAIN:
batch_size: 32
num_workers: 1
file_list: "./dataset/flower102/train_list.txt"
data_dir: "./dataset/flower102"
file_list: "./dataset/flowers102/train_list.txt"
data_dir: "./dataset/flowers102"
shuffle_seed: 0
transforms:
- DecodeImage:
......@@ -54,8 +55,8 @@ TRAIN:
VALID:
batch_size: 64
num_workers: 1
file_list: "./dataset/flower102/val_list.txt"
data_dir: "./dataset/flower102/"
file_list: "./dataset/flowers102/val_list.txt"
data_dir: "./dataset/flowers102/"
shuffle_seed: 0
transforms:
- DecodeImage:
......
......@@ -67,7 +67,7 @@ def check_architecture(architecture):
similar_names = similar_architectures(architecture["name"],
get_architectures())
model_list = ', '.join(similar_names)
err = "{} is not exist! Maybe you want: [{}]" \
err = "Architecture [{}] is not exist! Maybe you want: [{}]" \
"".format(architecture["name"], model_list)
try:
assert architecture["name"] in similar_names
......
......@@ -63,7 +63,11 @@ def print_dict(d, delimiter=0):
Recursively visualize a dict and
indenting acrrording by the relationship of keys.
"""
dk = []
dv = []
for k, v in d.items():
if k in CONFIG_SECS:
logger.info("-" * 60)
......@@ -75,11 +79,16 @@ def print_dict(d, delimiter=0):
for value in v:
print_dict(value, delimiter + 4)
else:
logger.info("{}{} : {}".format(delimiter * " ", k, v))
dk.append(k)
dv.append(v)
if k in CONFIG_SECS:
logger.info("-" * 60)
for ki,vi in zip(dk,dv):
logger.info("{}{} : {}".format(delimiter * " ", ki, vi))
def print_config(config):
"""
......
......@@ -71,8 +71,9 @@ def main(args):
valid_reader = Reader(config, 'valid')()
valid_dataloader.set_sample_list_generator(valid_reader, place)
compiled_valid_prog = program.compile(config, valid_prog)
#compiled_valid_prog = program.compile(config, valid_prog)
compiled_valid_prog = valid_prog
program.run(valid_dataloader, exe, compiled_valid_prog, valid_fetchs, 0,
'valid')
......
export PYTHONPATH=$PWD:$PYTHONPATH
python -m paddle.distributed.launch \
--selected_gpus="0" \
tools/eval.py \
-c ./configs/eval.yaml
......@@ -389,3 +389,4 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train'):
fetchs_str = ''.join([str(m) for m in metric_list] + [str(batch_time)])
logger.info("[epoch:%3d][%s][step:%4d]%s" %
(epoch, mode, idx, fetchs_str))
logger.info("END [epoch:%3d][%s]%s"%(epoch, mode, fetchs_str))
......@@ -4,6 +4,5 @@ export PYTHONPATH=$PWD:$PYTHONPATH
python -m paddle.distributed.launch \
--selected_gpus="0,1,2,3" \
--log_dir=log_ResNet50 \
tools/train.py \
-c ./configs/ResNet/ResNet50.yaml
......@@ -87,8 +87,8 @@ def main(args):
if config.validate:
valid_reader = Reader(config, 'valid')()
valid_dataloader.set_sample_list_generator(valid_reader, place)
compiled_valid_prog = program.compile(config, valid_prog)
#compiled_valid_prog = program.compile(config, valid_prog)
compiled_valid_prog = valid_prog
compiled_train_prog = fleet.main_program
for epoch_id in range(config.epochs):
# 1. train with train dataset
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册