diff --git a/README.md b/README.md index 60ffbe728178705b1734e682868614025214c2a4..45186ec4ef48dc305b2616dbf4966f01c3609962 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ Our vision is to enable deep learning for everyone via PaddlePaddle. Please refer to our [release announcement](https://github.com/PaddlePaddle/Paddle/releases) to track the latest feature of PaddlePaddle. -### Latest PaddlePaddle Release: [Fluid 0.14.0](https://github.com/PaddlePaddle/Paddle/tree/v0.14.0) +### Latest PaddlePaddle Release: [Fluid 0.15.0](https://github.com/PaddlePaddle/Paddle/tree/v0.15.0) ### Install Latest Stable Release: ``` # Linux CPU @@ -76,26 +76,26 @@ pip install paddlepaddle-gpu==0.14.0.post85 ## Installation -It is recommended to read [this doc](http://paddlepaddle.org/documentation/docs/zh/0.14.0/new_docs/beginners_guide/install/install_doc.html) on our website. +It is recommended to read [this doc](http://paddlepaddle.org/documentation/docs/zh/0.15.0/new_docs/beginners_guide/install/install_doc.html) on our website. ## Documentation -We provide [English](http://paddlepaddle.org/documentation/docs/en/0.14.0/getstarted/index_en.html) and -[Chinese](http://paddlepaddle.org/documentation/docs/zh/0.14.0/new_docs/beginners_guide/index.html) documentation. +We provide [English](http://paddlepaddle.org/documentation/docs/en/0.15.0/getstarted/index_en.html) and +[Chinese](http://paddlepaddle.org/documentation/docs/zh/0.15.0/new_docs/beginners_guide/index.html) documentation. - [Deep Learning 101](https://github.com/PaddlePaddle/book) You might want to start from this online interactive book that can run in a Jupyter Notebook. -- [Distributed Training](http://paddlepaddle.org/documentation/docs/zh/0.14.0/new_docs/user_guides/howto/training/cluster_howto.html) +- [Distributed Training](http://paddlepaddle.org/documentation/docs/zh/0.15.0/new_docs/user_guides/howto/training/cluster_howto.html) You can run distributed training jobs on MPI clusters. -- [Python API](http://paddlepaddle.org/documentation/api/zh/0.14.0/fluid.html) +- [Python API](http://paddlepaddle.org/documentation/api/zh/0.15.0/fluid.html) Our new API enables much shorter programs. -- [How to Contribute](http://paddlepaddle.org/documentation/docs/zh/0.14.0/new_docs/advanced_usage/development/contribute_to_paddle.html) +- [How to Contribute](http://paddlepaddle.org/documentation/docs/zh/0.15.0/new_docs/advanced_usage/development/contribute_to_paddle.html) We appreciate your contributions! diff --git a/benchmark/fluid/models/resnet.py b/benchmark/fluid/models/resnet.py index ae1baa48e17e40448e457052fd1464b9604a2128..d71b855612ae32083b2b2e3448db3749c340633b 100644 --- a/benchmark/fluid/models/resnet.py +++ b/benchmark/fluid/models/resnet.py @@ -20,6 +20,7 @@ import functools import numpy as np import time import os +import math import cProfile, pstats, StringIO @@ -27,128 +28,120 @@ import paddle import paddle.fluid as fluid import paddle.fluid.core as core import paddle.fluid.profiler as profiler -# from recordio_converter import imagenet_train, imagenet_test from imagenet_reader import train, val +train_parameters = { + "input_size": [3, 224, 224], + "input_mean": [0.485, 0.456, 0.406], + "input_std": [0.229, 0.224, 0.225], + "learning_strategy": { + "name": "piecewise_decay", + "batch_size": 256, + "epochs": [30, 60, 90], + "steps": [0.1, 0.01, 0.001, 0.0001] + } +} + + +class ResNet(): + def __init__(self, layers=50, is_train=True): + self.params = train_parameters + self.layers = layers + self.is_train = is_train + + def net(self, input, class_dim=1000): + layers = self.layers + supported_layers = [50, 101, 152] + assert layers in supported_layers, \ + "supported layers are {} but input layer is {}".format(supported_layers, layers) + + if layers == 50: + depth = [3, 4, 6, 3] + elif layers == 101: + depth = [3, 4, 23, 3] + elif layers == 152: + depth = [3, 8, 36, 3] + num_filters = [64, 128, 256, 512] + + conv = self.conv_bn_layer( + input=input, num_filters=64, filter_size=7, stride=2, act='relu') + conv = fluid.layers.pool2d( + input=conv, + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type='max') + + for block in range(len(depth)): + for i in range(depth[block]): + conv = self.bottleneck_block( + input=conv, + num_filters=num_filters[block], + stride=2 if i == 0 and block != 0 else 1) + + pool = fluid.layers.pool2d( + input=conv, pool_size=7, pool_type='avg', global_pooling=True) + stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) + out = fluid.layers.fc(input=pool, + size=class_dim, + act='softmax', + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, + stdv))) + return out + + def conv_bn_layer(self, + input, + num_filters, + filter_size, + stride=1, + groups=1, + act=None): + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + bias_attr=False) + return fluid.layers.batch_norm( + input=conv, act=act, is_test=not self.is_train) + + def shortcut(self, input, ch_out, stride): + ch_in = input.shape[1] + if ch_in != ch_out or stride != 1: + return self.conv_bn_layer(input, ch_out, 1, stride) + else: + return input -def conv_bn_layer(input, - ch_out, - filter_size, - stride, - padding, - act='relu', - is_train=True): - conv1 = fluid.layers.conv2d( - input=input, - filter_size=filter_size, - num_filters=ch_out, - stride=stride, - padding=padding, - act=None, - bias_attr=False) - return fluid.layers.batch_norm(input=conv1, act=act, is_test=not is_train) - - -def shortcut(input, ch_out, stride, is_train=True): - ch_in = input.shape[1] # if args.data_format == 'NCHW' else input.shape[-1] - if ch_in != ch_out: - return conv_bn_layer( - input, ch_out, 1, stride, 0, None, is_train=is_train) - else: - return input - - -def basicblock(input, ch_out, stride, is_train=True): - short = shortcut(input, ch_out, stride, is_train=is_train) - conv1 = conv_bn_layer(input, ch_out, 3, stride, 1, is_train=is_train) - conv2 = conv_bn_layer(conv1, ch_out, 3, 1, 1, act=None, is_train=is_train) - return fluid.layers.elementwise_add(x=short, y=conv2, act='relu') - - -def bottleneck(input, ch_out, stride, is_train=True): - short = shortcut(input, ch_out * 4, stride, is_train=is_train) - conv1 = conv_bn_layer(input, ch_out, 1, stride, 0, is_train=is_train) - conv2 = conv_bn_layer(conv1, ch_out, 3, 1, 1, is_train=is_train) - conv3 = conv_bn_layer( - conv2, ch_out * 4, 1, 1, 0, act=None, is_train=is_train) - return fluid.layers.elementwise_add(x=short, y=conv3, act='relu') - - -def layer_warp(block_func, input, ch_out, count, stride): - res_out = block_func(input, ch_out, stride) - for i in range(1, count): - res_out = block_func(res_out, ch_out, 1) - return res_out - + def bottleneck_block(self, input, num_filters, stride): + conv0 = self.conv_bn_layer( + input=input, num_filters=num_filters, filter_size=1, act='relu') + conv1 = self.conv_bn_layer( + input=conv0, + num_filters=num_filters, + filter_size=3, + stride=stride, + act='relu') + conv2 = self.conv_bn_layer( + input=conv1, num_filters=num_filters * 4, filter_size=1, act=None) -def resnet_imagenet(input, - class_dim, - depth=50, - data_format='NCHW', - is_train=True): + short = self.shortcut(input, num_filters * 4, stride) - cfg = { - 18: ([2, 2, 2, 1], basicblock), - 34: ([3, 4, 6, 3], basicblock), - 50: ([3, 4, 6, 3], bottleneck), - 101: ([3, 4, 23, 3], bottleneck), - 152: ([3, 8, 36, 3], bottleneck) - } - stages, block_func = cfg[depth] - conv1 = conv_bn_layer(input, ch_out=64, filter_size=7, stride=2, padding=3) - pool1 = fluid.layers.pool2d( - input=conv1, pool_type='avg', pool_size=3, pool_stride=2) - res1 = layer_warp(block_func, pool1, 64, stages[0], 1) - res2 = layer_warp(block_func, res1, 128, stages[1], 2) - res3 = layer_warp(block_func, res2, 256, stages[2], 2) - res4 = layer_warp(block_func, res3, 512, stages[3], 2) - pool2 = fluid.layers.pool2d( - input=res4, - pool_size=7, - pool_type='avg', - pool_stride=1, - global_pooling=True) - out = fluid.layers.fc(input=pool2, size=class_dim, act='softmax') - return out - - -def resnet_cifar10(input, class_dim, depth=32, data_format='NCHW'): - assert (depth - 2) % 6 == 0 - - n = (depth - 2) // 6 - - conv1 = conv_bn_layer( - input=input, ch_out=16, filter_size=3, stride=1, padding=1) - res1 = layer_warp(basicblock, conv1, 16, n, 1) - res2 = layer_warp(basicblock, res1, 32, n, 2) - res3 = layer_warp(basicblock, res2, 64, n, 2) - pool = fluid.layers.pool2d( - input=res3, pool_size=8, pool_type='avg', pool_stride=1) - out = fluid.layers.fc(input=pool, size=class_dim, act='softmax') - return out + return fluid.layers.elementwise_add(x=short, y=conv2, act='relu') def _model_reader_dshape_classdim(args, is_train): - model = resnet_cifar10 + model = None reader = None - if args.data_set == "cifar10": - class_dim = 10 - if args.data_format == 'NCHW': - dshape = [3, 32, 32] - else: - dshape = [32, 32, 3] - model = resnet_cifar10 - if is_train: - reader = paddle.dataset.cifar.train10() - else: - reader = paddle.dataset.cifar.test10() - elif args.data_set == "flowers": + if args.data_set == "flowers": class_dim = 102 if args.data_format == 'NCHW': dshape = [3, 224, 224] else: dshape = [224, 224, 3] - model = resnet_imagenet if is_train: reader = paddle.dataset.flowers.train() else: @@ -159,7 +152,6 @@ def _model_reader_dshape_classdim(args, is_train): dshape = [3, 224, 224] else: dshape = [224, 224, 3] - model = resnet_imagenet if not args.data_path: raise Exception( "Must specify --data_path when training with imagenet") @@ -173,12 +165,11 @@ def _model_reader_dshape_classdim(args, is_train): reader = train(xmap=False) else: reader = val(xmap=False) - return model, reader, dshape, class_dim + return reader, dshape, class_dim def get_model(args, is_train, main_prog, startup_prog): - model, reader, dshape, class_dim = _model_reader_dshape_classdim(args, - is_train) + reader, dshape, class_dim = _model_reader_dshape_classdim(args, is_train) pyreader = None trainer_count = int(os.getenv("PADDLE_TRAINERS")) @@ -198,7 +189,8 @@ def get_model(args, is_train, main_prog, startup_prog): label = fluid.layers.data( name='label', shape=[1], dtype='int64') - predict = model(input, class_dim, is_train=is_train) + model = ResNet(is_train=is_train) + predict = model.net(input, class_dim=class_dim) cost = fluid.layers.cross_entropy(input=predict, label=label) avg_cost = fluid.layers.mean(x=cost) @@ -216,15 +208,14 @@ def get_model(args, is_train, main_prog, startup_prog): total_images = 1281167 / trainer_count step = int(total_images / args.batch_size + 1) - epochs = [30, 60, 80, 90] + epochs = [30, 60, 90] bd = [step * e for e in epochs] base_lr = args.learning_rate lr = [] lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)] optimizer = fluid.optimizer.Momentum( - learning_rate=base_lr, - #learning_rate=fluid.layers.piecewise_decay( - # boundaries=bd, values=lr), + learning_rate=fluid.layers.piecewise_decay( + boundaries=bd, values=lr), momentum=0.9, regularization=fluid.regularizer.L2Decay(1e-4)) optimizer.minimize(avg_cost) diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index 7a99169849debcbc57d6f197b36c5045b211f3ef..d44ebbae4d4be9c79e629303805c94030b8879db 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -442,8 +442,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( use_gpu = nccl_ctxs_ != nullptr; #endif - if (use_gpu || - strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { + if (use_gpu && strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) { // Insert BCast Ops for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) { auto &to_bcast_set = bcast_var_name_set[dev_id]; diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index ce3ebed00b4db81b6ba1bab566a56207341a67c0..7004f484a9975124750fad4cb8f773342082b514 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -28,6 +28,9 @@ cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph grap pass_library(graph_to_program_pass base) pass_library(graph_viz_pass base) pass_library(fc_fuse_pass inference) +if(WITH_MKLDNN) + pass_library(conv_relu_mkldnn_fuse_pass inference) +endif() pass_library(attention_lstm_fuse_pass inference) pass_library(infer_clean_graph_pass inference) pass_library(fc_lstm_fuse_pass inference) @@ -42,3 +45,6 @@ cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_r cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass) cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector) cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto) +if(WITH_MKLDNN) + cc_test(test_conv_relu_mkldnn_fuse_pass SRCS conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass) +endif() diff --git a/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..4408cb45acb3d46e1addf5c25c238af50e5f5e5f --- /dev/null +++ b/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.cc @@ -0,0 +1,90 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h" +#include +#include +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { + +std::unique_ptr ConvReLUFusePass::ApplyImpl( + std::unique_ptr graph) const { + PADDLE_ENFORCE(graph.get()); + FusePassBase::Init("conv_relu_mkldnn_fuse", graph.get()); + + std::unordered_set nodes2delete; + + GraphPatternDetector gpd; + auto* conv_input = gpd.mutable_pattern() + ->NewNode("conv_relu_mkldnn_fuse/conv_input") + ->AsInput() + ->assert_is_op_input("conv2d", "Input"); + patterns::ConvReLU conv_relu_pattern(gpd.mutable_pattern(), + "conv_relu_mkldnn_fuse"); + conv_relu_pattern(conv_input); + + int found_conv_relu_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "handle ConvReLU fuse"; + GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, + conv_relu_pattern); // Filter + GET_IR_NODE_FROM_SUBGRAPH(conv_bias, conv_bias, conv_relu_pattern); // Bias + GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_relu_pattern); // tmp + GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_relu_pattern); // CONV op + GET_IR_NODE_FROM_SUBGRAPH(relu_out, relu_out, conv_relu_pattern); // Out + GET_IR_NODE_FROM_SUBGRAPH(relu, relu, conv_relu_pattern); // ReLU op + + // Create an ConvReLU Node. + OpDesc desc; + std::string conv_relu_i_in = subgraph.at(conv_input)->Name(); + std::string conv_relu_w_in = conv_weight->Name(); + std::string conv_relu_b_in = conv_bias->Name(); + std::string conv_relu_out = relu_out->Name(); + desc.SetInput("Input", std::vector({conv_relu_i_in})); + desc.SetInput("Filter", std::vector({conv_relu_w_in})); + desc.SetInput("Bias", std::vector({conv_relu_b_in})); + desc.SetOutput("Out", std::vector({conv_relu_out})); + desc.SetType("conv2d"); + for (auto& attr : conv->Op()->GetAttrMap()) { + desc.SetAttr(attr.first, attr.second); + } + desc.SetAttr("fuse_relu", true); + auto conv_relu_node = g->CreateOpNode(&desc); // OpDesc will be copied. + GraphSafeRemoveNodes(graph.get(), {conv, relu, conv_out}); + + PADDLE_ENFORCE(subgraph.count(conv_input)); + IR_NODE_LINK_TO(subgraph.at(conv_input), conv_relu_node); + IR_NODE_LINK_TO(conv_weight, conv_relu_node); + IR_NODE_LINK_TO(conv_bias, conv_relu_node); + IR_NODE_LINK_TO(conv_relu_node, relu_out); + + found_conv_relu_count++; + }; + + gpd(graph.get(), handler); + + AddStatis(found_conv_relu_count); + return graph; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(conv_relu_mkldnn_fuse_pass, + paddle::framework::ir::ConvReLUFusePass); diff --git a/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..b5de0d548713772e7ad41cfb6d8b3e9460683efb --- /dev/null +++ b/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h @@ -0,0 +1,39 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +/* + * Fuse the CONV and ReLU to a ConvReLUOp. + */ +class ConvReLUFusePass : public FusePassBase { + public: + virtual ~ConvReLUFusePass() {} + + protected: + std::unique_ptr ApplyImpl(std::unique_ptr graph) const; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass_tester.cc new file mode 100644 index 0000000000000000000000000000000000000000..82b5fa1886098ca3b19c147c307d3f2fc3ba03d6 --- /dev/null +++ b/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass_tester.cc @@ -0,0 +1,108 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h" + +#include + +namespace paddle { +namespace framework { +namespace ir { + +void SetOp(ProgramDesc* prog, const std::string& type, + const std::vector& inputs, + const std::vector& outputs) { + auto* op = prog->MutableBlock(0)->AppendOp(); + op->SetType(type); + if (type == "conv2d") { + op->SetAttr("use_mkldnn", true); + op->SetInput("Input", {inputs[0]}); + op->SetInput("Filter", {inputs[1]}); + op->SetInput("Bias", {inputs[2]}); + } else if (type == "relu") { + op->SetInput("X", inputs); + } + op->SetOutput("Out", outputs); +} + +// a->OP0->b +// b->OP1->c +// (c, weights, bias)->conv->f +// (f)->relu->g +ProgramDesc BuildProgramDesc() { + ProgramDesc prog; + for (auto& v : + std::vector({"a", "b", "c", "weights", "bias", "f", "g"})) { + auto* var = prog.MutableBlock(0)->Var(v); + var->SetType(proto::VarType::SELECTED_ROWS); + if (v == "weights" || v == "bias") { + var->SetPersistable(true); + } + } + + SetOp(&prog, "OP0", std::vector({"a"}), + std::vector({"b"})); + SetOp(&prog, "OP1", std::vector({"b"}), + std::vector({"c"})); + SetOp(&prog, "conv2d", std::vector({"c", "weights", "bias"}), + std::vector({"f"})); + SetOp(&prog, "relu", std::vector({"f"}), + std::vector({"g"})); + + return prog; +} + +TEST(ConvReLUFusePass, basic) { + auto prog = BuildProgramDesc(); + + std::unique_ptr graph(new ir::Graph(prog)); + + auto pass = PassRegistry::Instance().Get("conv_relu_mkldnn_fuse_pass"); + + int original_nodes_num = graph->Nodes().size(); + + graph = pass->Apply(std::move(graph)); + + int current_nodes_num = graph->Nodes().size(); + + // Remove 3 Nodes: CONV, RELU, conv_out + // Add 1 Node: ConvReLU + EXPECT_EQ(original_nodes_num - 2, current_nodes_num); + + // Assert conv_relu op in newly generated graph + int conv_relu_count = 0; + + for (auto* node : graph->Nodes()) { + if (node->IsOp() && node->Op()->Type() == "conv2d") { + if (node->Op()->HasAttr("use_mkldnn")) { + bool use_mkldnn = boost::get(node->Op()->GetAttr("use_mkldnn")); + if (use_mkldnn) { + if (node->Op()->HasAttr("fuse_relu")) { + bool fuse_relu = boost::get(node->Op()->GetAttr("fuse_relu")); + if (fuse_relu) { + ++conv_relu_count; + } + } + } + } + } + } + EXPECT_EQ(conv_relu_count, 1); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(conv_relu_mkldnn_fuse_pass); diff --git a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc index f7fda873574a0f8b10251d4fa6b604a9312ad7f9..aa95d3e9f6c8221f6e48d192b73ad5135539dc75 100644 --- a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc @@ -51,7 +51,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, if (with_fc_bias) { // Add FC-bias with LSTM-bias and create a new weight PADDLE_ENFORCE(scope); - const std::string& new_bias_var = name_scope + "_bias.new"; + const std::string& new_bias_var = patterns::UniqueKey("NewBias"); auto* bias_var = scope->Var(new_bias_var); PADDLE_ENFORCE(bias_var); auto* bias_tensor = bias_var->GetMutable(); @@ -120,7 +120,6 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { - GET_IR_NODE_FROM_SUBGRAPH(lstm, lstm, lstm_pattern); GET_IR_NODE_FROM_SUBGRAPH(Weight, Weight, lstm_pattern); GET_IR_NODE_FROM_SUBGRAPH(Bias, Bias, lstm_pattern); @@ -136,7 +135,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, fc_bias); // Remove unneeded nodes. std::unordered_set marked_nodes( - {mul, lstm, elementwise_add}); + {mul, lstm, elementwise_add, fc_bias}); GraphSafeRemoveNodes(graph, marked_nodes); } else { GET_IR_NODE_FROM_SUBGRAPH(fc_out, mul_out, fc_pattern); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 5825a129b731afe9de468b5a611c25ac2753aa3f..11d5998aafe1f325b94ef1a5ea1c13c72c13f5c9 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -522,6 +522,39 @@ bool VarLinksFromOp(Node* node, const std::string& op_type) { return false; } +PDNode* patterns::ConvReLU::operator()( + paddle::framework::ir::PDNode* conv_input) { + // Create Operators + conv_input->assert_is_op_input("conv2d", "Input"); + auto* conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d"); + auto* relu_op = pattern->NewNode(relu_repr())->assert_is_op("relu"); + // Create variables + // Filter + auto* conv_weight_var = pattern->NewNode(conv_weight_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("conv2d", "Filter"); + // Bias + auto* conv_bias_var = pattern->NewNode(conv_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("conv2d", "Bias"); + // intermediate variable, will be removed in the IR after fuse. + auto* conv_out_var = pattern->NewNode(conv_out_repr()) + ->AsIntermediate() + ->assert_is_only_output_of_op("conv2d") + ->assert_is_op_input("relu"); + // output + auto* relu_out_var = pattern->NewNode(relu_out_repr()) + ->AsOutput() + ->assert_is_op_output("relu"); + + conv_op->LinksFrom({conv_input, conv_weight_var, conv_bias_var}) + .LinksTo({conv_out_var}); + relu_op->LinksFrom({conv_out_var}).LinksTo({relu_out_var}); + return relu_out_var; +} + PDNode* patterns::FC::operator()(paddle::framework::ir::PDNode* x, bool with_bias) { // Create shared nodes. diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 57482a07b607ba1d9fa06a5f325f60ba58dce307..371384dc56eec91db1f621c0ebb65113e7a5a5cc 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -360,6 +360,28 @@ struct PatternBase { size_t id_; }; +// CONV with ReLU +// op: conv + relu +// named nodes: +// conv_input, conv_weight, +// conv_bias, conv_out, conv, +// relu_out, relu +struct ConvReLU : public PatternBase { + ConvReLU(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "conv_relu") {} + + PDNode* operator()(PDNode* conv_input); + + // declare operator node's name + PATTERN_DECL_NODE(conv); + PATTERN_DECL_NODE(relu); + // declare variable node's name + PATTERN_DECL_NODE(conv_weight); + PATTERN_DECL_NODE(conv_bias); + PATTERN_DECL_NODE(conv_out); + PATTERN_DECL_NODE(relu_out); +}; + // FC with bias // op: mul + elementwise_add // named nodes: diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index d58d6e4f3e684b97fcc1121e51355bdf3aae3fce..b7fae7171a57666a8fb4613a7cbe3aa15997b638 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -464,35 +464,35 @@ class RuntimeInferShapeContext : public InferShapeContext { : op_(op), scope_(scope) {} bool HasInput(const std::string& name) const override { - if (!op_.HasInputs(name)) { + // has only one input + const auto& ins = op_.Inputs(); + auto it = ins.find(name); + if (it == ins.end()) { return false; } - auto& ins = Inputs(name); - size_t length = ins.size(); - if (length == 0) { + const auto& in = it->second; + if (in.size() == 0 || in[0] == kEmptyVarName) { return false; } - PADDLE_ENFORCE_EQ(length, 1UL, + PADDLE_ENFORCE_EQ(in.size(), 1UL, "Input %s should not have more than one inputs", name); - auto ipt = ins[0]; - auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); - return var != nullptr; + return scope_.FindVar(in[0]) != nullptr; } bool HasOutput(const std::string& name) const override { - if (!op_.HasOutputs(name)) { + // has only one output + const auto& outs = op_.Outputs(); + auto it = outs.find(name); + if (it == outs.end()) { return false; } - auto& outs = Outputs(name); - size_t length = outs.size(); - if (length == 0) { + const auto& out = it->second; + if (out.size() == 0 || out[0] == kEmptyVarName) { return false; } - PADDLE_ENFORCE_EQ(length, 1UL, - "Output %s should not have more than one inputs", name); - auto ipt = outs[0]; - auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); - return var != nullptr; + PADDLE_ENFORCE_EQ(out.size(), 1UL, + "Output %s should not have more than one outputs", name); + return scope_.FindVar(out[0]) != nullptr; } bool HasInputs(const std::string& name) const override { diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 81cb24bdda6b87a3d708cf5047dce05d5020a0d5..5b8c75a93de2ddd8f7260d2191c22a5945b3d2d9 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -352,7 +352,10 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes( ParallelExecutor::~ParallelExecutor() { if (member_->own_local_scope_) { for (size_t i = 1; i < member_->local_scopes_.size(); ++i) { - member_->global_scope_->DeleteScope(member_->local_scopes_[i]); + Scope *local_scope = member_->local_scopes_[i]; + if (member_->global_scope_->HasKid(local_scope)) { + member_->global_scope_->DeleteScope(local_scope); + } } } } diff --git a/paddle/fluid/framework/program_desc_test.cc b/paddle/fluid/framework/program_desc_test.cc index 925ea98dbe62e4da91689f6e56c135e51c24a8a3..7e689a37da8a16bd9b1ac6650b9322d2eb5a2c85 100644 --- a/paddle/fluid/framework/program_desc_test.cc +++ b/paddle/fluid/framework/program_desc_test.cc @@ -87,8 +87,17 @@ TEST(ProgramDesc, copy_ctor) { ASSERT_EQ(op_origin->Inputs(), op_copy->Inputs()); ASSERT_EQ(op_origin->Outputs(), op_copy->Outputs()); - ASSERT_EQ(op_copy->Proto()->SerializeAsString(), - op_origin->Proto()->SerializeAsString()); + ASSERT_EQ(op_origin->Proto()->attrs().size(), + op_copy->Proto()->attrs().size()); + for (auto it = op_origin->Proto()->attrs().begin(); + it != op_origin->Proto()->attrs().end(); ++it) { + for (auto it_2 = op_copy->Proto()->attrs().begin(); + it_2 != op_copy->Proto()->attrs().end(); ++it_2) { + if (it->name() == it_2->name()) { + ASSERT_TRUE(it_2->SerializeAsString() == it->SerializeAsString()); + } + } + } if (op->Type() == "op_with_subblock") { ASSERT_EQ(1, op->GetBlockAttrId("sub_block")); diff --git a/paddle/fluid/framework/scope.cc b/paddle/fluid/framework/scope.cc index 50f374e3703a97f6c1fdb4b14fdeb0b603f9ac86..2be655b89a4caf2bf9874dcab6bc0bdb2856a026 100644 --- a/paddle/fluid/framework/scope.cc +++ b/paddle/fluid/framework/scope.cc @@ -72,6 +72,12 @@ void Scope::DropKids() { kids_.clear(); } +bool Scope::HasKid(const Scope* scope) const { + std::unique_lock lock(mutex_); + auto it = std::find(this->kids_.begin(), this->kids_.end(), scope); + return it != this->kids_.end(); +} + std::vector Scope::LocalVarNames() const { std::unique_lock lock(mutex_); std::vector known_vars; diff --git a/paddle/fluid/framework/scope.h b/paddle/fluid/framework/scope.h index e246241c0abfbc7bdcaf38d073cc58fc36a4f737..b6165a595d537c314a95685e8b1edbc42e387ab7 100644 --- a/paddle/fluid/framework/scope.h +++ b/paddle/fluid/framework/scope.h @@ -71,6 +71,9 @@ class Scope { /// Drop all kids scopes belonged to this scope. void DropKids(); + /// Find if a scope exists in the kid scopes + bool HasKid(const Scope* scope) const; + // enumerate all the variables current contains. std::vector LocalVarNames() const; diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index 30c1e8e93d2513dca4531d224fb939ec47a7739d..e76708baf4b39afb0febbcf3ff71281dfbfc8627 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/inference/analysis/ir_pass_manager.h" #include +#include #include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/scope.h" @@ -37,13 +38,16 @@ IRPassManager::IRPassManager(const ProgramDesc &program, void IRPassManager::Apply(const std::vector &passes) { // Apply all the passes std::string pre_pass; + int pass_num = 0; for (const std::string &pass_name : passes) { PrettyLogEndl(Style::H2(), "--- Running IR pass [%s]", pass_name); auto pass = framework::ir::PassRegistry::Instance().Get(pass_name); if (pass_name == "graph_viz_pass") { - std::string dot_file_path = - "ir_" + (pre_pass.empty() ? "origin" : pre_pass) + ".dot"; + std::string dot_file_path = std::to_string(pass_num) + "_ir_" + + (pre_pass.empty() ? "origin" : pre_pass) + + ".dot"; pass->Set("graph_viz_path", new std::string(std::move(dot_file_path))); + pass_num++; } graph_ = pass->Apply(std::move(graph_)); pre_pass = pass_name; diff --git a/paddle/fluid/inference/api/api_impl.cc b/paddle/fluid/inference/api/api_impl.cc index bd9b4b1a814f995e3979105f5b9830b95fd8ea7d..6fe13ed027de403bdc21882c26225bcd4cc7e49a 100644 --- a/paddle/fluid/inference/api/api_impl.cc +++ b/paddle/fluid/inference/api/api_impl.cc @@ -262,7 +262,7 @@ void NativePaddlePredictor::GetFetchOne(const framework::LoDTensor &fetch, if (buffer.empty() || buffer.length() < sizeof(T) * data.size()) { buffer.Resize(sizeof(T) * data.size()); } - std::memcpy(buffer.data(), data.data(), buffer.length()); + std::memcpy(buffer.data(), data.data(), sizeof(T) * data.size()); // copy LoD for (const auto &level : fetch.lod()) { output->lod.emplace_back(level); diff --git a/paddle/fluid/inference/tests/api/analyzer_lac_tester.cc b/paddle/fluid/inference/tests/api/analyzer_lac_tester.cc index 522d870db8583aac4006e8cdb7909625c3feb34b..7e00cb20ad0ce052a84d5491b0cdf167f0768081 100644 --- a/paddle/fluid/inference/tests/api/analyzer_lac_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_lac_tester.cc @@ -117,34 +117,6 @@ void GetOneBatch(std::vector *input_slots, DataRecord *data, input_slots->assign({input_tensor}); } -void BenchAllData(const std::string &model_path, const std::string &data_file, - const int batch_size, const int repeat) { - NativeConfig config; - config.model_dir = model_path; - config.use_gpu = false; - config.device = 0; - config.specify_input_name = true; - std::vector input_slots, outputs_slots; - DataRecord data(data_file, batch_size); - auto predictor = - CreatePaddlePredictor(config); - GetOneBatch(&input_slots, &data, batch_size); - for (int i = 0; i < FLAGS_burning; i++) { - predictor->Run(input_slots, &outputs_slots); - } - Timer timer; - double sum = 0; - for (int i = 0; i < repeat; i++) { - for (size_t bid = 0; bid < data.batched_datas.size(); ++bid) { - GetOneBatch(&input_slots, &data, batch_size); - timer.tic(); - predictor->Run(input_slots, &outputs_slots); - sum += timer.toc(); - } - } - PrintTime(batch_size, repeat, 1, 0, sum / repeat); -} - const int64_t lac_ref_data[] = {24, 25, 25, 25, 38, 30, 31, 14, 15, 44, 24, 25, 25, 25, 25, 25, 44, 24, 25, 25, 25, 36, 42, 43, 44, 14, 15, 44, 14, 15, 44, 14, 15, 44, 38, 39, diff --git a/paddle/fluid/inference/tests/api/analyzer_ner_tester.cc b/paddle/fluid/inference/tests/api/analyzer_ner_tester.cc index 661b047ed7cb70545267e468d8c2c48596a2994c..6e8e43add7d3383fa79efea91c23750be9c8956f 100644 --- a/paddle/fluid/inference/tests/api/analyzer_ner_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_ner_tester.cc @@ -144,8 +144,9 @@ void TestChineseNERPrediction(bool use_analysis) { size_t num_samples; for (int i = 0; i < FLAGS_repeat; i++) { DataRecord data(FLAGS_infer_data, FLAGS_batch_size); + // Just one batch, the num_samples remains the same. num_samples = data.num_samples; - for (size_t bid = 0; bid < num_samples; ++bid) { + for (size_t bid = 0; bid < num_samples / FLAGS_batch_size; ++bid) { PrepareInputs(&input_slots, &data, FLAGS_batch_size); timer.tic(); predictor->Run(input_slots, &outputs); diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc index 39b0c856996c11c6efdb530f1396afd5731c778d..9b943440a869e213db4ed761cfe7c508bc5e94ae 100644 --- a/paddle/fluid/operators/attention_lstm_op.cc +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -24,28 +24,28 @@ namespace operators { void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of AttentionLSTM should not be null."); + "Assert only one Input(X) of AttentionLSTM."); PADDLE_ENFORCE(ctx->HasInput("C0"), - "Input(C0) of AttentionLSTM should not be null."); + "Assert only one Input(C0) of AttentionLSTM."); PADDLE_ENFORCE(ctx->HasInput("LSTMWeight"), - "Input(LSTMWeight) of AttentionLSTM should not be null."); + "Assert only one Input(LSTMWeight) of AttentionLSTM."); PADDLE_ENFORCE(ctx->HasInput("LSTMBias"), - "Input(LSTMBias) of AttentionLSTM should not be null."); + "Assert only one Input(LSTMBias) of AttentionLSTM."); PADDLE_ENFORCE(ctx->HasInput("AttentionWeight"), - "Input(AttentionWeight) of AttentionLSTM should not be null."); + "Assert only one Input(AttentionWeight) of AttentionLSTM."); PADDLE_ENFORCE(ctx->HasOutput("Hidden"), - "Output(Hidden) of AttentionLSTM should not be null."); + "Assert only one Output(Hidden) of AttentionLSTM."); PADDLE_ENFORCE(ctx->HasOutput("Cell"), - "Output(Cell) of AttentionLSTM should not be null."); + "Assert only one Output(Cell) of AttentionLSTM."); PADDLE_ENFORCE(ctx->HasOutput("AttentionedX"), - "Output(AttentionedX) of AttentionLSTM should not be null."); + "Assert only one Output(AttentionedX) of AttentionLSTM."); PADDLE_ENFORCE(ctx->HasOutput("AttentionFCOut"), - "Output(AttentionFCOut) of AttentionLSTM should not be null."); + "Assert only one Output(AttentionFCOut) of AttentionLSTM."); PADDLE_ENFORCE(ctx->HasOutput("LSTMX"), - "Output(LSTMX) of AttentionLSTM should not be null."); + "Assert only one Output(LSTMX) of AttentionLSTM."); PADDLE_ENFORCE(ctx->HasOutput("LSTMOUT"), - "Output(LSTMOUT) of AttentionLSTM should not be null."); + "Assert only one Output(LSTMOUT) of AttentionLSTM."); auto x_dims = ctx->GetInputDim("X"); const int M = x_dims[1]; diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc index c5cbadc892904dc064b49ebc461944c4671a69da..3eb02c6b61ce61140bd777647a12477dd9c3c803 100644 --- a/paddle/fluid/operators/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/conv_mkldnn_op.cc @@ -130,12 +130,13 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler { std::shared_ptr AcquireWeightsMemoryFromPrimitive( const std::shared_ptr user_weights_memory_p, - std::vector& pipeline) { // NOLINT + std::vector& pipeline, // NOLINT + bool is_persistent = false) { auto user_weights_pd = user_weights_memory_p->get_primitive_desc(); auto weights_pd = conv_pd_->weights_primitive_desc(); return this->AcquireMemory(weights_pd, user_weights_pd, user_weights_memory_p, "@weights_mem_p", - pipeline); + pipeline, is_persistent); } std::shared_ptr AcquireBiasMemoryFromPrimitive( @@ -266,6 +267,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), "It must use CPUPlace."); + const bool is_test = ctx.Attr("is_test"); + auto& dev_ctx = ctx.template device_context(); const auto& mkldnn_engine = dev_ctx.GetEngine(); @@ -296,6 +299,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { std::vector strides = ctx.Attr>("strides"); std::vector paddings = ctx.Attr>("paddings"); std::vector dilations = ctx.Attr>("dilations"); + bool fuse_relu = ctx.Attr("fuse_relu"); int groups = ctx.Attr("groups"); // TODO(pzelazko-intel) add support for group convolution and dilation @@ -348,11 +352,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { bias_tz = paddle::framework::vectorize2int(bias->dims()); auto bias_md = platform::MKLDNNMemDesc( bias_tz, platform::MKLDNNGetDataType(), memory::format::x); - conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, - strides, paddings, mkldnn_engine); + conv_pd = + ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, strides, + paddings, mkldnn_engine, fuse_relu); } else { conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, - paddings, mkldnn_engine); + paddings, mkldnn_engine, fuse_relu); } // Save conv_pd/src_memory/weights_memory for backward pass dev_ctx.SetBlob(key_conv_pd, conv_pd); @@ -371,7 +376,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { auto src_memory_p = handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline); auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive( - user_weights_memory_p, pipeline); + user_weights_memory_p, pipeline, is_test); auto dst_memory_p = handler.AcquireDstMemoryFromPrimitive(to_void_cast(output_data)); @@ -402,11 +407,26 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { } private: + mkldnn::primitive_attr AddRelu() const { + // Fusion with ReLU layer is executed through the PostOps feature. Create a + // PostOps object and configure it to execute an eltwise relu operation. + mkldnn::primitive_attr conv_attr; + constexpr float scale = 1.0f; + constexpr float negative_slope = 0.0f; + constexpr float placeholder = 0.0f; + mkldnn::post_ops post_operations; + post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu, + negative_slope, placeholder); + conv_attr.set_post_ops(post_operations); + return conv_attr; + } + std::unique_ptr ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights, const memory::desc& dst, const std::vector& strides, const std::vector& paddings, - const mkldnn::engine& engine) const { + const mkldnn::engine& engine, + const bool fuse_relu) const { memory::dims stride_dims = {strides[0], strides[1]}; memory::dims padding_dims = {paddings[0], paddings[1]}; @@ -415,8 +435,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { dst, stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero); - auto p_conv_pd = - new mkldnn::convolution_forward::primitive_desc(conv_desc, engine); + mkldnn::primitive_attr conv_attr; + if (fuse_relu) { + conv_attr = AddRelu(); + } + + auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc( + conv_desc, conv_attr, engine); return std::unique_ptr( p_conv_pd); @@ -427,7 +452,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { const memory::desc& bias, const memory::desc& dst, const std::vector& strides, const std::vector& paddings, - const mkldnn::engine& engine) const { + const mkldnn::engine& engine, + const bool fuse_relu) const { memory::dims stride_dims = {strides[0], strides[1]}; memory::dims padding_dims = {paddings[0], paddings[1]}; @@ -436,8 +462,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { bias, dst, stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero); - auto p_conv_pd = - new mkldnn::convolution_forward::primitive_desc(conv_desc, engine); + mkldnn::primitive_attr conv_attr; + if (fuse_relu) { + conv_attr = AddRelu(); + } + + auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc( + conv_desc, conv_attr, engine); return std::unique_ptr( p_conv_pd); diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 61ca80877a6dfcdf30a0ff346342116e36eec6f2..41d4fcf6de7c8fcb3cfbb2063b0a2ac1a2356168 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -109,6 +109,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( } void Conv2DOpMaker::Make() { + AddAttr("is_test", "").SetDefault(false); AddInput( "Input", "(Tensor) The input tensor of convolution operator. " @@ -161,6 +162,8 @@ void Conv2DOpMaker::Make() { AddAttr("use_mkldnn", "(bool, default false) Only used in mkldnn kernel") .SetDefault(false); + AddAttr("fuse_relu", "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); AddAttr( "data_format", "(string, default NCHW) Only used in " diff --git a/paddle/fluid/operators/distributed/CMakeLists.txt b/paddle/fluid/operators/distributed/CMakeLists.txt index da5d20505e9b06c0717af8d79d5456a9ade1e89c..56734b81e8716a0c0c37a11e35c9118ee7b55020 100644 --- a/paddle/fluid/operators/distributed/CMakeLists.txt +++ b/paddle/fluid/operators/distributed/CMakeLists.txt @@ -20,6 +20,7 @@ if(WITH_GRPC) DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL) cc_test(rpc_server_test SRCS rpc_server_test.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL) + cc_test(varhandle_test SRCS varhandle_test.cc) return() endif() diff --git a/paddle/fluid/operators/distributed/grpc_client.cc b/paddle/fluid/operators/distributed/grpc_client.cc index b4f60c9ff9a41d5cb7dbe4e7a7694a84bab8e940..07ac20797ddab54296a45e99915588a40cc6f3c7 100644 --- a/paddle/fluid/operators/distributed/grpc_client.cc +++ b/paddle/fluid/operators/distributed/grpc_client.cc @@ -59,40 +59,32 @@ GRPCClient::~GRPCClient() { } channels_.clear(); } - client_thread_->join(); } -bool GRPCClient::AsyncSendVar(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& var_name, int64_t time_out) { +VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& var_name, + int64_t time_out) { const platform::DeviceContext* p_ctx = &ctx; const std::string ep_val = ep; const std::string var_name_val = var_name; const framework::Scope* p_scope = &scope; const auto ch = GetChannel(ep_val); + SendProcessor* s = new SendProcessor(ch); + VarHandlePtr h(new VarHandle(ep, "Send", var_name_val, p_ctx, p_scope)); + s->Prepare(h, time_out); - framework::AsyncIO([var_name_val, p_ctx, ep_val, p_scope, time_out, ch, - this] { + framework::AsyncIO([var_name_val, p_scope, p_ctx, s, this] { auto* var = p_scope->FindVar(var_name_val); ::grpc::ByteBuffer req; SerializeToByteBuffer(var_name_val, var, *p_ctx, &req); - // varhandle - VarHandle var_h; - var_h.ep = ep_val; - var_h.scope = p_scope; - var_h.name = var_name_val; - var_h.ctx = p_ctx; - var_h.method = "Send"; - - VLOG(3) << var_h.String() << " begin"; + VLOG(3) << s->GetVarHandlePtr()->String() << " begin"; // stub context - SendProcessor* s = new SendProcessor(ch); - s->Prepare(var_h, time_out); s->response_call_back_ = nullptr; auto call = s->stub_g_.PrepareUnaryCall( @@ -102,13 +94,13 @@ bool GRPCClient::AsyncSendVar(const std::string& ep, }); req_count_++; - return true; + return h; } void ProcGetResponse(const VarHandle& var_h, const ::grpc::ByteBuffer& ret_msg) { framework::Variable* outvar = nullptr; - DeserializeFromByteBuffer(ret_msg, *var_h.ctx, var_h.scope, &outvar); + DeserializeFromByteBuffer(ret_msg, *var_h.ctx(), var_h.scope(), &outvar); } template @@ -119,37 +111,30 @@ void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) { result->Swap(&tmp); } -bool GRPCClient::AsyncGetVar(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& var_name, int64_t time_out) { +VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& var_name, + int64_t time_out) { const platform::DeviceContext* p_ctx = &ctx; const std::string ep_val = ep; const std::string var_name_val = var_name; const framework::Scope* p_scope = &scope; const auto ch = GetChannel(ep_val); + GetProcessor* s = new GetProcessor(ch); + VarHandlePtr h(new VarHandle(ep, "Get", var_name_val, p_ctx, p_scope)); + s->Prepare(h, time_out); - framework::AsyncIO([var_name_val, ep_val, p_scope, p_ctx, time_out, ch, - this] { + framework::AsyncIO([var_name_val, p_scope, p_ctx, s, this] { // prepare input sendrecv::VariableMessage req; req.set_varname(var_name_val); ::grpc::ByteBuffer buf; RequestToByteBuffer(req, &buf); - // var handle - VarHandle var_h; - var_h.ep = ep_val; - var_h.scope = p_scope; - var_h.name = var_name_val; - var_h.ctx = p_ctx; - var_h.method = "Get"; - - VLOG(3) << var_h.String() << " begin"; + VLOG(3) << s->GetVarHandlePtr()->String() << " begin"; // stub context - GetProcessor* s = new GetProcessor(ch); - s->Prepare(var_h, time_out); s->response_call_back_ = ProcGetResponse; auto call = s->stub_g_.PrepareUnaryCall( @@ -160,42 +145,36 @@ bool GRPCClient::AsyncGetVar(const std::string& ep, req_count_++; - return true; + return h; } -bool GRPCClient::AsyncPrefetchVar(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& in_var_name, - const std::string& out_var_name, - int64_t time_out) { +VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& in_var_name, + const std::string& out_var_name, + int64_t time_out) { const platform::DeviceContext* p_ctx = &ctx; const std::string ep_val = ep; const std::string in_var_name_val = in_var_name; const std::string out_var_name_val = out_var_name; const framework::Scope* p_scope = &scope; const auto ch = GetChannel(ep_val); + GetProcessor* s = new GetProcessor(ch); + VarHandlePtr h( + new VarHandle(ep, "Prefetch", out_var_name_val, p_ctx, p_scope)); + s->Prepare(h, time_out); framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx, - time_out, ch, this] { + time_out, s, this] { auto* var = p_scope->FindVar(in_var_name_val); ::grpc::ByteBuffer req; SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val); - // var handle - VarHandle var_h; - var_h.ep = ep_val; - var_h.scope = p_scope; - var_h.name = out_var_name_val; - var_h.ctx = p_ctx; - var_h.method = "Prefetch"; - - VLOG(3) << var_h.String() << " begin"; + VLOG(3) << s->GetVarHandlePtr()->String() << " begin"; // stub context - GetProcessor* s = new GetProcessor(ch); - s->Prepare(var_h, time_out); s->response_call_back_ = ProcGetResponse; auto call = s->stub_g_.PrepareUnaryCall( @@ -206,56 +185,68 @@ bool GRPCClient::AsyncPrefetchVar(const std::string& ep, }); req_count_++; - return true; + return h; } -void GRPCClient::AsyncSendBatchBarrier(const std::string& ep, - int64_t time_out) { +VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep, + int64_t time_out) { const auto ch = GetChannel(ep); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); - s->Prepare(time_out); + VarHandlePtr h(new VarHandle(ep, "BatchBarrier", BATCH_BARRIER_MESSAGE, + nullptr, nullptr)); + s->Prepare(h, time_out); sendrecv::VariableMessage req; req.set_varname(BATCH_BARRIER_MESSAGE); auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); req_count_++; + return h; } -void GRPCClient::AsyncSendFetchBarrier(const std::string& ep, - int64_t time_out) { +VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep, + int64_t time_out) { const auto ch = GetChannel(ep); FetchBarrierProcessor* s = new FetchBarrierProcessor(ch); - s->Prepare(time_out); + VarHandlePtr h(new VarHandle(ep, "FetchBarrier", FETCH_BARRIER_MESSAGE, + nullptr, nullptr)); + s->Prepare(h, time_out); sendrecv::VariableMessage req; req.set_varname(FETCH_BARRIER_MESSAGE); auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); req_count_++; + return h; } -void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) { +VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep, + int64_t time_out) { const auto ch = GetChannel(ep); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); - s->Prepare(time_out); + VarHandlePtr h( + new VarHandle(ep, "SendComplete", COMPLETE_MESSAGE, nullptr, nullptr)); + s->Prepare(h, time_out); sendrecv::VariableMessage req; req.set_varname(COMPLETE_MESSAGE); auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); req_count_++; + return h; } -void GRPCClient::AsyncCheckpointNotify(const std::string& ep, - const std::string& dir, - int64_t time_out) { +VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep, + const std::string& dir, + int64_t time_out) { const auto ch = GetChannel(ep); CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch); - s->Prepare(time_out); + VarHandlePtr h(new VarHandle(ep, "CheckPointNotify", CHECKPOINT_SAVE_MESSAGE, + nullptr, nullptr)); + s->Prepare(h, time_out); sendrecv::VariableMessage req; req.set_varname(CHECKPOINT_SAVE_MESSAGE); @@ -264,6 +255,7 @@ void GRPCClient::AsyncCheckpointNotify(const std::string& ep, auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); req_count_++; + return h; } bool GRPCClient::Wait() { @@ -276,25 +268,28 @@ void GRPCClient::Proceed() { void* tag = nullptr; bool ok = false; + VLOG(3) << "GRPCClient Proceed begin"; while (!stopped_ && cq_.Next(&tag, &ok)) { BaseProcessor* c = static_cast(tag); GPR_ASSERT(ok); PADDLE_ENFORCE(c); if (c->status_.ok()) { - VLOG(3) << c->var_h_.String() << " process"; + VLOG(3) << c->GetVarHandlePtr()->String() << " process"; c->Process(); } else if (c->status_.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) { - LOG(ERROR) << c->var_h_.String() + LOG(ERROR) << c->GetVarHandlePtr()->String() << " meets grpc error:" << c->status_.error_message(); { std::lock_guard lk(sync_mutex_); ok_ = false; } - sync_cond_.notify_all(); + c->Finish(false); } else { - LOG(FATAL) << c->var_h_.String() + LOG(FATAL) << c->GetVarHandlePtr()->String() << " meets grpc error:" << c->status_.error_message(); + c->Finish(false); } + delete c; { std::lock_guard lk(sync_mutex_); @@ -302,6 +297,7 @@ void GRPCClient::Proceed() { } sync_cond_.notify_all(); } + VLOG(3) << "GRPCClient Proceed end"; } std::shared_ptr GRPCClient::GetChannel(const std::string& ep) { diff --git a/paddle/fluid/operators/distributed/grpc_client.h b/paddle/fluid/operators/distributed/grpc_client.h index 0c95ffeb5ce7e1586c5968fb122acd12c0c0196e..75a3662316462a222760bfbb7d7906c70f46d143 100644 --- a/paddle/fluid/operators/distributed/grpc_client.h +++ b/paddle/fluid/operators/distributed/grpc_client.h @@ -53,15 +53,14 @@ void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg); class BaseProcessor { public: - explicit BaseProcessor(std::shared_ptr ch) { - context_ = nullptr; - } + BaseProcessor() { context_ = nullptr; } virtual ~BaseProcessor() {} - virtual void Prepare(const VarHandle& var_info, int64_t time_out) { + virtual void Prepare(VarHandlePtr h, int64_t time_out) { + var_h_ = h; + context_.reset(new grpc::ClientContext()); - var_h_ = var_info; context_->set_wait_for_ready(true); if (time_out) { std::chrono::system_clock::time_point deadline = @@ -71,21 +70,21 @@ class BaseProcessor { } } - virtual void Prepare(int64_t time_out) { - context_.reset(new grpc::ClientContext()); - context_->set_wait_for_ready(true); - - std::chrono::system_clock::time_point deadline = - std::chrono::system_clock::now() + std::chrono::milliseconds(time_out); - - context_->set_deadline(deadline); + void Process() { + ProcessImpl(); + var_h_->Finish(true); } - virtual void Process() = 0; + VarHandlePtr GetVarHandlePtr() { return var_h_; } + bool Wait() { return var_h_->Wait(); } + void Finish(bool ok) { return var_h_->Finish(ok); } + virtual void ProcessImpl() = 0; std::unique_ptr context_; grpc::Status status_; - VarHandle var_h_; + + protected: + VarHandlePtr var_h_; }; typedef std::function @@ -94,13 +93,13 @@ typedef std::function class SendProcessor : public BaseProcessor { public: explicit SendProcessor(std::shared_ptr ch) - : BaseProcessor(ch), stub_g_(ch) {} + : BaseProcessor(), stub_g_(ch) {} virtual ~SendProcessor() {} - virtual void Process() { + void ProcessImpl() override { if (response_call_back_) { - response_call_back_(var_h_, reply_); + response_call_back_(*var_h_.get(), reply_); } } @@ -115,13 +114,13 @@ typedef std::function class GetProcessor : public BaseProcessor { public: explicit GetProcessor(std::shared_ptr ch) - : BaseProcessor(ch), stub_g_(ch) {} + : BaseProcessor(), stub_g_(ch) {} virtual ~GetProcessor() {} - virtual void Process() { + void ProcessImpl() override { if (response_call_back_) { - response_call_back_(var_h_, reply_); + response_call_back_(*var_h_.get(), reply_); } } @@ -133,13 +132,13 @@ class GetProcessor : public BaseProcessor { class BatchBarrierProcessor : public BaseProcessor { public: explicit BatchBarrierProcessor(std::shared_ptr ch) - : BaseProcessor(ch) { + : BaseProcessor() { stub_ = sendrecv::SendRecvService::NewStub(ch); } virtual ~BatchBarrierProcessor() {} - virtual void Process() {} + void ProcessImpl() override {} sendrecv::VoidMessage reply_; std::unique_ptr stub_; }; @@ -147,13 +146,13 @@ class BatchBarrierProcessor : public BaseProcessor { class FetchBarrierProcessor : public BaseProcessor { public: explicit FetchBarrierProcessor(std::shared_ptr ch) - : BaseProcessor(ch) { + : BaseProcessor() { stub_ = sendrecv::SendRecvService::NewStub(ch); } virtual ~FetchBarrierProcessor() {} - virtual void Process() {} + void ProcessImpl() override {} sendrecv::VariableMessage reply_; std::unique_ptr stub_; }; @@ -161,13 +160,13 @@ class FetchBarrierProcessor : public BaseProcessor { class CheckpointNotifyProcessor : public BaseProcessor { public: explicit CheckpointNotifyProcessor(std::shared_ptr ch) - : BaseProcessor(ch) { + : BaseProcessor() { stub_ = sendrecv::SendRecvService::NewStub(ch); } virtual ~CheckpointNotifyProcessor() {} - virtual void Process() {} + void ProcessImpl() override {} sendrecv::VoidMessage reply_; std::unique_ptr stub_; }; @@ -177,32 +176,37 @@ class GRPCClient : public RPCClient { GRPCClient() : ok_(true), completed_(false), stopped_(false) {} virtual ~GRPCClient(); - bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx, - const framework::Scope& scope, const std::string& var_name, - int64_t time_out = FLAGS_rpc_deadline) override; - - bool AsyncGetVar(const std::string& ep, const platform::DeviceContext& ctx, - const framework::Scope& scope, const std::string& var_name, - int64_t time_out = FLAGS_rpc_deadline) override; - - bool AsyncPrefetchVar(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& in_var_name, - const std::string& out_var_name, - int64_t time_out = FLAGS_rpc_deadline) override; - - void AsyncSendBatchBarrier(const std::string& ep, - int64_t time_out = FLAGS_rpc_deadline) override; - - void AsyncSendFetchBarrier(const std::string& ep, - int64_t time_out = FLAGS_rpc_deadline) override; - - void AsyncCheckpointNotify(const std::string& ep, const std::string& dir, - int64_t time_out = FLAGS_rpc_deadline) override; - - void AsyncSendComplete(const std::string& ep, - int64_t time_out = FLAGS_rpc_deadline) override; + VarHandlePtr AsyncSendVar(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& var_name, + int64_t time_out = FLAGS_rpc_deadline) override; + + VarHandlePtr AsyncGetVar(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& var_name, + int64_t time_out = FLAGS_rpc_deadline) override; + + VarHandlePtr AsyncPrefetchVar(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& in_var_name, + const std::string& out_var_name, + int64_t time_out = FLAGS_rpc_deadline) override; + + VarHandlePtr AsyncSendBatchBarrier( + const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override; + + VarHandlePtr AsyncSendFetchBarrier( + const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override; + + VarHandlePtr AsyncCheckpointNotify( + const std::string& ep, const std::string& dir, + int64_t time_out = FLAGS_rpc_deadline) override; + + VarHandlePtr AsyncSendComplete( + const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override; bool Wait() override; diff --git a/paddle/fluid/operators/distributed/request_handler.h b/paddle/fluid/operators/distributed/request_handler.h index 64ac7281848f91302bc0aa3cb81dd198e56fb653..3c3f9d17c871ac1cb4df83db17cf489d5b9e0563 100644 --- a/paddle/fluid/operators/distributed/request_handler.h +++ b/paddle/fluid/operators/distributed/request_handler.h @@ -28,6 +28,7 @@ #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/var_type.h" +#include "paddle/fluid/platform/macros.h" namespace paddle { namespace operators { @@ -49,23 +50,77 @@ constexpr char kRequestPassBarrier[] = "RequestPassBarrier"; class RPCServer; -struct VarHandle { - // RPC endpoint. - std::string ep; - const platform::DeviceContext* ctx; - const framework::Scope* scope; - // Variable name. - std::string name; - // RPC method name. - std::string method; +class VarHandle { + public: + VarHandle(const std::string ep, const std::string& method, + const std::string& name, + const platform::DeviceContext* p_ctx = nullptr, + const framework::Scope* p_scope = nullptr) + : ok_(kVarHandleDefaultState) { + ep_ = ep; + ctx_ = p_ctx; + scope_ = p_scope; + name_ = name; + method_ = method; + } + + virtual ~VarHandle() {} + + public: + bool Wait() { + { + std::unique_lock lk(sync_mutex_); + wait_cond_.wait(lk, [this] { return ok_ != kVarHandleDefaultState; }); + } + VLOG(7) << "VarHandle wait:" << ok_; + return ok_ != 0; + } + + void Finish(bool ok) { + { + std::unique_lock lk(sync_mutex_); + ok_ = ok; + } + VLOG(7) << "VarHandle finish:" << ok; + wait_cond_.notify_all(); + } std::string String() const { std::ostringstream s; - s << method << " name:[" << name << "], ep:[" << ep << "]"; + s << method_ << " name:[" << name_ << "], ep:[" << ep_ << "], ok:[" << ok_ + << "]"; return s.str(); } + + std::string ep() const { return ep_; } + const platform::DeviceContext* ctx() const { return ctx_; } + const framework::Scope* scope() const { return scope_; } + std::string name() const { return name_; } + std::string method() const { return method_; } + + protected: + // RPC endpoint. + std::string ep_; + const platform::DeviceContext* ctx_; + const framework::Scope* scope_; + // Variable name. + std::string name_; + // RPC method name. + std::string method_; + + protected: + std::mutex sync_mutex_; + std::condition_variable wait_cond_; + int ok_; + + static const int kVarHandleDefaultState = -1; + + private: + DISABLE_COPY_AND_ASSIGN(VarHandle); }; +typedef std::shared_ptr VarHandlePtr; + class RequestHandler { public: explicit RequestHandler(bool sync_mode) diff --git a/paddle/fluid/operators/distributed/rpc_client.h b/paddle/fluid/operators/distributed/rpc_client.h index 22a022a5d25e5c6628b80294494b87ca105a04c7..3539ee5e459d6dfe0b6510806464bcc6817910bb 100644 --- a/paddle/fluid/operators/distributed/rpc_client.h +++ b/paddle/fluid/operators/distributed/rpc_client.h @@ -14,12 +14,14 @@ #pragma once +#include // NOLINT #include #include "gflags/gflags.h" #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/operators/distributed/request_handler.h" DECLARE_int32(rpc_deadline); @@ -31,37 +33,36 @@ class RPCClient { public: RPCClient() {} virtual ~RPCClient() {} - virtual bool AsyncSendVar(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& var_name, - int64_t time_out = FLAGS_rpc_deadline) = 0; - - virtual bool AsyncGetVar(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& var_name, - int64_t time_out = FLAGS_rpc_deadline) = 0; - - virtual bool AsyncPrefetchVar(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& in_var_name, - const std::string& out_var_name, - int64_t time_out = FLAGS_rpc_deadline) = 0; - - virtual void AsyncSendBatchBarrier(const std::string& ep, - int64_t time_out = FLAGS_rpc_deadline) = 0; - - virtual void AsyncSendFetchBarrier(const std::string& ep, - int64_t time_out = FLAGS_rpc_deadline) = 0; - - virtual void AsyncCheckpointNotify(const std::string& ep, - const std::string& dir, - int64_t time_out = FLAGS_rpc_deadline) = 0; - - virtual void AsyncSendComplete(const std::string& ep, - int64_t time_out = FLAGS_rpc_deadline) = 0; + virtual VarHandlePtr AsyncSendVar(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& var_name, + int64_t time_out = FLAGS_rpc_deadline) = 0; + + virtual VarHandlePtr AsyncGetVar(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& var_name, + int64_t time_out = FLAGS_rpc_deadline) = 0; + + virtual VarHandlePtr AsyncPrefetchVar( + const std::string& ep, const platform::DeviceContext& ctx, + const framework::Scope& scope, const std::string& in_var_name, + const std::string& out_var_name, + int64_t time_out = FLAGS_rpc_deadline) = 0; + + virtual VarHandlePtr AsyncSendBatchBarrier( + const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) = 0; + + virtual VarHandlePtr AsyncSendFetchBarrier( + const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) = 0; + + virtual VarHandlePtr AsyncCheckpointNotify( + const std::string& ep, const std::string& dir, + int64_t time_out = FLAGS_rpc_deadline) = 0; + + virtual VarHandlePtr AsyncSendComplete( + const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) = 0; // Complete tells all the pserver instances that finishe the training, // the pserver can reduce it's barrier count, and continue to train diff --git a/paddle/fluid/operators/distributed/varhandle_test.cc b/paddle/fluid/operators/distributed/varhandle_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a0fcaf886475c5e03d959ffd6af22b2123526b9f --- /dev/null +++ b/paddle/fluid/operators/distributed/varhandle_test.cc @@ -0,0 +1,55 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include // NOLINT + +#include "google/protobuf/text_format.h" +#include "gtest/gtest.h" +#include "paddle/fluid/operators/distributed/request_handler.h" + +using paddle::operators::distributed::VarHandlePtr; +using paddle::operators::distributed::VarHandle; + +void WaitTrue(VarHandlePtr s) { EXPECT_TRUE(s->Wait()); } + +void WaitFalse(VarHandlePtr s) { EXPECT_FALSE(s->Wait()); } + +TEST(VarHandle, Run) { + std::vector a; + for (int i = 0; i < 12; i++) { + VarHandlePtr s(new VarHandle("", "", "", nullptr, nullptr)); + a.push_back(s); + } + + std::vector> t; + for (int i = 0; i < 6; i++) { + t.emplace_back(new std::thread(WaitFalse, a[i])); + } + + for (int i = 0; i < 6; i++) { + a[i]->Finish(false); + t[i]->join(); + } + + for (int i = 6; i < 12; i++) { + t.emplace_back(new std::thread(WaitTrue, a[i])); + } + + for (int i = 6; i < 12; i++) { + a[i]->Finish(true); + t[i]->join(); + } +} diff --git a/paddle/fluid/operators/fusion_gru_op.cc b/paddle/fluid/operators/fusion_gru_op.cc index 916f84cb4a78c3721cb67bd3cf8d3759a8eaf1bf..31e87d9113118ebe7a4b25ffee5ba55e2714fb66 100644 --- a/paddle/fluid/operators/fusion_gru_op.cc +++ b/paddle/fluid/operators/fusion_gru_op.cc @@ -25,14 +25,14 @@ namespace paddle { namespace operators { void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { - PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of GRU should not be null."); + PADDLE_ENFORCE(ctx->HasInput("X"), "Assert only one Input(X) of GRU."); PADDLE_ENFORCE(ctx->HasInput("WeightX"), - "Input(WeightX) of GRU should not be null."); + "Assert only one Input(WeightX) of GRU."); PADDLE_ENFORCE(ctx->HasInput("WeightH"), - "Input(WeightH) of GRU should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("XX"), "Output(XX) of GRU should not be null."); + "Assert only one Input(WeightH) of GRU."); + PADDLE_ENFORCE(ctx->HasOutput("XX"), "Assert only one Output(XX) of GRU."); PADDLE_ENFORCE(ctx->HasOutput("Hidden"), - "Output(Hidden) of GRU should not be null."); + "Assert only one Output(Hidden) of GRU."); auto x_dims = ctx->GetInputDim("X"); PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); @@ -80,11 +80,11 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { } else { xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"), - "Output(ReorderedH0) of GRU should not be null."); + "Assert only one Output(ReorderedH0) of GRU."); PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"), - "Output(BatchedInput) of GRU should not be null."); + "Assert only one Output(BatchedInput) of GRU."); PADDLE_ENFORCE(ctx->HasOutput("BatchedOut"), - "Output(BatchedOut) of GRU should not be null."); + "Assert only one Output(BatchedOut) of GRU."); ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]}); ctx->SetOutputDim("BatchedOut", out_dims); } diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index ef23ab3f981786d33567619ad0194d21f31bdc8e..55e465e3af08c012b8cff7714452ed32b32a5556 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -24,20 +24,17 @@ namespace paddle { namespace operators { void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { - PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("X"), "Assert only one Input(X) of LSTM."); PADDLE_ENFORCE(ctx->HasInput("WeightX"), - "Input(WeightX) of LSTM should not be null."); + "Assert only one Input(WeightX) of LSTM."); PADDLE_ENFORCE(ctx->HasInput("WeightH"), - "Input(WeightH) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Bias"), - "Input(Bias) of LSTM should not be null."); - - PADDLE_ENFORCE(ctx->HasOutput("XX"), - "Output(XX) of LSTM should not be null."); + "Assert only one Input(WeightH) of LSTM."); + PADDLE_ENFORCE(ctx->HasInput("Bias"), "Assert only one Input(Bias) of LSTM."); + PADDLE_ENFORCE(ctx->HasOutput("XX"), "Assert only one Output(XX) of LSTM."); PADDLE_ENFORCE(ctx->HasOutput("Hidden"), - "Output(Hidden) of LSTM should not be null."); + "Assert only one Output(Hidden) of LSTM."); PADDLE_ENFORCE(ctx->HasOutput("Cell"), - "Output(Cell) of LSTM should not be null."); + "Assert only one Output(Cell) of LSTM."); auto x_dims = ctx->GetInputDim("X"); PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); @@ -96,15 +93,15 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { } else { xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"), - "Output(BatchedInput) of LSTM should not be null."); + "Assert only one Output(BatchedInput) of LSTM."); PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"), - "Output(BatchedHidden) of LSTM should not be null."); + "Assert only one Output(BatchedHidden) of LSTM."); PADDLE_ENFORCE(ctx->HasOutput("BatchedCell"), - "Output(BatchedCell) of LSTM should not be null."); + "Assert only one Output(BatchedCell) of LSTM."); PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"), - "Output(ReorderedH0) of LSTM should not be null."); + "Assert only one Output(ReorderedH0) of LSTM"); PADDLE_ENFORCE(ctx->HasOutput("ReorderedC0"), - "Output(ReorderedC0) of LSTM should not be null."); + "Assert only one Output(ReorderedC0) of LSTM."); ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]}); ctx->SetOutputDim("BatchedHidden", out_dims); ctx->SetOutputDim("BatchedCell", out_dims); diff --git a/paddle/fluid/operators/prefetch_op.cc b/paddle/fluid/operators/prefetch_op.cc index 4b804740a06f9e29704f2b3f58a90191e3559347..0519c15e13aac99802ff0f95b975712b36b44246 100644 --- a/paddle/fluid/operators/prefetch_op.cc +++ b/paddle/fluid/operators/prefetch_op.cc @@ -44,16 +44,20 @@ class PrefetchOp : public framework::OperatorBase { distributed::RPCClient* rpc_client = distributed::RPCClient::GetInstance(); + std::vector rets; for (size_t i = 0; i < ins.size(); i++) { if (NeedSend(scope, ins[i])) { VLOG(3) << "sending " << ins[i] << " to " << epmap[i] << " to get " << outs[i] << " back"; - rpc_client->AsyncPrefetchVar(epmap[i], ctx, scope, ins[i], outs[i]); + rets.push_back(rpc_client->AsyncPrefetchVar(epmap[i], ctx, scope, + ins[i], outs[i])); } else { VLOG(3) << "don't send no-initialied variable: " << ins[i]; } } - PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); + for (size_t i = 0; i < rets.size(); i++) { + PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); + } } }; diff --git a/paddle/fluid/operators/recv_op.cc b/paddle/fluid/operators/recv_op.cc index a1f368e8690512cec2db7593aabc0279bbe174eb..4d34b8a1686efb1fc30020f0d27e9a3c3a6c0866 100644 --- a/paddle/fluid/operators/recv_op.cc +++ b/paddle/fluid/operators/recv_op.cc @@ -44,12 +44,15 @@ class RecvOp : public framework::OperatorBase { distributed::RPCClient* rpc_client = distributed::RPCClient::GetInstance(); + std::vector rets; for (size_t i = 0; i < outs.size(); i++) { VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; - rpc_client->AsyncGetVar(epmap[i], ctx, scope, outs[i]); + rets.push_back(rpc_client->AsyncGetVar(epmap[i], ctx, scope, outs[i])); } if (sync_mode) { - PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); + for (size_t i = 0; i < rets.size(); i++) { + PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); + } } } }; diff --git a/paddle/fluid/operators/send_op.cc b/paddle/fluid/operators/send_op.cc index 82a70e4bf13247d784371ffdf419c9f792d7f721..48322ac7fd54a2e4cc3405a2c4dcddfc273f5a66 100644 --- a/paddle/fluid/operators/send_op.cc +++ b/paddle/fluid/operators/send_op.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include // NOLINT #include +#include "paddle/fluid/framework/blocking_queue.h" #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" @@ -45,18 +46,19 @@ class SendOp : public framework::OperatorBase { distributed::RPCClient* rpc_client = distributed::RPCClient::GetInstance(); + std::vector rets; for (size_t i = 0; i < ins.size(); i++) { if (NeedSend(scope, ins[i])) { VLOG(3) << "sending " << ins[i] << " to " << epmap[i]; - // TODO(Yancey1989): we need to use an IO threadpool which has - // a larger number of threads than the computing threadpool. - rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i]); + rets.push_back(rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i])); } else { VLOG(3) << "don't send no-initialied variable: " << ins[i]; } } if (sync_send) { - PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); + for (size_t i = 0; i < rets.size(); i++) { + PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); + } } } }; diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu index 148faec4af50c4fe3a8e9d1f22e0da70c8ddcb44..a07c17348ebb3f768d1c8be65c2d31e3c130bd23 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu @@ -31,7 +31,8 @@ __global__ void CrossEntropyGrad(T* logit_grad, const int64_t* labels, for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch_size; i += blockDim.x * gridDim.x) { int idx = i * class_num + labels[i]; - logit_grad[idx] -= static_cast(1.); + logit_grad[idx] -= + ignore_index == labels[i] ? static_cast(0.) : static_cast(1.); } } diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index f6e9a52b275353c03c1f350719766922a97f6cb3..c0a2543ba5d8ff8f34cb6231c51cb5053a6a9481 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -192,7 +192,8 @@ class MKLDNNHandler { mkldnn::memory::primitive_desc& user_mpd, // NOLINT const std::shared_ptr user_memory_p, const std::string& suffix, - std::vector& pipeline) { // NOLINT + std::vector& pipeline, // NOLINT + bool is_persistent = false) { // create reorder primitive if the input format is not the preferred one auto local_key = key_ + suffix; auto key_reorder_p = key_ + suffix + "reorder_p"; @@ -213,7 +214,7 @@ class MKLDNNHandler { pipeline.push_back(*reorder_p); } dev_ctx_.SetBlob(local_key, target_memory_p); - } else { + } else if (!is_persistent) { // Make reorder if needed auto reorder_p = std::static_pointer_cast( dev_ctx_.GetBlob(key_reorder_p)); diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index 4790e0f6119e96b11b049bfdd3b46d40a382683b..bd9f8b3c356ca1e43923d42beebaa5ab98158084 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -128,6 +128,13 @@ class ParallelExecutor(object): os.environ.get('CPU_NUM', multiprocessing.cpu_count())) exec_strategy.num_threads = cpu_num * 2 + # Set 1 thread num under nccl2 distribute + # env to make sure all gpus run ops in same order. + if num_trainers > 1: + assert (use_cuda) + # FIXME(gongwb): avoid this set. + exec_strategy.num_threads = 1 + if build_strategy is None: build_strategy = BuildStrategy() diff --git a/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_resnet.py b/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_resnet.py index e5ae95e2d943917b9bc10f0d4c4bdc5f8fb07fdb..de276755bb1eb2746cc780575a40357255223809 100644 --- a/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_resnet.py +++ b/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_resnet.py @@ -178,7 +178,4 @@ if __name__ == '__main__': for parallel in (False, True): if use_cuda and not core.is_compiled_with_cuda(): continue - # TODO(minqiyang): remove this line after fixing the deletion - # order problem of Scope in ParallelExecutor in manylinux - if six.PY2: - main(use_cuda=use_cuda, parallel=parallel) + main(use_cuda=use_cuda, parallel=parallel) diff --git a/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_vgg.py b/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_vgg.py index ff91be72c918f8dac65b7030e45c4a00deb965ac..dd547f3448ae55c07b6c09f9de4ac08d8ec5ee88 100644 --- a/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_vgg.py +++ b/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_vgg.py @@ -152,7 +152,4 @@ if __name__ == '__main__': for parallel in (False, True): if use_cuda and not core.is_compiled_with_cuda(): continue - # TODO(minqiyang): remove this line after fixing the deletion - # order problem of Scope in ParallelExecutor in manylinux - if six.PY2: - main(use_cuda=use_cuda, parallel=parallel) + main(use_cuda=use_cuda, parallel=parallel) diff --git a/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_conv.py b/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_conv.py index fa72c939e57356f26d60032dd0a91c894b28c505..973308498bec3cddde2ef651751ad5d0c9f84503 100644 --- a/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_conv.py +++ b/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_conv.py @@ -155,7 +155,4 @@ if __name__ == '__main__': for parallel in (False, True): if use_cuda and not core.is_compiled_with_cuda(): continue - # TODO(minqiyang): remove this line after fixing the deletion - # order problem of Scope in ParallelExecutor in manylinux - if six.PY2: - main(use_cuda=use_cuda, parallel=parallel) + main(use_cuda=use_cuda, parallel=parallel) diff --git a/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_mlp.py b/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_mlp.py index 440d2a30835cb89089709f024a4dcc6e4113efa8..cb4aeb430e1a9662a183084c0cdacc41c5a8ec11 100644 --- a/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_mlp.py +++ b/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_mlp.py @@ -137,7 +137,4 @@ if __name__ == '__main__': for parallel in (False, True): if use_cuda and not core.is_compiled_with_cuda(): continue - # TODO(minqiyang): remove this line after fixing the deletion - # order problem of Scope in ParallelExecutor in manylinux - if six.PY2: - main(use_cuda=use_cuda, parallel=parallel) + main(use_cuda=use_cuda, parallel=parallel) diff --git a/python/paddle/fluid/tests/unittests/test_data_balance.py b/python/paddle/fluid/tests/unittests/test_data_balance.py index e39eedd282daf6bbe0603a22c357e06c95c086b6..4bd24510bc8ac7f0fbaad3fd1919ab589cd21c4b 100644 --- a/python/paddle/fluid/tests/unittests/test_data_balance.py +++ b/python/paddle/fluid/tests/unittests/test_data_balance.py @@ -84,7 +84,7 @@ class TestDataBalance(unittest.TestCase): self.data_file_name = './data_balance_test.recordio' self.lod_data_file_name = './data_balance_with_lod_test.recordio' self.total_ins_num = 50 - self.batch_size = 10 + self.batch_size = 12 self.prepare_data() self.prepare_lod_data() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py index 5ad922725a0b692e28552737a99b745ed09ddbd5..a55b2002ed989d4588716202a37aa6f4139825ea 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py @@ -20,6 +20,7 @@ import numpy as np from parallel_executor_test_base import TestParallelExecutorBase import unittest import paddle +import paddle.fluid.core as core import paddle.dataset.wmt16 as wmt16 import os @@ -170,7 +171,8 @@ class TestTransformer(TestParallelExecutorBase): writer.complete_append_tensor() def test_main(self): - self.check_network_convergence(transformer, use_cuda=True) + if core.is_compiled_with_cuda(): + self.check_network_convergence(transformer, use_cuda=True) self.check_network_convergence(transformer, use_cuda=False, iter=5) diff --git a/python/paddle/fluid/tests/unittests/test_py_reader_using_executor.py b/python/paddle/fluid/tests/unittests/test_py_reader_using_executor.py index 931cac409f26fce4ecca18c4b0cfcca2e675046f..b7fad9b3a60632adb564e1d155a3d935706b467f 100644 --- a/python/paddle/fluid/tests/unittests/test_py_reader_using_executor.py +++ b/python/paddle/fluid/tests/unittests/test_py_reader_using_executor.py @@ -96,7 +96,8 @@ class TestPyReaderUsingExecutor(unittest.TestCase): self.queue_capacity = 50 def test(self): - for use_cuda in [False, True]: + for use_cuda in ([False, True] + if core.is_compiled_with_cuda() else [False]): for use_parallel_executor in [False, True]: for use_double_buffer in [False, True]: print('Test Parameters:'), diff --git a/python/paddle/fluid/transpiler/inference_transpiler.py b/python/paddle/fluid/transpiler/inference_transpiler.py index 02fefe32df14dfca39de1da24147c284d6d82230..adad2428f7fdc554cf4efd652f52b5c5de0ab527 100644 --- a/python/paddle/fluid/transpiler/inference_transpiler.py +++ b/python/paddle/fluid/transpiler/inference_transpiler.py @@ -60,12 +60,46 @@ class InferenceTranspiler(object): if not isinstance(scope, core.Scope): raise TypeError("scope should be as Scope type or None") use_mkldnn = bool(os.getenv("FLAGS_use_mkldnn", False)) + self._fuse_batch_norm(program, place, scope) if use_mkldnn: - self._fuse_relu_mkldnn(program) self._fuse_conv_bias_mkldnn(program) + self._fuse_conv_relu_mkldnn(program) + self._fuse_bn_relu_mkldnn(program) + + def _fuse_conv_relu_mkldnn(self, program): + ''' + Transpile the program by fused relu activation for MKLDNN program. + Relu activation following convolution OP can be fused by adding + 'fuse_relu' attribute to convolution OP. + The result of fuse is: + - before: + - conv->relu->any_other_op + - after: + - conv->any_other_op + :param program: program to transpile + :type program: Program + ''' + self.block = program.block(0) + + i = 0 + while i < len(self.block.ops): + current_op = self.block.ops[i] + if current_op.type in ['conv2d']: + next_op = self.block.ops[i + 1] + if next_op.type == 'relu': + # modify conv OP to include relu + current_op.set_attr("fuse_relu", True) + # remove conv OP + self.block._remove_op(i + 1) + i = i + 1 + + # TODO(luotao): use clone() method to flush the program.desc in force, + # since some large program.desc will not be flushed immediately. + # And a better solution will be considered later. + program = program.clone() - def _fuse_relu_mkldnn(self, program): + def _fuse_bn_relu_mkldnn(self, program): ''' Transpile the program by fused relu activation for MKLDNN program. @@ -159,7 +193,6 @@ class InferenceTranspiler(object): self._fuse_conv_bias(i, current_op, next_op) self.block._remove_op(i + 1) # Remove old conv self.block._remove_op(i + 1) # Remove elementwise_add - i = i + 1 i = i + 1 self._remove_unused_var()