提交 6999f340 编写于 作者: S ShawnXuan

new dataloader and boxing v2 params

上级 080de09d
......@@ -39,18 +39,25 @@ model_dict = {
flow.config.gpu_device_num(args.gpu_num_per_node)
flow.config.enable_debug_mode(True)
if args.use_boxing_v2:
flow.config.collective_boxing.nccl_fusion_threshold_mb(8)
flow.config.collective_boxing.nccl_fusion_all_reduce_use_buffer(False)
@flow.function(get_train_config(args))
def TrainNet():
if args.train_data_dir:
assert os.path.exists(args.train_data_dir)
print("Loading data from {}".format(args.train_data_dir))
if args.use_new_dataloader:
(labels, images) = ofrecord_util.load_imagenet_for_training2(args)
else:
(labels, images) = ofrecord_util.load_imagenet_for_training(args)
# note: images.shape = (N C H W) in cc's new dataloader(load_imagenet_for_training2)
else:
print("Loading synthetic data.")
(labels, images) = ofrecord_util.load_synthetic(args)
logits = model_dict[args.model](images)
logits = model_dict[args.model](images, need_transpose=not args.use_new_dataloader)
loss = flow.nn.sparse_softmax_cross_entropy_with_logits(labels, logits, name="softmax_loss")
loss = flow.math.reduce_mean(loss)
flow.losses.add_loss(loss)
......@@ -64,12 +71,15 @@ def InferenceNet():
if args.val_data_dir:
assert os.path.exists(args.val_data_dir)
print("Loading data from {}".format(args.val_data_dir))
if args.use_new_dataloader:
(labels, images) = ofrecord_util.load_imagenet_for_validation2(args)
else:
(labels, images) = ofrecord_util.load_imagenet_for_validation(args)
else:
print("Loading synthetic data.")
(labels, images) = ofrecord_util.load_synthetic(args)
logits = model_dict[args.model](images)
logits = model_dict[args.model](images, need_transpose=not args.use_new_dataloader)
predictions = flow.nn.softmax(logits)
outputs = {"predictions":predictions, "labels": labels}
return outputs
......
......@@ -26,7 +26,7 @@ def _conv2d(
):
weight = flow.get_variable(
name + "-weight",
shape=(filters, input.static_shape[1], kernel_size, kernel_size),
shape=(filters, input.shape[1], kernel_size, kernel_size),
dtype=input.dtype,
initializer=weight_initializer,
regularizer=weight_regularizer,
......@@ -125,10 +125,11 @@ def resnet_stem(input):
return pool1
def resnet50(images, trainable=True):
def resnet50(images, trainable=True, need_transpose=False):
# note: images.shape = (N C H W) in cc's new dataloader, transpose is not needed anymore
# images = flow.transpose(images, name="transpose", perm=[0, 3, 1, 2])
if need_transpose:
images = flow.transpose(images, name="transpose", perm=[0, 3, 1, 2])
with flow.deprecated.variable_scope("Resnet"):
stem = resnet_stem(images)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册