提交 90565aa6 编写于 作者: F Flowingsun007

add 'nhwc' mode resnet50

上级 67585154
......@@ -6,4 +6,5 @@ MODEL_LOAD_DIR="resnet_v15_of_best_model_val_top1_77318/snapshot_epoch_88"
python3 of_cnn_inference.py \
--image_path="image_demo/tiger.jpg" \
--log_dir="inference_output" \
--model_load_dir=$MODEL_LOAD_DIR
--model_load_dir=$MODEL_LOAD_DIR
# --channel_last=True
......@@ -20,6 +20,7 @@ def load_image(image_path='image_demo/ILSVRC2012_val_00020287.JPEG'):
print(image_path)
im = Image.open(image_path)
im = im.resize((224, 224))
im = im.convert('RGB') # 有的图像是单通道的,不加转换会报错
im = np.array(im).astype('float32')
im = (im - args.rgb_mean) / args.rgb_std
im = np.transpose(im, (2, 0, 1))
......@@ -29,16 +30,22 @@ def load_image(image_path='image_demo/ILSVRC2012_val_00020287.JPEG'):
@flow.global_function(flow.function_config())
def InferenceNet(images=flow.FixedTensorDef((1, 3, 224, 224), dtype=flow.float)):
logits = resnet50(images, training=False)
logits = resnet50(images, training=False, channel_last=args.channel_last)
predictions = flow.nn.softmax(logits)
return predictions
def main():
flow.env.log_dir(args.log_dir)
assert os.path.isdir(args.model_load_dir)
check_point = flow.train.CheckPoint()
check_point.load(args.model_load_dir)
if args.channel_last:
print("Use 'NHWC' mode >> Channel last")
else:
print("Use 'NCHW' mode >> Channel first")
image = load_image()
predictions = InferenceNet(image).get()
clsidx = predictions.ndarray().argmax()
......
......@@ -58,9 +58,9 @@ def TrainNet():
else:
print("Loading synthetic data.")
(labels, images) = ofrecord_util.load_synthetic(args)
logits = model_dict[args.model](
images, need_transpose=False if args.train_data_dir else True)
logits = model_dict[args.model](images,
need_transpose=False if args.train_data_dir else True,
channel_last=args.channel_last)
# loss = flow.nn.sparse_softmax_cross_entropy_with_logits(
# labels, logits, name="softmax_loss")
# loss = flow.math.reduce_mean(loss)
......@@ -85,7 +85,7 @@ def InferenceNet():
(labels, images) = ofrecord_util.load_synthetic(args)
logits = model_dict[args.model](
images, need_transpose=False if args.train_data_dir else True)
images, need_transpose=False if args.train_data_dir else True, channel_last=args.channel_last)
predictions = flow.nn.softmax(logits)
outputs = {"predictions": predictions, "labels": labels}
return outputs
......@@ -93,7 +93,10 @@ def InferenceNet():
def main():
InitNodes(args)
if args.channel_last:
print("Use 'NHWC' mode >> Channel last")
else:
print("Use 'NCHW' mode >> Channel first")
flow.env.grpc_use_no_signal()
flow.env.log_dir(args.log_dir)
......
......@@ -2,51 +2,55 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import oneflow as flow
BLOCK_COUNTS = [3, 4, 6, 3]
BLOCK_FILTERS = [256, 512, 1024, 2048]
BLOCK_FILTERS_INNER = [64, 128, 256, 512]
class ResnetBuilder(object):
def __init__(self, weight_regularizer, trainable=True, training=True):
self.weight_initializer = flow.variance_scaling_initializer(2, 'fan_in', 'random_normal',
data_format="NCHW")
self.weight_regularizer = weight_regularizer
def __init__(self, weight_regularizer, trainable=True, training=True, channel_last=False):
self.data_format = "NHWC" if channel_last else "NCHW"
self.weight_initializer = flow.variance_scaling_initializer(2, 'fan_in', 'random_normal',
data_format=self.data_format)
self.weight_regularizer = weight_regularizer
self.trainable = trainable
self.training = training
def _conv2d(
self,
name,
input,
filters,
kernel_size,
strides=1,
padding="SAME",
data_format="NCHW",
dilations=1,
self,
name,
input,
filters,
kernel_size,
strides=1,
padding="SAME",
dilations=1,
):
# There are different shapes of weight metric between 'NCHW' and 'NHWC' mode
if self.data_format == "NHWC":
shape = (filters, kernel_size, kernel_size, input.shape[3])
else:
shape = (filters, input.shape[1], kernel_size, kernel_size)
weight = flow.get_variable(
name + "-weight",
shape=(filters, input.shape[1], kernel_size, kernel_size),
shape=shape,
dtype=input.dtype,
initializer=self.weight_initializer,
regularizer=self.weight_regularizer,
model_name="weight",
trainable=self.trainable,
)
return flow.nn.conv2d(input, weight, strides, padding, data_format, dilations, name=name)
return flow.nn.conv2d(input, weight, strides, padding, self.data_format, dilations, name=name)
def _batch_norm(self, inputs, name=None, last=False):
initializer = flow.zeros_initializer() if last else flow.ones_initializer()
return flow.layers.batch_normalization(
inputs=inputs,
axis=1,
momentum=0.9,#97,
momentum=0.9, # 97,
epsilon=1e-5,
center=True,
scale=True,
......@@ -59,7 +63,6 @@ class ResnetBuilder(object):
name=name,
)
def conv2d_affine(self, input, name, filters, kernel_size, strides, activation=None, last=False):
# input data_format must be NCHW, cannot check now
padding = "SAME" if strides > 1 or kernel_size > 1 else "VALID"
......@@ -67,24 +70,21 @@ class ResnetBuilder(object):
output = self._batch_norm(output, name + "_bn", last=last)
if activation == "Relu":
output = flow.nn.relu(output)
return output
return output
def bottleneck_transformation(self, input, block_name, filters, filters_inner, strides):
a = self.conv2d_affine(
input, block_name + "_branch2a", filters_inner, 1, 1, activation="Relu"
)
b = self.conv2d_affine(
a, block_name + "_branch2b", filters_inner, 3, strides, activation="Relu"
)
c = self.conv2d_affine(b, block_name + "_branch2c", filters, 1, 1, last=True)
return c
def residual_block(self, input, block_name, filters, filters_inner, strides_init):
if strides_init != 1 or block_name == "res2_0":
shortcut = self.conv2d_affine(
......@@ -92,14 +92,12 @@ class ResnetBuilder(object):
)
else:
shortcut = input
bottleneck = self.bottleneck_transformation(
input, block_name, filters, filters_inner, strides_init,
input, block_name, filters, filters_inner, strides_init,
)
return flow.nn.relu(bottleneck + shortcut)
def residual_stage(self, input, stage_name, counts, filters, filters_inner, stride_init=2):
output = input
for i in range(counts):
......@@ -107,44 +105,44 @@ class ResnetBuilder(object):
output = self.residual_block(
output, block_name, filters, filters_inner, stride_init if i == 0 else 1
)
return output
return output
def resnet_conv_x_body(self, input):
output = input
for i, (counts, filters, filters_inner) in enumerate(
zip(BLOCK_COUNTS, BLOCK_FILTERS, BLOCK_FILTERS_INNER)
zip(BLOCK_COUNTS, BLOCK_FILTERS, BLOCK_FILTERS_INNER)
):
stage_name = "res%d" % (i + 2)
output = self.residual_stage(
output, stage_name, counts, filters, filters_inner, 1 if i == 0 else 2
)
return output
return output
def resnet_stem(self, input):
conv1 = self._conv2d("conv1", input, 64, 7, 2)
conv1_bn = flow.nn.relu(self._batch_norm(conv1, "conv1_bn"))
pool1 = flow.nn.max_pool2d(
conv1_bn, ksize=3, strides=2, padding="SAME", data_format="NCHW", name="pool1",
conv1_bn, ksize=3, strides=2, padding="SAME", data_format=self.data_format, name="pool1",
)
return pool1
def resnet50(images, trainable=True, need_transpose=False, training=True, wd=1.0/32768):
def resnet50(images, trainable=True, need_transpose=False, training=True, wd=1.0 / 32768, channel_last=False):
weight_regularizer = flow.regularizers.l2(wd) if wd > 0.0 and wd < 1.0 else None
builder = ResnetBuilder(weight_regularizer, trainable, training)
builder = ResnetBuilder(weight_regularizer, trainable, training, channel_last)
# note: images.shape = (N C H W) in cc's new dataloader, transpose is not needed anymore
if need_transpose:
images = flow.transpose(images, name="transpose", perm=[0, 3, 1, 2])
if channel_last:
# if channel_last=True, then change mode from 'nchw' to 'nhwc'
images = flow.transpose(images, name="transpose", perm=[0, 2, 3, 1])
with flow.deprecated.variable_scope("Resnet"):
stem = builder.resnet_stem(images)
body = builder.resnet_conv_x_body(stem)
pool5 = flow.nn.avg_pool2d(
body, ksize=7, strides=1, padding="VALID", data_format="NCHW", name="pool5",
body, ksize=7, strides=1, padding="VALID", data_format=builder.data_format, name="pool5",
)
fc1001 = flow.layers.dense(
flow.reshape(pool5, (pool5.shape[0], -1)),
......@@ -157,5 +155,4 @@ def resnet50(images, trainable=True, need_transpose=False, training=True, wd=1.0
trainable=trainable,
name="fc1001",
)
return fc1001
return fc1001
\ No newline at end of file
......@@ -17,4 +17,5 @@ python3 of_cnn_train_val.py \
--batch_size_per_device=64 \
--val_batch_size_per_device=125 \
--num_epoch=90 \
--model="resnet50"
\ No newline at end of file
--model="resnet50"
# --channel_last=True
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册