diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index f13ac276fca01d15c2e9057f0577da00355d5787..bf1a705ef50b663efa53393ead1f81fd6bcf8c48 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -8,8 +8,14 @@ cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_pr cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base) cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph) + +if(WITH_GPU) + set(multi_devices_graph_builder_deps nccl_all_reduce_op_handle) +else() + set(multi_devices_graph_builder_deps) +endif() cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle - nccl_all_reduce_op_handle scale_loss_grad_op_handle) + scale_loss_grad_op_handle ${multi_devices_graph_builder_deps}) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph) cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index cb02d36714d8dc7a878db0832c6a1533671f4535..67987760764cd2c1fa0f11c290bee29e21120487 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -14,14 +14,18 @@ #include "paddle/fluid/framework/details/multi_devices_graph_builder.h" #include "paddle/fluid/framework/details/computation_op_handle.h" -#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/scope.h" -#include "paddle/fluid/platform/nccl_helper.h" + +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h" +#endif namespace paddle { namespace framework { namespace details { + +#ifdef PADDLE_WITH_CUDA MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( const std::vector &places, const std::string &loss_var_name, @@ -32,6 +36,16 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( places_(places), local_scopes_(local_scopes), nccl_ctxs_(nccl_ctxs) { +#else +MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( + const std::vector &places, + const std::string &loss_var_name, + const std::unordered_set ¶ms, + const std::vector &local_scopes) + : loss_var_name_(loss_var_name), + places_(places), + local_scopes_(local_scopes) { +#endif for (auto &p : params) { grad_names_.insert(GradVarName(p)); } @@ -78,9 +92,16 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( if (is_forwarding) { if (var_names.size() == 1 && var_names[0] == loss_var_name_) { - // Insert ScaleCost OpHandle +// Insert ScaleCost OpHandle +#ifdef PADDLE_WITH_CUDA + auto *communication_dev_ctx = nccl_ctxs_->DevCtx(p); +#else + auto *communication_dev_ctx = + platform::DeviceContextPool::Instance().Get(platform::CPUPlace()); +#endif + op_handle = new ScaleLossGradOpHandle(local_scopes_.size(), s, p, - nccl_ctxs_->DevCtx(p)); + communication_dev_ctx); result.ops_.emplace_back(op_handle); // FIXME: Currently ScaleLossGradOp only use device_count as scale @@ -103,7 +124,8 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( auto var_names = op->OutputArgumentNames(); for (auto &og : var_names) { if (grad_names_.count(og) != 0) { // is param grad - // Insert NCCL AllReduce Op + // Insert NCCL AllReduce Op +#ifdef PADDLE_WITH_CUDA result.ops_.emplace_back( new NCCLAllReduceOpHandle(local_scopes_, places_, *nccl_ctxs_)); auto *op_handle = result.ops_.back().get(); @@ -125,6 +147,9 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( op_handle->AddOutput(&var); } +#else + PADDLE_ENFORCE("Not implemented"); +#endif } } } @@ -143,7 +168,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( } return std::unique_ptr(graph); -} +} // namespace details } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index 17959a94d6cf7762982778bac35da624d1cb7436..d3c8e582cf2cdf26198822e4bd2602883622df21 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -26,11 +26,18 @@ class Scope; namespace details { class MultiDevSSAGraphBuilder : public SSAGraphBuilder { public: +#ifdef PADDLE_WITH_CUDA MultiDevSSAGraphBuilder(const std::vector &places, const std::string &loss_var_name, const std::unordered_set ¶ms, const std::vector &local_scopes, platform::NCCLContextMap *nccl_ctxs); +#else + MultiDevSSAGraphBuilder(const std::vector &places, + const std::string &loss_var_name, + const std::unordered_set ¶ms, + const std::vector &local_scopes); +#endif std::unique_ptr Build(const ProgramDesc &program) const override; @@ -38,8 +45,11 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { std::string loss_var_name_; const std::vector &places_; const std::vector &local_scopes_; - platform::NCCLContextMap *nccl_ctxs_; std::unordered_set grad_names_; + +#ifdef PADDLE_WITH_CUDA + platform::NCCLContextMap *nccl_ctxs_; +#endif }; } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index d1e1f0ed23d995a51a6be3eb46093eafbf89efb7..4936b8b6567d2901a2a611751bafef8196a346b4 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -16,7 +16,9 @@ limitations under the License. */ #include "ThreadPool.h" +#ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/nccl_helper.h" +#endif #include "paddle/fluid/framework/details/multi_devices_graph_builder.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" @@ -64,13 +66,18 @@ ParallelExecutor::ParallelExecutor( member_->local_scopes_.size() != 1) { // Is CUDA BCastParamsToGPUs(startup_program); } - // Startup Program has been run. All local scopes has correct parameters. +// Startup Program has been run. All local scopes has correct parameters. - // Step 2. Convert main_program to SSA form and dependency graph. Also, insert - // ncclOp +// Step 2. Convert main_program to SSA form and dependency graph. Also, insert +// ncclOp +#ifdef PADDLE_WITH_CUDA details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name, params, member_->local_scopes_, member_->nccl_ctxs_.get()); +#else + details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name, + params, member_->local_scopes_); +#endif auto graph = builder.Build(main_program); member_->executor_.reset(new details::ThreadedSSAGraphExecutor( @@ -137,3 +144,4 @@ void ParallelExecutor::Run(const std::vector &fetch_tensors, } // namespace framework } // namespace paddle +A \ No newline at end of file diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index 8bc09c5798854feabb43fa64d160b341bfe0e7b5..503efa2e447b0ac70f6302aa0a89cc55e5afcb81 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -21,8 +21,6 @@ limitations under the License. */ #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/tensor.h" - -#include "paddle/fluid/operators/nccl/nccl_gpu_common.h" #include "paddle/fluid/platform/device_context.h" namespace paddle { diff --git a/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc b/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc index 0e00f218f9a834111a86b7d20a864e11d15cc3df..adaa0b9e5f1ffcfbf3e9cd8fd060153575f270a6 100644 --- a/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc +++ b/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include #include "paddle/fluid/operators/reader/reader_op_registry.h" #include "paddle/fluid/recordio/scanner.h"