From 4d0dfb0a301cac2d085d4a1eee908e5e825b04ea Mon Sep 17 00:00:00 2001 From: ceci3 <592712189@qq.com> Date: Wed, 11 Dec 2019 09:20:15 +0000 Subject: [PATCH] update --- demo/nas/block_sa_nas_mobilenetv2.py | 4 ++-- demo/nas/sa_nas_mobilenetv2.py | 10 +++------- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/demo/nas/block_sa_nas_mobilenetv2.py b/demo/nas/block_sa_nas_mobilenetv2.py index b58fc9ef..27fb1f4c 100644 --- a/demo/nas/block_sa_nas_mobilenetv2.py +++ b/demo/nas/block_sa_nas_mobilenetv2.py @@ -25,9 +25,9 @@ retain_epoch = 5 def create_data_loader(image_shape): - data_shape = [-1] + image_shape + data_shape = [None] + image_shape data = fluid.data(name='data', shape=data_shape, dtype='float32') - label = fluid.data(name='label', shape=[-1, 1], dtype='int64') + label = fluid.data(name='label', shape=[None, 1], dtype='int64') data_loader = fluid.io.DataLoader.from_generator( feed_list=[data, label], capacity=1024, diff --git a/demo/nas/sa_nas_mobilenetv2.py b/demo/nas/sa_nas_mobilenetv2.py index 5e60da0f..e6abe115 100644 --- a/demo/nas/sa_nas_mobilenetv2.py +++ b/demo/nas/sa_nas_mobilenetv2.py @@ -27,9 +27,9 @@ retain_epoch = 5 def create_data_loader(image_shape): - data_shape = [-1] + image_shape + data_shape = [None] + image_shape data = fluid.data(name='data', shape=data_shape, dtype='float32') - label = fluid.data(name='label', shape=[-1, 1], dtype='int64') + label = fluid.data(name='label', shape=[None, 1], dtype='int64') data_loader = fluid.io.DataLoader.from_generator( feed_list=[data, label], capacity=1024, @@ -47,11 +47,7 @@ def build_program(main_program, with fluid.program_guard(main_program, startup_program): data_loader, data, label = create_data_loader(image_shape) output = archs(data) - output = fluid.layers.fc( - input=output, - size=args.class_dim, - param_attr=ParamAttr(name='mobilenetv2_fc_weights'), - bias_attr=ParamAttr(name='mobilenetv2_fc_offset')) + output = fluid.layers.fc(input=output, size=args.class_dim) softmax_out = fluid.layers.softmax(input=output, use_cudnn=False) cost = fluid.layers.cross_entropy(input=softmax_out, label=label) -- GitLab