diff --git a/benchmark/fluid/models/resnet.py b/benchmark/fluid/models/resnet.py index ae1baa48e17e40448e457052fd1464b9604a2128..d71b855612ae32083b2b2e3448db3749c340633b 100644 --- a/benchmark/fluid/models/resnet.py +++ b/benchmark/fluid/models/resnet.py @@ -20,6 +20,7 @@ import functools import numpy as np import time import os +import math import cProfile, pstats, StringIO @@ -27,128 +28,120 @@ import paddle import paddle.fluid as fluid import paddle.fluid.core as core import paddle.fluid.profiler as profiler -# from recordio_converter import imagenet_train, imagenet_test from imagenet_reader import train, val +train_parameters = { + "input_size": [3, 224, 224], + "input_mean": [0.485, 0.456, 0.406], + "input_std": [0.229, 0.224, 0.225], + "learning_strategy": { + "name": "piecewise_decay", + "batch_size": 256, + "epochs": [30, 60, 90], + "steps": [0.1, 0.01, 0.001, 0.0001] + } +} + + +class ResNet(): + def __init__(self, layers=50, is_train=True): + self.params = train_parameters + self.layers = layers + self.is_train = is_train + + def net(self, input, class_dim=1000): + layers = self.layers + supported_layers = [50, 101, 152] + assert layers in supported_layers, \ + "supported layers are {} but input layer is {}".format(supported_layers, layers) + + if layers == 50: + depth = [3, 4, 6, 3] + elif layers == 101: + depth = [3, 4, 23, 3] + elif layers == 152: + depth = [3, 8, 36, 3] + num_filters = [64, 128, 256, 512] + + conv = self.conv_bn_layer( + input=input, num_filters=64, filter_size=7, stride=2, act='relu') + conv = fluid.layers.pool2d( + input=conv, + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type='max') + + for block in range(len(depth)): + for i in range(depth[block]): + conv = self.bottleneck_block( + input=conv, + num_filters=num_filters[block], + stride=2 if i == 0 and block != 0 else 1) + + pool = fluid.layers.pool2d( + input=conv, pool_size=7, pool_type='avg', global_pooling=True) + stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) + out = fluid.layers.fc(input=pool, + size=class_dim, + act='softmax', + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, + stdv))) + return out + + def conv_bn_layer(self, + input, + num_filters, + filter_size, + stride=1, + groups=1, + act=None): + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + bias_attr=False) + return fluid.layers.batch_norm( + input=conv, act=act, is_test=not self.is_train) + + def shortcut(self, input, ch_out, stride): + ch_in = input.shape[1] + if ch_in != ch_out or stride != 1: + return self.conv_bn_layer(input, ch_out, 1, stride) + else: + return input -def conv_bn_layer(input, - ch_out, - filter_size, - stride, - padding, - act='relu', - is_train=True): - conv1 = fluid.layers.conv2d( - input=input, - filter_size=filter_size, - num_filters=ch_out, - stride=stride, - padding=padding, - act=None, - bias_attr=False) - return fluid.layers.batch_norm(input=conv1, act=act, is_test=not is_train) - - -def shortcut(input, ch_out, stride, is_train=True): - ch_in = input.shape[1] # if args.data_format == 'NCHW' else input.shape[-1] - if ch_in != ch_out: - return conv_bn_layer( - input, ch_out, 1, stride, 0, None, is_train=is_train) - else: - return input - - -def basicblock(input, ch_out, stride, is_train=True): - short = shortcut(input, ch_out, stride, is_train=is_train) - conv1 = conv_bn_layer(input, ch_out, 3, stride, 1, is_train=is_train) - conv2 = conv_bn_layer(conv1, ch_out, 3, 1, 1, act=None, is_train=is_train) - return fluid.layers.elementwise_add(x=short, y=conv2, act='relu') - - -def bottleneck(input, ch_out, stride, is_train=True): - short = shortcut(input, ch_out * 4, stride, is_train=is_train) - conv1 = conv_bn_layer(input, ch_out, 1, stride, 0, is_train=is_train) - conv2 = conv_bn_layer(conv1, ch_out, 3, 1, 1, is_train=is_train) - conv3 = conv_bn_layer( - conv2, ch_out * 4, 1, 1, 0, act=None, is_train=is_train) - return fluid.layers.elementwise_add(x=short, y=conv3, act='relu') - - -def layer_warp(block_func, input, ch_out, count, stride): - res_out = block_func(input, ch_out, stride) - for i in range(1, count): - res_out = block_func(res_out, ch_out, 1) - return res_out - + def bottleneck_block(self, input, num_filters, stride): + conv0 = self.conv_bn_layer( + input=input, num_filters=num_filters, filter_size=1, act='relu') + conv1 = self.conv_bn_layer( + input=conv0, + num_filters=num_filters, + filter_size=3, + stride=stride, + act='relu') + conv2 = self.conv_bn_layer( + input=conv1, num_filters=num_filters * 4, filter_size=1, act=None) -def resnet_imagenet(input, - class_dim, - depth=50, - data_format='NCHW', - is_train=True): + short = self.shortcut(input, num_filters * 4, stride) - cfg = { - 18: ([2, 2, 2, 1], basicblock), - 34: ([3, 4, 6, 3], basicblock), - 50: ([3, 4, 6, 3], bottleneck), - 101: ([3, 4, 23, 3], bottleneck), - 152: ([3, 8, 36, 3], bottleneck) - } - stages, block_func = cfg[depth] - conv1 = conv_bn_layer(input, ch_out=64, filter_size=7, stride=2, padding=3) - pool1 = fluid.layers.pool2d( - input=conv1, pool_type='avg', pool_size=3, pool_stride=2) - res1 = layer_warp(block_func, pool1, 64, stages[0], 1) - res2 = layer_warp(block_func, res1, 128, stages[1], 2) - res3 = layer_warp(block_func, res2, 256, stages[2], 2) - res4 = layer_warp(block_func, res3, 512, stages[3], 2) - pool2 = fluid.layers.pool2d( - input=res4, - pool_size=7, - pool_type='avg', - pool_stride=1, - global_pooling=True) - out = fluid.layers.fc(input=pool2, size=class_dim, act='softmax') - return out - - -def resnet_cifar10(input, class_dim, depth=32, data_format='NCHW'): - assert (depth - 2) % 6 == 0 - - n = (depth - 2) // 6 - - conv1 = conv_bn_layer( - input=input, ch_out=16, filter_size=3, stride=1, padding=1) - res1 = layer_warp(basicblock, conv1, 16, n, 1) - res2 = layer_warp(basicblock, res1, 32, n, 2) - res3 = layer_warp(basicblock, res2, 64, n, 2) - pool = fluid.layers.pool2d( - input=res3, pool_size=8, pool_type='avg', pool_stride=1) - out = fluid.layers.fc(input=pool, size=class_dim, act='softmax') - return out + return fluid.layers.elementwise_add(x=short, y=conv2, act='relu') def _model_reader_dshape_classdim(args, is_train): - model = resnet_cifar10 + model = None reader = None - if args.data_set == "cifar10": - class_dim = 10 - if args.data_format == 'NCHW': - dshape = [3, 32, 32] - else: - dshape = [32, 32, 3] - model = resnet_cifar10 - if is_train: - reader = paddle.dataset.cifar.train10() - else: - reader = paddle.dataset.cifar.test10() - elif args.data_set == "flowers": + if args.data_set == "flowers": class_dim = 102 if args.data_format == 'NCHW': dshape = [3, 224, 224] else: dshape = [224, 224, 3] - model = resnet_imagenet if is_train: reader = paddle.dataset.flowers.train() else: @@ -159,7 +152,6 @@ def _model_reader_dshape_classdim(args, is_train): dshape = [3, 224, 224] else: dshape = [224, 224, 3] - model = resnet_imagenet if not args.data_path: raise Exception( "Must specify --data_path when training with imagenet") @@ -173,12 +165,11 @@ def _model_reader_dshape_classdim(args, is_train): reader = train(xmap=False) else: reader = val(xmap=False) - return model, reader, dshape, class_dim + return reader, dshape, class_dim def get_model(args, is_train, main_prog, startup_prog): - model, reader, dshape, class_dim = _model_reader_dshape_classdim(args, - is_train) + reader, dshape, class_dim = _model_reader_dshape_classdim(args, is_train) pyreader = None trainer_count = int(os.getenv("PADDLE_TRAINERS")) @@ -198,7 +189,8 @@ def get_model(args, is_train, main_prog, startup_prog): label = fluid.layers.data( name='label', shape=[1], dtype='int64') - predict = model(input, class_dim, is_train=is_train) + model = ResNet(is_train=is_train) + predict = model.net(input, class_dim=class_dim) cost = fluid.layers.cross_entropy(input=predict, label=label) avg_cost = fluid.layers.mean(x=cost) @@ -216,15 +208,14 @@ def get_model(args, is_train, main_prog, startup_prog): total_images = 1281167 / trainer_count step = int(total_images / args.batch_size + 1) - epochs = [30, 60, 80, 90] + epochs = [30, 60, 90] bd = [step * e for e in epochs] base_lr = args.learning_rate lr = [] lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)] optimizer = fluid.optimizer.Momentum( - learning_rate=base_lr, - #learning_rate=fluid.layers.piecewise_decay( - # boundaries=bd, values=lr), + learning_rate=fluid.layers.piecewise_decay( + boundaries=bd, values=lr), momentum=0.9, regularization=fluid.regularizer.L2Decay(1e-4)) optimizer.minimize(avg_cost) diff --git a/paddle/fluid/inference/api/api_impl.cc b/paddle/fluid/inference/api/api_impl.cc index bd9b4b1a814f995e3979105f5b9830b95fd8ea7d..6fe13ed027de403bdc21882c26225bcd4cc7e49a 100644 --- a/paddle/fluid/inference/api/api_impl.cc +++ b/paddle/fluid/inference/api/api_impl.cc @@ -262,7 +262,7 @@ void NativePaddlePredictor::GetFetchOne(const framework::LoDTensor &fetch, if (buffer.empty() || buffer.length() < sizeof(T) * data.size()) { buffer.Resize(sizeof(T) * data.size()); } - std::memcpy(buffer.data(), data.data(), buffer.length()); + std::memcpy(buffer.data(), data.data(), sizeof(T) * data.size()); // copy LoD for (const auto &level : fetch.lod()) { output->lod.emplace_back(level); diff --git a/paddle/fluid/inference/tests/api/analyzer_lac_tester.cc b/paddle/fluid/inference/tests/api/analyzer_lac_tester.cc index 522d870db8583aac4006e8cdb7909625c3feb34b..7e00cb20ad0ce052a84d5491b0cdf167f0768081 100644 --- a/paddle/fluid/inference/tests/api/analyzer_lac_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_lac_tester.cc @@ -117,34 +117,6 @@ void GetOneBatch(std::vector *input_slots, DataRecord *data, input_slots->assign({input_tensor}); } -void BenchAllData(const std::string &model_path, const std::string &data_file, - const int batch_size, const int repeat) { - NativeConfig config; - config.model_dir = model_path; - config.use_gpu = false; - config.device = 0; - config.specify_input_name = true; - std::vector input_slots, outputs_slots; - DataRecord data(data_file, batch_size); - auto predictor = - CreatePaddlePredictor(config); - GetOneBatch(&input_slots, &data, batch_size); - for (int i = 0; i < FLAGS_burning; i++) { - predictor->Run(input_slots, &outputs_slots); - } - Timer timer; - double sum = 0; - for (int i = 0; i < repeat; i++) { - for (size_t bid = 0; bid < data.batched_datas.size(); ++bid) { - GetOneBatch(&input_slots, &data, batch_size); - timer.tic(); - predictor->Run(input_slots, &outputs_slots); - sum += timer.toc(); - } - } - PrintTime(batch_size, repeat, 1, 0, sum / repeat); -} - const int64_t lac_ref_data[] = {24, 25, 25, 25, 38, 30, 31, 14, 15, 44, 24, 25, 25, 25, 25, 25, 44, 24, 25, 25, 25, 36, 42, 43, 44, 14, 15, 44, 14, 15, 44, 14, 15, 44, 38, 39, diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc index c5cbadc892904dc064b49ebc461944c4671a69da..3eb02c6b61ce61140bd777647a12477dd9c3c803 100644 --- a/paddle/fluid/operators/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/conv_mkldnn_op.cc @@ -130,12 +130,13 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler { std::shared_ptr AcquireWeightsMemoryFromPrimitive( const std::shared_ptr user_weights_memory_p, - std::vector& pipeline) { // NOLINT + std::vector& pipeline, // NOLINT + bool is_persistent = false) { auto user_weights_pd = user_weights_memory_p->get_primitive_desc(); auto weights_pd = conv_pd_->weights_primitive_desc(); return this->AcquireMemory(weights_pd, user_weights_pd, user_weights_memory_p, "@weights_mem_p", - pipeline); + pipeline, is_persistent); } std::shared_ptr AcquireBiasMemoryFromPrimitive( @@ -266,6 +267,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), "It must use CPUPlace."); + const bool is_test = ctx.Attr("is_test"); + auto& dev_ctx = ctx.template device_context(); const auto& mkldnn_engine = dev_ctx.GetEngine(); @@ -296,6 +299,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { std::vector strides = ctx.Attr>("strides"); std::vector paddings = ctx.Attr>("paddings"); std::vector dilations = ctx.Attr>("dilations"); + bool fuse_relu = ctx.Attr("fuse_relu"); int groups = ctx.Attr("groups"); // TODO(pzelazko-intel) add support for group convolution and dilation @@ -348,11 +352,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { bias_tz = paddle::framework::vectorize2int(bias->dims()); auto bias_md = platform::MKLDNNMemDesc( bias_tz, platform::MKLDNNGetDataType(), memory::format::x); - conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, - strides, paddings, mkldnn_engine); + conv_pd = + ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, strides, + paddings, mkldnn_engine, fuse_relu); } else { conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, - paddings, mkldnn_engine); + paddings, mkldnn_engine, fuse_relu); } // Save conv_pd/src_memory/weights_memory for backward pass dev_ctx.SetBlob(key_conv_pd, conv_pd); @@ -371,7 +376,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { auto src_memory_p = handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline); auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive( - user_weights_memory_p, pipeline); + user_weights_memory_p, pipeline, is_test); auto dst_memory_p = handler.AcquireDstMemoryFromPrimitive(to_void_cast(output_data)); @@ -402,11 +407,26 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { } private: + mkldnn::primitive_attr AddRelu() const { + // Fusion with ReLU layer is executed through the PostOps feature. Create a + // PostOps object and configure it to execute an eltwise relu operation. + mkldnn::primitive_attr conv_attr; + constexpr float scale = 1.0f; + constexpr float negative_slope = 0.0f; + constexpr float placeholder = 0.0f; + mkldnn::post_ops post_operations; + post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu, + negative_slope, placeholder); + conv_attr.set_post_ops(post_operations); + return conv_attr; + } + std::unique_ptr ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights, const memory::desc& dst, const std::vector& strides, const std::vector& paddings, - const mkldnn::engine& engine) const { + const mkldnn::engine& engine, + const bool fuse_relu) const { memory::dims stride_dims = {strides[0], strides[1]}; memory::dims padding_dims = {paddings[0], paddings[1]}; @@ -415,8 +435,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { dst, stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero); - auto p_conv_pd = - new mkldnn::convolution_forward::primitive_desc(conv_desc, engine); + mkldnn::primitive_attr conv_attr; + if (fuse_relu) { + conv_attr = AddRelu(); + } + + auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc( + conv_desc, conv_attr, engine); return std::unique_ptr( p_conv_pd); @@ -427,7 +452,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { const memory::desc& bias, const memory::desc& dst, const std::vector& strides, const std::vector& paddings, - const mkldnn::engine& engine) const { + const mkldnn::engine& engine, + const bool fuse_relu) const { memory::dims stride_dims = {strides[0], strides[1]}; memory::dims padding_dims = {paddings[0], paddings[1]}; @@ -436,8 +462,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { bias, dst, stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero); - auto p_conv_pd = - new mkldnn::convolution_forward::primitive_desc(conv_desc, engine); + mkldnn::primitive_attr conv_attr; + if (fuse_relu) { + conv_attr = AddRelu(); + } + + auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc( + conv_desc, conv_attr, engine); return std::unique_ptr( p_conv_pd); diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 61ca80877a6dfcdf30a0ff346342116e36eec6f2..41d4fcf6de7c8fcb3cfbb2063b0a2ac1a2356168 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -109,6 +109,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( } void Conv2DOpMaker::Make() { + AddAttr("is_test", "").SetDefault(false); AddInput( "Input", "(Tensor) The input tensor of convolution operator. " @@ -161,6 +162,8 @@ void Conv2DOpMaker::Make() { AddAttr("use_mkldnn", "(bool, default false) Only used in mkldnn kernel") .SetDefault(false); + AddAttr("fuse_relu", "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); AddAttr( "data_format", "(string, default NCHW) Only used in " diff --git a/paddle/fluid/operators/distributed/CMakeLists.txt b/paddle/fluid/operators/distributed/CMakeLists.txt index da5d20505e9b06c0717af8d79d5456a9ade1e89c..56734b81e8716a0c0c37a11e35c9118ee7b55020 100644 --- a/paddle/fluid/operators/distributed/CMakeLists.txt +++ b/paddle/fluid/operators/distributed/CMakeLists.txt @@ -20,6 +20,7 @@ if(WITH_GRPC) DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL) cc_test(rpc_server_test SRCS rpc_server_test.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL) + cc_test(varhandle_test SRCS varhandle_test.cc) return() endif() diff --git a/paddle/fluid/operators/distributed/grpc_client.cc b/paddle/fluid/operators/distributed/grpc_client.cc index b4f60c9ff9a41d5cb7dbe4e7a7694a84bab8e940..07ac20797ddab54296a45e99915588a40cc6f3c7 100644 --- a/paddle/fluid/operators/distributed/grpc_client.cc +++ b/paddle/fluid/operators/distributed/grpc_client.cc @@ -59,40 +59,32 @@ GRPCClient::~GRPCClient() { } channels_.clear(); } - client_thread_->join(); } -bool GRPCClient::AsyncSendVar(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& var_name, int64_t time_out) { +VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& var_name, + int64_t time_out) { const platform::DeviceContext* p_ctx = &ctx; const std::string ep_val = ep; const std::string var_name_val = var_name; const framework::Scope* p_scope = &scope; const auto ch = GetChannel(ep_val); + SendProcessor* s = new SendProcessor(ch); + VarHandlePtr h(new VarHandle(ep, "Send", var_name_val, p_ctx, p_scope)); + s->Prepare(h, time_out); - framework::AsyncIO([var_name_val, p_ctx, ep_val, p_scope, time_out, ch, - this] { + framework::AsyncIO([var_name_val, p_scope, p_ctx, s, this] { auto* var = p_scope->FindVar(var_name_val); ::grpc::ByteBuffer req; SerializeToByteBuffer(var_name_val, var, *p_ctx, &req); - // varhandle - VarHandle var_h; - var_h.ep = ep_val; - var_h.scope = p_scope; - var_h.name = var_name_val; - var_h.ctx = p_ctx; - var_h.method = "Send"; - - VLOG(3) << var_h.String() << " begin"; + VLOG(3) << s->GetVarHandlePtr()->String() << " begin"; // stub context - SendProcessor* s = new SendProcessor(ch); - s->Prepare(var_h, time_out); s->response_call_back_ = nullptr; auto call = s->stub_g_.PrepareUnaryCall( @@ -102,13 +94,13 @@ bool GRPCClient::AsyncSendVar(const std::string& ep, }); req_count_++; - return true; + return h; } void ProcGetResponse(const VarHandle& var_h, const ::grpc::ByteBuffer& ret_msg) { framework::Variable* outvar = nullptr; - DeserializeFromByteBuffer(ret_msg, *var_h.ctx, var_h.scope, &outvar); + DeserializeFromByteBuffer(ret_msg, *var_h.ctx(), var_h.scope(), &outvar); } template @@ -119,37 +111,30 @@ void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) { result->Swap(&tmp); } -bool GRPCClient::AsyncGetVar(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& var_name, int64_t time_out) { +VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& var_name, + int64_t time_out) { const platform::DeviceContext* p_ctx = &ctx; const std::string ep_val = ep; const std::string var_name_val = var_name; const framework::Scope* p_scope = &scope; const auto ch = GetChannel(ep_val); + GetProcessor* s = new GetProcessor(ch); + VarHandlePtr h(new VarHandle(ep, "Get", var_name_val, p_ctx, p_scope)); + s->Prepare(h, time_out); - framework::AsyncIO([var_name_val, ep_val, p_scope, p_ctx, time_out, ch, - this] { + framework::AsyncIO([var_name_val, p_scope, p_ctx, s, this] { // prepare input sendrecv::VariableMessage req; req.set_varname(var_name_val); ::grpc::ByteBuffer buf; RequestToByteBuffer(req, &buf); - // var handle - VarHandle var_h; - var_h.ep = ep_val; - var_h.scope = p_scope; - var_h.name = var_name_val; - var_h.ctx = p_ctx; - var_h.method = "Get"; - - VLOG(3) << var_h.String() << " begin"; + VLOG(3) << s->GetVarHandlePtr()->String() << " begin"; // stub context - GetProcessor* s = new GetProcessor(ch); - s->Prepare(var_h, time_out); s->response_call_back_ = ProcGetResponse; auto call = s->stub_g_.PrepareUnaryCall( @@ -160,42 +145,36 @@ bool GRPCClient::AsyncGetVar(const std::string& ep, req_count_++; - return true; + return h; } -bool GRPCClient::AsyncPrefetchVar(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& in_var_name, - const std::string& out_var_name, - int64_t time_out) { +VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& in_var_name, + const std::string& out_var_name, + int64_t time_out) { const platform::DeviceContext* p_ctx = &ctx; const std::string ep_val = ep; const std::string in_var_name_val = in_var_name; const std::string out_var_name_val = out_var_name; const framework::Scope* p_scope = &scope; const auto ch = GetChannel(ep_val); + GetProcessor* s = new GetProcessor(ch); + VarHandlePtr h( + new VarHandle(ep, "Prefetch", out_var_name_val, p_ctx, p_scope)); + s->Prepare(h, time_out); framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx, - time_out, ch, this] { + time_out, s, this] { auto* var = p_scope->FindVar(in_var_name_val); ::grpc::ByteBuffer req; SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val); - // var handle - VarHandle var_h; - var_h.ep = ep_val; - var_h.scope = p_scope; - var_h.name = out_var_name_val; - var_h.ctx = p_ctx; - var_h.method = "Prefetch"; - - VLOG(3) << var_h.String() << " begin"; + VLOG(3) << s->GetVarHandlePtr()->String() << " begin"; // stub context - GetProcessor* s = new GetProcessor(ch); - s->Prepare(var_h, time_out); s->response_call_back_ = ProcGetResponse; auto call = s->stub_g_.PrepareUnaryCall( @@ -206,56 +185,68 @@ bool GRPCClient::AsyncPrefetchVar(const std::string& ep, }); req_count_++; - return true; + return h; } -void GRPCClient::AsyncSendBatchBarrier(const std::string& ep, - int64_t time_out) { +VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep, + int64_t time_out) { const auto ch = GetChannel(ep); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); - s->Prepare(time_out); + VarHandlePtr h(new VarHandle(ep, "BatchBarrier", BATCH_BARRIER_MESSAGE, + nullptr, nullptr)); + s->Prepare(h, time_out); sendrecv::VariableMessage req; req.set_varname(BATCH_BARRIER_MESSAGE); auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); req_count_++; + return h; } -void GRPCClient::AsyncSendFetchBarrier(const std::string& ep, - int64_t time_out) { +VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep, + int64_t time_out) { const auto ch = GetChannel(ep); FetchBarrierProcessor* s = new FetchBarrierProcessor(ch); - s->Prepare(time_out); + VarHandlePtr h(new VarHandle(ep, "FetchBarrier", FETCH_BARRIER_MESSAGE, + nullptr, nullptr)); + s->Prepare(h, time_out); sendrecv::VariableMessage req; req.set_varname(FETCH_BARRIER_MESSAGE); auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); req_count_++; + return h; } -void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) { +VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep, + int64_t time_out) { const auto ch = GetChannel(ep); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); - s->Prepare(time_out); + VarHandlePtr h( + new VarHandle(ep, "SendComplete", COMPLETE_MESSAGE, nullptr, nullptr)); + s->Prepare(h, time_out); sendrecv::VariableMessage req; req.set_varname(COMPLETE_MESSAGE); auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); req_count_++; + return h; } -void GRPCClient::AsyncCheckpointNotify(const std::string& ep, - const std::string& dir, - int64_t time_out) { +VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep, + const std::string& dir, + int64_t time_out) { const auto ch = GetChannel(ep); CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch); - s->Prepare(time_out); + VarHandlePtr h(new VarHandle(ep, "CheckPointNotify", CHECKPOINT_SAVE_MESSAGE, + nullptr, nullptr)); + s->Prepare(h, time_out); sendrecv::VariableMessage req; req.set_varname(CHECKPOINT_SAVE_MESSAGE); @@ -264,6 +255,7 @@ void GRPCClient::AsyncCheckpointNotify(const std::string& ep, auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); req_count_++; + return h; } bool GRPCClient::Wait() { @@ -276,25 +268,28 @@ void GRPCClient::Proceed() { void* tag = nullptr; bool ok = false; + VLOG(3) << "GRPCClient Proceed begin"; while (!stopped_ && cq_.Next(&tag, &ok)) { BaseProcessor* c = static_cast(tag); GPR_ASSERT(ok); PADDLE_ENFORCE(c); if (c->status_.ok()) { - VLOG(3) << c->var_h_.String() << " process"; + VLOG(3) << c->GetVarHandlePtr()->String() << " process"; c->Process(); } else if (c->status_.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) { - LOG(ERROR) << c->var_h_.String() + LOG(ERROR) << c->GetVarHandlePtr()->String() << " meets grpc error:" << c->status_.error_message(); { std::lock_guard lk(sync_mutex_); ok_ = false; } - sync_cond_.notify_all(); + c->Finish(false); } else { - LOG(FATAL) << c->var_h_.String() + LOG(FATAL) << c->GetVarHandlePtr()->String() << " meets grpc error:" << c->status_.error_message(); + c->Finish(false); } + delete c; { std::lock_guard lk(sync_mutex_); @@ -302,6 +297,7 @@ void GRPCClient::Proceed() { } sync_cond_.notify_all(); } + VLOG(3) << "GRPCClient Proceed end"; } std::shared_ptr GRPCClient::GetChannel(const std::string& ep) { diff --git a/paddle/fluid/operators/distributed/grpc_client.h b/paddle/fluid/operators/distributed/grpc_client.h index 0c95ffeb5ce7e1586c5968fb122acd12c0c0196e..75a3662316462a222760bfbb7d7906c70f46d143 100644 --- a/paddle/fluid/operators/distributed/grpc_client.h +++ b/paddle/fluid/operators/distributed/grpc_client.h @@ -53,15 +53,14 @@ void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg); class BaseProcessor { public: - explicit BaseProcessor(std::shared_ptr ch) { - context_ = nullptr; - } + BaseProcessor() { context_ = nullptr; } virtual ~BaseProcessor() {} - virtual void Prepare(const VarHandle& var_info, int64_t time_out) { + virtual void Prepare(VarHandlePtr h, int64_t time_out) { + var_h_ = h; + context_.reset(new grpc::ClientContext()); - var_h_ = var_info; context_->set_wait_for_ready(true); if (time_out) { std::chrono::system_clock::time_point deadline = @@ -71,21 +70,21 @@ class BaseProcessor { } } - virtual void Prepare(int64_t time_out) { - context_.reset(new grpc::ClientContext()); - context_->set_wait_for_ready(true); - - std::chrono::system_clock::time_point deadline = - std::chrono::system_clock::now() + std::chrono::milliseconds(time_out); - - context_->set_deadline(deadline); + void Process() { + ProcessImpl(); + var_h_->Finish(true); } - virtual void Process() = 0; + VarHandlePtr GetVarHandlePtr() { return var_h_; } + bool Wait() { return var_h_->Wait(); } + void Finish(bool ok) { return var_h_->Finish(ok); } + virtual void ProcessImpl() = 0; std::unique_ptr context_; grpc::Status status_; - VarHandle var_h_; + + protected: + VarHandlePtr var_h_; }; typedef std::function @@ -94,13 +93,13 @@ typedef std::function class SendProcessor : public BaseProcessor { public: explicit SendProcessor(std::shared_ptr ch) - : BaseProcessor(ch), stub_g_(ch) {} + : BaseProcessor(), stub_g_(ch) {} virtual ~SendProcessor() {} - virtual void Process() { + void ProcessImpl() override { if (response_call_back_) { - response_call_back_(var_h_, reply_); + response_call_back_(*var_h_.get(), reply_); } } @@ -115,13 +114,13 @@ typedef std::function class GetProcessor : public BaseProcessor { public: explicit GetProcessor(std::shared_ptr ch) - : BaseProcessor(ch), stub_g_(ch) {} + : BaseProcessor(), stub_g_(ch) {} virtual ~GetProcessor() {} - virtual void Process() { + void ProcessImpl() override { if (response_call_back_) { - response_call_back_(var_h_, reply_); + response_call_back_(*var_h_.get(), reply_); } } @@ -133,13 +132,13 @@ class GetProcessor : public BaseProcessor { class BatchBarrierProcessor : public BaseProcessor { public: explicit BatchBarrierProcessor(std::shared_ptr ch) - : BaseProcessor(ch) { + : BaseProcessor() { stub_ = sendrecv::SendRecvService::NewStub(ch); } virtual ~BatchBarrierProcessor() {} - virtual void Process() {} + void ProcessImpl() override {} sendrecv::VoidMessage reply_; std::unique_ptr stub_; }; @@ -147,13 +146,13 @@ class BatchBarrierProcessor : public BaseProcessor { class FetchBarrierProcessor : public BaseProcessor { public: explicit FetchBarrierProcessor(std::shared_ptr ch) - : BaseProcessor(ch) { + : BaseProcessor() { stub_ = sendrecv::SendRecvService::NewStub(ch); } virtual ~FetchBarrierProcessor() {} - virtual void Process() {} + void ProcessImpl() override {} sendrecv::VariableMessage reply_; std::unique_ptr stub_; }; @@ -161,13 +160,13 @@ class FetchBarrierProcessor : public BaseProcessor { class CheckpointNotifyProcessor : public BaseProcessor { public: explicit CheckpointNotifyProcessor(std::shared_ptr ch) - : BaseProcessor(ch) { + : BaseProcessor() { stub_ = sendrecv::SendRecvService::NewStub(ch); } virtual ~CheckpointNotifyProcessor() {} - virtual void Process() {} + void ProcessImpl() override {} sendrecv::VoidMessage reply_; std::unique_ptr stub_; }; @@ -177,32 +176,37 @@ class GRPCClient : public RPCClient { GRPCClient() : ok_(true), completed_(false), stopped_(false) {} virtual ~GRPCClient(); - bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx, - const framework::Scope& scope, const std::string& var_name, - int64_t time_out = FLAGS_rpc_deadline) override; - - bool AsyncGetVar(const std::string& ep, const platform::DeviceContext& ctx, - const framework::Scope& scope, const std::string& var_name, - int64_t time_out = FLAGS_rpc_deadline) override; - - bool AsyncPrefetchVar(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& in_var_name, - const std::string& out_var_name, - int64_t time_out = FLAGS_rpc_deadline) override; - - void AsyncSendBatchBarrier(const std::string& ep, - int64_t time_out = FLAGS_rpc_deadline) override; - - void AsyncSendFetchBarrier(const std::string& ep, - int64_t time_out = FLAGS_rpc_deadline) override; - - void AsyncCheckpointNotify(const std::string& ep, const std::string& dir, - int64_t time_out = FLAGS_rpc_deadline) override; - - void AsyncSendComplete(const std::string& ep, - int64_t time_out = FLAGS_rpc_deadline) override; + VarHandlePtr AsyncSendVar(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& var_name, + int64_t time_out = FLAGS_rpc_deadline) override; + + VarHandlePtr AsyncGetVar(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& var_name, + int64_t time_out = FLAGS_rpc_deadline) override; + + VarHandlePtr AsyncPrefetchVar(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& in_var_name, + const std::string& out_var_name, + int64_t time_out = FLAGS_rpc_deadline) override; + + VarHandlePtr AsyncSendBatchBarrier( + const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override; + + VarHandlePtr AsyncSendFetchBarrier( + const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override; + + VarHandlePtr AsyncCheckpointNotify( + const std::string& ep, const std::string& dir, + int64_t time_out = FLAGS_rpc_deadline) override; + + VarHandlePtr AsyncSendComplete( + const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override; bool Wait() override; diff --git a/paddle/fluid/operators/distributed/request_handler.h b/paddle/fluid/operators/distributed/request_handler.h index 64ac7281848f91302bc0aa3cb81dd198e56fb653..3c3f9d17c871ac1cb4df83db17cf489d5b9e0563 100644 --- a/paddle/fluid/operators/distributed/request_handler.h +++ b/paddle/fluid/operators/distributed/request_handler.h @@ -28,6 +28,7 @@ #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/var_type.h" +#include "paddle/fluid/platform/macros.h" namespace paddle { namespace operators { @@ -49,23 +50,77 @@ constexpr char kRequestPassBarrier[] = "RequestPassBarrier"; class RPCServer; -struct VarHandle { - // RPC endpoint. - std::string ep; - const platform::DeviceContext* ctx; - const framework::Scope* scope; - // Variable name. - std::string name; - // RPC method name. - std::string method; +class VarHandle { + public: + VarHandle(const std::string ep, const std::string& method, + const std::string& name, + const platform::DeviceContext* p_ctx = nullptr, + const framework::Scope* p_scope = nullptr) + : ok_(kVarHandleDefaultState) { + ep_ = ep; + ctx_ = p_ctx; + scope_ = p_scope; + name_ = name; + method_ = method; + } + + virtual ~VarHandle() {} + + public: + bool Wait() { + { + std::unique_lock lk(sync_mutex_); + wait_cond_.wait(lk, [this] { return ok_ != kVarHandleDefaultState; }); + } + VLOG(7) << "VarHandle wait:" << ok_; + return ok_ != 0; + } + + void Finish(bool ok) { + { + std::unique_lock lk(sync_mutex_); + ok_ = ok; + } + VLOG(7) << "VarHandle finish:" << ok; + wait_cond_.notify_all(); + } std::string String() const { std::ostringstream s; - s << method << " name:[" << name << "], ep:[" << ep << "]"; + s << method_ << " name:[" << name_ << "], ep:[" << ep_ << "], ok:[" << ok_ + << "]"; return s.str(); } + + std::string ep() const { return ep_; } + const platform::DeviceContext* ctx() const { return ctx_; } + const framework::Scope* scope() const { return scope_; } + std::string name() const { return name_; } + std::string method() const { return method_; } + + protected: + // RPC endpoint. + std::string ep_; + const platform::DeviceContext* ctx_; + const framework::Scope* scope_; + // Variable name. + std::string name_; + // RPC method name. + std::string method_; + + protected: + std::mutex sync_mutex_; + std::condition_variable wait_cond_; + int ok_; + + static const int kVarHandleDefaultState = -1; + + private: + DISABLE_COPY_AND_ASSIGN(VarHandle); }; +typedef std::shared_ptr VarHandlePtr; + class RequestHandler { public: explicit RequestHandler(bool sync_mode) diff --git a/paddle/fluid/operators/distributed/rpc_client.h b/paddle/fluid/operators/distributed/rpc_client.h index 22a022a5d25e5c6628b80294494b87ca105a04c7..3539ee5e459d6dfe0b6510806464bcc6817910bb 100644 --- a/paddle/fluid/operators/distributed/rpc_client.h +++ b/paddle/fluid/operators/distributed/rpc_client.h @@ -14,12 +14,14 @@ #pragma once +#include // NOLINT #include #include "gflags/gflags.h" #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/operators/distributed/request_handler.h" DECLARE_int32(rpc_deadline); @@ -31,37 +33,36 @@ class RPCClient { public: RPCClient() {} virtual ~RPCClient() {} - virtual bool AsyncSendVar(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& var_name, - int64_t time_out = FLAGS_rpc_deadline) = 0; - - virtual bool AsyncGetVar(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& var_name, - int64_t time_out = FLAGS_rpc_deadline) = 0; - - virtual bool AsyncPrefetchVar(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& in_var_name, - const std::string& out_var_name, - int64_t time_out = FLAGS_rpc_deadline) = 0; - - virtual void AsyncSendBatchBarrier(const std::string& ep, - int64_t time_out = FLAGS_rpc_deadline) = 0; - - virtual void AsyncSendFetchBarrier(const std::string& ep, - int64_t time_out = FLAGS_rpc_deadline) = 0; - - virtual void AsyncCheckpointNotify(const std::string& ep, - const std::string& dir, - int64_t time_out = FLAGS_rpc_deadline) = 0; - - virtual void AsyncSendComplete(const std::string& ep, - int64_t time_out = FLAGS_rpc_deadline) = 0; + virtual VarHandlePtr AsyncSendVar(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& var_name, + int64_t time_out = FLAGS_rpc_deadline) = 0; + + virtual VarHandlePtr AsyncGetVar(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& var_name, + int64_t time_out = FLAGS_rpc_deadline) = 0; + + virtual VarHandlePtr AsyncPrefetchVar( + const std::string& ep, const platform::DeviceContext& ctx, + const framework::Scope& scope, const std::string& in_var_name, + const std::string& out_var_name, + int64_t time_out = FLAGS_rpc_deadline) = 0; + + virtual VarHandlePtr AsyncSendBatchBarrier( + const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) = 0; + + virtual VarHandlePtr AsyncSendFetchBarrier( + const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) = 0; + + virtual VarHandlePtr AsyncCheckpointNotify( + const std::string& ep, const std::string& dir, + int64_t time_out = FLAGS_rpc_deadline) = 0; + + virtual VarHandlePtr AsyncSendComplete( + const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) = 0; // Complete tells all the pserver instances that finishe the training, // the pserver can reduce it's barrier count, and continue to train diff --git a/paddle/fluid/operators/distributed/varhandle_test.cc b/paddle/fluid/operators/distributed/varhandle_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a0fcaf886475c5e03d959ffd6af22b2123526b9f --- /dev/null +++ b/paddle/fluid/operators/distributed/varhandle_test.cc @@ -0,0 +1,55 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include // NOLINT + +#include "google/protobuf/text_format.h" +#include "gtest/gtest.h" +#include "paddle/fluid/operators/distributed/request_handler.h" + +using paddle::operators::distributed::VarHandlePtr; +using paddle::operators::distributed::VarHandle; + +void WaitTrue(VarHandlePtr s) { EXPECT_TRUE(s->Wait()); } + +void WaitFalse(VarHandlePtr s) { EXPECT_FALSE(s->Wait()); } + +TEST(VarHandle, Run) { + std::vector a; + for (int i = 0; i < 12; i++) { + VarHandlePtr s(new VarHandle("", "", "", nullptr, nullptr)); + a.push_back(s); + } + + std::vector> t; + for (int i = 0; i < 6; i++) { + t.emplace_back(new std::thread(WaitFalse, a[i])); + } + + for (int i = 0; i < 6; i++) { + a[i]->Finish(false); + t[i]->join(); + } + + for (int i = 6; i < 12; i++) { + t.emplace_back(new std::thread(WaitTrue, a[i])); + } + + for (int i = 6; i < 12; i++) { + a[i]->Finish(true); + t[i]->join(); + } +} diff --git a/paddle/fluid/operators/prefetch_op.cc b/paddle/fluid/operators/prefetch_op.cc index 4b804740a06f9e29704f2b3f58a90191e3559347..0519c15e13aac99802ff0f95b975712b36b44246 100644 --- a/paddle/fluid/operators/prefetch_op.cc +++ b/paddle/fluid/operators/prefetch_op.cc @@ -44,16 +44,20 @@ class PrefetchOp : public framework::OperatorBase { distributed::RPCClient* rpc_client = distributed::RPCClient::GetInstance(); + std::vector rets; for (size_t i = 0; i < ins.size(); i++) { if (NeedSend(scope, ins[i])) { VLOG(3) << "sending " << ins[i] << " to " << epmap[i] << " to get " << outs[i] << " back"; - rpc_client->AsyncPrefetchVar(epmap[i], ctx, scope, ins[i], outs[i]); + rets.push_back(rpc_client->AsyncPrefetchVar(epmap[i], ctx, scope, + ins[i], outs[i])); } else { VLOG(3) << "don't send no-initialied variable: " << ins[i]; } } - PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); + for (size_t i = 0; i < rets.size(); i++) { + PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); + } } }; diff --git a/paddle/fluid/operators/recv_op.cc b/paddle/fluid/operators/recv_op.cc index a1f368e8690512cec2db7593aabc0279bbe174eb..4d34b8a1686efb1fc30020f0d27e9a3c3a6c0866 100644 --- a/paddle/fluid/operators/recv_op.cc +++ b/paddle/fluid/operators/recv_op.cc @@ -44,12 +44,15 @@ class RecvOp : public framework::OperatorBase { distributed::RPCClient* rpc_client = distributed::RPCClient::GetInstance(); + std::vector rets; for (size_t i = 0; i < outs.size(); i++) { VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; - rpc_client->AsyncGetVar(epmap[i], ctx, scope, outs[i]); + rets.push_back(rpc_client->AsyncGetVar(epmap[i], ctx, scope, outs[i])); } if (sync_mode) { - PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); + for (size_t i = 0; i < rets.size(); i++) { + PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); + } } } }; diff --git a/paddle/fluid/operators/send_op.cc b/paddle/fluid/operators/send_op.cc index 82a70e4bf13247d784371ffdf419c9f792d7f721..48322ac7fd54a2e4cc3405a2c4dcddfc273f5a66 100644 --- a/paddle/fluid/operators/send_op.cc +++ b/paddle/fluid/operators/send_op.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include // NOLINT #include +#include "paddle/fluid/framework/blocking_queue.h" #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" @@ -45,18 +46,19 @@ class SendOp : public framework::OperatorBase { distributed::RPCClient* rpc_client = distributed::RPCClient::GetInstance(); + std::vector rets; for (size_t i = 0; i < ins.size(); i++) { if (NeedSend(scope, ins[i])) { VLOG(3) << "sending " << ins[i] << " to " << epmap[i]; - // TODO(Yancey1989): we need to use an IO threadpool which has - // a larger number of threads than the computing threadpool. - rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i]); + rets.push_back(rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i])); } else { VLOG(3) << "don't send no-initialied variable: " << ins[i]; } } if (sync_send) { - PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); + for (size_t i = 0; i < rets.size(); i++) { + PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); + } } } }; diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index f6e9a52b275353c03c1f350719766922a97f6cb3..c0a2543ba5d8ff8f34cb6231c51cb5053a6a9481 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -192,7 +192,8 @@ class MKLDNNHandler { mkldnn::memory::primitive_desc& user_mpd, // NOLINT const std::shared_ptr user_memory_p, const std::string& suffix, - std::vector& pipeline) { // NOLINT + std::vector& pipeline, // NOLINT + bool is_persistent = false) { // create reorder primitive if the input format is not the preferred one auto local_key = key_ + suffix; auto key_reorder_p = key_ + suffix + "reorder_p"; @@ -213,7 +214,7 @@ class MKLDNNHandler { pipeline.push_back(*reorder_p); } dev_ctx_.SetBlob(local_key, target_memory_p); - } else { + } else if (!is_persistent) { // Make reorder if needed auto reorder_p = std::static_pointer_cast( dev_ctx_.GetBlob(key_reorder_p)); diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index 058f414e9b7130bc6ff68c4b9d7d65ca7db300cc..a6395d9e207cbb2a030f7a1f76240736a4d66998 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -128,6 +128,13 @@ class ParallelExecutor(object): os.environ.get('CPU_NUM', multiprocessing.cpu_count())) exec_strategy.num_threads = cpu_num * 2 + # Set 1 thread num under nccl2 distribute + # env to make sure all gpus run ops in same order. + if num_trainers > 1: + assert (use_cuda) + # FIXME(gongwb): avoid this set. + exec_strategy.num_threads = 1 + if build_strategy is None: build_strategy = BuildStrategy() diff --git a/python/paddle/fluid/transpiler/inference_transpiler.py b/python/paddle/fluid/transpiler/inference_transpiler.py index 02fefe32df14dfca39de1da24147c284d6d82230..adad2428f7fdc554cf4efd652f52b5c5de0ab527 100644 --- a/python/paddle/fluid/transpiler/inference_transpiler.py +++ b/python/paddle/fluid/transpiler/inference_transpiler.py @@ -60,12 +60,46 @@ class InferenceTranspiler(object): if not isinstance(scope, core.Scope): raise TypeError("scope should be as Scope type or None") use_mkldnn = bool(os.getenv("FLAGS_use_mkldnn", False)) + self._fuse_batch_norm(program, place, scope) if use_mkldnn: - self._fuse_relu_mkldnn(program) self._fuse_conv_bias_mkldnn(program) + self._fuse_conv_relu_mkldnn(program) + self._fuse_bn_relu_mkldnn(program) + + def _fuse_conv_relu_mkldnn(self, program): + ''' + Transpile the program by fused relu activation for MKLDNN program. + Relu activation following convolution OP can be fused by adding + 'fuse_relu' attribute to convolution OP. + The result of fuse is: + - before: + - conv->relu->any_other_op + - after: + - conv->any_other_op + :param program: program to transpile + :type program: Program + ''' + self.block = program.block(0) + + i = 0 + while i < len(self.block.ops): + current_op = self.block.ops[i] + if current_op.type in ['conv2d']: + next_op = self.block.ops[i + 1] + if next_op.type == 'relu': + # modify conv OP to include relu + current_op.set_attr("fuse_relu", True) + # remove conv OP + self.block._remove_op(i + 1) + i = i + 1 + + # TODO(luotao): use clone() method to flush the program.desc in force, + # since some large program.desc will not be flushed immediately. + # And a better solution will be considered later. + program = program.clone() - def _fuse_relu_mkldnn(self, program): + def _fuse_bn_relu_mkldnn(self, program): ''' Transpile the program by fused relu activation for MKLDNN program. @@ -159,7 +193,6 @@ class InferenceTranspiler(object): self._fuse_conv_bias(i, current_op, next_op) self.block._remove_op(i + 1) # Remove old conv self.block._remove_op(i + 1) # Remove elementwise_add - i = i + 1 i = i + 1 self._remove_unused_var()