提交 329a8c52 编写于 作者: N nhzlx

merge develop

...@@ -19,7 +19,7 @@ Our vision is to enable deep learning for everyone via PaddlePaddle. ...@@ -19,7 +19,7 @@ Our vision is to enable deep learning for everyone via PaddlePaddle.
Please refer to our [release announcement](https://github.com/PaddlePaddle/Paddle/releases) to track the latest feature of PaddlePaddle. Please refer to our [release announcement](https://github.com/PaddlePaddle/Paddle/releases) to track the latest feature of PaddlePaddle.
### Latest PaddlePaddle Release: [Fluid 0.14.0](https://github.com/PaddlePaddle/Paddle/tree/v0.14.0) ### Latest PaddlePaddle Release: [Fluid 0.15.0](https://github.com/PaddlePaddle/Paddle/tree/v0.15.0)
### Install Latest Stable Release: ### Install Latest Stable Release:
``` ```
# Linux CPU # Linux CPU
...@@ -76,26 +76,26 @@ pip install paddlepaddle-gpu==0.14.0.post85 ...@@ -76,26 +76,26 @@ pip install paddlepaddle-gpu==0.14.0.post85
## Installation ## Installation
It is recommended to read [this doc](http://paddlepaddle.org/documentation/docs/zh/0.14.0/new_docs/beginners_guide/install/install_doc.html) on our website. It is recommended to read [this doc](http://paddlepaddle.org/documentation/docs/zh/0.15.0/new_docs/beginners_guide/install/install_doc.html) on our website.
## Documentation ## Documentation
We provide [English](http://paddlepaddle.org/documentation/docs/en/0.14.0/getstarted/index_en.html) and We provide [English](http://paddlepaddle.org/documentation/docs/en/0.15.0/getstarted/index_en.html) and
[Chinese](http://paddlepaddle.org/documentation/docs/zh/0.14.0/new_docs/beginners_guide/index.html) documentation. [Chinese](http://paddlepaddle.org/documentation/docs/zh/0.15.0/new_docs/beginners_guide/index.html) documentation.
- [Deep Learning 101](https://github.com/PaddlePaddle/book) - [Deep Learning 101](https://github.com/PaddlePaddle/book)
You might want to start from this online interactive book that can run in a Jupyter Notebook. You might want to start from this online interactive book that can run in a Jupyter Notebook.
- [Distributed Training](http://paddlepaddle.org/documentation/docs/zh/0.14.0/new_docs/user_guides/howto/training/cluster_howto.html) - [Distributed Training](http://paddlepaddle.org/documentation/docs/zh/0.15.0/new_docs/user_guides/howto/training/cluster_howto.html)
You can run distributed training jobs on MPI clusters. You can run distributed training jobs on MPI clusters.
- [Python API](http://paddlepaddle.org/documentation/api/zh/0.14.0/fluid.html) - [Python API](http://paddlepaddle.org/documentation/api/zh/0.15.0/fluid.html)
Our new API enables much shorter programs. Our new API enables much shorter programs.
- [How to Contribute](http://paddlepaddle.org/documentation/docs/zh/0.14.0/new_docs/advanced_usage/development/contribute_to_paddle.html) - [How to Contribute](http://paddlepaddle.org/documentation/docs/zh/0.15.0/new_docs/advanced_usage/development/contribute_to_paddle.html)
We appreciate your contributions! We appreciate your contributions!
......
...@@ -20,6 +20,7 @@ import functools ...@@ -20,6 +20,7 @@ import functools
import numpy as np import numpy as np
import time import time
import os import os
import math
import cProfile, pstats, StringIO import cProfile, pstats, StringIO
...@@ -27,128 +28,120 @@ import paddle ...@@ -27,128 +28,120 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.fluid.profiler as profiler import paddle.fluid.profiler as profiler
# from recordio_converter import imagenet_train, imagenet_test
from imagenet_reader import train, val from imagenet_reader import train, val
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [30, 60, 90],
"steps": [0.1, 0.01, 0.001, 0.0001]
}
}
class ResNet():
def __init__(self, layers=50, is_train=True):
self.params = train_parameters
self.layers = layers
self.is_train = is_train
def net(self, input, class_dim=1000):
layers = self.layers
supported_layers = [50, 101, 152]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(supported_layers, layers)
if layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
num_filters = [64, 128, 256, 512]
conv = self.conv_bn_layer(
input=input, num_filters=64, filter_size=7, stride=2, act='relu')
conv = fluid.layers.pool2d(
input=conv,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')
for block in range(len(depth)):
for i in range(depth[block]):
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1)
pool = fluid.layers.pool2d(
input=conv, pool_size=7, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc(input=pool,
size=class_dim,
act='softmax',
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv,
stdv)))
return out
def conv_bn_layer(self,
input,
num_filters,
filter_size,
stride=1,
groups=1,
act=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
bias_attr=False)
return fluid.layers.batch_norm(
input=conv, act=act, is_test=not self.is_train)
def shortcut(self, input, ch_out, stride):
ch_in = input.shape[1]
if ch_in != ch_out or stride != 1:
return self.conv_bn_layer(input, ch_out, 1, stride)
else:
return input
def conv_bn_layer(input, def bottleneck_block(self, input, num_filters, stride):
ch_out, conv0 = self.conv_bn_layer(
filter_size, input=input, num_filters=num_filters, filter_size=1, act='relu')
stride, conv1 = self.conv_bn_layer(
padding, input=conv0,
act='relu', num_filters=num_filters,
is_train=True): filter_size=3,
conv1 = fluid.layers.conv2d( stride=stride,
input=input, act='relu')
filter_size=filter_size, conv2 = self.conv_bn_layer(
num_filters=ch_out, input=conv1, num_filters=num_filters * 4, filter_size=1, act=None)
stride=stride,
padding=padding,
act=None,
bias_attr=False)
return fluid.layers.batch_norm(input=conv1, act=act, is_test=not is_train)
def shortcut(input, ch_out, stride, is_train=True):
ch_in = input.shape[1] # if args.data_format == 'NCHW' else input.shape[-1]
if ch_in != ch_out:
return conv_bn_layer(
input, ch_out, 1, stride, 0, None, is_train=is_train)
else:
return input
def basicblock(input, ch_out, stride, is_train=True):
short = shortcut(input, ch_out, stride, is_train=is_train)
conv1 = conv_bn_layer(input, ch_out, 3, stride, 1, is_train=is_train)
conv2 = conv_bn_layer(conv1, ch_out, 3, 1, 1, act=None, is_train=is_train)
return fluid.layers.elementwise_add(x=short, y=conv2, act='relu')
def bottleneck(input, ch_out, stride, is_train=True):
short = shortcut(input, ch_out * 4, stride, is_train=is_train)
conv1 = conv_bn_layer(input, ch_out, 1, stride, 0, is_train=is_train)
conv2 = conv_bn_layer(conv1, ch_out, 3, 1, 1, is_train=is_train)
conv3 = conv_bn_layer(
conv2, ch_out * 4, 1, 1, 0, act=None, is_train=is_train)
return fluid.layers.elementwise_add(x=short, y=conv3, act='relu')
def layer_warp(block_func, input, ch_out, count, stride):
res_out = block_func(input, ch_out, stride)
for i in range(1, count):
res_out = block_func(res_out, ch_out, 1)
return res_out
def resnet_imagenet(input, short = self.shortcut(input, num_filters * 4, stride)
class_dim,
depth=50,
data_format='NCHW',
is_train=True):
cfg = { return fluid.layers.elementwise_add(x=short, y=conv2, act='relu')
18: ([2, 2, 2, 1], basicblock),
34: ([3, 4, 6, 3], basicblock),
50: ([3, 4, 6, 3], bottleneck),
101: ([3, 4, 23, 3], bottleneck),
152: ([3, 8, 36, 3], bottleneck)
}
stages, block_func = cfg[depth]
conv1 = conv_bn_layer(input, ch_out=64, filter_size=7, stride=2, padding=3)
pool1 = fluid.layers.pool2d(
input=conv1, pool_type='avg', pool_size=3, pool_stride=2)
res1 = layer_warp(block_func, pool1, 64, stages[0], 1)
res2 = layer_warp(block_func, res1, 128, stages[1], 2)
res3 = layer_warp(block_func, res2, 256, stages[2], 2)
res4 = layer_warp(block_func, res3, 512, stages[3], 2)
pool2 = fluid.layers.pool2d(
input=res4,
pool_size=7,
pool_type='avg',
pool_stride=1,
global_pooling=True)
out = fluid.layers.fc(input=pool2, size=class_dim, act='softmax')
return out
def resnet_cifar10(input, class_dim, depth=32, data_format='NCHW'):
assert (depth - 2) % 6 == 0
n = (depth - 2) // 6
conv1 = conv_bn_layer(
input=input, ch_out=16, filter_size=3, stride=1, padding=1)
res1 = layer_warp(basicblock, conv1, 16, n, 1)
res2 = layer_warp(basicblock, res1, 32, n, 2)
res3 = layer_warp(basicblock, res2, 64, n, 2)
pool = fluid.layers.pool2d(
input=res3, pool_size=8, pool_type='avg', pool_stride=1)
out = fluid.layers.fc(input=pool, size=class_dim, act='softmax')
return out
def _model_reader_dshape_classdim(args, is_train): def _model_reader_dshape_classdim(args, is_train):
model = resnet_cifar10 model = None
reader = None reader = None
if args.data_set == "cifar10": if args.data_set == "flowers":
class_dim = 10
if args.data_format == 'NCHW':
dshape = [3, 32, 32]
else:
dshape = [32, 32, 3]
model = resnet_cifar10
if is_train:
reader = paddle.dataset.cifar.train10()
else:
reader = paddle.dataset.cifar.test10()
elif args.data_set == "flowers":
class_dim = 102 class_dim = 102
if args.data_format == 'NCHW': if args.data_format == 'NCHW':
dshape = [3, 224, 224] dshape = [3, 224, 224]
else: else:
dshape = [224, 224, 3] dshape = [224, 224, 3]
model = resnet_imagenet
if is_train: if is_train:
reader = paddle.dataset.flowers.train() reader = paddle.dataset.flowers.train()
else: else:
...@@ -159,7 +152,6 @@ def _model_reader_dshape_classdim(args, is_train): ...@@ -159,7 +152,6 @@ def _model_reader_dshape_classdim(args, is_train):
dshape = [3, 224, 224] dshape = [3, 224, 224]
else: else:
dshape = [224, 224, 3] dshape = [224, 224, 3]
model = resnet_imagenet
if not args.data_path: if not args.data_path:
raise Exception( raise Exception(
"Must specify --data_path when training with imagenet") "Must specify --data_path when training with imagenet")
...@@ -173,12 +165,11 @@ def _model_reader_dshape_classdim(args, is_train): ...@@ -173,12 +165,11 @@ def _model_reader_dshape_classdim(args, is_train):
reader = train(xmap=False) reader = train(xmap=False)
else: else:
reader = val(xmap=False) reader = val(xmap=False)
return model, reader, dshape, class_dim return reader, dshape, class_dim
def get_model(args, is_train, main_prog, startup_prog): def get_model(args, is_train, main_prog, startup_prog):
model, reader, dshape, class_dim = _model_reader_dshape_classdim(args, reader, dshape, class_dim = _model_reader_dshape_classdim(args, is_train)
is_train)
pyreader = None pyreader = None
trainer_count = int(os.getenv("PADDLE_TRAINERS")) trainer_count = int(os.getenv("PADDLE_TRAINERS"))
...@@ -198,7 +189,8 @@ def get_model(args, is_train, main_prog, startup_prog): ...@@ -198,7 +189,8 @@ def get_model(args, is_train, main_prog, startup_prog):
label = fluid.layers.data( label = fluid.layers.data(
name='label', shape=[1], dtype='int64') name='label', shape=[1], dtype='int64')
predict = model(input, class_dim, is_train=is_train) model = ResNet(is_train=is_train)
predict = model.net(input, class_dim=class_dim)
cost = fluid.layers.cross_entropy(input=predict, label=label) cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(x=cost) avg_cost = fluid.layers.mean(x=cost)
...@@ -216,15 +208,14 @@ def get_model(args, is_train, main_prog, startup_prog): ...@@ -216,15 +208,14 @@ def get_model(args, is_train, main_prog, startup_prog):
total_images = 1281167 / trainer_count total_images = 1281167 / trainer_count
step = int(total_images / args.batch_size + 1) step = int(total_images / args.batch_size + 1)
epochs = [30, 60, 80, 90] epochs = [30, 60, 90]
bd = [step * e for e in epochs] bd = [step * e for e in epochs]
base_lr = args.learning_rate base_lr = args.learning_rate
lr = [] lr = []
lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)] lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)]
optimizer = fluid.optimizer.Momentum( optimizer = fluid.optimizer.Momentum(
learning_rate=base_lr, learning_rate=fluid.layers.piecewise_decay(
#learning_rate=fluid.layers.piecewise_decay( boundaries=bd, values=lr),
# boundaries=bd, values=lr),
momentum=0.9, momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4)) regularization=fluid.regularizer.L2Decay(1e-4))
optimizer.minimize(avg_cost) optimizer.minimize(avg_cost)
......
...@@ -442,8 +442,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -442,8 +442,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
use_gpu = nccl_ctxs_ != nullptr; use_gpu = nccl_ctxs_ != nullptr;
#endif #endif
if (use_gpu || if (use_gpu && strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
// Insert BCast Ops // Insert BCast Ops
for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) { for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
auto &to_bcast_set = bcast_var_name_set[dev_id]; auto &to_bcast_set = bcast_var_name_set[dev_id];
......
...@@ -28,6 +28,9 @@ cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph grap ...@@ -28,6 +28,9 @@ cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph grap
pass_library(graph_to_program_pass base) pass_library(graph_to_program_pass base)
pass_library(graph_viz_pass base) pass_library(graph_viz_pass base)
pass_library(fc_fuse_pass inference) pass_library(fc_fuse_pass inference)
if(WITH_MKLDNN)
pass_library(conv_relu_mkldnn_fuse_pass inference)
endif()
pass_library(attention_lstm_fuse_pass inference) pass_library(attention_lstm_fuse_pass inference)
pass_library(infer_clean_graph_pass inference) pass_library(infer_clean_graph_pass inference)
pass_library(fc_lstm_fuse_pass inference) pass_library(fc_lstm_fuse_pass inference)
...@@ -42,3 +45,6 @@ cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_r ...@@ -42,3 +45,6 @@ cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_r
cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass) cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass)
cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector) cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector)
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto) cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto)
if(WITH_MKLDNN)
cc_test(test_conv_relu_mkldnn_fuse_pass SRCS conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass)
endif()
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h"
#include <string>
#include <vector>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init("conv_relu_mkldnn_fuse", graph.get());
std::unordered_set<Node*> nodes2delete;
GraphPatternDetector gpd;
auto* conv_input = gpd.mutable_pattern()
->NewNode("conv_relu_mkldnn_fuse/conv_input")
->AsInput()
->assert_is_op_input("conv2d", "Input");
patterns::ConvReLU conv_relu_pattern(gpd.mutable_pattern(),
"conv_relu_mkldnn_fuse");
conv_relu_pattern(conv_input);
int found_conv_relu_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "handle ConvReLU fuse";
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight,
conv_relu_pattern); // Filter
GET_IR_NODE_FROM_SUBGRAPH(conv_bias, conv_bias, conv_relu_pattern); // Bias
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_relu_pattern); // tmp
GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_relu_pattern); // CONV op
GET_IR_NODE_FROM_SUBGRAPH(relu_out, relu_out, conv_relu_pattern); // Out
GET_IR_NODE_FROM_SUBGRAPH(relu, relu, conv_relu_pattern); // ReLU op
// Create an ConvReLU Node.
OpDesc desc;
std::string conv_relu_i_in = subgraph.at(conv_input)->Name();
std::string conv_relu_w_in = conv_weight->Name();
std::string conv_relu_b_in = conv_bias->Name();
std::string conv_relu_out = relu_out->Name();
desc.SetInput("Input", std::vector<std::string>({conv_relu_i_in}));
desc.SetInput("Filter", std::vector<std::string>({conv_relu_w_in}));
desc.SetInput("Bias", std::vector<std::string>({conv_relu_b_in}));
desc.SetOutput("Out", std::vector<std::string>({conv_relu_out}));
desc.SetType("conv2d");
for (auto& attr : conv->Op()->GetAttrMap()) {
desc.SetAttr(attr.first, attr.second);
}
desc.SetAttr("fuse_relu", true);
auto conv_relu_node = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(graph.get(), {conv, relu, conv_out});
PADDLE_ENFORCE(subgraph.count(conv_input));
IR_NODE_LINK_TO(subgraph.at(conv_input), conv_relu_node);
IR_NODE_LINK_TO(conv_weight, conv_relu_node);
IR_NODE_LINK_TO(conv_bias, conv_relu_node);
IR_NODE_LINK_TO(conv_relu_node, relu_out);
found_conv_relu_count++;
};
gpd(graph.get(), handler);
AddStatis(found_conv_relu_count);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(conv_relu_mkldnn_fuse_pass,
paddle::framework::ir::ConvReLUFusePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Fuse the CONV and ReLU to a ConvReLUOp.
*/
class ConvReLUFusePass : public FusePassBase {
public:
virtual ~ConvReLUFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h"
#include <gtest/gtest.h>
namespace paddle {
namespace framework {
namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
if (type == "conv2d") {
op->SetAttr("use_mkldnn", true);
op->SetInput("Input", {inputs[0]});
op->SetInput("Filter", {inputs[1]});
op->SetInput("Bias", {inputs[2]});
} else if (type == "relu") {
op->SetInput("X", inputs);
}
op->SetOutput("Out", outputs);
}
// a->OP0->b
// b->OP1->c
// (c, weights, bias)->conv->f
// (f)->relu->g
ProgramDesc BuildProgramDesc() {
ProgramDesc prog;
for (auto& v :
std::vector<std::string>({"a", "b", "c", "weights", "bias", "f", "g"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::SELECTED_ROWS);
if (v == "weights" || v == "bias") {
var->SetPersistable(true);
}
}
SetOp(&prog, "OP0", std::vector<std::string>({"a"}),
std::vector<std::string>({"b"}));
SetOp(&prog, "OP1", std::vector<std::string>({"b"}),
std::vector<std::string>({"c"}));
SetOp(&prog, "conv2d", std::vector<std::string>({"c", "weights", "bias"}),
std::vector<std::string>({"f"}));
SetOp(&prog, "relu", std::vector<std::string>({"f"}),
std::vector<std::string>({"g"}));
return prog;
}
TEST(ConvReLUFusePass, basic) {
auto prog = BuildProgramDesc();
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
auto pass = PassRegistry::Instance().Get("conv_relu_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
int current_nodes_num = graph->Nodes().size();
// Remove 3 Nodes: CONV, RELU, conv_out
// Add 1 Node: ConvReLU
EXPECT_EQ(original_nodes_num - 2, current_nodes_num);
// Assert conv_relu op in newly generated graph
int conv_relu_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "conv2d") {
if (node->Op()->HasAttr("use_mkldnn")) {
bool use_mkldnn = boost::get<bool>(node->Op()->GetAttr("use_mkldnn"));
if (use_mkldnn) {
if (node->Op()->HasAttr("fuse_relu")) {
bool fuse_relu = boost::get<bool>(node->Op()->GetAttr("fuse_relu"));
if (fuse_relu) {
++conv_relu_count;
}
}
}
}
}
}
EXPECT_EQ(conv_relu_count, 1);
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(conv_relu_mkldnn_fuse_pass);
...@@ -51,7 +51,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, ...@@ -51,7 +51,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
if (with_fc_bias) { if (with_fc_bias) {
// Add FC-bias with LSTM-bias and create a new weight // Add FC-bias with LSTM-bias and create a new weight
PADDLE_ENFORCE(scope); PADDLE_ENFORCE(scope);
const std::string& new_bias_var = name_scope + "_bias.new"; const std::string& new_bias_var = patterns::UniqueKey("NewBias");
auto* bias_var = scope->Var(new_bias_var); auto* bias_var = scope->Var(new_bias_var);
PADDLE_ENFORCE(bias_var); PADDLE_ENFORCE(bias_var);
auto* bias_tensor = bias_var->GetMutable<framework::LoDTensor>(); auto* bias_tensor = bias_var->GetMutable<framework::LoDTensor>();
...@@ -120,7 +120,6 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, ...@@ -120,7 +120,6 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(lstm, lstm, lstm_pattern); GET_IR_NODE_FROM_SUBGRAPH(lstm, lstm, lstm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(Weight, Weight, lstm_pattern); GET_IR_NODE_FROM_SUBGRAPH(Weight, Weight, lstm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(Bias, Bias, lstm_pattern); GET_IR_NODE_FROM_SUBGRAPH(Bias, Bias, lstm_pattern);
...@@ -136,7 +135,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, ...@@ -136,7 +135,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
fc_bias); fc_bias);
// Remove unneeded nodes. // Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes( std::unordered_set<const Node*> marked_nodes(
{mul, lstm, elementwise_add}); {mul, lstm, elementwise_add, fc_bias});
GraphSafeRemoveNodes(graph, marked_nodes); GraphSafeRemoveNodes(graph, marked_nodes);
} else { } else {
GET_IR_NODE_FROM_SUBGRAPH(fc_out, mul_out, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(fc_out, mul_out, fc_pattern);
......
...@@ -522,6 +522,39 @@ bool VarLinksFromOp(Node* node, const std::string& op_type) { ...@@ -522,6 +522,39 @@ bool VarLinksFromOp(Node* node, const std::string& op_type) {
return false; return false;
} }
PDNode* patterns::ConvReLU::operator()(
paddle::framework::ir::PDNode* conv_input) {
// Create Operators
conv_input->assert_is_op_input("conv2d", "Input");
auto* conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d");
auto* relu_op = pattern->NewNode(relu_repr())->assert_is_op("relu");
// Create variables
// Filter
auto* conv_weight_var = pattern->NewNode(conv_weight_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("conv2d", "Filter");
// Bias
auto* conv_bias_var = pattern->NewNode(conv_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("conv2d", "Bias");
// intermediate variable, will be removed in the IR after fuse.
auto* conv_out_var = pattern->NewNode(conv_out_repr())
->AsIntermediate()
->assert_is_only_output_of_op("conv2d")
->assert_is_op_input("relu");
// output
auto* relu_out_var = pattern->NewNode(relu_out_repr())
->AsOutput()
->assert_is_op_output("relu");
conv_op->LinksFrom({conv_input, conv_weight_var, conv_bias_var})
.LinksTo({conv_out_var});
relu_op->LinksFrom({conv_out_var}).LinksTo({relu_out_var});
return relu_out_var;
}
PDNode* patterns::FC::operator()(paddle::framework::ir::PDNode* x, PDNode* patterns::FC::operator()(paddle::framework::ir::PDNode* x,
bool with_bias) { bool with_bias) {
// Create shared nodes. // Create shared nodes.
......
...@@ -360,6 +360,28 @@ struct PatternBase { ...@@ -360,6 +360,28 @@ struct PatternBase {
size_t id_; size_t id_;
}; };
// CONV with ReLU
// op: conv + relu
// named nodes:
// conv_input, conv_weight,
// conv_bias, conv_out, conv,
// relu_out, relu
struct ConvReLU : public PatternBase {
ConvReLU(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_relu") {}
PDNode* operator()(PDNode* conv_input);
// declare operator node's name
PATTERN_DECL_NODE(conv);
PATTERN_DECL_NODE(relu);
// declare variable node's name
PATTERN_DECL_NODE(conv_weight);
PATTERN_DECL_NODE(conv_bias);
PATTERN_DECL_NODE(conv_out);
PATTERN_DECL_NODE(relu_out);
};
// FC with bias // FC with bias
// op: mul + elementwise_add // op: mul + elementwise_add
// named nodes: // named nodes:
......
...@@ -464,35 +464,35 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -464,35 +464,35 @@ class RuntimeInferShapeContext : public InferShapeContext {
: op_(op), scope_(scope) {} : op_(op), scope_(scope) {}
bool HasInput(const std::string& name) const override { bool HasInput(const std::string& name) const override {
if (!op_.HasInputs(name)) { // has only one input
const auto& ins = op_.Inputs();
auto it = ins.find(name);
if (it == ins.end()) {
return false; return false;
} }
auto& ins = Inputs(name); const auto& in = it->second;
size_t length = ins.size(); if (in.size() == 0 || in[0] == kEmptyVarName) {
if (length == 0) {
return false; return false;
} }
PADDLE_ENFORCE_EQ(length, 1UL, PADDLE_ENFORCE_EQ(in.size(), 1UL,
"Input %s should not have more than one inputs", name); "Input %s should not have more than one inputs", name);
auto ipt = ins[0]; return scope_.FindVar(in[0]) != nullptr;
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return var != nullptr;
} }
bool HasOutput(const std::string& name) const override { bool HasOutput(const std::string& name) const override {
if (!op_.HasOutputs(name)) { // has only one output
const auto& outs = op_.Outputs();
auto it = outs.find(name);
if (it == outs.end()) {
return false; return false;
} }
auto& outs = Outputs(name); const auto& out = it->second;
size_t length = outs.size(); if (out.size() == 0 || out[0] == kEmptyVarName) {
if (length == 0) {
return false; return false;
} }
PADDLE_ENFORCE_EQ(length, 1UL, PADDLE_ENFORCE_EQ(out.size(), 1UL,
"Output %s should not have more than one inputs", name); "Output %s should not have more than one outputs", name);
auto ipt = outs[0]; return scope_.FindVar(out[0]) != nullptr;
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return var != nullptr;
} }
bool HasInputs(const std::string& name) const override { bool HasInputs(const std::string& name) const override {
......
...@@ -352,7 +352,10 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes( ...@@ -352,7 +352,10 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
ParallelExecutor::~ParallelExecutor() { ParallelExecutor::~ParallelExecutor() {
if (member_->own_local_scope_) { if (member_->own_local_scope_) {
for (size_t i = 1; i < member_->local_scopes_.size(); ++i) { for (size_t i = 1; i < member_->local_scopes_.size(); ++i) {
member_->global_scope_->DeleteScope(member_->local_scopes_[i]); Scope *local_scope = member_->local_scopes_[i];
if (member_->global_scope_->HasKid(local_scope)) {
member_->global_scope_->DeleteScope(local_scope);
}
} }
} }
} }
......
...@@ -87,8 +87,17 @@ TEST(ProgramDesc, copy_ctor) { ...@@ -87,8 +87,17 @@ TEST(ProgramDesc, copy_ctor) {
ASSERT_EQ(op_origin->Inputs(), op_copy->Inputs()); ASSERT_EQ(op_origin->Inputs(), op_copy->Inputs());
ASSERT_EQ(op_origin->Outputs(), op_copy->Outputs()); ASSERT_EQ(op_origin->Outputs(), op_copy->Outputs());
ASSERT_EQ(op_copy->Proto()->SerializeAsString(), ASSERT_EQ(op_origin->Proto()->attrs().size(),
op_origin->Proto()->SerializeAsString()); op_copy->Proto()->attrs().size());
for (auto it = op_origin->Proto()->attrs().begin();
it != op_origin->Proto()->attrs().end(); ++it) {
for (auto it_2 = op_copy->Proto()->attrs().begin();
it_2 != op_copy->Proto()->attrs().end(); ++it_2) {
if (it->name() == it_2->name()) {
ASSERT_TRUE(it_2->SerializeAsString() == it->SerializeAsString());
}
}
}
if (op->Type() == "op_with_subblock") { if (op->Type() == "op_with_subblock") {
ASSERT_EQ(1, op->GetBlockAttrId("sub_block")); ASSERT_EQ(1, op->GetBlockAttrId("sub_block"));
......
...@@ -72,6 +72,12 @@ void Scope::DropKids() { ...@@ -72,6 +72,12 @@ void Scope::DropKids() {
kids_.clear(); kids_.clear();
} }
bool Scope::HasKid(const Scope* scope) const {
std::unique_lock<std::mutex> lock(mutex_);
auto it = std::find(this->kids_.begin(), this->kids_.end(), scope);
return it != this->kids_.end();
}
std::vector<std::string> Scope::LocalVarNames() const { std::vector<std::string> Scope::LocalVarNames() const {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
std::vector<std::string> known_vars; std::vector<std::string> known_vars;
......
...@@ -71,6 +71,9 @@ class Scope { ...@@ -71,6 +71,9 @@ class Scope {
/// Drop all kids scopes belonged to this scope. /// Drop all kids scopes belonged to this scope.
void DropKids(); void DropKids();
/// Find if a scope exists in the kid scopes
bool HasKid(const Scope* scope) const;
// enumerate all the variables current contains. // enumerate all the variables current contains.
std::vector<std::string> LocalVarNames() const; std::vector<std::string> LocalVarNames() const;
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/inference/analysis/ir_pass_manager.h" #include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include <string> #include <string>
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -37,13 +38,16 @@ IRPassManager::IRPassManager(const ProgramDesc &program, ...@@ -37,13 +38,16 @@ IRPassManager::IRPassManager(const ProgramDesc &program,
void IRPassManager::Apply(const std::vector<std::string> &passes) { void IRPassManager::Apply(const std::vector<std::string> &passes) {
// Apply all the passes // Apply all the passes
std::string pre_pass; std::string pre_pass;
int pass_num = 0;
for (const std::string &pass_name : passes) { for (const std::string &pass_name : passes) {
PrettyLogEndl(Style::H2(), "--- Running IR pass [%s]", pass_name); PrettyLogEndl(Style::H2(), "--- Running IR pass [%s]", pass_name);
auto pass = framework::ir::PassRegistry::Instance().Get(pass_name); auto pass = framework::ir::PassRegistry::Instance().Get(pass_name);
if (pass_name == "graph_viz_pass") { if (pass_name == "graph_viz_pass") {
std::string dot_file_path = std::string dot_file_path = std::to_string(pass_num) + "_ir_" +
"ir_" + (pre_pass.empty() ? "origin" : pre_pass) + ".dot"; (pre_pass.empty() ? "origin" : pre_pass) +
".dot";
pass->Set("graph_viz_path", new std::string(std::move(dot_file_path))); pass->Set("graph_viz_path", new std::string(std::move(dot_file_path)));
pass_num++;
} }
graph_ = pass->Apply(std::move(graph_)); graph_ = pass->Apply(std::move(graph_));
pre_pass = pass_name; pre_pass = pass_name;
......
...@@ -262,7 +262,7 @@ void NativePaddlePredictor::GetFetchOne(const framework::LoDTensor &fetch, ...@@ -262,7 +262,7 @@ void NativePaddlePredictor::GetFetchOne(const framework::LoDTensor &fetch,
if (buffer.empty() || buffer.length() < sizeof(T) * data.size()) { if (buffer.empty() || buffer.length() < sizeof(T) * data.size()) {
buffer.Resize(sizeof(T) * data.size()); buffer.Resize(sizeof(T) * data.size());
} }
std::memcpy(buffer.data(), data.data(), buffer.length()); std::memcpy(buffer.data(), data.data(), sizeof(T) * data.size());
// copy LoD // copy LoD
for (const auto &level : fetch.lod()) { for (const auto &level : fetch.lod()) {
output->lod.emplace_back(level); output->lod.emplace_back(level);
......
...@@ -117,34 +117,6 @@ void GetOneBatch(std::vector<PaddleTensor> *input_slots, DataRecord *data, ...@@ -117,34 +117,6 @@ void GetOneBatch(std::vector<PaddleTensor> *input_slots, DataRecord *data,
input_slots->assign({input_tensor}); input_slots->assign({input_tensor});
} }
void BenchAllData(const std::string &model_path, const std::string &data_file,
const int batch_size, const int repeat) {
NativeConfig config;
config.model_dir = model_path;
config.use_gpu = false;
config.device = 0;
config.specify_input_name = true;
std::vector<PaddleTensor> input_slots, outputs_slots;
DataRecord data(data_file, batch_size);
auto predictor =
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(config);
GetOneBatch(&input_slots, &data, batch_size);
for (int i = 0; i < FLAGS_burning; i++) {
predictor->Run(input_slots, &outputs_slots);
}
Timer timer;
double sum = 0;
for (int i = 0; i < repeat; i++) {
for (size_t bid = 0; bid < data.batched_datas.size(); ++bid) {
GetOneBatch(&input_slots, &data, batch_size);
timer.tic();
predictor->Run(input_slots, &outputs_slots);
sum += timer.toc();
}
}
PrintTime(batch_size, repeat, 1, 0, sum / repeat);
}
const int64_t lac_ref_data[] = {24, 25, 25, 25, 38, 30, 31, 14, 15, 44, 24, 25, const int64_t lac_ref_data[] = {24, 25, 25, 25, 38, 30, 31, 14, 15, 44, 24, 25,
25, 25, 25, 25, 44, 24, 25, 25, 25, 36, 42, 43, 25, 25, 25, 25, 44, 24, 25, 25, 25, 36, 42, 43,
44, 14, 15, 44, 14, 15, 44, 14, 15, 44, 38, 39, 44, 14, 15, 44, 14, 15, 44, 14, 15, 44, 38, 39,
......
...@@ -144,8 +144,9 @@ void TestChineseNERPrediction(bool use_analysis) { ...@@ -144,8 +144,9 @@ void TestChineseNERPrediction(bool use_analysis) {
size_t num_samples; size_t num_samples;
for (int i = 0; i < FLAGS_repeat; i++) { for (int i = 0; i < FLAGS_repeat; i++) {
DataRecord data(FLAGS_infer_data, FLAGS_batch_size); DataRecord data(FLAGS_infer_data, FLAGS_batch_size);
// Just one batch, the num_samples remains the same.
num_samples = data.num_samples; num_samples = data.num_samples;
for (size_t bid = 0; bid < num_samples; ++bid) { for (size_t bid = 0; bid < num_samples / FLAGS_batch_size; ++bid) {
PrepareInputs(&input_slots, &data, FLAGS_batch_size); PrepareInputs(&input_slots, &data, FLAGS_batch_size);
timer.tic(); timer.tic();
predictor->Run(input_slots, &outputs); predictor->Run(input_slots, &outputs);
......
...@@ -24,28 +24,28 @@ namespace operators { ...@@ -24,28 +24,28 @@ namespace operators {
void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of AttentionLSTM should not be null."); "Assert only one Input(X) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasInput("C0"), PADDLE_ENFORCE(ctx->HasInput("C0"),
"Input(C0) of AttentionLSTM should not be null."); "Assert only one Input(C0) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasInput("LSTMWeight"), PADDLE_ENFORCE(ctx->HasInput("LSTMWeight"),
"Input(LSTMWeight) of AttentionLSTM should not be null."); "Assert only one Input(LSTMWeight) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasInput("LSTMBias"), PADDLE_ENFORCE(ctx->HasInput("LSTMBias"),
"Input(LSTMBias) of AttentionLSTM should not be null."); "Assert only one Input(LSTMBias) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasInput("AttentionWeight"), PADDLE_ENFORCE(ctx->HasInput("AttentionWeight"),
"Input(AttentionWeight) of AttentionLSTM should not be null."); "Assert only one Input(AttentionWeight) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"), PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(Hidden) of AttentionLSTM should not be null."); "Assert only one Output(Hidden) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasOutput("Cell"), PADDLE_ENFORCE(ctx->HasOutput("Cell"),
"Output(Cell) of AttentionLSTM should not be null."); "Assert only one Output(Cell) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasOutput("AttentionedX"), PADDLE_ENFORCE(ctx->HasOutput("AttentionedX"),
"Output(AttentionedX) of AttentionLSTM should not be null."); "Assert only one Output(AttentionedX) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasOutput("AttentionFCOut"), PADDLE_ENFORCE(ctx->HasOutput("AttentionFCOut"),
"Output(AttentionFCOut) of AttentionLSTM should not be null."); "Assert only one Output(AttentionFCOut) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasOutput("LSTMX"), PADDLE_ENFORCE(ctx->HasOutput("LSTMX"),
"Output(LSTMX) of AttentionLSTM should not be null."); "Assert only one Output(LSTMX) of AttentionLSTM.");
PADDLE_ENFORCE(ctx->HasOutput("LSTMOUT"), PADDLE_ENFORCE(ctx->HasOutput("LSTMOUT"),
"Output(LSTMOUT) of AttentionLSTM should not be null."); "Assert only one Output(LSTMOUT) of AttentionLSTM.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
const int M = x_dims[1]; const int M = x_dims[1];
......
...@@ -130,12 +130,13 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler { ...@@ -130,12 +130,13 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryFromPrimitive( std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryFromPrimitive(
const std::shared_ptr<mkldnn::memory> user_weights_memory_p, const std::shared_ptr<mkldnn::memory> user_weights_memory_p,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT std::vector<mkldnn::primitive>& pipeline, // NOLINT
bool is_persistent = false) {
auto user_weights_pd = user_weights_memory_p->get_primitive_desc(); auto user_weights_pd = user_weights_memory_p->get_primitive_desc();
auto weights_pd = conv_pd_->weights_primitive_desc(); auto weights_pd = conv_pd_->weights_primitive_desc();
return this->AcquireMemory(weights_pd, user_weights_pd, return this->AcquireMemory(weights_pd, user_weights_pd,
user_weights_memory_p, "@weights_mem_p", user_weights_memory_p, "@weights_mem_p",
pipeline); pipeline, is_persistent);
} }
std::shared_ptr<mkldnn::memory> AcquireBiasMemoryFromPrimitive( std::shared_ptr<mkldnn::memory> AcquireBiasMemoryFromPrimitive(
...@@ -266,6 +267,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -266,6 +267,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace."); "It must use CPUPlace.");
const bool is_test = ctx.Attr<bool>("is_test");
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>(); ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& mkldnn_engine = dev_ctx.GetEngine();
...@@ -296,6 +299,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -296,6 +299,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides"); std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations"); std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
bool fuse_relu = ctx.Attr<bool>("fuse_relu");
int groups = ctx.Attr<int>("groups"); int groups = ctx.Attr<int>("groups");
// TODO(pzelazko-intel) add support for group convolution and dilation // TODO(pzelazko-intel) add support for group convolution and dilation
...@@ -348,11 +352,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -348,11 +352,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias_tz = paddle::framework::vectorize2int(bias->dims()); bias_tz = paddle::framework::vectorize2int(bias->dims());
auto bias_md = platform::MKLDNNMemDesc( auto bias_md = platform::MKLDNNMemDesc(
bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x); bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x);
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, conv_pd =
strides, paddings, mkldnn_engine); ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, strides,
paddings, mkldnn_engine, fuse_relu);
} else { } else {
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides,
paddings, mkldnn_engine); paddings, mkldnn_engine, fuse_relu);
} }
// Save conv_pd/src_memory/weights_memory for backward pass // Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx.SetBlob(key_conv_pd, conv_pd); dev_ctx.SetBlob(key_conv_pd, conv_pd);
...@@ -371,7 +376,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -371,7 +376,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto src_memory_p = auto src_memory_p =
handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline); handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline);
auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive( auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive(
user_weights_memory_p, pipeline); user_weights_memory_p, pipeline, is_test);
auto dst_memory_p = auto dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data)); handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
...@@ -402,11 +407,26 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -402,11 +407,26 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
} }
private: private:
mkldnn::primitive_attr AddRelu() const {
// Fusion with ReLU layer is executed through the PostOps feature. Create a
// PostOps object and configure it to execute an eltwise relu operation.
mkldnn::primitive_attr conv_attr;
constexpr float scale = 1.0f;
constexpr float negative_slope = 0.0f;
constexpr float placeholder = 0.0f;
mkldnn::post_ops post_operations;
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu,
negative_slope, placeholder);
conv_attr.set_post_ops(post_operations);
return conv_attr;
}
std::unique_ptr<mkldnn::convolution_forward::primitive_desc> std::unique_ptr<mkldnn::convolution_forward::primitive_desc>
ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights, ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights,
const memory::desc& dst, const std::vector<int>& strides, const memory::desc& dst, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& paddings,
const mkldnn::engine& engine) const { const mkldnn::engine& engine,
const bool fuse_relu) const {
memory::dims stride_dims = {strides[0], strides[1]}; memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]}; memory::dims padding_dims = {paddings[0], paddings[1]};
...@@ -415,8 +435,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -415,8 +435,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dst, stride_dims, padding_dims, padding_dims, dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero); mkldnn::padding_kind::zero);
auto p_conv_pd = mkldnn::primitive_attr conv_attr;
new mkldnn::convolution_forward::primitive_desc(conv_desc, engine); if (fuse_relu) {
conv_attr = AddRelu();
}
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine);
return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>( return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
p_conv_pd); p_conv_pd);
...@@ -427,7 +452,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -427,7 +452,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const memory::desc& bias, const memory::desc& dst, const memory::desc& bias, const memory::desc& dst,
const std::vector<int>& strides, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& paddings,
const mkldnn::engine& engine) const { const mkldnn::engine& engine,
const bool fuse_relu) const {
memory::dims stride_dims = {strides[0], strides[1]}; memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]}; memory::dims padding_dims = {paddings[0], paddings[1]};
...@@ -436,8 +462,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -436,8 +462,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias, dst, stride_dims, padding_dims, padding_dims, bias, dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero); mkldnn::padding_kind::zero);
auto p_conv_pd = mkldnn::primitive_attr conv_attr;
new mkldnn::convolution_forward::primitive_desc(conv_desc, engine); if (fuse_relu) {
conv_attr = AddRelu();
}
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine);
return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>( return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
p_conv_pd); p_conv_pd);
......
...@@ -109,6 +109,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( ...@@ -109,6 +109,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
} }
void Conv2DOpMaker::Make() { void Conv2DOpMaker::Make() {
AddAttr<bool>("is_test", "").SetDefault(false);
AddInput( AddInput(
"Input", "Input",
"(Tensor) The input tensor of convolution operator. " "(Tensor) The input tensor of convolution operator. "
...@@ -161,6 +162,8 @@ void Conv2DOpMaker::Make() { ...@@ -161,6 +162,8 @@ void Conv2DOpMaker::Make() {
AddAttr<bool>("use_mkldnn", AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel") "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("fuse_relu", "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<std::string>( AddAttr<std::string>(
"data_format", "data_format",
"(string, default NCHW) Only used in " "(string, default NCHW) Only used in "
......
...@@ -20,6 +20,7 @@ if(WITH_GRPC) ...@@ -20,6 +20,7 @@ if(WITH_GRPC)
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL) DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL)
cc_test(rpc_server_test SRCS rpc_server_test.cc cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL) DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL)
cc_test(varhandle_test SRCS varhandle_test.cc)
return() return()
endif() endif()
......
...@@ -59,40 +59,32 @@ GRPCClient::~GRPCClient() { ...@@ -59,40 +59,32 @@ GRPCClient::~GRPCClient() {
} }
channels_.clear(); channels_.clear();
} }
client_thread_->join(); client_thread_->join();
} }
bool GRPCClient::AsyncSendVar(const std::string& ep, VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& var_name, int64_t time_out) { const std::string& var_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx; const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep; const std::string ep_val = ep;
const std::string var_name_val = var_name; const std::string var_name_val = var_name;
const framework::Scope* p_scope = &scope; const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val); const auto ch = GetChannel(ep_val);
SendProcessor* s = new SendProcessor(ch);
VarHandlePtr h(new VarHandle(ep, "Send", var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out);
framework::AsyncIO([var_name_val, p_ctx, ep_val, p_scope, time_out, ch, framework::AsyncIO([var_name_val, p_scope, p_ctx, s, this] {
this] {
auto* var = p_scope->FindVar(var_name_val); auto* var = p_scope->FindVar(var_name_val);
::grpc::ByteBuffer req; ::grpc::ByteBuffer req;
SerializeToByteBuffer(var_name_val, var, *p_ctx, &req); SerializeToByteBuffer(var_name_val, var, *p_ctx, &req);
// varhandle VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
VarHandle var_h;
var_h.ep = ep_val;
var_h.scope = p_scope;
var_h.name = var_name_val;
var_h.ctx = p_ctx;
var_h.method = "Send";
VLOG(3) << var_h.String() << " begin";
// stub context // stub context
SendProcessor* s = new SendProcessor(ch);
s->Prepare(var_h, time_out);
s->response_call_back_ = nullptr; s->response_call_back_ = nullptr;
auto call = s->stub_g_.PrepareUnaryCall( auto call = s->stub_g_.PrepareUnaryCall(
...@@ -102,13 +94,13 @@ bool GRPCClient::AsyncSendVar(const std::string& ep, ...@@ -102,13 +94,13 @@ bool GRPCClient::AsyncSendVar(const std::string& ep,
}); });
req_count_++; req_count_++;
return true; return h;
} }
void ProcGetResponse(const VarHandle& var_h, void ProcGetResponse(const VarHandle& var_h,
const ::grpc::ByteBuffer& ret_msg) { const ::grpc::ByteBuffer& ret_msg) {
framework::Variable* outvar = nullptr; framework::Variable* outvar = nullptr;
DeserializeFromByteBuffer(ret_msg, *var_h.ctx, var_h.scope, &outvar); DeserializeFromByteBuffer(ret_msg, *var_h.ctx(), var_h.scope(), &outvar);
} }
template <typename T> template <typename T>
...@@ -119,37 +111,30 @@ void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) { ...@@ -119,37 +111,30 @@ void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
result->Swap(&tmp); result->Swap(&tmp);
} }
bool GRPCClient::AsyncGetVar(const std::string& ep, VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& var_name, int64_t time_out) { const std::string& var_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx; const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep; const std::string ep_val = ep;
const std::string var_name_val = var_name; const std::string var_name_val = var_name;
const framework::Scope* p_scope = &scope; const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val); const auto ch = GetChannel(ep_val);
GetProcessor* s = new GetProcessor(ch);
VarHandlePtr h(new VarHandle(ep, "Get", var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out);
framework::AsyncIO([var_name_val, ep_val, p_scope, p_ctx, time_out, ch, framework::AsyncIO([var_name_val, p_scope, p_ctx, s, this] {
this] {
// prepare input // prepare input
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(var_name_val); req.set_varname(var_name_val);
::grpc::ByteBuffer buf; ::grpc::ByteBuffer buf;
RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf); RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
// var handle VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
VarHandle var_h;
var_h.ep = ep_val;
var_h.scope = p_scope;
var_h.name = var_name_val;
var_h.ctx = p_ctx;
var_h.method = "Get";
VLOG(3) << var_h.String() << " begin";
// stub context // stub context
GetProcessor* s = new GetProcessor(ch);
s->Prepare(var_h, time_out);
s->response_call_back_ = ProcGetResponse; s->response_call_back_ = ProcGetResponse;
auto call = s->stub_g_.PrepareUnaryCall( auto call = s->stub_g_.PrepareUnaryCall(
...@@ -160,42 +145,36 @@ bool GRPCClient::AsyncGetVar(const std::string& ep, ...@@ -160,42 +145,36 @@ bool GRPCClient::AsyncGetVar(const std::string& ep,
req_count_++; req_count_++;
return true; return h;
} }
bool GRPCClient::AsyncPrefetchVar(const std::string& ep, VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& in_var_name, const std::string& in_var_name,
const std::string& out_var_name, const std::string& out_var_name,
int64_t time_out) { int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx; const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep; const std::string ep_val = ep;
const std::string in_var_name_val = in_var_name; const std::string in_var_name_val = in_var_name;
const std::string out_var_name_val = out_var_name; const std::string out_var_name_val = out_var_name;
const framework::Scope* p_scope = &scope; const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val); const auto ch = GetChannel(ep_val);
GetProcessor* s = new GetProcessor(ch);
VarHandlePtr h(
new VarHandle(ep, "Prefetch", out_var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out);
framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx, framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
time_out, ch, this] { time_out, s, this] {
auto* var = p_scope->FindVar(in_var_name_val); auto* var = p_scope->FindVar(in_var_name_val);
::grpc::ByteBuffer req; ::grpc::ByteBuffer req;
SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val); SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val);
// var handle VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
VarHandle var_h;
var_h.ep = ep_val;
var_h.scope = p_scope;
var_h.name = out_var_name_val;
var_h.ctx = p_ctx;
var_h.method = "Prefetch";
VLOG(3) << var_h.String() << " begin";
// stub context // stub context
GetProcessor* s = new GetProcessor(ch);
s->Prepare(var_h, time_out);
s->response_call_back_ = ProcGetResponse; s->response_call_back_ = ProcGetResponse;
auto call = s->stub_g_.PrepareUnaryCall( auto call = s->stub_g_.PrepareUnaryCall(
...@@ -206,56 +185,68 @@ bool GRPCClient::AsyncPrefetchVar(const std::string& ep, ...@@ -206,56 +185,68 @@ bool GRPCClient::AsyncPrefetchVar(const std::string& ep,
}); });
req_count_++; req_count_++;
return true; return h;
} }
void GRPCClient::AsyncSendBatchBarrier(const std::string& ep, VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out) { int64_t time_out) {
const auto ch = GetChannel(ep); const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
s->Prepare(time_out); VarHandlePtr h(new VarHandle(ep, "BatchBarrier", BATCH_BARRIER_MESSAGE,
nullptr, nullptr));
s->Prepare(h, time_out);
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(BATCH_BARRIER_MESSAGE); req.set_varname(BATCH_BARRIER_MESSAGE);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s)); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++; req_count_++;
return h;
} }
void GRPCClient::AsyncSendFetchBarrier(const std::string& ep, VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out) { int64_t time_out) {
const auto ch = GetChannel(ep); const auto ch = GetChannel(ep);
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch); FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
s->Prepare(time_out); VarHandlePtr h(new VarHandle(ep, "FetchBarrier", FETCH_BARRIER_MESSAGE,
nullptr, nullptr));
s->Prepare(h, time_out);
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(FETCH_BARRIER_MESSAGE); req.set_varname(FETCH_BARRIER_MESSAGE);
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_); auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s)); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++; req_count_++;
return h;
} }
void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) { VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
int64_t time_out) {
const auto ch = GetChannel(ep); const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
s->Prepare(time_out); VarHandlePtr h(
new VarHandle(ep, "SendComplete", COMPLETE_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out);
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(COMPLETE_MESSAGE); req.set_varname(COMPLETE_MESSAGE);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s)); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++; req_count_++;
return h;
} }
void GRPCClient::AsyncCheckpointNotify(const std::string& ep, VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dir, const std::string& dir,
int64_t time_out) { int64_t time_out) {
const auto ch = GetChannel(ep); const auto ch = GetChannel(ep);
CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch); CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch);
s->Prepare(time_out); VarHandlePtr h(new VarHandle(ep, "CheckPointNotify", CHECKPOINT_SAVE_MESSAGE,
nullptr, nullptr));
s->Prepare(h, time_out);
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(CHECKPOINT_SAVE_MESSAGE); req.set_varname(CHECKPOINT_SAVE_MESSAGE);
...@@ -264,6 +255,7 @@ void GRPCClient::AsyncCheckpointNotify(const std::string& ep, ...@@ -264,6 +255,7 @@ void GRPCClient::AsyncCheckpointNotify(const std::string& ep,
auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_); auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s)); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++; req_count_++;
return h;
} }
bool GRPCClient::Wait() { bool GRPCClient::Wait() {
...@@ -276,25 +268,28 @@ void GRPCClient::Proceed() { ...@@ -276,25 +268,28 @@ void GRPCClient::Proceed() {
void* tag = nullptr; void* tag = nullptr;
bool ok = false; bool ok = false;
VLOG(3) << "GRPCClient Proceed begin";
while (!stopped_ && cq_.Next(&tag, &ok)) { while (!stopped_ && cq_.Next(&tag, &ok)) {
BaseProcessor* c = static_cast<BaseProcessor*>(tag); BaseProcessor* c = static_cast<BaseProcessor*>(tag);
GPR_ASSERT(ok); GPR_ASSERT(ok);
PADDLE_ENFORCE(c); PADDLE_ENFORCE(c);
if (c->status_.ok()) { if (c->status_.ok()) {
VLOG(3) << c->var_h_.String() << " process"; VLOG(3) << c->GetVarHandlePtr()->String() << " process";
c->Process(); c->Process();
} else if (c->status_.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) { } else if (c->status_.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) {
LOG(ERROR) << c->var_h_.String() LOG(ERROR) << c->GetVarHandlePtr()->String()
<< " meets grpc error:" << c->status_.error_message(); << " meets grpc error:" << c->status_.error_message();
{ {
std::lock_guard<std::mutex> lk(sync_mutex_); std::lock_guard<std::mutex> lk(sync_mutex_);
ok_ = false; ok_ = false;
} }
sync_cond_.notify_all(); c->Finish(false);
} else { } else {
LOG(FATAL) << c->var_h_.String() LOG(FATAL) << c->GetVarHandlePtr()->String()
<< " meets grpc error:" << c->status_.error_message(); << " meets grpc error:" << c->status_.error_message();
c->Finish(false);
} }
delete c; delete c;
{ {
std::lock_guard<std::mutex> lk(sync_mutex_); std::lock_guard<std::mutex> lk(sync_mutex_);
...@@ -302,6 +297,7 @@ void GRPCClient::Proceed() { ...@@ -302,6 +297,7 @@ void GRPCClient::Proceed() {
} }
sync_cond_.notify_all(); sync_cond_.notify_all();
} }
VLOG(3) << "GRPCClient Proceed end";
} }
std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) { std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) {
......
...@@ -53,15 +53,14 @@ void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg); ...@@ -53,15 +53,14 @@ void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg);
class BaseProcessor { class BaseProcessor {
public: public:
explicit BaseProcessor(std::shared_ptr<grpc::Channel> ch) { BaseProcessor() { context_ = nullptr; }
context_ = nullptr;
}
virtual ~BaseProcessor() {} virtual ~BaseProcessor() {}
virtual void Prepare(const VarHandle& var_info, int64_t time_out) { virtual void Prepare(VarHandlePtr h, int64_t time_out) {
var_h_ = h;
context_.reset(new grpc::ClientContext()); context_.reset(new grpc::ClientContext());
var_h_ = var_info;
context_->set_wait_for_ready(true); context_->set_wait_for_ready(true);
if (time_out) { if (time_out) {
std::chrono::system_clock::time_point deadline = std::chrono::system_clock::time_point deadline =
...@@ -71,21 +70,21 @@ class BaseProcessor { ...@@ -71,21 +70,21 @@ class BaseProcessor {
} }
} }
virtual void Prepare(int64_t time_out) { void Process() {
context_.reset(new grpc::ClientContext()); ProcessImpl();
context_->set_wait_for_ready(true); var_h_->Finish(true);
std::chrono::system_clock::time_point deadline =
std::chrono::system_clock::now() + std::chrono::milliseconds(time_out);
context_->set_deadline(deadline);
} }
virtual void Process() = 0; VarHandlePtr GetVarHandlePtr() { return var_h_; }
bool Wait() { return var_h_->Wait(); }
void Finish(bool ok) { return var_h_->Finish(ok); }
virtual void ProcessImpl() = 0;
std::unique_ptr<grpc::ClientContext> context_; std::unique_ptr<grpc::ClientContext> context_;
grpc::Status status_; grpc::Status status_;
VarHandle var_h_;
protected:
VarHandlePtr var_h_;
}; };
typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)> typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)>
...@@ -94,13 +93,13 @@ typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)> ...@@ -94,13 +93,13 @@ typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)>
class SendProcessor : public BaseProcessor { class SendProcessor : public BaseProcessor {
public: public:
explicit SendProcessor(std::shared_ptr<grpc::Channel> ch) explicit SendProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch), stub_g_(ch) {} : BaseProcessor(), stub_g_(ch) {}
virtual ~SendProcessor() {} virtual ~SendProcessor() {}
virtual void Process() { void ProcessImpl() override {
if (response_call_back_) { if (response_call_back_) {
response_call_back_(var_h_, reply_); response_call_back_(*var_h_.get(), reply_);
} }
} }
...@@ -115,13 +114,13 @@ typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)> ...@@ -115,13 +114,13 @@ typedef std::function<void(const VarHandle&, const ::grpc::ByteBuffer&)>
class GetProcessor : public BaseProcessor { class GetProcessor : public BaseProcessor {
public: public:
explicit GetProcessor(std::shared_ptr<grpc::Channel> ch) explicit GetProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch), stub_g_(ch) {} : BaseProcessor(), stub_g_(ch) {}
virtual ~GetProcessor() {} virtual ~GetProcessor() {}
virtual void Process() { void ProcessImpl() override {
if (response_call_back_) { if (response_call_back_) {
response_call_back_(var_h_, reply_); response_call_back_(*var_h_.get(), reply_);
} }
} }
...@@ -133,13 +132,13 @@ class GetProcessor : public BaseProcessor { ...@@ -133,13 +132,13 @@ class GetProcessor : public BaseProcessor {
class BatchBarrierProcessor : public BaseProcessor { class BatchBarrierProcessor : public BaseProcessor {
public: public:
explicit BatchBarrierProcessor(std::shared_ptr<grpc::Channel> ch) explicit BatchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch) { : BaseProcessor() {
stub_ = sendrecv::SendRecvService::NewStub(ch); stub_ = sendrecv::SendRecvService::NewStub(ch);
} }
virtual ~BatchBarrierProcessor() {} virtual ~BatchBarrierProcessor() {}
virtual void Process() {} void ProcessImpl() override {}
sendrecv::VoidMessage reply_; sendrecv::VoidMessage reply_;
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_; std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
}; };
...@@ -147,13 +146,13 @@ class BatchBarrierProcessor : public BaseProcessor { ...@@ -147,13 +146,13 @@ class BatchBarrierProcessor : public BaseProcessor {
class FetchBarrierProcessor : public BaseProcessor { class FetchBarrierProcessor : public BaseProcessor {
public: public:
explicit FetchBarrierProcessor(std::shared_ptr<grpc::Channel> ch) explicit FetchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch) { : BaseProcessor() {
stub_ = sendrecv::SendRecvService::NewStub(ch); stub_ = sendrecv::SendRecvService::NewStub(ch);
} }
virtual ~FetchBarrierProcessor() {} virtual ~FetchBarrierProcessor() {}
virtual void Process() {} void ProcessImpl() override {}
sendrecv::VariableMessage reply_; sendrecv::VariableMessage reply_;
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_; std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
}; };
...@@ -161,13 +160,13 @@ class FetchBarrierProcessor : public BaseProcessor { ...@@ -161,13 +160,13 @@ class FetchBarrierProcessor : public BaseProcessor {
class CheckpointNotifyProcessor : public BaseProcessor { class CheckpointNotifyProcessor : public BaseProcessor {
public: public:
explicit CheckpointNotifyProcessor(std::shared_ptr<grpc::Channel> ch) explicit CheckpointNotifyProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch) { : BaseProcessor() {
stub_ = sendrecv::SendRecvService::NewStub(ch); stub_ = sendrecv::SendRecvService::NewStub(ch);
} }
virtual ~CheckpointNotifyProcessor() {} virtual ~CheckpointNotifyProcessor() {}
virtual void Process() {} void ProcessImpl() override {}
sendrecv::VoidMessage reply_; sendrecv::VoidMessage reply_;
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_; std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
}; };
...@@ -177,32 +176,37 @@ class GRPCClient : public RPCClient { ...@@ -177,32 +176,37 @@ class GRPCClient : public RPCClient {
GRPCClient() : ok_(true), completed_(false), stopped_(false) {} GRPCClient() : ok_(true), completed_(false), stopped_(false) {}
virtual ~GRPCClient(); virtual ~GRPCClient();
bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx, VarHandlePtr AsyncSendVar(const std::string& ep,
const framework::Scope& scope, const std::string& var_name, const platform::DeviceContext& ctx,
int64_t time_out = FLAGS_rpc_deadline) override; const framework::Scope& scope,
const std::string& var_name,
bool AsyncGetVar(const std::string& ep, const platform::DeviceContext& ctx, int64_t time_out = FLAGS_rpc_deadline) override;
const framework::Scope& scope, const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) override; VarHandlePtr AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
bool AsyncPrefetchVar(const std::string& ep, const framework::Scope& scope,
const platform::DeviceContext& ctx, const std::string& var_name,
const framework::Scope& scope, int64_t time_out = FLAGS_rpc_deadline) override;
const std::string& in_var_name,
const std::string& out_var_name, VarHandlePtr AsyncPrefetchVar(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) override; const platform::DeviceContext& ctx,
const framework::Scope& scope,
void AsyncSendBatchBarrier(const std::string& ep, const std::string& in_var_name,
int64_t time_out = FLAGS_rpc_deadline) override; const std::string& out_var_name,
int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) override; VarHandlePtr AsyncSendBatchBarrier(
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncCheckpointNotify(const std::string& ep, const std::string& dir,
int64_t time_out = FLAGS_rpc_deadline) override; VarHandlePtr AsyncSendFetchBarrier(
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncSendComplete(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) override; VarHandlePtr AsyncCheckpointNotify(
const std::string& ep, const std::string& dir,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncSendComplete(
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override;
bool Wait() override; bool Wait() override;
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -49,23 +50,77 @@ constexpr char kRequestPassBarrier[] = "RequestPassBarrier"; ...@@ -49,23 +50,77 @@ constexpr char kRequestPassBarrier[] = "RequestPassBarrier";
class RPCServer; class RPCServer;
struct VarHandle { class VarHandle {
// RPC endpoint. public:
std::string ep; VarHandle(const std::string ep, const std::string& method,
const platform::DeviceContext* ctx; const std::string& name,
const framework::Scope* scope; const platform::DeviceContext* p_ctx = nullptr,
// Variable name. const framework::Scope* p_scope = nullptr)
std::string name; : ok_(kVarHandleDefaultState) {
// RPC method name. ep_ = ep;
std::string method; ctx_ = p_ctx;
scope_ = p_scope;
name_ = name;
method_ = method;
}
virtual ~VarHandle() {}
public:
bool Wait() {
{
std::unique_lock<std::mutex> lk(sync_mutex_);
wait_cond_.wait(lk, [this] { return ok_ != kVarHandleDefaultState; });
}
VLOG(7) << "VarHandle wait:" << ok_;
return ok_ != 0;
}
void Finish(bool ok) {
{
std::unique_lock<std::mutex> lk(sync_mutex_);
ok_ = ok;
}
VLOG(7) << "VarHandle finish:" << ok;
wait_cond_.notify_all();
}
std::string String() const { std::string String() const {
std::ostringstream s; std::ostringstream s;
s << method << " name:[" << name << "], ep:[" << ep << "]"; s << method_ << " name:[" << name_ << "], ep:[" << ep_ << "], ok:[" << ok_
<< "]";
return s.str(); return s.str();
} }
std::string ep() const { return ep_; }
const platform::DeviceContext* ctx() const { return ctx_; }
const framework::Scope* scope() const { return scope_; }
std::string name() const { return name_; }
std::string method() const { return method_; }
protected:
// RPC endpoint.
std::string ep_;
const platform::DeviceContext* ctx_;
const framework::Scope* scope_;
// Variable name.
std::string name_;
// RPC method name.
std::string method_;
protected:
std::mutex sync_mutex_;
std::condition_variable wait_cond_;
int ok_;
static const int kVarHandleDefaultState = -1;
private:
DISABLE_COPY_AND_ASSIGN(VarHandle);
}; };
typedef std::shared_ptr<VarHandle> VarHandlePtr;
class RequestHandler { class RequestHandler {
public: public:
explicit RequestHandler(bool sync_mode) explicit RequestHandler(bool sync_mode)
......
...@@ -14,12 +14,14 @@ ...@@ -14,12 +14,14 @@
#pragma once #pragma once
#include <condition_variable> // NOLINT
#include <string> #include <string>
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
DECLARE_int32(rpc_deadline); DECLARE_int32(rpc_deadline);
...@@ -31,37 +33,36 @@ class RPCClient { ...@@ -31,37 +33,36 @@ class RPCClient {
public: public:
RPCClient() {} RPCClient() {}
virtual ~RPCClient() {} virtual ~RPCClient() {}
virtual bool AsyncSendVar(const std::string& ep, virtual VarHandlePtr AsyncSendVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& var_name, const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) = 0; int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual bool AsyncGetVar(const std::string& ep, virtual VarHandlePtr AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& var_name, const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) = 0; int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual bool AsyncPrefetchVar(const std::string& ep, virtual VarHandlePtr AsyncPrefetchVar(
const platform::DeviceContext& ctx, const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope, const std::string& in_var_name,
const std::string& in_var_name, const std::string& out_var_name,
const std::string& out_var_name, int64_t time_out = FLAGS_rpc_deadline) = 0;
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncSendBatchBarrier(
virtual void AsyncSendBatchBarrier(const std::string& ep, const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) = 0;
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncSendFetchBarrier(
virtual void AsyncSendFetchBarrier(const std::string& ep, const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) = 0;
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncCheckpointNotify(
virtual void AsyncCheckpointNotify(const std::string& ep, const std::string& ep, const std::string& dir,
const std::string& dir, int64_t time_out = FLAGS_rpc_deadline) = 0;
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncSendComplete(
virtual void AsyncSendComplete(const std::string& ep, const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) = 0;
int64_t time_out = FLAGS_rpc_deadline) = 0;
// Complete tells all the pserver instances that finishe the training, // Complete tells all the pserver instances that finishe the training,
// the pserver can reduce it's barrier count, and continue to train // the pserver can reduce it's barrier count, and continue to train
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <unistd.h>
#include <string>
#include <thread> // NOLINT
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
using paddle::operators::distributed::VarHandlePtr;
using paddle::operators::distributed::VarHandle;
void WaitTrue(VarHandlePtr s) { EXPECT_TRUE(s->Wait()); }
void WaitFalse(VarHandlePtr s) { EXPECT_FALSE(s->Wait()); }
TEST(VarHandle, Run) {
std::vector<VarHandlePtr> a;
for (int i = 0; i < 12; i++) {
VarHandlePtr s(new VarHandle("", "", "", nullptr, nullptr));
a.push_back(s);
}
std::vector<std::unique_ptr<std::thread>> t;
for (int i = 0; i < 6; i++) {
t.emplace_back(new std::thread(WaitFalse, a[i]));
}
for (int i = 0; i < 6; i++) {
a[i]->Finish(false);
t[i]->join();
}
for (int i = 6; i < 12; i++) {
t.emplace_back(new std::thread(WaitTrue, a[i]));
}
for (int i = 6; i < 12; i++) {
a[i]->Finish(true);
t[i]->join();
}
}
...@@ -25,14 +25,14 @@ namespace paddle { ...@@ -25,14 +25,14 @@ namespace paddle {
namespace operators { namespace operators {
void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of GRU should not be null."); PADDLE_ENFORCE(ctx->HasInput("X"), "Assert only one Input(X) of GRU.");
PADDLE_ENFORCE(ctx->HasInput("WeightX"), PADDLE_ENFORCE(ctx->HasInput("WeightX"),
"Input(WeightX) of GRU should not be null."); "Assert only one Input(WeightX) of GRU.");
PADDLE_ENFORCE(ctx->HasInput("WeightH"), PADDLE_ENFORCE(ctx->HasInput("WeightH"),
"Input(WeightH) of GRU should not be null."); "Assert only one Input(WeightH) of GRU.");
PADDLE_ENFORCE(ctx->HasOutput("XX"), "Output(XX) of GRU should not be null."); PADDLE_ENFORCE(ctx->HasOutput("XX"), "Assert only one Output(XX) of GRU.");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"), PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(Hidden) of GRU should not be null."); "Assert only one Output(Hidden) of GRU.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
...@@ -80,11 +80,11 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -80,11 +80,11 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
} else { } else {
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"), PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"),
"Output(ReorderedH0) of GRU should not be null."); "Assert only one Output(ReorderedH0) of GRU.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"), PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"),
"Output(BatchedInput) of GRU should not be null."); "Assert only one Output(BatchedInput) of GRU.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedOut"), PADDLE_ENFORCE(ctx->HasOutput("BatchedOut"),
"Output(BatchedOut) of GRU should not be null."); "Assert only one Output(BatchedOut) of GRU.");
ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]}); ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
ctx->SetOutputDim("BatchedOut", out_dims); ctx->SetOutputDim("BatchedOut", out_dims);
} }
......
...@@ -24,20 +24,17 @@ namespace paddle { ...@@ -24,20 +24,17 @@ namespace paddle {
namespace operators { namespace operators {
void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of LSTM should not be null."); PADDLE_ENFORCE(ctx->HasInput("X"), "Assert only one Input(X) of LSTM.");
PADDLE_ENFORCE(ctx->HasInput("WeightX"), PADDLE_ENFORCE(ctx->HasInput("WeightX"),
"Input(WeightX) of LSTM should not be null."); "Assert only one Input(WeightX) of LSTM.");
PADDLE_ENFORCE(ctx->HasInput("WeightH"), PADDLE_ENFORCE(ctx->HasInput("WeightH"),
"Input(WeightH) of LSTM should not be null."); "Assert only one Input(WeightH) of LSTM.");
PADDLE_ENFORCE(ctx->HasInput("Bias"), PADDLE_ENFORCE(ctx->HasInput("Bias"), "Assert only one Input(Bias) of LSTM.");
"Input(Bias) of LSTM should not be null."); PADDLE_ENFORCE(ctx->HasOutput("XX"), "Assert only one Output(XX) of LSTM.");
PADDLE_ENFORCE(ctx->HasOutput("XX"),
"Output(XX) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"), PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(Hidden) of LSTM should not be null."); "Assert only one Output(Hidden) of LSTM.");
PADDLE_ENFORCE(ctx->HasOutput("Cell"), PADDLE_ENFORCE(ctx->HasOutput("Cell"),
"Output(Cell) of LSTM should not be null."); "Assert only one Output(Cell) of LSTM.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
...@@ -96,15 +93,15 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -96,15 +93,15 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
} else { } else {
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"), PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"),
"Output(BatchedInput) of LSTM should not be null."); "Assert only one Output(BatchedInput) of LSTM.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"), PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"),
"Output(BatchedHidden) of LSTM should not be null."); "Assert only one Output(BatchedHidden) of LSTM.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedCell"), PADDLE_ENFORCE(ctx->HasOutput("BatchedCell"),
"Output(BatchedCell) of LSTM should not be null."); "Assert only one Output(BatchedCell) of LSTM.");
PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"), PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"),
"Output(ReorderedH0) of LSTM should not be null."); "Assert only one Output(ReorderedH0) of LSTM");
PADDLE_ENFORCE(ctx->HasOutput("ReorderedC0"), PADDLE_ENFORCE(ctx->HasOutput("ReorderedC0"),
"Output(ReorderedC0) of LSTM should not be null."); "Assert only one Output(ReorderedC0) of LSTM.");
ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]}); ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
ctx->SetOutputDim("BatchedHidden", out_dims); ctx->SetOutputDim("BatchedHidden", out_dims);
ctx->SetOutputDim("BatchedCell", out_dims); ctx->SetOutputDim("BatchedCell", out_dims);
......
...@@ -44,16 +44,20 @@ class PrefetchOp : public framework::OperatorBase { ...@@ -44,16 +44,20 @@ class PrefetchOp : public framework::OperatorBase {
distributed::RPCClient* rpc_client = distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient::GetInstance<RPCCLIENT_T>();
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) { if (NeedSend(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i] << " to get " VLOG(3) << "sending " << ins[i] << " to " << epmap[i] << " to get "
<< outs[i] << " back"; << outs[i] << " back";
rpc_client->AsyncPrefetchVar(epmap[i], ctx, scope, ins[i], outs[i]); rets.push_back(rpc_client->AsyncPrefetchVar(epmap[i], ctx, scope,
ins[i], outs[i]));
} else { } else {
VLOG(3) << "don't send no-initialied variable: " << ins[i]; VLOG(3) << "don't send no-initialied variable: " << ins[i];
} }
} }
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
}
} }
}; };
......
...@@ -44,12 +44,15 @@ class RecvOp : public framework::OperatorBase { ...@@ -44,12 +44,15 @@ class RecvOp : public framework::OperatorBase {
distributed::RPCClient* rpc_client = distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient::GetInstance<RPCCLIENT_T>();
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < outs.size(); i++) { for (size_t i = 0; i < outs.size(); i++) {
VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; VLOG(3) << "getting " << outs[i] << " from " << epmap[i];
rpc_client->AsyncGetVar(epmap[i], ctx, scope, outs[i]); rets.push_back(rpc_client->AsyncGetVar(epmap[i], ctx, scope, outs[i]));
} }
if (sync_mode) { if (sync_mode) {
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
}
} }
} }
}; };
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include <future> // NOLINT #include <future> // NOLINT
#include <ostream> #include <ostream>
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -45,18 +46,19 @@ class SendOp : public framework::OperatorBase { ...@@ -45,18 +46,19 @@ class SendOp : public framework::OperatorBase {
distributed::RPCClient* rpc_client = distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient::GetInstance<RPCCLIENT_T>();
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) { if (NeedSend(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i]; VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
// TODO(Yancey1989): we need to use an IO threadpool which has rets.push_back(rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i]));
// a larger number of threads than the computing threadpool.
rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i]);
} else { } else {
VLOG(3) << "don't send no-initialied variable: " << ins[i]; VLOG(3) << "don't send no-initialied variable: " << ins[i];
} }
} }
if (sync_send) { if (sync_send) {
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
}
} }
} }
}; };
......
...@@ -31,7 +31,8 @@ __global__ void CrossEntropyGrad(T* logit_grad, const int64_t* labels, ...@@ -31,7 +31,8 @@ __global__ void CrossEntropyGrad(T* logit_grad, const int64_t* labels,
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch_size; for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch_size;
i += blockDim.x * gridDim.x) { i += blockDim.x * gridDim.x) {
int idx = i * class_num + labels[i]; int idx = i * class_num + labels[i];
logit_grad[idx] -= static_cast<T>(1.); logit_grad[idx] -=
ignore_index == labels[i] ? static_cast<T>(0.) : static_cast<T>(1.);
} }
} }
......
...@@ -192,7 +192,8 @@ class MKLDNNHandler { ...@@ -192,7 +192,8 @@ class MKLDNNHandler {
mkldnn::memory::primitive_desc& user_mpd, // NOLINT mkldnn::memory::primitive_desc& user_mpd, // NOLINT
const std::shared_ptr<mkldnn::memory> user_memory_p, const std::shared_ptr<mkldnn::memory> user_memory_p,
const std::string& suffix, const std::string& suffix,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT std::vector<mkldnn::primitive>& pipeline, // NOLINT
bool is_persistent = false) {
// create reorder primitive if the input format is not the preferred one // create reorder primitive if the input format is not the preferred one
auto local_key = key_ + suffix; auto local_key = key_ + suffix;
auto key_reorder_p = key_ + suffix + "reorder_p"; auto key_reorder_p = key_ + suffix + "reorder_p";
...@@ -213,7 +214,7 @@ class MKLDNNHandler { ...@@ -213,7 +214,7 @@ class MKLDNNHandler {
pipeline.push_back(*reorder_p); pipeline.push_back(*reorder_p);
} }
dev_ctx_.SetBlob(local_key, target_memory_p); dev_ctx_.SetBlob(local_key, target_memory_p);
} else { } else if (!is_persistent) {
// Make reorder if needed // Make reorder if needed
auto reorder_p = std::static_pointer_cast<mkldnn::reorder>( auto reorder_p = std::static_pointer_cast<mkldnn::reorder>(
dev_ctx_.GetBlob(key_reorder_p)); dev_ctx_.GetBlob(key_reorder_p));
......
...@@ -128,6 +128,13 @@ class ParallelExecutor(object): ...@@ -128,6 +128,13 @@ class ParallelExecutor(object):
os.environ.get('CPU_NUM', multiprocessing.cpu_count())) os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
exec_strategy.num_threads = cpu_num * 2 exec_strategy.num_threads = cpu_num * 2
# Set 1 thread num under nccl2 distribute
# env to make sure all gpus run ops in same order.
if num_trainers > 1:
assert (use_cuda)
# FIXME(gongwb): avoid this set.
exec_strategy.num_threads = 1
if build_strategy is None: if build_strategy is None:
build_strategy = BuildStrategy() build_strategy = BuildStrategy()
......
...@@ -178,7 +178,4 @@ if __name__ == '__main__': ...@@ -178,7 +178,4 @@ if __name__ == '__main__':
for parallel in (False, True): for parallel in (False, True):
if use_cuda and not core.is_compiled_with_cuda(): if use_cuda and not core.is_compiled_with_cuda():
continue continue
# TODO(minqiyang): remove this line after fixing the deletion main(use_cuda=use_cuda, parallel=parallel)
# order problem of Scope in ParallelExecutor in manylinux
if six.PY2:
main(use_cuda=use_cuda, parallel=parallel)
...@@ -152,7 +152,4 @@ if __name__ == '__main__': ...@@ -152,7 +152,4 @@ if __name__ == '__main__':
for parallel in (False, True): for parallel in (False, True):
if use_cuda and not core.is_compiled_with_cuda(): if use_cuda and not core.is_compiled_with_cuda():
continue continue
# TODO(minqiyang): remove this line after fixing the deletion main(use_cuda=use_cuda, parallel=parallel)
# order problem of Scope in ParallelExecutor in manylinux
if six.PY2:
main(use_cuda=use_cuda, parallel=parallel)
...@@ -155,7 +155,4 @@ if __name__ == '__main__': ...@@ -155,7 +155,4 @@ if __name__ == '__main__':
for parallel in (False, True): for parallel in (False, True):
if use_cuda and not core.is_compiled_with_cuda(): if use_cuda and not core.is_compiled_with_cuda():
continue continue
# TODO(minqiyang): remove this line after fixing the deletion main(use_cuda=use_cuda, parallel=parallel)
# order problem of Scope in ParallelExecutor in manylinux
if six.PY2:
main(use_cuda=use_cuda, parallel=parallel)
...@@ -137,7 +137,4 @@ if __name__ == '__main__': ...@@ -137,7 +137,4 @@ if __name__ == '__main__':
for parallel in (False, True): for parallel in (False, True):
if use_cuda and not core.is_compiled_with_cuda(): if use_cuda and not core.is_compiled_with_cuda():
continue continue
# TODO(minqiyang): remove this line after fixing the deletion main(use_cuda=use_cuda, parallel=parallel)
# order problem of Scope in ParallelExecutor in manylinux
if six.PY2:
main(use_cuda=use_cuda, parallel=parallel)
...@@ -84,7 +84,7 @@ class TestDataBalance(unittest.TestCase): ...@@ -84,7 +84,7 @@ class TestDataBalance(unittest.TestCase):
self.data_file_name = './data_balance_test.recordio' self.data_file_name = './data_balance_test.recordio'
self.lod_data_file_name = './data_balance_with_lod_test.recordio' self.lod_data_file_name = './data_balance_with_lod_test.recordio'
self.total_ins_num = 50 self.total_ins_num = 50
self.batch_size = 10 self.batch_size = 12
self.prepare_data() self.prepare_data()
self.prepare_lod_data() self.prepare_lod_data()
......
...@@ -20,6 +20,7 @@ import numpy as np ...@@ -20,6 +20,7 @@ import numpy as np
from parallel_executor_test_base import TestParallelExecutorBase from parallel_executor_test_base import TestParallelExecutorBase
import unittest import unittest
import paddle import paddle
import paddle.fluid.core as core
import paddle.dataset.wmt16 as wmt16 import paddle.dataset.wmt16 as wmt16
import os import os
...@@ -170,7 +171,8 @@ class TestTransformer(TestParallelExecutorBase): ...@@ -170,7 +171,8 @@ class TestTransformer(TestParallelExecutorBase):
writer.complete_append_tensor() writer.complete_append_tensor()
def test_main(self): def test_main(self):
self.check_network_convergence(transformer, use_cuda=True) if core.is_compiled_with_cuda():
self.check_network_convergence(transformer, use_cuda=True)
self.check_network_convergence(transformer, use_cuda=False, iter=5) self.check_network_convergence(transformer, use_cuda=False, iter=5)
......
...@@ -96,7 +96,8 @@ class TestPyReaderUsingExecutor(unittest.TestCase): ...@@ -96,7 +96,8 @@ class TestPyReaderUsingExecutor(unittest.TestCase):
self.queue_capacity = 50 self.queue_capacity = 50
def test(self): def test(self):
for use_cuda in [False, True]: for use_cuda in ([False, True]
if core.is_compiled_with_cuda() else [False]):
for use_parallel_executor in [False, True]: for use_parallel_executor in [False, True]:
for use_double_buffer in [False, True]: for use_double_buffer in [False, True]:
print('Test Parameters:'), print('Test Parameters:'),
......
...@@ -60,12 +60,46 @@ class InferenceTranspiler(object): ...@@ -60,12 +60,46 @@ class InferenceTranspiler(object):
if not isinstance(scope, core.Scope): if not isinstance(scope, core.Scope):
raise TypeError("scope should be as Scope type or None") raise TypeError("scope should be as Scope type or None")
use_mkldnn = bool(os.getenv("FLAGS_use_mkldnn", False)) use_mkldnn = bool(os.getenv("FLAGS_use_mkldnn", False))
self._fuse_batch_norm(program, place, scope) self._fuse_batch_norm(program, place, scope)
if use_mkldnn: if use_mkldnn:
self._fuse_relu_mkldnn(program)
self._fuse_conv_bias_mkldnn(program) self._fuse_conv_bias_mkldnn(program)
self._fuse_conv_relu_mkldnn(program)
self._fuse_bn_relu_mkldnn(program)
def _fuse_conv_relu_mkldnn(self, program):
'''
Transpile the program by fused relu activation for MKLDNN program.
Relu activation following convolution OP can be fused by adding
'fuse_relu' attribute to convolution OP.
The result of fuse is:
- before:
- conv->relu->any_other_op
- after:
- conv->any_other_op
:param program: program to transpile
:type program: Program
'''
self.block = program.block(0)
i = 0
while i < len(self.block.ops):
current_op = self.block.ops[i]
if current_op.type in ['conv2d']:
next_op = self.block.ops[i + 1]
if next_op.type == 'relu':
# modify conv OP to include relu
current_op.set_attr("fuse_relu", True)
# remove conv OP
self.block._remove_op(i + 1)
i = i + 1
# TODO(luotao): use clone() method to flush the program.desc in force,
# since some large program.desc will not be flushed immediately.
# And a better solution will be considered later.
program = program.clone()
def _fuse_relu_mkldnn(self, program): def _fuse_bn_relu_mkldnn(self, program):
''' '''
Transpile the program by fused relu activation for MKLDNN program. Transpile the program by fused relu activation for MKLDNN program.
...@@ -159,7 +193,6 @@ class InferenceTranspiler(object): ...@@ -159,7 +193,6 @@ class InferenceTranspiler(object):
self._fuse_conv_bias(i, current_op, next_op) self._fuse_conv_bias(i, current_op, next_op)
self.block._remove_op(i + 1) # Remove old conv self.block._remove_op(i + 1) # Remove old conv
self.block._remove_op(i + 1) # Remove elementwise_add self.block._remove_op(i + 1) # Remove elementwise_add
i = i + 1
i = i + 1 i = i + 1
self._remove_unused_var() self._remove_unused_var()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册