提交 2a1981cf 编写于 作者: S ShawnXuan

support image target resize

上级 95108e0e
...@@ -97,6 +97,7 @@ def main(): ...@@ -97,6 +97,7 @@ def main():
save_summary_steps=num_val_steps, batch_size=val_batch_size) save_summary_steps=num_val_steps, batch_size=val_batch_size)
for i in range(num_val_steps): for i in range(num_val_steps):
InferenceNet().async_get(metric.metric_cb(epoch, i)) InferenceNet().async_get(metric.metric_cb(epoch, i))
snapshot.save('epoch_{}'.format(epoch))
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -46,7 +46,7 @@ def load_imagenet_for_training(args): ...@@ -46,7 +46,7 @@ def load_imagenet_for_training(args):
flow.data.ImageResizePreprocessor(args.image_size, args.image_size), flow.data.ImageResizePreprocessor(args.image_size, args.image_size),
flow.data.ImagePreprocessor('mirror'), flow.data.ImagePreprocessor('mirror'),
]) ])
return load_imagenet(args, train_batch_size, args.train_data_dir, args.train_data_part_num, return load_imagenet(args, train_batch_size, args.train_data_dir, args.train_data_part_num,
codec) codec)
...@@ -55,7 +55,9 @@ def load_imagenet_for_validation(args): ...@@ -55,7 +55,9 @@ def load_imagenet_for_validation(args):
val_batch_size = total_device_num * args.val_batch_size_per_device val_batch_size = total_device_num * args.val_batch_size_per_device
codec=flow.data.ImageCodec( codec=flow.data.ImageCodec(
[ [
flow.data.ImageResizePreprocessor(args.image_size, args.image_size), flow.data.ImageTargetResizePreprocessor(resize_shorter=256),
flow.data.ImageCenterCropPreprocessor(args.image_size, args.image_size),
#flow.data.ImageResizePreprocessor(args.image_size, args.image_size),
] ]
) )
return load_imagenet(args, val_batch_size, args.val_data_dir, args.val_data_part_num, codec) return load_imagenet(args, val_batch_size, args.val_data_dir, args.val_data_part_num, codec)
...@@ -75,4 +77,4 @@ def load_synthetic(args): ...@@ -75,4 +77,4 @@ def load_synthetic(args):
shape=(args.image_size, args.image_size, 3), dtype=flow.float, batch_size=batch_size shape=(args.image_size, args.image_size, 3), dtype=flow.float, batch_size=batch_size
) )
return label, image return label, image
\ No newline at end of file
...@@ -22,7 +22,7 @@ def _conv2d( ...@@ -22,7 +22,7 @@ def _conv2d(
trainable=True, trainable=True,
weight_initializer=flow.variance_scaling_initializer(data_format="NCHW"), weight_initializer=flow.variance_scaling_initializer(data_format="NCHW"),
#weight_initializer=flow.variance_scaling_initializer(3, 'fan_in', 'random_normal', data_format="NCHW"), #weight_initializer=flow.variance_scaling_initializer(3, 'fan_in', 'random_normal', data_format="NCHW"),
weight_regularizer=flow.regularizers.l2(1e-4), weight_regularizer=flow.regularizers.l2(1.0/32768),
): ):
weight = flow.get_variable( weight = flow.get_variable(
name + "-weight", name + "-weight",
...@@ -143,7 +143,7 @@ def resnet50(images, trainable=True): ...@@ -143,7 +143,7 @@ def resnet50(images, trainable=True):
#kernel_initializer=flow.variance_scaling_initializer(3, 'fan_in', 'random_normal'), #kernel_initializer=flow.variance_scaling_initializer(3, 'fan_in', 'random_normal'),
kernel_initializer=flow.xavier_uniform_initializer(), kernel_initializer=flow.xavier_uniform_initializer(),
bias_initializer=flow.zeros_initializer(), bias_initializer=flow.zeros_initializer(),
kernel_regularizer=flow.regularizers.l2(1e-4), kernel_regularizer=flow.regularizers.l2(1.0/32768),
trainable=trainable, trainable=trainable,
name="fc1001", name="fc1001",
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册