nas_net.py 4.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
import paddle
import paddle.fluid as fluid
from .operations import OPS


def AuxiliaryHeadCIFAR(inputs, C, class_num):
    print('AuxiliaryHeadCIFAR : inputs-shape : {:}'.format(inputs.shape))
    temp = fluid.layers.relu(inputs)
    temp = fluid.layers.pool2d(
        temp, pool_size=5, pool_stride=3, pool_padding=0, pool_type='avg')
    temp = fluid.layers.conv2d(
        temp,
        filter_size=1,
        num_filters=128,
        stride=1,
        padding=0,
        act=None,
        bias_attr=False)
    temp = fluid.layers.batch_norm(input=temp, act='relu', bias_attr=None)
    temp = fluid.layers.conv2d(
        temp,
        filter_size=1,
        num_filters=768,
        stride=2,
        padding=0,
        act=None,
        bias_attr=False)
    temp = fluid.layers.batch_norm(input=temp, act='relu', bias_attr=None)
    print('AuxiliaryHeadCIFAR : last---shape : {:}'.format(temp.shape))
    predict = fluid.layers.fc(input=temp, size=class_num, act='softmax')
    return predict


def InferCell(name, inputs_prev_prev, inputs_prev, genotype, C_prev_prev,
              C_prev, C, reduction, reduction_prev):
    print(
        '[{:}] C_prev_prev={:} C_prev={:}, C={:}, reduction_prev={:}, reduction={:}'.
        format(name, C_prev_prev, C_prev, C, reduction_prev, reduction))
    print('inputs_prev_prev : {:}'.format(inputs_prev_prev.shape))
    print('inputs_prev      : {:}'.format(inputs_prev.shape))
    inputs_prev_prev = OPS['skip_connect'](inputs_prev_prev, C_prev_prev, C, 2
                                           if reduction_prev else 1)
    inputs_prev = OPS['skip_connect'](inputs_prev, C_prev, C, 1)
    print('inputs_prev_prev : {:}'.format(inputs_prev_prev.shape))
    print('inputs_prev      : {:}'.format(inputs_prev.shape))
    if reduction: step_ops, concat = genotype.reduce, genotype.reduce_concat
    else: step_ops, concat = genotype.normal, genotype.normal_concat
    states = [inputs_prev_prev, inputs_prev]
    for istep, operations in enumerate(step_ops):
        op_a, op_b = operations
        # the first operation
        #print ('-->>[{:}/{:}] [{:}] + [{:}]'.format(istep, len(step_ops), op_a, op_b))
        stride = 2 if reduction and op_a[1] < 2 else 1
        tensor1 = OPS[op_a[0]](states[op_a[1]], C, C, stride)
        stride = 2 if reduction and op_b[1] < 2 else 1
        tensor2 = OPS[op_b[0]](states[op_b[1]], C, C, stride)
        state = fluid.layers.elementwise_add(x=tensor1, y=tensor2, act=None)
        assert tensor1.shape == tensor2.shape, 'invalid shape {:} vs. {:}'.format(
            tensor1.shape, tensor2.shape)
        print('-->>[{:}/{:}] tensor={:} from {:} + {:}'.format(
            istep, len(step_ops), state.shape, tensor1.shape, tensor2.shape))
        states.append(state)
    states_to_cat = [states[x] for x in concat]
    outputs = fluid.layers.concat(states_to_cat, axis=1)
    print('-->> output-shape : {:} from concat={:}'.format(outputs.shape,
                                                           concat))
    return outputs


# NASCifarNet(inputs, 36, 6, 3, 10, 'xxx', True)
def NASCifarNet(ipt, C, N, stem_multiplier, class_num, genotype, auxiliary):
    # cifar head module
    C_curr = stem_multiplier * C
    stem = fluid.layers.conv2d(
        ipt,
        filter_size=3,
        num_filters=C_curr,
        stride=1,
        padding=1,
        act=None,
        bias_attr=False)
    stem = fluid.layers.batch_norm(input=stem, act=None, bias_attr=None)
    print('stem-shape : {:}'.format(stem.shape))
    # N + 1 + N + 1 + N cells
    layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
    layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
    C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
    reduction_prev = False
    auxiliary_pred = None

    cell_results = [stem, stem]
    for index, (C_curr,
                reduction) in enumerate(zip(layer_channels, layer_reductions)):
        xstr = '{:02d}/{:02d}'.format(index, len(layer_channels))
        cell_result = InferCell(xstr, cell_results[-2], cell_results[-1],
                                genotype, C_prev_prev, C_prev, C_curr,
                                reduction, reduction_prev)
        reduction_prev = reduction
        C_prev_prev, C_prev = C_prev, cell_result.shape[1]
        cell_results.append(cell_result)
        if auxiliary and reduction and C_curr == C * 4:
            auxiliary_pred = AuxiliaryHeadCIFAR(cell_result, C_prev, class_num)

    global_P = fluid.layers.pool2d(
        input=cell_results[-1], pool_size=8, pool_type='avg', pool_stride=1)
    predicts = fluid.layers.fc(input=global_P, size=class_num, act='softmax')
    print('predict-shape : {:}'.format(predicts.shape))
    if auxiliary_pred is None:
        return predicts
    else:
        return [predicts, auxiliary_pred]