未验证 提交 23928943 编写于 作者: Z Zhang Ting 提交者: GitHub

pad input to use tensor core (#4911)

上级 a25c0656
...@@ -43,7 +43,8 @@ class HybridTrainPipe(Pipeline): ...@@ -43,7 +43,8 @@ class HybridTrainPipe(Pipeline):
num_shards=1, num_shards=1,
random_shuffle=True, random_shuffle=True,
num_threads=4, num_threads=4,
seed=42): seed=42,
pad_output=False):
super(HybridTrainPipe, self).__init__( super(HybridTrainPipe, self).__init__(
batch_size, num_threads, device_id, seed=seed) batch_size, num_threads, device_id, seed=seed)
self.input = ops.FileReader( self.input = ops.FileReader(
...@@ -73,7 +74,8 @@ class HybridTrainPipe(Pipeline): ...@@ -73,7 +74,8 @@ class HybridTrainPipe(Pipeline):
crop=(crop, crop), crop=(crop, crop),
image_type=types.RGB, image_type=types.RGB,
mean=mean, mean=mean,
std=std) std=std,
pad_output=pad_output)
self.coin = ops.CoinFlip(probability=0.5) self.coin = ops.CoinFlip(probability=0.5)
self.to_int64 = ops.Cast(dtype=types.INT64, device="gpu") self.to_int64 = ops.Cast(dtype=types.INT64, device="gpu")
...@@ -104,7 +106,8 @@ class HybridValPipe(Pipeline): ...@@ -104,7 +106,8 @@ class HybridValPipe(Pipeline):
num_shards=1, num_shards=1,
random_shuffle=False, random_shuffle=False,
num_threads=4, num_threads=4,
seed=42): seed=42,
pad_output=False):
super(HybridValPipe, self).__init__( super(HybridValPipe, self).__init__(
batch_size, num_threads, device_id, seed=seed) batch_size, num_threads, device_id, seed=seed)
self.input = ops.FileReader( self.input = ops.FileReader(
...@@ -123,7 +126,8 @@ class HybridValPipe(Pipeline): ...@@ -123,7 +126,8 @@ class HybridValPipe(Pipeline):
crop=(crop, crop), crop=(crop, crop),
image_type=types.RGB, image_type=types.RGB,
mean=mean, mean=mean,
std=std) std=std,
pad_output=pad_output)
self.to_int64 = ops.Cast(dtype=types.INT64, device="gpu") self.to_int64 = ops.Cast(dtype=types.INT64, device="gpu")
def define_graph(self): def define_graph(self):
...@@ -169,6 +173,9 @@ def build(settings, mode='train'): ...@@ -169,6 +173,9 @@ def build(settings, mode='train'):
} }
assert interp in interp_map, "interpolation method not supported by DALI" assert interp in interp_map, "interpolation method not supported by DALI"
interp = interp_map[interp] interp = interp_map[interp]
pad_output = False
if settings.image_shape[0] == 4:
pad_output = True
if mode != 'train': if mode != 'train':
p = fluid.framework.cuda_places()[0] p = fluid.framework.cuda_places()[0]
...@@ -188,7 +195,8 @@ def build(settings, mode='train'): ...@@ -188,7 +195,8 @@ def build(settings, mode='train'):
interp, interp,
mean, mean,
std, std,
device_id=device_id) device_id=device_id,
pad_output=pad_output)
pipe.build() pipe.build()
return DALIGenericIterator( return DALIGenericIterator(
pipe, ['feed_image', 'feed_label'], pipe, ['feed_image', 'feed_label'],
...@@ -221,7 +229,8 @@ def build(settings, mode='train'): ...@@ -221,7 +229,8 @@ def build(settings, mode='train'):
device_id, device_id,
shard_id, shard_id,
num_shards, num_shards,
seed=42 + shard_id) seed=42 + shard_id,
pad_output=pad_output)
pipe.build() pipe.build()
pipelines = [pipe] pipelines = [pipe]
sample_per_shard = len(pipe) // num_shards sample_per_shard = len(pipe) // num_shards
...@@ -248,7 +257,8 @@ def build(settings, mode='train'): ...@@ -248,7 +257,8 @@ def build(settings, mode='train'):
device_id, device_id,
idx, idx,
num_shards, num_shards,
seed=42 + idx) seed=42 + idx,
pad_output=pad_output)
pipe.build() pipe.build()
pipelines.append(pipe) pipelines.append(pipe)
sample_per_shard = len(pipelines[0]) sample_per_shard = len(pipelines[0])
......
...@@ -245,6 +245,9 @@ def process_image(sample, settings, mode, color_jitter, rotate): ...@@ -245,6 +245,9 @@ def process_image(sample, settings, mode, color_jitter, rotate):
img_std = np.array(std).reshape((3, 1, 1)) img_std = np.array(std).reshape((3, 1, 1))
img -= img_mean img -= img_mean
img /= img_std img /= img_std
if settings.image_shape[0] == 4:
pad0 = np.zeros((1, img.shape[1], img.shape[2]))
img = np.concatenate((img, pad0), axis=0)
# doing training (train.py) # doing training (train.py)
if mode == 'train' or (mode == 'val' and if mode == 'train' or (mode == 'val' and
not hasattr(settings, 'save_json_path')): not hasattr(settings, 'save_json_path')):
......
...@@ -19,7 +19,7 @@ python train.py \ ...@@ -19,7 +19,7 @@ python train.py \
--data_dir=${DATA_DIR} \ --data_dir=${DATA_DIR} \
--batch_size=256 \ --batch_size=256 \
--total_images=1281167 \ --total_images=1281167 \
--image_shape 3 224 224 \ --image_shape 4 224 224 \
--class_dim=1000 \ --class_dim=1000 \
--print_step=10 \ --print_step=10 \
--model_save_dir=output/ \ --model_save_dir=output/ \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册