提交 fd915b07 编写于 作者: C ceci3

fix nas demo

上级 4be8333e
......@@ -53,19 +53,27 @@ def build_program(main_program,
return data_loader, avg_cost, acc_top1, acc_top5
def search_mobilenetv2(config, args, image_size):
def search_mobilenetv2(config, args, image_size, is_server=True):
factory = SearchSpaceFactory()
space = factory.get_search_space(config)
### start a server and a client
sa_nas = SANAS(
config,
server_addr=("", 8883),
init_temperature=args.init_temperature,
reduce_rate=args.reduce_rate,
search_steps=args.search_steps,
is_server=True)
### start a client
#sa_nas = SANAS(config, server_addr=("10.255.125.38", 8889), init_temperature=args.init_temperature, reduce_rate=args.reduce_rate, search_steps=args.search_steps, is_server=True)
if is_server:
### start a server and a client
sa_nas = SANAS(
config,
server_addr=("", 8883),
init_temperature=args.init_temperature,
reduce_rate=args.reduce_rate,
search_steps=args.search_steps,
is_server=True)
else:
### start a client
sa_nas = SANAS(
config,
server_addr=("10.255.125.38", 8883),
init_temperature=args.init_temperature,
reduce_rate=args.reduce_rate,
search_steps=args.search_steps,
is_server=False)
image_shape = [3, image_size, image_size]
for step in range(args.search_steps):
......@@ -122,22 +130,22 @@ def search_mobilenetv2(config, args, image_size):
places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places())
test_loader.set_sample_list_generator(test_reader, places=place)
#build_strategy = fluid.BuildStrategy()
#train_compiled_program = fluid.CompiledProgram(
# train_program).with_data_parallel(
# loss_name=avg_cost.name, build_strategy=build_strategy)
#for epoch_id in range(args.retain_epoch):
# for batch_id, data in enumerate(train_loader()):
# fetches = [avg_cost.name]
# s_time = time.time()
# outs = exe.run(train_compiled_program,
# feed=data,
# fetch_list=fetches)[0]
# batch_time = time.time() - s_time
# if batch_id % 10 == 0:
# _logger.info(
# 'TRAIN: steps: {}, epoch: {}, batch: {}, cost: {}, batch_time: {}ms'.
# format(step, epoch_id, batch_id, outs[0], batch_time))
build_strategy = fluid.BuildStrategy()
train_compiled_program = fluid.CompiledProgram(
train_program).with_data_parallel(
loss_name=avg_cost.name, build_strategy=build_strategy)
for epoch_id in range(args.retain_epoch):
for batch_id, data in enumerate(train_loader()):
fetches = [avg_cost.name]
s_time = time.time()
outs = exe.run(train_compiled_program,
feed=data,
fetch_list=fetches)[0]
batch_time = time.time() - s_time
if batch_id % 10 == 0:
_logger.info(
'TRAIN: steps: {}, epoch: {}, batch: {}, cost: {}, batch_time: {}ms'.
format(step, epoch_id, batch_id, outs[0], batch_time))
reward = []
for batch_id, data in enumerate(test_loader()):
......@@ -158,7 +166,7 @@ def search_mobilenetv2(config, args, image_size):
finally_reward = np.mean(np.array(reward), axis=0)
_logger.info(
'FINAL TEST: avg_cost: {}, acc_top1: {}, acc_top5: {}'.format(
step, finally_reward[0], finally_reward[1], finally_reward[2]))
finally_reward[0], finally_reward[1], finally_reward[2]))
sa_nas.reward(float(finally_reward[1]))
......@@ -188,6 +196,11 @@ if __name__ == '__main__':
type=float,
default=10.24,
help='init temperature.')
parser.add_argument(
'--is_server',
type=ast.literal_eval,
default=True,
help='Whether to start a server.')
# nas args
parser.add_argument(
'--max_flops', type=int, default=592948064, help='reduce rate.')
......@@ -260,4 +273,4 @@ if __name__ == '__main__':
}
config = [('MobileNetV2Space', config_info)]
search_mobilenetv2(config, args, image_size)
search_mobilenetv2(config, args, image_size, is_server=args.is_server)
......@@ -32,6 +32,7 @@ class MobileNetV1Space(SearchSpaceBase):
input_size,
output_size,
block_num,
block_mask,
scale=1.0,
class_dim=1000):
super(MobileNetV1Space, self).__init__(input_size, output_size,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册