of_cnn_train_val.py 4.7 KB
Newer Older
S
ShawnXuan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
S
ShawnXuan 已提交
16
import os
S
ShawnXuan 已提交
17
import math
S
ShawnXuan 已提交
18
import oneflow as flow
M
mir-of 已提交
19
import ofrecord_util
F
refine  
Flowingsun007 已提交
20
import optimizer_util
F
refine  
Flowingsun007 已提交
21
import config as configs
S
ShawnXuan 已提交
22
from util import Snapshot, InitNodes, Metric
M
mir-of 已提交
23
from job_function_util import get_train_config, get_val_config
S
ShawnXuan 已提交
24
import resnet_model
N
nlqq 已提交
25
import resnext_model
S
ShawnXuan 已提交
26
import vgg_model
I
iamyf 已提交
27
import alexnet_model
I
iamyf 已提交
28
import inception_model
G
guo-ran 已提交
29
import mobilenet_v2_model
M
mir-of 已提交
30

M
mir-of 已提交
31 32 33 34
parser = configs.get_parser()
args = parser.parse_args()
configs.print_args(args)

S
ShawnXuan 已提交
35 36 37
total_device_num = args.num_nodes * args.gpu_num_per_node
train_batch_size = total_device_num * args.batch_size_per_device
val_batch_size = total_device_num * args.val_batch_size_per_device
S
ShawnXuan 已提交
38
(C, H, W) = args.image_shape
S
ShawnXuan 已提交
39
epoch_size = math.ceil(args.num_examples / train_batch_size)
M
mir-of 已提交
40
num_val_steps = int(args.num_val_examples / val_batch_size)
S
ShawnXuan 已提交
41 42 43 44


model_dict = {
    "resnet50": resnet_model.resnet50,
S
ShawnXuan 已提交
45
    "vgg": vgg_model.vgg16bn,
I
iamyf 已提交
46
    "alexnet": alexnet_model.alexnet,
I
iamyf 已提交
47
    "inceptionv3": inception_model.inceptionv3,
G
guo-ran 已提交
48
    "mobilenetv2": mobilenet_v2_model.Mobilenet,
N
nlqq 已提交
49
    "resnext50": resnext_model.resnext50,
S
ShawnXuan 已提交
50 51
}

S
ShawnXuan 已提交
52

S
ShawnXuan 已提交
53
flow.config.gpu_device_num(args.gpu_num_per_node)
S
ShawnXuan 已提交
54
#flow.config.enable_debug_mode(True)
M
mir-of 已提交
55

O
ouyangyu 已提交
56 57 58 59 60 61 62 63
if args.use_fp16 and args.num_nodes * args.gpu_num_per_node > 1:
    flow.config.collective_boxing.nccl_fusion_all_reduce_use_buffer(False)

if args.nccl_fusion_threshold_mb:
    flow.config.collective_boxing.nccl_fusion_threshold_mb(args.nccl_fusion_threshold_mb)

if args.nccl_fusion_max_ops:
    flow.config.collective_boxing.nccl_fusion_max_ops(args.nccl_fusion_max_ops)
M
mir-of 已提交
64

M
mir-of 已提交
65 66 67 68 69 70 71
def label_smoothing(labels, classes, eta, dtype):
    assert classes > 0
    assert eta >= 0.0 and eta < 1.0
    return flow.one_hot(labels, depth=classes, dtype=dtype,
                        on_value=1 - eta + eta / classes, off_value=eta/classes)


72
@flow.global_function("train", get_train_config(args))
M
mir-of 已提交
73 74 75 76
def TrainNet():
    if args.train_data_dir:
        assert os.path.exists(args.train_data_dir)
        print("Loading data from {}".format(args.train_data_dir))
M
mir-of 已提交
77 78
        (labels, images) = ofrecord_util.load_imagenet_for_training(args)

M
mir-of 已提交
79 80 81
    else:
        print("Loading synthetic data.")
        (labels, images) = ofrecord_util.load_synthetic(args)
O
ouyangyu 已提交
82
    logits = model_dict[args.model](images, args)
S
ShawnXuan 已提交
83 84 85 86 87
    if args.label_smoothing > 0:
        one_hot_labels = label_smoothing(labels, args.num_classes, args.label_smoothing, logits.dtype)
        loss = flow.nn.softmax_cross_entropy_with_logits(one_hot_labels, logits, name="softmax_loss")
    else:
        loss = flow.nn.sparse_softmax_cross_entropy_with_logits(labels, logits, name="softmax_loss")
88

O
ouyangyu 已提交
89
    loss = flow.math.reduce_mean(loss)
M
mir-of 已提交
90 91
    predictions = flow.nn.softmax(logits)
    outputs = {"loss": loss, "predictions": predictions, "labels": labels}
F
Flowingsun007 已提交
92 93

    # set up warmup,learning rate and optimizer
F
refine  
Flowingsun007 已提交
94
    optimizer_util.set_up_optimizer(loss, args)
S
ShawnXuan 已提交
95 96 97
    return outputs


98
@flow.global_function("predict", get_val_config(args))
M
mir-of 已提交
99 100 101 102
def InferenceNet():
    if args.val_data_dir:
        assert os.path.exists(args.val_data_dir)
        print("Loading data from {}".format(args.val_data_dir))
M
mir-of 已提交
103 104
        (labels, images) = ofrecord_util.load_imagenet_for_validation(args)

S
ShawnXuan 已提交
105
    else:
M
mir-of 已提交
106 107 108
        print("Loading synthetic data.")
        (labels, images) = ofrecord_util.load_synthetic(args)

O
ouyangyu 已提交
109
    logits = model_dict[args.model](images, args)
M
mir-of 已提交
110 111 112
    predictions = flow.nn.softmax(logits)
    outputs = {"predictions": predictions, "labels": labels}
    return outputs
S
ShawnXuan 已提交
113 114


S
ShawnXuan 已提交
115
def main():
M
mir-of 已提交
116
    InitNodes(args)
S
ShawnXuan 已提交
117 118 119 120
    flow.env.log_dir(args.log_dir)

    snapshot = Snapshot(args.model_save_dir, args.model_load_dir)

S
ShawnXuan 已提交
121
    for epoch in range(args.num_epochs):
M
mir-of 已提交
122 123 124 125 126 127
        metric = Metric(desc='train', calculate_batches=args.loss_print_every_n_iter,
                        batch_size=train_batch_size, loss_key='loss')
        for i in range(epoch_size):
            TrainNet().async_get(metric.metric_cb(epoch, i))

        if args.val_data_dir:
S
ShawnXuan 已提交
128 129
            metric = Metric(desc='validation', calculate_batches=num_val_steps, 
                            batch_size=val_batch_size)
M
mir-of 已提交
130 131
            for i in range(num_val_steps):
                InferenceNet().async_get(metric.metric_cb(epoch, i))
F
Flowingsun007 已提交
132
        snapshot.save('epoch_{}'.format(epoch))
S
ShawnXuan 已提交
133 134 135 136


if __name__ == "__main__":
    main()