提交 b6f52fe7 编写于 作者: S ShawnXuan

support image target resize

上级 b6e38605
......@@ -97,6 +97,7 @@ def main():
save_summary_steps=num_val_steps, batch_size=val_batch_size)
for i in range(num_val_steps):
InferenceNet().async_get(metric.metric_cb(epoch, i))
snapshot.save('epoch_{}'.format(epoch))
if __name__ == "__main__":
......
......@@ -46,7 +46,7 @@ def load_imagenet_for_training(args):
flow.data.ImageResizePreprocessor(args.image_size, args.image_size),
flow.data.ImagePreprocessor('mirror'),
])
return load_imagenet(args, train_batch_size, args.train_data_dir, args.train_data_part_num,
return load_imagenet(args, train_batch_size, args.train_data_dir, args.train_data_part_num,
codec)
......@@ -55,7 +55,9 @@ def load_imagenet_for_validation(args):
val_batch_size = total_device_num * args.val_batch_size_per_device
codec=flow.data.ImageCodec(
[
flow.data.ImageResizePreprocessor(args.image_size, args.image_size),
flow.data.ImageTargetResizePreprocessor(resize_shorter=256),
flow.data.ImageCenterCropPreprocessor(args.image_size, args.image_size),
#flow.data.ImageResizePreprocessor(args.image_size, args.image_size),
]
)
return load_imagenet(args, val_batch_size, args.val_data_dir, args.val_data_part_num, codec)
......@@ -75,4 +77,4 @@ def load_synthetic(args):
shape=(args.image_size, args.image_size, 3), dtype=flow.float, batch_size=batch_size
)
return label, image
\ No newline at end of file
return label, image
......@@ -22,7 +22,7 @@ def _conv2d(
trainable=True,
weight_initializer=flow.variance_scaling_initializer(data_format="NCHW"),
#weight_initializer=flow.variance_scaling_initializer(3, 'fan_in', 'random_normal', data_format="NCHW"),
weight_regularizer=flow.regularizers.l2(1e-4),
weight_regularizer=flow.regularizers.l2(1.0/32768),
):
weight = flow.get_variable(
name + "-weight",
......@@ -143,7 +143,7 @@ def resnet50(images, trainable=True):
#kernel_initializer=flow.variance_scaling_initializer(3, 'fan_in', 'random_normal'),
kernel_initializer=flow.xavier_uniform_initializer(),
bias_initializer=flow.zeros_initializer(),
kernel_regularizer=flow.regularizers.l2(1e-4),
kernel_regularizer=flow.regularizers.l2(1.0/32768),
trainable=trainable,
name="fc1001",
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册