of_cnn_train_val.py 4.7 KB
Newer Older
S
ShawnXuan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
S
ShawnXuan 已提交
16
import os
S
ShawnXuan 已提交
17
import math
S
ShawnXuan 已提交
18
import oneflow as flow
M
mir-of 已提交
19
import ofrecord_util
F
refine  
Flowingsun007 已提交
20
import optimizer_util
F
refine  
Flowingsun007 已提交
21
import config as configs
M
mir-of 已提交
22 23
from util import Snapshot, Summary, InitNodes, Metric
from job_function_util import get_train_config, get_val_config
S
ShawnXuan 已提交
24
import resnet_model
N
nlqq 已提交
25
import resnext_model
S
ShawnXuan 已提交
26
import vgg_model
I
iamyf 已提交
27
import alexnet_model
I
iamyf 已提交
28
import inception_model
G
guo-ran 已提交
29
import mobilenet_v2_model
M
mir-of 已提交
30

M
mir-of 已提交
31 32 33 34
parser = configs.get_parser()
args = parser.parse_args()
configs.print_args(args)

S
ShawnXuan 已提交
35 36 37
total_device_num = args.num_nodes * args.gpu_num_per_node
train_batch_size = total_device_num * args.batch_size_per_device
val_batch_size = total_device_num * args.val_batch_size_per_device
S
ShawnXuan 已提交
38
(C, H, W) = args.image_shape
S
ShawnXuan 已提交
39
epoch_size = math.ceil(args.num_examples / train_batch_size)
M
mir-of 已提交
40
num_val_steps = int(args.num_val_examples / val_batch_size)
S
ShawnXuan 已提交
41 42 43 44


model_dict = {
    "resnet50": resnet_model.resnet50,
S
ShawnXuan 已提交
45
    "vgg": vgg_model.vgg16bn,
I
iamyf 已提交
46
    "alexnet": alexnet_model.alexnet,
I
iamyf 已提交
47
    "inceptionv3": inception_model.inceptionv3,
G
guo-ran 已提交
48
    "mobilenetv2": mobilenet_v2_model.Mobilenet,
N
nlqq 已提交
49
    "resnext50": resnext_model.resnext50,
S
ShawnXuan 已提交
50 51
}

S
ShawnXuan 已提交
52

S
ShawnXuan 已提交
53
flow.config.gpu_device_num(args.gpu_num_per_node)
S
ShawnXuan 已提交
54
#flow.config.enable_debug_mode(True)
M
mir-of 已提交
55 56


M
mir-of 已提交
57 58 59 60 61 62 63
def label_smoothing(labels, classes, eta, dtype):
    assert classes > 0
    assert eta >= 0.0 and eta < 1.0
    return flow.one_hot(labels, depth=classes, dtype=dtype,
                        on_value=1 - eta + eta / classes, off_value=eta/classes)


64
@flow.global_function("train", get_train_config(args))
M
mir-of 已提交
65 66 67 68
def TrainNet():
    if args.train_data_dir:
        assert os.path.exists(args.train_data_dir)
        print("Loading data from {}".format(args.train_data_dir))
M
mir-of 已提交
69 70
        (labels, images) = ofrecord_util.load_imagenet_for_training(args)

M
mir-of 已提交
71 72 73
    else:
        print("Loading synthetic data.")
        (labels, images) = ofrecord_util.load_synthetic(args)
F
Flowingsun007 已提交
74 75
    logits = model_dict[args.model](images,
                                    need_transpose=False if args.train_data_dir else True,
F
Flowingsun007 已提交
76
                                    )
S
ShawnXuan 已提交
77 78 79 80 81
    if args.label_smoothing > 0:
        one_hot_labels = label_smoothing(labels, args.num_classes, args.label_smoothing, logits.dtype)
        loss = flow.nn.softmax_cross_entropy_with_logits(one_hot_labels, logits, name="softmax_loss")
    else:
        loss = flow.nn.sparse_softmax_cross_entropy_with_logits(labels, logits, name="softmax_loss")
82

F
refine  
Flowingsun007 已提交
83
    loss = flow.math.reduce_mean(loss)
F
refine  
Flowingsun007 已提交
84
    flow.losses.add_loss(loss)
M
mir-of 已提交
85 86
    predictions = flow.nn.softmax(logits)
    outputs = {"loss": loss, "predictions": predictions, "labels": labels}
F
Flowingsun007 已提交
87 88

    # set up warmup,learning rate and optimizer
F
refine  
Flowingsun007 已提交
89
    optimizer_util.set_up_optimizer(loss, args)
S
ShawnXuan 已提交
90 91 92
    return outputs


93
@flow.global_function("predict", get_val_config(args))
M
mir-of 已提交
94 95 96 97
def InferenceNet():
    if args.val_data_dir:
        assert os.path.exists(args.val_data_dir)
        print("Loading data from {}".format(args.val_data_dir))
M
mir-of 已提交
98 99
        (labels, images) = ofrecord_util.load_imagenet_for_validation(args)

S
ShawnXuan 已提交
100
    else:
M
mir-of 已提交
101 102 103 104
        print("Loading synthetic data.")
        (labels, images) = ofrecord_util.load_synthetic(args)

    logits = model_dict[args.model](
F
Flowingsun007 已提交
105
        images, need_transpose=False if args.val_data_dir else True)
M
mir-of 已提交
106 107 108
    predictions = flow.nn.softmax(logits)
    outputs = {"predictions": predictions, "labels": labels}
    return outputs
S
ShawnXuan 已提交
109 110


S
ShawnXuan 已提交
111
def main():
M
mir-of 已提交
112
    InitNodes(args)
S
ShawnXuan 已提交
113 114 115
    flow.env.grpc_use_no_signal()
    flow.env.log_dir(args.log_dir)

M
mir-of 已提交
116
    summary = Summary(args.log_dir, args)
S
ShawnXuan 已提交
117 118
    snapshot = Snapshot(args.model_save_dir, args.model_load_dir)

S
ShawnXuan 已提交
119
    for epoch in range(args.num_epochs):
M
mir-of 已提交
120 121 122 123 124 125 126 127 128 129 130
        metric = Metric(desc='train', calculate_batches=args.loss_print_every_n_iter,
                        summary=summary, save_summary_steps=epoch_size,
                        batch_size=train_batch_size, loss_key='loss')
        for i in range(epoch_size):
            TrainNet().async_get(metric.metric_cb(epoch, i))

        if args.val_data_dir:
            metric = Metric(desc='validation', calculate_batches=num_val_steps, summary=summary,
                            save_summary_steps=num_val_steps, batch_size=val_batch_size)
            for i in range(num_val_steps):
                InferenceNet().async_get(metric.metric_cb(epoch, i))
F
Flowingsun007 已提交
131
        snapshot.save('epoch_{}'.format(epoch))
S
ShawnXuan 已提交
132 133 134 135


if __name__ == "__main__":
    main()