of_cnn_train_val.py 3.9 KB
Newer Older
S
ShawnXuan 已提交
1 2 3 4
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
S
ShawnXuan 已提交
5
import math
S
ShawnXuan 已提交
6
import oneflow as flow
M
mir-of 已提交
7 8 9 10
import ofrecord_util
import config as configs
from util import Snapshot, Summary, InitNodes, Metric
from job_function_util import get_train_config, get_val_config
S
ShawnXuan 已提交
11
import resnet_model
S
ShawnXuan 已提交
12
import vgg_model
I
iamyf 已提交
13
import alexnet_model
M
mir-of 已提交
14

M
mir-of 已提交
15 16 17 18
parser = configs.get_parser()
args = parser.parse_args()
configs.print_args(args)

S
ShawnXuan 已提交
19 20 21
total_device_num = args.num_nodes * args.gpu_num_per_node
train_batch_size = total_device_num * args.batch_size_per_device
val_batch_size = total_device_num * args.val_batch_size_per_device
S
ShawnXuan 已提交
22
(C, H, W) = args.image_shape
S
ShawnXuan 已提交
23
epoch_size = math.ceil(args.num_examples / train_batch_size)
M
mir-of 已提交
24
num_val_steps = int(args.num_val_examples / val_batch_size)
S
ShawnXuan 已提交
25 26 27 28


model_dict = {
    "resnet50": resnet_model.resnet50,
S
ShawnXuan 已提交
29
    "vgg": vgg_model.vgg16bn,
I
iamyf 已提交
30
    "alexnet": alexnet_model.alexnet,
S
ShawnXuan 已提交
31 32
}

S
ShawnXuan 已提交
33

S
ShawnXuan 已提交
34
flow.config.gpu_device_num(args.gpu_num_per_node)
S
ShawnXuan 已提交
35
flow.config.enable_debug_mode(True)
M
mir-of 已提交
36 37


M
mir-of 已提交
38 39 40 41 42 43 44 45
def label_smoothing(labels, classes, eta, dtype):
    assert classes > 0
    assert eta >= 0.0 and eta < 1.0

    return flow.one_hot(labels, depth=classes, dtype=dtype,
                        on_value=1 - eta + eta / classes, off_value=eta/classes)


S
update  
ScXfjiang 已提交
46
@flow.global_function(get_train_config(args))
M
mir-of 已提交
47 48 49 50
def TrainNet():
    if args.train_data_dir:
        assert os.path.exists(args.train_data_dir)
        print("Loading data from {}".format(args.train_data_dir))
M
mir-of 已提交
51 52
        (labels, images) = ofrecord_util.load_imagenet_for_training(args)

M
mir-of 已提交
53 54 55
    else:
        print("Loading synthetic data.")
        (labels, images) = ofrecord_util.load_synthetic(args)
F
Flowingsun007 已提交
56 57
    logits = model_dict[args.model](images,
                                    need_transpose=False if args.train_data_dir else True,
F
Flowingsun007 已提交
58
                                    )
S
ShawnXuan 已提交
59 60 61 62 63
    if args.label_smoothing > 0:
        one_hot_labels = label_smoothing(labels, args.num_classes, args.label_smoothing, logits.dtype)
        loss = flow.nn.softmax_cross_entropy_with_logits(one_hot_labels, logits, name="softmax_loss")
    else:
        loss = flow.nn.sparse_softmax_cross_entropy_with_logits(labels, logits, name="softmax_loss")
64

S
ShawnXuan 已提交
65
    flow.losses.add_loss(loss)
M
mir-of 已提交
66 67
    predictions = flow.nn.softmax(logits)
    outputs = {"loss": loss, "predictions": predictions, "labels": labels}
S
ShawnXuan 已提交
68 69 70
    return outputs


S
update  
ScXfjiang 已提交
71
@flow.global_function(get_val_config(args))
M
mir-of 已提交
72 73 74 75
def InferenceNet():
    if args.val_data_dir:
        assert os.path.exists(args.val_data_dir)
        print("Loading data from {}".format(args.val_data_dir))
M
mir-of 已提交
76 77
        (labels, images) = ofrecord_util.load_imagenet_for_validation(args)

S
ShawnXuan 已提交
78
    else:
M
mir-of 已提交
79 80 81 82
        print("Loading synthetic data.")
        (labels, images) = ofrecord_util.load_synthetic(args)

    logits = model_dict[args.model](
F
Flowingsun007 已提交
83
        images, need_transpose=False if args.val_data_dir else True)
M
mir-of 已提交
84 85 86
    predictions = flow.nn.softmax(logits)
    outputs = {"predictions": predictions, "labels": labels}
    return outputs
S
ShawnXuan 已提交
87 88


S
ShawnXuan 已提交
89
def main():
M
mir-of 已提交
90
    InitNodes(args)
S
ShawnXuan 已提交
91 92 93
    flow.env.grpc_use_no_signal()
    flow.env.log_dir(args.log_dir)

M
mir-of 已提交
94
    summary = Summary(args.log_dir, args)
S
ShawnXuan 已提交
95 96
    snapshot = Snapshot(args.model_save_dir, args.model_load_dir)

S
ShawnXuan 已提交
97
    for epoch in range(args.num_epochs):
M
mir-of 已提交
98 99 100 101 102 103 104 105 106 107 108 109
        metric = Metric(desc='train', calculate_batches=args.loss_print_every_n_iter,
                        summary=summary, save_summary_steps=epoch_size,
                        batch_size=train_batch_size, loss_key='loss')
        for i in range(epoch_size):
            TrainNet().async_get(metric.metric_cb(epoch, i))

        if args.val_data_dir:
            metric = Metric(desc='validation', calculate_batches=num_val_steps, summary=summary,
                            save_summary_steps=num_val_steps, batch_size=val_batch_size)
            for i in range(num_val_steps):
                InferenceNet().async_get(metric.metric_cb(epoch, i))
        snapshot.save('epoch_{}'.format(epoch))
S
ShawnXuan 已提交
110 111 112 113


if __name__ == "__main__":
    main()