提交 6f7ce3f6 编写于 作者: S ShawnXuan

fix dali cpu mode

上级 288e19c4
......@@ -57,7 +57,8 @@ def get_parser(parser=None):
parser.add_argument('--data_train_idx', type=str, default='', help='the index of training data')
parser.add_argument('--data_val', type=str, help='the validation data')
parser.add_argument('--data_val_idx', type=str, default='', help='the index of validation data')
parser.add_argument("--num_examples", type=int, default=1281167, help="imagenet pic number")
parser.add_argument("--num_examples", type=int, default=1281167, help="train pic number")
parser.add_argument("--num_val_examples", type=int, default=50000, help="validation pic number")
## snapshot
parser.add_argument("--model_save_dir", type=str,
......
......@@ -54,7 +54,6 @@ class HybridTrainPipe(Pipeline):
dali_device = "cpu" if dali_cpu else "mixed"
dali_resize_device = "cpu" if dali_cpu else "gpu"
print(dali_device, dali_resize_device)
if args.dali_fuse_decoder:
self.decode = ops.ImageDecoderRandomCrop(device=dali_device, output_type=types.RGB,
device_memory_padding=nvjpeg_padding,
......@@ -68,8 +67,8 @@ class HybridTrainPipe(Pipeline):
self.resize = ops.RandomResizedCrop(device=dali_resize_device, size=crop_shape)
#self.cmnp = ops.CropMirrorNormalize(device=dali_resize_device, #"gpu",
self.cmnp = ops.CropMirrorNormalize(device="gpu",
#self.cmnp = ops.CropMirrorNormalize(device="gpu",
self.cmnp = ops.CropMirrorNormalize(device=dali_resize_device, #"gpu",
output_dtype=types.FLOAT16 if dtype == 'float16' else types.FLOAT,
output_layout=output_layout, crop=crop_shape, pad_output=pad_output,
image_type=types.RGB, mean=args.rgb_mean, std=args.rgb_std)
......@@ -81,7 +80,7 @@ class HybridTrainPipe(Pipeline):
images = self.decode(self.jpegs)
images = self.resize(images)
output = self.cmnp(images.gpu(), mirror=rng)
output = self.cmnp(images, mirror=rng)
return [output, self.labels]
......@@ -102,10 +101,9 @@ class HybridValPipe(Pipeline):
self.decode = ops.ImageDecoder(device="mixed", output_type=types.RGB,
device_memory_padding=nvjpeg_padding,
host_memory_padding=nvjpeg_padding)
print(dali_device)
self.resize = ops.Resize(device=dali_device, resize_shorter=resize_shp) if resize_shp else None
#self.cmnp = ops.CropMirrorNormalize(device=dali_device,#"gpu",
self.cmnp = ops.CropMirrorNormalize(device="gpu",
#self.cmnp = ops.CropMirrorNormalize(device="gpu",
self.cmnp = ops.CropMirrorNormalize(device=dali_device,#"gpu",
output_dtype=types.FLOAT16 if dtype == 'float16' else types.FLOAT,
output_layout=output_layout, crop=crop_shape, pad_output=pad_output,
image_type=types.RGB, mean=args.rgb_mean, std=args.rgb_std)
......@@ -115,7 +113,7 @@ class HybridValPipe(Pipeline):
images = self.decode(self.jpegs)
if self.resize:
images = self.resize(images)
output = self.cmnp(images.gpu())
output = self.cmnp(images)
return [output, self.labels]
......@@ -279,7 +277,7 @@ class DALIGenericIterator(object):
print("DALI iterator does not support resetting while epoch is not finished. Ignoring...")
def get_rec_iter(args, dali_cpu=False, todo=True):
def get_rec_iter(args, train_batch_size, val_batch_size, dali_cpu=False, todo=True):
# TBD dali_cpu only not work
if todo:
gpus = [0]
......@@ -295,11 +293,6 @@ def get_rec_iter(args, dali_cpu=False, todo=True):
# the input_layout w.r.t. the model is the output_layout of the image pipeline
output_layout = types.NHWC if args.input_layout == 'NHWC' else types.NCHW
total_device_num = args.num_nodes * args.gpu_num_per_node
train_batch_size = total_device_num * args.batch_size_per_device
val_batch_size = total_device_num * args.val_batch_size_per_device
print(train_batch_size, val_batch_size)
trainpipes = [HybridTrainPipe(args = args,
batch_size = train_batch_size,
num_threads = num_threads,
......@@ -355,7 +348,7 @@ if __name__ == '__main__':
parser = configs.get_parser()
args = parser.parse_args()
print_args(args)
train_data_iter, val_data_iter = get_rec_iter(args, True)
train_data_iter, val_data_iter = get_rec_iter(args, 256, 500, True)
for epoch in range(args.num_epochs):
tic = time.time()
print('Starting epoch {}'.format(epoch))
......
......@@ -30,6 +30,7 @@ epoch_size = math.ceil(args.num_examples / train_batch_size)
num_train_batches = epoch_size * args.num_epochs
num_warmup_batches = epoch_size * args.warmup_epochs
decay_batches = num_train_batches - num_warmup_batches
num_val_steps = args.num_val_examples / val_batch_size
summary = Summary(args.log_dir, args)
timer = StopWatch()
......@@ -143,7 +144,7 @@ def train_callback(epoch, step):
def do_predictions(epoch, predict_step, predictions):
acc_acc(predict_step, predictions)
if predict_step + 1 == args.val_step_num:
if predict_step + 1 == num_val_steps:
assert main.total > 0
summary.scalar('top1_accuracy', main.correct/main.total, epoch)
#summary.scalar('top1_correct', main.correct, epoch)
......@@ -166,7 +167,7 @@ def main():
snapshot = Snapshot(args.model_save_dir, args.model_load_dir)
train_data_iter, val_data_iter = get_rec_iter(args, True)
train_data_iter, val_data_iter = get_rec_iter(args, train_batch_size, val_batch_size, True)
timer.start()
for epoch in range(args.num_epochs):
tic = time.time()
......@@ -186,8 +187,8 @@ def main():
for i, batches in enumerate(val_data_iter):
assert len(batches) == 1
images, labels = batches[0]
#InferenceNet(images, labels.astype(np.int32)).async_get(predict_callback(epoch, i))
acc_acc(i, InferenceNet(images, labels.astype(np.int32)).get())
InferenceNet(images, labels.astype(np.int32)).async_get(predict_callback(epoch, i))
#acc_acc(i, InferenceNet(images, labels.astype(np.int32)).get())
assert main.total > 0
top1_accuracy = main.correct/main.total
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册