diff --git a/example/resnet101_imagenet2012/dataset.py b/example/resnet101_imagenet2012/dataset.py index 27d93dc08601dc6e399fae2b20d7a7811d51009c..31377cfc1220ca13a92c473f049c97eea367dfb5 100755 --- a/example/resnet101_imagenet2012/dataset.py +++ b/example/resnet101_imagenet2012/dataset.py @@ -76,8 +76,8 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32): type_cast_op = C2.TypeCast(mstype.int32) - ds = ds.map(input_columns="image", operations=trans) - ds = ds.map(input_columns="label", operations=type_cast_op) + ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=8) + ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=8) # apply shuffle operations ds = ds.shuffle(buffer_size=config.buffer_size)