提交 02aaecca 编写于 作者: Y Yu Yang

Fix CPU compile

上级 54bd17fe
......@@ -8,8 +8,14 @@ cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_pr
cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base)
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
if(WITH_GPU)
set(multi_devices_graph_builder_deps nccl_all_reduce_op_handle)
else()
set(multi_devices_graph_builder_deps)
endif()
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
nccl_all_reduce_op_handle scale_loss_grad_op_handle)
scale_loss_grad_op_handle ${multi_devices_graph_builder_deps})
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph)
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context)
......@@ -14,14 +14,18 @@
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/nccl_helper.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
#endif
namespace paddle {
namespace framework {
namespace details {
#ifdef PADDLE_WITH_CUDA
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
const std::vector<platform::Place> &places,
const std::string &loss_var_name,
......@@ -32,6 +36,16 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
places_(places),
local_scopes_(local_scopes),
nccl_ctxs_(nccl_ctxs) {
#else
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes)
: loss_var_name_(loss_var_name),
places_(places),
local_scopes_(local_scopes) {
#endif
for (auto &p : params) {
grad_names_.insert(GradVarName(p));
}
......@@ -78,9 +92,16 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
if (is_forwarding) {
if (var_names.size() == 1 && var_names[0] == loss_var_name_) {
// Insert ScaleCost OpHandle
// Insert ScaleCost OpHandle
#ifdef PADDLE_WITH_CUDA
auto *communication_dev_ctx = nccl_ctxs_->DevCtx(p);
#else
auto *communication_dev_ctx =
platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
#endif
op_handle = new ScaleLossGradOpHandle(local_scopes_.size(), s, p,
nccl_ctxs_->DevCtx(p));
communication_dev_ctx);
result.ops_.emplace_back(op_handle);
// FIXME: Currently ScaleLossGradOp only use device_count as scale
......@@ -104,6 +125,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
for (auto &og : var_names) {
if (grad_names_.count(og) != 0) { // is param grad
// Insert NCCL AllReduce Op
#ifdef PADDLE_WITH_CUDA
result.ops_.emplace_back(
new NCCLAllReduceOpHandle(local_scopes_, places_, *nccl_ctxs_));
auto *op_handle = result.ops_.back().get();
......@@ -125,6 +147,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
op_handle->AddOutput(&var);
}
#else
PADDLE_ENFORCE("Not implemented");
#endif
}
}
}
......@@ -143,7 +168,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
}
return std::unique_ptr<SSAGraph>(graph);
}
} // namespace details
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -26,11 +26,18 @@ class Scope;
namespace details {
class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
public:
#ifdef PADDLE_WITH_CUDA
MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes,
platform::NCCLContextMap *nccl_ctxs);
#else
MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes);
#endif
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
......@@ -38,8 +45,11 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
std::string loss_var_name_;
const std::vector<platform::Place> &places_;
const std::vector<Scope *> &local_scopes_;
platform::NCCLContextMap *nccl_ctxs_;
std::unordered_set<std::string> grad_names_;
#ifdef PADDLE_WITH_CUDA
platform::NCCLContextMap *nccl_ctxs_;
#endif
};
} // namespace details
} // namespace framework
......
......@@ -16,7 +16,9 @@ limitations under the License. */
#include "ThreadPool.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h"
#endif
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
......@@ -64,13 +66,18 @@ ParallelExecutor::ParallelExecutor(
member_->local_scopes_.size() != 1) { // Is CUDA
BCastParamsToGPUs(startup_program);
}
// Startup Program has been run. All local scopes has correct parameters.
// Startup Program has been run. All local scopes has correct parameters.
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp
#ifdef PADDLE_WITH_CUDA
details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name,
params, member_->local_scopes_,
member_->nccl_ctxs_.get());
#else
details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name,
params, member_->local_scopes_);
#endif
auto graph = builder.Build(main_program);
member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
......@@ -137,3 +144,4 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
} // namespace framework
} // namespace paddle
A
\ No newline at end of file
......@@ -21,8 +21,6 @@ limitations under the License. */
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
......
......@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <mutex>
#include <thread>
#include "paddle/fluid/operators/reader/reader_op_registry.h"
#include "paddle/fluid/recordio/scanner.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册