提交 43d30547 编写于 作者: T tensor-tang

Merge remote-tracking branch 'ups/develop' into refine/infershape

......@@ -20,6 +20,7 @@ import functools
import numpy as np
import time
import os
import math
import cProfile, pstats, StringIO
......@@ -27,128 +28,120 @@ import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.profiler as profiler
# from recordio_converter import imagenet_train, imagenet_test
from imagenet_reader import train, val
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [30, 60, 90],
"steps": [0.1, 0.01, 0.001, 0.0001]
}
}
class ResNet():
def __init__(self, layers=50, is_train=True):
self.params = train_parameters
self.layers = layers
self.is_train = is_train
def net(self, input, class_dim=1000):
layers = self.layers
supported_layers = [50, 101, 152]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(supported_layers, layers)
if layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
num_filters = [64, 128, 256, 512]
conv = self.conv_bn_layer(
input=input, num_filters=64, filter_size=7, stride=2, act='relu')
conv = fluid.layers.pool2d(
input=conv,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')
for block in range(len(depth)):
for i in range(depth[block]):
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1)
pool = fluid.layers.pool2d(
input=conv, pool_size=7, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc(input=pool,
size=class_dim,
act='softmax',
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv,
stdv)))
return out
def conv_bn_layer(self,
input,
num_filters,
filter_size,
stride=1,
groups=1,
act=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
bias_attr=False)
return fluid.layers.batch_norm(
input=conv, act=act, is_test=not self.is_train)
def shortcut(self, input, ch_out, stride):
ch_in = input.shape[1]
if ch_in != ch_out or stride != 1:
return self.conv_bn_layer(input, ch_out, 1, stride)
else:
return input
def conv_bn_layer(input,
ch_out,
filter_size,
stride,
padding,
act='relu',
is_train=True):
conv1 = fluid.layers.conv2d(
input=input,
filter_size=filter_size,
num_filters=ch_out,
stride=stride,
padding=padding,
act=None,
bias_attr=False)
return fluid.layers.batch_norm(input=conv1, act=act, is_test=not is_train)
def shortcut(input, ch_out, stride, is_train=True):
ch_in = input.shape[1] # if args.data_format == 'NCHW' else input.shape[-1]
if ch_in != ch_out:
return conv_bn_layer(
input, ch_out, 1, stride, 0, None, is_train=is_train)
else:
return input
def basicblock(input, ch_out, stride, is_train=True):
short = shortcut(input, ch_out, stride, is_train=is_train)
conv1 = conv_bn_layer(input, ch_out, 3, stride, 1, is_train=is_train)
conv2 = conv_bn_layer(conv1, ch_out, 3, 1, 1, act=None, is_train=is_train)
return fluid.layers.elementwise_add(x=short, y=conv2, act='relu')
def bottleneck(input, ch_out, stride, is_train=True):
short = shortcut(input, ch_out * 4, stride, is_train=is_train)
conv1 = conv_bn_layer(input, ch_out, 1, stride, 0, is_train=is_train)
conv2 = conv_bn_layer(conv1, ch_out, 3, 1, 1, is_train=is_train)
conv3 = conv_bn_layer(
conv2, ch_out * 4, 1, 1, 0, act=None, is_train=is_train)
return fluid.layers.elementwise_add(x=short, y=conv3, act='relu')
def layer_warp(block_func, input, ch_out, count, stride):
res_out = block_func(input, ch_out, stride)
for i in range(1, count):
res_out = block_func(res_out, ch_out, 1)
return res_out
def bottleneck_block(self, input, num_filters, stride):
conv0 = self.conv_bn_layer(
input=input, num_filters=num_filters, filter_size=1, act='relu')
conv1 = self.conv_bn_layer(
input=conv0,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu')
conv2 = self.conv_bn_layer(
input=conv1, num_filters=num_filters * 4, filter_size=1, act=None)
def resnet_imagenet(input,
class_dim,
depth=50,
data_format='NCHW',
is_train=True):
short = self.shortcut(input, num_filters * 4, stride)
cfg = {
18: ([2, 2, 2, 1], basicblock),
34: ([3, 4, 6, 3], basicblock),
50: ([3, 4, 6, 3], bottleneck),
101: ([3, 4, 23, 3], bottleneck),
152: ([3, 8, 36, 3], bottleneck)
}
stages, block_func = cfg[depth]
conv1 = conv_bn_layer(input, ch_out=64, filter_size=7, stride=2, padding=3)
pool1 = fluid.layers.pool2d(
input=conv1, pool_type='avg', pool_size=3, pool_stride=2)
res1 = layer_warp(block_func, pool1, 64, stages[0], 1)
res2 = layer_warp(block_func, res1, 128, stages[1], 2)
res3 = layer_warp(block_func, res2, 256, stages[2], 2)
res4 = layer_warp(block_func, res3, 512, stages[3], 2)
pool2 = fluid.layers.pool2d(
input=res4,
pool_size=7,
pool_type='avg',
pool_stride=1,
global_pooling=True)
out = fluid.layers.fc(input=pool2, size=class_dim, act='softmax')
return out
def resnet_cifar10(input, class_dim, depth=32, data_format='NCHW'):
assert (depth - 2) % 6 == 0
n = (depth - 2) // 6
conv1 = conv_bn_layer(
input=input, ch_out=16, filter_size=3, stride=1, padding=1)
res1 = layer_warp(basicblock, conv1, 16, n, 1)
res2 = layer_warp(basicblock, res1, 32, n, 2)
res3 = layer_warp(basicblock, res2, 64, n, 2)
pool = fluid.layers.pool2d(
input=res3, pool_size=8, pool_type='avg', pool_stride=1)
out = fluid.layers.fc(input=pool, size=class_dim, act='softmax')
return out
return fluid.layers.elementwise_add(x=short, y=conv2, act='relu')
def _model_reader_dshape_classdim(args, is_train):
model = resnet_cifar10
model = None
reader = None
if args.data_set == "cifar10":
class_dim = 10
if args.data_format == 'NCHW':
dshape = [3, 32, 32]
else:
dshape = [32, 32, 3]
model = resnet_cifar10
if is_train:
reader = paddle.dataset.cifar.train10()
else:
reader = paddle.dataset.cifar.test10()
elif args.data_set == "flowers":
if args.data_set == "flowers":
class_dim = 102
if args.data_format == 'NCHW':
dshape = [3, 224, 224]
else:
dshape = [224, 224, 3]
model = resnet_imagenet
if is_train:
reader = paddle.dataset.flowers.train()
else:
......@@ -159,7 +152,6 @@ def _model_reader_dshape_classdim(args, is_train):
dshape = [3, 224, 224]
else:
dshape = [224, 224, 3]
model = resnet_imagenet
if not args.data_path:
raise Exception(
"Must specify --data_path when training with imagenet")
......@@ -173,12 +165,11 @@ def _model_reader_dshape_classdim(args, is_train):
reader = train(xmap=False)
else:
reader = val(xmap=False)
return model, reader, dshape, class_dim
return reader, dshape, class_dim
def get_model(args, is_train, main_prog, startup_prog):
model, reader, dshape, class_dim = _model_reader_dshape_classdim(args,
is_train)
reader, dshape, class_dim = _model_reader_dshape_classdim(args, is_train)
pyreader = None
trainer_count = int(os.getenv("PADDLE_TRAINERS"))
......@@ -198,7 +189,8 @@ def get_model(args, is_train, main_prog, startup_prog):
label = fluid.layers.data(
name='label', shape=[1], dtype='int64')
predict = model(input, class_dim, is_train=is_train)
model = ResNet(is_train=is_train)
predict = model.net(input, class_dim=class_dim)
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(x=cost)
......@@ -216,15 +208,14 @@ def get_model(args, is_train, main_prog, startup_prog):
total_images = 1281167 / trainer_count
step = int(total_images / args.batch_size + 1)
epochs = [30, 60, 80, 90]
epochs = [30, 60, 90]
bd = [step * e for e in epochs]
base_lr = args.learning_rate
lr = []
lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)]
optimizer = fluid.optimizer.Momentum(
learning_rate=base_lr,
#learning_rate=fluid.layers.piecewise_decay(
# boundaries=bd, values=lr),
learning_rate=fluid.layers.piecewise_decay(
boundaries=bd, values=lr),
momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4))
optimizer.minimize(avg_cost)
......
......@@ -130,12 +130,13 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryFromPrimitive(
const std::shared_ptr<mkldnn::memory> user_weights_memory_p,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
std::vector<mkldnn::primitive>& pipeline, // NOLINT
bool is_persistent = false) {
auto user_weights_pd = user_weights_memory_p->get_primitive_desc();
auto weights_pd = conv_pd_->weights_primitive_desc();
return this->AcquireMemory(weights_pd, user_weights_pd,
user_weights_memory_p, "@weights_mem_p",
pipeline);
pipeline, is_persistent);
}
std::shared_ptr<mkldnn::memory> AcquireBiasMemoryFromPrimitive(
......@@ -266,6 +267,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
const bool is_test = ctx.Attr<bool>("is_test");
auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
......@@ -296,6 +299,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
bool fuse_relu = ctx.Attr<bool>("fuse_relu");
int groups = ctx.Attr<int>("groups");
// TODO(pzelazko-intel) add support for group convolution and dilation
......@@ -348,11 +352,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias_tz = paddle::framework::vectorize2int(bias->dims());
auto bias_md = platform::MKLDNNMemDesc(
bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x);
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md,
strides, paddings, mkldnn_engine);
conv_pd =
ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, strides,
paddings, mkldnn_engine, fuse_relu);
} else {
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides,
paddings, mkldnn_engine);
paddings, mkldnn_engine, fuse_relu);
}
// Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx.SetBlob(key_conv_pd, conv_pd);
......@@ -371,7 +376,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto src_memory_p =
handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline);
auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive(
user_weights_memory_p, pipeline);
user_weights_memory_p, pipeline, is_test);
auto dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
......@@ -402,11 +407,26 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
}
private:
mkldnn::primitive_attr AddRelu() const {
// Fusion with ReLU layer is executed through the PostOps feature. Create a
// PostOps object and configure it to execute an eltwise relu operation.
mkldnn::primitive_attr conv_attr;
constexpr float scale = 1.0f;
constexpr float negative_slope = 0.0f;
constexpr float placeholder = 0.0f;
mkldnn::post_ops post_operations;
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu,
negative_slope, placeholder);
conv_attr.set_post_ops(post_operations);
return conv_attr;
}
std::unique_ptr<mkldnn::convolution_forward::primitive_desc>
ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights,
const memory::desc& dst, const std::vector<int>& strides,
const std::vector<int>& paddings,
const mkldnn::engine& engine) const {
const mkldnn::engine& engine,
const bool fuse_relu) const {
memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]};
......@@ -415,8 +435,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero);
auto p_conv_pd =
new mkldnn::convolution_forward::primitive_desc(conv_desc, engine);
mkldnn::primitive_attr conv_attr;
if (fuse_relu) {
conv_attr = AddRelu();
}
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine);
return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
p_conv_pd);
......@@ -427,7 +452,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const memory::desc& bias, const memory::desc& dst,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const mkldnn::engine& engine) const {
const mkldnn::engine& engine,
const bool fuse_relu) const {
memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]};
......@@ -436,8 +462,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias, dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero);
auto p_conv_pd =
new mkldnn::convolution_forward::primitive_desc(conv_desc, engine);
mkldnn::primitive_attr conv_attr;
if (fuse_relu) {
conv_attr = AddRelu();
}
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine);
return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
p_conv_pd);
......
......@@ -109,6 +109,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
}
void Conv2DOpMaker::Make() {
AddAttr<bool>("is_test", "").SetDefault(false);
AddInput(
"Input",
"(Tensor) The input tensor of convolution operator. "
......@@ -161,6 +162,8 @@ void Conv2DOpMaker::Make() {
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<bool>("fuse_relu", "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<std::string>(
"data_format",
"(string, default NCHW) Only used in "
......
......@@ -192,7 +192,8 @@ class MKLDNNHandler {
mkldnn::memory::primitive_desc& user_mpd, // NOLINT
const std::shared_ptr<mkldnn::memory> user_memory_p,
const std::string& suffix,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
std::vector<mkldnn::primitive>& pipeline, // NOLINT
bool is_persistent = false) {
// create reorder primitive if the input format is not the preferred one
auto local_key = key_ + suffix;
auto key_reorder_p = key_ + suffix + "reorder_p";
......@@ -213,7 +214,7 @@ class MKLDNNHandler {
pipeline.push_back(*reorder_p);
}
dev_ctx_.SetBlob(local_key, target_memory_p);
} else {
} else if (!is_persistent) {
// Make reorder if needed
auto reorder_p = std::static_pointer_cast<mkldnn::reorder>(
dev_ctx_.GetBlob(key_reorder_p));
......
......@@ -60,12 +60,46 @@ class InferenceTranspiler(object):
if not isinstance(scope, core.Scope):
raise TypeError("scope should be as Scope type or None")
use_mkldnn = bool(os.getenv("FLAGS_use_mkldnn", False))
self._fuse_batch_norm(program, place, scope)
if use_mkldnn:
self._fuse_relu_mkldnn(program)
self._fuse_conv_bias_mkldnn(program)
self._fuse_conv_relu_mkldnn(program)
self._fuse_bn_relu_mkldnn(program)
def _fuse_conv_relu_mkldnn(self, program):
'''
Transpile the program by fused relu activation for MKLDNN program.
Relu activation following convolution OP can be fused by adding
'fuse_relu' attribute to convolution OP.
The result of fuse is:
- before:
- conv->relu->any_other_op
- after:
- conv->any_other_op
:param program: program to transpile
:type program: Program
'''
self.block = program.block(0)
i = 0
while i < len(self.block.ops):
current_op = self.block.ops[i]
if current_op.type in ['conv2d']:
next_op = self.block.ops[i + 1]
if next_op.type == 'relu':
# modify conv OP to include relu
current_op.set_attr("fuse_relu", True)
# remove conv OP
self.block._remove_op(i + 1)
i = i + 1
# TODO(luotao): use clone() method to flush the program.desc in force,
# since some large program.desc will not be flushed immediately.
# And a better solution will be considered later.
program = program.clone()
def _fuse_relu_mkldnn(self, program):
def _fuse_bn_relu_mkldnn(self, program):
'''
Transpile the program by fused relu activation for MKLDNN program.
......@@ -159,7 +193,6 @@ class InferenceTranspiler(object):
self._fuse_conv_bias(i, current_op, next_op)
self.block._remove_op(i + 1) # Remove old conv
self.block._remove_op(i + 1) # Remove elementwise_add
i = i + 1
i = i + 1
self._remove_unused_var()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册