提交 b084dfab 编写于 作者: Y Yancey1989

Merge branch 'develop' of github.com:PaddlePaddle/Paddle into parallel_bcast

...@@ -20,6 +20,7 @@ import functools ...@@ -20,6 +20,7 @@ import functools
import numpy as np import numpy as np
import time import time
import os import os
import math
import cProfile, pstats, StringIO import cProfile, pstats, StringIO
...@@ -27,128 +28,120 @@ import paddle ...@@ -27,128 +28,120 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.fluid.profiler as profiler import paddle.fluid.profiler as profiler
# from recordio_converter import imagenet_train, imagenet_test
from imagenet_reader import train, val from imagenet_reader import train, val
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [30, 60, 90],
"steps": [0.1, 0.01, 0.001, 0.0001]
}
}
class ResNet():
def __init__(self, layers=50, is_train=True):
self.params = train_parameters
self.layers = layers
self.is_train = is_train
def net(self, input, class_dim=1000):
layers = self.layers
supported_layers = [50, 101, 152]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(supported_layers, layers)
if layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
num_filters = [64, 128, 256, 512]
conv = self.conv_bn_layer(
input=input, num_filters=64, filter_size=7, stride=2, act='relu')
conv = fluid.layers.pool2d(
input=conv,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')
for block in range(len(depth)):
for i in range(depth[block]):
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1)
pool = fluid.layers.pool2d(
input=conv, pool_size=7, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc(input=pool,
size=class_dim,
act='softmax',
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv,
stdv)))
return out
def conv_bn_layer(input, def conv_bn_layer(self,
ch_out, input,
num_filters,
filter_size, filter_size,
stride, stride=1,
padding, groups=1,
act='relu', act=None):
is_train=True): conv = fluid.layers.conv2d(
conv1 = fluid.layers.conv2d(
input=input, input=input,
num_filters=num_filters,
filter_size=filter_size, filter_size=filter_size,
num_filters=ch_out,
stride=stride, stride=stride,
padding=padding, padding=(filter_size - 1) // 2,
groups=groups,
act=None, act=None,
bias_attr=False) bias_attr=False)
return fluid.layers.batch_norm(input=conv1, act=act, is_test=not is_train) return fluid.layers.batch_norm(
input=conv, act=act, is_test=not self.is_train)
def shortcut(self, input, ch_out, stride):
def shortcut(input, ch_out, stride, is_train=True): ch_in = input.shape[1]
ch_in = input.shape[1] # if args.data_format == 'NCHW' else input.shape[-1] if ch_in != ch_out or stride != 1:
if ch_in != ch_out: return self.conv_bn_layer(input, ch_out, 1, stride)
return conv_bn_layer(
input, ch_out, 1, stride, 0, None, is_train=is_train)
else: else:
return input return input
def bottleneck_block(self, input, num_filters, stride):
conv0 = self.conv_bn_layer(
input=input, num_filters=num_filters, filter_size=1, act='relu')
conv1 = self.conv_bn_layer(
input=conv0,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu')
conv2 = self.conv_bn_layer(
input=conv1, num_filters=num_filters * 4, filter_size=1, act=None)
def basicblock(input, ch_out, stride, is_train=True): short = self.shortcut(input, num_filters * 4, stride)
short = shortcut(input, ch_out, stride, is_train=is_train)
conv1 = conv_bn_layer(input, ch_out, 3, stride, 1, is_train=is_train)
conv2 = conv_bn_layer(conv1, ch_out, 3, 1, 1, act=None, is_train=is_train)
return fluid.layers.elementwise_add(x=short, y=conv2, act='relu')
def bottleneck(input, ch_out, stride, is_train=True):
short = shortcut(input, ch_out * 4, stride, is_train=is_train)
conv1 = conv_bn_layer(input, ch_out, 1, stride, 0, is_train=is_train)
conv2 = conv_bn_layer(conv1, ch_out, 3, 1, 1, is_train=is_train)
conv3 = conv_bn_layer(
conv2, ch_out * 4, 1, 1, 0, act=None, is_train=is_train)
return fluid.layers.elementwise_add(x=short, y=conv3, act='relu')
def layer_warp(block_func, input, ch_out, count, stride):
res_out = block_func(input, ch_out, stride)
for i in range(1, count):
res_out = block_func(res_out, ch_out, 1)
return res_out
def resnet_imagenet(input,
class_dim,
depth=50,
data_format='NCHW',
is_train=True):
cfg = {
18: ([2, 2, 2, 1], basicblock),
34: ([3, 4, 6, 3], basicblock),
50: ([3, 4, 6, 3], bottleneck),
101: ([3, 4, 23, 3], bottleneck),
152: ([3, 8, 36, 3], bottleneck)
}
stages, block_func = cfg[depth]
conv1 = conv_bn_layer(input, ch_out=64, filter_size=7, stride=2, padding=3)
pool1 = fluid.layers.pool2d(
input=conv1, pool_type='avg', pool_size=3, pool_stride=2)
res1 = layer_warp(block_func, pool1, 64, stages[0], 1)
res2 = layer_warp(block_func, res1, 128, stages[1], 2)
res3 = layer_warp(block_func, res2, 256, stages[2], 2)
res4 = layer_warp(block_func, res3, 512, stages[3], 2)
pool2 = fluid.layers.pool2d(
input=res4,
pool_size=7,
pool_type='avg',
pool_stride=1,
global_pooling=True)
out = fluid.layers.fc(input=pool2, size=class_dim, act='softmax')
return out
def resnet_cifar10(input, class_dim, depth=32, data_format='NCHW'):
assert (depth - 2) % 6 == 0
n = (depth - 2) // 6
conv1 = conv_bn_layer( return fluid.layers.elementwise_add(x=short, y=conv2, act='relu')
input=input, ch_out=16, filter_size=3, stride=1, padding=1)
res1 = layer_warp(basicblock, conv1, 16, n, 1)
res2 = layer_warp(basicblock, res1, 32, n, 2)
res3 = layer_warp(basicblock, res2, 64, n, 2)
pool = fluid.layers.pool2d(
input=res3, pool_size=8, pool_type='avg', pool_stride=1)
out = fluid.layers.fc(input=pool, size=class_dim, act='softmax')
return out
def _model_reader_dshape_classdim(args, is_train): def _model_reader_dshape_classdim(args, is_train):
model = resnet_cifar10 model = None
reader = None reader = None
if args.data_set == "cifar10": if args.data_set == "flowers":
class_dim = 10
if args.data_format == 'NCHW':
dshape = [3, 32, 32]
else:
dshape = [32, 32, 3]
model = resnet_cifar10
if is_train:
reader = paddle.dataset.cifar.train10()
else:
reader = paddle.dataset.cifar.test10()
elif args.data_set == "flowers":
class_dim = 102 class_dim = 102
if args.data_format == 'NCHW': if args.data_format == 'NCHW':
dshape = [3, 224, 224] dshape = [3, 224, 224]
else: else:
dshape = [224, 224, 3] dshape = [224, 224, 3]
model = resnet_imagenet
if is_train: if is_train:
reader = paddle.dataset.flowers.train() reader = paddle.dataset.flowers.train()
else: else:
...@@ -159,7 +152,6 @@ def _model_reader_dshape_classdim(args, is_train): ...@@ -159,7 +152,6 @@ def _model_reader_dshape_classdim(args, is_train):
dshape = [3, 224, 224] dshape = [3, 224, 224]
else: else:
dshape = [224, 224, 3] dshape = [224, 224, 3]
model = resnet_imagenet
if not args.data_path: if not args.data_path:
raise Exception( raise Exception(
"Must specify --data_path when training with imagenet") "Must specify --data_path when training with imagenet")
...@@ -173,12 +165,11 @@ def _model_reader_dshape_classdim(args, is_train): ...@@ -173,12 +165,11 @@ def _model_reader_dshape_classdim(args, is_train):
reader = train(xmap=False) reader = train(xmap=False)
else: else:
reader = val(xmap=False) reader = val(xmap=False)
return model, reader, dshape, class_dim return reader, dshape, class_dim
def get_model(args, is_train, main_prog, startup_prog): def get_model(args, is_train, main_prog, startup_prog):
model, reader, dshape, class_dim = _model_reader_dshape_classdim(args, reader, dshape, class_dim = _model_reader_dshape_classdim(args, is_train)
is_train)
pyreader = None pyreader = None
trainer_count = int(os.getenv("PADDLE_TRAINERS")) trainer_count = int(os.getenv("PADDLE_TRAINERS"))
...@@ -198,7 +189,8 @@ def get_model(args, is_train, main_prog, startup_prog): ...@@ -198,7 +189,8 @@ def get_model(args, is_train, main_prog, startup_prog):
label = fluid.layers.data( label = fluid.layers.data(
name='label', shape=[1], dtype='int64') name='label', shape=[1], dtype='int64')
predict = model(input, class_dim, is_train=is_train) model = ResNet(is_train=is_train)
predict = model.net(input, class_dim=class_dim)
cost = fluid.layers.cross_entropy(input=predict, label=label) cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(x=cost) avg_cost = fluid.layers.mean(x=cost)
...@@ -216,15 +208,14 @@ def get_model(args, is_train, main_prog, startup_prog): ...@@ -216,15 +208,14 @@ def get_model(args, is_train, main_prog, startup_prog):
total_images = 1281167 / trainer_count total_images = 1281167 / trainer_count
step = int(total_images / args.batch_size + 1) step = int(total_images / args.batch_size + 1)
epochs = [30, 60, 80, 90] epochs = [30, 60, 90]
bd = [step * e for e in epochs] bd = [step * e for e in epochs]
base_lr = args.learning_rate base_lr = args.learning_rate
lr = [] lr = []
lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)] lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)]
optimizer = fluid.optimizer.Momentum( optimizer = fluid.optimizer.Momentum(
learning_rate=base_lr, learning_rate=fluid.layers.piecewise_decay(
#learning_rate=fluid.layers.piecewise_decay( boundaries=bd, values=lr),
# boundaries=bd, values=lr),
momentum=0.9, momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4)) regularization=fluid.regularizer.L2Decay(1e-4))
optimizer.minimize(avg_cost) optimizer.minimize(avg_cost)
......
...@@ -262,7 +262,7 @@ void NativePaddlePredictor::GetFetchOne(const framework::LoDTensor &fetch, ...@@ -262,7 +262,7 @@ void NativePaddlePredictor::GetFetchOne(const framework::LoDTensor &fetch,
if (buffer.empty() || buffer.length() < sizeof(T) * data.size()) { if (buffer.empty() || buffer.length() < sizeof(T) * data.size()) {
buffer.Resize(sizeof(T) * data.size()); buffer.Resize(sizeof(T) * data.size());
} }
std::memcpy(buffer.data(), data.data(), buffer.length()); std::memcpy(buffer.data(), data.data(), sizeof(T) * data.size());
// copy LoD // copy LoD
for (const auto &level : fetch.lod()) { for (const auto &level : fetch.lod()) {
output->lod.emplace_back(level); output->lod.emplace_back(level);
......
...@@ -117,34 +117,6 @@ void GetOneBatch(std::vector<PaddleTensor> *input_slots, DataRecord *data, ...@@ -117,34 +117,6 @@ void GetOneBatch(std::vector<PaddleTensor> *input_slots, DataRecord *data,
input_slots->assign({input_tensor}); input_slots->assign({input_tensor});
} }
void BenchAllData(const std::string &model_path, const std::string &data_file,
const int batch_size, const int repeat) {
NativeConfig config;
config.model_dir = model_path;
config.use_gpu = false;
config.device = 0;
config.specify_input_name = true;
std::vector<PaddleTensor> input_slots, outputs_slots;
DataRecord data(data_file, batch_size);
auto predictor =
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(config);
GetOneBatch(&input_slots, &data, batch_size);
for (int i = 0; i < FLAGS_burning; i++) {
predictor->Run(input_slots, &outputs_slots);
}
Timer timer;
double sum = 0;
for (int i = 0; i < repeat; i++) {
for (size_t bid = 0; bid < data.batched_datas.size(); ++bid) {
GetOneBatch(&input_slots, &data, batch_size);
timer.tic();
predictor->Run(input_slots, &outputs_slots);
sum += timer.toc();
}
}
PrintTime(batch_size, repeat, 1, 0, sum / repeat);
}
const int64_t lac_ref_data[] = {24, 25, 25, 25, 38, 30, 31, 14, 15, 44, 24, 25, const int64_t lac_ref_data[] = {24, 25, 25, 25, 38, 30, 31, 14, 15, 44, 24, 25,
25, 25, 25, 25, 44, 24, 25, 25, 25, 36, 42, 43, 25, 25, 25, 25, 44, 24, 25, 25, 25, 36, 42, 43,
44, 14, 15, 44, 14, 15, 44, 14, 15, 44, 38, 39, 44, 14, 15, 44, 14, 15, 44, 14, 15, 44, 38, 39,
......
...@@ -130,12 +130,13 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler { ...@@ -130,12 +130,13 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryFromPrimitive( std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryFromPrimitive(
const std::shared_ptr<mkldnn::memory> user_weights_memory_p, const std::shared_ptr<mkldnn::memory> user_weights_memory_p,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT std::vector<mkldnn::primitive>& pipeline, // NOLINT
bool is_persistent = false) {
auto user_weights_pd = user_weights_memory_p->get_primitive_desc(); auto user_weights_pd = user_weights_memory_p->get_primitive_desc();
auto weights_pd = conv_pd_->weights_primitive_desc(); auto weights_pd = conv_pd_->weights_primitive_desc();
return this->AcquireMemory(weights_pd, user_weights_pd, return this->AcquireMemory(weights_pd, user_weights_pd,
user_weights_memory_p, "@weights_mem_p", user_weights_memory_p, "@weights_mem_p",
pipeline); pipeline, is_persistent);
} }
std::shared_ptr<mkldnn::memory> AcquireBiasMemoryFromPrimitive( std::shared_ptr<mkldnn::memory> AcquireBiasMemoryFromPrimitive(
...@@ -266,6 +267,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -266,6 +267,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace."); "It must use CPUPlace.");
const bool is_test = ctx.Attr<bool>("is_test");
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>(); ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& mkldnn_engine = dev_ctx.GetEngine();
...@@ -296,6 +299,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -296,6 +299,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides"); std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations"); std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
bool fuse_relu = ctx.Attr<bool>("fuse_relu");
int groups = ctx.Attr<int>("groups"); int groups = ctx.Attr<int>("groups");
// TODO(pzelazko-intel) add support for group convolution and dilation // TODO(pzelazko-intel) add support for group convolution and dilation
...@@ -348,11 +352,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -348,11 +352,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias_tz = paddle::framework::vectorize2int(bias->dims()); bias_tz = paddle::framework::vectorize2int(bias->dims());
auto bias_md = platform::MKLDNNMemDesc( auto bias_md = platform::MKLDNNMemDesc(
bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x); bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x);
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, conv_pd =
strides, paddings, mkldnn_engine); ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, strides,
paddings, mkldnn_engine, fuse_relu);
} else { } else {
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides,
paddings, mkldnn_engine); paddings, mkldnn_engine, fuse_relu);
} }
// Save conv_pd/src_memory/weights_memory for backward pass // Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx.SetBlob(key_conv_pd, conv_pd); dev_ctx.SetBlob(key_conv_pd, conv_pd);
...@@ -371,7 +376,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -371,7 +376,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto src_memory_p = auto src_memory_p =
handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline); handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline);
auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive( auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive(
user_weights_memory_p, pipeline); user_weights_memory_p, pipeline, is_test);
auto dst_memory_p = auto dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data)); handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
...@@ -402,11 +407,26 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -402,11 +407,26 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
} }
private: private:
mkldnn::primitive_attr AddRelu() const {
// Fusion with ReLU layer is executed through the PostOps feature. Create a
// PostOps object and configure it to execute an eltwise relu operation.
mkldnn::primitive_attr conv_attr;
constexpr float scale = 1.0f;
constexpr float negative_slope = 0.0f;
constexpr float placeholder = 0.0f;
mkldnn::post_ops post_operations;
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu,
negative_slope, placeholder);
conv_attr.set_post_ops(post_operations);
return conv_attr;
}
std::unique_ptr<mkldnn::convolution_forward::primitive_desc> std::unique_ptr<mkldnn::convolution_forward::primitive_desc>
ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights, ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights,
const memory::desc& dst, const std::vector<int>& strides, const memory::desc& dst, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& paddings,
const mkldnn::engine& engine) const { const mkldnn::engine& engine,
const bool fuse_relu) const {
memory::dims stride_dims = {strides[0], strides[1]}; memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]}; memory::dims padding_dims = {paddings[0], paddings[1]};
...@@ -415,8 +435,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -415,8 +435,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dst, stride_dims, padding_dims, padding_dims, dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero); mkldnn::padding_kind::zero);
auto p_conv_pd = mkldnn::primitive_attr conv_attr;
new mkldnn::convolution_forward::primitive_desc(conv_desc, engine); if (fuse_relu) {
conv_attr = AddRelu();
}
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine);
return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>( return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
p_conv_pd); p_conv_pd);
...@@ -427,7 +452,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -427,7 +452,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const memory::desc& bias, const memory::desc& dst, const memory::desc& bias, const memory::desc& dst,
const std::vector<int>& strides, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& paddings,
const mkldnn::engine& engine) const { const mkldnn::engine& engine,
const bool fuse_relu) const {
memory::dims stride_dims = {strides[0], strides[1]}; memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]}; memory::dims padding_dims = {paddings[0], paddings[1]};
...@@ -436,8 +462,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -436,8 +462,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias, dst, stride_dims, padding_dims, padding_dims, bias, dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero); mkldnn::padding_kind::zero);
auto p_conv_pd = mkldnn::primitive_attr conv_attr;
new mkldnn::convolution_forward::primitive_desc(conv_desc, engine); if (fuse_relu) {
conv_attr = AddRelu();
}
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine);
return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>( return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
p_conv_pd); p_conv_pd);
......
...@@ -109,6 +109,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( ...@@ -109,6 +109,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
} }
void Conv2DOpMaker::Make() { void Conv2DOpMaker::Make() {
AddAttr<bool>("is_test", "").SetDefault(false);
AddInput( AddInput(
"Input", "Input",
"(Tensor) The input tensor of convolution operator. " "(Tensor) The input tensor of convolution operator. "
...@@ -161,6 +162,8 @@ void Conv2DOpMaker::Make() { ...@@ -161,6 +162,8 @@ void Conv2DOpMaker::Make() {
AddAttr<bool>("use_mkldnn", AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel") "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("fuse_relu", "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<std::string>( AddAttr<std::string>(
"data_format", "data_format",
"(string, default NCHW) Only used in " "(string, default NCHW) Only used in "
......
...@@ -20,6 +20,7 @@ if(WITH_GRPC) ...@@ -20,6 +20,7 @@ if(WITH_GRPC)
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL) DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL)
cc_test(rpc_server_test SRCS rpc_server_test.cc cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL) DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL)
cc_test(varhandle_test SRCS varhandle_test.cc)
return() return()
endif() endif()
......
...@@ -59,40 +59,32 @@ GRPCClient::~GRPCClient() { ...@@ -59,40 +59,32 @@ GRPCClient::~GRPCClient() {
} }
channels_.clear(); channels_.clear();
} }
client_thread_->join(); client_thread_->join();
} }
bool GRPCClient::AsyncSendVar(const std::string& ep, VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& var_name, int64_t time_out) { const std::string& var_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx; const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep; const std::string ep_val = ep;
const std::string var_name_val = var_name; const std::string var_name_val = var_name;
const framework::Scope* p_scope = &scope; const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val); const auto ch = GetChannel(ep_val);
SendProcessor* s = new SendProcessor(ch);
VarHandlePtr h(new VarHandle(ep, "Send", var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out);
framework::AsyncIO([var_name_val, p_ctx, ep_val, p_scope, time_out, ch, framework::AsyncIO([var_name_val, p_scope, p_ctx, s, this] {
this] {
auto* var = p_scope->FindVar(var_name_val); auto* var = p_scope->FindVar(var_name_val);
::grpc::ByteBuffer req; ::grpc::ByteBuffer req;
SerializeToByteBuffer(var_name_val, var, *p_ctx, &req); SerializeToByteBuffer(var_name_val, var, *p_ctx, &req);
// varhandle VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
VarHandle var_h;
var_h.ep = ep_val;
var_h.scope = p_scope;
var_h.name = var_name_val;
var_h.ctx = p_ctx;
var_h.method = "Send";
VLOG(3) << var_h.String() << " begin";
// stub context // stub context
SendProcessor* s = new SendProcessor(ch);
s->Prepare(var_h, time_out);
s->response_call_back_ = nullptr; s->response_call_back_ = nullptr;
auto call = s->stub_g_.PrepareUnaryCall( auto call = s->stub_g_.PrepareUnaryCall(
...@@ -102,13 +94,13 @@ bool GRPCClient::AsyncSendVar(const std::string& ep, ...@@ -102,13 +94,13 @@ bool GRPCClient::AsyncSendVar(const std::string& ep,
}); });
req_count_++; req_count_++;
return true; return h;
} }
void ProcGetResponse(const VarHandle& var_h, void ProcGetResponse(const VarHandle& var_h,
const ::grpc::ByteBuffer& ret_msg) { const ::grpc::ByteBuffer& ret_msg) {
framework::Variable* outvar = nullptr; framework::Variable* outvar = nullptr;
DeserializeFromByteBuffer(ret_msg, *var_h.ctx, var_h.scope, &outvar); DeserializeFromByteBuffer(ret_msg, *var_h.ctx(), var_h.scope(), &outvar);
} }
template <typename T> template <typename T>
...@@ -119,37 +111,30 @@ void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) { ...@@ -119,37 +111,30 @@ void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
result->Swap(&tmp); result->Swap(&tmp);
} }
bool GRPCClient::AsyncGetVar(const std::string& ep, VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& var_name, int64_t time_out) { const std::string& var_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx; const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep; const std::string ep_val = ep;
const std::string var_name_val = var_name; const std::string var_name_val = var_name;
const framework::Scope* p_scope = &scope; const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val); const auto ch = GetChannel(ep_val);
GetProcessor* s = new GetProcessor(ch);
VarHandlePtr h(new VarHandle(ep, "Get", var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out);
framework::AsyncIO([var_name_val, ep_val, p_scope, p_ctx, time_out, ch, framework::AsyncIO([var_name_val, p_scope, p_ctx, s, this] {
this] {
// prepare input // prepare input
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(var_name_val); req.set_varname(var_name_val);
::grpc::ByteBuffer buf; ::grpc::ByteBuffer buf;
RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf); RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
// var handle VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
VarHandle var_h;
var_h.ep = ep_val;
var_h.scope = p_scope;
var_h.name = var_name_val;
var_h.ctx = p_ctx;
var_h.method = "Get";
VLOG(3) << var_h.String() << " begin";
// stub context // stub context
GetProcessor* s = new GetProcessor(ch);
s->Prepare(var_h, time_out);
s->response_call_back_ = ProcGetResponse; s->response_call_back_ = ProcGetResponse;
auto call = s->stub_g_.PrepareUnaryCall( auto call = s->stub_g_.PrepareUnaryCall(
...@@ -160,10 +145,10 @@ bool GRPCClient::AsyncGetVar(const std::string& ep, ...@@ -160,10 +145,10 @@ bool GRPCClient::AsyncGetVar(const std::string& ep,
req_count_++; req_count_++;
return true; return h;
} }
bool GRPCClient::AsyncPrefetchVar(const std::string& ep, VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& in_var_name, const std::string& in_var_name,
...@@ -175,27 +160,21 @@ bool GRPCClient::AsyncPrefetchVar(const std::string& ep, ...@@ -175,27 +160,21 @@ bool GRPCClient::AsyncPrefetchVar(const std::string& ep,
const std::string out_var_name_val = out_var_name; const std::string out_var_name_val = out_var_name;
const framework::Scope* p_scope = &scope; const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val); const auto ch = GetChannel(ep_val);
GetProcessor* s = new GetProcessor(ch);
VarHandlePtr h(
new VarHandle(ep, "Prefetch", out_var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out);
framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx, framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
time_out, ch, this] { time_out, s, this] {
auto* var = p_scope->FindVar(in_var_name_val); auto* var = p_scope->FindVar(in_var_name_val);
::grpc::ByteBuffer req; ::grpc::ByteBuffer req;
SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val); SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val);
// var handle VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
VarHandle var_h;
var_h.ep = ep_val;
var_h.scope = p_scope;
var_h.name = out_var_name_val;
var_h.ctx = p_ctx;
var_h.method = "Prefetch";
VLOG(3) << var_h.String() << " begin";
// stub context // stub context
GetProcessor* s = new GetProcessor(ch);
s->Prepare(var_h, time_out);
s->response_call_back_ = ProcGetResponse; s->response_call_back_ = ProcGetResponse;
auto call = s->stub_g_.PrepareUnaryCall( auto call = s->stub_g_.PrepareUnaryCall(
...@@ -206,56 +185,68 @@ bool GRPCClient::AsyncPrefetchVar(const std::string& ep, ...@@ -206,56 +185,68 @@ bool GRPCClient::AsyncPrefetchVar(const std::string& ep,
}); });
req_count_++; req_count_++;
return true; return h;
} }
void GRPCClient::AsyncSendBatchBarrier(const std::string& ep, VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out) { int64_t time_out) {
const auto ch = GetChannel(ep); const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
s->Prepare(time_out); VarHandlePtr h(new VarHandle(ep, "BatchBarrier", BATCH_BARRIER_MESSAGE,
nullptr, nullptr));
s->Prepare(h, time_out);
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(BATCH_BARRIER_MESSAGE); req.set_varname(BATCH_BARRIER_MESSAGE);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s)); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++; req_count_++;
return h;
} }
void GRPCClient::AsyncSendFetchBarrier(const std::string& ep, VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out) { int64_t time_out) {
const auto ch = GetChannel(ep); const auto ch = GetChannel(ep);
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch); FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
s->Prepare(time_out); VarHandlePtr h(new VarHandle(ep, "FetchBarrier", FETCH_BARRIER_MESSAGE,
nullptr, nullptr));
s->Prepare(h, time_out);
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(FETCH_BARRIER_MESSAGE); req.set_varname(FETCH_BARRIER_MESSAGE);
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_); auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s)); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++; req_count_++;
return h;
} }
void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) { VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
int64_t time_out) {
const auto ch = GetChannel(ep); const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
s->Prepare(time_out); VarHandlePtr h(
new VarHandle(ep, "SendComplete", COMPLETE_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out);
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(COMPLETE_MESSAGE); req.set_varname(COMPLETE_MESSAGE);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s)); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++; req_count_++;
return h;
} }
void GRPCClient::AsyncCheckpointNotify(const std::string& ep, VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dir, const std::string& dir,
int64_t time_out) { int64_t time_out) {
const auto ch = GetChannel(ep); const auto ch = GetChannel(ep);
CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch); CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch);
s->Prepare(time_out); VarHandlePtr h(new VarHandle(ep, "CheckPointNotify", CHECKPOINT_SAVE_MESSAGE,
nullptr, nullptr));
s->Prepare(h, time_out);
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(CHECKPOINT_SAVE_MESSAGE); req.set_varname(CHECKPOINT_SAVE_MESSAGE);
...@@ -264,6 +255,7 @@ void GRPCClient::AsyncCheckpointNotify(const std::string& ep, ...@@ -264,6 +255,7 @@ void GRPCClient::AsyncCheckpointNotify(const std::string& ep,
auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_); auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s)); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++; req_count_++;
return h;
} }
bool GRPCClient::Wait() { bool GRPCClient::Wait() {
...@@ -276,25 +268,28 @@ void GRPCClient::Proceed() { ...@@ -276,25 +268,28 @@ void GRPCClient::Proceed() {
void* tag = nullptr; void* tag = nullptr;
bool ok = false; bool ok = false;
VLOG(3) << "GRPCClient Proceed begin";
while (!stopped_ && cq_.Next(&tag, &ok)) { while (!stopped_ && cq_.Next(&tag, &ok)) {
BaseProcessor* c = static_cast<BaseProcessor*>(tag); BaseProcessor* c = static_cast<BaseProcessor*>(tag);
GPR_ASSERT(ok); GPR_ASSERT(ok);
PADDLE_ENFORCE(c); PADDLE_ENFORCE(c);
if (c->status_.ok()) { if (c->status_.ok()) {
VLOG(3) << c->var_h_.String() << " process"; VLOG(3) << c->GetVarHandlePtr()->String() << " process";
c->Process(); c->Process();
} else if (c->status_.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) { } else if (c->status_.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) {
LOG(ERROR) << c->var_h_.String() LOG(ERROR) << c->GetVarHandlePtr()->String()
<< " meets grpc error:" << c->status_.error_message(); << " meets grpc error:" << c->status_.error_message();
{ {
std::lock_guard<std::mutex> lk(sync_mutex_); std::lock_guard<std::mutex> lk(sync_mutex_);
ok_ = false; ok_ = false;
} }
sync_cond_.notify_all(); c->Finish(false);
} else { } else {
LOG(FATAL) << c->var_h_.String() LOG(FATAL) << c->GetVarHandlePtr()->String()
<< " meets grpc error:" << c->status_.error_message(); << " meets grpc error:" << c->status_.error_message();
c->Finish(false);
} }
delete c; delete c;
{ {
std::lock_guard<std::mutex> lk(sync_mutex_); std::lock_guard<std::mutex> lk(sync_mutex_);
...@@ -302,6 +297,7 @@ void GRPCClient::Proceed() { ...@@ -302,6 +297,7 @@ void GRPCClient::Proceed() {
} }
sync_cond_.notify_all(); sync_cond_.notify_all();
} }
VLOG(3) << "GRPCClient Proceed end";
} }
std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) { std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) {
......
...@@ -53,15 +53,14 @@ void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg); ...@@ -53,15 +53,14 @@ void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg);
class BaseProcessor { class BaseProcessor {
public: public:
explicit BaseProcessor(std::shared_ptr<grpc::Channel> ch) { BaseProcessor() { context_ = nullptr; }
context_ = nullptr;
}
virtual ~BaseProcessor() {} virtual ~BaseProcessor() {}
virtual void Prepare(const VarHandle& var_info, int64_t time_out) { virtual void Prepare(VarHandlePtr h, int64_t time_out) {
var_h_ = h;
context_.reset(new grpc::ClientContext()); context_.reset(new grpc::ClientContext());
var_h_ = var_info;
context_->set_wait_for_ready(true); context_->set_wait_for_ready(true);
if (time_out) { if (time_out) {
std::chrono::system_clock::time_point deadline = std::chrono::system_clock::time_point deadline =
...@@ -71,21 +70,21 @@ class BaseProcessor { ...@@ -71,21 +70,21 @@ class BaseProcessor {
} }
} }
virtual void Prepare(int64_t time_out) { void Process() {
context_.reset(new grpc::ClientContext()); ProcessImpl();
context_->set_wait_for_ready(true); var_h_->Finish(true);
std::chrono::system_clock::time_point deadline =
std::chrono::system_clock::now() + std::chrono::milliseconds(time_out);
context_->set_deadline(deadline);
} }
virtual void Process() = 0; VarHandlePtr GetVarHandlePtr() { return var_h_; }
bool Wait() { return var_h_->Wait(); }
void Finish(bool ok) { return var_h_->Finish(ok); }
virtual void ProcessImpl() = 0;
std::unique_ptr<grpc::ClientContext> context_; std::unique_ptr<grpc::ClientContext> context_;
grpc::Status status_; grpc::Status status_;
VarHandle var_h_;
protected:
VarHandlePtr var_h_;
}; };
typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)> typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)>
...@@ -94,13 +93,13 @@ typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)> ...@@ -94,13 +93,13 @@ typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)>
class SendProcessor : public BaseProcessor { class SendProcessor : public BaseProcessor {
public: public:
explicit SendProcessor(std::shared_ptr<grpc::Channel> ch) explicit SendProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch), stub_g_(ch) {} : BaseProcessor(), stub_g_(ch) {}
virtual ~SendProcessor() {} virtual ~SendProcessor() {}
virtual void Process() { void ProcessImpl() override {
if (response_call_back_) { if (response_call_back_) {
response_call_back_(var_h_, reply_); response_call_back_(*var_h_.get(), reply_);
} }
} }
...@@ -115,13 +114,13 @@ typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)> ...@@ -115,13 +114,13 @@ typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)>
class GetProcessor : public BaseProcessor { class GetProcessor : public BaseProcessor {
public: public:
explicit GetProcessor(std::shared_ptr<grpc::Channel> ch) explicit GetProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch), stub_g_(ch) {} : BaseProcessor(), stub_g_(ch) {}
virtual ~GetProcessor() {} virtual ~GetProcessor() {}
virtual void Process() { void ProcessImpl() override {
if (response_call_back_) { if (response_call_back_) {
response_call_back_(var_h_, reply_); response_call_back_(*var_h_.get(), reply_);
} }
} }
...@@ -133,13 +132,13 @@ class GetProcessor : public BaseProcessor { ...@@ -133,13 +132,13 @@ class GetProcessor : public BaseProcessor {
class BatchBarrierProcessor : public BaseProcessor { class BatchBarrierProcessor : public BaseProcessor {
public: public:
explicit BatchBarrierProcessor(std::shared_ptr<grpc::Channel> ch) explicit BatchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch) { : BaseProcessor() {
stub_ = sendrecv::SendRecvService::NewStub(ch); stub_ = sendrecv::SendRecvService::NewStub(ch);
} }
virtual ~BatchBarrierProcessor() {} virtual ~BatchBarrierProcessor() {}
virtual void Process() {} void ProcessImpl() override {}
sendrecv::VoidMessage reply_; sendrecv::VoidMessage reply_;
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_; std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
}; };
...@@ -147,13 +146,13 @@ class BatchBarrierProcessor : public BaseProcessor { ...@@ -147,13 +146,13 @@ class BatchBarrierProcessor : public BaseProcessor {
class FetchBarrierProcessor : public BaseProcessor { class FetchBarrierProcessor : public BaseProcessor {
public: public:
explicit FetchBarrierProcessor(std::shared_ptr<grpc::Channel> ch) explicit FetchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch) { : BaseProcessor() {
stub_ = sendrecv::SendRecvService::NewStub(ch); stub_ = sendrecv::SendRecvService::NewStub(ch);
} }
virtual ~FetchBarrierProcessor() {} virtual ~FetchBarrierProcessor() {}
virtual void Process() {} void ProcessImpl() override {}
sendrecv::VariableMessage reply_; sendrecv::VariableMessage reply_;
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_; std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
}; };
...@@ -161,13 +160,13 @@ class FetchBarrierProcessor : public BaseProcessor { ...@@ -161,13 +160,13 @@ class FetchBarrierProcessor : public BaseProcessor {
class CheckpointNotifyProcessor : public BaseProcessor { class CheckpointNotifyProcessor : public BaseProcessor {
public: public:
explicit CheckpointNotifyProcessor(std::shared_ptr<grpc::Channel> ch) explicit CheckpointNotifyProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch) { : BaseProcessor() {
stub_ = sendrecv::SendRecvService::NewStub(ch); stub_ = sendrecv::SendRecvService::NewStub(ch);
} }
virtual ~CheckpointNotifyProcessor() {} virtual ~CheckpointNotifyProcessor() {}
virtual void Process() {} void ProcessImpl() override {}
sendrecv::VoidMessage reply_; sendrecv::VoidMessage reply_;
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_; std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
}; };
...@@ -177,32 +176,37 @@ class GRPCClient : public RPCClient { ...@@ -177,32 +176,37 @@ class GRPCClient : public RPCClient {
GRPCClient() : ok_(true), completed_(false), stopped_(false) {} GRPCClient() : ok_(true), completed_(false), stopped_(false) {}
virtual ~GRPCClient(); virtual ~GRPCClient();
bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx, VarHandlePtr AsyncSendVar(const std::string& ep,
const framework::Scope& scope, const std::string& var_name, const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) override; int64_t time_out = FLAGS_rpc_deadline) override;
bool AsyncGetVar(const std::string& ep, const platform::DeviceContext& ctx, VarHandlePtr AsyncGetVar(const std::string& ep,
const framework::Scope& scope, const std::string& var_name, const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) override; int64_t time_out = FLAGS_rpc_deadline) override;
bool AsyncPrefetchVar(const std::string& ep, VarHandlePtr AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& in_var_name, const std::string& in_var_name,
const std::string& out_var_name, const std::string& out_var_name,
int64_t time_out = FLAGS_rpc_deadline) override; int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncSendBatchBarrier(const std::string& ep, VarHandlePtr AsyncSendBatchBarrier(
int64_t time_out = FLAGS_rpc_deadline) override; const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncSendFetchBarrier(const std::string& ep, VarHandlePtr AsyncSendFetchBarrier(
int64_t time_out = FLAGS_rpc_deadline) override; const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncCheckpointNotify(const std::string& ep, const std::string& dir, VarHandlePtr AsyncCheckpointNotify(
const std::string& ep, const std::string& dir,
int64_t time_out = FLAGS_rpc_deadline) override; int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncSendComplete(const std::string& ep, VarHandlePtr AsyncSendComplete(
int64_t time_out = FLAGS_rpc_deadline) override; const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override;
bool Wait() override; bool Wait() override;
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -49,23 +50,77 @@ constexpr char kRequestPassBarrier[] = "RequestPassBarrier"; ...@@ -49,23 +50,77 @@ constexpr char kRequestPassBarrier[] = "RequestPassBarrier";
class RPCServer; class RPCServer;
struct VarHandle { class VarHandle {
// RPC endpoint. public:
std::string ep; VarHandle(const std::string ep, const std::string& method,
const platform::DeviceContext* ctx; const std::string& name,
const framework::Scope* scope; const platform::DeviceContext* p_ctx = nullptr,
// Variable name. const framework::Scope* p_scope = nullptr)
std::string name; : ok_(kVarHandleDefaultState) {
// RPC method name. ep_ = ep;
std::string method; ctx_ = p_ctx;
scope_ = p_scope;
name_ = name;
method_ = method;
}
virtual ~VarHandle() {}
public:
bool Wait() {
{
std::unique_lock<std::mutex> lk(sync_mutex_);
wait_cond_.wait(lk, [this] { return ok_ != kVarHandleDefaultState; });
}
VLOG(7) << "VarHandle wait:" << ok_;
return ok_ != 0;
}
void Finish(bool ok) {
{
std::unique_lock<std::mutex> lk(sync_mutex_);
ok_ = ok;
}
VLOG(7) << "VarHandle finish:" << ok;
wait_cond_.notify_all();
}
std::string String() const { std::string String() const {
std::ostringstream s; std::ostringstream s;
s << method << " name:[" << name << "], ep:[" << ep << "]"; s << method_ << " name:[" << name_ << "], ep:[" << ep_ << "], ok:[" << ok_
<< "]";
return s.str(); return s.str();
} }
std::string ep() const { return ep_; }
const platform::DeviceContext* ctx() const { return ctx_; }
const framework::Scope* scope() const { return scope_; }
std::string name() const { return name_; }
std::string method() const { return method_; }
protected:
// RPC endpoint.
std::string ep_;
const platform::DeviceContext* ctx_;
const framework::Scope* scope_;
// Variable name.
std::string name_;
// RPC method name.
std::string method_;
protected:
std::mutex sync_mutex_;
std::condition_variable wait_cond_;
int ok_;
static const int kVarHandleDefaultState = -1;
private:
DISABLE_COPY_AND_ASSIGN(VarHandle);
}; };
typedef std::shared_ptr<VarHandle> VarHandlePtr;
class RequestHandler { class RequestHandler {
public: public:
explicit RequestHandler(bool sync_mode) explicit RequestHandler(bool sync_mode)
......
...@@ -14,12 +14,14 @@ ...@@ -14,12 +14,14 @@
#pragma once #pragma once
#include <condition_variable> // NOLINT
#include <string> #include <string>
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
DECLARE_int32(rpc_deadline); DECLARE_int32(rpc_deadline);
...@@ -31,37 +33,36 @@ class RPCClient { ...@@ -31,37 +33,36 @@ class RPCClient {
public: public:
RPCClient() {} RPCClient() {}
virtual ~RPCClient() {} virtual ~RPCClient() {}
virtual bool AsyncSendVar(const std::string& ep, virtual VarHandlePtr AsyncSendVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& var_name, const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) = 0; int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual bool AsyncGetVar(const std::string& ep, virtual VarHandlePtr AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& var_name, const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) = 0; int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual bool AsyncPrefetchVar(const std::string& ep, virtual VarHandlePtr AsyncPrefetchVar(
const platform::DeviceContext& ctx, const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope, const std::string& in_var_name,
const std::string& in_var_name,
const std::string& out_var_name, const std::string& out_var_name,
int64_t time_out = FLAGS_rpc_deadline) = 0; int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual void AsyncSendBatchBarrier(const std::string& ep, virtual VarHandlePtr AsyncSendBatchBarrier(
int64_t time_out = FLAGS_rpc_deadline) = 0; const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual void AsyncSendFetchBarrier(const std::string& ep, virtual VarHandlePtr AsyncSendFetchBarrier(
int64_t time_out = FLAGS_rpc_deadline) = 0; const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual void AsyncCheckpointNotify(const std::string& ep, virtual VarHandlePtr AsyncCheckpointNotify(
const std::string& dir, const std::string& ep, const std::string& dir,
int64_t time_out = FLAGS_rpc_deadline) = 0; int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual void AsyncSendComplete(const std::string& ep, virtual VarHandlePtr AsyncSendComplete(
int64_t time_out = FLAGS_rpc_deadline) = 0; const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) = 0;
// Complete tells all the pserver instances that finishe the training, // Complete tells all the pserver instances that finishe the training,
// the pserver can reduce it's barrier count, and continue to train // the pserver can reduce it's barrier count, and continue to train
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <unistd.h>
#include <string>
#include <thread> // NOLINT
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
using paddle::operators::distributed::VarHandlePtr;
using paddle::operators::distributed::VarHandle;
void WaitTrue(VarHandlePtr s) { EXPECT_TRUE(s->Wait()); }
void WaitFalse(VarHandlePtr s) { EXPECT_FALSE(s->Wait()); }
TEST(VarHandle, Run) {
std::vector<VarHandlePtr> a;
for (int i = 0; i < 12; i++) {
VarHandlePtr s(new VarHandle("", "", "", nullptr, nullptr));
a.push_back(s);
}
std::vector<std::unique_ptr<std::thread>> t;
for (int i = 0; i < 6; i++) {
t.emplace_back(new std::thread(WaitFalse, a[i]));
}
for (int i = 0; i < 6; i++) {
a[i]->Finish(false);
t[i]->join();
}
for (int i = 6; i < 12; i++) {
t.emplace_back(new std::thread(WaitTrue, a[i]));
}
for (int i = 6; i < 12; i++) {
a[i]->Finish(true);
t[i]->join();
}
}
...@@ -44,16 +44,20 @@ class PrefetchOp : public framework::OperatorBase { ...@@ -44,16 +44,20 @@ class PrefetchOp : public framework::OperatorBase {
distributed::RPCClient* rpc_client = distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient::GetInstance<RPCCLIENT_T>();
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) { if (NeedSend(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i] << " to get " VLOG(3) << "sending " << ins[i] << " to " << epmap[i] << " to get "
<< outs[i] << " back"; << outs[i] << " back";
rpc_client->AsyncPrefetchVar(epmap[i], ctx, scope, ins[i], outs[i]); rets.push_back(rpc_client->AsyncPrefetchVar(epmap[i], ctx, scope,
ins[i], outs[i]));
} else { } else {
VLOG(3) << "don't send no-initialied variable: " << ins[i]; VLOG(3) << "don't send no-initialied variable: " << ins[i];
} }
} }
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
}
} }
}; };
......
...@@ -44,12 +44,15 @@ class RecvOp : public framework::OperatorBase { ...@@ -44,12 +44,15 @@ class RecvOp : public framework::OperatorBase {
distributed::RPCClient* rpc_client = distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient::GetInstance<RPCCLIENT_T>();
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < outs.size(); i++) { for (size_t i = 0; i < outs.size(); i++) {
VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; VLOG(3) << "getting " << outs[i] << " from " << epmap[i];
rpc_client->AsyncGetVar(epmap[i], ctx, scope, outs[i]); rets.push_back(rpc_client->AsyncGetVar(epmap[i], ctx, scope, outs[i]));
} }
if (sync_mode) { if (sync_mode) {
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
}
} }
} }
}; };
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include <future> // NOLINT #include <future> // NOLINT
#include <ostream> #include <ostream>
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -45,18 +46,19 @@ class SendOp : public framework::OperatorBase { ...@@ -45,18 +46,19 @@ class SendOp : public framework::OperatorBase {
distributed::RPCClient* rpc_client = distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient::GetInstance<RPCCLIENT_T>();
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) { if (NeedSend(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i]; VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
// TODO(Yancey1989): we need to use an IO threadpool which has rets.push_back(rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i]));
// a larger number of threads than the computing threadpool.
rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i]);
} else { } else {
VLOG(3) << "don't send no-initialied variable: " << ins[i]; VLOG(3) << "don't send no-initialied variable: " << ins[i];
} }
} }
if (sync_send) { if (sync_send) {
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
}
} }
} }
}; };
......
...@@ -192,7 +192,8 @@ class MKLDNNHandler { ...@@ -192,7 +192,8 @@ class MKLDNNHandler {
mkldnn::memory::primitive_desc& user_mpd, // NOLINT mkldnn::memory::primitive_desc& user_mpd, // NOLINT
const std::shared_ptr<mkldnn::memory> user_memory_p, const std::shared_ptr<mkldnn::memory> user_memory_p,
const std::string& suffix, const std::string& suffix,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT std::vector<mkldnn::primitive>& pipeline, // NOLINT
bool is_persistent = false) {
// create reorder primitive if the input format is not the preferred one // create reorder primitive if the input format is not the preferred one
auto local_key = key_ + suffix; auto local_key = key_ + suffix;
auto key_reorder_p = key_ + suffix + "reorder_p"; auto key_reorder_p = key_ + suffix + "reorder_p";
...@@ -213,7 +214,7 @@ class MKLDNNHandler { ...@@ -213,7 +214,7 @@ class MKLDNNHandler {
pipeline.push_back(*reorder_p); pipeline.push_back(*reorder_p);
} }
dev_ctx_.SetBlob(local_key, target_memory_p); dev_ctx_.SetBlob(local_key, target_memory_p);
} else { } else if (!is_persistent) {
// Make reorder if needed // Make reorder if needed
auto reorder_p = std::static_pointer_cast<mkldnn::reorder>( auto reorder_p = std::static_pointer_cast<mkldnn::reorder>(
dev_ctx_.GetBlob(key_reorder_p)); dev_ctx_.GetBlob(key_reorder_p));
......
...@@ -128,6 +128,13 @@ class ParallelExecutor(object): ...@@ -128,6 +128,13 @@ class ParallelExecutor(object):
os.environ.get('CPU_NUM', multiprocessing.cpu_count())) os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
exec_strategy.num_threads = cpu_num * 2 exec_strategy.num_threads = cpu_num * 2
# Set 1 thread num under nccl2 distribute
# env to make sure all gpus run ops in same order.
if num_trainers > 1:
assert (use_cuda)
# FIXME(gongwb): avoid this set.
exec_strategy.num_threads = 1
if build_strategy is None: if build_strategy is None:
build_strategy = BuildStrategy() build_strategy = BuildStrategy()
......
...@@ -60,12 +60,46 @@ class InferenceTranspiler(object): ...@@ -60,12 +60,46 @@ class InferenceTranspiler(object):
if not isinstance(scope, core.Scope): if not isinstance(scope, core.Scope):
raise TypeError("scope should be as Scope type or None") raise TypeError("scope should be as Scope type or None")
use_mkldnn = bool(os.getenv("FLAGS_use_mkldnn", False)) use_mkldnn = bool(os.getenv("FLAGS_use_mkldnn", False))
self._fuse_batch_norm(program, place, scope) self._fuse_batch_norm(program, place, scope)
if use_mkldnn: if use_mkldnn:
self._fuse_relu_mkldnn(program)
self._fuse_conv_bias_mkldnn(program) self._fuse_conv_bias_mkldnn(program)
self._fuse_conv_relu_mkldnn(program)
self._fuse_bn_relu_mkldnn(program)
def _fuse_conv_relu_mkldnn(self, program):
'''
Transpile the program by fused relu activation for MKLDNN program.
Relu activation following convolution OP can be fused by adding
'fuse_relu' attribute to convolution OP.
The result of fuse is:
- before:
- conv->relu->any_other_op
- after:
- conv->any_other_op
:param program: program to transpile
:type program: Program
'''
self.block = program.block(0)
i = 0
while i < len(self.block.ops):
current_op = self.block.ops[i]
if current_op.type in ['conv2d']:
next_op = self.block.ops[i + 1]
if next_op.type == 'relu':
# modify conv OP to include relu
current_op.set_attr("fuse_relu", True)
# remove conv OP
self.block._remove_op(i + 1)
i = i + 1
def _fuse_relu_mkldnn(self, program): # TODO(luotao): use clone() method to flush the program.desc in force,
# since some large program.desc will not be flushed immediately.
# And a better solution will be considered later.
program = program.clone()
def _fuse_bn_relu_mkldnn(self, program):
''' '''
Transpile the program by fused relu activation for MKLDNN program. Transpile the program by fused relu activation for MKLDNN program.
...@@ -160,7 +194,6 @@ class InferenceTranspiler(object): ...@@ -160,7 +194,6 @@ class InferenceTranspiler(object):
self.block._remove_op(i + 1) # Remove old conv self.block._remove_op(i + 1) # Remove old conv
self.block._remove_op(i + 1) # Remove elementwise_add self.block._remove_op(i + 1) # Remove elementwise_add
i = i + 1 i = i + 1
i = i + 1
self._remove_unused_var() self._remove_unused_var()
# TODO(luotao): use clone() method to flush the program.desc in force, # TODO(luotao): use clone() method to flush the program.desc in force,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册