提交 238b24bf 编写于 作者: Q Qiao Longfei

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into optimize-adam

......@@ -19,6 +19,15 @@ Our vision is to enable deep learning for everyone via PaddlePaddle.
Please refer to our [release announcement](https://github.com/PaddlePaddle/Paddle/releases) to track the latest feature of PaddlePaddle.
欢迎来到 PaddlePaddle GitHub
PaddlePaddle (PArallel Distributed Deep LEarning) 是一个简单易用、高效灵活、可扩展的深度学习平台,最初由百度科学家和工程师共同开发,目的是将深度学习技术应用到百度的众多产品中。
我们的愿景是让每个人都能通过PaddlePaddle接触深度学习
跟进PaddlePaddle最新特性请参考我们的[版本说明](https://github.com/PaddlePaddle/Paddle/releases)
### Latest PaddlePaddle Release: [Fluid 1.2.0](https://github.com/PaddlePaddle/Paddle/tree/release/1.2)
### Install Latest Stable Release:
```
......@@ -34,6 +43,23 @@ pip install paddlepaddle-gpu==1.2.0.post85
# For installation on other platform, refer to http://paddlepaddle.org/
```
### PaddlePaddle最新版本: [Fluid 1.2.0](https://github.com/PaddlePaddle/Paddle/tree/release/1.2)
### 安装最新稳定版本:
```
# Linux CPU
pip install paddlepaddle
# Linux GPU cuda9cudnn7
pip install paddlepaddle-gpu
# Linux GPU cuda8cudnn7
pip install paddlepaddle-gpu==1.2.0.post87
# Linux GPU cuda8cudnn5
pip install paddlepaddle-gpu==1.2.0.post85
# 其他平台上的安装指引请参考 http://paddlepaddle.org/
```
## Features
- **Flexibility**
......@@ -74,10 +100,38 @@ pip install paddlepaddle-gpu==1.2.0.post85
Baidu and it has achieved a significant impact. We hope you can also explore
the capability of PaddlePaddle to make an impact on your product.
## 特点
- **灵活性**
PaddlePaddle支持丰富的神经网络架构和优化算法。易于配置复杂模型,例如带有注意力机制或复杂记忆连接的神经网络机器翻译模型。
- **高效性**
为了高效使用异步计算资源,PaddlePaddle对框架的不同层进行优化,包括计算、存储、架构和通信。下面是一些样例:
- 通过SSE/AVX 内置函数、BLAS库(例如MKL、OpenBLAS、cuBLAS)或定制的CPU/GPU内核优化数学操作。
- 通过MKL-DNN库优化CNN网络
- 高度优化循环网络,无需执行 `padding` 操作即可处理 **变长** 序列
- 针对高维稀疏数据模型,优化了局部和分布式训练。
- **稳定性**
有了 PaddlePaddle,使得利用各种CPU/GPU和机器来加速训练变得简单。PaddlePaddle 通过优化通信可以实现巨大吞吐量和快速执行。
- **连接产品**
另外,PaddlePaddle 的设计也易于部署。在百度,PaddlePaddle 已经部署到含有巨大用户量的产品和服务上,包括广告点击率(CTR)预测、大规模图像分类、光学字符识别(OCR)、搜索排序,计算机病毒检测、推荐系统等等。PaddlePaddle广泛应用于百度产品中,产生了非常重要的影响。我们希望您也能探索 PaddlePaddle 的能力,为您的产品创造新的影响力和效果。
## Installation
It is recommended to read [this doc](http://paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/install/index_cn.html) on our website.
## 安装
推荐阅读官网上的[安装说明](http://paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/install/index_cn.html)
## Documentation
We provide [English](http://paddlepaddle.org/documentation/docs/en/1.2/getstarted/index_en.html) and
......@@ -99,10 +153,37 @@ We provide [English](http://paddlepaddle.org/documentation/docs/en/1.2/getstarte
We appreciate your contributions!
## 文档
我们提供[英文](http://paddlepaddle.org/documentation/docs/en/1.2/getstarted/index_en.html)
[中文](http://paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/index.html) 文档
- [深度学习101](https://github.com/PaddlePaddle/book)
或许您想从这个在线交互式书籍开始,可以在Jupyter Notebook中运行
- [分布式训练](http://paddlepaddle.org/documentation/docs/zh/1.2/user_guides/howto/training/cluster_howto.html)
可以在MPI集群上运行分布式训练任务
- [Python API](http://paddlepaddle.org/documentation/docs/zh/1.2/api_cn/index_cn.html)
新的API支持代码更少更简洁的程序
- [贡献方式](http://paddlepaddle.org/documentation/docs/zh/1.2/advanced_usage/development/contribute_to_paddle/index_cn.html)
欢迎您的贡献!
## Ask Questions
You are welcome to submit questions and bug reports as [Github Issues](https://github.com/PaddlePaddle/Paddle/issues).
## 答疑
欢迎您将问题和bug报告以[Github Issues](https://github.com/PaddlePaddle/Paddle/issues)的形式提交
## Copyright and License
PaddlePaddle is provided under the [Apache-2.0 license](LICENSE).
## 版权和许可证
PaddlePaddle由[Apache-2.0 license](LICENSE)提供
......@@ -81,9 +81,11 @@ def dist_transpile(trainer_id, args, train_prog, startup_prog):
# the role, should be either PSERVER or TRAINER
training_role = os.getenv("PADDLE_TRAINING_ROLE")
config = distribute_transpiler.DistributeTranspilerConfig()
config = fluid.DistributeTranspilerConfig()
config.slice_var_up = not args.no_split_var
config.min_block_size = 1048576
t = distribute_transpiler.DistributeTranspiler(config=config)
t.transpile(
trainer_id,
# NOTE: *MUST* use train_prog, for we are using with guard to
......
......@@ -14,14 +14,16 @@
INCLUDE(ExternalProject)
find_library(SSL_LIBRARY NAMES ssl)
find_package(OpenSSL REQUIRED)
message(STATUS "ssl:" ${OPENSSL_SSL_LIBRARY})
message(STATUS "crypto:" ${OPENSSL_CRYPTO_LIBRARY})
ADD_LIBRARY(ssl SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET ssl PROPERTY IMPORTED_LOCATION ${SSL_LIBRARY})
SET_PROPERTY(TARGET ssl PROPERTY IMPORTED_LOCATION ${OPENSSL_SSL_LIBRARY})
find_library(CRYPTO_LIBRARY NAMES crypto)
ADD_LIBRARY(crypto SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET crypto PROPERTY IMPORTED_LOCATION ${CRYPTO_LIBRARY})
SET_PROPERTY(TARGET crypto PROPERTY IMPORTED_LOCATION ${OPENSSL_CRYPTO_LIBRARY})
SET(BRPC_SOURCES_DIR ${THIRD_PARTY_PATH}/brpc)
SET(BRPC_INSTALL_DIR ${THIRD_PARTY_PATH}/install/brpc)
......@@ -31,14 +33,15 @@ SET(BRPC_LIBRARIES "${BRPC_INSTALL_DIR}/lib/libbrpc.a" CACHE FILEPATH "brpc libr
INCLUDE_DIRECTORIES(${BRPC_INCLUDE_DIR})
# Reference https://stackoverflow.com/questions/45414507/pass-a-list-of-prefix-paths-to-externalproject-add-in-cmake-args
set(prefix_path "${THIRD_PARTY_PATH}/install/gflags|${THIRD_PARTY_PATH}/install/leveldb|${THIRD_PARTY_PATH}/install/snappy|${THIRD_PARTY_PATH}/install/gtest|${THIRD_PARTY_PATH}/install/protobuf|${THIRD_PARTY_PATH}/install/zlib")
set(prefix_path "${THIRD_PARTY_PATH}/install/gflags|${THIRD_PARTY_PATH}/install/leveldb|${THIRD_PARTY_PATH}/install/snappy|${THIRD_PARTY_PATH}/install/gtest|${THIRD_PARTY_PATH}/install/protobuf|${THIRD_PARTY_PATH}/install/zlib|${THIRD_PARTY_PATH}/install/glog")
# If minimal .a is need, you can set WITH_DEBUG_SYMBOLS=OFF
ExternalProject_Add(
extern_brpc
${EXTERNAL_PROJECT_LOG_ARGS}
# TODO(gongwb): change to de newst repo when they changed.
GIT_REPOSITORY "https://github.com/gongweibao/brpc"
GIT_TAG "7dc04defad1fd4173aae170c3fcbde131b65155a"
GIT_TAG "e9b67ec1b7458f2af5fae76451afe1e27e01b4b4"
PREFIX ${BRPC_SOURCES_DIR}
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
......@@ -50,7 +53,7 @@ ExternalProject_Add(
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_PREFIX_PATH=${prefix_path}
-DBRPC_WITH_GLOG=ON
-DWITH_GLOG=ON
-DIOBUF_WITH_HUGE_BLOCK=ON
-DBRPC_WITH_RDMA=${WITH_BRPC_RDMA}
${EXTERNAL_OPTIONAL_ARGS}
......@@ -65,5 +68,6 @@ ADD_LIBRARY(brpc STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET brpc PROPERTY IMPORTED_LOCATION ${BRPC_LIBRARIES})
ADD_DEPENDENCIES(brpc extern_brpc)
add_definitions(-DBRPC_WITH_GLOG)
LIST(APPEND external_project_dependencies brpc)
......@@ -12,8 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
IF(WITH_TESTING)
ENABLE_TESTING()
#FIXME:(gongwb) Move brpc's gtest dependency.
IF(WITH_TESTING OR (WITH_DISTRIBUTE AND NOT WITH_GRPC))
IF(WITH_TESTING)
ENABLE_TESTING()
ENDIF(WITH_TESTING)
INCLUDE(ExternalProject)
SET(GTEST_SOURCES_DIR ${THIRD_PARTY_PATH}/gtest)
......@@ -76,4 +80,4 @@ IF(WITH_TESTING)
ADD_DEPENDENCIES(gtest_main extern_gtest)
LIST(APPEND external_project_dependencies gtest gtest_main)
ENDIF(WITH_TESTING)
ENDIF(WITH_TESTING OR (WITH_DISTRIBUTE AND NOT WITH_GRPC))
......@@ -24,8 +24,8 @@ ExternalProject_Add(
extern_leveldb
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${LEVELDB_SOURCES_DIR}
URL "https://github.com/google/leveldb/archive/v1.18.tar.gz"
URL_MD5 "73770de34a2a5ab34498d2e05b2b7fa0"
GIT_REPOSITORY "https://github.com/google/leveldb"
GIT_TAG v1.18
CONFIGURE_COMMAND ""
BUILD_COMMAND CXXFLAGS=-fPIC make -j ${NUM_OF_PROCESSOR} libleveldb.a
INSTALL_COMMAND mkdir -p ${LEVELDB_INSTALL_DIR}/lib/
......
......@@ -77,6 +77,8 @@ paddle.fluid.layers.sequence_softmax ArgSpec(args=['input', 'use_cudnn', 'name']
paddle.fluid.layers.softmax ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(True, None))
paddle.fluid.layers.pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True))
paddle.fluid.layers.pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True))
paddle.fluid.layers.adaptive_pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'require_index', 'name'], varargs=None, keywords=None, defaults=('max', False, None))
paddle.fluid.layers.adaptive_pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'require_index', 'name'], varargs=None, keywords=None, defaults=('max', False, None))
paddle.fluid.layers.batch_norm ArgSpec(args=['input', 'act', 'is_test', 'momentum', 'epsilon', 'param_attr', 'bias_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var', 'fuse_with_relu', 'use_global_stats'], varargs=None, keywords=None, defaults=(None, False, 0.9, 1e-05, None, None, 'NCHW', False, None, None, None, False, False, False))
paddle.fluid.layers.beam_search_decode ArgSpec(args=['ids', 'scores', 'beam_size', 'end_id', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.conv2d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None))
......
......@@ -169,9 +169,12 @@ cc_library(variable_helper SRCS variable_helper.cc DEPS lod_tensor)
cc_library(naive_executor SRCS naive_executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper)
if(WITH_DISTRIBUTE)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc cares grpc++_unsecure grpc_unsecure gpr graph_to_program_pass variable_helper)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog
lod_rank_table feed_fetch_method sendrecvop_rpc ${GLOB_DISTRIBUTE_DEPS} graph_to_program_pass variable_helper)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else()
if(WITH_NGRAPH)
if(NOT WIN32)
......
......@@ -12,12 +12,19 @@ cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
if(WITH_DISTRIBUTE)
if(NOT WITH_GRPC)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(reduce_op_handle.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
endif()
endif()
if(WITH_GPU)
nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor)
if(WITH_DISTRIBUTE)
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim dynload_cuda selected_rows_functor sendrecvop_grpc)
ddim dynload_cuda selected_rows_functor sendrecvop_rpc)
else()
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim dynload_cuda selected_rows_functor)
......@@ -30,7 +37,7 @@ else()
variable_visitor)
if(WITH_DISTRIBUTE)
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim selected_rows_functor sendrecvop_grpc)
ddim selected_rows_functor sendrecvop_rpc)
else()
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim selected_rows_functor)
......
......@@ -157,9 +157,9 @@ void Executor::Close() {
#ifdef PADDLE_WITH_DISTRIBUTE
// TODO(typhoonzero): complete message will need to use real trainer_id,
// except 0.
::paddle::operators::distributed::RPCClient::GetInstance<
::paddle::operators::distributed::GRPCClient>(0)
->SendComplete();
auto client =
paddle::operators::distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
client->SendComplete();
#endif
}
......
......@@ -42,6 +42,8 @@ pass_library(multi_batch_merge_pass base)
pass_library(conv_bn_fuse_pass inference)
pass_library(seqconv_eltadd_relu_fuse_pass inference)
pass_library(is_test_pass base)
pass_library(conv_elementwise_add_act_fuse_pass inference)
pass_library(conv_elementwise_add2_act_fuse_pass inference)
if(WITH_MKLDNN)
pass_library(mkldnn_placement_pass base)
pass_library(depthwise_conv_mkldnn_pass base)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.h"
namespace paddle {
namespace framework {
namespace ir {
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(conv_op); \
GET_IR_NODE(conv_out); \
GET_IR_NODE(conv_filter); \
GET_IR_NODE(elementwise_add_op); \
GET_IR_NODE(elementwise_add_in_y); \
GET_IR_NODE(elementwise_add_out); \
GET_IR_NODE(elementwise_add_op_1); \
GET_IR_NODE(elementwise_add_in_y_1); \
GET_IR_NODE(elementwise_add_out_1); \
GET_IR_NODE(act_op); \
GET_IR_NODE(act_out);
// Inherient the basic infomation from `base_desc`, and modify some fields.
framework::proto::OpDesc PrepareOpDesc(
const framework::proto::OpDesc& base_desc, const std::string& bias,
const std::string& bias1, const std::string& activation,
const std::string& output) {
auto proto = base_desc;
framework::OpDesc desc(proto, nullptr);
desc.SetInput("Bias", {bias});
desc.SetInput("ResidualData", {bias1});
desc.SetAttr("activation", activation);
desc.SetOutput("Output", {output});
desc.SetAttr("is_test", true);
desc.SetAttr("use_cudnn", false);
return *desc.Proto();
}
std::unique_ptr<ir::Graph> ConvElementwiseAddActFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
const std::string pattern_name = "conv_elementwise_add_act_fuse";
FusePassBase::Init(pattern_name, graph.get());
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()->NewNode("x")->AsInput()->assert_is_op_input(
"conv2d", "Input");
patterns::ConvElementwiseaddAct pattern(gpd.mutable_pattern(), pattern_name);
pattern(x);
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_NODES;
auto base_op_desc = *conv_op->Op()->Proto();
std::string bias_name = elementwise_add_in_y->Name();
std::string bias1_name = elementwise_add_in_y_1->Name();
std::string act_op_type = act_op->Op()->Type();
std::string act_op_out = act_out->Name();
auto new_op_proto = PrepareOpDesc(base_op_desc, bias_name, bias1_name,
act_op_type, act_op_out);
framework::OpDesc new_op_desc(new_op_proto, nullptr);
// Create a new node for the fused op.
auto new_conv_op = graph->CreateOpNode(&new_op_desc);
// Link inputs and outputs.
PADDLE_ENFORCE(subgraph.count(x));
auto* conv_in_node = subgraph.at(x);
IR_NODE_LINK_TO(conv_in_node, new_conv_op); // Input
IR_NODE_LINK_TO(conv_filter, new_conv_op); // Filter
IR_NODE_LINK_TO(elementwise_add_in_y, new_conv_op); // Bias
IR_NODE_LINK_TO(elementwise_add_in_y_1, new_conv_op); // ResidualData
IR_NODE_LINK_TO(new_conv_op, act_out); // Output
// Delete the unneeded nodes.
GraphSafeRemoveNodes(graph.get(),
{conv_op, elementwise_add_op, elementwise_add_op_1,
elementwise_add_out});
};
gpd(graph.get(), handler);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(conv_elementwise_add2_act_fuse_pass,
paddle::framework::ir::ConvElementwiseAdd2ActFusePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.h"
#include <string>
namespace paddle {
namespace framework {
namespace ir {
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(conv_op); \
GET_IR_NODE(conv_out); \
GET_IR_NODE(conv_filter); \
GET_IR_NODE(elementwise_add_op); \
GET_IR_NODE(elementwise_add_in_y); \
GET_IR_NODE(elementwise_add_out); \
GET_IR_NODE(elementwise_add_op_1); \
GET_IR_NODE(elementwise_add_in_y_1); \
GET_IR_NODE(elementwise_add_out_1); \
GET_IR_NODE(act_op); \
GET_IR_NODE(act_out);
// Inherient the basic infomation from `base_desc`, and modify some fields.
framework::proto::OpDesc PrepareOpDesc(
const framework::proto::OpDesc& base_desc, const std::string& bias,
const std::string& bias1, const std::string& activation,
const std::string& output) {
auto proto = base_desc;
framework::OpDesc desc(proto, nullptr);
desc.SetInput("Bias", {bias});
desc.SetInput("ResidualData", {bias1});
desc.SetAttr("activation", activation);
desc.SetOutput("Output", {output});
desc.SetAttr("is_test", true);
return *desc.Proto();
}
std::unique_ptr<ir::Graph> ConvElementwiseAdd2ActFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
const std::string pattern_name = "conv_elementwise_add_act_fuse";
FusePassBase::Init(pattern_name, graph.get());
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()->NewNode("x")->AsInput()->assert_is_op_input(
"conv2d", "Input");
patterns::ConvElementwiseadd2Act pattern(gpd.mutable_pattern(), pattern_name);
pattern(x);
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_NODES;
auto base_op_desc = *conv_op->Op()->Proto();
std::string bias_name = elementwise_add_in_y->Name();
std::string bias1_name = elementwise_add_in_y_1->Name();
std::string act_op_type = act_op->Op()->Type();
std::string act_op_out = act_out->Name();
auto new_op_proto = PrepareOpDesc(base_op_desc, bias_name, bias1_name,
act_op_type, act_op_out);
framework::OpDesc new_op_desc(new_op_proto, nullptr);
// Create a new node for the fused op.
graph->CreateOpNode(&new_op_desc);
// Link inputs and outputs.
PADDLE_ENFORCE(subgraph.count(x));
auto* conv_in_node = subgraph.at(x);
IR_NODE_LINK_TO(conv_in_node, conv_op); // Input
IR_NODE_LINK_TO(conv_filter, conv_op); // Filter
IR_NODE_LINK_TO(conv_op, conv_out); // Output
IR_NODE_LINK_TO(elementwise_add_in_y, conv_op); // Bias
IR_NODE_LINK_TO(elementwise_add_in_y_1, conv_op); // Bias
// Delete the unneeded nodes.
GraphSafeRemoveNodes(graph.get(),
{conv_op, elementwise_add_op, elementwise_add_op_1,
elementwise_add_out});
};
gpd(graph.get(), handler);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(conv_elementwise_add2_act_fuse_pass,
paddle::framework::ir::ConvElementwiseAdd2ActFusePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
class ConvElementwiseAdd2ActFusePass : public FusePassBase {
public:
virtual ~ConvElementwiseAdd2ActFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
namespace paddle {
namespace framework {
namespace ir {
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(conv_op); \
GET_IR_NODE(conv_out); \
GET_IR_NODE(conv_filter); \
GET_IR_NODE(elementwise_add_op); \
GET_IR_NODE(elementwise_add_in_y); \
GET_IR_NODE(elementwise_add_out); \
GET_IR_NODE(act_op); \
GET_IR_NODE(act_out);
// Inherient the basic infomation from `base_desc`, and modify some fields.
framework::proto::OpDesc PrepareOpDesc(
const framework::proto::OpDesc& base_desc, const std::string& bias,
const std::string& activation, const std::string& output) {
auto proto = base_desc;
framework::OpDesc desc(proto, nullptr);
desc.SetType("conv2d_fusion");
desc.SetInput("Bias", {bias});
desc.SetInput("ResidualData", {});
desc.SetAttr("activation", activation);
desc.SetOutput("Output", {output});
desc.SetAttr("is_test", true);
desc.SetAttr("use_cudnn", false);
desc.Flush();
return *desc.Proto();
}
std::unique_ptr<ir::Graph> ConvElementwiseAddActFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
const std::string pattern_name = "conv_elementwise_add_act_fuse";
FusePassBase::Init(pattern_name, graph.get());
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()
->NewNode("x")
->assert_is_op_input("conv2d", "Input")
->AsInput();
patterns::ConvElementwiseaddAct pattern(gpd.mutable_pattern(), pattern_name);
pattern(x);
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_NODES;
auto base_op_desc = *conv_op->Op()->Proto();
std::string bias_name = elementwise_add_in_y->Name();
std::string act_op_type = act_op->Op()->Type();
std::string act_op_out = act_out->Name();
auto new_op_proto =
PrepareOpDesc(base_op_desc, bias_name, act_op_type, act_op_out);
framework::OpDesc new_op_desc(new_op_proto, nullptr);
// Create a new node for the fused op.
auto* new_conv_op = graph->CreateOpNode(&new_op_desc);
// Link inputs and outputs.
PADDLE_ENFORCE(subgraph.count(x));
auto* conv_in_node = subgraph.at(x);
IR_NODE_LINK_TO(conv_in_node, new_conv_op); // Input
IR_NODE_LINK_TO(conv_filter, new_conv_op); // Filter
IR_NODE_LINK_TO(elementwise_add_in_y, new_conv_op); // Bias
IR_NODE_LINK_TO(new_conv_op, act_out); // Output
// Delete the unneeded nodes.
GraphSafeRemoveNodes(graph.get(), {conv_op, conv_out, elementwise_add_op,
elementwise_add_out, act_op});
};
gpd(graph.get(), handler);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(conv_elementwise_add_act_fuse_pass,
paddle::framework::ir::ConvElementwiseAddActFusePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
class ConvElementwiseAddActFusePass : public FusePassBase {
public:
virtual ~ConvElementwiseAddActFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -17,6 +17,7 @@
#include <string>
#include <vector>
#include "graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
......@@ -25,6 +26,7 @@
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/pretty_log.h"
#include "paddle/fluid/string/printf.h"
namespace paddle {
namespace framework {
namespace ir {
......@@ -104,7 +106,7 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph &graph) {
for (auto &node : GraphTraits::DFS(graph)) {
for (const auto &pdnode : pattern_.nodes()) {
if (pdnode->Tell(&node)) {
VLOG(4) << "pdnode " << pdnode->name() << " marked";
VLOG(4) << "Node " << node.Name() << " marked as " << pdnode->name();
pdnodes2nodes_[pdnode.get()].insert(&node);
}
}
......@@ -1099,6 +1101,115 @@ PDNode *patterns::ElementwiseAdd::operator()(PDNode *x_var, PDNode *y_var) {
return out_var;
}
std::unordered_set<std::string> conv_act_set({"identity", "sigmoid", "relu",
"relu6", "relux", "tanh",
"band_pass"});
PDNode *patterns::ConvElementwiseaddAct::operator()(PDNode *conv_in) {
conv_in->AsInput();
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
auto conv_out = pattern->NewNode(conv_out_repr())
->assert_is_op_output("conv2d")
->assert_is_op_input("elementwise_add", "X")
->AsIntermediate();
auto conv_filter = pattern->NewNode(conv_filter_repr())
->assert_is_op_input("conv2d", "Filter")
->AsInput();
auto elementwise_add_op = pattern->NewNode(elementwise_add_op_repr())
->assert_is_op("elementwise_add");
auto elementwise_add_in_y = pattern->NewNode(elementwise_add_in_y_repr())
->assert_is_op_input("elementwise_add", "Y")
->AsInput();
auto elementwise_add_out = pattern->NewNode(elementwise_add_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate();
auto act_op = pattern->NewNode(act_op_repr())
->assert_is_op()
->assert_more([&](Node *node) {
auto op_type = node->Name();
return conv_act_set.count(op_type);
});
auto act_out = pattern->NewNode(act_out_repr())
->assert_is_var()
// is activation op's output.
->assert_more([&](Node *node) {
for (auto *in_op : node->inputs) {
if (conv_act_set.count(in_op->Name())) {
return true;
}
}
return false;
})
->AsOutput();
conv_op->LinksFrom({conv_in, conv_filter});
conv_out->LinksFrom({conv_op});
elementwise_add_op->LinksFrom({conv_out, elementwise_add_in_y})
.LinksTo({elementwise_add_out});
act_op->LinksFrom({elementwise_add_out}).LinksTo({act_out});
return act_out;
}
PDNode *patterns::ConvElementwiseadd2Act::operator()(PDNode *conv_in) {
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
auto conv_filter = pattern->NewNode(conv_filter_repr())
->assert_is_op_input("conv2d", "Filter")
->AsInput();
auto conv_out = pattern->NewNode(conv_out_repr())
->assert_is_op_output("conv2d")
->assert_is_op_input("elementwise_add", "X")
->AsIntermediate();
auto elementwise_add_op = pattern->NewNode(elementwise_add_op_repr())
->assert_is_op("elementwise_add");
auto elementwise_add_in_y = pattern->NewNode(elementwise_add_in_y_repr())
->assert_is_op_input("elementwise_add", "Y")
->AsInput();
auto elementwise_add_out = pattern->NewNode(elementwise_add_out_repr())
->assert_is_op_output("elementwise_add")
->assert_is_op_input("elementwise_add", "X")
->AsIntermediate();
auto elementwise_add_op_1 = pattern->NewNode(elementwise_add_op_1_repr())
->assert_is_op("elementwise_add");
auto elementwise_add_in_y_1 = pattern->NewNode(elementwise_add_in_y_1_repr())
->assert_is_op_input("elementwise_add", "Y")
->AsInput();
auto elementwise_add_out_1 = pattern->NewNode(elementwise_add_out_1_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate();
auto act_op = pattern->NewNode(act_op_repr())
->assert_is_op()
->assert_more([&](Node *node) {
auto op_type = node->Name();
return conv_act_set.count(op_type);
});
auto act_out = pattern->NewNode(act_out_repr())
->assert_is_var()
// is activation op's output.
->assert_more([&](Node *node) {
for (auto *in_op : node->inputs) {
if (conv_act_set.count(in_op->Name())) {
return true;
}
}
return false;
})
->AsOutput();
conv_op->LinksFrom({conv_in, conv_filter}).LinksTo({conv_out});
elementwise_add_op->LinksFrom({conv_out, elementwise_add_in_y})
.LinksTo({elementwise_add_out});
elementwise_add_op_1->LinksFrom(
{elementwise_add_out, elementwise_add_in_y_1});
act_op->LinksFrom({elementwise_add_out_1}).LinksTo({act_out});
return act_out;
}
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -671,6 +671,51 @@ struct ElementwiseAdd : public PatternBase {
PATTERN_DECL_NODE(elementwise_add_y);
PATTERN_DECL_NODE(elementwise_add_out);
};
// Conv + ElementwiseAdd + an activation
// This pattern can futher fuse the conv related ops after the conv+bn fusion.
struct ConvElementwiseaddAct : public PatternBase {
ConvElementwiseaddAct(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_elementwiseadd_act") {}
PDNode* operator()(PDNode* conv_in);
PATTERN_DECL_NODE(conv_op);
PATTERN_DECL_NODE(conv_out);
PATTERN_DECL_NODE(conv_filter);
PATTERN_DECL_NODE(elementwise_add_op);
PATTERN_DECL_NODE(elementwise_add_in_y); // input
PATTERN_DECL_NODE(elementwise_add_out);
PATTERN_DECL_NODE(act_op);
PATTERN_DECL_NODE(act_out);
};
// Conv + ElementwiseAdd + ElementwiseAdd + Activation
struct ConvElementwiseadd2Act : public PatternBase {
ConvElementwiseadd2Act(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope,
"conv_elementwiseadd2_elementwiseadd_act") {}
PDNode* operator()(PDNode* conv_in);
PATTERN_DECL_NODE(conv_op);
PATTERN_DECL_NODE(conv_filter);
PATTERN_DECL_NODE(conv_out);
PATTERN_DECL_NODE(elementwise_add_op);
PATTERN_DECL_NODE(elementwise_add_in_y); // input
PATTERN_DECL_NODE(elementwise_add_out);
PATTERN_DECL_NODE(elementwise_add_op_1);
PATTERN_DECL_NODE(elementwise_add_in_y_1); // input
PATTERN_DECL_NODE(elementwise_add_out_1);
PATTERN_DECL_NODE(act_op);
PATTERN_DECL_NODE(act_out);
};
} // namespace patterns
// Link two ir::Nodes from each other.
......
......@@ -55,7 +55,12 @@ TEST(AnalysisPredictor, analysis_off) {
}
TEST(AnalysisPredictor, analysis_on) {
AnalysisConfig config(false);
#ifdef PADDLE_WITH_CUDA
AnalysisConfig config(true);
config.fraction_of_gpu_memory = 0.15;
#else
AnalysisConfig config;
#endif
config.model_dir = FLAGS_dirname;
config.enable_ir_optim = true;
......
......@@ -118,7 +118,10 @@ class GpuPassStrategy : public PassStrategy {
public:
GpuPassStrategy() : PassStrategy({}) {
passes_.assign({
"infer_clean_graph_pass", "conv_bn_fuse_pass",
"infer_clean_graph_pass", //
"conv_bn_fuse_pass", //
"conv_elementwise_add_act_fuse_pass", //
"conv_elementwise_add2_act_fuse_pass", //
});
}
......
......@@ -79,7 +79,7 @@ void LoadPersistables(framework::Executor* executor, framework::Scope* scope,
for (auto* var : global_block.AllVars()) {
if (IsPersistable(var)) {
VLOG(3) << "persistable variable's name: " << var->Name();
VLOG(4) << "persistable variable's name: " << var->Name();
framework::VarDesc* new_var = load_block->Var(var->Name());
new_var->SetShape(var->GetShape());
......
......@@ -78,6 +78,7 @@ void profile(std::string model_dir, bool use_analysis, bool use_tensorrt) {
std::vector<PaddleTensor> outputs;
if (use_analysis || use_tensorrt) {
contrib::AnalysisConfig config(true);
config.pass_builder()->TurnOnDebug();
SetConfig<contrib::AnalysisConfig>(&config, model_dir, true, use_tensorrt,
FLAGS_batch_size);
TestPrediction(reinterpret_cast<PaddlePredictor::Config*>(&config),
......@@ -141,9 +142,31 @@ TEST(TensorRT_resnext50, profile) {
profile(model_dir, /* use_analysis */ true, FLAGS_use_tensorrt);
}
TEST(resnext50, compare_analysis_native) {
std::string model_dir = FLAGS_infer_model + "/resnext50";
compare(model_dir, false /*use tensorrt*/);
}
TEST(TensorRT_mobilenet, analysis) {
std::string model_dir = FLAGS_infer_model + "/" + "mobilenet";
compare(model_dir, /* use_tensorrt */ false);
compare(model_dir, false /* use_tensorrt */);
}
TEST(AnalysisPredictor, use_gpu) {
std::string model_dir = FLAGS_infer_model + "/" + "mobilenet";
AnalysisConfig config(true);
config.model_dir = model_dir;
config.fraction_of_gpu_memory = 0.15;
config.pass_builder()->TurnOnDebug();
std::vector<std::vector<PaddleTensor>> inputs_all;
auto predictor = CreatePaddlePredictor(config);
SetFakeImageInput(&inputs_all, model_dir, false, "__model__", "");
std::vector<PaddleTensor> outputs;
for (auto& input : inputs_all) {
ASSERT_TRUE(predictor->Run(input, &outputs));
}
}
} // namespace inference
......
include(operators)
register_operators()
register_operators(DEPS naive_executor)
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n")
......@@ -44,7 +44,9 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
std::vector<int> dilations = ctx->Attrs().Get<std::vector<int>>("dilations");
PADDLE_ENFORCE(in_dims.size() == 4 || in_dims.size() == 5,
"Conv intput should be 4-D or 5-D tensor.");
"Conv intput should be 4-D or 5-D tensor, get %u",
in_dims.size());
PADDLE_ENFORCE_EQ(
in_dims.size(), filter_dims.size(),
"Conv input dimension and filter dimension should be the same.");
......
......@@ -300,9 +300,11 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
}
CudnnRNNCache *cudnn_rnn_cache = nullptr;
if (cache_var->IsInitialized()) {
// const_cast is usually bad.
cudnn_rnn_cache = const_cast<framework::Variable *>(cache_var)
->GetMutable<CudnnRNNCache>();
} else {
// const_cast is usually bad.
cudnn_rnn_cache = const_cast<framework::Variable *>(cache_var)
->GetMutable<CudnnRNNCache>();
std::random_device rnd;
......
......@@ -12,7 +12,7 @@ configure_file(send_recv.proto.in ${CMAKE_CURRENT_SOURCE_DIR}/send_recv.proto @O
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
if(WITH_GRPC)
grpc_library(sendrecvop_grpc SRCS grpc_bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
grpc_library(sendrecvop_rpc SRCS grpc_bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc grpc_variable_response.cc grpc_serde.cc collective_client.cc collective_server.cc
PROTO send_recv.proto
DEPS lod_tensor selected_rows_functor memory)
......@@ -20,36 +20,43 @@ if(WITH_GRPC)
set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(grpc_serde_test SRCS grpc_serde_test.cc
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL)
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_rpc scope profiler math_function SERIAL)
cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL)
DEPS sendrecvop_rpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL)
cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler)
if(WITH_GPU)
cc_test(collective_server_test SRCS collective_server_test.cc
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor
DEPS sendrecvop_rpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor
selected_rows_functor scope math_function SERIAL)
endif()
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_grpc memory)
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory)
else()
set_source_files_properties(brpc_server.cc brpc_client.cc rpc_server_test.cc brpc_serde_test.cc
brpc_variable_response.cc brpc_sendrecvop_utils.cc brpc_rdma_pool.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(brpc_server.cc parameter_prefetch.cc brpc_client.cc rpc_server_test.cc brpc_serde_test.cc
brpc_variable_response.cc brpc_sendrecvop_utils.cc brpc_rdma_pool.cc collective_server.cc collective_server_test.cc
collective_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
brpc_library(sendrecvop_brpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc brpc_sendrecvop_utils.cc
brpc_variable_response.cc variable_response.cc sendrecvop_utils.cc brpc_rdma_pool.cc
brpc_library(sendrecvop_rpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc brpc_sendrecvop_utils.cc
brpc_variable_response.cc variable_response.cc sendrecvop_utils.cc brpc_rdma_pool.cc collective_client.cc collective_server.cc
PROTO send_recv.proto
DEPS lod_tensor selected_rows memory)
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_brpc memory)
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory)
set(brpc_test_depends sendrecvop_brpc brpc ssl crypto protobuf leveldb gflags glog executor proto_desc lookup_table_op snappystream snappy)
set(brpc_test_depends sendrecvop_rpc brpc ssl crypto protobuf leveldb gflags glog executor
proto_desc lookup_sparse_table_op snappystream snappy zlib)
cc_test(brpc_server_test SRCS rpc_server_test.cc
cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS ${brpc_test_depends} SERIAL)
cc_test(brpc_serde_test SRCS brpc_serde_test.cc
DEPS ${brpc_test_depends} SERIAL)
if(WITH_GPU)
cc_test(collective_server_test SRCS collective_server_test.cc
DEPS ${brpc_test_depends} selected_rows_functor scope math_function SERIAL)
endif()
endif()
......@@ -14,135 +14,316 @@
#include "paddle/fluid/operators/distributed/brpc_client.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
namespace distributed {
DEFINE_int32(brpc_channel_num, 24,
"Number of channels to send requests connected to one server");
DEFINE_int32(timeout_ms, 30000, "RPC timeout in milliseconds");
DEFINE_int32(max_retry, 3, "Max retries(not including the first RPC)");
BRPCClient::~BRPCClient() { Wait(); }
void HandleSendResponse(brpc::Controller* cntl,
sendrecv::VoidMessage* response) {
void HandleSendResponse(brpc::Controller* cntl, sendrecv::VoidMessage* response,
VarHandlePtr var_h, ChannelQueuePtr ch_ptr,
ChannelContextPtr ch_ctx, BRPCClient* cls) {
// std::unique_ptr makes sure cntl/response will be deleted before returning.
std::unique_ptr<brpc::Controller> cntl_guard(cntl);
std::unique_ptr<sendrecv::VoidMessage> response_guard(response);
// this channel can be used by other now.
ch_ptr->Push(ch_ctx);
if (cntl->Failed()) {
LOG(WARNING) << "Fail to send EchoRequest, " << cntl->ErrorText();
LOG(FATAL) << "Fail to send SendVar: " << var_h->name()
<< ", error text: " << cntl->ErrorText();
var_h->Finish(false);
cls->DecreaseReqCount();
return;
}
LOG(INFO) << "Received response from " << cntl->remote_side()
<< " latency=" << cntl->latency_us() << "us";
var_h->Finish(true);
cls->DecreaseReqCount();
VLOG(4) << "HandleSendResponse from: " << cntl->remote_side()
<< ", varname: " << var_h->name()
<< ", latency: " << cntl->latency_us() << "us";
VLOG(4) << "Finish HandleSendResponse";
}
bool BRPCClient::AsyncSendVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name, int64_t time_out) {
VarHandlePtr BRPCClient::AsyncSendVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
const std::string var_name_val = var_name;
const framework::Scope* p_scope = &scope;
const auto ch_ptr = GetChannel(ep_val);
const std::string method = "SendRPC";
VarHandlePtr var_h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
framework::AsyncIO([=] {
auto ch_ctx = ch_ptr->Pop();
brpc::Controller* cntl = new brpc::Controller();
sendrecv::VoidMessage* response = new sendrecv::VoidMessage();
cntl->set_timeout_ms(time_out);
framework::AsyncIO(
[var_name_val, p_ctx, ep_val, p_scope, time_out, ch_ptr, this] {
auto ch_ctx = ch_ptr->Pop();
brpc::Controller* cntl = new brpc::Controller();
sendrecv::VoidMessage* response = new sendrecv::VoidMessage();
cntl->set_timeout_ms(time_out);
auto* var = p_scope->FindVar(var_name_val);
sendrecv::VariableMessage request;
distributed::SerializeToIOBuf(var_name_val, var, *p_ctx, &request,
&cntl->request_attachment(), "", false,
trainer_id_);
google::protobuf::Closure* done =
brpc::NewCallback(&HandleSendResponse, cntl, response);
google::protobuf::Closure* done = brpc::NewCallback(
&HandleSendResponse, cntl, response, var_h, ch_ptr, ch_ctx, this);
sendrecv::VariableMessage request;
ch_ctx->stub->SendVariable(cntl, &request, response, done);
});
platform::RecordRPCEvent record_event(method, p_ctx);
ch_ctx->stub->SendVariable(cntl, &request, response, done);
if (UNLIKELY(platform::IsProfileEnabled())) {
var_h->Wait();
}
});
req_count_++;
return true;
return var_h;
}
void HandleFetchBarrierResponse(brpc::Controller* cntl,
sendrecv::VariableMessage* response,
VarHandlePtr var_h, ChannelQueuePtr ch_ptr,
ChannelContextPtr ch_ctx, BRPCClient* cls) {
// std::unique_ptr makes sure cntl/response will be deleted before returning.
std::unique_ptr<brpc::Controller> cntl_guard(cntl);
std::unique_ptr<sendrecv::VariableMessage> response_guard(response);
// this channel can be used other now.
ch_ptr->Push(ch_ctx);
if (cntl->Failed()) {
LOG(FATAL) << "Fail to get HandleFetchBarrierResponse: " << var_h->name()
<< ", error text: " << cntl->ErrorText();
var_h->Finish(false);
cls->DecreaseReqCount();
return;
}
var_h->Finish(true);
cls->DecreaseReqCount();
VLOG(4) << "HandleFetchBarrierResponse from: " << cntl->remote_side()
<< ", varname: " << var_h->name()
<< ", latency: " << cntl->latency_us() << "us";
VLOG(4) << "Finish HandleFetchBarrierResponse";
}
void HandleGetResponse(brpc::Controller* cntl,
sendrecv::VariableMessage* response) {
sendrecv::VariableMessage* response, VarHandlePtr var_h,
ChannelQueuePtr ch_ptr, ChannelContextPtr ch_ctx,
BRPCClient* cls) {
// std::unique_ptr makes sure cntl/response will be deleted before returning.
std::unique_ptr<brpc::Controller> cntl_guard(cntl);
std::unique_ptr<sendrecv::VariableMessage> response_guard(response);
// this channel can be used other now.
ch_ptr->Push(ch_ctx);
if (cntl->Failed()) {
LOG(WARNING) << "Fail to send EchoRequest, " << cntl->ErrorText();
LOG(FATAL) << "Fail to GetVar: " << var_h->name()
<< ", error text: " << cntl->ErrorText();
cls->DecreaseReqCount();
var_h->Finish(false);
return;
}
LOG(INFO) << "Received response from " << cntl->remote_side()
<< " latency=" << cntl->latency_us() << "us";
// framework::Variable* outvar = nullptr;
// DeserializeFromByteBuffer(ret_msg, *var_h.ctx, var_h.scope, &outvar);
VLOG(4) << "HandleGetResponse from: " << cntl->remote_side()
<< ", varname: " << var_h->name()
<< ", latency: " << cntl->latency_us() << "us";
framework::Variable* outvar = nullptr;
int trainer_id;
distributed::DeserializeFromIOBuf(*response, cntl->response_attachment(),
*var_h->ctx(), var_h->scope(), &outvar,
&trainer_id);
VLOG(4) << "Finish HandleGetResponse";
cls->DecreaseReqCount();
var_h->Finish(true);
}
bool BRPCClient::AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name, int64_t time_out) {
VarHandlePtr BRPCClient::_AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
const std::string& method_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
const std::string var_name_val = var_name;
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
const auto ch_ptr = GetChannel(ep_val);
const std::string method = "GetRPC";
VarHandlePtr var_h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
framework::AsyncIO([=] {
auto ch_ctx = ch_ptr->Pop();
brpc::Controller* cntl = new brpc::Controller();
sendrecv::VariableMessage* response = new sendrecv::VariableMessage();
cntl->set_timeout_ms(time_out);
framework::AsyncIO(
[var_name_val, ep_val, p_scope, p_ctx, time_out, ch, this] {});
sendrecv::VariableMessage req;
req.set_varname(var_name_val);
req.set_trainer_id(trainer_id_);
google::protobuf::Closure* done = brpc::NewCallback(
&HandleGetResponse, cntl, response, var_h, ch_ptr, ch_ctx, this);
platform::RecordRPCEvent record_event(method, p_ctx);
if (method_name == "GetMonomerVariable") {
ch_ctx->stub->GetMonomerVariable(cntl, &req, response, done);
} else {
ch_ctx->stub->GetVariable(cntl, &req, response, done);
}
if (UNLIKELY(platform::IsProfileEnabled())) {
var_h->Wait();
}
});
req_count_++;
return true;
return var_h;
}
VarHandlePtr BRPCClient::AsyncGetMonomerVariable(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out) {
return _AsyncGetVar(ep, ctx, scope, var_name, "GetMonomerVariable", time_out);
}
VarHandlePtr BRPCClient::AsyncGetMonomerBarrier(const std::string& ep,
const std::string& var_name,
int64_t time_out) {
return AsyncSendMessage(ep, "GetMonomerBarrier", var_name, time_out);
}
bool BRPCClient::AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
int64_t time_out) {
VarHandlePtr BRPCClient::AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out) {
return _AsyncGetVar(ep, ctx, scope, var_name, "GetVariable", time_out);
}
VarHandlePtr BRPCClient::AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
const std::string& table_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
const std::string in_var_name_val = in_var_name;
const std::string out_var_name_val = out_var_name;
const std::string table_name_val = table_name;
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
const auto ch_ptr = GetChannel(ep_val);
const std::string method = "PrefetchRPC";
VarHandlePtr var_h(
new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
framework::AsyncIO([=] {
auto ch_ctx = ch_ptr->Pop();
brpc::Controller* cntl = new brpc::Controller();
sendrecv::VariableMessage* response = new sendrecv::VariableMessage();
cntl->set_timeout_ms(time_out);
auto* var = p_scope->FindVar(in_var_name_val);
sendrecv::VariableMessage req;
distributed::SerializeToIOBuf(in_var_name_val, var, *p_ctx, &req,
&cntl->request_attachment(), out_var_name_val,
false, 0, table_name_val);
platform::RecordRPCEvent record_event(method, p_ctx);
google::protobuf::Closure* done = brpc::NewCallback(
&HandleGetResponse, cntl, response, var_h, ch_ptr, ch_ctx, this);
framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
time_out, ch, this] {});
ch_ctx->stub->PrefetchVariable(cntl, &req, response, done);
if (UNLIKELY(platform::IsProfileEnabled())) {
var_h->Wait();
}
});
req_count_++;
return true;
return var_h;
}
void BRPCClient::AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out) {
req_count_++;
VarHandlePtr BRPCClient::AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out) {
return AsyncSendMessage(ep, "BatchBarrierRPC", BATCH_BARRIER_MESSAGE,
time_out);
}
void BRPCClient::AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out) {
VarHandlePtr BRPCClient::AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out) {
auto ch_ptr = GetChannel(ep);
auto ch_ctx = ch_ptr->Pop();
brpc::Controller* cntl = new brpc::Controller();
sendrecv::VariableMessage* response = new sendrecv::VariableMessage();
cntl->set_timeout_ms(time_out);
sendrecv::VariableMessage req;
req.set_varname(FETCH_BARRIER_MESSAGE);
const std::string method = "FetchBarrierRPC";
// var handle
VarHandlePtr var_h(
new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
platform::RecordRPCEvent record_event(method, nullptr);
google::protobuf::Closure* done = brpc::NewCallback(
&HandleFetchBarrierResponse, cntl, response, var_h, ch_ptr, ch_ctx, this);
ch_ctx->stub->GetVariable(cntl, &req, response, done);
req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
var_h->Wait();
}
return var_h;
}
void BRPCClient::Wait() {
std::unique_lock<std::mutex> lk(sync_mutex_);
sync_cond_.wait(lk, [this] { return req_count_ == 0; });
bool BRPCClient::Wait() {
VLOG(9) << "begin to brpcclient wait";
{
std::unique_lock<std::mutex> lk(sync_mutex_);
sync_cond_.wait(lk, [this] { return req_count_ == 0; });
}
VLOG(9) << "end to brpcclient wait";
return true;
}
ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) {
VLOG(4) << "begin to GetChannel:" << ep;
{
std::lock_guard<std::mutex> guard(chan_mutex_);
auto it = channels_.find(ep);
if (it != channels_.end()) {
VLOG(4) << "end to GetChannel:" << ep;
return it->second;
}
}
......@@ -150,12 +331,20 @@ ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) {
ChannelQueuePtr q(new framework::BlockingQueue<ChannelContextPtr>());
brpc::ChannelOptions options;
#ifdef PADDLE_WITH_BRPC_RDMA
options.use_rdma = true;
#endif
options.protocol = "baidu_std";
options.connection_type = "pooled";
options.connect_timeout_ms = 100;
// don't use pooled type. the server can't afford that.
options.connection_type = "single";
options.connect_timeout_ms = 1000;
options.timeout_ms = FLAGS_timeout_ms /*milliseconds*/;
options.max_retry = FLAGS_max_retry;
for (int i = 0; i < FLAGS_brpc_channel_num; ++i) {
VLOG(1) << "create " << brpc_channel_num_per_server_
<< " brpc channels to pserver:" << ep;
for (int i = 0; i < brpc_channel_num_per_server_; ++i) {
std::shared_ptr<ChannelContext> c(new ChannelContext());
if (c->channel.Init(ep.c_str(), &options) != 0) {
LOG(FATAL) << "Fail to initialize channel";
......@@ -172,9 +361,75 @@ ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) {
channels_[ep] = q;
}
VLOG(4) << "end to GetChannel:" << ep;
return q;
}
VarHandlePtr BRPCClient::AsyncSendComplete(const std::string& ep,
int64_t time_out) {
return AsyncSendMessage(ep, "SendCompleteRPC", COMPLETE_MESSAGE, time_out);
}
void BRPCClient::SendComplete() {
for (auto& kv : channels_) {
AsyncSendComplete(kv.first);
}
}
VarHandlePtr BRPCClient::AsyncSendVarMessage(
const std::string& ep, const std::string& method_name,
const sendrecv::VariableMessage& req, int64_t time_out) {
auto ch_ptr = GetChannel(ep);
auto ch_ctx = ch_ptr->Pop();
brpc::Controller* cntl = new brpc::Controller();
sendrecv::VoidMessage* response = new sendrecv::VoidMessage();
cntl->set_timeout_ms(time_out);
platform::RecordRPCEvent record_event(method_name, nullptr);
VarHandlePtr var_h(
new VarHandle(ep, method_name, req.varname(), nullptr, nullptr));
google::protobuf::Closure* done = brpc::NewCallback(
&HandleSendResponse, cntl, response, var_h, ch_ptr, ch_ctx, this);
if (method_name == "CheckPointNotifyRPC") {
ch_ctx->stub->CheckpointNotify(cntl, &req, response, done);
} else if (method_name == "GetMonomerBarrier") {
ch_ctx->stub->GetMonomerBarrier(cntl, &req, response, done);
} else {
ch_ctx->stub->SendVariable(cntl, &req, response, done);
}
req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
var_h->Wait();
}
return var_h;
}
VarHandlePtr BRPCClient::AsyncSendMessage(const std::string& ep,
const std::string& method_name,
const std::string& message,
int64_t time_out) {
sendrecv::VariableMessage req;
req.set_varname(message);
return AsyncSendVarMessage(ep, method_name, req, time_out);
}
VarHandlePtr BRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dir,
int64_t time_out) {
sendrecv::VariableMessage req;
req.set_varname(CHECKPOINT_SAVE_MESSAGE);
req.set_out_varname(dir);
return AsyncSendVarMessage(ep, "CheckPointNotifyRPC", req, time_out);
}
} // namespace distributed
} // namespace operators
} // namespace paddle
......@@ -31,6 +31,8 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
......@@ -53,33 +55,94 @@ class BRPCClient : public RPCClient {
BRPCClient() {}
virtual ~BRPCClient();
bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncSendVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) override;
bool AsyncGetVar(const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) override;
bool AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncGetMonomerBarrier(
const std::string& ep, const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncGetMonomerVariable(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
const std::string& table_name = "",
int64_t time_out = FLAGS_rpc_deadline) override;
void Wait() override;
VarHandlePtr AsyncSendBatchBarrier(
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncSendFetchBarrier(
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncCheckpointNotify(
const std::string& ep, const std::string& dir,
int64_t time_out = FLAGS_rpc_deadline) override;
bool Wait() override;
void SendComplete() override;
private:
VarHandlePtr _AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
const std::string& method_name,
int64_t time_out = FLAGS_rpc_deadline);
void Proceed();
ChannelQueuePtr GetChannel(const std::string& ep);
VarHandlePtr AsyncSendComplete(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline);
VarHandlePtr AsyncSendMessage(const std::string& ep,
const std::string& method_name,
const std::string& message, int64_t time_out);
VarHandlePtr AsyncSendVarMessage(const std::string& ep,
const std::string& method_name,
const sendrecv::VariableMessage& req,
int64_t time_out);
friend void HandleSendResponse(brpc::Controller* cntl,
sendrecv::VoidMessage* response,
VarHandlePtr var_h, ChannelQueuePtr ch_ptr,
ChannelContextPtr ch_ctx, BRPCClient* cls);
friend void HandleGetResponse(brpc::Controller* cntl,
sendrecv::VariableMessage* response,
VarHandlePtr var_h, ChannelQueuePtr ch_ptr,
ChannelContextPtr ch_ctx, BRPCClient* cls);
friend void HandleFetchBarrierResponse(brpc::Controller* cntl,
sendrecv::VariableMessage* response,
VarHandlePtr var_h,
ChannelQueuePtr ch_ptr,
ChannelContextPtr ch_ctx,
BRPCClient* cls);
void DecreaseReqCount() {
if (--req_count_ <= 0) {
sync_cond_.notify_all();
}
}
private:
std::unordered_map<std::string, ChannelQueuePtr> channels_;
......@@ -88,6 +151,8 @@ class BRPCClient : public RPCClient {
std::condition_variable sync_cond_;
std::atomic<int64_t> req_count_{0};
static constexpr int brpc_channel_num_per_server_ = 4;
// mutex for GetChannel thread safety
std::mutex chan_mutex_;
DISABLE_COPY_AND_ASSIGN(BRPCClient);
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_BRPC_RDMA
#include "paddle/fluid/operators/distributed/brpc_rdma_pool.h"
#include "brpc/channel.h"
#include "brpc/rdma/rdma_helper.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace distributed {
RdmaMemPool& RdmaMemPool::Instance() {
static RdmaMemPool* g_rdma_mem_pool = new RdmaMemPool();
return *g_rdma_mem_pool;
}
void* RdmaMemPool::Find(const std::string& varname, int64_t size) {
pthread_rwlock_rdlock(&access_);
auto it = pool_.find(varname);
if (it == pool_.end()) {
pthread_rwlock_unlock(&access_);
return nullptr;
}
auto info = it->second;
if (info.data_size != size) {
pthread_rwlock_unlock(&access_);
PADDLE_ENFORCE(false, "var:%s size:%ld != %ld", varname, size,
info.data_size);
return nullptr;
}
pthread_rwlock_unlock(&access_);
return info.data;
}
void RdmaMemPool::Register(const std::string& varname, void* data,
int64_t data_size) {
void* old = Find(varname, data_size);
if (old != nullptr) {
if (data != old) {
PADDLE_ENFORCE(false, "var:%s data:%ld != %ld", varname, data, old);
}
VLOG(7) << "Find on rdma:" << varname << " data:" << data
<< " data_size:" << data_size;
return;
}
VarInfo info;
info.data = data;
info.data_size = data_size;
pthread_rwlock_wrlock(&access_);
pool_[varname] = info;
pthread_rwlock_unlock(&access_);
if (brpc::rdma::RegisterMemoryForRdma(data, data_size)) {
LOG(FATAL) << "register " << varname << " data:" << data
<< " data_size:" << data_size << " error";
}
VLOG(4) << "register on rdma:" << varname << " data:" << data
<< " data_size:" << data_size;
}
} // namespace distributed
} // namespace operators
} // namespace paddle
#endif
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_BRPC_RDMA
#include <pthread.h> // NOLINT
#include <string>
#include <unordered_map>
namespace paddle {
namespace operators {
namespace distributed {
/*
* This class is used to avoid duplicated registion of brpc::rdma.
*/
class RdmaMemPool {
public:
static RdmaMemPool& Instance();
RdmaMemPool() : access_(PTHREAD_RWLOCK_INITIALIZER) {}
virtual ~RdmaMemPool() { pthread_rwlock_destroy(&access_); }
void Register(const std::string& varname, void* data, int64_t size);
void* Find(const std::string& varname, int64_t size);
private:
struct VarInfo {
void* data;
int64_t data_size;
VarInfo() : data(nullptr), data_size(0) {}
};
private:
std::unordered_map<std::string, VarInfo> pool_;
pthread_rwlock_t access_;
};
} // namespace distributed
} // namespace operators
} // namespace paddle
#endif
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_CUDA
#include <nccl.h>
#endif
#include <sys/time.h>
#include <thread> // NOLINT
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/distributed/brpc_rdma_pool.h"
#include "paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/brpc_variable_response.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
namespace distributed {
class IOBufWriter {
public:
static void Append(butil::IOBuf* iobuf, int k, const char* v, int64_t vlen) {
iobuf->append(reinterpret_cast<char*>(&k), 4);
iobuf->append(reinterpret_cast<char*>(&vlen), 8);
iobuf->append(v, vlen);
}
static void AppendTCPZeroCopy(butil::IOBuf* iobuf, int k, const char* v,
int64_t vlen, bool in_cuda_pinned,
void (*destroy)(void*), void* user_data) {
VLOG(7) << "AppendTCPZeroCopy "
<< " k:" << k
<< " data:" << static_cast<void*>(const_cast<char*>(v))
<< " data_size:" << vlen << " in_cuda_pinned:" << in_cuda_pinned;
iobuf->append(reinterpret_cast<char*>(&k), 4);
iobuf->append(reinterpret_cast<char*>(&vlen), 8);
// FIXME(gongwb): use append_zerocopy
/*
if (in_cuda_pinned) {
iobuf->append_zerocopy(v, vlen, IOBufWriter::FreeMemory);
} else {
iobuf->append_zerocopy(v, vlen, nullptr);
}
*/
iobuf->append(v, vlen);
destroy(user_data);
}
#ifdef PADDLE_WITH_BRPC_RDMA
static void AppendRdmaZeroCopy(const std::string varname, butil::IOBuf* iobuf,
int k, const char* v, int64_t vlen,
bool in_cuda_pinned, void (*destroy)(void*),
void* user_data) {
VLOG(7) << "AppendRdmaZeroCopy varname:" << varname << " k:" << k
<< " data:" << static_cast<void*>(const_cast<char*>(v))
<< " data_size:" << vlen << " in_cuda_pinned:" << in_cuda_pinned;
iobuf->append(reinterpret_cast<char*>(&k), 4);
iobuf->append(reinterpret_cast<char*>(&vlen), 8);
RdmaMemPool::Instance().Register(
varname, static_cast<void*>(const_cast<char*>(v)), vlen);
// FIXME(gongwb): use append_zerocopy
// iobuf->append_zerocopy(v, vlen, nullptr);
iobuf->append(v, vlen);
destroy(user_data);
return;
}
#endif
static void AppendZeroCopy(const std::string varname, butil::IOBuf* iobuf,
int k, const char* v, int64_t vlen,
bool in_cuda_pinned, void (*destroy)(void*),
void* user_data) {
#ifdef PADDLE_WITH_BRPC_RDMA
IOBufWriter::AppendRdmaZeroCopy(varname, iobuf, k, v, vlen, in_cuda_pinned,
destroy, user_data);
#else
IOBufWriter::AppendTCPZeroCopy(iobuf, k, v, vlen, in_cuda_pinned, destroy,
user_data);
#endif
}
};
void SerializeToIOBuf(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* request,
butil::IOBuf* iobuf, const std::string& out_varname,
bool var_is_not_stable, int trainer_id,
const std::string& table_name) {
std::unique_ptr<TensorPayload> payload;
request->set_varname(name);
request->set_trainer_id(trainer_id);
// Note: normally the profiler is enabled in 1 trainer, hence only
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
// servers the trainer's profiling state so that PS can follow the
// trainer.
if (platform::ShouldSendProfileState()) {
if (platform::IsProfileEnabled()) {
request->set_profile(platform::kEnableProfiler);
} else {
request->set_profile(platform::kDisableProfiler);
}
}
if (!out_varname.empty()) {
request->set_out_varname(out_varname);
}
if (!table_name.empty()) {
request->set_table_name(table_name);
}
if (var->IsType<framework::LoDTensor>()) {
request->set_type(::sendrecv::LOD_TENSOR);
payload.reset(new TensorPayload(GetTensorPayload(var, ctx, request)));
} else if (var->IsType<framework::SelectedRows>()) {
request->set_type(::sendrecv::SELECTED_ROWS);
payload.reset(new TensorPayload(GetSelectedRowsPayload(var, ctx, request)));
#ifdef PADDLE_WITH_CUDA
} else if (var->IsType<ncclUniqueId>()) {
request->set_type(::sendrecv::NCCL_ID);
const ncclUniqueId& uid = var->Get<ncclUniqueId>();
// TODO(gongwb): use append_zero to avoid data copy.
IOBufWriter::Append(iobuf,
sendrecv::VariableMessage::kSerializedFieldNumber,
uid.internal, NCCL_UNIQUE_ID_BYTES);
return;
#endif
} else {
PADDLE_THROW("Serialize does not support type: %s",
typeid(var->Type()).name());
}
PADDLE_ENFORCE_NOT_NULL(payload);
// FIXME(gongwb): it seems that can use zero copy.
if (var_is_not_stable) {
IOBufWriter::Append(
iobuf, ::sendrecv::VariableMessage::kSerializedFieldNumber,
static_cast<const char*>(payload->ptr()), payload->memory_size());
} else {
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
IOBufWriter::AppendZeroCopy(
name, iobuf, ::sendrecv::VariableMessage::kSerializedFieldNumber,
static_cast<const char*>(payload->ptr()), payload->memory_size(),
true, SerializeDestroyCallback, static_cast<void*>(payload.get()));
payload.release();
#endif
} else {
IOBufWriter::AppendZeroCopy(
name, iobuf, ::sendrecv::VariableMessage::kSerializedFieldNumber,
static_cast<const char*>(payload->ptr()), payload->memory_size(),
false, SerializeDestroyCallback, static_cast<void*>(payload.get()));
payload.release();
}
}
if (var->IsType<framework::SelectedRows>()) {
auto* slr = var->GetMutable<framework::SelectedRows>();
size_t rows_memory_size =
slr->rows().size() * framework::SizeOfType(typeid(int64_t));
IOBufWriter::Append(iobuf, ::sendrecv::VariableMessage::kRowsFieldNumber,
reinterpret_cast<const char*>(slr->rows().data()),
static_cast<int64_t>(rows_memory_size));
}
}
void DeserializeFromIOBuf(const ::sendrecv::VariableMessage& meta,
const butil::IOBuf& iobuf,
const platform::DeviceContext& ctx,
const framework::Scope* scope,
framework::Variable** var, int* trainer_id) {
operators::distributed::BRPCVariableResponse resp(scope, &ctx);
PADDLE_ENFORCE(resp.Parse(iobuf, meta) == 0, "parse iobuf to tensor error!");
*var = resp.GetVar();
*trainer_id = resp.GetTrainerId();
}
} // namespace distributed
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <sys/time.h>
#include <iostream>
#include <string>
#include <vector>
#include "brpc/channel.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
namespace paddle {
namespace operators {
namespace distributed {
void SerializeToIOBuf(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* request,
butil::IOBuf* iobuf, const std::string& out_varname,
bool var_is_not_stable, const int trainer_id = 0,
const std::string& table_name = std::string());
void DeserializeFromIOBuf(const VarMsg& meta, const butil::IOBuf& iobuf,
const platform::DeviceContext& ctx,
const framework::Scope* scope,
framework::Variable** var, int* trainer_id);
} // namespace distributed
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <unistd.h>
#include <string>
#include <thread> // NOLINT
#include "brpc/channel.h"
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/brpc_variable_response.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h"
namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace operators = paddle::operators;
namespace math = paddle::operators::math;
namespace memory = paddle::memory;
void RunSerdeTestSelectedRows(platform::Place place) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
butil::IOBuf iobuf;
sendrecv::VariableMessage msg;
int tensor_numel = 564 * 128;
// serialize var to IOBuf
{
framework::Variable var;
auto* slr = var.GetMutable<framework::SelectedRows>();
slr->set_height(1000);
auto* tensor = slr->mutable_value();
auto* rows = slr->mutable_rows();
tensor->Resize(framework::make_ddim({564, 128}));
tensor->mutable_data<float>(place);
math::set_constant(ctx, tensor, 32.7);
for (int i = 0; i < 564; ++i) rows->push_back(i);
operators::distributed::SerializeToIOBuf("myvar", &var, ctx, &msg, &iobuf,
"", false);
}
// desrialize
{
framework::Scope scope;
scope.Var("myvar");
operators::distributed::BRPCVariableResponse resp(&scope, &ctx);
EXPECT_EQ(resp.Parse(iobuf, msg), 0);
framework::Variable* var2 = resp.GetVar();
auto* slr2 = var2->GetMutable<framework::SelectedRows>();
auto* tensor2 = slr2->mutable_value();
auto* rows2 = slr2->mutable_rows();
float* tensor_data2 = nullptr;
framework::Tensor tmp_tensor;
if (platform::is_gpu_place(ctx.GetPlace())) {
platform::CPUPlace cpu;
framework::TensorCopy(*tensor2, cpu, &tmp_tensor);
tensor_data2 = tmp_tensor.data<float>();
} else {
tensor_data2 = const_cast<float*>(tensor2->data<float>());
}
const int64_t* rows_data2 = rows2->data();
for (int i = 0; i < tensor_numel; ++i) {
EXPECT_FLOAT_EQ(tensor_data2[i], 32.7);
}
for (size_t i = 0; i < rows2->size(); ++i) {
EXPECT_EQ(rows_data2[i], static_cast<int64_t>(i));
}
EXPECT_EQ(slr2->height(), 1000);
}
}
void RunTestLodTensor(platform::Place place) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
// serialize var to ByteBuffer
butil::IOBuf iobuf;
sendrecv::VariableMessage msg;
int tensor_numel = 512 * 8 * 4 * 2;
{
framework::Variable var;
auto* tensor = var.GetMutable<framework::LoDTensor>();
tensor->Resize(framework::make_ddim({512, 8, 4, 2}));
framework::LoD lod;
lod.push_back(framework::Vector<size_t>({1, 3, 8}));
tensor->set_lod(lod);
tensor->mutable_data<float>(place);
math::set_constant(ctx, tensor, 31.9);
operators::distributed::SerializeToIOBuf("myvar", &var, ctx, &msg, &iobuf,
"", false);
}
// check sendrecv::VariableMessage meta data
{
EXPECT_EQ(msg.varname(), "myvar");
EXPECT_EQ(msg.type(), 0);
EXPECT_EQ(msg.dims()[0], 512);
EXPECT_EQ(msg.dims()[1], 8);
EXPECT_EQ(msg.dims()[2], 4);
EXPECT_EQ(msg.dims()[3], 2);
EXPECT_EQ(msg.lod_level(), 1);
EXPECT_EQ(msg.lod(0).lod_data(0), 1);
EXPECT_EQ(msg.lod(0).lod_data(1), 3);
EXPECT_EQ(msg.lod(0).lod_data(2), 8);
}
// deserialize
{
framework::Scope scope;
scope.Var("myvar");
operators::distributed::BRPCVariableResponse resp(&scope, &ctx);
EXPECT_EQ(resp.Parse(iobuf, msg), 0);
framework::Variable* var2 = resp.GetVar();
auto tensor2 = var2->Get<framework::LoDTensor>();
float* tensor_data2 = nullptr;
framework::Tensor tmp_tensor;
if (platform::is_gpu_place(ctx.GetPlace())) {
platform::CPUPlace cpu;
framework::TensorCopy(tensor2, cpu, &tmp_tensor);
tensor_data2 = tmp_tensor.data<float>();
} else {
tensor_data2 = const_cast<float*>(tensor2.data<float>());
}
for (int i = 0; i < tensor_numel; ++i)
EXPECT_FLOAT_EQ(tensor_data2[i], 31.9);
}
}
TEST(LodTensor, Run) {
platform::CPUPlace place;
RunTestLodTensor(place);
#ifdef PADDLE_WITH_CUDA
platform::CUDAPlace gpu(0);
RunTestLodTensor(gpu);
#endif
}
TEST(SelectedRows, Run) {
platform::CPUPlace place;
RunSerdeTestSelectedRows(place);
#ifdef PADDLE_WITH_CUDA
platform::CUDAPlace gpu;
RunSerdeTestSelectedRows(gpu);
#endif
}
......@@ -13,84 +13,287 @@
// limitations under the License.
#include "paddle/fluid/operators/distributed/brpc_server.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/brpc_variable_response.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
namespace sendrecv {
typedef std::unordered_map<std::string,
paddle::operators::distributed::RequestHandler*>
namespace distributed = paddle::operators::distributed;
typedef std::unordered_map<std::string, distributed::RequestHandler*>
HandlerMap;
class BRPCServiceImpl : public SendRecvService {
public:
explicit BRPCServiceImpl(const HandlerMap& rpc_call_map)
: request_send_h_(nullptr),
request_get_h_(nullptr),
request_prefetch_h_(nullptr) {
auto it = rpc_call_map.find(paddle::operators::distributed::kRequestSend);
explicit BRPCServiceImpl(const HandlerMap& rpc_call_map,
distributed::RPCServer* rpc_server)
: rpc_server_(rpc_server) {
VLOG(3) << "BRPCServiceImpl size: " << rpc_call_map.size();
auto it = rpc_call_map.find(distributed::kRequestSend);
if (it != rpc_call_map.end()) {
request_send_h_ = it->second;
send_threads_.reset(new paddle::framework::ThreadPool(
rpc_server_->GetThreadNum(distributed::kRequestSend)));
}
it = rpc_call_map.find(paddle::operators::distributed::kRequestSend);
it = rpc_call_map.find(distributed::kRequestGet);
if (it != rpc_call_map.end()) {
request_get_h_ = it->second;
get_threads_.reset(new paddle::framework::ThreadPool(
rpc_server_->GetThreadNum(distributed::kRequestGet)));
}
it = rpc_call_map.find(paddle::operators::distributed::kRequestPrefetch);
it = rpc_call_map.find(distributed::kRequestPrefetch);
if (it != rpc_call_map.end()) {
request_prefetch_h_ = it->second;
prefetch_threads_.reset(new paddle::framework::ThreadPool(
rpc_server_->GetThreadNum(distributed::kRequestPrefetch)));
}
it = rpc_call_map.find(distributed::kRequestCheckpoint);
if (it != rpc_call_map.end()) {
request_checkpoint_h_ = it->second;
checkpoint_notify_threads_.reset(new paddle::framework::ThreadPool(
rpc_server_->GetThreadNum(distributed::kRequestPrefetch)));
}
it = rpc_call_map.find(distributed::kRequestGetMonomerVariable);
if (it != rpc_call_map.end()) {
request_get_monomer_handler_h_ = it->second;
}
it = rpc_call_map.find(distributed::kRequestGetMonomerBarrier);
if (it != rpc_call_map.end()) {
request_get_monomer_barrier_handler_h_ = it->second;
}
}
virtual ~BRPCServiceImpl() {}
void SendVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VoidMessage* response,
google::protobuf::Closure* done) override {
send_threads_->Run(
[=] { _SendVariable(cntl_butil, request, response, done); });
}
void _SendVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VoidMessage* response,
google::protobuf::Closure* done) {
PADDLE_ENFORCE(request_send_h_ != nullptr,
"RequestSend handler should be registed first!");
brpc::ClosureGuard done_guard(done);
paddle::framework::Scope* local_scope = request_send_h_->scope();
paddle::framework::Variable* outvar = nullptr;
paddle::framework::Variable* invar = nullptr;
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
std::string varname = request->varname();
VLOG(3) << "RequestSend var_name:" << varname
<< ", trainer_id:" << request->trainer_id()
<< ", from:" << cntl->remote_side();
if (!request_send_h_->sync_mode()) {
local_scope = &request_send_h_->scope()->NewScope();
invar = local_scope->Var(varname);
} else {
invar = local_scope->FindVar(varname);
}
distributed::BRPCVariableResponse resp(request_send_h_->scope(),
request_send_h_->dev_ctx(),
!request_send_h_->sync_mode());
PADDLE_ENFORCE(resp.Parse(cntl->request_attachment(), *request) == 0,
"parse iobuf to tensor error!");
request_send_h_->Handle(varname, local_scope, invar, &outvar);
auto scope = resp.GetMutableLocalScope();
auto invar = resp.GetVar();
int trainer_id = request->trainer_id();
paddle::framework::Variable* outvar = nullptr;
if (!request_send_h_->sync_mode()) {
request_send_h_->scope()->DeleteScope(local_scope);
}
request_send_h_->Handle(varname, scope, invar, &outvar, trainer_id);
}
void GetVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VariableMessage* response,
google::protobuf::Closure* done) override {
get_threads_->Run(
[=] { _GetVariable(cntl_butil, request, response, done); });
}
void _GetVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VariableMessage* response,
google::protobuf::Closure* done) {
PADDLE_ENFORCE(request_get_h_ != nullptr,
"RequestGet handler should be registed first!");
}
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
std::string varname = request->varname();
VLOG(3) << "RequestGet varname:" << varname
<< ", trainer_id:" << request->trainer_id()
<< ", from:" << cntl->remote_side();
auto scope = request_get_h_->scope();
auto invar = scope->FindVar(varname);
int trainer_id = request->trainer_id();
paddle::framework::Variable* outvar = nullptr;
request_get_h_->Handle(varname, scope, invar, &outvar, trainer_id);
if (outvar) {
distributed::SerializeToIOBuf(varname, outvar, *request_get_h_->dev_ctx(),
response, &cntl->response_attachment(), "",
false);
}
}
void PrefetchVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request,
VariableMessage* response,
google::protobuf::Closure* done) override {
prefetch_threads_->Run(
[=] { _PrefetchVariable(cntl_butil, request, response, done); });
}
void _PrefetchVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request,
VariableMessage* response,
google::protobuf::Closure* done) {
PADDLE_ENFORCE(request_prefetch_h_ != nullptr,
"kRequestPrefetch handler should be registed first!");
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
// prefetch process...
std::string in_var_name = request->varname();
std::string out_var_name = request->out_varname();
VLOG(3) << "RequestPrefetch, in_var_name: " << in_var_name
<< ", out_var_name: " << out_var_name
<< ", trainer_id:" << request->trainer_id()
<< ", from:" << cntl->remote_side();
distributed::BRPCVariableResponse resp(
request_prefetch_h_->scope(), request_prefetch_h_->dev_ctx(), true);
PADDLE_ENFORCE(resp.Parse(cntl->request_attachment(), *request) == 0,
"parse iobuf to tensor error!");
auto scope = resp.GetMutableLocalScope();
auto invar = scope->FindVar(in_var_name);
std::string table_name = request->table_name();
int trainer_id = request->trainer_id();
paddle::framework::Variable* outvar = scope->Var(out_var_name);
request_prefetch_h_->Handle(in_var_name, scope, invar, &outvar, trainer_id,
out_var_name, table_name);
distributed::SerializeToIOBuf(out_var_name, outvar,
*request_prefetch_h_->dev_ctx(), response,
&cntl->response_attachment(), "", true);
}
void CheckpointNotify(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VoidMessage* response,
google::protobuf::Closure* done) override {
checkpoint_notify_threads_->Run(
[=] { _CheckpointNotify(cntl_butil, request, response, done); });
}
void _CheckpointNotify(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VoidMessage* response,
google::protobuf::Closure* done) {
PADDLE_ENFORCE(
request_checkpoint_h_ != nullptr,
"kRequestCheckpointNotify handler should be registed first!");
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
distributed::BRPCVariableResponse resp(request_checkpoint_h_->scope(),
request_checkpoint_h_->dev_ctx());
auto scope = resp.GetMutableLocalScope();
std::string checkpoint_notify = request->varname();
std::string checkpoint_dir = request->out_varname();
int trainer_id = request->trainer_id();
VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify
<< ", dir: " << checkpoint_dir
<< ", trainer_id:" << request->trainer_id()
<< ", from:" << cntl->remote_side();
request_checkpoint_h_->Handle(checkpoint_notify, scope, nullptr, nullptr,
trainer_id, checkpoint_dir);
}
void GetMonomerVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request,
VariableMessage* response,
google::protobuf::Closure* done) override {
PADDLE_ENFORCE(
request_get_monomer_handler_h_ != nullptr,
"kRequestGetMonomerVariable handler should be registed first!");
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
// proc request.
std::string varname = request->varname();
VLOG(3) << "GetMonomerVariable " << varname
<< ", trainer_id:" << request->trainer_id()
<< ", from:" << cntl->remote_side();
rpc_server_->WaitVarCond(varname);
distributed::MonomerHandle h = rpc_server_->GetMonomer(varname);
auto scope = h.scope_;
auto invar = scope->FindVar(varname);
paddle::framework::Variable* outvar = nullptr;
request_get_monomer_handler_h_->Handle(varname, scope, invar, &outvar,
request->trainer_id());
if (outvar) {
distributed::SerializeToIOBuf(varname, outvar, *h.dev_ctx_, response,
&cntl->response_attachment(), "", false);
}
}
void GetMonomerBarrier(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VoidMessage* response,
google::protobuf::Closure* done) override {
PADDLE_ENFORCE(
request_get_monomer_barrier_handler_h_ != nullptr,
"RequestGetMonomerBarrier handler should be registed first!");
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
std::string varname = request->varname();
VLOG(3) << "RequestGetMonomerBarrier var_name:" << varname
<< ", trainer_id:" << request->trainer_id()
<< ", from:" << cntl->remote_side();
rpc_server_->WaitVarCond(varname);
distributed::MonomerHandle h = rpc_server_->GetMonomer(varname);
paddle::framework::Scope* scope = nullptr;
paddle::framework::Variable* invar = nullptr;
paddle::framework::Variable* outvar = nullptr;
request_get_monomer_barrier_handler_h_->Handle(
varname, scope, invar, &outvar, request->trainer_id());
}
private:
paddle::operators::distributed::RequestHandler* request_send_h_;
paddle::operators::distributed::RequestHandler* request_get_h_;
paddle::operators::distributed::RequestHandler* request_prefetch_h_;
distributed::RequestHandler* request_send_h_{nullptr};
distributed::RequestHandler* request_get_h_{nullptr};
distributed::RequestHandler* request_prefetch_h_{nullptr};
distributed::RequestHandler* request_checkpoint_h_{nullptr};
distributed::RequestHandler* request_get_monomer_handler_h_{nullptr};
distributed::RequestHandler* request_get_monomer_barrier_handler_h_{nullptr};
distributed::RPCServer* rpc_server_{nullptr};
// FIXME(gongwb): brpc should support process one rpce use one threadpool.
std::unique_ptr<paddle::framework::ThreadPool> send_threads_;
std::unique_ptr<paddle::framework::ThreadPool> get_threads_;
std::unique_ptr<paddle::framework::ThreadPool> prefetch_threads_;
std::unique_ptr<paddle::framework::ThreadPool> checkpoint_notify_threads_;
};
} // namespace sendrecv
......@@ -100,7 +303,7 @@ namespace distributed {
void AsyncBRPCServer::StartServer() {
// Instance of your service.
sendrecv::BRPCServiceImpl service_impl(rpc_call_map_);
sendrecv::BRPCServiceImpl service_impl(rpc_call_map_, this);
// Add the service into server. Notice the second parameter, because the
// service is put on stack, we don't want server to delete it, otherwise
......@@ -111,6 +314,9 @@ void AsyncBRPCServer::StartServer() {
}
brpc::ServerOptions options;
#ifdef PADDLE_WITH_BRPC_RDMA
options.use_rdma = true;
#endif
options.idle_timeout_sec = idle_timeout_s_;
options.max_concurrency = max_concurrency_;
if (server_.Start(bind_address_.c_str(), &options) != 0) {
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "paddle/fluid/operators/distributed/brpc_variable_response.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
namespace paddle {
namespace operators {
namespace distributed {
namespace pb = ::google::protobuf;
using vr = ::sendrecv::VariableMessage;
int BRPCVariableResponse::Parse(Source* source) {
pb::io::ZeroCopyInputStream* input_stream = source->contents();
pb::io::CodedInputStream input(input_stream);
input.SetTotalBytesLimit(INT_MAX, INT_MAX);
while (1) {
unsigned int tag = 0;
if (!input.ReadLittleEndian32(&tag)) {
break;
}
uint64_t num_bytes = 0;
if (!input.ReadLittleEndian64(&num_bytes)) {
break;
}
int field = static_cast<int>(tag);
int ret = field == 0 ? -1 : field;
switch (field) {
case vr::kSerializedFieldNumber: {
if (!ProcSerializedField(field, &input, num_bytes)) {
return ret;
}
break;
}
case vr::kRowsFieldNumber: {
PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS ||
meta_.type() == sendrecv::LOD_TENSOR) &&
meta_.varname() != "",
"meta info should be got first!");
if (!CopySelectRowsData(&input, *dev_ctx_, num_bytes)) {
return ret;
}
break;
}
default: {
PADDLE_ENFORCE(false, "not surpported %u fieldnumber", field);
return ret;
}
}
}
return 0;
}
} // namespace distributed
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "brpc/channel.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
namespace paddle {
namespace operators {
namespace distributed {
class BRPCSourceWrapper : public Source {
public:
explicit BRPCSourceWrapper(const butil::IOBuf& iobuf) : source_(iobuf) {}
::google::protobuf::io::ZeroCopyInputStream* contents() override {
return &source_;
}
private:
butil::IOBufAsZeroCopyInputStream source_;
};
class BRPCVariableResponse : public VariableResponse {
public:
BRPCVariableResponse(const framework::Scope* scope,
const platform::DeviceContext* dev_ctx,
bool create_scope = false)
: VariableResponse(scope, dev_ctx, create_scope) {}
virtual ~BRPCVariableResponse() {}
// parse attachment from iobuf
int Parse(Source* source) override;
int Parse(const butil::IOBuf& iobuf, const sendrecv::VariableMessage& meta) {
BRPCSourceWrapper wrapper(iobuf);
return VariableResponse::Parse(&wrapper, meta);
}
};
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
......@@ -293,8 +293,7 @@ VarHandlePtr GRPCClient::AsyncGetMonomerBarrier(const std::string& ep,
const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
const std::string method = "SendMonomerFetchBarrierRPC";
VarHandlePtr h(
new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
VarHandlePtr h(new VarHandle(ep, method, var_name, nullptr, nullptr));
s->Prepare(h, time_out);
VLOG(30) << s->GetVarHandlePtr()->String() << " begin";
......
......@@ -32,13 +32,6 @@ namespace paddle {
namespace operators {
namespace distributed {
static void SerializeDestroyCallback(void* payload) {
if (payload != nullptr) {
auto* shared_payload = reinterpret_cast<TensorPayload*>(payload);
delete shared_payload;
}
}
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg, const std::string& out_name,
......
......@@ -75,6 +75,10 @@ class RPCServer {
void RegisterRPC(const std::string& rpc_name, RequestHandler* handler,
int thread_num = 5);
int GetThreadNum(const std::string& rpc_name) {
return rpc_thread_num_[rpc_name];
}
// Wait util all the clients have reached the barrier for one
// rpc method. This function should be called in the
// RequestHandler if you want to run the server/client in a
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include <thread> // NOLINT
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/distributed/brpc_rdma_pool.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/platform/port.h"
......@@ -45,7 +46,6 @@ static TensorPayload GetCommunicationAllocationFromTensor(
memory::Copy(cuda_pinned, result->ptr(),
boost::get<platform::CUDAPlace>(tensor.place()),
tensor.data<void>(), copy_size, gpu_dev_ctx.stream());
ctx.Wait();
return TensorPayload(result);
#else
......
......@@ -50,6 +50,13 @@ class TensorPayload final {
size_t memory_size_;
};
inline void SerializeDestroyCallback(void* payload) {
if (payload != nullptr) {
auto* shared_payload = reinterpret_cast<TensorPayload*>(payload);
delete shared_payload;
}
}
TensorPayload GetTensorPayload(framework::Variable* var,
const platform::DeviceContext& ctx,
VarMsg* request);
......
......@@ -2,9 +2,9 @@ include(operators)
set(DISTRIBUTE_DEPS "")
if(WITH_GRPC)
set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node)
set(DISTRIBUTE_DEPS sendrecvop_rpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node)
else()
set(DISTRIBUTE_DEPS sendrecvop_brpc brpc leveldb snappystream snappy protobuf ssl crypto zlib node)
set(DISTRIBUTE_DEPS sendrecvop_rpc brpc leveldb snappystream snappy protobuf ssl crypto zlib node)
if(WITH_BRPC_RDMA)
find_library(IBVERBS_LIBRARY NAMES ibverbs)
ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL)
......
......@@ -26,10 +26,11 @@ limitations under the License. */
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_int32(rpc_send_thread_num, 5, "number of threads for rpc send");
DEFINE_int32(rpc_get_thread_num, 5, "number of threads for rpc get");
DEFINE_int32(rpc_prefetch_thread_num, 5, "number of threads for rpc prefetch");
DEFINE_int32(rpc_send_thread_num, 12, "number of threads for rpc send");
DEFINE_int32(rpc_get_thread_num, 12, "number of threads for rpc get");
DEFINE_int32(rpc_prefetch_thread_num, 12, "number of threads for rpc prefetch");
namespace paddle {
namespace operators {
......
......@@ -58,7 +58,9 @@ class SendOp : public framework::OperatorBase {
}
if (sync_send) {
for (size_t i = 0; i < rets.size(); i++) {
VLOG(7) << "before sync_send " << ins[i] << "from " << epmap[i];
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
VLOG(7) << "after sync_send " << ins[i] << "from " << epmap[i];
}
}
}
......
......@@ -31,7 +31,7 @@ class Pool2dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
const framework::Tensor& input, const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, PoolProcess pool_process,
bool exclusive, framework::Tensor* output) {
bool exclusive, bool adaptive, framework::Tensor* output) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
......@@ -51,16 +51,28 @@ class Pool2dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
const T* input_data = input.data<T>();
T* output_data = output->mutable_data<T>(context.GetPlace());
int hstart, hend;
int wstart, wend;
for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) {
for (int ph = 0; ph < output_height; ++ph) {
int hstart = ph * stride_height - padding_height;
int hend = std::min(hstart + ksize_height, input_height);
hstart = std::max(hstart, 0);
if (adaptive) {
hstart = AdaptStartIndex(ph, input_height, output_height);
hend = AdaptEndIndex(ph, input_height, output_height);
} else {
hstart = ph * stride_height - padding_height;
hend = std::min(hstart + ksize_height, input_height);
hstart = std::max(hstart, 0);
}
for (int pw = 0; pw < output_width; ++pw) {
int wstart = pw * stride_width - padding_width;
int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);
if (adaptive) {
wstart = AdaptStartIndex(pw, input_width, output_width);
wend = AdaptEndIndex(pw, input_width, output_width);
} else {
wstart = pw * stride_width - padding_width;
wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);
}
T ele = pool_process.initial();
for (int h = hstart; h < hend; ++h) {
......@@ -68,8 +80,9 @@ class Pool2dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
pool_process.compute(input_data[h * input_width + w], &ele);
}
}
int pool_size = exclusive ? (hend - hstart) * (wend - wstart)
: ksize_height * ksize_width;
int pool_size = (exclusive || adaptive)
? (hend - hstart) * (wend - wstart)
: ksize_height * ksize_width;
pool_process.finalize(static_cast<T>(pool_size), &ele);
output_data[ph * output_width + pw] = ele;
}
......@@ -94,7 +107,7 @@ class Pool2dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
const framework::Tensor& output, const framework::Tensor& output_grad,
const std::vector<int>& ksize, const std::vector<int>& strides,
const std::vector<int>& paddings, PoolProcess pool_grad_process,
bool exclusive, framework::Tensor* input_grad) {
bool exclusive, bool adaptive, framework::Tensor* input_grad) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
......@@ -115,18 +128,31 @@ class Pool2dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
int hstart, hend;
int wstart, wend;
for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) {
for (int ph = 0; ph < output_height; ++ph) {
int hstart = ph * stride_height - padding_height;
int hend = std::min(hstart + ksize_height, input_height);
hstart = std::max(hstart, 0);
if (adaptive) {
hstart = AdaptStartIndex(ph, input_height, output_height);
hend = AdaptEndIndex(ph, input_height, output_height);
} else {
hstart = ph * stride_height - padding_height;
hend = std::min(hstart + ksize_height, input_height);
hstart = std::max(hstart, 0);
}
for (int pw = 0; pw < output_width; ++pw) {
int wstart = pw * stride_width - padding_width;
int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);
int pool_size = exclusive ? (hend - hstart) * (wend - wstart)
: ksize_height * ksize_width;
if (adaptive) {
wstart = AdaptStartIndex(pw, input_width, output_width);
wend = AdaptEndIndex(pw, input_width, output_width);
} else {
wstart = pw * stride_width - padding_width;
wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);
}
int pool_size = (exclusive || adaptive)
? (hend - hstart) * (wend - wstart)
: ksize_height * ksize_width;
float scale = 1.0 / pool_size;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
......@@ -251,7 +277,7 @@ class Pool3dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
const framework::Tensor& input, const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, PoolProcess pool_process,
bool exclusive, framework::Tensor* output) {
bool exclusive, bool adaptive, framework::Tensor* output) {
const int batch_size = input.dims()[0];
const int input_depth = input.dims()[2];
const int input_height = input.dims()[3];
......@@ -276,20 +302,38 @@ class Pool3dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
const T* input_data = input.data<T>();
T* output_data = output->mutable_data<T>(context.GetPlace());
int dstart, dend;
int hstart, hend;
int wstart, wend;
for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) {
for (int pd = 0; pd < output_depth; ++pd) {
int dstart = pd * stride_depth - padding_depth;
int dend = std::min(dstart + ksize_depth, input_depth);
dstart = std::max(dstart, 0);
if (adaptive) {
dstart = AdaptStartIndex(pd, input_depth, output_depth);
dend = AdaptEndIndex(pd, input_depth, output_depth);
} else {
dstart = pd * stride_depth - padding_depth;
dend = std::min(dstart + ksize_depth, input_depth);
dstart = std::max(dstart, 0);
}
for (int ph = 0; ph < output_height; ++ph) {
int hstart = ph * stride_height - padding_height;
int hend = std::min(hstart + ksize_height, input_height);
hstart = std::max(hstart, 0);
if (adaptive) {
hstart = AdaptStartIndex(ph, input_height, output_height);
hend = AdaptEndIndex(ph, input_height, output_height);
} else {
hstart = ph * stride_height - padding_height;
hend = std::min(hstart + ksize_height, input_height);
hstart = std::max(hstart, 0);
}
for (int pw = 0; pw < output_width; ++pw) {
int wstart = pw * stride_width - padding_width;
int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);
if (adaptive) {
wstart = AdaptStartIndex(pw, input_width, output_width);
wend = AdaptEndIndex(pw, input_width, output_width);
} else {
wstart = pw * stride_width - padding_width;
wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);
}
int output_idx = (pd * output_height + ph) * output_width + pw;
T ele = pool_process.initial();
for (int d = dstart; d < dend; ++d) {
......@@ -302,7 +346,7 @@ class Pool3dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
}
}
int pool_size =
exclusive
(exclusive || adaptive)
? (dend - dstart) * (hend - hstart) * (wend - wstart)
: ksize_depth * ksize_height * ksize_width;
pool_process.finalize(static_cast<T>(pool_size), &ele);
......@@ -330,7 +374,7 @@ class Pool3dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
const framework::Tensor& output, const framework::Tensor& output_grad,
const std::vector<int>& ksize, const std::vector<int>& strides,
const std::vector<int>& paddings, PoolProcess pool_grad_process,
bool exclusive, framework::Tensor* input_grad) {
bool exclusive, bool adaptive, framework::Tensor* input_grad) {
const int batch_size = input.dims()[0];
const int input_depth = input.dims()[2];
const int input_height = input.dims()[3];
......@@ -356,24 +400,41 @@ class Pool3dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
int dstart, dend;
int hstart, hend;
int wstart, wend;
for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) {
for (int pd = 0; pd < output_depth; ++pd) {
int dstart = pd * stride_depth - padding_depth;
int dend = std::min(dstart + ksize_depth, input_depth);
dstart = std::max(dstart, 0);
if (adaptive) {
dstart = AdaptStartIndex(pd, input_depth, output_depth);
dend = AdaptEndIndex(pd, input_depth, output_depth);
} else {
dstart = pd * stride_depth - padding_depth;
dend = std::min(dstart + ksize_depth, input_depth);
dstart = std::max(dstart, 0);
}
for (int ph = 0; ph < output_height; ++ph) {
int hstart = ph * stride_height - padding_height;
int hend = std::min(hstart + ksize_height, input_height);
hstart = std::max(hstart, 0);
if (adaptive) {
hstart = AdaptStartIndex(ph, input_height, output_height);
hend = AdaptEndIndex(ph, input_height, output_height);
} else {
hstart = ph * stride_height - padding_height;
hend = std::min(hstart + ksize_height, input_height);
hstart = std::max(hstart, 0);
}
for (int pw = 0; pw < output_width; ++pw) {
int wstart = pw * stride_width - padding_width;
int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);
if (adaptive) {
wstart = AdaptStartIndex(pw, input_width, output_width);
wend = AdaptEndIndex(pw, input_width, output_width);
} else {
wstart = pw * stride_width - padding_width;
wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);
}
int pool_size =
exclusive
(exclusive || adaptive)
? (dend - dstart) * (hend - hstart) * (wend - wstart)
: ksize_depth * ksize_height * ksize_width;
float scale = 1.0 / pool_size;
......@@ -517,8 +578,8 @@ class MaxPool2dWithIndexFunctor<platform::CPUDeviceContext, T1, T2> {
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input, const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, framework::Tensor* output,
framework::Tensor* mask) {
const std::vector<int>& paddings, bool adaptive,
framework::Tensor* output, framework::Tensor* mask) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
......@@ -538,16 +599,28 @@ class MaxPool2dWithIndexFunctor<platform::CPUDeviceContext, T1, T2> {
T1* output_data = output->mutable_data<T1>(context.GetPlace());
T2* mask_data = mask->mutable_data<T2>(context.GetPlace());
int hstart, hend;
int wstart, wend;
for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) {
for (int ph = 0; ph < output_height; ++ph) {
int hstart = ph * stride_height - padding_height;
int hend = std::min(hstart + ksize_height, input_height);
hstart = std::max(hstart, 0);
if (adaptive) {
hstart = AdaptStartIndex(ph, input_height, output_height);
hend = AdaptEndIndex(ph, input_height, output_height);
} else {
hstart = ph * stride_height - padding_height;
hend = std::min(hstart + ksize_height, input_height);
hstart = std::max(hstart, 0);
}
for (int pw = 0; pw < output_width; ++pw) {
int wstart = pw * stride_width - padding_width;
int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);
if (adaptive) {
wstart = AdaptStartIndex(pw, input_width, output_width);
wend = AdaptEndIndex(pw, input_width, output_width);
} else {
wstart = pw * stride_width - padding_width;
wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);
}
T1 ele = static_cast<T1>(-FLT_MAX);
int index = -1;
......@@ -584,7 +657,7 @@ class MaxPool2dWithIndexGradFunctor<platform::CPUDeviceContext, T1, T2> {
const framework::Tensor& output_grad,
const framework::Tensor& mask, const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& paddings, bool adaptive,
framework::Tensor* input_grad) {
const int batch_size = input_grad->dims()[0];
const int input_height = input_grad->dims()[2];
......@@ -637,8 +710,8 @@ class MaxPool3dWithIndexFunctor<platform::CPUDeviceContext, T1, T2> {
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input, const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, framework::Tensor* output,
framework::Tensor* mask) {
const std::vector<int>& paddings, bool adaptive,
framework::Tensor* output, framework::Tensor* mask) {
const int batch_size = input.dims()[0];
const int input_depth = input.dims()[2];
const int input_height = input.dims()[3];
......@@ -663,20 +736,38 @@ class MaxPool3dWithIndexFunctor<platform::CPUDeviceContext, T1, T2> {
T1* output_data = output->mutable_data<T1>(context.GetPlace());
T2* mask_data = mask->mutable_data<T2>(context.GetPlace());
int dstart, dend;
int hstart, hend;
int wstart, wend;
for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) {
for (int pd = 0; pd < output_depth; ++pd) {
int dstart = pd * stride_depth - padding_depth;
int dend = std::min(dstart + ksize_depth, input_depth);
dstart = std::max(dstart, 0);
if (adaptive) {
dstart = AdaptStartIndex(pd, input_depth, output_depth);
dend = AdaptEndIndex(pd, input_depth, output_depth);
} else {
dstart = pd * stride_depth - padding_depth;
dend = std::min(dstart + ksize_depth, input_depth);
dstart = std::max(dstart, 0);
}
for (int ph = 0; ph < output_height; ++ph) {
int hstart = ph * stride_height - padding_height;
int hend = std::min(hstart + ksize_height, input_height);
hstart = std::max(hstart, 0);
if (adaptive) {
hstart = AdaptStartIndex(ph, input_height, output_height);
hend = AdaptEndIndex(ph, input_height, output_height);
} else {
hstart = ph * stride_height - padding_height;
hend = std::min(hstart + ksize_height, input_height);
hstart = std::max(hstart, 0);
}
for (int pw = 0; pw < output_width; ++pw) {
int wstart = pw * stride_width - padding_width;
int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);
if (adaptive) {
wstart = AdaptStartIndex(pw, input_width, output_width);
wend = AdaptEndIndex(pw, input_width, output_width);
} else {
wstart = pw * stride_width - padding_width;
wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);
}
int output_idx = (pd * output_height + ph) * output_width + pw;
T1 ele = static_cast<T1>(-FLT_MAX);
......@@ -718,7 +809,7 @@ class MaxPool3dWithIndexGradFunctor<platform::CPUDeviceContext, T1, T2> {
const framework::Tensor& output_grad,
const framework::Tensor& mask, const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& paddings, bool adaptive,
framework::Tensor* input_grad) {
const int batch_size = input_grad->dims()[0];
const int input_depth = input_grad->dims()[2];
......
......@@ -68,6 +68,18 @@ class AvgPoolGrad {
}
};
/* used for adaptive pool to calculate start and end index of each divided grid
*/
HOSTDEVICE inline int AdaptStartIndex(int ph, int input_size, int output_size) {
return static_cast<int>(
floor(static_cast<double>(ph * input_size) / output_size));
}
HOSTDEVICE inline int AdaptEndIndex(int ph, int input_size, int output_size) {
return static_cast<int>(
ceil(static_cast<double>((ph + 1) * input_size) / output_size));
}
/*
* \brief Getting pooling results, and calculating gradient.
*
......@@ -102,7 +114,7 @@ class Pool2dFunctor {
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, PoolProcess pool_compute,
bool exclusive, framework::Tensor* output);
bool exclusive, bool adaptive, framework::Tensor* output);
};
template <typename DeviceContext, typename PoolProcess, typename T>
......@@ -114,7 +126,7 @@ class Pool2dGradFunctor {
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, PoolProcess pool_compute,
bool exclusive, framework::Tensor* input_grad);
bool exclusive, bool adaptive, framework::Tensor* input_grad);
};
template <typename DeviceContext, class T>
......@@ -136,7 +148,7 @@ class Pool3dFunctor {
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, PoolProcess pool_compute,
bool exclusive, framework::Tensor* output);
bool exclusive, bool adaptive, framework::Tensor* output);
};
template <typename DeviceContext, typename PoolProcess, typename T>
......@@ -148,7 +160,7 @@ class Pool3dGradFunctor {
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, PoolProcess pool_compute,
bool exclusive, framework::Tensor* input_grad);
bool exclusive, bool adaptive, framework::Tensor* input_grad);
};
template <typename DeviceContext, class T>
......@@ -176,8 +188,8 @@ class MaxPool2dWithIndexFunctor {
void operator()(const DeviceContext& context, const framework::Tensor& input,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, framework::Tensor* output,
framework::Tensor* mask);
const std::vector<int>& paddings, bool adaptive,
framework::Tensor* output, framework::Tensor* mask);
};
template <typename DeviceContext, typename T1, typename T2>
......@@ -187,7 +199,7 @@ class MaxPool2dWithIndexGradFunctor {
const framework::Tensor& output_grad,
const framework::Tensor& mask, const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& paddings, bool adaptive,
framework::Tensor* input_grad);
};
......@@ -197,8 +209,8 @@ class MaxPool3dWithIndexFunctor {
void operator()(const DeviceContext& context, const framework::Tensor& input,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, framework::Tensor* output,
framework::Tensor* mask);
const std::vector<int>& paddings, bool adaptive,
framework::Tensor* output, framework::Tensor* mask);
};
template <typename DeviceContext, typename T1, typename T2>
......@@ -208,7 +220,7 @@ class MaxPool3dWithIndexGradFunctor {
const framework::Tensor& output_grad,
const framework::Tensor& mask, const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& paddings, bool adaptive,
framework::Tensor* input_grad);
};
......
......@@ -52,6 +52,7 @@ void PoolOp::InferShape(framework::InferShapeContext* ctx) const {
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
bool ceil_mode = ctx->Attrs().Get<bool>("ceil_mode");
bool adaptive = ctx->Attrs().Get<bool>("adaptive");
PADDLE_ENFORCE(in_x_dims.size() == 4 || in_x_dims.size() == 5,
"Pooling intput should be 4-D or 5-D tensor.");
......@@ -72,9 +73,13 @@ void PoolOp::InferShape(framework::InferShapeContext* ctx) const {
"Paddings size and pooling size should be the same.");
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1]});
for (size_t i = 0; i < ksize.size(); ++i) {
output_shape.push_back(PoolOutputSize(in_x_dims[i + 2], ksize[i],
paddings[i], strides[i], ceil_mode));
if (adaptive) {
output_shape.insert(output_shape.end(), ksize.begin(), ksize.end());
} else {
for (size_t i = 0; i < ksize.size(); ++i) {
output_shape.push_back(PoolOutputSize(
in_x_dims[i + 2], ksize[i], paddings[i], strides[i], ceil_mode));
}
}
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
ctx->ShareLoD("X", "Out");
......@@ -186,6 +191,14 @@ void Pool2dOpMaker::Make() {
"averaging calculating, otherwise, include the zero-padding. Note, it "
"is only used when pooling_type is avg. The defalut is True.")
.SetDefault(true);
AddAttr<bool>(
"adaptive",
"(bool, default False) When true, will perform adaptive pooling instead, "
"output shape in H and W dimensions will be same as ksize, input data "
"will be divided into grids specify by ksize averagely and perform "
"pooling in each grid area to get output pooling value.")
.SetDefault(false);
AddAttr<bool>(
"use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn")
......@@ -264,6 +277,14 @@ Example:
Output(i ,j) = \\frac{sum(Input[hstart:hend, wstart:wend])}{(hend - hstart) * (wend - wstart)}
$$
For adaptive = true:
$$
hstart = floor(i * H_{in} / H_{out})
hend = ceil((i + 1) * H_{in} / H_{out})
wstart = floor(j * W_{in} / W_{out})
wend = ceil((j + 1) * W_{in} / W_{out})
Output(i ,j) = \\frac{sum(Input[hstart:hend, wstart:wend])}{(hend - hstart) * (wend - wstart)}
$$
)DOC");
}
......@@ -325,6 +346,13 @@ void Pool3dOpMaker::Make() {
"averaging calculating, otherwise, include the zero-padding. Note, it "
"is only used when pooling_type is avg. The defalut is True.")
.SetDefault(true);
AddAttr<bool>(
"adaptive",
"(bool, default False) When true, will perform adaptive pooling instead, "
"output shape in H and W dimensions will be same as ksize, input data "
"will be divided into grids specify by ksize averagely and perform "
"pooling in each grid area to get output pooling value.")
.SetDefault(false);
AddAttr<bool>(
"use_cudnn",
......@@ -376,6 +404,37 @@ Example:
H_{out} = \frac{(H_{in} - ksize[1] + 2 * paddings[1] + strides[1] -1)}{strides[1]} + 1 \\
W_{out} = \frac{(W_{in} - ksize[2] + 2 * paddings[2] + strides[2] -1)}{strides[2]} + 1
$$
For exclusive = true:
$$
dstart = i * strides[0] - paddings[0]
dend = dstart + ksize[0]
hstart = j * strides[1] - paddings[1]
hend = hstart + ksize[1]
wstart = k * strides[2] - paddings[2]
wend = wstart + ksize[2]
Output(i ,j, k) = \\frac{sum(Input[dstart:dend, hstart:hend, wstart:wend])}{ksize[0] * ksize[1] * ksize[2]}
$$
For exclusive = false:
$$
dstart = max(0, i * strides[0] - paddings[0])
dend = min(D, dstart + ksize[0])
hstart = max(0, j * strides[1] - paddings[1])
hend = min(H, hstart + ksize[1])
wstart = max(0, k * strides[2] - paddings[2])
wend = min(W, wstart + ksize[2])
Output(i ,j, k) = \\frac{sum(Input[dstart:dend, hstart:hend, wstart:wend])}{(dend - dstart) * (hend - hstart) * (wend - wstart)}
$$
For adaptive = true:
$$
dstart = floor(i * D_{in} / D_{out})
dend = ceil((i + 1) * D_{in} / D_{out})
hstart = floor(j * H_{in} / H_{out})
hend = ceil((j + 1) * H_{in} / H_{out})
wstart = floor(k * W_{in} / W_{out})
wend = ceil((k + 1) * W_{in} / W_{out})
Output(i ,j, k) = \\frac{sum(Input[dstart:dend, hstart:hend, wstart:wend])}{(dend - dstart) * (hend - hstart) * (wend - wstart)}
$$
)DOC");
}
......
......@@ -70,6 +70,7 @@ class PoolKernel : public framework::OpKernel<T> {
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
bool exclusive = context.Attr<bool>("exclusive");
bool adaptive = context.Attr<bool>("adaptive");
if (context.Attr<bool>("global_pooling")) {
for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
......@@ -85,7 +86,7 @@ class PoolKernel : public framework::OpKernel<T> {
pool2d_forward;
paddle::operators::math::MaxPool<T> pool_process;
pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process,
true, out);
true, false, out);
} else if (pooling_type == "avg") {
paddle::operators::math::Pool2dFunctor<
......@@ -93,7 +94,7 @@ class PoolKernel : public framework::OpKernel<T> {
pool2d_forward;
paddle::operators::math::AvgPool<T> pool_process;
pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process,
exclusive, out);
exclusive, adaptive, out);
}
} break;
case 3: {
......@@ -103,14 +104,14 @@ class PoolKernel : public framework::OpKernel<T> {
pool3d_forward;
paddle::operators::math::MaxPool<T> pool_process;
pool3d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process,
true, out);
true, false, out);
} else if (pooling_type == "avg") {
paddle::operators::math::Pool3dFunctor<
DeviceContext, paddle::operators::math::AvgPool<T>, T>
pool3d_forward;
paddle::operators::math::AvgPool<T> pool_process;
pool3d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process,
exclusive, out);
exclusive, adaptive, out);
}
} break;
default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
......@@ -133,6 +134,7 @@ class PoolGradKernel : public framework::OpKernel<T> {
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
bool exclusive = context.Attr<bool>("exclusive");
bool adaptive = context.Attr<bool>("adaptive");
if (context.Attr<bool>("global_pooling")) {
for (size_t i = 0; i < ksize.size(); ++i) {
......@@ -159,7 +161,8 @@ class PoolGradKernel : public framework::OpKernel<T> {
pool2d_backward;
paddle::operators::math::AvgPoolGrad<T> pool_process;
pool2d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
paddings, pool_process, exclusive, in_x_grad);
paddings, pool_process, exclusive, adaptive,
in_x_grad);
}
} break;
case 3: {
......@@ -174,7 +177,8 @@ class PoolGradKernel : public framework::OpKernel<T> {
pool3d_backward;
paddle::operators::math::AvgPoolGrad<T> pool_process;
pool3d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
paddings, pool_process, exclusive, in_x_grad);
paddings, pool_process, exclusive, adaptive,
in_x_grad);
}
} break;
default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
......
......@@ -40,6 +40,7 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel {
std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
bool adaptive = ctx->Attrs().Get<bool>("adaptive");
PADDLE_ENFORCE(in_x_dims.size() == 4 || in_x_dims.size() == 5,
"Pooling intput should be 4-D or 5-D tensor.");
......@@ -60,9 +61,13 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel {
"Paddings size and pooling size should be the same.");
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1]});
for (size_t i = 0; i < ksize.size(); ++i) {
output_shape.push_back(MaxPoolOutputSize(in_x_dims[i + 2], ksize[i],
paddings[i], strides[i]));
if (adaptive) {
output_shape.insert(output_shape.end(), ksize.begin(), ksize.end());
} else {
for (size_t i = 0; i < ksize.size(); ++i) {
output_shape.push_back(MaxPoolOutputSize(in_x_dims[i + 2], ksize[i],
paddings[i], strides[i]));
}
}
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
ctx->SetOutputDim("Mask", framework::make_ddim(output_shape));
......@@ -133,6 +138,14 @@ class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default:false) Whether to use the global pooling. "
"If global_pooling = true, ksize and paddings will be ignored.")
.SetDefault(false);
AddAttr<bool>(
"adaptive",
"(bool, default False) When true, will perform adaptive pooling "
"instead, "
"output shape in H and W dimensions will be same as ksize, input data "
"will be divided into grids specify by ksize averagely and perform "
"pooling in each grid area to get output pooling value.")
.SetDefault(false);
AddAttr<std::vector<int>>("strides",
"(vector<int>, default {1, 1}), strides(height, "
"width) of pooling operator.")
......@@ -169,6 +182,12 @@ Example:
H_{out} = \frac{(H_{in} - ksize[0] + 2 * paddings[0])}{strides[0]} + 1 \\
W_{out} = \frac{(W_{in} - ksize[1] + 2 * paddings[1])}{strides[1]} + 1
$$
For adaptive = true:
$$
H_{out} = ksize[0] W_{out} = ksize[1]
$$
)DOC");
}
......@@ -209,6 +228,14 @@ class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default false) Whether to use the global pooling. "
"If global_pooling = true, ksize and paddings will be ignored.")
.SetDefault(false);
AddAttr<bool>(
"adaptive",
"(bool, default False) When true, will perform adaptive pooling "
"instead, "
"output shape in H and W dimensions will be same as ksize, input data "
"will be divided into grids specify by ksize averagely and perform "
"pooling in each grid area to get output pooling value.")
.SetDefault(false);
AddAttr<std::vector<int>>("strides",
"(vector<int>, default {1,1,1}), strides(depth, "
"height, width) of pooling operator.")
......@@ -246,6 +273,11 @@ Example:
H_{out} = \frac{(H_{in} - ksize[1] + 2 * paddings[1])}{strides[1]} + 1 \\
W_{out} = \frac{(W_{in} - ksize[2] + 2 * paddings[2])}{strides[2]} + 1
$$
For adaptive = true:
$$
D_{out} = ksize[0] H_{out} = ksize[1] W_{out} = ksize[2]
$$
)DOC");
}
......
......@@ -36,6 +36,7 @@ class MaxPoolWithIndexKernel : public framework::OpKernel<T1> {
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
bool adaptive = context.Attr<bool>("adaptive");
auto& dev_ctx = context.template device_context<DeviceContext>();
if (context.Attr<bool>("global_pooling")) {
......@@ -50,13 +51,15 @@ class MaxPoolWithIndexKernel : public framework::OpKernel<T1> {
paddle::operators::math::MaxPool2dWithIndexFunctor<DeviceContext, T1,
T2>
pool2d_forward;
pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, out, mask);
pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, adaptive, out,
mask);
} break;
case 3: {
paddle::operators::math::MaxPool3dWithIndexFunctor<DeviceContext, T1,
T2>
pool3d_forward;
pool3d_forward(dev_ctx, *in_x, ksize, strides, paddings, out, mask);
pool3d_forward(dev_ctx, *in_x, ksize, strides, paddings, adaptive, out,
mask);
} break;
default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
}
......@@ -75,6 +78,7 @@ class MaxPoolWithIndexGradKernel : public framework::OpKernel<T1> {
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
bool adaptive = context.Attr<bool>("adaptive");
if (context.Attr<bool>("global_pooling")) {
for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
......@@ -93,14 +97,14 @@ class MaxPoolWithIndexGradKernel : public framework::OpKernel<T1> {
T1, T2>
pool2d_backward;
pool2d_backward(device_ctx, *out_grad, *mask, ksize, strides,
paddings, in_x_grad);
paddings, adaptive, in_x_grad);
} break;
case 3: {
paddle::operators::math::MaxPool3dWithIndexGradFunctor<DeviceContext,
T1, T2>
pool3d_backward;
pool3d_backward(device_ctx, *out_grad, *mask, ksize, strides,
paddings, in_x_grad);
paddings, adaptive, in_x_grad);
} break;
default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
}
......
......@@ -56,13 +56,13 @@ class SppKernel : public framework::OpKernel<T> {
math::Pool2dFunctor<DeviceContext, math::MaxPool<T>, T> pool_forward;
math::MaxPool<T> max_process;
pool_forward(context.template device_context<DeviceContext>(), *in_x,
kernel_size, strides, paddings, max_process, true,
kernel_size, strides, paddings, max_process, true, false,
&out_level);
} else if (pooling_type == "avg") {
math::Pool2dFunctor<DeviceContext, math::AvgPool<T>, T> pool_forward;
math::AvgPool<T> avg_process;
pool_forward(context.template device_context<DeviceContext>(), *in_x,
kernel_size, strides, paddings, avg_process, true,
kernel_size, strides, paddings, avg_process, true, false,
&out_level);
}
// flatten pooling output shape
......@@ -156,7 +156,7 @@ class SppGradKernel : public framework::OpKernel<T> {
math::AvgPoolGrad<T> avg_process;
pool_backward(context.template device_context<DeviceContext>(), *in_x,
*&out_level, *&outgrad_level, kernel_size, strides,
paddings, avg_process, true, in_x_grad);
paddings, avg_process, true, false, in_x_grad);
}
}
}
......
......@@ -3,6 +3,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......
......@@ -81,6 +81,14 @@ bool IsCompiledWithCUDA() {
#endif
}
bool IsCompiledWithBrpc() {
#if defined(PADDLE_WITH_BRPC) || defined(PADDLE_WITH_BRPC_RDMA)
return true;
#else
return false;
#endif
}
bool IsCompiledWithDIST() {
#ifdef PADDLE_WITH_DISTRIBUTE
return true;
......@@ -631,6 +639,7 @@ All parameter, weight, gradient are variables in Paddle.
[](bool init_p2p) { framework::InitDevices(init_p2p); });
m.def("is_compiled_with_cuda", IsCompiledWithCUDA);
m.def("is_compiled_with_brpc", IsCompiledWithBrpc);
m.def("is_compiled_with_dist", IsCompiledWithDIST);
#ifdef PADDLE_WITH_CUDA
m.def("is_float16_supported", [](const platform::CUDAPlace &place) -> bool {
......
......@@ -517,6 +517,18 @@ function assert_api_spec_approvals() {
fi
fi
done
HAS_CONST_CAST=`git diff -U0 upstream/$BRANCH |grep -o -m 1 "const_cast" || true`
if [ ${HAS_CONST_CAST} ] && [ "${GIT_PR_ID}" != "" ]; then
APPROVALS=`curl -H "Authorization: token ${GITHUB_API_TOKEN}" https://api.github.com/repos/PaddlePaddle/Paddle/pulls/${GIT_PR_ID}/reviews?per_page=10000 | \
python ${PADDLE_ROOT}/tools/check_pr_approval.py 2 7845005 2887803 728699 13348433`
echo "current pr ${GIT_PR_ID} got approvals: ${APPROVALS}"
if [ "${APPROVALS}" == "FALSE" ]; then
echo "You must have at least 2 approvals for the const_cast"
exit 1
fi
fi
}
......
......@@ -152,6 +152,7 @@ def __bootstrap__():
'enable_cublas_tensor_op_math', 'conv_workspace_size_limit',
'cudnn_exhaustive_search', 'selected_gpus'
]
core.init_gflags([sys.argv[0]] +
["--tryfromenv=" + ",".join(read_env_flags)])
core.init_glog(sys.argv[0])
......
......@@ -52,6 +52,8 @@ __all__ = [
'softmax',
'pool2d',
'pool3d',
'adaptive_pool2d',
'adaptive_pool3d',
'batch_norm',
'beam_search_decode',
'conv2d_transpose',
......@@ -2500,6 +2502,204 @@ def pool3d(input,
return pool_out
@templatedoc(op_type="pool2d")
def adaptive_pool2d(input,
pool_size,
pool_type="max",
require_index=False,
name=None):
"""
${comment}
Args:
input (Variable): The input tensor of pooling operator. The format of
input tensor is NCHW, where N is batch size, C is
the number of channels, H is the height of the
feature, and W is the width of the feature.
pool_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list,
it must contain two integers, (pool_size_Height, pool_size_Width).
pool_type: ${pooling_type_comment}
require_index (bool): If true, the index of max pooling point along with outputs.
it cannot be set in average pooling type.
name (str|None): A name for this layer(optional). If set None, the
layer will be named automatically.
Returns:
Variable: The pooling result.
Raises:
ValueError: 'pool_type' is not 'max' nor 'avg'.
ValueError: invalid setting 'require_index' true when 'pool_type' is 'avg'.
ValueError: 'pool_size' should be a list or tuple with length as 2.
Examples:
.. code-block:: python
# suppose input data in shape of [N, C, H, W], `pool_size` is [m, n],
# output shape is [N, C, m, n], adaptive pool divide H and W dimentions
# of input data into m * n grids averagely and performs poolings in each
# grid to get output.
# adaptive average pool performs calculations as follow:
#
# for i in range(m):
# for j in range(n):
# hstart = floor(i * H / m)
# hend = ceil((i + 1) * H / m)
# wstart = floor(i * W / n)
# wend = ceil((i + 1) * W / n)
# output[:, :, i, j] = avg(input[:, :, hstart: hend, wstart: wend])
#
data = fluid.layers.data(
name='data', shape=[3, 32, 32], dtype='float32')
pool_out = fluid.layers.adaptive_pool2d(
input=data,
pool_size=[3, 3],
pool_type='avg')
"""
if pool_type not in ["max", "avg"]:
raise ValueError(
"Unknown pool_type: '%s'. It can only be 'max' or 'avg'.",
str(pool_type))
if pool_type == "avg" and require_index:
raise ValueError(
"invalid setting 'require_index' true when 'pool_type' is 'avg'.")
def _is_list_or_tuple_(data):
return (isinstance(data, list) or isinstance(data, tuple))
if not _is_list_or_tuple_(pool_size) or len(pool_size) != 2:
raise ValueError(
"'pool_size' should be a list or tuple with length as 2.")
if pool_type == "max":
l_type = 'max_pool2d_with_index'
else:
l_type = "pool2d"
helper = LayerHelper(l_type, **locals())
dtype = helper.input_dtype()
pool_out = helper.create_variable_for_type_inference(dtype)
outputs = {"Out": pool_out}
if pool_type == "max":
mask = helper.create_variable_for_type_inference(dtype)
outputs["Mask"] = mask
helper.append_op(
type=l_type,
inputs={"X": input},
outputs=outputs,
attrs={
"pooling_type": pool_type,
"ksize": pool_size,
"adaptive": True,
})
return (pool_out, mask) if require_index else pool_out
@templatedoc(op_type="pool3d")
def adaptive_pool3d(input,
pool_size,
pool_type="max",
require_index=False,
name=None):
"""
${comment}
Args:
input (Variable): The input tensor of pooling operator. The format of
input tensor is NCHW, where N is batch size, C is
the number of channels, H is the height of the
feature, and W is the width of the feature.
pool_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list,
it must contain two integers, (Depth, Height, Width).
pool_type: ${pooling_type_comment}
require_index (bool): If true, the index of max pooling point along with outputs.
it cannot be set in average pooling type.
name (str|None): A name for this layer(optional). If set None, the
layer will be named automatically.
Returns:
Variable: The pooling result.
Raises:
ValueError: 'pool_type' is not 'max' nor 'avg'.
ValueError: invalid setting 'require_index' true when 'pool_type' is 'avg'.
ValueError: 'pool_size' should be a list or tuple with length as 2.
Examples:
.. code-block:: python
# suppose input data in shape of [N, C, D, H, W], `pool_size` is [l, m, n],
# output shape is [N, C, l, m, n], adaptive pool divide D, H and W dimentions
# of input data into l * m * n grids averagely and performs poolings in each
# grid to get output.
# adaptive average pool performs calculations as follow:
#
# for i in range(l):
# for j in range(m):
# for k in range(n):
# dstart = floor(i * D / l)
# dend = ceil((i + 1) * D / l)
# hstart = floor(j * H / m)
# hend = ceil((j + 1) * H / m)
# wstart = floor(k * W / n)
# wend = ceil((k + 1) * W / n)
# output[:, :, i, j, k] =
# avg(input[:, :, dstart:dend, hstart: hend, wstart: wend])
#
data = fluid.layers.data(
name='data', shape=[3, 32, 32], dtype='float32')
pool_out, mask = fluid.layers.adaptive_pool3d(
input=data,
pool_size=[3, 3],
pool_type='avg')
"""
if pool_type not in ["max", "avg"]:
raise ValueError(
"Unknown pool_type: '%s'. It can only be 'max' or 'avg'.",
str(pool_type))
if pool_type == "avg" and require_index:
raise ValueError(
"invalid setting 'require_index' true when 'pool_type' is 'avg'.")
def _is_list_or_tuple_(data):
return (isinstance(data, list) or isinstance(data, tuple))
if not _is_list_or_tuple_(pool_size) or len(pool_size) != 3:
raise ValueError(
"'pool_size' should be a list or tuple with length as 3.")
if pool_type == "max":
l_type = 'max_pool3d_with_index'
else:
l_type = "pool3d"
helper = LayerHelper(l_type, **locals())
dtype = helper.input_dtype()
pool_out = helper.create_variable_for_type_inference(dtype)
outputs = {"Out": pool_out}
if pool_type == "max":
mask = helper.create_variable_for_type_inference(dtype)
outputs["Mask"] = mask
helper.append_op(
type=l_type,
inputs={"X": input},
outputs=outputs,
attrs={
"pooling_type": pool_type,
"ksize": pool_size,
"adaptive": True,
})
return (pool_out, mask) if require_index else pool_out
def batch_norm(input,
act=None,
is_test=False,
......
......@@ -63,9 +63,9 @@ function(py_test_modules TARGET_NAME)
set(multiValueArgs MODULES DEPS ENVS)
cmake_parse_arguments(py_test_modules "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_test(NAME ${TARGET_NAME}
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${PADDLE_BINARY_DIR}/python ${py_test_modules_ENVS}
${PYTHON_EXECUTABLE} ${PADDLE_SOURCE_DIR}/tools/test_runner.py ${py_test_modules_MODULES}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${PADDLE_BINARY_DIR}/python ${py_test_modules_ENVS}
${PYTHON_EXECUTABLE} ${PADDLE_SOURCE_DIR}/tools/test_runner.py ${py_test_modules_MODULES}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
if (py_test_modules_SERIAL)
set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1)
endif()
......@@ -111,3 +111,7 @@ py_test_modules(test_parallel_executor_transformer MODULES test_parallel_executo
if(NOT APPLE)
py_test_modules(test_image_classification_resnet MODULES test_image_classification_resnet SERIAL)
endif()
if (WITH_NGRAPH)
add_subdirectory(ngraph)
endif()
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS FLAGS_use_ngraph=true)
endforeach(TEST_OP)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
......@@ -233,6 +233,29 @@ class TestBook(unittest.TestCase):
pool_stride=[1, 2],
pool_padding=(2, 1)))
def test_adaptive_pool2d(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[3, 224, 224], dtype='float32')
self.assertIsNotNone(
layers.adaptive_pool2d(
x, [3, 3], pool_type='avg'))
pool, mask = layers.adaptive_pool2d(x, [3, 3], require_index=True)
self.assertIsNotNone(pool)
self.assertIsNotNone(mask)
def test_adaptive_pool3d(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[3, 244, 224, 224], dtype='float32')
self.assertIsNotNone(
layers.adaptive_pool3d(
x, [3, 3, 3], pool_type='avg'))
pool, mask = layers.adaptive_pool3d(
x, [3, 3, 3], require_index=True)
self.assertIsNotNone(pool)
self.assertIsNotNone(mask)
def test_lstm_unit(self):
program = Program()
with program_guard(program):
......
......@@ -13,6 +13,7 @@
# limitations under the License.
from __future__ import print_function
from __future__ import division
import unittest
import numpy as np
......@@ -21,29 +22,47 @@ import paddle.fluid.core as core
from op_test import OpTest
def adaptive_start_index(index, input_size, output_size):
return int(np.floor(index * input_size / output_size))
def adaptive_end_index(index, input_size, output_size):
return int(np.ceil((index + 1) * input_size / output_size))
def max_pool2D_forward_naive(x,
ksize,
strides,
paddings,
global_pool=0,
ceil_mode=False,
exclusive=True):
exclusive=True,
adaptive=False):
N, C, H, W = x.shape
if global_pool == 1:
ksize = [H, W]
H_out = (H - ksize[0] + 2 * paddings[0] + strides[0] - 1
) // strides[0] + 1 if ceil_mode else (
H - ksize[0] + 2 * paddings[0]) // strides[0] + 1
W_out = (W - ksize[1] + 2 * paddings[1] + strides[1] - 1
) // strides[1] + 1 if ceil_mode else (
W - ksize[1] + 2 * paddings[1]) // strides[1] + 1
if adaptive:
H_out, W_out = ksize
else:
H_out = (H - ksize[0] + 2 * paddings[0] + strides[0] - 1
) // strides[0] + 1 if ceil_mode else (
H - ksize[0] + 2 * paddings[0]) // strides[0] + 1
W_out = (W - ksize[1] + 2 * paddings[1] + strides[1] - 1
) // strides[1] + 1 if ceil_mode else (
W - ksize[1] + 2 * paddings[1]) // strides[1] + 1
out = np.zeros((N, C, H_out, W_out))
for i in range(H_out):
for j in range(W_out):
r_start = np.max((i * strides[0] - paddings[0], 0))
r_end = np.min((i * strides[0] + ksize[0] - paddings[0], H))
c_start = np.max((j * strides[1] - paddings[1], 0))
c_end = np.min((j * strides[1] + ksize[1] - paddings[1], W))
if adaptive:
r_start = adaptive_start_index(i, H, ksize[0])
r_end = adaptive_end_index(i, H, ksize[0])
c_start = adaptive_start_index(j, W, ksize[1])
c_end = adaptive_end_index(j, W, ksize[1])
else:
r_start = np.max((i * strides[0] - paddings[0], 0))
r_end = np.min((i * strides[0] + ksize[0] - paddings[0], H))
c_start = np.max((j * strides[1] - paddings[1], 0))
c_end = np.min((j * strides[1] + ksize[1] - paddings[1], W))
x_masked = x[:, :, r_start:r_end, c_start:c_end]
out[:, :, i, j] = np.max(x_masked, axis=(2, 3))
......@@ -56,27 +75,37 @@ def avg_pool2D_forward_naive(x,
paddings,
global_pool=0,
ceil_mode=False,
exclusive=True):
exclusive=True,
adaptive=False):
N, C, H, W = x.shape
if global_pool == 1:
ksize = [H, W]
H_out = (H - ksize[0] + 2 * paddings[0] + strides[0] - 1
) // strides[0] + 1 if ceil_mode else (
H - ksize[0] + 2 * paddings[0]) // strides[0] + 1
W_out = (W - ksize[1] + 2 * paddings[1] + strides[1] - 1
) // strides[1] + 1 if ceil_mode else (
W - ksize[1] + 2 * paddings[1]) // strides[1] + 1
if adaptive:
H_out, W_out = ksize
else:
H_out = (H - ksize[0] + 2 * paddings[0] + strides[0] - 1
) // strides[0] + 1 if ceil_mode else (
H - ksize[0] + 2 * paddings[0]) // strides[0] + 1
W_out = (W - ksize[1] + 2 * paddings[1] + strides[1] - 1
) // strides[1] + 1 if ceil_mode else (
W - ksize[1] + 2 * paddings[1]) // strides[1] + 1
out = np.zeros((N, C, H_out, W_out))
for i in range(H_out):
for j in range(W_out):
r_start = np.max((i * strides[0] - paddings[0], 0))
r_end = np.min((i * strides[0] + ksize[0] - paddings[0], H))
c_start = np.max((j * strides[1] - paddings[1], 0))
c_end = np.min((j * strides[1] + ksize[1] - paddings[1], W))
if adaptive:
r_start = adaptive_start_index(i, H, ksize[0])
r_end = adaptive_end_index(i, H, ksize[0])
c_start = adaptive_start_index(j, W, ksize[1])
c_end = adaptive_end_index(j, W, ksize[1])
else:
r_start = np.max((i * strides[0] - paddings[0], 0))
r_end = np.min((i * strides[0] + ksize[0] - paddings[0], H))
c_start = np.max((j * strides[1] - paddings[1], 0))
c_end = np.min((j * strides[1] + ksize[1] - paddings[1], W))
x_masked = x[:, :, r_start:r_end, c_start:c_end]
field_size = ((r_end - r_start) * (c_end - c_start)) if exclusive \
else (ksize[0] * ksize[1])
field_size = ((r_end - r_start) * (c_end - c_start)) \
if (exclusive or adaptive) else (ksize[0] * ksize[1])
out[:, :, i, j] = np.sum(x_masked, axis=(2, 3)) / field_size
return out
......@@ -93,12 +122,13 @@ class TestPool2D_Op(OpTest):
self.init_pool_type()
self.init_ceil_mode()
self.init_exclusive()
self.init_adaptive()
if self.global_pool:
self.paddings = [0 for _ in range(len(self.paddings))]
input = np.random.random(self.shape).astype(self.dtype)
output = self.pool2D_forward_naive(
input, self.ksize, self.strides, self.paddings, self.global_pool,
self.ceil_mode, self.exclusive).astype(self.dtype)
self.ceil_mode, self.exclusive, self.adaptive).astype(self.dtype)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(input)}
self.attrs = {
......@@ -112,7 +142,8 @@ class TestPool2D_Op(OpTest):
'ceil_mode': self.ceil_mode,
'data_format':
'AnyLayout', # TODO(dzhwinter) : should be fix latter
'exclusive': self.exclusive
'exclusive': self.exclusive,
'adaptive': self.adaptive
}
self.outputs = {'Out': output}
......@@ -159,6 +190,9 @@ class TestPool2D_Op(OpTest):
def init_exclusive(self):
self.exclusive = True
def init_adaptive(self):
self.adaptive = False
class TestCase1(TestPool2D_Op):
def init_test_case(self):
......@@ -315,5 +349,10 @@ class TestCUDNNAvgInclude(TestCase2):
self.exclusive = False
class TestAvgPoolAdaptive(TestCase1):
def init_adaptive(self):
self.adaptive = True
if __name__ == '__main__':
unittest.main()
......@@ -13,6 +13,7 @@
# limitations under the License.
from __future__ import print_function
from __future__ import division
import unittest
import numpy as np
......@@ -21,35 +22,59 @@ import paddle.fluid.core as core
from op_test import OpTest
def adaptive_start_index(index, input_size, output_size):
return int(np.floor(index * input_size / output_size))
def adaptive_end_index(index, input_size, output_size):
return int(np.ceil((index + 1) * input_size / output_size))
def max_pool3D_forward_naive(x,
ksize,
strides,
paddings,
global_pool=0,
ceil_mode=False,
exclusive=True):
exclusive=True,
adaptive=False):
N, C, D, H, W = x.shape
if global_pool == 1:
ksize = [D, H, W]
D_out = (D - ksize[0] + 2 * paddings[0] + strides[0] - 1
) // strides[0] + 1 if ceil_mode else (
H - ksize[0] + 2 * paddings[0]) // strides[0] + 1
H_out = (H - ksize[1] + 2 * paddings[1] + strides[1] - 1
) // strides[1] + 1 if ceil_mode else (
W - ksize[1] + 2 * paddings[1]) // strides[1] + 1
W_out = (W - ksize[2] + 2 * paddings[2] + strides[2] - 1
) // strides[2] + 1 if ceil_mode else (
W - ksize[2] + 2 * paddings[2]) // strides[2] + 1
if adaptive:
D_out, H_out, W_out = ksize
else:
D_out = (D - ksize[0] + 2 * paddings[0] + strides[0] - 1
) // strides[0] + 1 if ceil_mode else (
H - ksize[0] + 2 * paddings[0]) // strides[0] + 1
H_out = (H - ksize[1] + 2 * paddings[1] + strides[1] - 1
) // strides[1] + 1 if ceil_mode else (
W - ksize[1] + 2 * paddings[1]) // strides[1] + 1
W_out = (W - ksize[2] + 2 * paddings[2] + strides[2] - 1
) // strides[2] + 1 if ceil_mode else (
W - ksize[2] + 2 * paddings[2]) // strides[2] + 1
out = np.zeros((N, C, D_out, H_out, W_out))
for k in range(D_out):
d_start = np.max((k * strides[0] - paddings[0], 0))
d_end = np.min((k * strides[0] + ksize[0] - paddings[0], D))
if adaptive:
d_start = adaptive_start_index(k, D, ksize[0])
d_end = adaptive_end_index(k, D, ksize[0])
else:
d_start = np.max((k * strides[0] - paddings[0], 0))
d_end = np.min((k * strides[0] + ksize[0] - paddings[0], D))
for i in range(H_out):
h_start = np.max((i * strides[0] - paddings[0], 0))
h_end = np.min((i * strides[0] + ksize[0] - paddings[0], H))
if adaptive:
h_start = adaptive_start_index(i, H, ksize[1])
h_end = adaptive_end_index(i, H, ksize[1])
else:
h_start = np.max((i * strides[1] - paddings[1], 0))
h_end = np.min((i * strides[1] + ksize[1] - paddings[1], H))
for j in range(W_out):
w_start = np.max((j * strides[1] - paddings[1], 0))
w_end = np.min((j * strides[1] + ksize[1] - paddings[1], W))
if adaptive:
w_start = adaptive_start_index(j, W, ksize[2])
w_end = adaptive_end_index(j, W, ksize[2])
else:
w_start = np.max((j * strides[2] - paddings[2], 0))
w_end = np.min((j * strides[2] + ksize[2] - paddings[2], W))
x_masked = x[:, :, d_start:d_end, h_start:h_end, w_start:w_end]
out[:, :, k, i, j] = np.max(x_masked, axis=(2, 3, 4))
......@@ -62,33 +87,49 @@ def avg_pool3D_forward_naive(x,
paddings,
global_pool=0,
ceil_mode=False,
exclusive=True):
exclusive=True,
adaptive=False):
N, C, D, H, W = x.shape
if global_pool == 1:
ksize = [D, H, W]
D_out = (D - ksize[0] + 2 * paddings[0] + strides[0] - 1
) // strides[0] + 1 if ceil_mode else (
H - ksize[0] + 2 * paddings[0]) // strides[0] + 1
H_out = (H - ksize[1] + 2 * paddings[1] + strides[1] - 1
) // strides[1] + 1 if ceil_mode else (
W - ksize[1] + 2 * paddings[1]) // strides[1] + 1
W_out = (W - ksize[2] + 2 * paddings[2] + strides[2] - 1
) // strides[2] + 1 if ceil_mode else (
W - ksize[2] + 2 * paddings[2]) // strides[2] + 1
if adaptive:
D_out, H_out, W_out = ksize
else:
D_out = (D - ksize[0] + 2 * paddings[0] + strides[0] - 1
) // strides[0] + 1 if ceil_mode else (
H - ksize[0] + 2 * paddings[0]) // strides[0] + 1
H_out = (H - ksize[1] + 2 * paddings[1] + strides[1] - 1
) // strides[1] + 1 if ceil_mode else (
W - ksize[1] + 2 * paddings[1]) // strides[1] + 1
W_out = (W - ksize[2] + 2 * paddings[2] + strides[2] - 1
) // strides[2] + 1 if ceil_mode else (
W - ksize[2] + 2 * paddings[2]) // strides[2] + 1
out = np.zeros((N, C, D_out, H_out, W_out))
for k in range(D_out):
d_start = np.max((k * strides[0] - paddings[0], 0))
d_end = np.min((k * strides[0] + ksize[0] - paddings[0], D))
if adaptive:
d_start = adaptive_start_index(k, D, ksize[0])
d_end = adaptive_end_index(k, D, ksize[0])
else:
d_start = np.max((k * strides[0] - paddings[0], 0))
d_end = np.min((k * strides[0] + ksize[0] - paddings[0], D))
for i in range(H_out):
h_start = np.max((i * strides[0] - paddings[0], 0))
h_end = np.min((i * strides[0] + ksize[0] - paddings[0], H))
if adaptive:
h_start = adaptive_start_index(i, H, ksize[1])
h_end = adaptive_end_index(i, H, ksize[1])
else:
h_start = np.max((i * strides[1] - paddings[1], 0))
h_end = np.min((i * strides[1] + ksize[1] - paddings[1], H))
for j in range(W_out):
w_start = np.max((j * strides[1] - paddings[1], 0))
w_end = np.min((j * strides[1] + ksize[1] - paddings[1], W))
if adaptive:
w_start = adaptive_start_index(j, W, ksize[2])
w_end = adaptive_end_index(j, W, ksize[2])
else:
w_start = np.max((j * strides[2] - paddings[2], 0))
w_end = np.min((j * strides[2] + ksize[2] - paddings[2], W))
x_masked = x[:, :, d_start:d_end, h_start:h_end, w_start:w_end]
field_size = (d_end - d_start) * (h_end - h_start) * (w_end - w_start) \
if exclusive else ksize[0] * ksize[1] * ksize[2]
if (exclusive or adaptive) else ksize[0] * ksize[1] * ksize[2]
out[:, :, k, i, j] = np.sum(x_masked, axis=(2, 3,
4)) / field_size
return out
......@@ -105,13 +146,14 @@ class TestPool3d_Op(OpTest):
self.init_pool_type()
self.init_ceil_mode()
self.init_exclusive()
self.init_adaptive()
if self.global_pool:
self.paddings = [0 for _ in range(len(self.paddings))]
input = np.random.random(self.shape).astype(self.dtype)
output = self.pool3D_forward_naive(
input, self.ksize, self.strides, self.paddings, self.global_pool,
self.ceil_mode, self.exclusive).astype(self.dtype)
self.ceil_mode, self.exclusive, self.adaptive).astype(self.dtype)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(input)}
self.attrs = {
......@@ -124,7 +166,8 @@ class TestPool3d_Op(OpTest):
'ceil_mode': self.ceil_mode,
'data_format':
'AnyLayout', # TODO(dzhwinter) : should be fix latter
'exclusive': self.exclusive
'exclusive': self.exclusive,
'adaptive': self.adaptive
}
self.outputs = {'Out': output}
......@@ -171,6 +214,9 @@ class TestPool3d_Op(OpTest):
def init_exclusive(self):
self.exclusive = True
def init_adaptive(self):
self.adaptive = False
class TestCase1(TestPool3d_Op):
def init_test_case(self):
......@@ -353,5 +399,10 @@ class TestCUDNNAvgInclude(TestCUDNNCase3):
self.exclusive = False
class TestAvgPoolAdaptive(TestCase1):
def init_adaptive(self):
self.adaptive = True
if __name__ == '__main__':
unittest.main()
......@@ -13,33 +13,62 @@
# limitations under the License.
from __future__ import print_function
from __future__ import division
import unittest
import numpy as np
from op_test import OpTest
def max_pool3D_forward_naive(x, ksize, strides, paddings, global_pool=False):
def adaptive_start_index(index, input_size, output_size):
return int(np.floor(index * input_size / output_size))
def adaptive_end_index(index, input_size, output_size):
return int(np.ceil((index + 1) * input_size / output_size))
def max_pool3D_forward_naive(x,
ksize,
strides,
paddings,
global_pool=False,
adaptive=False):
N, C, D, H, W = x.shape
if global_pool:
ksize = [D, H, W]
paddings = [0, 0, 0]
D_out = (D - ksize[0] + 2 * paddings[0]) // strides[0] + 1
H_out = (H - ksize[1] + 2 * paddings[1]) // strides[1] + 1
W_out = (W - ksize[2] + 2 * paddings[2]) // strides[2] + 1
if adaptive:
D_out, H_out, W_out = ksize
else:
D_out = (D - ksize[0] + 2 * paddings[0]) // strides[0] + 1
H_out = (H - ksize[1] + 2 * paddings[1]) // strides[1] + 1
W_out = (W - ksize[2] + 2 * paddings[2]) // strides[2] + 1
out = np.zeros((N, C, D_out, H_out, W_out))
mask = np.zeros((N, C, D_out, H_out, W_out))
for k in range(D_out):
d_start = np.max((k * strides[0] - paddings[0], 0))
d_end = np.min((k * strides[0] + ksize[0] - paddings[0], D))
if adaptive:
d_start = adaptive_start_index(k, D, ksize[0])
d_end = adaptive_end_index(k, D, ksize[0])
else:
d_start = np.max((k * strides[0] - paddings[0], 0))
d_end = np.min((k * strides[0] + ksize[0] - paddings[0], D))
for i in range(H_out):
h_start = np.max((i * strides[0] - paddings[0], 0))
h_end = np.min((i * strides[0] + ksize[0] - paddings[0], H))
if adaptive:
h_start = adaptive_start_index(i, H, ksize[1])
h_end = adaptive_end_index(i, H, ksize[1])
else:
h_start = np.max((i * strides[1] - paddings[1], 0))
h_end = np.min((i * strides[1] + ksize[1] - paddings[1], H))
for j in range(W_out):
w_start = np.max((j * strides[1] - paddings[1], 0))
w_end = np.min((j * strides[1] + ksize[1] - paddings[1], W))
if adaptive:
w_start = adaptive_start_index(j, W, ksize[2])
w_end = adaptive_end_index(j, W, ksize[2])
else:
w_start = np.max((j * strides[2] - paddings[2], 0))
w_end = np.min((j * strides[2] + ksize[2] - paddings[2], W))
x_masked = x[:, :, d_start:d_end, h_start:h_end, w_start:w_end]
out[:, :, k, i, j] = np.max(x_masked, axis=(2, 3, 4))
......@@ -58,23 +87,37 @@ def max_pool3D_forward_naive(x, ksize, strides, paddings, global_pool=False):
return out, mask
def max_pool2D_forward_naive(x, ksize, strides, paddings, global_pool=False):
def max_pool2D_forward_naive(x,
ksize,
strides,
paddings,
global_pool=False,
adaptive=False):
N, C, H, W = x.shape
if global_pool:
ksize = [H, W]
paddings = [0, 0]
H_out = (H - ksize[0] + 2 * paddings[0]) // strides[0] + 1
W_out = (W - ksize[1] + 2 * paddings[1]) // strides[1] + 1
if adaptive:
H_out, W_out = ksize
else:
H_out = (H - ksize[0] + 2 * paddings[0]) // strides[0] + 1
W_out = (W - ksize[1] + 2 * paddings[1]) // strides[1] + 1
out = np.zeros((N, C, H_out, W_out))
mask = np.zeros((N, C, H_out, W_out))
for i in range(H_out):
for j in range(W_out):
r_start = np.max((i * strides[0] - paddings[0], 0))
r_end = np.min((i * strides[0] + ksize[0] - paddings[0], H))
c_start = np.max((j * strides[1] - paddings[1], 0))
c_end = np.min((j * strides[1] + ksize[1] - paddings[1], W))
if adaptive:
r_start = adaptive_start_index(i, H, ksize[0])
r_end = adaptive_end_index(i, H, ksize[0])
c_start = adaptive_start_index(j, W, ksize[1])
c_end = adaptive_end_index(j, W, ksize[1])
else:
r_start = np.max((i * strides[0] - paddings[0], 0))
r_end = np.min((i * strides[0] + ksize[0] - paddings[0], H))
c_start = np.max((j * strides[1] - paddings[1], 0))
c_end = np.min((j * strides[1] + ksize[1] - paddings[1], W))
x_masked = x[:, :, r_start:r_end, c_start:c_end]
out[:, :, i, j] = np.max(x_masked, axis=(2, 3))
......@@ -95,10 +138,12 @@ class TestMaxPoolWithIndex_Op(OpTest):
def setUp(self):
self.init_test_case()
self.init_global()
self.init_adaptive()
input = np.random.random(self.shape).astype("float32")
output, mask = self.pool_forward_naive(input, self.ksize, self.strides,
self.paddings, self.global_pool)
self.paddings, self.global_pool,
self.adaptive)
output = output.astype("float32")
mask = mask.astype("int32")
......@@ -107,6 +152,7 @@ class TestMaxPoolWithIndex_Op(OpTest):
'paddings': self.paddings,
'ksize': self.ksize,
'global_pooling': self.global_pool,
'adaptive': self.adaptive,
}
self.inputs = {'X': input}
......@@ -129,6 +175,9 @@ class TestMaxPoolWithIndex_Op(OpTest):
def init_global(self):
self.global_pool = False
def init_adaptive(self):
self.adaptive = False
class TestCase1(TestMaxPoolWithIndex_Op):
def init_global(self):
......@@ -190,5 +239,15 @@ class TestCase7(TestCase6):
self.global_pool = False
class TestCastAdaptive2d(TestCase6):
def init_adaptive(self):
self.adaptive = True
class TestCastAdaptive3d(TestMaxPoolWithIndex_Op):
def init_adaptive(self):
self.adaptive = True
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册