提交 3bd54ed7 编写于 作者: Q Qiao Longfei

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into multithread-sparse-adam

...@@ -19,6 +19,15 @@ Our vision is to enable deep learning for everyone via PaddlePaddle. ...@@ -19,6 +19,15 @@ Our vision is to enable deep learning for everyone via PaddlePaddle.
Please refer to our [release announcement](https://github.com/PaddlePaddle/Paddle/releases) to track the latest feature of PaddlePaddle. Please refer to our [release announcement](https://github.com/PaddlePaddle/Paddle/releases) to track the latest feature of PaddlePaddle.
欢迎来到 PaddlePaddle GitHub
PaddlePaddle (PArallel Distributed Deep LEarning) 是一个简单易用、高效灵活、可扩展的深度学习平台,最初由百度科学家和工程师共同开发,目的是将深度学习技术应用到百度的众多产品中。
我们的愿景是让每个人都能通过PaddlePaddle接触深度学习
跟进PaddlePaddle最新特性请参考我们的[版本说明](https://github.com/PaddlePaddle/Paddle/releases)
### Latest PaddlePaddle Release: [Fluid 1.2.0](https://github.com/PaddlePaddle/Paddle/tree/release/1.2) ### Latest PaddlePaddle Release: [Fluid 1.2.0](https://github.com/PaddlePaddle/Paddle/tree/release/1.2)
### Install Latest Stable Release: ### Install Latest Stable Release:
``` ```
...@@ -34,6 +43,23 @@ pip install paddlepaddle-gpu==1.2.0.post85 ...@@ -34,6 +43,23 @@ pip install paddlepaddle-gpu==1.2.0.post85
# For installation on other platform, refer to http://paddlepaddle.org/ # For installation on other platform, refer to http://paddlepaddle.org/
``` ```
### PaddlePaddle最新版本: [Fluid 1.2.0](https://github.com/PaddlePaddle/Paddle/tree/release/1.2)
### 安装最新稳定版本:
```
# Linux CPU
pip install paddlepaddle
# Linux GPU cuda9cudnn7
pip install paddlepaddle-gpu
# Linux GPU cuda8cudnn7
pip install paddlepaddle-gpu==1.2.0.post87
# Linux GPU cuda8cudnn5
pip install paddlepaddle-gpu==1.2.0.post85
# 其他平台上的安装指引请参考 http://paddlepaddle.org/
```
## Features ## Features
- **Flexibility** - **Flexibility**
...@@ -74,10 +100,38 @@ pip install paddlepaddle-gpu==1.2.0.post85 ...@@ -74,10 +100,38 @@ pip install paddlepaddle-gpu==1.2.0.post85
Baidu and it has achieved a significant impact. We hope you can also explore Baidu and it has achieved a significant impact. We hope you can also explore
the capability of PaddlePaddle to make an impact on your product. the capability of PaddlePaddle to make an impact on your product.
## 特点
- **灵活性**
PaddlePaddle支持丰富的神经网络架构和优化算法。易于配置复杂模型,例如带有注意力机制或复杂记忆连接的神经网络机器翻译模型。
- **高效性**
为了高效使用异步计算资源,PaddlePaddle对框架的不同层进行优化,包括计算、存储、架构和通信。下面是一些样例:
- 通过SSE/AVX 内置函数、BLAS库(例如MKL、OpenBLAS、cuBLAS)或定制的CPU/GPU内核优化数学操作。
- 通过MKL-DNN库优化CNN网络
- 高度优化循环网络,无需执行 `padding` 操作即可处理 **变长** 序列
- 针对高维稀疏数据模型,优化了局部和分布式训练。
- **稳定性**
有了 PaddlePaddle,使得利用各种CPU/GPU和机器来加速训练变得简单。PaddlePaddle 通过优化通信可以实现巨大吞吐量和快速执行。
- **连接产品**
另外,PaddlePaddle 的设计也易于部署。在百度,PaddlePaddle 已经部署到含有巨大用户量的产品和服务上,包括广告点击率(CTR)预测、大规模图像分类、光学字符识别(OCR)、搜索排序,计算机病毒检测、推荐系统等等。PaddlePaddle广泛应用于百度产品中,产生了非常重要的影响。我们希望您也能探索 PaddlePaddle 的能力,为您的产品创造新的影响力和效果。
## Installation ## Installation
It is recommended to read [this doc](http://paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/install/index_cn.html) on our website. It is recommended to read [this doc](http://paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/install/index_cn.html) on our website.
## 安装
推荐阅读官网上的[安装说明](http://paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/install/index_cn.html)
## Documentation ## Documentation
We provide [English](http://paddlepaddle.org/documentation/docs/en/1.2/getstarted/index_en.html) and We provide [English](http://paddlepaddle.org/documentation/docs/en/1.2/getstarted/index_en.html) and
...@@ -99,10 +153,37 @@ We provide [English](http://paddlepaddle.org/documentation/docs/en/1.2/getstarte ...@@ -99,10 +153,37 @@ We provide [English](http://paddlepaddle.org/documentation/docs/en/1.2/getstarte
We appreciate your contributions! We appreciate your contributions!
## 文档
我们提供[英文](http://paddlepaddle.org/documentation/docs/en/1.2/getstarted/index_en.html)
[中文](http://paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/index.html) 文档
- [深度学习101](https://github.com/PaddlePaddle/book)
或许您想从这个在线交互式书籍开始,可以在Jupyter Notebook中运行
- [分布式训练](http://paddlepaddle.org/documentation/docs/zh/1.2/user_guides/howto/training/cluster_howto.html)
可以在MPI集群上运行分布式训练任务
- [Python API](http://paddlepaddle.org/documentation/docs/zh/1.2/api_cn/index_cn.html)
新的API支持代码更少更简洁的程序
- [贡献方式](http://paddlepaddle.org/documentation/docs/zh/1.2/advanced_usage/development/contribute_to_paddle/index_cn.html)
欢迎您的贡献!
## Ask Questions ## Ask Questions
You are welcome to submit questions and bug reports as [Github Issues](https://github.com/PaddlePaddle/Paddle/issues). You are welcome to submit questions and bug reports as [Github Issues](https://github.com/PaddlePaddle/Paddle/issues).
## 答疑
欢迎您将问题和bug报告以[Github Issues](https://github.com/PaddlePaddle/Paddle/issues)的形式提交
## Copyright and License ## Copyright and License
PaddlePaddle is provided under the [Apache-2.0 license](LICENSE). PaddlePaddle is provided under the [Apache-2.0 license](LICENSE).
## 版权和许可证
PaddlePaddle由[Apache-2.0 license](LICENSE)提供
...@@ -81,9 +81,11 @@ def dist_transpile(trainer_id, args, train_prog, startup_prog): ...@@ -81,9 +81,11 @@ def dist_transpile(trainer_id, args, train_prog, startup_prog):
# the role, should be either PSERVER or TRAINER # the role, should be either PSERVER or TRAINER
training_role = os.getenv("PADDLE_TRAINING_ROLE") training_role = os.getenv("PADDLE_TRAINING_ROLE")
config = distribute_transpiler.DistributeTranspilerConfig() config = fluid.DistributeTranspilerConfig()
config.slice_var_up = not args.no_split_var config.slice_var_up = not args.no_split_var
config.min_block_size = 1048576
t = distribute_transpiler.DistributeTranspiler(config=config) t = distribute_transpiler.DistributeTranspiler(config=config)
t.transpile( t.transpile(
trainer_id, trainer_id,
# NOTE: *MUST* use train_prog, for we are using with guard to # NOTE: *MUST* use train_prog, for we are using with guard to
......
...@@ -14,14 +14,16 @@ ...@@ -14,14 +14,16 @@
INCLUDE(ExternalProject) INCLUDE(ExternalProject)
find_library(SSL_LIBRARY NAMES ssl) find_package(OpenSSL REQUIRED)
message(STATUS "ssl:" ${OPENSSL_SSL_LIBRARY})
message(STATUS "crypto:" ${OPENSSL_CRYPTO_LIBRARY})
ADD_LIBRARY(ssl SHARED IMPORTED GLOBAL) ADD_LIBRARY(ssl SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET ssl PROPERTY IMPORTED_LOCATION ${SSL_LIBRARY}) SET_PROPERTY(TARGET ssl PROPERTY IMPORTED_LOCATION ${OPENSSL_SSL_LIBRARY})
find_library(CRYPTO_LIBRARY NAMES crypto)
ADD_LIBRARY(crypto SHARED IMPORTED GLOBAL) ADD_LIBRARY(crypto SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET crypto PROPERTY IMPORTED_LOCATION ${CRYPTO_LIBRARY}) SET_PROPERTY(TARGET crypto PROPERTY IMPORTED_LOCATION ${OPENSSL_CRYPTO_LIBRARY})
SET(BRPC_SOURCES_DIR ${THIRD_PARTY_PATH}/brpc) SET(BRPC_SOURCES_DIR ${THIRD_PARTY_PATH}/brpc)
SET(BRPC_INSTALL_DIR ${THIRD_PARTY_PATH}/install/brpc) SET(BRPC_INSTALL_DIR ${THIRD_PARTY_PATH}/install/brpc)
...@@ -31,14 +33,15 @@ SET(BRPC_LIBRARIES "${BRPC_INSTALL_DIR}/lib/libbrpc.a" CACHE FILEPATH "brpc libr ...@@ -31,14 +33,15 @@ SET(BRPC_LIBRARIES "${BRPC_INSTALL_DIR}/lib/libbrpc.a" CACHE FILEPATH "brpc libr
INCLUDE_DIRECTORIES(${BRPC_INCLUDE_DIR}) INCLUDE_DIRECTORIES(${BRPC_INCLUDE_DIR})
# Reference https://stackoverflow.com/questions/45414507/pass-a-list-of-prefix-paths-to-externalproject-add-in-cmake-args # Reference https://stackoverflow.com/questions/45414507/pass-a-list-of-prefix-paths-to-externalproject-add-in-cmake-args
set(prefix_path "${THIRD_PARTY_PATH}/install/gflags|${THIRD_PARTY_PATH}/install/leveldb|${THIRD_PARTY_PATH}/install/snappy|${THIRD_PARTY_PATH}/install/gtest|${THIRD_PARTY_PATH}/install/protobuf|${THIRD_PARTY_PATH}/install/zlib") set(prefix_path "${THIRD_PARTY_PATH}/install/gflags|${THIRD_PARTY_PATH}/install/leveldb|${THIRD_PARTY_PATH}/install/snappy|${THIRD_PARTY_PATH}/install/gtest|${THIRD_PARTY_PATH}/install/protobuf|${THIRD_PARTY_PATH}/install/zlib|${THIRD_PARTY_PATH}/install/glog")
# If minimal .a is need, you can set WITH_DEBUG_SYMBOLS=OFF # If minimal .a is need, you can set WITH_DEBUG_SYMBOLS=OFF
ExternalProject_Add( ExternalProject_Add(
extern_brpc extern_brpc
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
# TODO(gongwb): change to de newst repo when they changed.
GIT_REPOSITORY "https://github.com/gongweibao/brpc" GIT_REPOSITORY "https://github.com/gongweibao/brpc"
GIT_TAG "7dc04defad1fd4173aae170c3fcbde131b65155a" GIT_TAG "e9b67ec1b7458f2af5fae76451afe1e27e01b4b4"
PREFIX ${BRPC_SOURCES_DIR} PREFIX ${BRPC_SOURCES_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
...@@ -50,7 +53,7 @@ ExternalProject_Add( ...@@ -50,7 +53,7 @@ ExternalProject_Add(
-DCMAKE_POSITION_INDEPENDENT_CODE=ON -DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_PREFIX_PATH=${prefix_path} -DCMAKE_PREFIX_PATH=${prefix_path}
-DBRPC_WITH_GLOG=ON -DWITH_GLOG=ON
-DIOBUF_WITH_HUGE_BLOCK=ON -DIOBUF_WITH_HUGE_BLOCK=ON
-DBRPC_WITH_RDMA=${WITH_BRPC_RDMA} -DBRPC_WITH_RDMA=${WITH_BRPC_RDMA}
${EXTERNAL_OPTIONAL_ARGS} ${EXTERNAL_OPTIONAL_ARGS}
...@@ -65,5 +68,6 @@ ADD_LIBRARY(brpc STATIC IMPORTED GLOBAL) ...@@ -65,5 +68,6 @@ ADD_LIBRARY(brpc STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET brpc PROPERTY IMPORTED_LOCATION ${BRPC_LIBRARIES}) SET_PROPERTY(TARGET brpc PROPERTY IMPORTED_LOCATION ${BRPC_LIBRARIES})
ADD_DEPENDENCIES(brpc extern_brpc) ADD_DEPENDENCIES(brpc extern_brpc)
add_definitions(-DBRPC_WITH_GLOG)
LIST(APPEND external_project_dependencies brpc) LIST(APPEND external_project_dependencies brpc)
...@@ -12,8 +12,12 @@ ...@@ -12,8 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
IF(WITH_TESTING) #FIXME:(gongwb) Move brpc's gtest dependency.
ENABLE_TESTING() IF(WITH_TESTING OR (WITH_DISTRIBUTE AND NOT WITH_GRPC))
IF(WITH_TESTING)
ENABLE_TESTING()
ENDIF(WITH_TESTING)
INCLUDE(ExternalProject) INCLUDE(ExternalProject)
SET(GTEST_SOURCES_DIR ${THIRD_PARTY_PATH}/gtest) SET(GTEST_SOURCES_DIR ${THIRD_PARTY_PATH}/gtest)
...@@ -76,4 +80,4 @@ IF(WITH_TESTING) ...@@ -76,4 +80,4 @@ IF(WITH_TESTING)
ADD_DEPENDENCIES(gtest_main extern_gtest) ADD_DEPENDENCIES(gtest_main extern_gtest)
LIST(APPEND external_project_dependencies gtest gtest_main) LIST(APPEND external_project_dependencies gtest gtest_main)
ENDIF(WITH_TESTING) ENDIF(WITH_TESTING OR (WITH_DISTRIBUTE AND NOT WITH_GRPC))
...@@ -24,8 +24,8 @@ ExternalProject_Add( ...@@ -24,8 +24,8 @@ ExternalProject_Add(
extern_leveldb extern_leveldb
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${LEVELDB_SOURCES_DIR} PREFIX ${LEVELDB_SOURCES_DIR}
URL "https://github.com/google/leveldb/archive/v1.18.tar.gz" GIT_REPOSITORY "https://github.com/google/leveldb"
URL_MD5 "73770de34a2a5ab34498d2e05b2b7fa0" GIT_TAG v1.18
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""
BUILD_COMMAND CXXFLAGS=-fPIC make -j ${NUM_OF_PROCESSOR} libleveldb.a BUILD_COMMAND CXXFLAGS=-fPIC make -j ${NUM_OF_PROCESSOR} libleveldb.a
INSTALL_COMMAND mkdir -p ${LEVELDB_INSTALL_DIR}/lib/ INSTALL_COMMAND mkdir -p ${LEVELDB_INSTALL_DIR}/lib/
......
...@@ -77,6 +77,8 @@ paddle.fluid.layers.sequence_softmax ArgSpec(args=['input', 'use_cudnn', 'name'] ...@@ -77,6 +77,8 @@ paddle.fluid.layers.sequence_softmax ArgSpec(args=['input', 'use_cudnn', 'name']
paddle.fluid.layers.softmax ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(True, None)) paddle.fluid.layers.softmax ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(True, None))
paddle.fluid.layers.pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True)) paddle.fluid.layers.pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True))
paddle.fluid.layers.pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True)) paddle.fluid.layers.pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True))
paddle.fluid.layers.adaptive_pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'require_index', 'name'], varargs=None, keywords=None, defaults=('max', False, None))
paddle.fluid.layers.adaptive_pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'require_index', 'name'], varargs=None, keywords=None, defaults=('max', False, None))
paddle.fluid.layers.batch_norm ArgSpec(args=['input', 'act', 'is_test', 'momentum', 'epsilon', 'param_attr', 'bias_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var', 'fuse_with_relu', 'use_global_stats'], varargs=None, keywords=None, defaults=(None, False, 0.9, 1e-05, None, None, 'NCHW', False, None, None, None, False, False, False)) paddle.fluid.layers.batch_norm ArgSpec(args=['input', 'act', 'is_test', 'momentum', 'epsilon', 'param_attr', 'bias_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var', 'fuse_with_relu', 'use_global_stats'], varargs=None, keywords=None, defaults=(None, False, 0.9, 1e-05, None, None, 'NCHW', False, None, None, None, False, False, False))
paddle.fluid.layers.beam_search_decode ArgSpec(args=['ids', 'scores', 'beam_size', 'end_id', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.beam_search_decode ArgSpec(args=['ids', 'scores', 'beam_size', 'end_id', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.conv2d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None)) paddle.fluid.layers.conv2d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None))
......
...@@ -169,9 +169,12 @@ cc_library(variable_helper SRCS variable_helper.cc DEPS lod_tensor) ...@@ -169,9 +169,12 @@ cc_library(variable_helper SRCS variable_helper.cc DEPS lod_tensor)
cc_library(naive_executor SRCS naive_executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper) cc_library(naive_executor SRCS naive_executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc cares grpc++_unsecure grpc_unsecure gpr graph_to_program_pass variable_helper) cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") lod_rank_table feed_fetch_method sendrecvop_rpc ${GLOB_DISTRIBUTE_DEPS} graph_to_program_pass variable_helper)
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else() else()
if(WITH_NGRAPH) if(WITH_NGRAPH)
if(NOT WIN32) if(NOT WIN32)
......
...@@ -85,7 +85,7 @@ void TransDataLayout(const OpKernelType& kernel_type_for_var, ...@@ -85,7 +85,7 @@ void TransDataLayout(const OpKernelType& kernel_type_for_var,
out->mutable_data(expected_kernel_type.place_, in.type()); out->mutable_data(expected_kernel_type.place_, in.type());
framework::VisitDataType( framework::VisitDataType(
framework::ToDataType(in.type()), in.type(),
CastDataLayout(pool.Get(expected_kernel_type.place_), axis, in, out)); CastDataLayout(pool.Get(expected_kernel_type.place_), axis, in, out));
out->set_layout(expected_kernel_type.data_layout_); out->set_layout(expected_kernel_type.data_layout_);
...@@ -101,7 +101,7 @@ void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) { ...@@ -101,7 +101,7 @@ void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) {
case mkldnn::memory::data_type::f32: case mkldnn::memory::data_type::f32:
return platform::to_void_cast(tensor.data<float>()); return platform::to_void_cast(tensor.data<float>());
case mkldnn::memory::data_type::s8: case mkldnn::memory::data_type::s8:
return platform::to_void_cast(tensor.data<char>()); return platform::to_void_cast(tensor.data<int8_t>());
case mkldnn::memory::data_type::u8: case mkldnn::memory::data_type::u8:
return platform::to_void_cast(tensor.data<unsigned char>()); return platform::to_void_cast(tensor.data<unsigned char>());
case mkldnn::memory::data_type::s16: case mkldnn::memory::data_type::s16:
...@@ -144,7 +144,7 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, ...@@ -144,7 +144,7 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
memory::data_type in_type = ToMKLDNNDataType(in.type()); memory::data_type in_type = ToMKLDNNDataType(in.type());
PADDLE_ENFORCE(in_type != memory::data_type::data_undef, PADDLE_ENFORCE(in_type != memory::data_type::data_undef,
"Input tensor type is not supported: ", in.type().name()); "Input tensor type is not supported: %s", in.type());
memory::data_type out_type = in_type; memory::data_type out_type = in_type;
auto in_format = platform::MKLDNNFormatForSize(in_tz.size(), in.format()); auto in_format = platform::MKLDNNFormatForSize(in_tz.size(), in.format());
......
...@@ -50,14 +50,14 @@ inline DataLayout ToPaddleLayout(const MKLDNNFormat& format) { ...@@ -50,14 +50,14 @@ inline DataLayout ToPaddleLayout(const MKLDNNFormat& format) {
} }
} }
inline MKLDNNDataType ToMKLDNNDataType(const std::type_index type) { inline MKLDNNDataType ToMKLDNNDataType(proto::VarType::Type type) {
static const std::map<std::type_index, MKLDNNDataType> dict{ static std::unordered_map<int, MKLDNNDataType> dict{
{std::type_index(typeid(float)), MKLDNNDataType::f32}, // NOLINT {DataTypeTrait<float>::DataType, MKLDNNDataType::f32},
{std::type_index(typeid(char)), MKLDNNDataType::s8}, // NOLINT {DataTypeTrait<int8_t>::DataType, MKLDNNDataType::s8},
{std::type_index(typeid(unsigned char)), MKLDNNDataType::u8}, {DataTypeTrait<uint8_t>::DataType, MKLDNNDataType::u8},
{std::type_index(typeid(int16_t)), MKLDNNDataType::s16}, {DataTypeTrait<int16_t>::DataType, MKLDNNDataType::s16},
{std::type_index(typeid(int32_t)), MKLDNNDataType::s32}}; {DataTypeTrait<int32_t>::DataType, MKLDNNDataType::s32}};
auto iter = dict.find(type); auto iter = dict.find(static_cast<int>(type));
if (iter != dict.end()) return iter->second; if (iter != dict.end()) return iter->second;
return MKLDNNDataType::data_undef; return MKLDNNDataType::data_undef;
} }
......
...@@ -26,7 +26,7 @@ struct DataTypeMap { ...@@ -26,7 +26,7 @@ struct DataTypeMap {
std::unordered_map<std::type_index, proto::VarType::Type> cpp_to_proto_; std::unordered_map<std::type_index, proto::VarType::Type> cpp_to_proto_;
std::unordered_map<int, std::type_index> proto_to_cpp_; std::unordered_map<int, std::type_index> proto_to_cpp_;
std::unordered_map<int, std::string> proto_to_str_; std::unordered_map<int, std::string> proto_to_str_;
std::unordered_map<std::type_index, size_t> cpp_to_size_; std::unordered_map<int, size_t> proto_to_size_;
}; };
static DataTypeMap* InitDataTypeMap(); static DataTypeMap* InitDataTypeMap();
...@@ -45,7 +45,7 @@ static inline void RegisterType(DataTypeMap* map, ...@@ -45,7 +45,7 @@ static inline void RegisterType(DataTypeMap* map,
map->proto_to_cpp_.emplace(static_cast<int>(proto_type), typeid(T)); map->proto_to_cpp_.emplace(static_cast<int>(proto_type), typeid(T));
map->cpp_to_proto_.emplace(typeid(T), proto_type); map->cpp_to_proto_.emplace(typeid(T), proto_type);
map->proto_to_str_.emplace(static_cast<int>(proto_type), name); map->proto_to_str_.emplace(static_cast<int>(proto_type), name);
map->cpp_to_size_.emplace(typeid(T), sizeof(T)); map->proto_to_size_.emplace(static_cast<int>(proto_type), sizeof(T));
} }
static DataTypeMap* InitDataTypeMap() { static DataTypeMap* InitDataTypeMap() {
...@@ -54,17 +54,7 @@ static DataTypeMap* InitDataTypeMap() { ...@@ -54,17 +54,7 @@ static DataTypeMap* InitDataTypeMap() {
#define RegType(cc_type, proto_type) \ #define RegType(cc_type, proto_type) \
RegisterType<cc_type>(retv, proto_type, #cc_type) RegisterType<cc_type>(retv, proto_type, #cc_type)
// NOTE: Add your customize type here. _ForEachDataType_(RegType);
RegType(float16, proto::VarType::FP16);
RegType(float, proto::VarType::FP32);
RegType(double, proto::VarType::FP64);
RegType(int, proto::VarType::INT32);
RegType(int64_t, proto::VarType::INT64);
RegType(bool, proto::VarType::BOOL);
RegType(size_t, proto::VarType::SIZE_T);
RegType(int16_t, proto::VarType::INT16);
RegType(uint8_t, proto::VarType::UINT8);
RegType(int8_t, proto::VarType::INT8);
#undef RegType #undef RegType
return retv; return retv;
...@@ -96,12 +86,12 @@ std::string DataTypeToString(const proto::VarType::Type type) { ...@@ -96,12 +86,12 @@ std::string DataTypeToString(const proto::VarType::Type type) {
static_cast<int>(type)); static_cast<int>(type));
} }
size_t SizeOfType(std::type_index type) { size_t SizeOfType(proto::VarType::Type type) {
auto it = gDataTypeMap().cpp_to_size_.find(type); auto it = gDataTypeMap().proto_to_size_.find(static_cast<int>(type));
if (it != gDataTypeMap().cpp_to_size_.end()) { if (it != gDataTypeMap().proto_to_size_.end()) {
return it->second; return it->second;
} }
PADDLE_THROW("Not support %s as tensor type", type.name()); PADDLE_THROW("Not support %s as tensor type", DataTypeToString(type));
} }
} // namespace framework } // namespace framework
......
...@@ -22,46 +22,59 @@ limitations under the License. */ ...@@ -22,46 +22,59 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
template <typename T>
struct DataTypeTrait {};
// Stub handle for void
template <>
struct DataTypeTrait<void> {
constexpr static auto DataType = proto::VarType::RAW;
};
#define _ForEachDataTypeHelper_(callback, cpp_type, proto_type) \
callback(cpp_type, ::paddle::framework::proto::VarType::proto_type);
#define _ForEachDataType_(callback) \
_ForEachDataTypeHelper_(callback, float, FP32); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \
_ForEachDataTypeHelper_(callback, double, FP64); \
_ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
_ForEachDataTypeHelper_(callback, bool, BOOL); \
_ForEachDataTypeHelper_(callback, uint8_t, UINT8); \
_ForEachDataTypeHelper_(callback, int16_t, INT16); \
_ForEachDataTypeHelper_(callback, int8_t, INT8)
#define DefineDataTypeTrait(cpp_type, proto_type) \
template <> \
struct DataTypeTrait<cpp_type> { \
constexpr static auto DataType = proto_type; \
}
_ForEachDataType_(DefineDataTypeTrait);
#undef DefineDataTypeTrait
extern proto::VarType::Type ToDataType(std::type_index type); extern proto::VarType::Type ToDataType(std::type_index type);
extern std::type_index ToTypeIndex(proto::VarType::Type type); extern std::type_index ToTypeIndex(proto::VarType::Type type);
template <typename Visitor> template <typename Visitor>
inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
switch (type) { #define VisitDataTypeCallback(cpp_type, proto_type) \
case proto::VarType::FP16: do { \
visitor.template apply<platform::float16>(); if (type == proto_type) { \
break; visitor.template apply<cpp_type>(); \
case proto::VarType::FP32: return; \
visitor.template apply<float>(); } \
break; } while (0)
case proto::VarType::FP64:
visitor.template apply<double>(); _ForEachDataType_(VisitDataTypeCallback);
break; #undef VisitDataTypeCallback
case proto::VarType::INT32: PADDLE_THROW("Not supported %d", type);
visitor.template apply<int>();
break;
case proto::VarType::INT64:
visitor.template apply<int64_t>();
break;
case proto::VarType::BOOL:
visitor.template apply<bool>();
break;
case proto::VarType::UINT8:
visitor.template apply<uint8_t>();
break;
case proto::VarType::INT16:
visitor.template apply<int16_t>();
break;
case proto::VarType::INT8:
visitor.template apply<int8_t>();
break;
default:
PADDLE_THROW("Not supported %d", type);
}
} }
extern std::string DataTypeToString(const proto::VarType::Type type); extern std::string DataTypeToString(const proto::VarType::Type type);
extern size_t SizeOfType(std::type_index type); extern size_t SizeOfType(proto::VarType::Type type);
inline std::ostream& operator<<(std::ostream& out, inline std::ostream& operator<<(std::ostream& out,
const proto::VarType::Type& type) { const proto::VarType::Type& type) {
out << DataTypeToString(type); out << DataTypeToString(type);
......
...@@ -26,15 +26,15 @@ TEST(DataType, float16) { ...@@ -26,15 +26,15 @@ TEST(DataType, float16) {
Tensor tensor; Tensor tensor;
CPUPlace cpu; CPUPlace cpu;
tensor.mutable_data(cpu, f::ToTypeIndex(dtype)); tensor.mutable_data(cpu, dtype);
// test fp16 tensor // test fp16 tensor
EXPECT_EQ(tensor.type(), std::type_index(typeid(float16))); EXPECT_EQ(tensor.type(), f::ToDataType(typeid(float16)));
// test fp16 size // test fp16 size
EXPECT_EQ(f::SizeOfType(f::ToTypeIndex(dtype)), 2u); EXPECT_EQ(f::SizeOfType(dtype), 2u);
// test debug info // test debug info
std::string type = "float16"; std::string type = "::paddle::platform::float16";
EXPECT_STREQ(f::DataTypeToString(dtype).c_str(), type.c_str()); EXPECT_STREQ(f::DataTypeToString(dtype).c_str(), type.c_str());
} }
...@@ -12,12 +12,19 @@ cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc ...@@ -12,12 +12,19 @@ cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows) cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
if(WITH_DISTRIBUTE)
if(NOT WITH_GRPC)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(reduce_op_handle.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
endif()
endif()
if(WITH_GPU) if(WITH_GPU)
nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor) dynload_cuda variable_visitor)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim dynload_cuda selected_rows_functor sendrecvop_grpc) ddim dynload_cuda selected_rows_functor sendrecvop_rpc)
else() else()
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim dynload_cuda selected_rows_functor) ddim dynload_cuda selected_rows_functor)
...@@ -30,7 +37,7 @@ else() ...@@ -30,7 +37,7 @@ else()
variable_visitor) variable_visitor)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim selected_rows_functor sendrecvop_grpc) ddim selected_rows_functor sendrecvop_rpc)
else() else()
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim selected_rows_functor) ddim selected_rows_functor)
......
...@@ -127,7 +127,7 @@ void AllReduceOpHandle::RunImpl() { ...@@ -127,7 +127,7 @@ void AllReduceOpHandle::RunImpl() {
// Reduce All Tensor to trg in CPU // Reduce All Tensor to trg in CPU
ReduceLoDTensor func(lod_tensors, &trg); ReduceLoDTensor func(lod_tensors, &trg);
VisitDataType(ToDataType(lod_tensors[0]->type()), func); VisitDataType(lod_tensors[0]->type(), func);
for (size_t i = 1; i < local_scopes_.size(); ++i) { for (size_t i = 1; i < local_scopes_.size(); ++i) {
auto &scope = auto &scope =
......
...@@ -33,7 +33,7 @@ struct FuseVarsOpHandle : public OpHandleBase { ...@@ -33,7 +33,7 @@ struct FuseVarsOpHandle : public OpHandleBase {
FuseVarsOpHandle(ir::Node *node, Scope *local_scope, FuseVarsOpHandle(ir::Node *node, Scope *local_scope,
const platform::Place &place, const platform::Place &place,
const std::unordered_map<std::string, int64_t> &inputs_numel, const std::unordered_map<std::string, int64_t> &inputs_numel,
const std::type_index &var_type) const proto::VarType::Type var_type)
: OpHandleBase(node), : OpHandleBase(node),
local_scope_(local_scope), local_scope_(local_scope),
place_(place), place_(place),
...@@ -57,7 +57,7 @@ struct FuseVarsOpHandle : public OpHandleBase { ...@@ -57,7 +57,7 @@ struct FuseVarsOpHandle : public OpHandleBase {
Scope *local_scope_; Scope *local_scope_;
const platform::Place place_; const platform::Place place_;
const std::unordered_map<std::string, int64_t> inputs_numel_; const std::unordered_map<std::string, int64_t> inputs_numel_;
const std::type_index type_; const proto::VarType::Type type_;
int64_t total_numel_; int64_t total_numel_;
}; };
} // namespace details } // namespace details
......
...@@ -218,18 +218,18 @@ void ReduceOpHandle::RunImpl() { ...@@ -218,18 +218,18 @@ void ReduceOpHandle::RunImpl() {
} }
#if defined PADDLE_WITH_CUDA && defined PADDLE_WITH_DISTRIBUTE #if defined PADDLE_WITH_CUDA && defined PADDLE_WITH_DISTRIBUTE
if (framework::IsType<const float>(in_selected_rows[0]->value().type())) { if (in_selected_rows[0]->value().type() ==
framework::proto::VarType::FP32) {
GatherSelectedRows<platform::CUDADeviceContext, float>( GatherSelectedRows<platform::CUDADeviceContext, float>(
in_selected_rows, in_places, dev_ctxes_, out_var_handle, t_out_p, in_selected_rows, in_places, dev_ctxes_, out_var_handle, t_out_p,
out_var->GetMutable<framework::SelectedRows>()); out_var->GetMutable<framework::SelectedRows>());
} else if (framework::IsType<const double>( } else if (in_selected_rows[0]->value().type() ==
in_selected_rows[0]->value().type())) { framework::proto::VarType::FP64) {
GatherSelectedRows<platform::CUDADeviceContext, double>( GatherSelectedRows<platform::CUDADeviceContext, double>(
in_selected_rows, in_places, dev_ctxes_, out_var_handle, t_out_p, in_selected_rows, in_places, dev_ctxes_, out_var_handle, t_out_p,
out_var->GetMutable<framework::SelectedRows>()); out_var->GetMutable<framework::SelectedRows>());
} else { } else {
PADDLE_ENFORCE(false, PADDLE_THROW("only support double or float when gather SelectedRows");
"only support double or float when gahter SelectedRows");
} }
#endif #endif
}); });
...@@ -246,7 +246,7 @@ void ReduceOpHandle::RunImpl() { ...@@ -246,7 +246,7 @@ void ReduceOpHandle::RunImpl() {
if (!FLAGS_cpu_deterministic) { if (!FLAGS_cpu_deterministic) {
ReduceLoDTensor func(lod_tensors, ReduceLoDTensor func(lod_tensors,
out_var->GetMutable<framework::LoDTensor>()); out_var->GetMutable<framework::LoDTensor>());
VisitDataType(ToDataType(lod_tensors[0]->type()), func); VisitDataType(lod_tensors[0]->type(), func);
} else { } else {
// We sum lod_tensors to reduce_sum_trg which is in local_scopes_0 // We sum lod_tensors to reduce_sum_trg which is in local_scopes_0
// here, but it doesn't mean reduce_sum_trg must be in local_scopes_0. // here, but it doesn't mean reduce_sum_trg must be in local_scopes_0.
...@@ -256,7 +256,7 @@ void ReduceOpHandle::RunImpl() { ...@@ -256,7 +256,7 @@ void ReduceOpHandle::RunImpl() {
->FindVar(out_var_handle->name_) ->FindVar(out_var_handle->name_)
->GetMutable<framework::LoDTensor>(); ->GetMutable<framework::LoDTensor>();
ReduceLoDTensor func(lod_tensors, &reduce_sum_trg); ReduceLoDTensor func(lod_tensors, &reduce_sum_trg);
VisitDataType(ToDataType(lod_tensors[0]->type()), func); VisitDataType(lod_tensors[0]->type(), func);
auto trg = out_var->GetMutable<framework::LoDTensor>(); auto trg = out_var->GetMutable<framework::LoDTensor>();
if (reduce_sum_trg.data<void>() != trg->data<void>()) { if (reduce_sum_trg.data<void>() != trg->data<void>()) {
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/dlpack_tensor.h" #include "paddle/fluid/framework/dlpack_tensor.h"
#include "paddle/fluid/framework/data_type.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -36,26 +36,23 @@ static ::DLDataType GetDLDataTypeCode() { ...@@ -36,26 +36,23 @@ static ::DLDataType GetDLDataTypeCode() {
return dtype; return dtype;
} }
static DLDataType GetDLDataTypeFromTypeIndex(const std::type_index &type) { static std::unordered_map<int, ::DLDataType> CreateDLDataTypeMap() {
#define REG_DL_DATA_TYPE(type) \ static std::unordered_map<int, ::DLDataType> result;
{ std::type_index(typeid(type)), GetDLDataTypeCode<type>() }
static const std::unordered_map<std::type_index, ::DLDataType> #define REG_DL_DATA_TYPE(cpp_type, proto_type) \
type_to_dtype_map({ result[static_cast<int>(proto_type)] = GetDLDataTypeCode<cpp_type>()
REG_DL_DATA_TYPE(platform::float16), // NOLINT
REG_DL_DATA_TYPE(float), // NOLINT _ForEachDataType_(REG_DL_DATA_TYPE);
REG_DL_DATA_TYPE(double), // NOLINT #undef REG_DL_DATA_TYPE
REG_DL_DATA_TYPE(int), // NOLINT return result;
REG_DL_DATA_TYPE(int64_t), // NOLINT }
REG_DL_DATA_TYPE(bool), // NOLINT
REG_DL_DATA_TYPE(size_t), // NOLINT static DLDataType GetDLDataTypeFromTypeIndex(proto::VarType::Type type) {
REG_DL_DATA_TYPE(int16_t), // NOLINT static auto type_to_dtype_map = CreateDLDataTypeMap();
REG_DL_DATA_TYPE(uint8_t), // NOLINT
REG_DL_DATA_TYPE(int8_t) // NOLINT
});
static auto type_to_dtype_map_end_it = type_to_dtype_map.end(); static auto type_to_dtype_map_end_it = type_to_dtype_map.end();
auto it = type_to_dtype_map.find(type); auto it = type_to_dtype_map.find(static_cast<int>(type));
PADDLE_ENFORCE(it != type_to_dtype_map_end_it, "Unsupported data type %s", PADDLE_ENFORCE(it != type_to_dtype_map_end_it, "Unsupported data type %d",
type.name()); type);
return it->second; return it->second;
#undef REG_DL_DATA_TYPE #undef REG_DL_DATA_TYPE
} }
......
...@@ -91,23 +91,11 @@ void TestMainLoop() { ...@@ -91,23 +91,11 @@ void TestMainLoop() {
} }
} }
} }
TEST(dlpack, test_all) {
#define TestCallback(cpp_type, proto_type) TestMainLoop<cpp_type>()
#define PADDLE_DLPACK_TEST(type) \ _ForEachDataType_(TestCallback);
TEST(dlpack, test_##type) { TestMainLoop<type>(); } }
using float16 = platform::float16;
PADDLE_DLPACK_TEST(float16);
PADDLE_DLPACK_TEST(float);
PADDLE_DLPACK_TEST(double);
PADDLE_DLPACK_TEST(int);
PADDLE_DLPACK_TEST(int64_t);
PADDLE_DLPACK_TEST(bool);
PADDLE_DLPACK_TEST(size_t);
PADDLE_DLPACK_TEST(int16_t);
PADDLE_DLPACK_TEST(uint8_t);
PADDLE_DLPACK_TEST(int8_t);
#undef PADDLE_DLPACK_TEST
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -157,9 +157,9 @@ void Executor::Close() { ...@@ -157,9 +157,9 @@ void Executor::Close() {
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
// TODO(typhoonzero): complete message will need to use real trainer_id, // TODO(typhoonzero): complete message will need to use real trainer_id,
// except 0. // except 0.
::paddle::operators::distributed::RPCClient::GetInstance< auto client =
::paddle::operators::distributed::GRPCClient>(0) paddle::operators::distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
->SendComplete(); client->SendComplete();
#endif #endif
} }
......
...@@ -139,39 +139,19 @@ void print_lod_tensor(std::string var_name, const LoDTensor& lod_tensor) { ...@@ -139,39 +139,19 @@ void print_lod_tensor(std::string var_name, const LoDTensor& lod_tensor) {
std::cout << sstream.str() << std::endl; std::cout << sstream.str() << std::endl;
} }
void print_fetch_var(Scope* scope, std::string var_name) { static void print_fetch_var(Scope* scope, const std::string& var_name) {
const LoDTensor& tensor = scope->FindVar(var_name)->Get<LoDTensor>(); auto& tensor = scope->FindVar(var_name)->Get<LoDTensor>();
if (std::type_index(tensor.type()) == #define PrintLoDTensorCallback(cpp_type, proto_type) \
std::type_index(typeid(platform::float16))) { do { \
print_lod_tensor<platform::float16>(var_name, tensor); if (tensor.type() == proto_type) { \
} else if (std::type_index(tensor.type()) == std::type_index(typeid(float))) { print_lod_tensor<cpp_type>(var_name, tensor); \
print_lod_tensor<float>(var_name, tensor); return; \
} else if (std::type_index(tensor.type()) == } \
std::type_index(typeid(double))) { } while (0)
print_lod_tensor<double>(var_name, tensor);
} else if (std::type_index(tensor.type()) == std::type_index(typeid(int))) { _ForEachDataType_(PrintLoDTensorCallback);
print_lod_tensor<int>(var_name, tensor); VLOG(1) << "print_fetch_var: unrecognized data type:" << tensor.type();
} else if (std::type_index(tensor.type()) ==
std::type_index(typeid(int64_t))) {
print_lod_tensor<int64_t>(var_name, tensor);
} else if (std::type_index(tensor.type()) == std::type_index(typeid(bool))) {
print_lod_tensor<bool>(var_name, tensor);
} else if (std::type_index(tensor.type()) ==
std::type_index(typeid(uint8_t))) {
print_lod_tensor<uint8_t>(var_name, tensor);
} else if (std::type_index(tensor.type()) ==
std::type_index(typeid(int16_t))) {
print_lod_tensor<int16_t>(var_name, tensor);
} else if (std::type_index(tensor.type()) ==
std::type_index(typeid(int8_t))) {
print_lod_tensor<int8_t>(var_name, tensor);
} else {
VLOG(1) << "print_fetch_var: unrecognized data type:"
<< tensor.type().name();
}
return;
} }
void ExecutorThreadWorker::TrainFiles() { void ExecutorThreadWorker::TrainFiles() {
......
...@@ -42,6 +42,8 @@ pass_library(multi_batch_merge_pass base) ...@@ -42,6 +42,8 @@ pass_library(multi_batch_merge_pass base)
pass_library(conv_bn_fuse_pass inference) pass_library(conv_bn_fuse_pass inference)
pass_library(seqconv_eltadd_relu_fuse_pass inference) pass_library(seqconv_eltadd_relu_fuse_pass inference)
pass_library(is_test_pass base) pass_library(is_test_pass base)
pass_library(conv_elementwise_add_act_fuse_pass inference)
pass_library(conv_elementwise_add2_act_fuse_pass inference)
if(WITH_MKLDNN) if(WITH_MKLDNN)
pass_library(mkldnn_placement_pass base) pass_library(mkldnn_placement_pass base)
pass_library(depthwise_conv_mkldnn_pass base) pass_library(depthwise_conv_mkldnn_pass base)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.h"
namespace paddle {
namespace framework {
namespace ir {
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(conv_op); \
GET_IR_NODE(conv_out); \
GET_IR_NODE(conv_filter); \
GET_IR_NODE(elementwise_add_op); \
GET_IR_NODE(elementwise_add_in_y); \
GET_IR_NODE(elementwise_add_out); \
GET_IR_NODE(elementwise_add_op_1); \
GET_IR_NODE(elementwise_add_in_y_1); \
GET_IR_NODE(elementwise_add_out_1); \
GET_IR_NODE(act_op); \
GET_IR_NODE(act_out);
// Inherient the basic infomation from `base_desc`, and modify some fields.
framework::proto::OpDesc PrepareOpDesc(
const framework::proto::OpDesc& base_desc, const std::string& bias,
const std::string& bias1, const std::string& activation,
const std::string& output) {
auto proto = base_desc;
framework::OpDesc desc(proto, nullptr);
desc.SetInput("Bias", {bias});
desc.SetInput("ResidualData", {bias1});
desc.SetAttr("activation", activation);
desc.SetOutput("Output", {output});
desc.SetAttr("is_test", true);
desc.SetAttr("use_cudnn", false);
return *desc.Proto();
}
std::unique_ptr<ir::Graph> ConvElementwiseAddActFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
const std::string pattern_name = "conv_elementwise_add_act_fuse";
FusePassBase::Init(pattern_name, graph.get());
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()->NewNode("x")->AsInput()->assert_is_op_input(
"conv2d", "Input");
patterns::ConvElementwiseaddAct pattern(gpd.mutable_pattern(), pattern_name);
pattern(x);
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_NODES;
auto base_op_desc = *conv_op->Op()->Proto();
std::string bias_name = elementwise_add_in_y->Name();
std::string bias1_name = elementwise_add_in_y_1->Name();
std::string act_op_type = act_op->Op()->Type();
std::string act_op_out = act_out->Name();
auto new_op_proto = PrepareOpDesc(base_op_desc, bias_name, bias1_name,
act_op_type, act_op_out);
framework::OpDesc new_op_desc(new_op_proto, nullptr);
// Create a new node for the fused op.
auto new_conv_op = graph->CreateOpNode(&new_op_desc);
// Link inputs and outputs.
PADDLE_ENFORCE(subgraph.count(x));
auto* conv_in_node = subgraph.at(x);
IR_NODE_LINK_TO(conv_in_node, new_conv_op); // Input
IR_NODE_LINK_TO(conv_filter, new_conv_op); // Filter
IR_NODE_LINK_TO(elementwise_add_in_y, new_conv_op); // Bias
IR_NODE_LINK_TO(elementwise_add_in_y_1, new_conv_op); // ResidualData
IR_NODE_LINK_TO(new_conv_op, act_out); // Output
// Delete the unneeded nodes.
GraphSafeRemoveNodes(graph.get(),
{conv_op, elementwise_add_op, elementwise_add_op_1,
elementwise_add_out});
};
gpd(graph.get(), handler);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(conv_elementwise_add2_act_fuse_pass,
paddle::framework::ir::ConvElementwiseAdd2ActFusePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.h"
#include <string>
namespace paddle {
namespace framework {
namespace ir {
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(conv_op); \
GET_IR_NODE(conv_out); \
GET_IR_NODE(conv_filter); \
GET_IR_NODE(elementwise_add_op); \
GET_IR_NODE(elementwise_add_in_y); \
GET_IR_NODE(elementwise_add_out); \
GET_IR_NODE(elementwise_add_op_1); \
GET_IR_NODE(elementwise_add_in_y_1); \
GET_IR_NODE(elementwise_add_out_1); \
GET_IR_NODE(act_op); \
GET_IR_NODE(act_out);
// Inherient the basic infomation from `base_desc`, and modify some fields.
framework::proto::OpDesc PrepareOpDesc(
const framework::proto::OpDesc& base_desc, const std::string& bias,
const std::string& bias1, const std::string& activation,
const std::string& output) {
auto proto = base_desc;
framework::OpDesc desc(proto, nullptr);
desc.SetInput("Bias", {bias});
desc.SetInput("ResidualData", {bias1});
desc.SetAttr("activation", activation);
desc.SetOutput("Output", {output});
desc.SetAttr("is_test", true);
return *desc.Proto();
}
std::unique_ptr<ir::Graph> ConvElementwiseAdd2ActFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
const std::string pattern_name = "conv_elementwise_add_act_fuse";
FusePassBase::Init(pattern_name, graph.get());
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()->NewNode("x")->AsInput()->assert_is_op_input(
"conv2d", "Input");
patterns::ConvElementwiseadd2Act pattern(gpd.mutable_pattern(), pattern_name);
pattern(x);
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_NODES;
auto base_op_desc = *conv_op->Op()->Proto();
std::string bias_name = elementwise_add_in_y->Name();
std::string bias1_name = elementwise_add_in_y_1->Name();
std::string act_op_type = act_op->Op()->Type();
std::string act_op_out = act_out->Name();
auto new_op_proto = PrepareOpDesc(base_op_desc, bias_name, bias1_name,
act_op_type, act_op_out);
framework::OpDesc new_op_desc(new_op_proto, nullptr);
// Create a new node for the fused op.
graph->CreateOpNode(&new_op_desc);
// Link inputs and outputs.
PADDLE_ENFORCE(subgraph.count(x));
auto* conv_in_node = subgraph.at(x);
IR_NODE_LINK_TO(conv_in_node, conv_op); // Input
IR_NODE_LINK_TO(conv_filter, conv_op); // Filter
IR_NODE_LINK_TO(conv_op, conv_out); // Output
IR_NODE_LINK_TO(elementwise_add_in_y, conv_op); // Bias
IR_NODE_LINK_TO(elementwise_add_in_y_1, conv_op); // Bias
// Delete the unneeded nodes.
GraphSafeRemoveNodes(graph.get(),
{conv_op, elementwise_add_op, elementwise_add_op_1,
elementwise_add_out});
};
gpd(graph.get(), handler);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(conv_elementwise_add2_act_fuse_pass,
paddle::framework::ir::ConvElementwiseAdd2ActFusePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
class ConvElementwiseAdd2ActFusePass : public FusePassBase {
public:
virtual ~ConvElementwiseAdd2ActFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
namespace paddle {
namespace framework {
namespace ir {
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(conv_op); \
GET_IR_NODE(conv_out); \
GET_IR_NODE(conv_filter); \
GET_IR_NODE(elementwise_add_op); \
GET_IR_NODE(elementwise_add_in_y); \
GET_IR_NODE(elementwise_add_out); \
GET_IR_NODE(act_op); \
GET_IR_NODE(act_out);
// Inherient the basic infomation from `base_desc`, and modify some fields.
framework::proto::OpDesc PrepareOpDesc(
const framework::proto::OpDesc& base_desc, const std::string& bias,
const std::string& activation, const std::string& output) {
auto proto = base_desc;
framework::OpDesc desc(proto, nullptr);
desc.SetType("conv2d_fusion");
desc.SetInput("Bias", {bias});
desc.SetInput("ResidualData", {});
desc.SetAttr("activation", activation);
desc.SetOutput("Output", {output});
desc.SetAttr("is_test", true);
desc.SetAttr("use_cudnn", false);
desc.Flush();
return *desc.Proto();
}
std::unique_ptr<ir::Graph> ConvElementwiseAddActFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
const std::string pattern_name = "conv_elementwise_add_act_fuse";
FusePassBase::Init(pattern_name, graph.get());
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()
->NewNode("x")
->assert_is_op_input("conv2d", "Input")
->AsInput();
patterns::ConvElementwiseaddAct pattern(gpd.mutable_pattern(), pattern_name);
pattern(x);
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_NODES;
auto base_op_desc = *conv_op->Op()->Proto();
std::string bias_name = elementwise_add_in_y->Name();
std::string act_op_type = act_op->Op()->Type();
std::string act_op_out = act_out->Name();
auto new_op_proto =
PrepareOpDesc(base_op_desc, bias_name, act_op_type, act_op_out);
framework::OpDesc new_op_desc(new_op_proto, nullptr);
// Create a new node for the fused op.
auto* new_conv_op = graph->CreateOpNode(&new_op_desc);
// Link inputs and outputs.
PADDLE_ENFORCE(subgraph.count(x));
auto* conv_in_node = subgraph.at(x);
IR_NODE_LINK_TO(conv_in_node, new_conv_op); // Input
IR_NODE_LINK_TO(conv_filter, new_conv_op); // Filter
IR_NODE_LINK_TO(elementwise_add_in_y, new_conv_op); // Bias
IR_NODE_LINK_TO(new_conv_op, act_out); // Output
// Delete the unneeded nodes.
GraphSafeRemoveNodes(graph.get(), {conv_op, conv_out, elementwise_add_op,
elementwise_add_out, act_op});
};
gpd(graph.get(), handler);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(conv_elementwise_add_act_fuse_pass,
paddle::framework::ir::ConvElementwiseAddActFusePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
class ConvElementwiseAddActFusePass : public FusePassBase {
public:
virtual ~ConvElementwiseAddActFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_traits.h" #include "paddle/fluid/framework/ir/graph_traits.h"
...@@ -25,6 +26,7 @@ ...@@ -25,6 +26,7 @@
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/pretty_log.h" #include "paddle/fluid/string/pretty_log.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
...@@ -104,7 +106,7 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph &graph) { ...@@ -104,7 +106,7 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph &graph) {
for (auto &node : GraphTraits::DFS(graph)) { for (auto &node : GraphTraits::DFS(graph)) {
for (const auto &pdnode : pattern_.nodes()) { for (const auto &pdnode : pattern_.nodes()) {
if (pdnode->Tell(&node)) { if (pdnode->Tell(&node)) {
VLOG(4) << "pdnode " << pdnode->name() << " marked"; VLOG(4) << "Node " << node.Name() << " marked as " << pdnode->name();
pdnodes2nodes_[pdnode.get()].insert(&node); pdnodes2nodes_[pdnode.get()].insert(&node);
} }
} }
...@@ -1099,6 +1101,115 @@ PDNode *patterns::ElementwiseAdd::operator()(PDNode *x_var, PDNode *y_var) { ...@@ -1099,6 +1101,115 @@ PDNode *patterns::ElementwiseAdd::operator()(PDNode *x_var, PDNode *y_var) {
return out_var; return out_var;
} }
std::unordered_set<std::string> conv_act_set({"identity", "sigmoid", "relu",
"relu6", "relux", "tanh",
"band_pass"});
PDNode *patterns::ConvElementwiseaddAct::operator()(PDNode *conv_in) {
conv_in->AsInput();
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
auto conv_out = pattern->NewNode(conv_out_repr())
->assert_is_op_output("conv2d")
->assert_is_op_input("elementwise_add", "X")
->AsIntermediate();
auto conv_filter = pattern->NewNode(conv_filter_repr())
->assert_is_op_input("conv2d", "Filter")
->AsInput();
auto elementwise_add_op = pattern->NewNode(elementwise_add_op_repr())
->assert_is_op("elementwise_add");
auto elementwise_add_in_y = pattern->NewNode(elementwise_add_in_y_repr())
->assert_is_op_input("elementwise_add", "Y")
->AsInput();
auto elementwise_add_out = pattern->NewNode(elementwise_add_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate();
auto act_op = pattern->NewNode(act_op_repr())
->assert_is_op()
->assert_more([&](Node *node) {
auto op_type = node->Name();
return conv_act_set.count(op_type);
});
auto act_out = pattern->NewNode(act_out_repr())
->assert_is_var()
// is activation op's output.
->assert_more([&](Node *node) {
for (auto *in_op : node->inputs) {
if (conv_act_set.count(in_op->Name())) {
return true;
}
}
return false;
})
->AsOutput();
conv_op->LinksFrom({conv_in, conv_filter});
conv_out->LinksFrom({conv_op});
elementwise_add_op->LinksFrom({conv_out, elementwise_add_in_y})
.LinksTo({elementwise_add_out});
act_op->LinksFrom({elementwise_add_out}).LinksTo({act_out});
return act_out;
}
PDNode *patterns::ConvElementwiseadd2Act::operator()(PDNode *conv_in) {
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
auto conv_filter = pattern->NewNode(conv_filter_repr())
->assert_is_op_input("conv2d", "Filter")
->AsInput();
auto conv_out = pattern->NewNode(conv_out_repr())
->assert_is_op_output("conv2d")
->assert_is_op_input("elementwise_add", "X")
->AsIntermediate();
auto elementwise_add_op = pattern->NewNode(elementwise_add_op_repr())
->assert_is_op("elementwise_add");
auto elementwise_add_in_y = pattern->NewNode(elementwise_add_in_y_repr())
->assert_is_op_input("elementwise_add", "Y")
->AsInput();
auto elementwise_add_out = pattern->NewNode(elementwise_add_out_repr())
->assert_is_op_output("elementwise_add")
->assert_is_op_input("elementwise_add", "X")
->AsIntermediate();
auto elementwise_add_op_1 = pattern->NewNode(elementwise_add_op_1_repr())
->assert_is_op("elementwise_add");
auto elementwise_add_in_y_1 = pattern->NewNode(elementwise_add_in_y_1_repr())
->assert_is_op_input("elementwise_add", "Y")
->AsInput();
auto elementwise_add_out_1 = pattern->NewNode(elementwise_add_out_1_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate();
auto act_op = pattern->NewNode(act_op_repr())
->assert_is_op()
->assert_more([&](Node *node) {
auto op_type = node->Name();
return conv_act_set.count(op_type);
});
auto act_out = pattern->NewNode(act_out_repr())
->assert_is_var()
// is activation op's output.
->assert_more([&](Node *node) {
for (auto *in_op : node->inputs) {
if (conv_act_set.count(in_op->Name())) {
return true;
}
}
return false;
})
->AsOutput();
conv_op->LinksFrom({conv_in, conv_filter}).LinksTo({conv_out});
elementwise_add_op->LinksFrom({conv_out, elementwise_add_in_y})
.LinksTo({elementwise_add_out});
elementwise_add_op_1->LinksFrom(
{elementwise_add_out, elementwise_add_in_y_1});
act_op->LinksFrom({elementwise_add_out_1}).LinksTo({act_out});
return act_out;
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -671,6 +671,51 @@ struct ElementwiseAdd : public PatternBase { ...@@ -671,6 +671,51 @@ struct ElementwiseAdd : public PatternBase {
PATTERN_DECL_NODE(elementwise_add_y); PATTERN_DECL_NODE(elementwise_add_y);
PATTERN_DECL_NODE(elementwise_add_out); PATTERN_DECL_NODE(elementwise_add_out);
}; };
// Conv + ElementwiseAdd + an activation
// This pattern can futher fuse the conv related ops after the conv+bn fusion.
struct ConvElementwiseaddAct : public PatternBase {
ConvElementwiseaddAct(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_elementwiseadd_act") {}
PDNode* operator()(PDNode* conv_in);
PATTERN_DECL_NODE(conv_op);
PATTERN_DECL_NODE(conv_out);
PATTERN_DECL_NODE(conv_filter);
PATTERN_DECL_NODE(elementwise_add_op);
PATTERN_DECL_NODE(elementwise_add_in_y); // input
PATTERN_DECL_NODE(elementwise_add_out);
PATTERN_DECL_NODE(act_op);
PATTERN_DECL_NODE(act_out);
};
// Conv + ElementwiseAdd + ElementwiseAdd + Activation
struct ConvElementwiseadd2Act : public PatternBase {
ConvElementwiseadd2Act(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope,
"conv_elementwiseadd2_elementwiseadd_act") {}
PDNode* operator()(PDNode* conv_in);
PATTERN_DECL_NODE(conv_op);
PATTERN_DECL_NODE(conv_filter);
PATTERN_DECL_NODE(conv_out);
PATTERN_DECL_NODE(elementwise_add_op);
PATTERN_DECL_NODE(elementwise_add_in_y); // input
PATTERN_DECL_NODE(elementwise_add_out);
PATTERN_DECL_NODE(elementwise_add_op_1);
PATTERN_DECL_NODE(elementwise_add_in_y_1); // input
PATTERN_DECL_NODE(elementwise_add_out_1);
PATTERN_DECL_NODE(act_op);
PATTERN_DECL_NODE(act_out);
};
} // namespace patterns } // namespace patterns
// Link two ir::Nodes from each other. // Link two ir::Nodes from each other.
......
...@@ -70,9 +70,9 @@ std::ostream &operator<<(std::ostream &os, const LoDTensor &t) { ...@@ -70,9 +70,9 @@ std::ostream &operator<<(std::ostream &os, const LoDTensor &t) {
// only print first ten elements // only print first ten elements
int64_t size = t.numel() < 10 ? t.numel() : 10; int64_t size = t.numel() < 10 ? t.numel() : 10;
for (int64_t i = 0; i < size; ++i) { for (int64_t i = 0; i < size; ++i) {
if (IsType<float>(t.type())) { if (t.type() == proto::VarType::FP32) {
os << t.data<float>()[i] << " "; os << t.data<float>()[i] << " ";
} else if (IsType<int64_t>(t.type())) { } else if (t.type() == proto::VarType::INT64) {
os << t.data<int64_t>()[i] << " "; os << t.data<int64_t>()[i] << " ";
} else { } else {
PADDLE_THROW("LoDTensor data type not in [float, int64_t]"); PADDLE_THROW("LoDTensor data type not in [float, int64_t]");
...@@ -387,7 +387,7 @@ void LoDTensor::MergeLoDTensor( ...@@ -387,7 +387,7 @@ void LoDTensor::MergeLoDTensor(
PADDLE_ENFORCE(!lod_tensors.empty()); PADDLE_ENFORCE(!lod_tensors.empty());
framework::DDim new_dim = lod_tensors[0]->dims(); framework::DDim new_dim = lod_tensors[0]->dims();
std::type_index new_type = lod_tensors[0]->type(); auto new_type = lod_tensors[0]->type();
framework::DataLayout new_layout = lod_tensors[0]->layout(); framework::DataLayout new_layout = lod_tensors[0]->layout();
LoD new_lod = lod_tensors[0]->lod(); LoD new_lod = lod_tensors[0]->lod();
for (size_t i = 1; i < lod_tensors.size(); ++i) { for (size_t i = 1; i < lod_tensors.size(); ++i) {
......
...@@ -471,27 +471,23 @@ void NgraphEngine::Run(const Scope& scope, const platform::Place& place) const { ...@@ -471,27 +471,23 @@ void NgraphEngine::Run(const Scope& scope, const platform::Place& place) const {
auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var); auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var);
PADDLE_ENFORCE(sp == Ddim2Shape(tensor_pd->dims()), PADDLE_ENFORCE(sp == Ddim2Shape(tensor_pd->dims()),
"Ensure ngraph tensor layout align with paddle tensor"); "Ensure ngraph tensor layout align with paddle tensor");
if (tensor_pd->type().hash_code() == if (tensor_pd->type() == proto::VarType::FP32) {
typeid(float).hash_code()) { // NOLINT
const float* arr = tensor_pd->data<float>(); const float* arr = tensor_pd->data<float>();
ti = backend_->create_tensor(ngraph::element::f32, sp, ti = backend_->create_tensor(ngraph::element::f32, sp,
const_cast<float*>(arr)); const_cast<float*>(arr));
} else if (tensor_pd->type().hash_code() == } else if (tensor_pd->type() == proto::VarType::INT32) {
typeid(int).hash_code()) { // NOLINT
const int* arr = tensor_pd->data<int>(); const int* arr = tensor_pd->data<int>();
ti = backend_->create_tensor(ngraph::element::i32, sp, ti = backend_->create_tensor(ngraph::element::i32, sp,
const_cast<int*>(arr)); const_cast<int*>(arr));
} else if (tensor_pd->type().hash_code() == typeid(int64_t).hash_code()) { } else if (tensor_pd->type() == proto::VarType::INT64) {
const int64_t* arr = tensor_pd->data<int64_t>(); const int64_t* arr = tensor_pd->data<int64_t>();
ti = backend_->create_tensor(ngraph::element::i64, sp, ti = backend_->create_tensor(ngraph::element::i64, sp,
const_cast<int64_t*>(arr)); const_cast<int64_t*>(arr));
} else if (tensor_pd->type().hash_code() == } else if (tensor_pd->type() == proto::VarType::FP64) {
typeid(double).hash_code()) { // NOLINT
const double* arr = tensor_pd->data<double>(); const double* arr = tensor_pd->data<double>();
ti = backend_->create_tensor(ngraph::element::f64, sp, ti = backend_->create_tensor(ngraph::element::f64, sp,
const_cast<double*>(arr)); const_cast<double*>(arr));
} else if (tensor_pd->type().hash_code() == } else if (tensor_pd->type() == proto::VarType::BOOL) {
typeid(bool).hash_code()) { // NOLINT
const bool* arr = tensor_pd->data<bool>(); const bool* arr = tensor_pd->data<bool>();
ti = backend_->create_tensor(ngraph::element::boolean, sp, ti = backend_->create_tensor(ngraph::element::boolean, sp,
const_cast<bool*>(arr)); const_cast<bool*>(arr));
......
...@@ -34,7 +34,8 @@ TEST(OpKernelType, ToString) { ...@@ -34,7 +34,8 @@ TEST(OpKernelType, ToString) {
OpKernelType op_kernel_type2(DataType::FP16, CUDAPlace(0), DataLayout::kNCHW, OpKernelType op_kernel_type2(DataType::FP16, CUDAPlace(0), DataLayout::kNCHW,
LibraryType::kCUDNN); LibraryType::kCUDNN);
ASSERT_EQ(paddle::framework::KernelTypeToString(op_kernel_type2), ASSERT_EQ(paddle::framework::KernelTypeToString(op_kernel_type2),
"data_type[float16]:data_layout[NCHW]:place[CUDAPlace(0)]:library_" "data_type[::paddle::platform::float16]:data_layout[NCHW]:place["
"CUDAPlace(0)]:library_"
"type[CUDNN]"); "type[CUDNN]");
} }
......
...@@ -45,10 +45,9 @@ std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority = { ...@@ -45,10 +45,9 @@ std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority = {
proto::VarType::Type GetDataTypeOfVar(const Variable* var) { proto::VarType::Type GetDataTypeOfVar(const Variable* var) {
if (var->IsType<framework::LoDTensor>()) { if (var->IsType<framework::LoDTensor>()) {
return framework::ToDataType(var->Get<framework::LoDTensor>().type()); return var->Get<framework::LoDTensor>().type();
} else if (var->IsType<framework::SelectedRows>()) { } else if (var->IsType<framework::SelectedRows>()) {
return framework::ToDataType( return var->Get<framework::SelectedRows>().value().type();
var->Get<framework::SelectedRows>().value().type());
} else { } else {
PADDLE_THROW("Var should be LoDTensor or SelectedRows"); PADDLE_THROW("Var should be LoDTensor or SelectedRows");
} }
...@@ -95,13 +94,13 @@ static std::string GetDtype(const Scope& scope, const std::string& name) { ...@@ -95,13 +94,13 @@ static std::string GetDtype(const Scope& scope, const std::string& name) {
if (UNLIKELY(!tensor.IsInitialized())) { if (UNLIKELY(!tensor.IsInitialized())) {
return ""; return "";
} }
return DataTypeToString(ToDataType(tensor.type())); return DataTypeToString(tensor.type());
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<SelectedRows>()) {
auto tensor = var->Get<SelectedRows>().value(); auto tensor = var->Get<SelectedRows>().value();
if (UNLIKELY(!tensor.IsInitialized())) { if (UNLIKELY(!tensor.IsInitialized())) {
return "uninited"; return "uninited";
} else { } else {
return DataTypeToString(ToDataType(tensor.type())); return DataTypeToString(tensor.type());
} }
} else { } else {
return ""; return "";
...@@ -688,7 +687,8 @@ static void CheckTensorNANOrInf(const std::string& name, ...@@ -688,7 +687,8 @@ static void CheckTensorNANOrInf(const std::string& name,
if (tensor.memory_size() == 0) { if (tensor.memory_size() == 0) {
return; return;
} }
if (!IsType<float>(tensor.type()) && !IsType<double>(tensor.type())) { if (tensor.type() != proto::VarType::FP32 &&
tensor.type() != proto::VarType::FP64) {
return; return;
} }
PADDLE_ENFORCE(!framework::TensorContainsInf(tensor), PADDLE_ENFORCE(!framework::TensorContainsInf(tensor),
...@@ -883,7 +883,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType( ...@@ -883,7 +883,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
if (t != nullptr) { if (t != nullptr) {
PADDLE_ENFORCE(t->IsInitialized(), "Input %s is not initialized: %s", PADDLE_ENFORCE(t->IsInitialized(), "Input %s is not initialized: %s",
ipt_name, DebugString()); ipt_name, DebugString());
int tmp = static_cast<int>(ToDataType(t->type())); int tmp = static_cast<int>(t->type());
PADDLE_ENFORCE( PADDLE_ENFORCE(
tmp == data_type || data_type == -1, tmp == data_type || data_type == -1,
"DataType of Paddle Op %s must be the same. Get %s(%d) != %s(%d)", "DataType of Paddle Op %s must be the same. Get %s(%d) != %s(%d)",
......
...@@ -218,11 +218,11 @@ void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value, ...@@ -218,11 +218,11 @@ void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value,
if (index < 0) { if (index < 0) {
VLOG(5) << "id " << id << " not in the table, return 0"; VLOG(5) << "id " << id << " not in the table, return 0";
framework::VisitDataType( framework::VisitDataType(
framework::ToDataType(value_->type()), value_->type(),
TensorFillVisitor(value, i * value_width, value_width, 0.0)); TensorFillVisitor(value, i * value_width, value_width, 0.0));
} else { } else {
framework::VisitDataType( framework::VisitDataType(
framework::ToDataType(value_->type()), value_->type(),
TensorCopyVisitor(value, i * value_width, *value_.get(), TensorCopyVisitor(value, i * value_width, *value_.get(),
index * value_width, value_width)); index * value_width, value_width));
} }
......
...@@ -16,7 +16,7 @@ limitations under the License. */ ...@@ -16,7 +16,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
extern size_t SizeOfType(std::type_index type); extern size_t SizeOfType(proto::VarType::Type type);
void Tensor::check_memory_size() const { void Tensor::check_memory_size() const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
holder_, "Tensor holds no memory. Call Tensor::mutable_data first."); holder_, "Tensor holds no memory. Call Tensor::mutable_data first.");
...@@ -31,7 +31,7 @@ size_t Tensor::memory_size() const { ...@@ -31,7 +31,7 @@ size_t Tensor::memory_size() const {
return holder_ == nullptr ? 0UL : holder_->size() - offset_; return holder_ == nullptr ? 0UL : holder_->size() - offset_;
} }
void* Tensor::mutable_data(platform::Place place, std::type_index type, void* Tensor::mutable_data(platform::Place place, proto::VarType::Type type,
memory::Allocator::Attr attr, memory::Allocator::Attr attr,
size_t requested_size) { size_t requested_size) {
type_ = type; type_ = type;
......
...@@ -19,9 +19,9 @@ limitations under the License. */ ...@@ -19,9 +19,9 @@ limitations under the License. */
#include <memory> #include <memory>
#include <typeindex> #include <typeindex>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/memory/memory.h" #include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -67,7 +67,7 @@ class Tensor { ...@@ -67,7 +67,7 @@ class Tensor {
friend struct EigenVector; friend struct EigenVector;
public: public:
Tensor() : type_(typeid(float)), offset_(0) {} Tensor() : type_(proto::VarType::FP32), offset_(0) {}
/*! Return a pointer to mutable memory block. */ /*! Return a pointer to mutable memory block. */
template <typename T> template <typename T>
...@@ -88,7 +88,7 @@ class Tensor { ...@@ -88,7 +88,7 @@ class Tensor {
memory::Allocator::Attr attr = memory::Allocator::kDefault, memory::Allocator::Attr attr = memory::Allocator::kDefault,
size_t requested_size = 0); size_t requested_size = 0);
void* mutable_data(platform::Place place, std::type_index type, void* mutable_data(platform::Place place, proto::VarType::Type type,
memory::Allocator::Attr attr = memory::Allocator::kDefault, memory::Allocator::Attr attr = memory::Allocator::kDefault,
size_t requested_size = 0); size_t requested_size = 0);
...@@ -138,7 +138,7 @@ class Tensor { ...@@ -138,7 +138,7 @@ class Tensor {
return holder_->place(); return holder_->place();
} }
std::type_index type() const { proto::VarType::Type type() const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
holder_, "Tensor not initialized yet when Tensor::type() is called."); holder_, "Tensor not initialized yet when Tensor::type() is called.");
return type_; return type_;
...@@ -165,7 +165,7 @@ class Tensor { ...@@ -165,7 +165,7 @@ class Tensor {
private: private:
/*! holds the memory block if allocated. */ /*! holds the memory block if allocated. */
std::shared_ptr<memory::Allocation> holder_; std::shared_ptr<memory::Allocation> holder_;
std::type_index type_; proto::VarType::Type type_;
/** /**
* @brief points to elements dimensions. * @brief points to elements dimensions.
* *
......
...@@ -24,9 +24,8 @@ template <typename T> ...@@ -24,9 +24,8 @@ template <typename T>
inline const T* Tensor::data() const { inline const T* Tensor::data() const {
check_memory_size(); check_memory_size();
bool valid = bool valid =
std::is_same<T, void>::value || type_ == std::type_index(typeid(T)); std::is_same<T, void>::value || type_ == DataTypeTrait<T>::DataType;
PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %s", PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %d", type_);
type_.name());
return reinterpret_cast<const T*>( return reinterpret_cast<const T*>(
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_); reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
...@@ -38,9 +37,8 @@ template <typename T> ...@@ -38,9 +37,8 @@ template <typename T>
inline T* Tensor::data() { inline T* Tensor::data() {
check_memory_size(); check_memory_size();
bool valid = bool valid =
std::is_same<T, void>::value || type_ == std::type_index(typeid(T)); std::is_same<T, void>::value || type_ == DataTypeTrait<T>::DataType;
PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %s", PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %s", type_);
type_.name());
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) + return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_); offset_);
} }
...@@ -60,7 +58,7 @@ inline T* Tensor::mutable_data(platform::Place place, ...@@ -60,7 +58,7 @@ inline T* Tensor::mutable_data(platform::Place place,
size_t requested_size) { size_t requested_size) {
static_assert(std::is_pod<T>::value, "T must be POD"); static_assert(std::is_pod<T>::value, "T must be POD");
return reinterpret_cast<T*>( return reinterpret_cast<T*>(
mutable_data(place, typeid(T), attr, requested_size)); mutable_data(place, DataTypeTrait<T>::DataType, attr, requested_size));
} }
inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) { inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) {
......
...@@ -186,8 +186,8 @@ struct AnyDTypeVisitor { ...@@ -186,8 +186,8 @@ struct AnyDTypeVisitor {
template <typename Predicate, typename DevCtx> template <typename Predicate, typename DevCtx>
inline void AnyImpl(Predicate predicate, const framework::Tensor& tensor, inline void AnyImpl(Predicate predicate, const framework::Tensor& tensor,
const DevCtx& ctx, framework::Tensor* out) { const DevCtx& ctx, framework::Tensor* out) {
VisitDataType(ToDataType(tensor.type()), AnyDTypeVisitor<Predicate, DevCtx>( VisitDataType(tensor.type(), AnyDTypeVisitor<Predicate, DevCtx>(
predicate, tensor, ctx, out)); predicate, tensor, ctx, out));
} }
template <typename Predicate> template <typename Predicate>
...@@ -379,7 +379,7 @@ void TensorToStream(std::ostream& os, const Tensor& tensor, ...@@ -379,7 +379,7 @@ void TensorToStream(std::ostream& os, const Tensor& tensor,
// int32_t size // int32_t size
// void* protobuf message // void* protobuf message
proto::VarType::TensorDesc desc; proto::VarType::TensorDesc desc;
desc.set_data_type(framework::ToDataType(tensor.type())); desc.set_data_type(tensor.type());
auto dims = framework::vectorize(tensor.dims()); auto dims = framework::vectorize(tensor.dims());
auto* pb_dims = desc.mutable_dims(); auto* pb_dims = desc.mutable_dims();
pb_dims->Resize(static_cast<int>(dims.size()), 0); pb_dims->Resize(static_cast<int>(dims.size()), 0);
...@@ -461,9 +461,7 @@ void TensorFromStream(std::istream& is, Tensor* tensor, ...@@ -461,9 +461,7 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
tensor->Resize(framework::make_ddim(dims)); tensor->Resize(framework::make_ddim(dims));
void* buf; void* buf;
auto ctx = platform::CPUDeviceContext(); auto ctx = platform::CPUDeviceContext();
size_t size = size_t size = tensor->numel() * framework::SizeOfType(desc.data_type());
tensor->numel() *
framework::SizeOfType(framework::ToTypeIndex(desc.data_type()));
if (platform::is_gpu_place(dev_ctx.GetPlace())) { if (platform::is_gpu_place(dev_ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
Tensor cpu_tensor; Tensor cpu_tensor;
......
...@@ -289,10 +289,10 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs, ...@@ -289,10 +289,10 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
auto type = fetch.type(); auto type = fetch.type();
auto output = &(outputs->at(i)); auto output = &(outputs->at(i));
output->name = fetchs_[idx]->Input("X")[0]; output->name = fetchs_[idx]->Input("X")[0];
if (type == typeid(float)) { if (type == framework::proto::VarType::FP32) {
GetFetchOne<float>(fetch, output); GetFetchOne<float>(fetch, output);
output->dtype = PaddleDType::FLOAT32; output->dtype = PaddleDType::FLOAT32;
} else if (type == typeid(int64_t)) { } else if (type == framework::proto::VarType::INT64) {
GetFetchOne<int64_t>(fetch, output); GetFetchOne<int64_t>(fetch, output);
output->dtype = PaddleDType::INT64; output->dtype = PaddleDType::INT64;
} else { } else {
......
...@@ -55,7 +55,12 @@ TEST(AnalysisPredictor, analysis_off) { ...@@ -55,7 +55,12 @@ TEST(AnalysisPredictor, analysis_off) {
} }
TEST(AnalysisPredictor, analysis_on) { TEST(AnalysisPredictor, analysis_on) {
AnalysisConfig config(false); #ifdef PADDLE_WITH_CUDA
AnalysisConfig config(true);
config.fraction_of_gpu_memory = 0.15;
#else
AnalysisConfig config;
#endif
config.model_dir = FLAGS_dirname; config.model_dir = FLAGS_dirname;
config.enable_ir_optim = true; config.enable_ir_optim = true;
......
...@@ -266,10 +266,10 @@ bool NativePaddlePredictor::GetFetch(std::vector<PaddleTensor> *outputs, ...@@ -266,10 +266,10 @@ bool NativePaddlePredictor::GetFetch(std::vector<PaddleTensor> *outputs,
auto type = fetch.type(); auto type = fetch.type();
auto output = &(outputs->at(i)); auto output = &(outputs->at(i));
output->name = fetchs_[idx]->Input("X")[0]; output->name = fetchs_[idx]->Input("X")[0];
if (type == typeid(float)) { if (type == framework::DataTypeTrait<float>::DataType) {
GetFetchOne<float>(fetch, output); GetFetchOne<float>(fetch, output);
output->dtype = PaddleDType::FLOAT32; output->dtype = PaddleDType::FLOAT32;
} else if (type == typeid(int64_t)) { } else if (type == framework::DataTypeTrait<int64_t>::DataType) {
GetFetchOne<int64_t>(fetch, output); GetFetchOne<int64_t>(fetch, output);
output->dtype = PaddleDType::INT64; output->dtype = PaddleDType::INT64;
} else { } else {
......
...@@ -36,10 +36,10 @@ namespace paddle { ...@@ -36,10 +36,10 @@ namespace paddle {
PaddleTensor LodTensorToPaddleTensor(framework::LoDTensor* t) { PaddleTensor LodTensorToPaddleTensor(framework::LoDTensor* t) {
PaddleTensor pt; PaddleTensor pt;
if (t->type() == typeid(int64_t)) { if (t->type() == framework::proto::VarType::INT64) {
pt.data.Reset(t->data<void>(), t->numel() * sizeof(int64_t)); pt.data.Reset(t->data<void>(), t->numel() * sizeof(int64_t));
pt.dtype = PaddleDType::INT64; pt.dtype = PaddleDType::INT64;
} else if (t->type() == typeid(float)) { } else if (t->type() == framework::proto::VarType::FP32) {
pt.data.Reset(t->data<void>(), t->numel() * sizeof(float)); pt.data.Reset(t->data<void>(), t->numel() * sizeof(float));
pt.dtype = PaddleDType::FLOAT32; pt.dtype = PaddleDType::FLOAT32;
} else { } else {
......
...@@ -118,7 +118,10 @@ class GpuPassStrategy : public PassStrategy { ...@@ -118,7 +118,10 @@ class GpuPassStrategy : public PassStrategy {
public: public:
GpuPassStrategy() : PassStrategy({}) { GpuPassStrategy() : PassStrategy({}) {
passes_.assign({ passes_.assign({
"infer_clean_graph_pass", "conv_bn_fuse_pass", "infer_clean_graph_pass", //
"conv_bn_fuse_pass", //
"conv_elementwise_add_act_fuse_pass", //
"conv_elementwise_add2_act_fuse_pass", //
}); });
} }
......
...@@ -79,7 +79,7 @@ void LoadPersistables(framework::Executor* executor, framework::Scope* scope, ...@@ -79,7 +79,7 @@ void LoadPersistables(framework::Executor* executor, framework::Scope* scope,
for (auto* var : global_block.AllVars()) { for (auto* var : global_block.AllVars()) {
if (IsPersistable(var)) { if (IsPersistable(var)) {
VLOG(3) << "persistable variable's name: " << var->Name(); VLOG(4) << "persistable variable's name: " << var->Name();
framework::VarDesc* new_var = load_block->Var(var->Name()); framework::VarDesc* new_var = load_block->Var(var->Name());
new_var->SetShape(var->GetShape()); new_var->SetShape(var->GetShape());
......
...@@ -373,7 +373,7 @@ static bool CompareTensorData(const framework::LoDTensor &a, ...@@ -373,7 +373,7 @@ static bool CompareTensorData(const framework::LoDTensor &a,
} }
for (size_t i = 0; i < a_size; i++) { for (size_t i = 0; i < a_size; i++) {
if (a.type() == typeid(float)) { if (a.type() == framework::proto::VarType::FP32) {
const auto *a_data = a.data<float>(); const auto *a_data = a.data<float>();
const auto *b_data = b.data<float>(); const auto *b_data = b.data<float>();
if (std::abs(a_data[i] - b_data[i]) > 1e-3) { if (std::abs(a_data[i] - b_data[i]) > 1e-3) {
...@@ -382,7 +382,7 @@ static bool CompareTensorData(const framework::LoDTensor &a, ...@@ -382,7 +382,7 @@ static bool CompareTensorData(const framework::LoDTensor &a,
b_data[i]); b_data[i]);
return false; return false;
} }
} else if (a.type() == typeid(int64_t)) { } else if (a.type() == framework::proto::VarType::INT64) {
const auto *a_data = a.data<int64_t>(); const auto *a_data = a.data<int64_t>();
const auto *b_data = b.data<int64_t>(); const auto *b_data = b.data<int64_t>();
if (std::abs(a_data[i] - b_data[i]) > 1e-3) { if (std::abs(a_data[i] - b_data[i]) > 1e-3) {
......
...@@ -78,6 +78,7 @@ void profile(std::string model_dir, bool use_analysis, bool use_tensorrt) { ...@@ -78,6 +78,7 @@ void profile(std::string model_dir, bool use_analysis, bool use_tensorrt) {
std::vector<PaddleTensor> outputs; std::vector<PaddleTensor> outputs;
if (use_analysis || use_tensorrt) { if (use_analysis || use_tensorrt) {
contrib::AnalysisConfig config(true); contrib::AnalysisConfig config(true);
config.pass_builder()->TurnOnDebug();
SetConfig<contrib::AnalysisConfig>(&config, model_dir, true, use_tensorrt, SetConfig<contrib::AnalysisConfig>(&config, model_dir, true, use_tensorrt,
FLAGS_batch_size); FLAGS_batch_size);
TestPrediction(reinterpret_cast<PaddlePredictor::Config*>(&config), TestPrediction(reinterpret_cast<PaddlePredictor::Config*>(&config),
...@@ -141,9 +142,31 @@ TEST(TensorRT_resnext50, profile) { ...@@ -141,9 +142,31 @@ TEST(TensorRT_resnext50, profile) {
profile(model_dir, /* use_analysis */ true, FLAGS_use_tensorrt); profile(model_dir, /* use_analysis */ true, FLAGS_use_tensorrt);
} }
TEST(resnext50, compare_analysis_native) {
std::string model_dir = FLAGS_infer_model + "/resnext50";
compare(model_dir, false /*use tensorrt*/);
}
TEST(TensorRT_mobilenet, analysis) { TEST(TensorRT_mobilenet, analysis) {
std::string model_dir = FLAGS_infer_model + "/" + "mobilenet"; std::string model_dir = FLAGS_infer_model + "/" + "mobilenet";
compare(model_dir, /* use_tensorrt */ false); compare(model_dir, false /* use_tensorrt */);
}
TEST(AnalysisPredictor, use_gpu) {
std::string model_dir = FLAGS_infer_model + "/" + "mobilenet";
AnalysisConfig config(true);
config.model_dir = model_dir;
config.fraction_of_gpu_memory = 0.15;
config.pass_builder()->TurnOnDebug();
std::vector<std::vector<PaddleTensor>> inputs_all;
auto predictor = CreatePaddlePredictor(config);
SetFakeImageInput(&inputs_all, model_dir, false, "__model__", "");
std::vector<PaddleTensor> outputs;
for (auto& input : inputs_all) {
ASSERT_TRUE(predictor->Run(input, &outputs));
}
} }
} // namespace inference } // namespace inference
......
...@@ -78,7 +78,7 @@ class AffineGridOp : public framework::OperatorWithKernel { ...@@ -78,7 +78,7 @@ class AffineGridOp : public framework::OperatorWithKernel {
library = framework::LibraryType::kCUDNN; library = framework::LibraryType::kCUDNN;
} }
#endif #endif
auto data_type = framework::ToDataType(ctx.Input<Tensor>("Theta")->type()); auto data_type = ctx.Input<Tensor>("Theta")->type();
return framework::OpKernelType(data_type, ctx.GetPlace(), return framework::OpKernelType(data_type, ctx.GetPlace(),
framework::DataLayout::kAnyLayout, library); framework::DataLayout::kAnyLayout, library);
} }
...@@ -188,9 +188,9 @@ class AffineGridOpGrad : public framework::OperatorWithKernel { ...@@ -188,9 +188,9 @@ class AffineGridOpGrad : public framework::OperatorWithKernel {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
} }
#endif #endif
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Theta")->type(),
framework::ToDataType(ctx.Input<Tensor>("Theta")->type()), ctx.GetPlace(),
ctx.GetPlace(), framework::DataLayout::kAnyLayout, library_); framework::DataLayout::kAnyLayout, library_);
} }
}; };
......
...@@ -28,6 +28,5 @@ REGISTER_OP_CPU_KERNEL( ...@@ -28,6 +28,5 @@ REGISTER_OP_CPU_KERNEL(
int32_t>, int32_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
int16_t>, int16_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, size_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
uint8_t>); uint8_t>);
...@@ -25,7 +25,5 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -25,7 +25,5 @@ REGISTER_OP_CUDA_KERNEL(
int32_t>, int32_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext, paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
int16_t>, int16_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
size_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext, paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
uint8_t>); uint8_t>);
...@@ -28,6 +28,5 @@ REGISTER_OP_CPU_KERNEL( ...@@ -28,6 +28,5 @@ REGISTER_OP_CPU_KERNEL(
int32_t>, int32_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
int16_t>, int16_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, size_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
uint8_t>); uint8_t>);
...@@ -25,7 +25,5 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -25,7 +25,5 @@ REGISTER_OP_CUDA_KERNEL(
int32_t>, int32_t>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext, paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
int16_t>, int16_t>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
size_t>,
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext, paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
uint8_t>); uint8_t>);
...@@ -58,7 +58,7 @@ struct ArrayToLoDFunctor : public boost::static_visitor<void> { ...@@ -58,7 +58,7 @@ struct ArrayToLoDFunctor : public boost::static_visitor<void> {
ArrayToLoDFunctorImpl<DeviceContext> functor; ArrayToLoDFunctorImpl<DeviceContext> functor;
functor.dev_ctx_ = dev_ctx; functor.dev_ctx_ = dev_ctx;
functor.prev_functor_ = this; functor.prev_functor_ = this;
framework::VisitDataType(framework::ToDataType(out->type()), functor); framework::VisitDataType(out->type(), functor);
} }
}; };
...@@ -91,7 +91,7 @@ class ArrayToLoDTensorOp : public framework::OperatorBase { ...@@ -91,7 +91,7 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
PADDLE_ENFORCE(!x.empty(), "There's no element in the input array."); PADDLE_ENFORCE(!x.empty(), "There's no element in the input array.");
int rank = x[0].dims().size(); int rank = x[0].dims().size();
platform::Place place = x[0].place(); platform::Place place = x[0].place();
std::type_index data_type = x[0].type(); auto data_type = x[0].type();
int64_t batch_size = x[0].dims()[0]; int64_t batch_size = x[0].dims()[0];
framework::DDim ins_dims = rank > 1 framework::DDim ins_dims = rank > 1
? framework::slice_ddim(x[0].dims(), 1, rank) ? framework::slice_ddim(x[0].dims(), 1, rank)
......
...@@ -121,9 +121,8 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -121,9 +121,8 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType AttentionLSTMOp::GetExpectedKernelType( framework::OpKernelType AttentionLSTMOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
void AttentionLSTMOpMaker::Make() { void AttentionLSTMOpMaker::Make() {
......
...@@ -103,9 +103,8 @@ class AverageAccumulatesOp : public framework::OperatorWithKernel { ...@@ -103,9 +103,8 @@ class AverageAccumulatesOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("param")->type(),
framework::ToDataType(ctx.Input<Tensor>("param")->type()), ctx.GetPlace());
ctx.GetPlace());
} }
}; };
......
...@@ -72,8 +72,7 @@ class BatchNormOp : public framework::OperatorWithKernel { ...@@ -72,8 +72,7 @@ class BatchNormOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = auto input_data_type = ctx.Input<Tensor>("X")->type();
framework::ToDataType(ctx.Input<Tensor>("X")->type());
// By default, the type of the scale, bias, mean, // By default, the type of the scale, bias, mean,
// and var tensors should both be float. (For float or float16 input tensor) // and var tensors should both be float. (For float or float16 input tensor)
// or double (For double input tensor). // or double (For double input tensor).
...@@ -81,17 +80,13 @@ class BatchNormOp : public framework::OperatorWithKernel { ...@@ -81,17 +80,13 @@ class BatchNormOp : public framework::OperatorWithKernel {
if (input_data_type == framework::proto::VarType::FP64) { if (input_data_type == framework::proto::VarType::FP64) {
bn_param_type = framework::proto::VarType::FP64; bn_param_type = framework::proto::VarType::FP64;
} }
PADDLE_ENFORCE_EQ(bn_param_type, PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Scale")->type(),
framework::ToDataType(ctx.Input<Tensor>("Scale")->type()),
"Scale input should be of float type"); "Scale input should be of float type");
PADDLE_ENFORCE_EQ(bn_param_type, PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Bias")->type(),
framework::ToDataType(ctx.Input<Tensor>("Bias")->type()),
"Bias input should be of float type"); "Bias input should be of float type");
PADDLE_ENFORCE_EQ(bn_param_type, PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Mean")->type(),
framework::ToDataType(ctx.Input<Tensor>("Mean")->type()),
"Mean input should be of float type"); "Mean input should be of float type");
PADDLE_ENFORCE_EQ(bn_param_type, framework::ToDataType( PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Variance")->type(),
ctx.Input<Tensor>("Variance")->type()),
"Variance input should be of float type"); "Variance input should be of float type");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready // TODO(pzelazko-intel): enable MKLDNN layout when it's ready
...@@ -413,9 +408,8 @@ class BatchNormGradOp : public framework::OperatorWithKernel { ...@@ -413,9 +408,8 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
} }
#endif #endif
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(), ctx.GetPlace(), layout, library);
layout, library);
} }
}; };
......
...@@ -145,7 +145,7 @@ class BeamSearchDecodeOp : public framework::OperatorBase { ...@@ -145,7 +145,7 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
LoDTensor* sentenceScores = ctx.Output<LoDTensor>("SentenceScores"); LoDTensor* sentenceScores = ctx.Output<LoDTensor>("SentenceScores");
framework::VisitDataType( framework::VisitDataType(
framework::ToDataType(scores->at(0).type()), scores->at(0).type(),
BeamSearchDecodeFunctor(*ids, *scores, sentenceIds, sentenceScores, BeamSearchDecodeFunctor(*ids, *scores, sentenceIds, sentenceScores,
beam_size, end_id)); beam_size, end_id));
} }
......
...@@ -282,8 +282,7 @@ class BeamSearchOp : public framework::OperatorWithKernel { ...@@ -282,8 +282,7 @@ class BeamSearchOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = framework::OpKernelType( framework::OpKernelType kt = framework::OpKernelType(
framework::ToDataType( ctx.Input<framework::LoDTensor>("pre_ids")->type(),
ctx.Input<framework::LoDTensor>("pre_ids")->type()),
platform::CPUPlace()); platform::CPUPlace());
return kt; return kt;
} }
......
...@@ -47,9 +47,8 @@ class BprLossOp : public framework::OperatorWithKernel { ...@@ -47,9 +47,8 @@ class BprLossOp : public framework::OperatorWithKernel {
// is determined by its input "X". // is determined by its input "X".
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()), platform::CPUPlace());
platform::CPUPlace());
} }
}; };
...@@ -94,9 +93,8 @@ class BprLossGradientOp : public framework::OperatorWithKernel { ...@@ -94,9 +93,8 @@ class BprLossGradientOp : public framework::OperatorWithKernel {
// is determined by its input "X". // is determined by its input "X".
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()), platform::CPUPlace());
platform::CPUPlace());
} }
}; };
......
include(operators) include(operators)
register_operators() register_operators(DEPS naive_executor)
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n") file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n")
...@@ -48,13 +48,12 @@ class ConditionalOp : public framework::OperatorBase { ...@@ -48,13 +48,12 @@ class ConditionalOp : public framework::OperatorBase {
if (!(ips.size() == 1UL && ips[0]->IsInitialized())) { if (!(ips.size() == 1UL && ips[0]->IsInitialized())) {
PADDLE_THROW("should have one initialized input as condition"); PADDLE_THROW("should have one initialized input as condition");
} }
if (!(framework::IsType<bool>(ips[0]->type()) && // NOLINT
ips[0]->numel() == 1)) { PADDLE_ENFORCE(ips[0]->type() == framework::proto::VarType::BOOL &&
PADDLE_THROW( ips[0]->numel() == 1,
"condition input's data type should be bool, " "condition input's data type should be bool, "
"numel should be 1, actual numel is %d", "numel should be 1, actual numel is %d",
ips[0]->numel()); ips[0]->numel());
}
bool res = false; bool res = false;
if (platform::is_gpu_place(ips[0]->place())) { if (platform::is_gpu_place(ips[0]->place())) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
......
...@@ -261,7 +261,7 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -261,7 +261,7 @@ class WhileGradOp : public framework::OperatorBase {
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
auto &inside_tensor = var->Get<framework::LoDTensor>(); auto &inside_tensor = var->Get<framework::LoDTensor>();
framework::AttributeMap attrs; framework::AttributeMap attrs;
attrs["dtype"] = framework::ToDataType(inside_tensor.type()); attrs["dtype"] = inside_tensor.type();
attrs["shape"] = framework::vectorize2int(inside_tensor.dims()); attrs["shape"] = framework::vectorize2int(inside_tensor.dims());
attrs["value"] = 0.0f; attrs["value"] = 0.0f;
......
...@@ -44,7 +44,9 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -44,7 +44,9 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
std::vector<int> dilations = ctx->Attrs().Get<std::vector<int>>("dilations"); std::vector<int> dilations = ctx->Attrs().Get<std::vector<int>>("dilations");
PADDLE_ENFORCE(in_dims.size() == 4 || in_dims.size() == 5, PADDLE_ENFORCE(in_dims.size() == 4 || in_dims.size() == 5,
"Conv intput should be 4-D or 5-D tensor."); "Conv intput should be 4-D or 5-D tensor, get %u",
in_dims.size());
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
in_dims.size(), filter_dims.size(), in_dims.size(), filter_dims.size(),
"Conv input dimension and filter dimension should be the same."); "Conv input dimension and filter dimension should be the same.");
...@@ -95,10 +97,8 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( ...@@ -95,10 +97,8 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
} }
#endif #endif
auto input_data_type = auto input_data_type = ctx.Input<Tensor>("Input")->type();
framework::ToDataType(ctx.Input<Tensor>("Input")->type()); auto filter_data_type = ctx.Input<Tensor>("Filter")->type();
auto filter_data_type =
framework::ToDataType(ctx.Input<Tensor>("Filter")->type());
PADDLE_ENFORCE_EQ(input_data_type, filter_data_type, PADDLE_ENFORCE_EQ(input_data_type, filter_data_type,
"input and filter data type should be consistent"); "input and filter data type should be consistent");
...@@ -382,9 +382,9 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( ...@@ -382,9 +382,9 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
} }
#endif #endif
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(), ctx.GetPlace(), layout_, library_,
layout_, library_, customized_type_value); customized_type_value);
} }
} // namespace operators } // namespace operators
......
...@@ -104,9 +104,8 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( ...@@ -104,9 +104,8 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
} }
#endif #endif
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(), ctx.GetPlace(), layout_, library_);
layout_, library_);
} }
void Conv2DTransposeOpMaker::Make() { void Conv2DTransposeOpMaker::Make() {
...@@ -335,9 +334,8 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType( ...@@ -335,9 +334,8 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType(
std::string data_format = ctx.Attr<std::string>("data_format"); std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(), ctx.GetPlace(), layout_, library_);
layout_, library_);
} }
} // namespace operators } // namespace operators
......
...@@ -118,9 +118,8 @@ class CRFDecodingOp : public framework::OperatorWithKernel { ...@@ -118,9 +118,8 @@ class CRFDecodingOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<LoDTensor>("Emission")->type(),
framework::ToDataType(ctx.Input<LoDTensor>("Emission")->type()), platform::CPUPlace());
platform::CPUPlace());
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -51,9 +51,8 @@ class CropOp : public framework::OperatorWithKernel { ...@@ -51,9 +51,8 @@ class CropOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
...@@ -174,9 +173,7 @@ class CropOpGrad : public framework::OperatorWithKernel { ...@@ -174,9 +173,7 @@ class CropOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(),
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))
->type()),
ctx.device_context()); ctx.device_context());
} }
}; };
......
...@@ -57,9 +57,8 @@ class CrossEntropyOp : public framework::OperatorWithKernel { ...@@ -57,9 +57,8 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
// is determined by its input "X". // is determined by its input "X".
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
...@@ -111,9 +110,8 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { ...@@ -111,9 +110,8 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
// is determined by its input "X". // is determined by its input "X".
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -36,9 +36,8 @@ class CTCAlignOp : public framework::OperatorWithKernel { ...@@ -36,9 +36,8 @@ class CTCAlignOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -300,9 +300,11 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> { ...@@ -300,9 +300,11 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
} }
CudnnRNNCache *cudnn_rnn_cache = nullptr; CudnnRNNCache *cudnn_rnn_cache = nullptr;
if (cache_var->IsInitialized()) { if (cache_var->IsInitialized()) {
// const_cast is usually bad.
cudnn_rnn_cache = const_cast<framework::Variable *>(cache_var) cudnn_rnn_cache = const_cast<framework::Variable *>(cache_var)
->GetMutable<CudnnRNNCache>(); ->GetMutable<CudnnRNNCache>();
} else { } else {
// const_cast is usually bad.
cudnn_rnn_cache = const_cast<framework::Variable *>(cache_var) cudnn_rnn_cache = const_cast<framework::Variable *>(cache_var)
->GetMutable<CudnnRNNCache>(); ->GetMutable<CudnnRNNCache>();
std::random_device rnd; std::random_device rnd;
......
...@@ -53,8 +53,7 @@ class AnchorGeneratorOp : public framework::OperatorWithKernel { ...@@ -53,8 +53,7 @@ class AnchorGeneratorOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("Input")->type()), ctx.Input<framework::Tensor>("Input")->type(), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -45,9 +45,8 @@ class BipartiteMatchOp : public framework::OperatorWithKernel { ...@@ -45,9 +45,8 @@ class BipartiteMatchOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<LoDTensor>("DistMat")->type(),
framework::ToDataType(ctx.Input<LoDTensor>("DistMat")->type()), platform::CPUPlace());
platform::CPUPlace());
} }
}; };
......
...@@ -66,8 +66,7 @@ class DensityPriorBoxOp : public framework::OperatorWithKernel { ...@@ -66,8 +66,7 @@ class DensityPriorBoxOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("Input")->type()), ctx.Input<framework::Tensor>("Input")->type(), ctx.GetPlace());
ctx.GetPlace());
} }
}; };
......
...@@ -66,9 +66,8 @@ class GenerateProposalsOp : public framework::OperatorWithKernel { ...@@ -66,9 +66,8 @@ class GenerateProposalsOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Anchors")->type(),
framework::ToDataType(ctx.Input<Tensor>("Anchors")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -249,8 +249,7 @@ class MineHardExamplesOp : public framework::OperatorWithKernel { ...@@ -249,8 +249,7 @@ class MineHardExamplesOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("ClsLoss")->type()), ctx.Input<framework::Tensor>("ClsLoss")->type(), platform::CPUPlace());
platform::CPUPlace());
} }
}; };
......
...@@ -65,8 +65,7 @@ class MultiClassNMSOp : public framework::OperatorWithKernel { ...@@ -65,8 +65,7 @@ class MultiClassNMSOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( ctx.Input<framework::LoDTensor>("Scores")->type(),
ctx.Input<framework::LoDTensor>("Scores")->type()),
platform::CPUPlace()); platform::CPUPlace());
} }
}; };
......
...@@ -72,8 +72,7 @@ class PriorBoxOp : public framework::OperatorWithKernel { ...@@ -72,8 +72,7 @@ class PriorBoxOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("Input")->type()), ctx.Input<framework::Tensor>("Input")->type(), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -498,9 +498,8 @@ class ROIPerspectiveTransformOp : public framework::OperatorWithKernel { ...@@ -498,9 +498,8 @@ class ROIPerspectiveTransformOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
...@@ -519,9 +518,8 @@ class ROIPerspectiveTransformGradOp : public framework::OperatorWithKernel { ...@@ -519,9 +518,8 @@ class ROIPerspectiveTransformGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -78,8 +78,7 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel { ...@@ -78,8 +78,7 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( ctx.Input<framework::LoDTensor>("Anchor")->type(),
ctx.Input<framework::LoDTensor>("Anchor")->type()),
platform::CPUPlace()); platform::CPUPlace());
} }
}; };
......
...@@ -57,9 +57,8 @@ class TargetAssignOp : public framework::OperatorWithKernel { ...@@ -57,9 +57,8 @@ class TargetAssignOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), ctx.device_context());
ctx.device_context());
} }
}; };
......
...@@ -71,8 +71,7 @@ class DetectionMAPOp : public framework::OperatorWithKernel { ...@@ -71,8 +71,7 @@ class DetectionMAPOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( ctx.Input<framework::Tensor>("DetectRes")->type(),
ctx.Input<framework::Tensor>("DetectRes")->type()),
platform::CPUPlace()); platform::CPUPlace());
} }
}; };
......
...@@ -12,7 +12,7 @@ configure_file(send_recv.proto.in ${CMAKE_CURRENT_SOURCE_DIR}/send_recv.proto @O ...@@ -12,7 +12,7 @@ configure_file(send_recv.proto.in ${CMAKE_CURRENT_SOURCE_DIR}/send_recv.proto @O
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
if(WITH_GRPC) if(WITH_GRPC)
grpc_library(sendrecvop_grpc SRCS grpc_bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc grpc_library(sendrecvop_rpc SRCS grpc_bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc grpc_variable_response.cc grpc_serde.cc collective_client.cc collective_server.cc request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc grpc_variable_response.cc grpc_serde.cc collective_client.cc collective_server.cc
PROTO send_recv.proto PROTO send_recv.proto
DEPS lod_tensor selected_rows_functor memory) DEPS lod_tensor selected_rows_functor memory)
...@@ -20,36 +20,43 @@ if(WITH_GRPC) ...@@ -20,36 +20,43 @@ if(WITH_GRPC)
set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(grpc_serde_test SRCS grpc_serde_test.cc cc_test(grpc_serde_test SRCS grpc_serde_test.cc
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL) DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_rpc scope profiler math_function SERIAL)
cc_test(rpc_server_test SRCS rpc_server_test.cc cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL) DEPS sendrecvop_rpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL)
cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler) cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler)
if(WITH_GPU) if(WITH_GPU)
cc_test(collective_server_test SRCS collective_server_test.cc cc_test(collective_server_test SRCS collective_server_test.cc
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor DEPS sendrecvop_rpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor
selected_rows_functor scope math_function SERIAL) selected_rows_functor scope math_function SERIAL)
endif() endif()
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_grpc memory) cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory)
else() else()
set_source_files_properties(brpc_server.cc brpc_client.cc rpc_server_test.cc brpc_serde_test.cc set_source_files_properties(brpc_server.cc parameter_prefetch.cc brpc_client.cc rpc_server_test.cc brpc_serde_test.cc
brpc_variable_response.cc brpc_sendrecvop_utils.cc brpc_rdma_pool.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) brpc_variable_response.cc brpc_sendrecvop_utils.cc brpc_rdma_pool.cc collective_server.cc collective_server_test.cc
collective_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
brpc_library(sendrecvop_brpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc brpc_sendrecvop_utils.cc brpc_library(sendrecvop_rpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc brpc_sendrecvop_utils.cc
brpc_variable_response.cc variable_response.cc sendrecvop_utils.cc brpc_rdma_pool.cc brpc_variable_response.cc variable_response.cc sendrecvop_utils.cc brpc_rdma_pool.cc collective_client.cc collective_server.cc
PROTO send_recv.proto PROTO send_recv.proto
DEPS lod_tensor selected_rows memory) DEPS lod_tensor selected_rows memory)
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_brpc memory) cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory)
set(brpc_test_depends sendrecvop_brpc brpc ssl crypto protobuf leveldb gflags glog executor proto_desc lookup_table_op snappystream snappy) set(brpc_test_depends sendrecvop_rpc brpc ssl crypto protobuf leveldb gflags glog executor
proto_desc lookup_sparse_table_op snappystream snappy zlib)
cc_test(brpc_server_test SRCS rpc_server_test.cc cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS ${brpc_test_depends} SERIAL) DEPS ${brpc_test_depends} SERIAL)
cc_test(brpc_serde_test SRCS brpc_serde_test.cc cc_test(brpc_serde_test SRCS brpc_serde_test.cc
DEPS ${brpc_test_depends} SERIAL) DEPS ${brpc_test_depends} SERIAL)
if(WITH_GPU)
cc_test(collective_server_test SRCS collective_server_test.cc
DEPS ${brpc_test_depends} selected_rows_functor scope math_function SERIAL)
endif()
endif() endif()
...@@ -14,135 +14,316 @@ ...@@ -14,135 +14,316 @@
#include "paddle/fluid/operators/distributed/brpc_client.h" #include "paddle/fluid/operators/distributed/brpc_client.h"
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace distributed { namespace distributed {
DEFINE_int32(brpc_channel_num, 24,
"Number of channels to send requests connected to one server");
DEFINE_int32(timeout_ms, 30000, "RPC timeout in milliseconds"); DEFINE_int32(timeout_ms, 30000, "RPC timeout in milliseconds");
DEFINE_int32(max_retry, 3, "Max retries(not including the first RPC)"); DEFINE_int32(max_retry, 3, "Max retries(not including the first RPC)");
BRPCClient::~BRPCClient() { Wait(); } BRPCClient::~BRPCClient() { Wait(); }
void HandleSendResponse(brpc::Controller* cntl, void HandleSendResponse(brpc::Controller* cntl, sendrecv::VoidMessage* response,
sendrecv::VoidMessage* response) { VarHandlePtr var_h, ChannelQueuePtr ch_ptr,
ChannelContextPtr ch_ctx, BRPCClient* cls) {
// std::unique_ptr makes sure cntl/response will be deleted before returning. // std::unique_ptr makes sure cntl/response will be deleted before returning.
std::unique_ptr<brpc::Controller> cntl_guard(cntl); std::unique_ptr<brpc::Controller> cntl_guard(cntl);
std::unique_ptr<sendrecv::VoidMessage> response_guard(response); std::unique_ptr<sendrecv::VoidMessage> response_guard(response);
// this channel can be used by other now.
ch_ptr->Push(ch_ctx);
if (cntl->Failed()) { if (cntl->Failed()) {
LOG(WARNING) << "Fail to send EchoRequest, " << cntl->ErrorText(); LOG(FATAL) << "Fail to send SendVar: " << var_h->name()
<< ", error text: " << cntl->ErrorText();
var_h->Finish(false);
cls->DecreaseReqCount();
return; return;
} }
LOG(INFO) << "Received response from " << cntl->remote_side() var_h->Finish(true);
<< " latency=" << cntl->latency_us() << "us"; cls->DecreaseReqCount();
VLOG(4) << "HandleSendResponse from: " << cntl->remote_side()
<< ", varname: " << var_h->name()
<< ", latency: " << cntl->latency_us() << "us";
VLOG(4) << "Finish HandleSendResponse";
} }
bool BRPCClient::AsyncSendVar(const std::string& ep, VarHandlePtr BRPCClient::AsyncSendVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& var_name, int64_t time_out) { const std::string& var_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx; const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep; const std::string ep_val = ep;
const std::string var_name_val = var_name; const std::string var_name_val = var_name;
const framework::Scope* p_scope = &scope; const framework::Scope* p_scope = &scope;
const auto ch_ptr = GetChannel(ep_val); const auto ch_ptr = GetChannel(ep_val);
const std::string method = "SendRPC";
VarHandlePtr var_h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
framework::AsyncIO([=] {
auto ch_ctx = ch_ptr->Pop();
brpc::Controller* cntl = new brpc::Controller();
sendrecv::VoidMessage* response = new sendrecv::VoidMessage();
cntl->set_timeout_ms(time_out);
framework::AsyncIO( auto* var = p_scope->FindVar(var_name_val);
[var_name_val, p_ctx, ep_val, p_scope, time_out, ch_ptr, this] { sendrecv::VariableMessage request;
auto ch_ctx = ch_ptr->Pop(); distributed::SerializeToIOBuf(var_name_val, var, *p_ctx, &request,
brpc::Controller* cntl = new brpc::Controller(); &cntl->request_attachment(), "", false,
sendrecv::VoidMessage* response = new sendrecv::VoidMessage(); trainer_id_);
cntl->set_timeout_ms(time_out);
google::protobuf::Closure* done = google::protobuf::Closure* done = brpc::NewCallback(
brpc::NewCallback(&HandleSendResponse, cntl, response); &HandleSendResponse, cntl, response, var_h, ch_ptr, ch_ctx, this);
sendrecv::VariableMessage request; platform::RecordRPCEvent record_event(method, p_ctx);
ch_ctx->stub->SendVariable(cntl, &request, response, done);
}); ch_ctx->stub->SendVariable(cntl, &request, response, done);
if (UNLIKELY(platform::IsProfileEnabled())) {
var_h->Wait();
}
});
req_count_++; req_count_++;
return true; return var_h;
} }
void HandleFetchBarrierResponse(brpc::Controller* cntl,
sendrecv::VariableMessage* response,
VarHandlePtr var_h, ChannelQueuePtr ch_ptr,
ChannelContextPtr ch_ctx, BRPCClient* cls) {
// std::unique_ptr makes sure cntl/response will be deleted before returning.
std::unique_ptr<brpc::Controller> cntl_guard(cntl);
std::unique_ptr<sendrecv::VariableMessage> response_guard(response);
// this channel can be used other now.
ch_ptr->Push(ch_ctx);
if (cntl->Failed()) {
LOG(FATAL) << "Fail to get HandleFetchBarrierResponse: " << var_h->name()
<< ", error text: " << cntl->ErrorText();
var_h->Finish(false);
cls->DecreaseReqCount();
return;
}
var_h->Finish(true);
cls->DecreaseReqCount();
VLOG(4) << "HandleFetchBarrierResponse from: " << cntl->remote_side()
<< ", varname: " << var_h->name()
<< ", latency: " << cntl->latency_us() << "us";
VLOG(4) << "Finish HandleFetchBarrierResponse";
}
void HandleGetResponse(brpc::Controller* cntl, void HandleGetResponse(brpc::Controller* cntl,
sendrecv::VariableMessage* response) { sendrecv::VariableMessage* response, VarHandlePtr var_h,
ChannelQueuePtr ch_ptr, ChannelContextPtr ch_ctx,
BRPCClient* cls) {
// std::unique_ptr makes sure cntl/response will be deleted before returning. // std::unique_ptr makes sure cntl/response will be deleted before returning.
std::unique_ptr<brpc::Controller> cntl_guard(cntl); std::unique_ptr<brpc::Controller> cntl_guard(cntl);
std::unique_ptr<sendrecv::VariableMessage> response_guard(response); std::unique_ptr<sendrecv::VariableMessage> response_guard(response);
// this channel can be used other now.
ch_ptr->Push(ch_ctx);
if (cntl->Failed()) { if (cntl->Failed()) {
LOG(WARNING) << "Fail to send EchoRequest, " << cntl->ErrorText(); LOG(FATAL) << "Fail to GetVar: " << var_h->name()
<< ", error text: " << cntl->ErrorText();
cls->DecreaseReqCount();
var_h->Finish(false);
return; return;
} }
LOG(INFO) << "Received response from " << cntl->remote_side()
<< " latency=" << cntl->latency_us() << "us";
// framework::Variable* outvar = nullptr; VLOG(4) << "HandleGetResponse from: " << cntl->remote_side()
// DeserializeFromByteBuffer(ret_msg, *var_h.ctx, var_h.scope, &outvar); << ", varname: " << var_h->name()
<< ", latency: " << cntl->latency_us() << "us";
framework::Variable* outvar = nullptr;
int trainer_id;
distributed::DeserializeFromIOBuf(*response, cntl->response_attachment(),
*var_h->ctx(), var_h->scope(), &outvar,
&trainer_id);
VLOG(4) << "Finish HandleGetResponse";
cls->DecreaseReqCount();
var_h->Finish(true);
} }
bool BRPCClient::AsyncGetVar(const std::string& ep, VarHandlePtr BRPCClient::_AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& var_name, int64_t time_out) { const std::string& var_name,
const std::string& method_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx; const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep; const std::string ep_val = ep;
const std::string var_name_val = var_name; const std::string var_name_val = var_name;
const framework::Scope* p_scope = &scope; const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val); const auto ch_ptr = GetChannel(ep_val);
const std::string method = "GetRPC";
VarHandlePtr var_h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
framework::AsyncIO([=] {
auto ch_ctx = ch_ptr->Pop();
brpc::Controller* cntl = new brpc::Controller();
sendrecv::VariableMessage* response = new sendrecv::VariableMessage();
cntl->set_timeout_ms(time_out);
framework::AsyncIO( sendrecv::VariableMessage req;
[var_name_val, ep_val, p_scope, p_ctx, time_out, ch, this] {}); req.set_varname(var_name_val);
req.set_trainer_id(trainer_id_);
google::protobuf::Closure* done = brpc::NewCallback(
&HandleGetResponse, cntl, response, var_h, ch_ptr, ch_ctx, this);
platform::RecordRPCEvent record_event(method, p_ctx);
if (method_name == "GetMonomerVariable") {
ch_ctx->stub->GetMonomerVariable(cntl, &req, response, done);
} else {
ch_ctx->stub->GetVariable(cntl, &req, response, done);
}
if (UNLIKELY(platform::IsProfileEnabled())) {
var_h->Wait();
}
});
req_count_++; req_count_++;
return true; return var_h;
}
VarHandlePtr BRPCClient::AsyncGetMonomerVariable(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out) {
return _AsyncGetVar(ep, ctx, scope, var_name, "GetMonomerVariable", time_out);
}
VarHandlePtr BRPCClient::AsyncGetMonomerBarrier(const std::string& ep,
const std::string& var_name,
int64_t time_out) {
return AsyncSendMessage(ep, "GetMonomerBarrier", var_name, time_out);
} }
bool BRPCClient::AsyncPrefetchVar(const std::string& ep, VarHandlePtr BRPCClient::AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& in_var_name, const std::string& var_name,
const std::string& out_var_name, int64_t time_out) {
int64_t time_out) { return _AsyncGetVar(ep, ctx, scope, var_name, "GetVariable", time_out);
}
VarHandlePtr BRPCClient::AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
const std::string& table_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx; const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep; const std::string ep_val = ep;
const std::string in_var_name_val = in_var_name; const std::string in_var_name_val = in_var_name;
const std::string out_var_name_val = out_var_name; const std::string out_var_name_val = out_var_name;
const std::string table_name_val = table_name;
const framework::Scope* p_scope = &scope; const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val); const auto ch_ptr = GetChannel(ep_val);
const std::string method = "PrefetchRPC";
VarHandlePtr var_h(
new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
framework::AsyncIO([=] {
auto ch_ctx = ch_ptr->Pop();
brpc::Controller* cntl = new brpc::Controller();
sendrecv::VariableMessage* response = new sendrecv::VariableMessage();
cntl->set_timeout_ms(time_out);
auto* var = p_scope->FindVar(in_var_name_val);
sendrecv::VariableMessage req;
distributed::SerializeToIOBuf(in_var_name_val, var, *p_ctx, &req,
&cntl->request_attachment(), out_var_name_val,
false, 0, table_name_val);
platform::RecordRPCEvent record_event(method, p_ctx);
google::protobuf::Closure* done = brpc::NewCallback(
&HandleGetResponse, cntl, response, var_h, ch_ptr, ch_ctx, this);
framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx, ch_ctx->stub->PrefetchVariable(cntl, &req, response, done);
time_out, ch, this] {});
if (UNLIKELY(platform::IsProfileEnabled())) {
var_h->Wait();
}
});
req_count_++; req_count_++;
return true; return var_h;
} }
void BRPCClient::AsyncSendBatchBarrier(const std::string& ep, VarHandlePtr BRPCClient::AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out) { int64_t time_out) {
req_count_++; return AsyncSendMessage(ep, "BatchBarrierRPC", BATCH_BARRIER_MESSAGE,
time_out);
} }
void BRPCClient::AsyncSendFetchBarrier(const std::string& ep, VarHandlePtr BRPCClient::AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out) { int64_t time_out) {
auto ch_ptr = GetChannel(ep);
auto ch_ctx = ch_ptr->Pop();
brpc::Controller* cntl = new brpc::Controller();
sendrecv::VariableMessage* response = new sendrecv::VariableMessage();
cntl->set_timeout_ms(time_out);
sendrecv::VariableMessage req;
req.set_varname(FETCH_BARRIER_MESSAGE);
const std::string method = "FetchBarrierRPC";
// var handle
VarHandlePtr var_h(
new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
platform::RecordRPCEvent record_event(method, nullptr);
google::protobuf::Closure* done = brpc::NewCallback(
&HandleFetchBarrierResponse, cntl, response, var_h, ch_ptr, ch_ctx, this);
ch_ctx->stub->GetVariable(cntl, &req, response, done);
req_count_++; req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
var_h->Wait();
}
return var_h;
} }
void BRPCClient::Wait() { bool BRPCClient::Wait() {
std::unique_lock<std::mutex> lk(sync_mutex_); VLOG(9) << "begin to brpcclient wait";
sync_cond_.wait(lk, [this] { return req_count_ == 0; }); {
std::unique_lock<std::mutex> lk(sync_mutex_);
sync_cond_.wait(lk, [this] { return req_count_ == 0; });
}
VLOG(9) << "end to brpcclient wait";
return true;
} }
ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) { ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) {
VLOG(4) << "begin to GetChannel:" << ep;
{ {
std::lock_guard<std::mutex> guard(chan_mutex_); std::lock_guard<std::mutex> guard(chan_mutex_);
auto it = channels_.find(ep); auto it = channels_.find(ep);
if (it != channels_.end()) { if (it != channels_.end()) {
VLOG(4) << "end to GetChannel:" << ep;
return it->second; return it->second;
} }
} }
...@@ -150,12 +331,20 @@ ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) { ...@@ -150,12 +331,20 @@ ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) {
ChannelQueuePtr q(new framework::BlockingQueue<ChannelContextPtr>()); ChannelQueuePtr q(new framework::BlockingQueue<ChannelContextPtr>());
brpc::ChannelOptions options; brpc::ChannelOptions options;
#ifdef PADDLE_WITH_BRPC_RDMA
options.use_rdma = true;
#endif
options.protocol = "baidu_std"; options.protocol = "baidu_std";
options.connection_type = "pooled"; // don't use pooled type. the server can't afford that.
options.connect_timeout_ms = 100; options.connection_type = "single";
options.connect_timeout_ms = 1000;
options.timeout_ms = FLAGS_timeout_ms /*milliseconds*/; options.timeout_ms = FLAGS_timeout_ms /*milliseconds*/;
options.max_retry = FLAGS_max_retry; options.max_retry = FLAGS_max_retry;
for (int i = 0; i < FLAGS_brpc_channel_num; ++i) {
VLOG(1) << "create " << brpc_channel_num_per_server_
<< " brpc channels to pserver:" << ep;
for (int i = 0; i < brpc_channel_num_per_server_; ++i) {
std::shared_ptr<ChannelContext> c(new ChannelContext()); std::shared_ptr<ChannelContext> c(new ChannelContext());
if (c->channel.Init(ep.c_str(), &options) != 0) { if (c->channel.Init(ep.c_str(), &options) != 0) {
LOG(FATAL) << "Fail to initialize channel"; LOG(FATAL) << "Fail to initialize channel";
...@@ -172,9 +361,75 @@ ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) { ...@@ -172,9 +361,75 @@ ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) {
channels_[ep] = q; channels_[ep] = q;
} }
VLOG(4) << "end to GetChannel:" << ep;
return q; return q;
} }
VarHandlePtr BRPCClient::AsyncSendComplete(const std::string& ep,
int64_t time_out) {
return AsyncSendMessage(ep, "SendCompleteRPC", COMPLETE_MESSAGE, time_out);
}
void BRPCClient::SendComplete() {
for (auto& kv : channels_) {
AsyncSendComplete(kv.first);
}
}
VarHandlePtr BRPCClient::AsyncSendVarMessage(
const std::string& ep, const std::string& method_name,
const sendrecv::VariableMessage& req, int64_t time_out) {
auto ch_ptr = GetChannel(ep);
auto ch_ctx = ch_ptr->Pop();
brpc::Controller* cntl = new brpc::Controller();
sendrecv::VoidMessage* response = new sendrecv::VoidMessage();
cntl->set_timeout_ms(time_out);
platform::RecordRPCEvent record_event(method_name, nullptr);
VarHandlePtr var_h(
new VarHandle(ep, method_name, req.varname(), nullptr, nullptr));
google::protobuf::Closure* done = brpc::NewCallback(
&HandleSendResponse, cntl, response, var_h, ch_ptr, ch_ctx, this);
if (method_name == "CheckPointNotifyRPC") {
ch_ctx->stub->CheckpointNotify(cntl, &req, response, done);
} else if (method_name == "GetMonomerBarrier") {
ch_ctx->stub->GetMonomerBarrier(cntl, &req, response, done);
} else {
ch_ctx->stub->SendVariable(cntl, &req, response, done);
}
req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
var_h->Wait();
}
return var_h;
}
VarHandlePtr BRPCClient::AsyncSendMessage(const std::string& ep,
const std::string& method_name,
const std::string& message,
int64_t time_out) {
sendrecv::VariableMessage req;
req.set_varname(message);
return AsyncSendVarMessage(ep, method_name, req, time_out);
}
VarHandlePtr BRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dir,
int64_t time_out) {
sendrecv::VariableMessage req;
req.set_varname(CHECKPOINT_SAVE_MESSAGE);
req.set_out_varname(dir);
return AsyncSendVarMessage(ep, "CheckPointNotifyRPC", req, time_out);
}
} // namespace distributed } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -31,6 +31,8 @@ limitations under the License. */ ...@@ -31,6 +31,8 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/distributed/rpc_client.h" #include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h" #include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN #include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
...@@ -53,33 +55,94 @@ class BRPCClient : public RPCClient { ...@@ -53,33 +55,94 @@ class BRPCClient : public RPCClient {
BRPCClient() {} BRPCClient() {}
virtual ~BRPCClient(); virtual ~BRPCClient();
bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx, VarHandlePtr AsyncSendVar(const std::string& ep,
const framework::Scope& scope, const std::string& var_name, const platform::DeviceContext& ctx,
int64_t time_out = FLAGS_rpc_deadline) override; const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) override;
bool AsyncGetVar(const std::string& ep, const platform::DeviceContext& ctx, VarHandlePtr AsyncGetVar(const std::string& ep,
const framework::Scope& scope, const std::string& var_name, const platform::DeviceContext& ctx,
int64_t time_out = FLAGS_rpc_deadline) override; const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) override;
bool AsyncPrefetchVar(const std::string& ep, VarHandlePtr AsyncGetMonomerBarrier(
const platform::DeviceContext& ctx, const std::string& ep, const std::string& var_name,
const framework::Scope& scope, int64_t time_out = FLAGS_rpc_deadline) override;
const std::string& in_var_name,
const std::string& out_var_name,
int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncSendBatchBarrier(const std::string& ep, VarHandlePtr AsyncGetMonomerVariable(
int64_t time_out = FLAGS_rpc_deadline) override; const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncSendFetchBarrier(const std::string& ep, VarHandlePtr AsyncPrefetchVar(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) override; const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
const std::string& table_name = "",
int64_t time_out = FLAGS_rpc_deadline) override;
void Wait() override; VarHandlePtr AsyncSendBatchBarrier(
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncSendFetchBarrier(
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncCheckpointNotify(
const std::string& ep, const std::string& dir,
int64_t time_out = FLAGS_rpc_deadline) override;
bool Wait() override;
void SendComplete() override;
private: private:
VarHandlePtr _AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
const std::string& method_name,
int64_t time_out = FLAGS_rpc_deadline);
void Proceed(); void Proceed();
ChannelQueuePtr GetChannel(const std::string& ep); ChannelQueuePtr GetChannel(const std::string& ep);
VarHandlePtr AsyncSendComplete(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline);
VarHandlePtr AsyncSendMessage(const std::string& ep,
const std::string& method_name,
const std::string& message, int64_t time_out);
VarHandlePtr AsyncSendVarMessage(const std::string& ep,
const std::string& method_name,
const sendrecv::VariableMessage& req,
int64_t time_out);
friend void HandleSendResponse(brpc::Controller* cntl,
sendrecv::VoidMessage* response,
VarHandlePtr var_h, ChannelQueuePtr ch_ptr,
ChannelContextPtr ch_ctx, BRPCClient* cls);
friend void HandleGetResponse(brpc::Controller* cntl,
sendrecv::VariableMessage* response,
VarHandlePtr var_h, ChannelQueuePtr ch_ptr,
ChannelContextPtr ch_ctx, BRPCClient* cls);
friend void HandleFetchBarrierResponse(brpc::Controller* cntl,
sendrecv::VariableMessage* response,
VarHandlePtr var_h,
ChannelQueuePtr ch_ptr,
ChannelContextPtr ch_ctx,
BRPCClient* cls);
void DecreaseReqCount() {
if (--req_count_ <= 0) {
sync_cond_.notify_all();
}
}
private: private:
std::unordered_map<std::string, ChannelQueuePtr> channels_; std::unordered_map<std::string, ChannelQueuePtr> channels_;
...@@ -88,6 +151,8 @@ class BRPCClient : public RPCClient { ...@@ -88,6 +151,8 @@ class BRPCClient : public RPCClient {
std::condition_variable sync_cond_; std::condition_variable sync_cond_;
std::atomic<int64_t> req_count_{0}; std::atomic<int64_t> req_count_{0};
static constexpr int brpc_channel_num_per_server_ = 4;
// mutex for GetChannel thread safety // mutex for GetChannel thread safety
std::mutex chan_mutex_; std::mutex chan_mutex_;
DISABLE_COPY_AND_ASSIGN(BRPCClient); DISABLE_COPY_AND_ASSIGN(BRPCClient);
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_BRPC_RDMA
#include "paddle/fluid/operators/distributed/brpc_rdma_pool.h"
#include "brpc/channel.h"
#include "brpc/rdma/rdma_helper.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace distributed {
RdmaMemPool& RdmaMemPool::Instance() {
static RdmaMemPool* g_rdma_mem_pool = new RdmaMemPool();
return *g_rdma_mem_pool;
}
void* RdmaMemPool::Find(const std::string& varname, int64_t size) {
pthread_rwlock_rdlock(&access_);
auto it = pool_.find(varname);
if (it == pool_.end()) {
pthread_rwlock_unlock(&access_);
return nullptr;
}
auto info = it->second;
if (info.data_size != size) {
pthread_rwlock_unlock(&access_);
PADDLE_ENFORCE(false, "var:%s size:%ld != %ld", varname, size,
info.data_size);
return nullptr;
}
pthread_rwlock_unlock(&access_);
return info.data;
}
void RdmaMemPool::Register(const std::string& varname, void* data,
int64_t data_size) {
void* old = Find(varname, data_size);
if (old != nullptr) {
if (data != old) {
PADDLE_ENFORCE(false, "var:%s data:%ld != %ld", varname, data, old);
}
VLOG(7) << "Find on rdma:" << varname << " data:" << data
<< " data_size:" << data_size;
return;
}
VarInfo info;
info.data = data;
info.data_size = data_size;
pthread_rwlock_wrlock(&access_);
pool_[varname] = info;
pthread_rwlock_unlock(&access_);
if (brpc::rdma::RegisterMemoryForRdma(data, data_size)) {
LOG(FATAL) << "register " << varname << " data:" << data
<< " data_size:" << data_size << " error";
}
VLOG(4) << "register on rdma:" << varname << " data:" << data
<< " data_size:" << data_size;
}
} // namespace distributed
} // namespace operators
} // namespace paddle
#endif
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_BRPC_RDMA
#include <pthread.h> // NOLINT
#include <string>
#include <unordered_map>
namespace paddle {
namespace operators {
namespace distributed {
/*
* This class is used to avoid duplicated registion of brpc::rdma.
*/
class RdmaMemPool {
public:
static RdmaMemPool& Instance();
RdmaMemPool() : access_(PTHREAD_RWLOCK_INITIALIZER) {}
virtual ~RdmaMemPool() { pthread_rwlock_destroy(&access_); }
void Register(const std::string& varname, void* data, int64_t size);
void* Find(const std::string& varname, int64_t size);
private:
struct VarInfo {
void* data;
int64_t data_size;
VarInfo() : data(nullptr), data_size(0) {}
};
private:
std::unordered_map<std::string, VarInfo> pool_;
pthread_rwlock_t access_;
};
} // namespace distributed
} // namespace operators
} // namespace paddle
#endif
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_CUDA
#include <nccl.h>
#endif
#include <sys/time.h>
#include <thread> // NOLINT
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/distributed/brpc_rdma_pool.h"
#include "paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/brpc_variable_response.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
namespace distributed {
class IOBufWriter {
public:
static void Append(butil::IOBuf* iobuf, int k, const char* v, int64_t vlen) {
iobuf->append(reinterpret_cast<char*>(&k), 4);
iobuf->append(reinterpret_cast<char*>(&vlen), 8);
iobuf->append(v, vlen);
}
static void AppendTCPZeroCopy(butil::IOBuf* iobuf, int k, const char* v,
int64_t vlen, bool in_cuda_pinned,
void (*destroy)(void*), void* user_data) {
VLOG(7) << "AppendTCPZeroCopy "
<< " k:" << k
<< " data:" << static_cast<void*>(const_cast<char*>(v))
<< " data_size:" << vlen << " in_cuda_pinned:" << in_cuda_pinned;
iobuf->append(reinterpret_cast<char*>(&k), 4);
iobuf->append(reinterpret_cast<char*>(&vlen), 8);
// FIXME(gongwb): use append_zerocopy
/*
if (in_cuda_pinned) {
iobuf->append_zerocopy(v, vlen, IOBufWriter::FreeMemory);
} else {
iobuf->append_zerocopy(v, vlen, nullptr);
}
*/
iobuf->append(v, vlen);
destroy(user_data);
}
#ifdef PADDLE_WITH_BRPC_RDMA
static void AppendRdmaZeroCopy(const std::string varname, butil::IOBuf* iobuf,
int k, const char* v, int64_t vlen,
bool in_cuda_pinned, void (*destroy)(void*),
void* user_data) {
VLOG(7) << "AppendRdmaZeroCopy varname:" << varname << " k:" << k
<< " data:" << static_cast<void*>(const_cast<char*>(v))
<< " data_size:" << vlen << " in_cuda_pinned:" << in_cuda_pinned;
iobuf->append(reinterpret_cast<char*>(&k), 4);
iobuf->append(reinterpret_cast<char*>(&vlen), 8);
RdmaMemPool::Instance().Register(
varname, static_cast<void*>(const_cast<char*>(v)), vlen);
// FIXME(gongwb): use append_zerocopy
// iobuf->append_zerocopy(v, vlen, nullptr);
iobuf->append(v, vlen);
destroy(user_data);
return;
}
#endif
static void AppendZeroCopy(const std::string varname, butil::IOBuf* iobuf,
int k, const char* v, int64_t vlen,
bool in_cuda_pinned, void (*destroy)(void*),
void* user_data) {
#ifdef PADDLE_WITH_BRPC_RDMA
IOBufWriter::AppendRdmaZeroCopy(varname, iobuf, k, v, vlen, in_cuda_pinned,
destroy, user_data);
#else
IOBufWriter::AppendTCPZeroCopy(iobuf, k, v, vlen, in_cuda_pinned, destroy,
user_data);
#endif
}
};
void SerializeToIOBuf(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* request,
butil::IOBuf* iobuf, const std::string& out_varname,
bool var_is_not_stable, int trainer_id,
const std::string& table_name) {
std::unique_ptr<TensorPayload> payload;
request->set_varname(name);
request->set_trainer_id(trainer_id);
// Note: normally the profiler is enabled in 1 trainer, hence only
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
// servers the trainer's profiling state so that PS can follow the
// trainer.
if (platform::ShouldSendProfileState()) {
if (platform::IsProfileEnabled()) {
request->set_profile(platform::kEnableProfiler);
} else {
request->set_profile(platform::kDisableProfiler);
}
}
if (!out_varname.empty()) {
request->set_out_varname(out_varname);
}
if (!table_name.empty()) {
request->set_table_name(table_name);
}
if (var->IsType<framework::LoDTensor>()) {
request->set_type(::sendrecv::LOD_TENSOR);
payload.reset(new TensorPayload(GetTensorPayload(var, ctx, request)));
} else if (var->IsType<framework::SelectedRows>()) {
request->set_type(::sendrecv::SELECTED_ROWS);
payload.reset(new TensorPayload(GetSelectedRowsPayload(var, ctx, request)));
#ifdef PADDLE_WITH_CUDA
} else if (var->IsType<ncclUniqueId>()) {
request->set_type(::sendrecv::NCCL_ID);
const ncclUniqueId& uid = var->Get<ncclUniqueId>();
// TODO(gongwb): use append_zero to avoid data copy.
IOBufWriter::Append(iobuf,
sendrecv::VariableMessage::kSerializedFieldNumber,
uid.internal, NCCL_UNIQUE_ID_BYTES);
return;
#endif
} else {
PADDLE_THROW("Serialize does not support type: %s",
typeid(var->Type()).name());
}
PADDLE_ENFORCE_NOT_NULL(payload);
// FIXME(gongwb): it seems that can use zero copy.
if (var_is_not_stable) {
IOBufWriter::Append(
iobuf, ::sendrecv::VariableMessage::kSerializedFieldNumber,
static_cast<const char*>(payload->ptr()), payload->memory_size());
} else {
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
IOBufWriter::AppendZeroCopy(
name, iobuf, ::sendrecv::VariableMessage::kSerializedFieldNumber,
static_cast<const char*>(payload->ptr()), payload->memory_size(),
true, SerializeDestroyCallback, static_cast<void*>(payload.get()));
payload.release();
#endif
} else {
IOBufWriter::AppendZeroCopy(
name, iobuf, ::sendrecv::VariableMessage::kSerializedFieldNumber,
static_cast<const char*>(payload->ptr()), payload->memory_size(),
false, SerializeDestroyCallback, static_cast<void*>(payload.get()));
payload.release();
}
}
if (var->IsType<framework::SelectedRows>()) {
auto* slr = var->GetMutable<framework::SelectedRows>();
size_t rows_memory_size =
slr->rows().size() * framework::SizeOfType(typeid(int64_t));
IOBufWriter::Append(iobuf, ::sendrecv::VariableMessage::kRowsFieldNumber,
reinterpret_cast<const char*>(slr->rows().data()),
static_cast<int64_t>(rows_memory_size));
}
}
void DeserializeFromIOBuf(const ::sendrecv::VariableMessage& meta,
const butil::IOBuf& iobuf,
const platform::DeviceContext& ctx,
const framework::Scope* scope,
framework::Variable** var, int* trainer_id) {
operators::distributed::BRPCVariableResponse resp(scope, &ctx);
PADDLE_ENFORCE(resp.Parse(iobuf, meta) == 0, "parse iobuf to tensor error!");
*var = resp.GetVar();
*trainer_id = resp.GetTrainerId();
}
} // namespace distributed
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <sys/time.h>
#include <iostream>
#include <string>
#include <vector>
#include "brpc/channel.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
namespace paddle {
namespace operators {
namespace distributed {
void SerializeToIOBuf(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* request,
butil::IOBuf* iobuf, const std::string& out_varname,
bool var_is_not_stable, const int trainer_id = 0,
const std::string& table_name = std::string());
void DeserializeFromIOBuf(const VarMsg& meta, const butil::IOBuf& iobuf,
const platform::DeviceContext& ctx,
const framework::Scope* scope,
framework::Variable** var, int* trainer_id);
} // namespace distributed
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <unistd.h>
#include <string>
#include <thread> // NOLINT
#include "brpc/channel.h"
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/brpc_variable_response.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h"
namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace operators = paddle::operators;
namespace math = paddle::operators::math;
namespace memory = paddle::memory;
void RunSerdeTestSelectedRows(platform::Place place) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
butil::IOBuf iobuf;
sendrecv::VariableMessage msg;
int tensor_numel = 564 * 128;
// serialize var to IOBuf
{
framework::Variable var;
auto* slr = var.GetMutable<framework::SelectedRows>();
slr->set_height(1000);
auto* tensor = slr->mutable_value();
auto* rows = slr->mutable_rows();
tensor->Resize(framework::make_ddim({564, 128}));
tensor->mutable_data<float>(place);
math::set_constant(ctx, tensor, 32.7);
for (int i = 0; i < 564; ++i) rows->push_back(i);
operators::distributed::SerializeToIOBuf("myvar", &var, ctx, &msg, &iobuf,
"", false);
}
// desrialize
{
framework::Scope scope;
scope.Var("myvar");
operators::distributed::BRPCVariableResponse resp(&scope, &ctx);
EXPECT_EQ(resp.Parse(iobuf, msg), 0);
framework::Variable* var2 = resp.GetVar();
auto* slr2 = var2->GetMutable<framework::SelectedRows>();
auto* tensor2 = slr2->mutable_value();
auto* rows2 = slr2->mutable_rows();
float* tensor_data2 = nullptr;
framework::Tensor tmp_tensor;
if (platform::is_gpu_place(ctx.GetPlace())) {
platform::CPUPlace cpu;
framework::TensorCopy(*tensor2, cpu, &tmp_tensor);
tensor_data2 = tmp_tensor.data<float>();
} else {
tensor_data2 = const_cast<float*>(tensor2->data<float>());
}
const int64_t* rows_data2 = rows2->data();
for (int i = 0; i < tensor_numel; ++i) {
EXPECT_FLOAT_EQ(tensor_data2[i], 32.7);
}
for (size_t i = 0; i < rows2->size(); ++i) {
EXPECT_EQ(rows_data2[i], static_cast<int64_t>(i));
}
EXPECT_EQ(slr2->height(), 1000);
}
}
void RunTestLodTensor(platform::Place place) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
// serialize var to ByteBuffer
butil::IOBuf iobuf;
sendrecv::VariableMessage msg;
int tensor_numel = 512 * 8 * 4 * 2;
{
framework::Variable var;
auto* tensor = var.GetMutable<framework::LoDTensor>();
tensor->Resize(framework::make_ddim({512, 8, 4, 2}));
framework::LoD lod;
lod.push_back(framework::Vector<size_t>({1, 3, 8}));
tensor->set_lod(lod);
tensor->mutable_data<float>(place);
math::set_constant(ctx, tensor, 31.9);
operators::distributed::SerializeToIOBuf("myvar", &var, ctx, &msg, &iobuf,
"", false);
}
// check sendrecv::VariableMessage meta data
{
EXPECT_EQ(msg.varname(), "myvar");
EXPECT_EQ(msg.type(), 0);
EXPECT_EQ(msg.dims()[0], 512);
EXPECT_EQ(msg.dims()[1], 8);
EXPECT_EQ(msg.dims()[2], 4);
EXPECT_EQ(msg.dims()[3], 2);
EXPECT_EQ(msg.lod_level(), 1);
EXPECT_EQ(msg.lod(0).lod_data(0), 1);
EXPECT_EQ(msg.lod(0).lod_data(1), 3);
EXPECT_EQ(msg.lod(0).lod_data(2), 8);
}
// deserialize
{
framework::Scope scope;
scope.Var("myvar");
operators::distributed::BRPCVariableResponse resp(&scope, &ctx);
EXPECT_EQ(resp.Parse(iobuf, msg), 0);
framework::Variable* var2 = resp.GetVar();
auto tensor2 = var2->Get<framework::LoDTensor>();
float* tensor_data2 = nullptr;
framework::Tensor tmp_tensor;
if (platform::is_gpu_place(ctx.GetPlace())) {
platform::CPUPlace cpu;
framework::TensorCopy(tensor2, cpu, &tmp_tensor);
tensor_data2 = tmp_tensor.data<float>();
} else {
tensor_data2 = const_cast<float*>(tensor2.data<float>());
}
for (int i = 0; i < tensor_numel; ++i)
EXPECT_FLOAT_EQ(tensor_data2[i], 31.9);
}
}
TEST(LodTensor, Run) {
platform::CPUPlace place;
RunTestLodTensor(place);
#ifdef PADDLE_WITH_CUDA
platform::CUDAPlace gpu(0);
RunTestLodTensor(gpu);
#endif
}
TEST(SelectedRows, Run) {
platform::CPUPlace place;
RunSerdeTestSelectedRows(place);
#ifdef PADDLE_WITH_CUDA
platform::CUDAPlace gpu;
RunSerdeTestSelectedRows(gpu);
#endif
}
...@@ -13,84 +13,287 @@ ...@@ -13,84 +13,287 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/distributed/brpc_server.h" #include "paddle/fluid/operators/distributed/brpc_server.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/brpc_variable_response.h"
#include "paddle/fluid/operators/distributed/request_handler.h" #include "paddle/fluid/operators/distributed/request_handler.h"
namespace sendrecv { namespace sendrecv {
typedef std::unordered_map<std::string, namespace distributed = paddle::operators::distributed;
paddle::operators::distributed::RequestHandler*>
typedef std::unordered_map<std::string, distributed::RequestHandler*>
HandlerMap; HandlerMap;
class BRPCServiceImpl : public SendRecvService { class BRPCServiceImpl : public SendRecvService {
public: public:
explicit BRPCServiceImpl(const HandlerMap& rpc_call_map) explicit BRPCServiceImpl(const HandlerMap& rpc_call_map,
: request_send_h_(nullptr), distributed::RPCServer* rpc_server)
request_get_h_(nullptr), : rpc_server_(rpc_server) {
request_prefetch_h_(nullptr) { VLOG(3) << "BRPCServiceImpl size: " << rpc_call_map.size();
auto it = rpc_call_map.find(paddle::operators::distributed::kRequestSend); auto it = rpc_call_map.find(distributed::kRequestSend);
if (it != rpc_call_map.end()) { if (it != rpc_call_map.end()) {
request_send_h_ = it->second; request_send_h_ = it->second;
send_threads_.reset(new paddle::framework::ThreadPool(
rpc_server_->GetThreadNum(distributed::kRequestSend)));
} }
it = rpc_call_map.find(paddle::operators::distributed::kRequestSend); it = rpc_call_map.find(distributed::kRequestGet);
if (it != rpc_call_map.end()) { if (it != rpc_call_map.end()) {
request_get_h_ = it->second; request_get_h_ = it->second;
get_threads_.reset(new paddle::framework::ThreadPool(
rpc_server_->GetThreadNum(distributed::kRequestGet)));
} }
it = rpc_call_map.find(paddle::operators::distributed::kRequestPrefetch); it = rpc_call_map.find(distributed::kRequestPrefetch);
if (it != rpc_call_map.end()) { if (it != rpc_call_map.end()) {
request_prefetch_h_ = it->second; request_prefetch_h_ = it->second;
prefetch_threads_.reset(new paddle::framework::ThreadPool(
rpc_server_->GetThreadNum(distributed::kRequestPrefetch)));
}
it = rpc_call_map.find(distributed::kRequestCheckpoint);
if (it != rpc_call_map.end()) {
request_checkpoint_h_ = it->second;
checkpoint_notify_threads_.reset(new paddle::framework::ThreadPool(
rpc_server_->GetThreadNum(distributed::kRequestPrefetch)));
}
it = rpc_call_map.find(distributed::kRequestGetMonomerVariable);
if (it != rpc_call_map.end()) {
request_get_monomer_handler_h_ = it->second;
}
it = rpc_call_map.find(distributed::kRequestGetMonomerBarrier);
if (it != rpc_call_map.end()) {
request_get_monomer_barrier_handler_h_ = it->second;
} }
} }
virtual ~BRPCServiceImpl() {} virtual ~BRPCServiceImpl() {}
void SendVariable(google::protobuf::RpcController* cntl_butil, void SendVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VoidMessage* response, const VariableMessage* request, VoidMessage* response,
google::protobuf::Closure* done) override { google::protobuf::Closure* done) override {
send_threads_->Run(
[=] { _SendVariable(cntl_butil, request, response, done); });
}
void _SendVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VoidMessage* response,
google::protobuf::Closure* done) {
PADDLE_ENFORCE(request_send_h_ != nullptr, PADDLE_ENFORCE(request_send_h_ != nullptr,
"RequestSend handler should be registed first!"); "RequestSend handler should be registed first!");
brpc::ClosureGuard done_guard(done); brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
paddle::framework::Scope* local_scope = request_send_h_->scope();
paddle::framework::Variable* outvar = nullptr;
paddle::framework::Variable* invar = nullptr;
std::string varname = request->varname(); std::string varname = request->varname();
VLOG(3) << "RequestSend var_name:" << varname
<< ", trainer_id:" << request->trainer_id()
<< ", from:" << cntl->remote_side();
if (!request_send_h_->sync_mode()) { distributed::BRPCVariableResponse resp(request_send_h_->scope(),
local_scope = &request_send_h_->scope()->NewScope(); request_send_h_->dev_ctx(),
invar = local_scope->Var(varname); !request_send_h_->sync_mode());
} else { PADDLE_ENFORCE(resp.Parse(cntl->request_attachment(), *request) == 0,
invar = local_scope->FindVar(varname); "parse iobuf to tensor error!");
}
request_send_h_->Handle(varname, local_scope, invar, &outvar); auto scope = resp.GetMutableLocalScope();
auto invar = resp.GetVar();
int trainer_id = request->trainer_id();
paddle::framework::Variable* outvar = nullptr;
if (!request_send_h_->sync_mode()) { request_send_h_->Handle(varname, scope, invar, &outvar, trainer_id);
request_send_h_->scope()->DeleteScope(local_scope);
}
} }
void GetVariable(google::protobuf::RpcController* cntl_butil, void GetVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VariableMessage* response, const VariableMessage* request, VariableMessage* response,
google::protobuf::Closure* done) override { google::protobuf::Closure* done) override {
get_threads_->Run(
[=] { _GetVariable(cntl_butil, request, response, done); });
}
void _GetVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VariableMessage* response,
google::protobuf::Closure* done) {
PADDLE_ENFORCE(request_get_h_ != nullptr, PADDLE_ENFORCE(request_get_h_ != nullptr,
"RequestGet handler should be registed first!"); "RequestGet handler should be registed first!");
}
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
std::string varname = request->varname();
VLOG(3) << "RequestGet varname:" << varname
<< ", trainer_id:" << request->trainer_id()
<< ", from:" << cntl->remote_side();
auto scope = request_get_h_->scope();
auto invar = scope->FindVar(varname);
int trainer_id = request->trainer_id();
paddle::framework::Variable* outvar = nullptr;
request_get_h_->Handle(varname, scope, invar, &outvar, trainer_id);
if (outvar) {
distributed::SerializeToIOBuf(varname, outvar, *request_get_h_->dev_ctx(),
response, &cntl->response_attachment(), "",
false);
}
}
void PrefetchVariable(google::protobuf::RpcController* cntl_butil, void PrefetchVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, const VariableMessage* request,
VariableMessage* response, VariableMessage* response,
google::protobuf::Closure* done) override { google::protobuf::Closure* done) override {
prefetch_threads_->Run(
[=] { _PrefetchVariable(cntl_butil, request, response, done); });
}
void _PrefetchVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request,
VariableMessage* response,
google::protobuf::Closure* done) {
PADDLE_ENFORCE(request_prefetch_h_ != nullptr, PADDLE_ENFORCE(request_prefetch_h_ != nullptr,
"kRequestPrefetch handler should be registed first!"); "kRequestPrefetch handler should be registed first!");
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
// prefetch process...
std::string in_var_name = request->varname();
std::string out_var_name = request->out_varname();
VLOG(3) << "RequestPrefetch, in_var_name: " << in_var_name
<< ", out_var_name: " << out_var_name
<< ", trainer_id:" << request->trainer_id()
<< ", from:" << cntl->remote_side();
distributed::BRPCVariableResponse resp(
request_prefetch_h_->scope(), request_prefetch_h_->dev_ctx(), true);
PADDLE_ENFORCE(resp.Parse(cntl->request_attachment(), *request) == 0,
"parse iobuf to tensor error!");
auto scope = resp.GetMutableLocalScope();
auto invar = scope->FindVar(in_var_name);
std::string table_name = request->table_name();
int trainer_id = request->trainer_id();
paddle::framework::Variable* outvar = scope->Var(out_var_name);
request_prefetch_h_->Handle(in_var_name, scope, invar, &outvar, trainer_id,
out_var_name, table_name);
distributed::SerializeToIOBuf(out_var_name, outvar,
*request_prefetch_h_->dev_ctx(), response,
&cntl->response_attachment(), "", true);
}
void CheckpointNotify(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VoidMessage* response,
google::protobuf::Closure* done) override {
checkpoint_notify_threads_->Run(
[=] { _CheckpointNotify(cntl_butil, request, response, done); });
}
void _CheckpointNotify(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VoidMessage* response,
google::protobuf::Closure* done) {
PADDLE_ENFORCE(
request_checkpoint_h_ != nullptr,
"kRequestCheckpointNotify handler should be registed first!");
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
distributed::BRPCVariableResponse resp(request_checkpoint_h_->scope(),
request_checkpoint_h_->dev_ctx());
auto scope = resp.GetMutableLocalScope();
std::string checkpoint_notify = request->varname();
std::string checkpoint_dir = request->out_varname();
int trainer_id = request->trainer_id();
VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify
<< ", dir: " << checkpoint_dir
<< ", trainer_id:" << request->trainer_id()
<< ", from:" << cntl->remote_side();
request_checkpoint_h_->Handle(checkpoint_notify, scope, nullptr, nullptr,
trainer_id, checkpoint_dir);
}
void GetMonomerVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request,
VariableMessage* response,
google::protobuf::Closure* done) override {
PADDLE_ENFORCE(
request_get_monomer_handler_h_ != nullptr,
"kRequestGetMonomerVariable handler should be registed first!");
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
// proc request.
std::string varname = request->varname();
VLOG(3) << "GetMonomerVariable " << varname
<< ", trainer_id:" << request->trainer_id()
<< ", from:" << cntl->remote_side();
rpc_server_->WaitVarCond(varname);
distributed::MonomerHandle h = rpc_server_->GetMonomer(varname);
auto scope = h.scope_;
auto invar = scope->FindVar(varname);
paddle::framework::Variable* outvar = nullptr;
request_get_monomer_handler_h_->Handle(varname, scope, invar, &outvar,
request->trainer_id());
if (outvar) {
distributed::SerializeToIOBuf(varname, outvar, *h.dev_ctx_, response,
&cntl->response_attachment(), "", false);
}
}
void GetMonomerBarrier(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VoidMessage* response,
google::protobuf::Closure* done) override {
PADDLE_ENFORCE(
request_get_monomer_barrier_handler_h_ != nullptr,
"RequestGetMonomerBarrier handler should be registed first!");
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
std::string varname = request->varname();
VLOG(3) << "RequestGetMonomerBarrier var_name:" << varname
<< ", trainer_id:" << request->trainer_id()
<< ", from:" << cntl->remote_side();
rpc_server_->WaitVarCond(varname);
distributed::MonomerHandle h = rpc_server_->GetMonomer(varname);
paddle::framework::Scope* scope = nullptr;
paddle::framework::Variable* invar = nullptr;
paddle::framework::Variable* outvar = nullptr;
request_get_monomer_barrier_handler_h_->Handle(
varname, scope, invar, &outvar, request->trainer_id());
} }
private: private:
paddle::operators::distributed::RequestHandler* request_send_h_; distributed::RequestHandler* request_send_h_{nullptr};
paddle::operators::distributed::RequestHandler* request_get_h_; distributed::RequestHandler* request_get_h_{nullptr};
paddle::operators::distributed::RequestHandler* request_prefetch_h_; distributed::RequestHandler* request_prefetch_h_{nullptr};
distributed::RequestHandler* request_checkpoint_h_{nullptr};
distributed::RequestHandler* request_get_monomer_handler_h_{nullptr};
distributed::RequestHandler* request_get_monomer_barrier_handler_h_{nullptr};
distributed::RPCServer* rpc_server_{nullptr};
// FIXME(gongwb): brpc should support process one rpce use one threadpool.
std::unique_ptr<paddle::framework::ThreadPool> send_threads_;
std::unique_ptr<paddle::framework::ThreadPool> get_threads_;
std::unique_ptr<paddle::framework::ThreadPool> prefetch_threads_;
std::unique_ptr<paddle::framework::ThreadPool> checkpoint_notify_threads_;
}; };
} // namespace sendrecv } // namespace sendrecv
...@@ -100,7 +303,7 @@ namespace distributed { ...@@ -100,7 +303,7 @@ namespace distributed {
void AsyncBRPCServer::StartServer() { void AsyncBRPCServer::StartServer() {
// Instance of your service. // Instance of your service.
sendrecv::BRPCServiceImpl service_impl(rpc_call_map_); sendrecv::BRPCServiceImpl service_impl(rpc_call_map_, this);
// Add the service into server. Notice the second parameter, because the // Add the service into server. Notice the second parameter, because the
// service is put on stack, we don't want server to delete it, otherwise // service is put on stack, we don't want server to delete it, otherwise
...@@ -111,6 +314,9 @@ void AsyncBRPCServer::StartServer() { ...@@ -111,6 +314,9 @@ void AsyncBRPCServer::StartServer() {
} }
brpc::ServerOptions options; brpc::ServerOptions options;
#ifdef PADDLE_WITH_BRPC_RDMA
options.use_rdma = true;
#endif
options.idle_timeout_sec = idle_timeout_s_; options.idle_timeout_sec = idle_timeout_s_;
options.max_concurrency = max_concurrency_; options.max_concurrency = max_concurrency_;
if (server_.Start(bind_address_.c_str(), &options) != 0) { if (server_.Start(bind_address_.c_str(), &options) != 0) {
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "paddle/fluid/operators/distributed/brpc_variable_response.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
namespace paddle {
namespace operators {
namespace distributed {
namespace pb = ::google::protobuf;
using vr = ::sendrecv::VariableMessage;
int BRPCVariableResponse::Parse(Source* source) {
pb::io::ZeroCopyInputStream* input_stream = source->contents();
pb::io::CodedInputStream input(input_stream);
input.SetTotalBytesLimit(INT_MAX, INT_MAX);
while (1) {
unsigned int tag = 0;
if (!input.ReadLittleEndian32(&tag)) {
break;
}
uint64_t num_bytes = 0;
if (!input.ReadLittleEndian64(&num_bytes)) {
break;
}
int field = static_cast<int>(tag);
int ret = field == 0 ? -1 : field;
switch (field) {
case vr::kSerializedFieldNumber: {
if (!ProcSerializedField(field, &input, num_bytes)) {
return ret;
}
break;
}
case vr::kRowsFieldNumber: {
PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS ||
meta_.type() == sendrecv::LOD_TENSOR) &&
meta_.varname() != "",
"meta info should be got first!");
if (!CopySelectRowsData(&input, *dev_ctx_, num_bytes)) {
return ret;
}
break;
}
default: {
PADDLE_ENFORCE(false, "not surpported %u fieldnumber", field);
return ret;
}
}
}
return 0;
}
} // namespace distributed
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "brpc/channel.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
namespace paddle {
namespace operators {
namespace distributed {
class BRPCSourceWrapper : public Source {
public:
explicit BRPCSourceWrapper(const butil::IOBuf& iobuf) : source_(iobuf) {}
::google::protobuf::io::ZeroCopyInputStream* contents() override {
return &source_;
}
private:
butil::IOBufAsZeroCopyInputStream source_;
};
class BRPCVariableResponse : public VariableResponse {
public:
BRPCVariableResponse(const framework::Scope* scope,
const platform::DeviceContext* dev_ctx,
bool create_scope = false)
: VariableResponse(scope, dev_ctx, create_scope) {}
virtual ~BRPCVariableResponse() {}
// parse attachment from iobuf
int Parse(Source* source) override;
int Parse(const butil::IOBuf& iobuf, const sendrecv::VariableMessage& meta) {
BRPCSourceWrapper wrapper(iobuf);
return VariableResponse::Parse(&wrapper, meta);
}
};
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
...@@ -293,8 +293,7 @@ VarHandlePtr GRPCClient::AsyncGetMonomerBarrier(const std::string& ep, ...@@ -293,8 +293,7 @@ VarHandlePtr GRPCClient::AsyncGetMonomerBarrier(const std::string& ep,
const auto ch = GetChannel(ep); const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
const std::string method = "SendMonomerFetchBarrierRPC"; const std::string method = "SendMonomerFetchBarrierRPC";
VarHandlePtr h( VarHandlePtr h(new VarHandle(ep, method, var_name, nullptr, nullptr));
new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out); s->Prepare(h, time_out);
VLOG(30) << s->GetVarHandlePtr()->String() << " begin"; VLOG(30) << s->GetVarHandlePtr()->String() << " begin";
......
...@@ -32,13 +32,6 @@ namespace paddle { ...@@ -32,13 +32,6 @@ namespace paddle {
namespace operators { namespace operators {
namespace distributed { namespace distributed {
static void SerializeDestroyCallback(void* payload) {
if (payload != nullptr) {
auto* shared_payload = reinterpret_cast<TensorPayload*>(payload);
delete shared_payload;
}
}
void SerializeToByteBuffer(const std::string& name, framework::Variable* var, void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg, const std::string& out_name, ::grpc::ByteBuffer* msg, const std::string& out_name,
...@@ -122,8 +115,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -122,8 +115,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
if (var->IsType<framework::SelectedRows>()) { if (var->IsType<framework::SelectedRows>()) {
auto* slr = var->GetMutable<framework::SelectedRows>(); auto* slr = var->GetMutable<framework::SelectedRows>();
ProtoEncodeHelper e2(static_cast<char*>(buf), 128); ProtoEncodeHelper e2(static_cast<char*>(buf), 128);
size_t rows_memory_size = size_t rows_memory_size = slr->rows().size() * sizeof(int64_t);
slr->rows().size() * framework::SizeOfType(typeid(int64_t));
e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size); e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size);
slices[2] = ::grpc::Slice(e2.size()); slices[2] = ::grpc::Slice(e2.size());
memcpy(const_cast<uint8_t*>(slices[2].begin()), e2.data(), e2.size()); memcpy(const_cast<uint8_t*>(slices[2].begin()), e2.data(), e2.size());
......
...@@ -75,6 +75,10 @@ class RPCServer { ...@@ -75,6 +75,10 @@ class RPCServer {
void RegisterRPC(const std::string& rpc_name, RequestHandler* handler, void RegisterRPC(const std::string& rpc_name, RequestHandler* handler,
int thread_num = 5); int thread_num = 5);
int GetThreadNum(const std::string& rpc_name) {
return rpc_thread_num_[rpc_name];
}
// Wait util all the clients have reached the barrier for one // Wait util all the clients have reached the barrier for one
// rpc method. This function should be called in the // rpc method. This function should be called in the
// RequestHandler if you want to run the server/client in a // RequestHandler if you want to run the server/client in a
......
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include <thread> // NOLINT #include <thread> // NOLINT
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/distributed/brpc_rdma_pool.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h" #include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/variable_response.h" #include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/platform/port.h" #include "paddle/fluid/platform/port.h"
...@@ -45,7 +46,6 @@ static TensorPayload GetCommunicationAllocationFromTensor( ...@@ -45,7 +46,6 @@ static TensorPayload GetCommunicationAllocationFromTensor(
memory::Copy(cuda_pinned, result->ptr(), memory::Copy(cuda_pinned, result->ptr(),
boost::get<platform::CUDAPlace>(tensor.place()), boost::get<platform::CUDAPlace>(tensor.place()),
tensor.data<void>(), copy_size, gpu_dev_ctx.stream()); tensor.data<void>(), copy_size, gpu_dev_ctx.stream());
ctx.Wait(); ctx.Wait();
return TensorPayload(result); return TensorPayload(result);
#else #else
...@@ -61,8 +61,7 @@ TensorPayload GetTensorPayload(framework::Variable* var, ...@@ -61,8 +61,7 @@ TensorPayload GetTensorPayload(framework::Variable* var,
auto tensor = var->Get<framework::LoDTensor>(); auto tensor = var->Get<framework::LoDTensor>();
// FIXME(wuyi): data types in send_recv.proto is copied from // FIXME(wuyi): data types in send_recv.proto is copied from
// framework.proto // framework.proto
request->set_data_type( request->set_data_type(static_cast<VarMsg::Type>(tensor.type()));
static_cast<VarMsg::Type>(framework::ToDataType(tensor.type())));
for (auto& dim : framework::vectorize(tensor.dims())) { for (auto& dim : framework::vectorize(tensor.dims())) {
request->add_dims(dim); request->add_dims(dim);
} }
...@@ -83,8 +82,7 @@ TensorPayload GetSelectedRowsPayload(framework::Variable* var, ...@@ -83,8 +82,7 @@ TensorPayload GetSelectedRowsPayload(framework::Variable* var,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
VarMsg* request) { VarMsg* request) {
auto* slr = var->GetMutable<framework::SelectedRows>(); auto* slr = var->GetMutable<framework::SelectedRows>();
request->set_data_type( request->set_data_type(static_cast<VarMsg::Type>(slr->value().type()));
static_cast<VarMsg::Type>(framework::ToDataType(slr->value().type())));
request->set_lod_level(0); request->set_lod_level(0);
request->set_slr_height(slr->height()); request->set_slr_height(slr->height());
......
...@@ -50,6 +50,13 @@ class TensorPayload final { ...@@ -50,6 +50,13 @@ class TensorPayload final {
size_t memory_size_; size_t memory_size_;
}; };
inline void SerializeDestroyCallback(void* payload) {
if (payload != nullptr) {
auto* shared_payload = reinterpret_cast<TensorPayload*>(payload);
delete shared_payload;
}
}
TensorPayload GetTensorPayload(framework::Variable* var, TensorPayload GetTensorPayload(framework::Variable* var,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
VarMsg* request); VarMsg* request);
...@@ -58,18 +65,19 @@ TensorPayload GetSelectedRowsPayload(framework::Variable* var, ...@@ -58,18 +65,19 @@ TensorPayload GetSelectedRowsPayload(framework::Variable* var,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
VarMsg* request); VarMsg* request);
inline std::type_index ToTypeIndex(sendrecv::VariableMessage::Type type) { inline framework::proto::VarType::Type ToVarType(
sendrecv::VariableMessage::Type type) {
switch (type) { switch (type) {
case sendrecv::VariableMessage::FP32: case sendrecv::VariableMessage::FP32:
return typeid(float); // NOLINT return framework::proto::VarType::FP32; // NOLINT
case sendrecv::VariableMessage::FP64: case sendrecv::VariableMessage::FP64:
return typeid(double); // NOLINT return framework::proto::VarType::FP64; // NOLINT
case sendrecv::VariableMessage::INT32: case sendrecv::VariableMessage::INT32:
return typeid(int); // NOLINT return framework::proto::VarType::INT32; // NOLINT
case sendrecv::VariableMessage::INT64: case sendrecv::VariableMessage::INT64:
return typeid(int64_t); // NOLINT return framework::proto::VarType::INT64; // NOLINT
case sendrecv::VariableMessage::BOOL: case sendrecv::VariableMessage::BOOL:
return typeid(bool); // NOLINT return framework::proto::VarType::BOOL; // NOLINT
default: default:
PADDLE_THROW("Not support type %d", type); PADDLE_THROW("Not support type %d", type);
} }
......
...@@ -114,7 +114,7 @@ bool VariableResponse::CopyLodTensorData( ...@@ -114,7 +114,7 @@ bool VariableResponse::CopyLodTensorData(
tensor->set_lod(lod); tensor->set_lod(lod);
void* tensor_data = void* tensor_data =
tensor->mutable_data(ctx.GetPlace(), ToTypeIndex(meta_.data_type())); tensor->mutable_data(ctx.GetPlace(), ToVarType(meta_.data_type()));
VLOG(6) << "Tensor.memory_size = " << tensor->memory_size() VLOG(6) << "Tensor.memory_size = " << tensor->memory_size()
<< ", Buffer Size = " << length; << ", Buffer Size = " << length;
...@@ -139,13 +139,13 @@ bool VariableResponse::CopySelectRowsTensorData( ...@@ -139,13 +139,13 @@ bool VariableResponse::CopySelectRowsTensorData(
slr->set_height(meta_.slr_height()); slr->set_height(meta_.slr_height());
auto* tensor = slr->mutable_value(); auto* tensor = slr->mutable_value();
tensor->Resize(dims); tensor->Resize(dims);
PADDLE_ENFORCE_EQ(static_cast<size_t>(tensor->numel()), PADDLE_ENFORCE_EQ(
length / framework::SizeOfType( static_cast<size_t>(tensor->numel()),
paddle::operators::distributed::ToTypeIndex( length / framework::SizeOfType(paddle::operators::distributed::ToVarType(
meta_.data_type()))); meta_.data_type())));
void* tensor_data = tensor->mutable_data( void* tensor_data = tensor->mutable_data(
ctx.GetPlace(), ctx.GetPlace(),
paddle::operators::distributed::ToTypeIndex(meta_.data_type())); paddle::operators::distributed::ToVarType(meta_.data_type()));
if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) { if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) {
return false; return false;
...@@ -159,8 +159,7 @@ bool VariableResponse::CopySelectRowsData( ...@@ -159,8 +159,7 @@ bool VariableResponse::CopySelectRowsData(
const platform::DeviceContext& ctx, int length) { const platform::DeviceContext& ctx, int length) {
auto* slr = GetVar()->GetMutable<framework::SelectedRows>(); auto* slr = GetVar()->GetMutable<framework::SelectedRows>();
slr->mutable_rows()->clear(); slr->mutable_rows()->clear();
slr->mutable_rows()->resize(length / slr->mutable_rows()->resize(length / sizeof(int64_t)); // int64
framework::SizeOfType(typeid(int64_t))); // int64
int64_t* rows_data = slr->mutable_rows()->data(); int64_t* rows_data = slr->mutable_rows()->data();
// copy rows CPU data, GPU data will be copied lazily. // copy rows CPU data, GPU data will be copied lazily.
......
...@@ -2,9 +2,9 @@ include(operators) ...@@ -2,9 +2,9 @@ include(operators)
set(DISTRIBUTE_DEPS "") set(DISTRIBUTE_DEPS "")
if(WITH_GRPC) if(WITH_GRPC)
set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node) set(DISTRIBUTE_DEPS sendrecvop_rpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node)
else() else()
set(DISTRIBUTE_DEPS sendrecvop_brpc brpc leveldb snappystream snappy protobuf ssl crypto zlib node) set(DISTRIBUTE_DEPS sendrecvop_rpc brpc leveldb snappystream snappy protobuf ssl crypto zlib node)
if(WITH_BRPC_RDMA) if(WITH_BRPC_RDMA)
find_library(IBVERBS_LIBRARY NAMES ibverbs) find_library(IBVERBS_LIBRARY NAMES ibverbs)
ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL) ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL)
......
...@@ -26,10 +26,11 @@ limitations under the License. */ ...@@ -26,10 +26,11 @@ limitations under the License. */
#include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h" #include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_int32(rpc_send_thread_num, 5, "number of threads for rpc send"); DEFINE_int32(rpc_send_thread_num, 12, "number of threads for rpc send");
DEFINE_int32(rpc_get_thread_num, 5, "number of threads for rpc get"); DEFINE_int32(rpc_get_thread_num, 12, "number of threads for rpc get");
DEFINE_int32(rpc_prefetch_thread_num, 5, "number of threads for rpc prefetch"); DEFINE_int32(rpc_prefetch_thread_num, 12, "number of threads for rpc prefetch");
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -108,9 +108,7 @@ class MergeIdsOp : public framework::OperatorWithKernel { ...@@ -108,9 +108,7 @@ class MergeIdsOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( ctx.MultiInput<framework::Tensor>("X").front()->type(), ctx.GetPlace());
ctx.MultiInput<framework::Tensor>("X").front()->type()),
ctx.GetPlace());
} }
}; };
......
...@@ -42,9 +42,7 @@ class RefByTrainerIdOp : public framework::OperatorWithKernel { ...@@ -42,9 +42,7 @@ class RefByTrainerIdOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( ctx.MultiInput<framework::Tensor>("X")[0]->type(), ctx.GetPlace());
ctx.MultiInput<framework::Tensor>("X")[0]->type()),
ctx.GetPlace());
} }
}; };
......
...@@ -58,7 +58,9 @@ class SendOp : public framework::OperatorBase { ...@@ -58,7 +58,9 @@ class SendOp : public framework::OperatorBase {
} }
if (sync_send) { if (sync_send) {
for (size_t i = 0; i < rets.size(); i++) { for (size_t i = 0; i < rets.size(); i++) {
VLOG(7) << "before sync_send " << ins[i] << "from " << epmap[i];
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
VLOG(7) << "after sync_send " << ins[i] << "from " << epmap[i];
} }
} }
} }
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册