提交 02aaecca 编写于 作者: Y Yu Yang

Fix CPU compile

上级 54bd17fe
...@@ -8,8 +8,14 @@ cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_pr ...@@ -8,8 +8,14 @@ cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_pr
cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base) cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base)
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph) cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
if(WITH_GPU)
set(multi_devices_graph_builder_deps nccl_all_reduce_op_handle)
else()
set(multi_devices_graph_builder_deps)
endif()
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
nccl_all_reduce_op_handle scale_loss_grad_op_handle) scale_loss_grad_op_handle ${multi_devices_graph_builder_deps})
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph)
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context) simple_threadpool device_context)
...@@ -14,14 +14,18 @@ ...@@ -14,14 +14,18 @@
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" #include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/nccl_helper.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
#ifdef PADDLE_WITH_CUDA
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const std::string &loss_var_name, const std::string &loss_var_name,
...@@ -32,6 +36,16 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( ...@@ -32,6 +36,16 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
places_(places), places_(places),
local_scopes_(local_scopes), local_scopes_(local_scopes),
nccl_ctxs_(nccl_ctxs) { nccl_ctxs_(nccl_ctxs) {
#else
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes)
: loss_var_name_(loss_var_name),
places_(places),
local_scopes_(local_scopes) {
#endif
for (auto &p : params) { for (auto &p : params) {
grad_names_.insert(GradVarName(p)); grad_names_.insert(GradVarName(p));
} }
...@@ -78,9 +92,16 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -78,9 +92,16 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
if (is_forwarding) { if (is_forwarding) {
if (var_names.size() == 1 && var_names[0] == loss_var_name_) { if (var_names.size() == 1 && var_names[0] == loss_var_name_) {
// Insert ScaleCost OpHandle // Insert ScaleCost OpHandle
#ifdef PADDLE_WITH_CUDA
auto *communication_dev_ctx = nccl_ctxs_->DevCtx(p);
#else
auto *communication_dev_ctx =
platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
#endif
op_handle = new ScaleLossGradOpHandle(local_scopes_.size(), s, p, op_handle = new ScaleLossGradOpHandle(local_scopes_.size(), s, p,
nccl_ctxs_->DevCtx(p)); communication_dev_ctx);
result.ops_.emplace_back(op_handle); result.ops_.emplace_back(op_handle);
// FIXME: Currently ScaleLossGradOp only use device_count as scale // FIXME: Currently ScaleLossGradOp only use device_count as scale
...@@ -103,7 +124,8 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -103,7 +124,8 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
auto var_names = op->OutputArgumentNames(); auto var_names = op->OutputArgumentNames();
for (auto &og : var_names) { for (auto &og : var_names) {
if (grad_names_.count(og) != 0) { // is param grad if (grad_names_.count(og) != 0) { // is param grad
// Insert NCCL AllReduce Op // Insert NCCL AllReduce Op
#ifdef PADDLE_WITH_CUDA
result.ops_.emplace_back( result.ops_.emplace_back(
new NCCLAllReduceOpHandle(local_scopes_, places_, *nccl_ctxs_)); new NCCLAllReduceOpHandle(local_scopes_, places_, *nccl_ctxs_));
auto *op_handle = result.ops_.back().get(); auto *op_handle = result.ops_.back().get();
...@@ -125,6 +147,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -125,6 +147,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
op_handle->AddOutput(&var); op_handle->AddOutput(&var);
} }
#else
PADDLE_ENFORCE("Not implemented");
#endif
} }
} }
} }
...@@ -143,7 +168,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -143,7 +168,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
} }
return std::unique_ptr<SSAGraph>(graph); return std::unique_ptr<SSAGraph>(graph);
} } // namespace details
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -26,11 +26,18 @@ class Scope; ...@@ -26,11 +26,18 @@ class Scope;
namespace details { namespace details {
class MultiDevSSAGraphBuilder : public SSAGraphBuilder { class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
public: public:
#ifdef PADDLE_WITH_CUDA
MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places, MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places,
const std::string &loss_var_name, const std::string &loss_var_name,
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
platform::NCCLContextMap *nccl_ctxs); platform::NCCLContextMap *nccl_ctxs);
#else
MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes);
#endif
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override; std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
...@@ -38,8 +45,11 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -38,8 +45,11 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
std::string loss_var_name_; std::string loss_var_name_;
const std::vector<platform::Place> &places_; const std::vector<platform::Place> &places_;
const std::vector<Scope *> &local_scopes_; const std::vector<Scope *> &local_scopes_;
platform::NCCLContextMap *nccl_ctxs_;
std::unordered_set<std::string> grad_names_; std::unordered_set<std::string> grad_names_;
#ifdef PADDLE_WITH_CUDA
platform::NCCLContextMap *nccl_ctxs_;
#endif
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -16,7 +16,9 @@ limitations under the License. */ ...@@ -16,7 +16,9 @@ limitations under the License. */
#include "ThreadPool.h" #include "ThreadPool.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
#endif
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" #include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
...@@ -64,13 +66,18 @@ ParallelExecutor::ParallelExecutor( ...@@ -64,13 +66,18 @@ ParallelExecutor::ParallelExecutor(
member_->local_scopes_.size() != 1) { // Is CUDA member_->local_scopes_.size() != 1) { // Is CUDA
BCastParamsToGPUs(startup_program); BCastParamsToGPUs(startup_program);
} }
// Startup Program has been run. All local scopes has correct parameters. // Startup Program has been run. All local scopes has correct parameters.
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert // Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp // ncclOp
#ifdef PADDLE_WITH_CUDA
details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name, details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name,
params, member_->local_scopes_, params, member_->local_scopes_,
member_->nccl_ctxs_.get()); member_->nccl_ctxs_.get());
#else
details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name,
params, member_->local_scopes_);
#endif
auto graph = builder.Build(main_program); auto graph = builder.Build(main_program);
member_->executor_.reset(new details::ThreadedSSAGraphExecutor( member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
...@@ -137,3 +144,4 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, ...@@ -137,3 +144,4 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
A
\ No newline at end of file
...@@ -21,8 +21,6 @@ limitations under the License. */ ...@@ -21,8 +21,6 @@ limitations under the License. */
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
namespace paddle { namespace paddle {
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <mutex>
#include <thread>
#include "paddle/fluid/operators/reader/reader_op_registry.h" #include "paddle/fluid/operators/reader/reader_op_registry.h"
#include "paddle/fluid/recordio/scanner.h" #include "paddle/fluid/recordio/scanner.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册