提交 c32b8b15 编写于 作者: S shippingwang

refine code

上级 5ae03cfb
mode: 'valid' mode: 'valid'
ARCHITECTURE: ARCHITECTURE:
name: "" name: "ResNet50_vd"
pretrained_model: "" pretrained_model: "./pretrained_model/ResNet50_vd_pretrained"
classes_num: 1000 classes_num: 1000
total_images: 1281167 total_images: 1281167
topk: 5 topk: 5
...@@ -11,8 +11,8 @@ image_shape: [3, 224, 224] ...@@ -11,8 +11,8 @@ image_shape: [3, 224, 224]
VALID: VALID:
batch_size: 16 batch_size: 16
num_workers: 4 num_workers: 4
file_list: "../dataset/ILSVRC2012/val_list.txt" file_list: "./dataset/ILSVRC2012/val_list.txt"
data_dir: "../dataset/ILSVRC2012/" data_dir: "./dataset/ILSVRC2012/"
shuffle_seed: 0 shuffle_seed: 0
transforms: transforms:
- DecodeImage: - DecodeImage:
......
mode: 'train' mode: 'train'
architecture: 'ResNet50_vd' ARCHITECTURE:
name: 'ResNet50_vd'
pretrained_model: pretrained_model:
model_save_dir: "./output/" model_save_dir: "./output/"
classes_num: 102 classes_num: 102
...@@ -29,8 +30,8 @@ OPTIMIZER: ...@@ -29,8 +30,8 @@ OPTIMIZER:
TRAIN: TRAIN:
batch_size: 32 batch_size: 32
num_workers: 1 num_workers: 1
file_list: "./dataset/flower102/train_list.txt" file_list: "./dataset/flowers102/train_list.txt"
data_dir: "./dataset/flower102" data_dir: "./dataset/flowers102"
shuffle_seed: 0 shuffle_seed: 0
transforms: transforms:
- DecodeImage: - DecodeImage:
...@@ -54,8 +55,8 @@ TRAIN: ...@@ -54,8 +55,8 @@ TRAIN:
VALID: VALID:
batch_size: 64 batch_size: 64
num_workers: 1 num_workers: 1
file_list: "./dataset/flower102/val_list.txt" file_list: "./dataset/flowers102/val_list.txt"
data_dir: "./dataset/flower102/" data_dir: "./dataset/flowers102/"
shuffle_seed: 0 shuffle_seed: 0
transforms: transforms:
- DecodeImage: - DecodeImage:
......
...@@ -67,7 +67,7 @@ def check_architecture(architecture): ...@@ -67,7 +67,7 @@ def check_architecture(architecture):
similar_names = similar_architectures(architecture["name"], similar_names = similar_architectures(architecture["name"],
get_architectures()) get_architectures())
model_list = ', '.join(similar_names) model_list = ', '.join(similar_names)
err = "{} is not exist! Maybe you want: [{}]" \ err = "Architecture [{}] is not exist! Maybe you want: [{}]" \
"".format(architecture["name"], model_list) "".format(architecture["name"], model_list)
try: try:
assert architecture["name"] in similar_names assert architecture["name"] in similar_names
......
...@@ -63,7 +63,11 @@ def print_dict(d, delimiter=0): ...@@ -63,7 +63,11 @@ def print_dict(d, delimiter=0):
Recursively visualize a dict and Recursively visualize a dict and
indenting acrrording by the relationship of keys. indenting acrrording by the relationship of keys.
""" """
dk = []
dv = []
for k, v in d.items(): for k, v in d.items():
if k in CONFIG_SECS: if k in CONFIG_SECS:
logger.info("-" * 60) logger.info("-" * 60)
...@@ -75,11 +79,16 @@ def print_dict(d, delimiter=0): ...@@ -75,11 +79,16 @@ def print_dict(d, delimiter=0):
for value in v: for value in v:
print_dict(value, delimiter + 4) print_dict(value, delimiter + 4)
else: else:
logger.info("{}{} : {}".format(delimiter * " ", k, v)) dk.append(k)
dv.append(v)
if k in CONFIG_SECS: if k in CONFIG_SECS:
logger.info("-" * 60) logger.info("-" * 60)
for ki,vi in zip(dk,dv):
logger.info("{}{} : {}".format(delimiter * " ", ki, vi))
def print_config(config): def print_config(config):
""" """
......
...@@ -71,8 +71,9 @@ def main(args): ...@@ -71,8 +71,9 @@ def main(args):
valid_reader = Reader(config, 'valid')() valid_reader = Reader(config, 'valid')()
valid_dataloader.set_sample_list_generator(valid_reader, place) valid_dataloader.set_sample_list_generator(valid_reader, place)
compiled_valid_prog = program.compile(config, valid_prog)
#compiled_valid_prog = program.compile(config, valid_prog)
compiled_valid_prog = valid_prog
program.run(valid_dataloader, exe, compiled_valid_prog, valid_fetchs, 0, program.run(valid_dataloader, exe, compiled_valid_prog, valid_fetchs, 0,
'valid') 'valid')
......
export PYTHONPATH=$PWD:$PYTHONPATH
python -m paddle.distributed.launch \
--selected_gpus="0" \
tools/eval.py \
-c ./configs/eval.yaml
...@@ -389,3 +389,4 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train'): ...@@ -389,3 +389,4 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train'):
fetchs_str = ''.join([str(m) for m in metric_list] + [str(batch_time)]) fetchs_str = ''.join([str(m) for m in metric_list] + [str(batch_time)])
logger.info("[epoch:%3d][%s][step:%4d]%s" % logger.info("[epoch:%3d][%s][step:%4d]%s" %
(epoch, mode, idx, fetchs_str)) (epoch, mode, idx, fetchs_str))
logger.info("END [epoch:%3d][%s]%s"%(epoch, mode, fetchs_str))
...@@ -4,6 +4,5 @@ export PYTHONPATH=$PWD:$PYTHONPATH ...@@ -4,6 +4,5 @@ export PYTHONPATH=$PWD:$PYTHONPATH
python -m paddle.distributed.launch \ python -m paddle.distributed.launch \
--selected_gpus="0,1,2,3" \ --selected_gpus="0,1,2,3" \
--log_dir=log_ResNet50 \
tools/train.py \ tools/train.py \
-c ./configs/ResNet/ResNet50.yaml -c ./configs/ResNet/ResNet50.yaml
...@@ -87,8 +87,8 @@ def main(args): ...@@ -87,8 +87,8 @@ def main(args):
if config.validate: if config.validate:
valid_reader = Reader(config, 'valid')() valid_reader = Reader(config, 'valid')()
valid_dataloader.set_sample_list_generator(valid_reader, place) valid_dataloader.set_sample_list_generator(valid_reader, place)
compiled_valid_prog = program.compile(config, valid_prog) #compiled_valid_prog = program.compile(config, valid_prog)
compiled_valid_prog = valid_prog
compiled_train_prog = fleet.main_program compiled_train_prog = fleet.main_program
for epoch_id in range(config.epochs): for epoch_id in range(config.epochs):
# 1. train with train dataset # 1. train with train dataset
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册