提交 8709128b 编写于 作者: S ShawnXuan 提交者: GitHub

Merge pull request #23 from Oneflow-Inc/refine_vgg16

label_smoothing
......@@ -12,6 +12,7 @@ import config as configs
from util import Snapshot, Summary, InitNodes, Metric
from job_function_util import get_train_config, get_val_config
import resnet_model
import vgg_model
parser = configs.get_parser()
......@@ -29,6 +30,7 @@ num_val_steps = int(args.num_val_examples / val_batch_size)
model_dict = {
"resnet50": resnet_model.resnet50,
"vgg": vgg_model.vgg16bn,
}
......@@ -61,12 +63,12 @@ def TrainNet():
logits = model_dict[args.model](images,
need_transpose=False if args.train_data_dir else True,
channel_last=args.channel_last)
# loss = flow.nn.sparse_softmax_cross_entropy_with_logits(
# labels, logits, name="softmax_loss")
# loss = flow.math.reduce_mean(loss)
one_hot_labels = label_smoothing(labels, args.num_classes, args.label_smoothing, logits.dtype)
loss = flow.nn.softmax_cross_entropy_with_logits(one_hot_labels, logits, name="softmax_loss")
if args.label_smoothing > 0:
one_hot_labels = label_smoothing(labels, args.num_classes, args.label_smoothing, logits.dtype)
loss = flow.nn.softmax_cross_entropy_with_logits(one_hot_labels, logits, name="softmax_loss")
else:
loss = flow.nn.sparse_softmax_cross_entropy_with_logits(labels, logits, name="softmax_loss")
flow.losses.add_loss(loss)
predictions = flow.nn.softmax(logits)
outputs = {"loss": loss, "predictions": predictions, "labels": labels}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册