提交 b084dfab 编写于 作者: Y Yancey1989

Merge branch 'develop' of github.com:PaddlePaddle/Paddle into parallel_bcast

......@@ -20,6 +20,7 @@ import functools
import numpy as np
import time
import os
import math
import cProfile, pstats, StringIO
......@@ -27,128 +28,120 @@ import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.profiler as profiler
# from recordio_converter import imagenet_train, imagenet_test
from imagenet_reader import train, val
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [30, 60, 90],
"steps": [0.1, 0.01, 0.001, 0.0001]
}
}
class ResNet():
def __init__(self, layers=50, is_train=True):
self.params = train_parameters
self.layers = layers
self.is_train = is_train
def net(self, input, class_dim=1000):
layers = self.layers
supported_layers = [50, 101, 152]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(supported_layers, layers)
if layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
num_filters = [64, 128, 256, 512]
conv = self.conv_bn_layer(
input=input, num_filters=64, filter_size=7, stride=2, act='relu')
conv = fluid.layers.pool2d(
input=conv,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')
for block in range(len(depth)):
for i in range(depth[block]):
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1)
pool = fluid.layers.pool2d(
input=conv, pool_size=7, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc(input=pool,
size=class_dim,
act='softmax',
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv,
stdv)))
return out
def conv_bn_layer(self,
input,
num_filters,
filter_size,
stride=1,
groups=1,
act=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
bias_attr=False)
return fluid.layers.batch_norm(
input=conv, act=act, is_test=not self.is_train)
def shortcut(self, input, ch_out, stride):
ch_in = input.shape[1]
if ch_in != ch_out or stride != 1:
return self.conv_bn_layer(input, ch_out, 1, stride)
else:
return input
def conv_bn_layer(input,
ch_out,
filter_size,
stride,
padding,
act='relu',
is_train=True):
conv1 = fluid.layers.conv2d(
input=input,
filter_size=filter_size,
num_filters=ch_out,
stride=stride,
padding=padding,
act=None,
bias_attr=False)
return fluid.layers.batch_norm(input=conv1, act=act, is_test=not is_train)
def shortcut(input, ch_out, stride, is_train=True):
ch_in = input.shape[1] # if args.data_format == 'NCHW' else input.shape[-1]
if ch_in != ch_out:
return conv_bn_layer(
input, ch_out, 1, stride, 0, None, is_train=is_train)
else:
return input
def basicblock(input, ch_out, stride, is_train=True):
short = shortcut(input, ch_out, stride, is_train=is_train)
conv1 = conv_bn_layer(input, ch_out, 3, stride, 1, is_train=is_train)
conv2 = conv_bn_layer(conv1, ch_out, 3, 1, 1, act=None, is_train=is_train)
return fluid.layers.elementwise_add(x=short, y=conv2, act='relu')
def bottleneck(input, ch_out, stride, is_train=True):
short = shortcut(input, ch_out * 4, stride, is_train=is_train)
conv1 = conv_bn_layer(input, ch_out, 1, stride, 0, is_train=is_train)
conv2 = conv_bn_layer(conv1, ch_out, 3, 1, 1, is_train=is_train)
conv3 = conv_bn_layer(
conv2, ch_out * 4, 1, 1, 0, act=None, is_train=is_train)
return fluid.layers.elementwise_add(x=short, y=conv3, act='relu')
def layer_warp(block_func, input, ch_out, count, stride):
res_out = block_func(input, ch_out, stride)
for i in range(1, count):
res_out = block_func(res_out, ch_out, 1)
return res_out
def bottleneck_block(self, input, num_filters, stride):
conv0 = self.conv_bn_layer(
input=input, num_filters=num_filters, filter_size=1, act='relu')
conv1 = self.conv_bn_layer(
input=conv0,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu')
conv2 = self.conv_bn_layer(
input=conv1, num_filters=num_filters * 4, filter_size=1, act=None)
def resnet_imagenet(input,
class_dim,
depth=50,
data_format='NCHW',
is_train=True):
short = self.shortcut(input, num_filters * 4, stride)
cfg = {
18: ([2, 2, 2, 1], basicblock),
34: ([3, 4, 6, 3], basicblock),
50: ([3, 4, 6, 3], bottleneck),
101: ([3, 4, 23, 3], bottleneck),
152: ([3, 8, 36, 3], bottleneck)
}
stages, block_func = cfg[depth]
conv1 = conv_bn_layer(input, ch_out=64, filter_size=7, stride=2, padding=3)
pool1 = fluid.layers.pool2d(
input=conv1, pool_type='avg', pool_size=3, pool_stride=2)
res1 = layer_warp(block_func, pool1, 64, stages[0], 1)
res2 = layer_warp(block_func, res1, 128, stages[1], 2)
res3 = layer_warp(block_func, res2, 256, stages[2], 2)
res4 = layer_warp(block_func, res3, 512, stages[3], 2)
pool2 = fluid.layers.pool2d(
input=res4,
pool_size=7,
pool_type='avg',
pool_stride=1,
global_pooling=True)
out = fluid.layers.fc(input=pool2, size=class_dim, act='softmax')
return out
def resnet_cifar10(input, class_dim, depth=32, data_format='NCHW'):
assert (depth - 2) % 6 == 0
n = (depth - 2) // 6
conv1 = conv_bn_layer(
input=input, ch_out=16, filter_size=3, stride=1, padding=1)
res1 = layer_warp(basicblock, conv1, 16, n, 1)
res2 = layer_warp(basicblock, res1, 32, n, 2)
res3 = layer_warp(basicblock, res2, 64, n, 2)
pool = fluid.layers.pool2d(
input=res3, pool_size=8, pool_type='avg', pool_stride=1)
out = fluid.layers.fc(input=pool, size=class_dim, act='softmax')
return out
return fluid.layers.elementwise_add(x=short, y=conv2, act='relu')
def _model_reader_dshape_classdim(args, is_train):
model = resnet_cifar10
model = None
reader = None
if args.data_set == "cifar10":
class_dim = 10
if args.data_format == 'NCHW':
dshape = [3, 32, 32]
else:
dshape = [32, 32, 3]
model = resnet_cifar10
if is_train:
reader = paddle.dataset.cifar.train10()
else:
reader = paddle.dataset.cifar.test10()
elif args.data_set == "flowers":
if args.data_set == "flowers":
class_dim = 102
if args.data_format == 'NCHW':
dshape = [3, 224, 224]
else:
dshape = [224, 224, 3]
model = resnet_imagenet
if is_train:
reader = paddle.dataset.flowers.train()
else:
......@@ -159,7 +152,6 @@ def _model_reader_dshape_classdim(args, is_train):
dshape = [3, 224, 224]
else:
dshape = [224, 224, 3]
model = resnet_imagenet
if not args.data_path:
raise Exception(
"Must specify --data_path when training with imagenet")
......@@ -173,12 +165,11 @@ def _model_reader_dshape_classdim(args, is_train):
reader = train(xmap=False)
else:
reader = val(xmap=False)
return model, reader, dshape, class_dim
return reader, dshape, class_dim
def get_model(args, is_train, main_prog, startup_prog):
model, reader, dshape, class_dim = _model_reader_dshape_classdim(args,
is_train)
reader, dshape, class_dim = _model_reader_dshape_classdim(args, is_train)
pyreader = None
trainer_count = int(os.getenv("PADDLE_TRAINERS"))
......@@ -198,7 +189,8 @@ def get_model(args, is_train, main_prog, startup_prog):
label = fluid.layers.data(
name='label', shape=[1], dtype='int64')
predict = model(input, class_dim, is_train=is_train)
model = ResNet(is_train=is_train)
predict = model.net(input, class_dim=class_dim)
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(x=cost)
......@@ -216,15 +208,14 @@ def get_model(args, is_train, main_prog, startup_prog):
total_images = 1281167 / trainer_count
step = int(total_images / args.batch_size + 1)
epochs = [30, 60, 80, 90]
epochs = [30, 60, 90]
bd = [step * e for e in epochs]
base_lr = args.learning_rate
lr = []
lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)]
optimizer = fluid.optimizer.Momentum(
learning_rate=base_lr,
#learning_rate=fluid.layers.piecewise_decay(
# boundaries=bd, values=lr),
learning_rate=fluid.layers.piecewise_decay(
boundaries=bd, values=lr),
momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4))
optimizer.minimize(avg_cost)
......
......@@ -262,7 +262,7 @@ void NativePaddlePredictor::GetFetchOne(const framework::LoDTensor &fetch,
if (buffer.empty() || buffer.length() < sizeof(T) * data.size()) {
buffer.Resize(sizeof(T) * data.size());
}
std::memcpy(buffer.data(), data.data(), buffer.length());
std::memcpy(buffer.data(), data.data(), sizeof(T) * data.size());
// copy LoD
for (const auto &level : fetch.lod()) {
output->lod.emplace_back(level);
......
......@@ -117,34 +117,6 @@ void GetOneBatch(std::vector<PaddleTensor> *input_slots, DataRecord *data,
input_slots->assign({input_tensor});
}
void BenchAllData(const std::string &model_path, const std::string &data_file,
const int batch_size, const int repeat) {
NativeConfig config;
config.model_dir = model_path;
config.use_gpu = false;
config.device = 0;
config.specify_input_name = true;
std::vector<PaddleTensor> input_slots, outputs_slots;
DataRecord data(data_file, batch_size);
auto predictor =
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(config);
GetOneBatch(&input_slots, &data, batch_size);
for (int i = 0; i < FLAGS_burning; i++) {
predictor->Run(input_slots, &outputs_slots);
}
Timer timer;
double sum = 0;
for (int i = 0; i < repeat; i++) {
for (size_t bid = 0; bid < data.batched_datas.size(); ++bid) {
GetOneBatch(&input_slots, &data, batch_size);
timer.tic();
predictor->Run(input_slots, &outputs_slots);
sum += timer.toc();
}
}
PrintTime(batch_size, repeat, 1, 0, sum / repeat);
}
const int64_t lac_ref_data[] = {24, 25, 25, 25, 38, 30, 31, 14, 15, 44, 24, 25,
25, 25, 25, 25, 44, 24, 25, 25, 25, 36, 42, 43,
44, 14, 15, 44, 14, 15, 44, 14, 15, 44, 38, 39,
......
......@@ -130,12 +130,13 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryFromPrimitive(
const std::shared_ptr<mkldnn::memory> user_weights_memory_p,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
std::vector<mkldnn::primitive>& pipeline, // NOLINT
bool is_persistent = false) {
auto user_weights_pd = user_weights_memory_p->get_primitive_desc();
auto weights_pd = conv_pd_->weights_primitive_desc();
return this->AcquireMemory(weights_pd, user_weights_pd,
user_weights_memory_p, "@weights_mem_p",
pipeline);
pipeline, is_persistent);
}
std::shared_ptr<mkldnn::memory> AcquireBiasMemoryFromPrimitive(
......@@ -266,6 +267,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
const bool is_test = ctx.Attr<bool>("is_test");
auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
......@@ -296,6 +299,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
bool fuse_relu = ctx.Attr<bool>("fuse_relu");
int groups = ctx.Attr<int>("groups");
// TODO(pzelazko-intel) add support for group convolution and dilation
......@@ -348,11 +352,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias_tz = paddle::framework::vectorize2int(bias->dims());
auto bias_md = platform::MKLDNNMemDesc(
bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x);
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md,
strides, paddings, mkldnn_engine);
conv_pd =
ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, strides,
paddings, mkldnn_engine, fuse_relu);
} else {
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides,
paddings, mkldnn_engine);
paddings, mkldnn_engine, fuse_relu);
}
// Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx.SetBlob(key_conv_pd, conv_pd);
......@@ -371,7 +376,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto src_memory_p =
handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline);
auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive(
user_weights_memory_p, pipeline);
user_weights_memory_p, pipeline, is_test);
auto dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
......@@ -402,11 +407,26 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
}
private:
mkldnn::primitive_attr AddRelu() const {
// Fusion with ReLU layer is executed through the PostOps feature. Create a
// PostOps object and configure it to execute an eltwise relu operation.
mkldnn::primitive_attr conv_attr;
constexpr float scale = 1.0f;
constexpr float negative_slope = 0.0f;
constexpr float placeholder = 0.0f;
mkldnn::post_ops post_operations;
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu,
negative_slope, placeholder);
conv_attr.set_post_ops(post_operations);
return conv_attr;
}
std::unique_ptr<mkldnn::convolution_forward::primitive_desc>
ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights,
const memory::desc& dst, const std::vector<int>& strides,
const std::vector<int>& paddings,
const mkldnn::engine& engine) const {
const mkldnn::engine& engine,
const bool fuse_relu) const {
memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]};
......@@ -415,8 +435,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero);
auto p_conv_pd =
new mkldnn::convolution_forward::primitive_desc(conv_desc, engine);
mkldnn::primitive_attr conv_attr;
if (fuse_relu) {
conv_attr = AddRelu();
}
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine);
return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
p_conv_pd);
......@@ -427,7 +452,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const memory::desc& bias, const memory::desc& dst,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const mkldnn::engine& engine) const {
const mkldnn::engine& engine,
const bool fuse_relu) const {
memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]};
......@@ -436,8 +462,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias, dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero);
auto p_conv_pd =
new mkldnn::convolution_forward::primitive_desc(conv_desc, engine);
mkldnn::primitive_attr conv_attr;
if (fuse_relu) {
conv_attr = AddRelu();
}
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine);
return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
p_conv_pd);
......
......@@ -109,6 +109,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
}
void Conv2DOpMaker::Make() {
AddAttr<bool>("is_test", "").SetDefault(false);
AddInput(
"Input",
"(Tensor) The input tensor of convolution operator. "
......@@ -161,6 +162,8 @@ void Conv2DOpMaker::Make() {
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<bool>("fuse_relu", "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<std::string>(
"data_format",
"(string, default NCHW) Only used in "
......
......@@ -20,6 +20,7 @@ if(WITH_GRPC)
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL)
cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL)
cc_test(varhandle_test SRCS varhandle_test.cc)
return()
endif()
......
......@@ -59,40 +59,32 @@ GRPCClient::~GRPCClient() {
}
channels_.clear();
}
client_thread_->join();
}
bool GRPCClient::AsyncSendVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name, int64_t time_out) {
VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
const std::string var_name_val = var_name;
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
SendProcessor* s = new SendProcessor(ch);
VarHandlePtr h(new VarHandle(ep, "Send", var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out);
framework::AsyncIO([var_name_val, p_ctx, ep_val, p_scope, time_out, ch,
this] {
framework::AsyncIO([var_name_val, p_scope, p_ctx, s, this] {
auto* var = p_scope->FindVar(var_name_val);
::grpc::ByteBuffer req;
SerializeToByteBuffer(var_name_val, var, *p_ctx, &req);
// varhandle
VarHandle var_h;
var_h.ep = ep_val;
var_h.scope = p_scope;
var_h.name = var_name_val;
var_h.ctx = p_ctx;
var_h.method = "Send";
VLOG(3) << var_h.String() << " begin";
VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
// stub context
SendProcessor* s = new SendProcessor(ch);
s->Prepare(var_h, time_out);
s->response_call_back_ = nullptr;
auto call = s->stub_g_.PrepareUnaryCall(
......@@ -102,13 +94,13 @@ bool GRPCClient::AsyncSendVar(const std::string& ep,
});
req_count_++;
return true;
return h;
}
void ProcGetResponse(const VarHandle& var_h,
const ::grpc::ByteBuffer& ret_msg) {
framework::Variable* outvar = nullptr;
DeserializeFromByteBuffer(ret_msg, *var_h.ctx, var_h.scope, &outvar);
DeserializeFromByteBuffer(ret_msg, *var_h.ctx(), var_h.scope(), &outvar);
}
template <typename T>
......@@ -119,37 +111,30 @@ void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
result->Swap(&tmp);
}
bool GRPCClient::AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name, int64_t time_out) {
VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
const std::string var_name_val = var_name;
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
GetProcessor* s = new GetProcessor(ch);
VarHandlePtr h(new VarHandle(ep, "Get", var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out);
framework::AsyncIO([var_name_val, ep_val, p_scope, p_ctx, time_out, ch,
this] {
framework::AsyncIO([var_name_val, p_scope, p_ctx, s, this] {
// prepare input
sendrecv::VariableMessage req;
req.set_varname(var_name_val);
::grpc::ByteBuffer buf;
RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
// var handle
VarHandle var_h;
var_h.ep = ep_val;
var_h.scope = p_scope;
var_h.name = var_name_val;
var_h.ctx = p_ctx;
var_h.method = "Get";
VLOG(3) << var_h.String() << " begin";
VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
// stub context
GetProcessor* s = new GetProcessor(ch);
s->Prepare(var_h, time_out);
s->response_call_back_ = ProcGetResponse;
auto call = s->stub_g_.PrepareUnaryCall(
......@@ -160,42 +145,36 @@ bool GRPCClient::AsyncGetVar(const std::string& ep,
req_count_++;
return true;
return h;
}
bool GRPCClient::AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
int64_t time_out) {
VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
const std::string in_var_name_val = in_var_name;
const std::string out_var_name_val = out_var_name;
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
GetProcessor* s = new GetProcessor(ch);
VarHandlePtr h(
new VarHandle(ep, "Prefetch", out_var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out);
framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
time_out, ch, this] {
time_out, s, this] {
auto* var = p_scope->FindVar(in_var_name_val);
::grpc::ByteBuffer req;
SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val);
// var handle
VarHandle var_h;
var_h.ep = ep_val;
var_h.scope = p_scope;
var_h.name = out_var_name_val;
var_h.ctx = p_ctx;
var_h.method = "Prefetch";
VLOG(3) << var_h.String() << " begin";
VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
// stub context
GetProcessor* s = new GetProcessor(ch);
s->Prepare(var_h, time_out);
s->response_call_back_ = ProcGetResponse;
auto call = s->stub_g_.PrepareUnaryCall(
......@@ -206,56 +185,68 @@ bool GRPCClient::AsyncPrefetchVar(const std::string& ep,
});
req_count_++;
return true;
return h;
}
void GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out) {
VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out) {
const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
s->Prepare(time_out);
VarHandlePtr h(new VarHandle(ep, "BatchBarrier", BATCH_BARRIER_MESSAGE,
nullptr, nullptr));
s->Prepare(h, time_out);
sendrecv::VariableMessage req;
req.set_varname(BATCH_BARRIER_MESSAGE);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
return h;
}
void GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out) {
VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out) {
const auto ch = GetChannel(ep);
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
s->Prepare(time_out);
VarHandlePtr h(new VarHandle(ep, "FetchBarrier", FETCH_BARRIER_MESSAGE,
nullptr, nullptr));
s->Prepare(h, time_out);
sendrecv::VariableMessage req;
req.set_varname(FETCH_BARRIER_MESSAGE);
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
return h;
}
void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) {
VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
int64_t time_out) {
const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
s->Prepare(time_out);
VarHandlePtr h(
new VarHandle(ep, "SendComplete", COMPLETE_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out);
sendrecv::VariableMessage req;
req.set_varname(COMPLETE_MESSAGE);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
return h;
}
void GRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dir,
int64_t time_out) {
VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dir,
int64_t time_out) {
const auto ch = GetChannel(ep);
CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch);
s->Prepare(time_out);
VarHandlePtr h(new VarHandle(ep, "CheckPointNotify", CHECKPOINT_SAVE_MESSAGE,
nullptr, nullptr));
s->Prepare(h, time_out);
sendrecv::VariableMessage req;
req.set_varname(CHECKPOINT_SAVE_MESSAGE);
......@@ -264,6 +255,7 @@ void GRPCClient::AsyncCheckpointNotify(const std::string& ep,
auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
return h;
}
bool GRPCClient::Wait() {
......@@ -276,25 +268,28 @@ void GRPCClient::Proceed() {
void* tag = nullptr;
bool ok = false;
VLOG(3) << "GRPCClient Proceed begin";
while (!stopped_ && cq_.Next(&tag, &ok)) {
BaseProcessor* c = static_cast<BaseProcessor*>(tag);
GPR_ASSERT(ok);
PADDLE_ENFORCE(c);
if (c->status_.ok()) {
VLOG(3) << c->var_h_.String() << " process";
VLOG(3) << c->GetVarHandlePtr()->String() << " process";
c->Process();
} else if (c->status_.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) {
LOG(ERROR) << c->var_h_.String()
LOG(ERROR) << c->GetVarHandlePtr()->String()
<< " meets grpc error:" << c->status_.error_message();
{
std::lock_guard<std::mutex> lk(sync_mutex_);
ok_ = false;
}
sync_cond_.notify_all();
c->Finish(false);
} else {
LOG(FATAL) << c->var_h_.String()
LOG(FATAL) << c->GetVarHandlePtr()->String()
<< " meets grpc error:" << c->status_.error_message();
c->Finish(false);
}
delete c;
{
std::lock_guard<std::mutex> lk(sync_mutex_);
......@@ -302,6 +297,7 @@ void GRPCClient::Proceed() {
}
sync_cond_.notify_all();
}
VLOG(3) << "GRPCClient Proceed end";
}
std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) {
......
......@@ -53,15 +53,14 @@ void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg);
class BaseProcessor {
public:
explicit BaseProcessor(std::shared_ptr<grpc::Channel> ch) {
context_ = nullptr;
}
BaseProcessor() { context_ = nullptr; }
virtual ~BaseProcessor() {}
virtual void Prepare(const VarHandle& var_info, int64_t time_out) {
virtual void Prepare(VarHandlePtr h, int64_t time_out) {
var_h_ = h;
context_.reset(new grpc::ClientContext());
var_h_ = var_info;
context_->set_wait_for_ready(true);
if (time_out) {
std::chrono::system_clock::time_point deadline =
......@@ -71,21 +70,21 @@ class BaseProcessor {
}
}
virtual void Prepare(int64_t time_out) {
context_.reset(new grpc::ClientContext());
context_->set_wait_for_ready(true);
std::chrono::system_clock::time_point deadline =
std::chrono::system_clock::now() + std::chrono::milliseconds(time_out);
context_->set_deadline(deadline);
void Process() {
ProcessImpl();
var_h_->Finish(true);
}
virtual void Process() = 0;
VarHandlePtr GetVarHandlePtr() { return var_h_; }
bool Wait() { return var_h_->Wait(); }
void Finish(bool ok) { return var_h_->Finish(ok); }
virtual void ProcessImpl() = 0;
std::unique_ptr<grpc::ClientContext> context_;
grpc::Status status_;
VarHandle var_h_;
protected:
VarHandlePtr var_h_;
};
typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)>
......@@ -94,13 +93,13 @@ typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)>
class SendProcessor : public BaseProcessor {
public:
explicit SendProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch), stub_g_(ch) {}
: BaseProcessor(), stub_g_(ch) {}
virtual ~SendProcessor() {}
virtual void Process() {
void ProcessImpl() override {
if (response_call_back_) {
response_call_back_(var_h_, reply_);
response_call_back_(*var_h_.get(), reply_);
}
}
......@@ -115,13 +114,13 @@ typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)>
class GetProcessor : public BaseProcessor {
public:
explicit GetProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch), stub_g_(ch) {}
: BaseProcessor(), stub_g_(ch) {}
virtual ~GetProcessor() {}
virtual void Process() {
void ProcessImpl() override {
if (response_call_back_) {
response_call_back_(var_h_, reply_);
response_call_back_(*var_h_.get(), reply_);
}
}
......@@ -133,13 +132,13 @@ class GetProcessor : public BaseProcessor {
class BatchBarrierProcessor : public BaseProcessor {
public:
explicit BatchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch) {
: BaseProcessor() {
stub_ = sendrecv::SendRecvService::NewStub(ch);
}
virtual ~BatchBarrierProcessor() {}
virtual void Process() {}
void ProcessImpl() override {}
sendrecv::VoidMessage reply_;
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
};
......@@ -147,13 +146,13 @@ class BatchBarrierProcessor : public BaseProcessor {
class FetchBarrierProcessor : public BaseProcessor {
public:
explicit FetchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch) {
: BaseProcessor() {
stub_ = sendrecv::SendRecvService::NewStub(ch);
}
virtual ~FetchBarrierProcessor() {}
virtual void Process() {}
void ProcessImpl() override {}
sendrecv::VariableMessage reply_;
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
};
......@@ -161,13 +160,13 @@ class FetchBarrierProcessor : public BaseProcessor {
class CheckpointNotifyProcessor : public BaseProcessor {
public:
explicit CheckpointNotifyProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch) {
: BaseProcessor() {
stub_ = sendrecv::SendRecvService::NewStub(ch);
}
virtual ~CheckpointNotifyProcessor() {}
virtual void Process() {}
void ProcessImpl() override {}
sendrecv::VoidMessage reply_;
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
};
......@@ -177,32 +176,37 @@ class GRPCClient : public RPCClient {
GRPCClient() : ok_(true), completed_(false), stopped_(false) {}
virtual ~GRPCClient();
bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) override;
bool AsyncGetVar(const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) override;
bool AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncCheckpointNotify(const std::string& ep, const std::string& dir,
int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncSendComplete(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncSendVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncSendBatchBarrier(
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncSendFetchBarrier(
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncCheckpointNotify(
const std::string& ep, const std::string& dir,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncSendComplete(
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override;
bool Wait() override;
......
......@@ -28,6 +28,7 @@
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace operators {
......@@ -49,23 +50,77 @@ constexpr char kRequestPassBarrier[] = "RequestPassBarrier";
class RPCServer;
struct VarHandle {
// RPC endpoint.
std::string ep;
const platform::DeviceContext* ctx;
const framework::Scope* scope;
// Variable name.
std::string name;
// RPC method name.
std::string method;
class VarHandle {
public:
VarHandle(const std::string ep, const std::string& method,
const std::string& name,
const platform::DeviceContext* p_ctx = nullptr,
const framework::Scope* p_scope = nullptr)
: ok_(kVarHandleDefaultState) {
ep_ = ep;
ctx_ = p_ctx;
scope_ = p_scope;
name_ = name;
method_ = method;
}
virtual ~VarHandle() {}
public:
bool Wait() {
{
std::unique_lock<std::mutex> lk(sync_mutex_);
wait_cond_.wait(lk, [this] { return ok_ != kVarHandleDefaultState; });
}
VLOG(7) << "VarHandle wait:" << ok_;
return ok_ != 0;
}
void Finish(bool ok) {
{
std::unique_lock<std::mutex> lk(sync_mutex_);
ok_ = ok;
}
VLOG(7) << "VarHandle finish:" << ok;
wait_cond_.notify_all();
}
std::string String() const {
std::ostringstream s;
s << method << " name:[" << name << "], ep:[" << ep << "]";
s << method_ << " name:[" << name_ << "], ep:[" << ep_ << "], ok:[" << ok_
<< "]";
return s.str();
}
std::string ep() const { return ep_; }
const platform::DeviceContext* ctx() const { return ctx_; }
const framework::Scope* scope() const { return scope_; }
std::string name() const { return name_; }
std::string method() const { return method_; }
protected:
// RPC endpoint.
std::string ep_;
const platform::DeviceContext* ctx_;
const framework::Scope* scope_;
// Variable name.
std::string name_;
// RPC method name.
std::string method_;
protected:
std::mutex sync_mutex_;
std::condition_variable wait_cond_;
int ok_;
static const int kVarHandleDefaultState = -1;
private:
DISABLE_COPY_AND_ASSIGN(VarHandle);
};
typedef std::shared_ptr<VarHandle> VarHandlePtr;
class RequestHandler {
public:
explicit RequestHandler(bool sync_mode)
......
......@@ -14,12 +14,14 @@
#pragma once
#include <condition_variable> // NOLINT
#include <string>
#include "gflags/gflags.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
DECLARE_int32(rpc_deadline);
......@@ -31,37 +33,36 @@ class RPCClient {
public:
RPCClient() {}
virtual ~RPCClient() {}
virtual bool AsyncSendVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual bool AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual bool AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual void AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual void AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual void AsyncCheckpointNotify(const std::string& ep,
const std::string& dir,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual void AsyncSendComplete(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncSendVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncPrefetchVar(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& in_var_name,
const std::string& out_var_name,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncSendBatchBarrier(
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncSendFetchBarrier(
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncCheckpointNotify(
const std::string& ep, const std::string& dir,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncSendComplete(
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) = 0;
// Complete tells all the pserver instances that finishe the training,
// the pserver can reduce it's barrier count, and continue to train
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <unistd.h>
#include <string>
#include <thread> // NOLINT
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
using paddle::operators::distributed::VarHandlePtr;
using paddle::operators::distributed::VarHandle;
void WaitTrue(VarHandlePtr s) { EXPECT_TRUE(s->Wait()); }
void WaitFalse(VarHandlePtr s) { EXPECT_FALSE(s->Wait()); }
TEST(VarHandle, Run) {
std::vector<VarHandlePtr> a;
for (int i = 0; i < 12; i++) {
VarHandlePtr s(new VarHandle("", "", "", nullptr, nullptr));
a.push_back(s);
}
std::vector<std::unique_ptr<std::thread>> t;
for (int i = 0; i < 6; i++) {
t.emplace_back(new std::thread(WaitFalse, a[i]));
}
for (int i = 0; i < 6; i++) {
a[i]->Finish(false);
t[i]->join();
}
for (int i = 6; i < 12; i++) {
t.emplace_back(new std::thread(WaitTrue, a[i]));
}
for (int i = 6; i < 12; i++) {
a[i]->Finish(true);
t[i]->join();
}
}
......@@ -44,16 +44,20 @@ class PrefetchOp : public framework::OperatorBase {
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i] << " to get "
<< outs[i] << " back";
rpc_client->AsyncPrefetchVar(epmap[i], ctx, scope, ins[i], outs[i]);
rets.push_back(rpc_client->AsyncPrefetchVar(epmap[i], ctx, scope,
ins[i], outs[i]));
} else {
VLOG(3) << "don't send no-initialied variable: " << ins[i];
}
}
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
}
}
};
......
......@@ -44,12 +44,15 @@ class RecvOp : public framework::OperatorBase {
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < outs.size(); i++) {
VLOG(3) << "getting " << outs[i] << " from " << epmap[i];
rpc_client->AsyncGetVar(epmap[i], ctx, scope, outs[i]);
rets.push_back(rpc_client->AsyncGetVar(epmap[i], ctx, scope, outs[i]));
}
if (sync_mode) {
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
}
}
}
};
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include <future> // NOLINT
#include <ostream>
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
......@@ -45,18 +46,19 @@ class SendOp : public framework::OperatorBase {
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
// TODO(Yancey1989): we need to use an IO threadpool which has
// a larger number of threads than the computing threadpool.
rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i]);
rets.push_back(rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i]));
} else {
VLOG(3) << "don't send no-initialied variable: " << ins[i];
}
}
if (sync_send) {
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
}
}
}
};
......
......@@ -192,7 +192,8 @@ class MKLDNNHandler {
mkldnn::memory::primitive_desc& user_mpd, // NOLINT
const std::shared_ptr<mkldnn::memory> user_memory_p,
const std::string& suffix,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
std::vector<mkldnn::primitive>& pipeline, // NOLINT
bool is_persistent = false) {
// create reorder primitive if the input format is not the preferred one
auto local_key = key_ + suffix;
auto key_reorder_p = key_ + suffix + "reorder_p";
......@@ -213,7 +214,7 @@ class MKLDNNHandler {
pipeline.push_back(*reorder_p);
}
dev_ctx_.SetBlob(local_key, target_memory_p);
} else {
} else if (!is_persistent) {
// Make reorder if needed
auto reorder_p = std::static_pointer_cast<mkldnn::reorder>(
dev_ctx_.GetBlob(key_reorder_p));
......
......@@ -128,6 +128,13 @@ class ParallelExecutor(object):
os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
exec_strategy.num_threads = cpu_num * 2
# Set 1 thread num under nccl2 distribute
# env to make sure all gpus run ops in same order.
if num_trainers > 1:
assert (use_cuda)
# FIXME(gongwb): avoid this set.
exec_strategy.num_threads = 1
if build_strategy is None:
build_strategy = BuildStrategy()
......
......@@ -60,12 +60,46 @@ class InferenceTranspiler(object):
if not isinstance(scope, core.Scope):
raise TypeError("scope should be as Scope type or None")
use_mkldnn = bool(os.getenv("FLAGS_use_mkldnn", False))
self._fuse_batch_norm(program, place, scope)
if use_mkldnn:
self._fuse_relu_mkldnn(program)
self._fuse_conv_bias_mkldnn(program)
self._fuse_conv_relu_mkldnn(program)
self._fuse_bn_relu_mkldnn(program)
def _fuse_conv_relu_mkldnn(self, program):
'''
Transpile the program by fused relu activation for MKLDNN program.
Relu activation following convolution OP can be fused by adding
'fuse_relu' attribute to convolution OP.
The result of fuse is:
- before:
- conv->relu->any_other_op
- after:
- conv->any_other_op
:param program: program to transpile
:type program: Program
'''
self.block = program.block(0)
i = 0
while i < len(self.block.ops):
current_op = self.block.ops[i]
if current_op.type in ['conv2d']:
next_op = self.block.ops[i + 1]
if next_op.type == 'relu':
# modify conv OP to include relu
current_op.set_attr("fuse_relu", True)
# remove conv OP
self.block._remove_op(i + 1)
i = i + 1
# TODO(luotao): use clone() method to flush the program.desc in force,
# since some large program.desc will not be flushed immediately.
# And a better solution will be considered later.
program = program.clone()
def _fuse_relu_mkldnn(self, program):
def _fuse_bn_relu_mkldnn(self, program):
'''
Transpile the program by fused relu activation for MKLDNN program.
......@@ -159,7 +193,6 @@ class InferenceTranspiler(object):
self._fuse_conv_bias(i, current_op, next_op)
self.block._remove_op(i + 1) # Remove old conv
self.block._remove_op(i + 1) # Remove elementwise_add
i = i + 1
i = i + 1
self._remove_unused_var()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册