diff --git a/CMakeLists.txt b/CMakeLists.txt
index 997672169fbb4d24028a4529b1a97880b7480503..23bb27e77b9eab0c322a71a8ff570d12d1050377 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -65,6 +65,7 @@ option(REPLACE_ENFORCE_GLOG "Replace PADDLE_ENFORCE with glog/CHECK for better d
option(WITH_ANAKIN "Compile with Anakin library" OFF)
option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE})
option(WITH_BRPC_RDMA "Use brpc rdma as the rpc protocal" OFF)
+option(WITH_SYSTEM_BLAS "Use system blas library" OFF)
# CMAKE_BUILD_TYPE
if(NOT CMAKE_BUILD_TYPE)
diff --git a/README.md b/README.md
index 63abca069a6629ac59739224ded9cd9f06207d0a..eb99ed21d02650ef16cc7da91836909c02895be9 100644
--- a/README.md
+++ b/README.md
@@ -18,6 +18,8 @@ learning to many products at Baidu.
Our vision is to enable deep learning for everyone via PaddlePaddle.
Please refer to our [release announcement](https://github.com/PaddlePaddle/Paddle/releases) to track the latest feature of PaddlePaddle.
+### Lastest PaddlePaddle Version: [Fluid](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/fluid)
+
## Features
- **Flexibility**
diff --git a/benchmark/fluid/args.py b/benchmark/fluid/args.py
index 99c9d79b068f5886012fd702d84d0666b9d197b5..a79f25ccc6ace1594f3f331633130eaace5e175b 100644
--- a/benchmark/fluid/args.py
+++ b/benchmark/fluid/args.py
@@ -125,6 +125,10 @@ def parse_args():
parser.add_argument(
'--use_inference_transpiler',
action='store_true',
- help='If set, uses inference transpiler to optimize the program.')
+ help='If set, use inference transpiler to optimize the program.')
+ parser.add_argument(
+ '--no_random',
+ action='store_true',
+ help='If set, keep the random seed and do not shuffle the data.')
args = parser.parse_args()
return args
diff --git a/benchmark/fluid/fluid_benchmark.py b/benchmark/fluid/fluid_benchmark.py
old mode 100755
new mode 100644
index dcd4d9ea95d816029317a29055b5ca8273ac9f43..94ea7bd6aca7c9595037a2dacc5e36d4c77827e7
--- a/benchmark/fluid/fluid_benchmark.py
+++ b/benchmark/fluid/fluid_benchmark.py
@@ -132,10 +132,6 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc,
exe.run(startup_prog)
# Use inference_transpiler to speedup
- if args.use_inference_transpiler:
- t = fluid.InferenceTranspiler()
- t.transpile(infer_prog, place)
-
if not args.use_reader_op:
feed_var_list = [
var for var in train_prog.global_block().vars.itervalues()
@@ -186,6 +182,10 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc,
print("Pass: %d, Loss: %f" % (pass_id, np.mean(train_losses))),
# evaluation
if not args.no_test and batch_acc and not args.use_reader_op:
+ if args.use_inference_transpiler:
+ t = fluid.InferenceTranspiler()
+ t.transpile(infer_prog, place)
+
pass_test_acc = test(exe, infer_prog, test_reader, feeder,
batch_acc)
print(", Test Accuracy: %f" % pass_test_acc)
@@ -316,6 +316,8 @@ def main():
args = parse_args()
print_arguments(args)
print_paddle_envs()
+ if args.no_random:
+ fluid.default_startup_program().random_seed = 1
# the unique trainer id, starting from 0, needed by trainer
# only
diff --git a/benchmark/fluid/models/resnet.py b/benchmark/fluid/models/resnet.py
index 9ed1093c54a501cc93dbbf9c3651fe70914ce26b..d44a9c07d31cfae9d54ad5949b85c77e60eae258 100644
--- a/benchmark/fluid/models/resnet.py
+++ b/benchmark/fluid/models/resnet.py
@@ -197,12 +197,12 @@ def get_model(args):
optimizer = fluid.optimizer.Momentum(learning_rate=0.01, momentum=0.9)
batched_train_reader = paddle.batch(
- paddle.reader.shuffle(
+ train_reader if args.no_random else paddle.reader.shuffle(
train_reader, buf_size=5120),
batch_size=args.batch_size * args.gpus,
drop_last=True)
batched_test_reader = paddle.batch(
- train_reader, batch_size=args.batch_size, drop_last=True)
+ test_reader, batch_size=args.batch_size, drop_last=True)
return avg_cost, inference_program, optimizer, batched_train_reader,\
batched_test_reader, batch_acc
diff --git a/cmake/cblas.cmake b/cmake/cblas.cmake
index e3b9d94215a858c5c9a34e1b7e97540f1876801d..6ed51c648478efb9784d0c43b169c285e740e0f3 100644
--- a/cmake/cblas.cmake
+++ b/cmake/cblas.cmake
@@ -83,18 +83,20 @@ else()
set(REFERENCE_CBLAS_LIB_SEARCH_PATHS ${REFERENCE_CBLAS_ROOT}/lib)
endif()
-find_path(REFERENCE_CBLAS_INCLUDE_DIR NAMES cblas.h PATHS
+if(WITH_SYSTEM_BLAS)
+ find_path(REFERENCE_CBLAS_INCLUDE_DIR NAMES cblas.h PATHS
${REFERENCE_CBLAS_INCLUDE_SEARCH_PATHS})
-find_library(REFERENCE_CBLAS_LIBRARY NAMES cblas PATHS
+ find_library(REFERENCE_CBLAS_LIBRARY NAMES cblas PATHS
${REFERENCE_CBLAS_LIB_SEARCH_PATHS})
-if(REFERENCE_CBLAS_INCLUDE_DIR AND REFERENCE_CBLAS_LIBRARY)
- set(CBLAS_FOUND ON)
- set(CBLAS_PROVIDER REFERENCE)
- set(CBLAS_INC_DIR ${REFERENCE_CBLAS_INCLUDE_DIR})
- set(CBLAS_LIBRARIES ${REFERENCE_CBLAS_LIBRARY})
- add_definitions(-DPADDLE_USE_REFERENCE_CBLAS)
- message(STATUS "Found reference-cblas (include: ${CBLAS_INC_DIR}, library: ${CBLAS_LIBRARIES})")
+ if(REFERENCE_CBLAS_INCLUDE_DIR AND REFERENCE_CBLAS_LIBRARY)
+ set(CBLAS_FOUND ON)
+ set(CBLAS_PROVIDER REFERENCE)
+ set(CBLAS_INC_DIR ${REFERENCE_CBLAS_INCLUDE_DIR})
+ set(CBLAS_LIBRARIES ${REFERENCE_CBLAS_LIBRARY})
+ add_definitions(-DPADDLE_USE_REFERENCE_CBLAS)
+ message(STATUS "Found reference-cblas (include: ${CBLAS_INC_DIR}, library: ${CBLAS_LIBRARIES})")
+ endif()
endif()
if(IOS_USE_VECLIB_FOR_BLAS AND VECLIB_FOUND)
diff --git a/doc/v2/design/cluster_train/large_model_dist_train.md b/doc/v2/design/cluster_train/large_model_dist_train.md
index 0c4b5bc24c854b7062d509249bea9c50d42bd5f1..edb0245ea083e791b7f32ac57a330698299fceda 100644
--- a/doc/v2/design/cluster_train/large_model_dist_train.md
+++ b/doc/v2/design/cluster_train/large_model_dist_train.md
@@ -52,7 +52,7 @@ In `trainer_internal.cpp:L93 trainOneBatch`:
When doing actual network forward and backward, at the beginning of each batch, the trainer will try to download one row of data from pserver.
-In `trainer/RemoteParameterUpdater.cpp`: `parameterUpdater_->getParametersRemote();`:
+In `legacy/trainer/RemoteParameterUpdater.cpp`: `parameterUpdater_->getParametersRemote();`:
```c++
if (fullSize) {
diff --git a/doc/v2/design/mkl/mkldnn.md b/doc/v2/design/mkl/mkldnn.md
index bd5bcf6f67168c21cebb046a629b948d1661e75c..4876de0045979be20fa45bdc84d2594516f71c03 100644
--- a/doc/v2/design/mkl/mkldnn.md
+++ b/doc/v2/design/mkl/mkldnn.md
@@ -18,20 +18,20 @@ Figure 1. PaddlePaddle on IA
具体的完成状态可以参见[这里](https://github.com/PaddlePaddle/Paddle/projects/21)。
## Contents
-
-- [Overview](#overview)
-- [Actions](#actions)
- - [CMake](#cmake)
- - [Matrix](#matrix)
- - [Layers](#layers)
- - [Activations](#activations)
- - [Parameters](#parameters)
- - [Gradients](#gradients)
- - [Unit Tests](#unit-tests)
- - [Python API](#python-api)
- - [Benchmarking](#benchmarking)
- - [Others](#others)
-- [Design Concerns](#design-concerns)
+
+- [Overview](#overview)
+- [Actions](#actions)
+ - [CMake](#cmake)
+ - [Matrix](#matrix)
+ - [Layers](#layers)
+ - [Activations](#activations)
+ - [Parameters](#parameters)
+ - [Gradients](#gradients)
+ - [Unit Tests](#unit-tests)
+ - [Python API](#python-api)
+ - [Benchmarking](#benchmarking)
+ - [Others](#others)
+- [Design Concerns](#design-concerns)
## Overview
@@ -218,20 +218,20 @@ if use_mkldnn
我们总结出一些特别需要注意的点:
1. 使用**deviceId_**。为了尽可能少的在父类Layer中添加变量或者函数,
-我们决定使用已有的`deviceId_`变量来区分layer的属性,定义`-2`为`MKLDNNLayer`特有的设备ID。
-2. 重写父类Layer的**init**函数,修改`deviceId_`为`-2`,代表这个layer是用于跑在MKL-DNN的环境下。
+我们决定使用已有的`deviceId_`变量来区分layer的属性,定义`-2`为`MKLDNNLayer`特有的设备ID。
+2. 重写父类Layer的**init**函数,修改`deviceId_`为`-2`,代表这个layer是用于跑在MKL-DNN的环境下。
3. 创建`MKLDNNBase`,定义一些除了layer和memory相关的类和函数。
-包括MKL-DNN会用到`MKLDNNStream`和`CPUEngine`,和未来可能还会用到`FPGAEngine`等。
+包括MKL-DNN会用到`MKLDNNStream`和`CPUEngine`,和未来可能还会用到`FPGAEngine`等。
4. 如果MKL-DNN layer的后面接有cpu device,那么就会使`output_.value`与`extOutVal_`共享内存,
同时数据格式就是`NCHW`,这样下一个cpu device就能拿到正确的数据。
在有普通的CPU layer时, `extOutVal_`和`extOutGrad_`的格式始终是`NCHW`或者`NC`。
## References
1. [MKL small library](https://github.com/01org/mkl-dnn#linking-your-application)是[Intel MKL](https://software.intel.com/en-us/mkl)的一个子集。
-主要包括了深度学习相关的数学原语与操作,一般由MKL-DNN在发布[新版本](https://github.com/01org/mkl-dnn/releases)时一起更新。
+主要包括了深度学习相关的数学原语与操作,一般由MKL-DNN在发布[新版本](https://github.com/01org/mkl-dnn/releases)时一起更新。
2. [MKL-DNN System Requirements](https://github.com/01org/mkl-dnn#system-requirements)。
目前在PaddlePaddle中,仅会在支持AVX2指令集及以上的机器才使用MKL-DNN。
3. [原来的方案](https://github.com/PaddlePaddle/Paddle/pull/3096)会引入**nextLayer**的信息。
-但是在PaddlePaddle中,无论是重构前的layer还是重构后的op,都不会想要知道next layer/op的信息。
+但是在PaddlePaddle中,无论是重构前的layer还是重构后的op,都不会想要知道next layer/op的信息。
4. MKL-DNN的高性能格式与PaddlePaddle原有的`NCHW`不同(PaddlePaddle中的cuDNN部分使用的也是`NCHW`,所以不存在这个问题)。
-所以需要引入一个转换方法,并且只需要在必要的时候转换这种格式,才能更好的发挥MKL-DNN的性能。
+所以需要引入一个转换方法,并且只需要在必要的时候转换这种格式,才能更好的发挥MKL-DNN的性能。
diff --git a/doc/v2/dev/new_layer_en.rst b/doc/v2/dev/new_layer_en.rst
index 6a848a020df343c14601b9c3fcb5fb6fcde7f880..ad723738801908a5f48343574c204bdbfc97ee08 100644
--- a/doc/v2/dev/new_layer_en.rst
+++ b/doc/v2/dev/new_layer_en.rst
@@ -339,7 +339,7 @@ If you are creating a new file for the test, such as :code:`paddle/legacy/gserve
Implement Python Wrapper
========================
-Implementing Python wrapper allows us to use the added layer in configuration files. All the Python wrappers are in file :code:`python/paddle/trainer/config_parser.py`. An example of the Python wrapper for fully connected layer is listed below. It has the following steps:
+Implementing Python wrapper allows us to use the added layer in configuration files. All the Python wrappers are in file :code:`python/paddle/legacy/trainer/config_parser.py`. An example of the Python wrapper for fully connected layer is listed below. It has the following steps:
- Use :code:`@config_layer('fc')` at the decorator for all the Python wrapper class. :code:`fc` is the identifier of the layer.
- Implements :code:`__init__` constructor function.
diff --git a/doc/v2/howto/capi/compile_paddle_lib_cn.md b/doc/v2/howto/capi/compile_paddle_lib_cn.md
index e223fd33a8420abcdfdad53d1cfc5ed160a1b37e..2c87e9afc6911526cd51d6c691f262960accc9e8 100644
--- a/doc/v2/howto/capi/compile_paddle_lib_cn.md
+++ b/doc/v2/howto/capi/compile_paddle_lib_cn.md
@@ -18,7 +18,7 @@
cpu_avx_openblas |
-暂无 |
+paddle.tgz |
cpu_noavx_openblas |
@@ -35,7 +35,12 @@
cuda8.0_cudnn7_avx_mkl |
paddle.tgz |
-
+
+
+cuda9.0_cudnn7_avx_mkl |
+paddle.tgz |
+
+
### 从源码编译
diff --git a/doc/v2/howto/capi/compile_paddle_lib_en.md b/doc/v2/howto/capi/compile_paddle_lib_en.md
index 6212a3081116d988630706e83d2349dd200b73ab..3fa8a18a9fbea21b494c416e6b938990fbb68337 100644
--- a/doc/v2/howto/capi/compile_paddle_lib_en.md
+++ b/doc/v2/howto/capi/compile_paddle_lib_en.md
@@ -17,7 +17,7 @@
cpu_avx_openblas |
-- |
+paddle.tgz |
cpu_noavx_openblas |
@@ -34,7 +34,12 @@
cuda8.0_cudnn7_avx_mkl |
paddle.tgz |
-
+
+
+cuda9.0_cudnn7_avx_mkl |
+paddle.tgz |
+
+
### From source
diff --git a/paddle/CMakeLists.txt b/paddle/CMakeLists.txt
index efa59fc4a5cf21e885435f564d2a19f892cb534b..6653244507742b33d9524a7a0e4a5b2b575d358a 100644
--- a/paddle/CMakeLists.txt
+++ b/paddle/CMakeLists.txt
@@ -1,24 +1,24 @@
if(NOT WITH_FLUID_ONLY)
add_subdirectory(legacy/cuda)
add_subdirectory(legacy/function)
- add_subdirectory(utils)
+ add_subdirectory(legacy/utils)
add_subdirectory(legacy/math)
add_subdirectory(legacy/gserver)
add_subdirectory(legacy/parameter)
if(MOBILE_INFERENCE)
- add_subdirectory(capi)
+ add_subdirectory(legacy/capi)
else()
add_subdirectory(legacy/pserver)
- add_subdirectory(trainer)
+ add_subdirectory(legacy/trainer)
add_subdirectory(scripts)
if(WITH_C_API)
- add_subdirectory(capi)
+ add_subdirectory(legacy/capi)
endif()
if(WITH_SWIG_PY)
- add_subdirectory(api)
+ add_subdirectory(legacy/api)
endif()
endif()
endif()
diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt
index 3c73b6cc55c187c3f6e7edd1ce38cc58f4e8413d..4fb4ec38ee965a2790d11378a1ce6befa0ef5a00 100644
--- a/paddle/fluid/framework/details/CMakeLists.txt
+++ b/paddle/fluid/framework/details/CMakeLists.txt
@@ -25,11 +25,12 @@ else()
cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
endif()
+cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_base scope lod_tensor)
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope)
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
- scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle)
+ scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle)
cc_library(ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer ssa_graph_checker)
diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h
index 64e83acb4dc1995800c4ca3caf81668b24a7c9fe..9c2c845c6efb206fb1ad5150189430b9a6fe9ea3 100644
--- a/paddle/fluid/framework/details/build_strategy.h
+++ b/paddle/fluid/framework/details/build_strategy.h
@@ -33,6 +33,8 @@ struct BuildStrategy {
GradientScaleStrategy gradient_scale_{GradientScaleStrategy::kCoeffNumDevice};
std::string debug_graphviz_path_{""};
+
+ bool enable_data_balance_{true};
};
} // namespace details
diff --git a/paddle/fluid/framework/details/data_balance_op_handle.cc b/paddle/fluid/framework/details/data_balance_op_handle.cc
new file mode 100644
index 0000000000000000000000000000000000000000..d07235df5856591f8ad707c86fa5b3b65868c3d1
--- /dev/null
+++ b/paddle/fluid/framework/details/data_balance_op_handle.cc
@@ -0,0 +1,154 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/fluid/framework/details/data_balance_op_handle.h"
+#include
+#include "paddle/fluid/framework/details/container_cast.h"
+
+namespace paddle {
+namespace framework {
+namespace details {
+
+#ifdef PADDLE_WITH_CUDA
+DataBalanceOpHandle::DataBalanceOpHandle(
+ const std::vector &local_scopes,
+ const std::vector &places,
+ const platform::NCCLContextMap *ctxs)
+ : local_scopes_(local_scopes), places_(places) {
+ if (ctxs) {
+ for (auto &p : places_) {
+ this->dev_ctxes_[p] = ctxs->DevCtx(p);
+ }
+ }
+}
+#else
+DataBalanceOpHandle::DataBalanceOpHandle(
+ const std::vector &local_scopes,
+ const std::vector &places)
+ : local_scopes_(local_scopes), places_(places) {}
+#endif
+
+std::string DataBalanceOpHandle::Name() const { return "data balance"; }
+
+std::vector> DataBalanceOpHandle::GetBalancePlan(
+ const std::vector &device_sizes) {
+ int device_num = device_sizes.size();
+ int total_size = 0;
+ int empty_num = 0;
+ std::vector> size_device_vec;
+ size_device_vec.reserve(device_num);
+ for (int i = 0; i < device_num; ++i) {
+ if (device_sizes[i] == 0) {
+ ++empty_num;
+ }
+ total_size += device_sizes[i];
+ size_device_vec.push_back({{device_sizes[i], i}});
+ }
+ std::vector> res;
+ if (empty_num == 0) {
+ // No need to do data balance.
+ return res;
+ }
+ if (total_size < device_num) {
+ // No enough data.
+ PADDLE_THROW_EOF();
+ }
+ std::sort(size_device_vec.begin(), size_device_vec.end(),
+ [](const std::array &a, const std::array &b) {
+ return a[0] > b[0];
+ });
+ int expected_device_size = total_size / device_num;
+ int src_idx = 0;
+ for (int dst_idx = device_num - empty_num; dst_idx < device_num; ++dst_idx) {
+ if (size_device_vec[src_idx][0] <= expected_device_size) {
+ ++src_idx;
+ PADDLE_ENFORCE_LT(
+ src_idx, device_num - empty_num,
+ "In current srategy an empty tensor should not be copy source.");
+ }
+ size_device_vec[src_idx][0] -= expected_device_size;
+ size_device_vec[dst_idx][0] += expected_device_size;
+ res.push_back({{size_device_vec[src_idx][1], size_device_vec[dst_idx][1],
+ expected_device_size}});
+ }
+ return res;
+}
+
+void DataBalanceOpHandle::RunImpl() {
+ if (places_.size() == 1) {
+ return;
+ }
+ auto in_var_handles = DynamicCast(inputs_);
+ auto out_var_handles = DynamicCast(outputs_);
+ PADDLE_ENFORCE(in_var_handles.size() % places_.size() == 0);
+ PADDLE_ENFORCE_EQ(
+ in_var_handles.size(), out_var_handles.size(),
+ "The NoDummyInputSize and NoDummyOutputSize should be equal.");
+ int data_num = in_var_handles.size() / places_.size();
+ WaitInputVarGenerated();
+ std::vector> lod_tensors(data_num);
+ std::vector device_sizes;
+ for (int i = 0; i < static_cast(in_var_handles.size()); ++i) {
+ PADDLE_ENFORCE_EQ(in_var_handles[i]->name_, out_var_handles[i]->name_,
+ "The name of input and output should be equal.");
+ int place_idx = i / data_num;
+ int data_idx = i % data_num;
+ auto *local_scope =
+ local_scopes_[place_idx]->FindVar(kLocalExecScopeName)->Get();
+ auto *tensor_var = local_scope->FindVar(in_var_handles[i]->name_);
+ PADDLE_ENFORCE(tensor_var->IsType());
+ auto *tensor = tensor_var->GetMutable();
+ lod_tensors[data_idx].push_back(tensor);
+ int ins_size =
+ tensor->lod().empty() ? tensor->dims()[0] : tensor->NumElements();
+ if (data_idx == 0) {
+ device_sizes.emplace_back(ins_size);
+ } else {
+ PADDLE_ENFORCE_EQ(
+ ins_size, device_sizes.at(place_idx),
+ "All data on the same device shall have the same batch size.");
+ }
+ }
+ const auto &balance_plan = GetBalancePlan(device_sizes);
+
+ for (const auto &trans : balance_plan) {
+ for (int data_idx = 0; data_idx < data_num; ++data_idx) {
+ LoDTensor *src_tensor = lod_tensors[data_idx][trans[0]];
+ LoDTensor *dst_tensor = lod_tensors[data_idx][trans[1]];
+ int trans_ins_size = trans[2];
+ LoD src_lod = src_tensor->lod();
+ int src_ins_size =
+ src_lod.empty() ? src_tensor->dims()[0] : src_tensor->NumElements();
+ int cut_point = src_ins_size - trans_ins_size;
+ if (!src_lod.empty()) {
+ for (auto &level : src_lod) {
+ cut_point = level[cut_point];
+ }
+ }
+ TensorCopySync(src_tensor->Slice(cut_point, src_tensor->dims()[0]),
+ dst_tensor->place(), dst_tensor);
+ src_tensor->ShareDataWith(src_tensor->Slice(0, cut_point));
+ if (!src_lod.empty()) {
+ dst_tensor->set_lod(SliceInLevel(
+ src_lod, 0, src_ins_size - trans_ins_size, src_ins_size));
+ src_tensor->set_lod(
+ SliceInLevel(src_lod, 0, 0, src_ins_size - trans_ins_size));
+ }
+ }
+ }
+}
+
+} // namespace details
+} // namespace framework
+} // namespace paddle
diff --git a/paddle/fluid/framework/details/data_balance_op_handle.h b/paddle/fluid/framework/details/data_balance_op_handle.h
new file mode 100644
index 0000000000000000000000000000000000000000..76a407e3610e8bb48facf1f814779f4c23f92d98
--- /dev/null
+++ b/paddle/fluid/framework/details/data_balance_op_handle.h
@@ -0,0 +1,59 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+#include
+#include
+#include "paddle/fluid/framework/details/op_handle_base.h"
+#include "paddle/fluid/framework/lod_tensor.h"
+#include "paddle/fluid/framework/scope.h"
+#ifdef PADDLE_WITH_CUDA
+#include "paddle/fluid/platform/nccl_helper.h"
+#endif
+
+namespace paddle {
+namespace framework {
+namespace details {
+
+struct DataBalanceOpHandle : public OpHandleBase {
+ public:
+#ifdef PADDLE_WITH_CUDA
+ DataBalanceOpHandle(const std::vector &local_scopes,
+ const std::vector &places,
+ const platform::NCCLContextMap *ctxs);
+#else
+ DataBalanceOpHandle(const std::vector &local_scopes,
+ const std::vector &places);
+#endif
+
+ std::string Name() const override;
+
+ bool IsMultiDeviceTransfer() override { return false; };
+
+ protected:
+ void RunImpl() override;
+
+ private:
+ // std::vector<(src_dev_id, dst_dev_id, trans_size)>
+ std::vector> GetBalancePlan(
+ const std::vector &batch_size_per_device);
+
+ const std::vector local_scopes_;
+ const std::vector places_;
+};
+
+} // namespace details
+} // namespace framework
+} // namespace paddle
diff --git a/paddle/fluid/framework/details/fetch_op_handle.cc b/paddle/fluid/framework/details/fetch_op_handle.cc
index 224e8e1f6efd7a894591ac51c929517cae7539ce..d646c944601e81477787740189d7ac60ae97fa80 100644
--- a/paddle/fluid/framework/details/fetch_op_handle.cc
+++ b/paddle/fluid/framework/details/fetch_op_handle.cc
@@ -67,8 +67,8 @@ void FetchOpHandle::RunImpl() {
#endif
} else {
tensors_[i].ShareDataWith(t);
- tensors_[i].set_lod(t.lod());
}
+ tensors_[i].set_lod(t.lod());
}
this->WaitAndMergeCPUTensors();
diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc
index cc7b94d0653e34c8ac711a7db7ab6ab1a9ac46a2..46d0c2769cb334f5cb75ae0ef5e48da45448c48f 100644
--- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc
+++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc
@@ -20,6 +20,7 @@
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
+#include "paddle/fluid/framework/details/data_balance_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h"
@@ -215,7 +216,14 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build(
} else {
// This op runs on all devices, and its output may have parameter's
// gradients.
- CreateComputationalOps(&result, *op, places_.size());
+ if (op->Type() == "read" && strategy_.enable_data_balance_) {
+ op->SetAttr("throw_eof_exp", false);
+ CreateComputationalOps(&result, *op, places_.size());
+ const auto &data_var_names = op->Output("Out");
+ InsertDataBalanceOp(&result, data_var_names);
+ } else {
+ CreateComputationalOps(&result, *op, places_.size());
+ }
if (!is_forwarding && places_.size() > 1) {
// Currently, we assume that once gradient is generated, it can be
@@ -360,6 +368,29 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(SSAGraph *result,
}
}
+void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
+ SSAGraph *result, const std::vector &datas) const {
+#ifdef PADDLE_WITH_CUDA
+ result->ops_.emplace_back(
+ new DataBalanceOpHandle(local_scopes_, places_, nccl_ctxs_));
+#else
+ result->ops_.emplace_back(new DataBalanceOpHandle(local_scopes_, places_));
+#endif
+ auto *op_handle = result->ops_.back().get();
+ for (size_t i = 0; i < places_.size(); ++i) {
+ auto &p = places_[i];
+ SetCommunicationContext(op_handle, p);
+ for (const std::string &d_name : datas) {
+ auto &vars = result->vars_[i][d_name];
+ PADDLE_ENFORCE(!vars.empty());
+ op_handle->AddInput(vars.back().get());
+ auto var = new VarHandle(vars.size(), i, d_name, p);
+ vars.emplace_back(var);
+ op_handle->AddOutput(var);
+ }
+ }
+}
+
bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
const std::string &og,
std::unordered_set *og_has_been_broadcast) const {
@@ -512,7 +543,8 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
// the variable name which contains .block means it was splited by
// split_byref op
- // so that we can balance the variable blocks to all the pserver instances.
+ // so that we can balance the variable blocks to all the pserver
+ // instances.
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce &&
op.InputArgumentNames()[0].find(".block") == std::string::npos) {
op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames());
diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h
index 0b6347bf51dc1c347073a0fdcf4ddd91865d846d..a964e024885e56693224a6199e00ff30beaa1df4 100644
--- a/paddle/fluid/framework/details/multi_devices_graph_builder.h
+++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h
@@ -101,6 +101,9 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
void InsertAllReduceOp(SSAGraph *result, const std::string &og) const;
+ void InsertDataBalanceOp(SSAGraph *result,
+ const std::vector &datas) const;
+
void CreateBroadcastOp(SSAGraph *result, const std::string &p_name,
size_t src_dev_id) const;
diff --git a/paddle/fluid/framework/details/op_handle_base.cc b/paddle/fluid/framework/details/op_handle_base.cc
index 1f84c3b9e2d7ee9ae51959988fceeb3451b7b3b8..d80bdcf15d798925c137460125964d3d7e65f67e 100644
--- a/paddle/fluid/framework/details/op_handle_base.cc
+++ b/paddle/fluid/framework/details/op_handle_base.cc
@@ -58,8 +58,10 @@ void OpHandleBase::Run(bool use_cuda) {
void OpHandleBase::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) {
#ifdef PADDLE_WITH_CUDA
+ PADDLE_ENFORCE_NOT_NULL(waited_ctx);
if (platform::is_cpu_place(waited_ctx->GetPlace()) || events_.empty()) {
for (auto &dev_ctx : dev_ctxes_) {
+ PADDLE_ENFORCE_NOT_NULL(dev_ctx.second);
dev_ctx.second->Wait();
}
} else {
@@ -122,16 +124,10 @@ void OpHandleBase::RunAndRecordEvent(const std::function &callback) {
#ifdef PADDLE_WITH_CUDA
if (!events_.empty()) { // Use event
std::function method = callback;
- // NOTE(zcd): device context must be ordered here because RecordEvent
- // will use a mutex to ensure the safe of multi-threads.
- std::map ordered_ctxes;
for (auto &p : dev_ctxes_) {
- ordered_ctxes.emplace(p.second, p.first);
- }
- for (auto &p : ordered_ctxes) {
method = [method, p, this]() {
- static_cast(p.first)->RecordEvent(
- events_.at(boost::get(p.second).device),
+ static_cast(p.second)->RecordEvent(
+ events_.at(boost::get(p.first).device),
method);
};
}
diff --git a/paddle/fluid/framework/details/op_handle_base.h b/paddle/fluid/framework/details/op_handle_base.h
index fbd90a3296bca92b097cab925b218b91e7f4752f..6aec178831161f8ac1306fc3ed72e3267ca3c7e5 100644
--- a/paddle/fluid/framework/details/op_handle_base.h
+++ b/paddle/fluid/framework/details/op_handle_base.h
@@ -13,9 +13,9 @@
// limitations under the License.
#pragma once
+#include