未验证 提交 23928943 编写于 作者: Z Zhang Ting 提交者: GitHub

pad input to use tensor core (#4911)

上级 a25c0656
......@@ -43,7 +43,8 @@ class HybridTrainPipe(Pipeline):
num_shards=1,
random_shuffle=True,
num_threads=4,
seed=42):
seed=42,
pad_output=False):
super(HybridTrainPipe, self).__init__(
batch_size, num_threads, device_id, seed=seed)
self.input = ops.FileReader(
......@@ -73,7 +74,8 @@ class HybridTrainPipe(Pipeline):
crop=(crop, crop),
image_type=types.RGB,
mean=mean,
std=std)
std=std,
pad_output=pad_output)
self.coin = ops.CoinFlip(probability=0.5)
self.to_int64 = ops.Cast(dtype=types.INT64, device="gpu")
......@@ -104,7 +106,8 @@ class HybridValPipe(Pipeline):
num_shards=1,
random_shuffle=False,
num_threads=4,
seed=42):
seed=42,
pad_output=False):
super(HybridValPipe, self).__init__(
batch_size, num_threads, device_id, seed=seed)
self.input = ops.FileReader(
......@@ -123,7 +126,8 @@ class HybridValPipe(Pipeline):
crop=(crop, crop),
image_type=types.RGB,
mean=mean,
std=std)
std=std,
pad_output=pad_output)
self.to_int64 = ops.Cast(dtype=types.INT64, device="gpu")
def define_graph(self):
......@@ -169,6 +173,9 @@ def build(settings, mode='train'):
}
assert interp in interp_map, "interpolation method not supported by DALI"
interp = interp_map[interp]
pad_output = False
if settings.image_shape[0] == 4:
pad_output = True
if mode != 'train':
p = fluid.framework.cuda_places()[0]
......@@ -188,7 +195,8 @@ def build(settings, mode='train'):
interp,
mean,
std,
device_id=device_id)
device_id=device_id,
pad_output=pad_output)
pipe.build()
return DALIGenericIterator(
pipe, ['feed_image', 'feed_label'],
......@@ -221,7 +229,8 @@ def build(settings, mode='train'):
device_id,
shard_id,
num_shards,
seed=42 + shard_id)
seed=42 + shard_id,
pad_output=pad_output)
pipe.build()
pipelines = [pipe]
sample_per_shard = len(pipe) // num_shards
......@@ -248,7 +257,8 @@ def build(settings, mode='train'):
device_id,
idx,
num_shards,
seed=42 + idx)
seed=42 + idx,
pad_output=pad_output)
pipe.build()
pipelines.append(pipe)
sample_per_shard = len(pipelines[0])
......
......@@ -245,6 +245,9 @@ def process_image(sample, settings, mode, color_jitter, rotate):
img_std = np.array(std).reshape((3, 1, 1))
img -= img_mean
img /= img_std
if settings.image_shape[0] == 4:
pad0 = np.zeros((1, img.shape[1], img.shape[2]))
img = np.concatenate((img, pad0), axis=0)
# doing training (train.py)
if mode == 'train' or (mode == 'val' and
not hasattr(settings, 'save_json_path')):
......
......@@ -19,7 +19,7 @@ python train.py \
--data_dir=${DATA_DIR} \
--batch_size=256 \
--total_images=1281167 \
--image_shape 3 224 224 \
--image_shape 4 224 224 \
--class_dim=1000 \
--print_step=10 \
--model_save_dir=output/ \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册