提交 68901d68 编写于 作者: M mir-of

use argpartition

上级 a72650a5
......@@ -66,6 +66,7 @@ def get_parser(parser=None):
parser.add_argument("--val_batch_size_per_device", type=int, default=8)
# for data process
parser.add_argument("--num_classes", type=int, default=1000, help="num of pic classes")
parser.add_argument("--num_examples", type=int,
default=1281167, help="train pic number")
parser.add_argument("--num_val_examples", type=int,
......@@ -78,6 +79,7 @@ def get_parser(parser=None):
default='NHWC', help="NCHW or NHWC")
parser.add_argument('--image-shape', type=int_list, default=[3, 224, 224],
help='the image shape feed into the network')
parser.add_argument('--label-smoothing', type=float, default=0.1, help='label smoothing factor')
# snapshot
parser.add_argument("--model_save_dir", type=str,
......
......@@ -14,7 +14,6 @@ from job_function_util import get_train_config, get_val_config
import resnet_model
parser = configs.get_parser()
args = parser.parse_args()
configs.print_args(args)
......@@ -41,6 +40,14 @@ if args.use_boxing_v2:
flow.config.collective_boxing.nccl_fusion_all_reduce_use_buffer(False)
def label_smoothing(labels, classes, eta, dtype):
assert classes > 0
assert eta >= 0.0 and eta < 1.0
return flow.one_hot(labels, depth=classes, dtype=dtype,
on_value=1 - eta + eta / classes, off_value=eta/classes)
@flow.global_function(get_train_config(args))
def TrainNet():
if args.train_data_dir:
......@@ -54,9 +61,12 @@ def TrainNet():
logits = model_dict[args.model](
images, need_transpose=False if args.train_data_dir else True)
loss = flow.nn.sparse_softmax_cross_entropy_with_logits(
labels, logits, name="softmax_loss")
loss = flow.math.reduce_mean(loss)
# loss = flow.nn.sparse_softmax_cross_entropy_with_logits(
# labels, logits, name="softmax_loss")
# loss = flow.math.reduce_mean(loss)
one_hot_labels = label_smoothing(labels, args.num_classes, args.label_smoothing, logits.dtype)
loss = flow.nn.softmax_cross_entropy_with_logits(one_hot_labels, logits, name="softmax_loss")
flow.losses.add_loss(loss)
predictions = flow.nn.softmax(logits)
outputs = {"loss": loss, "predictions": predictions, "labels": labels}
......
......@@ -48,6 +48,7 @@ class Summary(object):
def __init__(self, log_dir, config, filename='summary.csv'):
self._filename = filename
self._log_dir = log_dir
if not os.path.exists(log_dir): os.makedirs(log_dir)
self._metrics = pd.DataFrame({"epoch":0, "iter": 0, "legend": "cfg", "note": str(config)}, index=[0])
def scalar(self, legend, value, epoch, step=-1):
......@@ -84,7 +85,7 @@ class StopWatch(object):
def match_top_k(predictions, labels, top_k=1):
max_k_preds = predictions.argsort(axis=1)[:, -top_k:][:, ::-1]
max_k_preds = np.argpartition(predictions.ndarray(), -top_k)[:, -top_k:]
match_array = np.logical_or.reduce(max_k_preds==labels.reshape((-1, 1)), axis=1)
num_matched = match_array.sum()
return num_matched, match_array.shape[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册