of_cnn_train_val.py 4.2 KB
Newer Older
S
ShawnXuan 已提交
1 2 3 4 5
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
S
ShawnXuan 已提交
6
import math
S
ShawnXuan 已提交
7

S
ShawnXuan 已提交
8
import oneflow as flow
M
mir-of 已提交
9 10 11 12 13

import ofrecord_util
import config as configs
from util import Snapshot, Summary, InitNodes, Metric
from job_function_util import get_train_config, get_val_config
S
ShawnXuan 已提交
14
import resnet_model
S
ShawnXuan 已提交
15
import vgg_model
M
mir-of 已提交
16

S
ShawnXuan 已提交
17

M
mir-of 已提交
18 19 20 21
parser = configs.get_parser()
args = parser.parse_args()
configs.print_args(args)

S
ShawnXuan 已提交
22

S
ShawnXuan 已提交
23 24 25
total_device_num = args.num_nodes * args.gpu_num_per_node
train_batch_size = total_device_num * args.batch_size_per_device
val_batch_size = total_device_num * args.val_batch_size_per_device
S
ShawnXuan 已提交
26
(C, H, W) = args.image_shape
S
ShawnXuan 已提交
27
epoch_size = math.ceil(args.num_examples / train_batch_size)
M
mir-of 已提交
28
num_val_steps = int(args.num_val_examples / val_batch_size)
S
ShawnXuan 已提交
29 30 31 32


model_dict = {
    "resnet50": resnet_model.resnet50,
S
ShawnXuan 已提交
33
    "vgg": vgg_model.vgg16bn,
S
ShawnXuan 已提交
34 35
}

S
ShawnXuan 已提交
36

S
ShawnXuan 已提交
37
flow.config.gpu_device_num(args.gpu_num_per_node)
S
ShawnXuan 已提交
38
flow.config.enable_debug_mode(True)
M
mir-of 已提交
39 40 41 42 43 44

if args.use_boxing_v2:
    flow.config.collective_boxing.nccl_fusion_threshold_mb(8)
    flow.config.collective_boxing.nccl_fusion_all_reduce_use_buffer(False)


M
mir-of 已提交
45 46 47 48 49 50 51 52
def label_smoothing(labels, classes, eta, dtype):
    assert classes > 0
    assert eta >= 0.0 and eta < 1.0

    return flow.one_hot(labels, depth=classes, dtype=dtype,
                        on_value=1 - eta + eta / classes, off_value=eta/classes)


S
update  
ScXfjiang 已提交
53
@flow.global_function(get_train_config(args))
M
mir-of 已提交
54 55 56 57
def TrainNet():
    if args.train_data_dir:
        assert os.path.exists(args.train_data_dir)
        print("Loading data from {}".format(args.train_data_dir))
M
mir-of 已提交
58 59
        (labels, images) = ofrecord_util.load_imagenet_for_training(args)

M
mir-of 已提交
60 61 62
    else:
        print("Loading synthetic data.")
        (labels, images) = ofrecord_util.load_synthetic(args)
F
Flowingsun007 已提交
63 64 65
    logits = model_dict[args.model](images,
                                    need_transpose=False if args.train_data_dir else True,
                                    channel_last=args.channel_last)
S
ShawnXuan 已提交
66 67 68 69 70
    if args.label_smoothing > 0:
        one_hot_labels = label_smoothing(labels, args.num_classes, args.label_smoothing, logits.dtype)
        loss = flow.nn.softmax_cross_entropy_with_logits(one_hot_labels, logits, name="softmax_loss")
    else:
        loss = flow.nn.sparse_softmax_cross_entropy_with_logits(labels, logits, name="softmax_loss")
71

S
ShawnXuan 已提交
72
    flow.losses.add_loss(loss)
M
mir-of 已提交
73 74
    predictions = flow.nn.softmax(logits)
    outputs = {"loss": loss, "predictions": predictions, "labels": labels}
S
ShawnXuan 已提交
75 76 77
    return outputs


S
update  
ScXfjiang 已提交
78
@flow.global_function(get_val_config(args))
M
mir-of 已提交
79 80 81 82
def InferenceNet():
    if args.val_data_dir:
        assert os.path.exists(args.val_data_dir)
        print("Loading data from {}".format(args.val_data_dir))
M
mir-of 已提交
83 84
        (labels, images) = ofrecord_util.load_imagenet_for_validation(args)

S
ShawnXuan 已提交
85
    else:
M
mir-of 已提交
86 87 88 89
        print("Loading synthetic data.")
        (labels, images) = ofrecord_util.load_synthetic(args)

    logits = model_dict[args.model](
F
Flowingsun007 已提交
90
        images, need_transpose=False if args.train_data_dir else True, channel_last=args.channel_last)
M
mir-of 已提交
91 92 93
    predictions = flow.nn.softmax(logits)
    outputs = {"predictions": predictions, "labels": labels}
    return outputs
S
ShawnXuan 已提交
94 95


S
ShawnXuan 已提交
96
def main():
M
mir-of 已提交
97
    InitNodes(args)
F
Flowingsun007 已提交
98 99 100 101
    if args.channel_last:
        print("Use 'NHWC' mode >> Channel last")
    else:
        print("Use 'NCHW' mode >> Channel first")
S
ShawnXuan 已提交
102 103 104
    flow.env.grpc_use_no_signal()
    flow.env.log_dir(args.log_dir)

M
mir-of 已提交
105
    summary = Summary(args.log_dir, args)
S
ShawnXuan 已提交
106 107
    snapshot = Snapshot(args.model_save_dir, args.model_load_dir)

S
ShawnXuan 已提交
108
    for epoch in range(args.num_epochs):
M
mir-of 已提交
109 110 111 112 113 114 115 116 117 118 119 120
        metric = Metric(desc='train', calculate_batches=args.loss_print_every_n_iter,
                        summary=summary, save_summary_steps=epoch_size,
                        batch_size=train_batch_size, loss_key='loss')
        for i in range(epoch_size):
            TrainNet().async_get(metric.metric_cb(epoch, i))

        if args.val_data_dir:
            metric = Metric(desc='validation', calculate_batches=num_val_steps, summary=summary,
                            save_summary_steps=num_val_steps, batch_size=val_batch_size)
            for i in range(num_val_steps):
                InferenceNet().async_get(metric.metric_cb(epoch, i))
        snapshot.save('epoch_{}'.format(epoch))
S
ShawnXuan 已提交
121 122 123 124


if __name__ == "__main__":
    main()