提交 e3d910f4 编写于 作者: J jerrywgz

clean code

上级 73d5f419
...@@ -76,4 +76,3 @@ def cosine_with_warmup_decay(learning_rate, lr_min, steps_one_epoch, ...@@ -76,4 +76,3 @@ def cosine_with_warmup_decay(learning_rate, lr_min, steps_one_epoch,
fluid.layers.assign(cosine_lr, lr) fluid.layers.assign(cosine_lr, lr)
return lr return lr
...@@ -175,7 +175,6 @@ def StemConv(input, C_out, kernel_size, padding): ...@@ -175,7 +175,6 @@ def StemConv(input, C_out, kernel_size, padding):
return bn_a return bn_a
class NetworkCIFAR(object): class NetworkCIFAR(object):
def __init__(self, C, class_num, layers, auxiliary, genotype): def __init__(self, C, class_num, layers, auxiliary, genotype):
self._layers = layers self._layers = layers
......
...@@ -52,6 +52,7 @@ half_length = 8 ...@@ -52,6 +52,7 @@ half_length = 8
CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124] CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
CIFAR_STD = [0.24703233, 0.24348505, 0.26158768] CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]
def generate_reshape_label(label, batch_size, CIFAR_CLASSES=10): def generate_reshape_label(label, batch_size, CIFAR_CLASSES=10):
reshape_label = np.zeros((batch_size, 1), dtype='int32') reshape_label = np.zeros((batch_size, 1), dtype='int32')
reshape_non_label = np.zeros( reshape_non_label = np.zeros(
...@@ -88,7 +89,7 @@ def preprocess(sample, is_training, args): ...@@ -88,7 +89,7 @@ def preprocess(sample, is_training, args):
image_array = sample.reshape(3, image_size, image_size) image_array = sample.reshape(3, image_size, image_size)
rgb_array = np.transpose(image_array, (1, 2, 0)) rgb_array = np.transpose(image_array, (1, 2, 0))
img = Image.fromarray(rgb_array, 'RGB') img = Image.fromarray(rgb_array, 'RGB')
if is_training: if is_training:
# pad and ramdom crop # pad and ramdom crop
img = ImageOps.expand(img, (4, 4, 4, 4), fill=0) # pad to 40 * 40 * 3 img = ImageOps.expand(img, (4, 4, 4, 4), fill=0) # pad to 40 * 40 * 3
...@@ -97,13 +98,13 @@ def preprocess(sample, is_training, args): ...@@ -97,13 +98,13 @@ def preprocess(sample, is_training, args):
left_top[1] + image_size)) left_top[1] + image_size))
if np.random.randint(2): if np.random.randint(2):
img = img.transpose(Image.FLIP_LEFT_RIGHT) img = img.transpose(Image.FLIP_LEFT_RIGHT)
img = np.array(img).astype(np.float32) img = np.array(img).astype(np.float32)
# per_image_standardization # per_image_standardization
img_float = img / 255.0 img_float = img / 255.0
img = (img_float - CIFAR_MEAN) / CIFAR_STD img = (img_float - CIFAR_MEAN) / CIFAR_STD
if is_training and args.cutout: if is_training and args.cutout:
center = np.random.randint(image_size, size=2) center = np.random.randint(image_size, size=2)
offset_width = max(0, center[0] - half_length) offset_width = max(0, center[0] - half_length)
...@@ -114,7 +115,7 @@ def preprocess(sample, is_training, args): ...@@ -114,7 +115,7 @@ def preprocess(sample, is_training, args):
for i in range(offset_height, target_height): for i in range(offset_height, target_height):
for j in range(offset_width, target_width): for j in range(offset_width, target_width):
img[i][j][:] = 0.0 img[i][j][:] = 0.0
img = np.transpose(img, (2, 0, 1)) img = np.transpose(img, (2, 0, 1))
return img return img
...@@ -153,10 +154,6 @@ def reader_creator_filepath(filename, sub_name, is_training, args): ...@@ -153,10 +154,6 @@ def reader_creator_filepath(filename, sub_name, is_training, args):
if len(batch_data) == args.batch_size: if len(batch_data) == args.batch_size:
batch_data = np.array(batch_data, dtype='float32') batch_data = np.array(batch_data, dtype='float32')
batch_label = np.array(batch_label, dtype='int64') batch_label = np.array(batch_label, dtype='int64')
#
# batch_data = pickle.load(open('input.pkl'))
# batch_label = pickle.load(open('target.pkl')).reshape(-1,1)
#
if is_training: if is_training:
flatten_label, flatten_non_label = \ flatten_label, flatten_non_label = \
generate_reshape_label(batch_label, args.batch_size) generate_reshape_label(batch_label, args.batch_size)
......
...@@ -70,6 +70,7 @@ dataset_train_size = 50000. ...@@ -70,6 +70,7 @@ dataset_train_size = 50000.
image_size = 32 image_size = 32
genotypes.DARTS = genotypes.MY_DARTS_list[args.model_id] genotypes.DARTS = genotypes.MY_DARTS_list[args.model_id]
def main(): def main():
image_shape = [3, image_size, image_size] image_shape = [3, image_size, image_size]
devices = os.getenv("CUDA_VISIBLE_DEVICES") or "" devices = os.getenv("CUDA_VISIBLE_DEVICES") or ""
...@@ -79,7 +80,8 @@ def main(): ...@@ -79,7 +80,8 @@ def main():
model = Network(args.init_channels, CIFAR_CLASSES, args.layers, model = Network(args.init_channels, CIFAR_CLASSES, args.layers,
args.auxiliary, genotype) args.auxiliary, genotype)
steps_one_epoch = math.ceil(dataset_train_size / (devices_num * args.batch_size)) steps_one_epoch = math.ceil(dataset_train_size /
(devices_num * args.batch_size))
train(model, args, image_shape, steps_one_epoch) train(model, args, image_shape, steps_one_epoch)
...@@ -136,13 +138,6 @@ def train(model, args, im_shape, steps_one_epoch): ...@@ -136,13 +138,6 @@ def train(model, args, im_shape, steps_one_epoch):
main_program=train_prog, main_program=train_prog,
predicate=if_exist) predicate=if_exist)
#if args.pretrained_model:
# def if_exist(var):
# return os.path.exists(os.path.join(args.pretrained_model, var.name))
# fluid.io.load_vars(exe, args.pretrained_model, main_program=train_prog, predicate=if_exist)
exec_strategy = fluid.ExecutionStrategy() exec_strategy = fluid.ExecutionStrategy()
exec_strategy.num_threads = 1 exec_strategy.num_threads = 1
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
......
...@@ -34,10 +34,6 @@ def mixup_data(x, y, batch_size, alpha=1.0): ...@@ -34,10 +34,6 @@ def mixup_data(x, y, batch_size, alpha=1.0):
lam = 1. lam = 1.
index = np.random.permutation(batch_size) index = np.random.permutation(batch_size)
#
#lam = 0.5
#index = np.arange(batch_size-1, -1, -1)
#
mixed_x = lam * x + (1 - lam) * x[index, :] mixed_x = lam * x + (1 - lam) * x[index, :]
y_a, y_b = y, y[index] y_a, y_b = y, y[index]
return mixed_x.astype('float32'), y_a.astype('int64'),\ return mixed_x.astype('float32'), y_a.astype('int64'),\
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册