未验证 提交 21622ca3 编写于 作者: 乔龙飞 Qiao Longfei 提交者: GitHub

Merge pull request #16172 from jacquesqiao/add-async-ssa-graph-executor-communicator

Add async ssa graph executor communicator
...@@ -196,7 +196,7 @@ endif() ...@@ -196,7 +196,7 @@ endif()
target_link_libraries(executor while_op_helper executor_gc_helper) target_link_libraries(executor while_op_helper executor_gc_helper)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS cc_library(parallel_executor SRCS parallel_executor.cc DEPS
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor async_ssa_graph_executor
graph build_strategy graph build_strategy
fast_threaded_ssa_graph_executor variable_helper) fast_threaded_ssa_graph_executor variable_helper)
......
...@@ -96,6 +96,12 @@ cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS ...@@ -96,6 +96,12 @@ cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS
cc_library(parallel_ssa_graph_executor SRCS parallel_ssa_graph_executor.cc DEPS threaded_ssa_graph_executor) cc_library(parallel_ssa_graph_executor SRCS parallel_ssa_graph_executor.cc DEPS threaded_ssa_graph_executor)
set(ASYNC_SSA_GRAPH_EXECUTOR_DEPS threaded_ssa_graph_executor)
if(WITH_DISTRIBUTE)
list(APPEND ASYNC_SSA_GRAPH_EXECUTOR_DEPS communicator)
endif()
cc_library(async_ssa_graph_executor SRCS async_ssa_graph_executor.cc DEPS ${ASYNC_SSA_GRAPH_EXECUTOR_DEPS})
cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
device_context broadcast_op_handle) device_context broadcast_op_handle)
cc_test(gather_op_test SRCS gather_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory cc_test(gather_op_test SRCS gather_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/async_ssa_graph_executor.h"
#include "paddle/fluid/framework/variable_helper.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/fluid/operators/distributed/communicator.h"
#endif
namespace paddle {
namespace framework {
namespace details {
inline void NewTempScopeAndInitVars(const std::vector<VarInfo> &var_infos,
Scope *scope) {
VLOG(3) << "NewTempScopeAndInitVars";
Scope &local_scope = scope->NewScope();
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() =
&local_scope;
for (auto &info : var_infos) {
if (scope->FindVar(info.name_) != nullptr) {
continue;
}
if (info.persistable_) { // Persistable
InitializeVariable(scope->Var(info.name_), info.type_);
} else {
InitializeVariable(local_scope.Var(info.name_), info.type_);
}
}
}
// get RpcContext and remote send and recv op
void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
#ifdef PADDLE_WITH_DISTRIBUTE
using RpcCtxMap = operators::distributed::RpcCtxMap;
VLOG(3) << "ProcessGraph";
RpcCtxMap send_varname_to_ctx;
RpcCtxMap recv_varname_to_ctx;
for (auto i = 0; i < graphs.size(); ++i) {
std::vector<ir::Node *> nodes_to_delete;
for (auto &node : graphs[i]->Nodes()) {
VLOG(3) << "node name " << node->Name();
if (node && node->IsOp()) {
if (node->Name() == "send") {
auto send_var_name = node->Op()->Input("X")[0];
auto send_varnames = boost::get<std::vector<std::string>>(
node->Op()->GetNullableAttr("send_varnames"));
auto epmap = boost::get<std::vector<std::string>>(
node->Op()->GetNullableAttr("epmap"));
auto height_section = boost::get<std::vector<int64_t>>(
node->Op()->GetNullableAttr("sections"));
send_varname_to_ctx[send_var_name] =
operators::distributed::RpcContext(send_var_name, send_varnames,
epmap, height_section);
VLOG(3) << "find and init an send op: "
<< send_varname_to_ctx[send_var_name];
} else if (node->Name() == "recv") {
auto recv_var_name = node->Op()->Output("Out")[0];
auto recv_varnames = boost::get<std::vector<std::string>>(
node->Op()->GetNullableAttr("recv_varnames"));
auto epmap = boost::get<std::vector<std::string>>(
node->Op()->GetNullableAttr("epmap"));
recv_varname_to_ctx[recv_var_name] =
operators::distributed::RpcContext(recv_var_name, recv_varnames,
epmap, {});
nodes_to_delete.push_back(node);
VLOG(3) << "find and remove an recv op: "
<< recv_varname_to_ctx[recv_var_name];
}
}
}
}
// init communicator here
if (send_varname_to_ctx.size() > 0) {
VLOG(3) << "this is distribute mode, will use communicator";
operators::distributed::Communicator::Init(send_varname_to_ctx,
recv_varname_to_ctx, scope);
operators::distributed::Communicator::GetInstance()->Start();
}
#endif
}
AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, std::vector<ir::Graph *> graphs)
: strategy_(std::move(strategy)),
local_scopes_(std::move(local_scopes)),
pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr),
places_(std::move(places)),
graphs_(std::move(graphs)) {
VLOG(3) << "build AsyncSSAGraphExecutor";
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
// set the correct size of thread pool to each device.
strategy_.num_threads_ = strategy_.num_threads_ < places_.size()
? 1UL
: strategy_.num_threads_ / places_.size();
VLOG(1) << "set num_threads: " << strategy_.num_threads_
<< " to run the operators of the graph on each device.";
for (size_t i = 0; i < places.size(); ++i) {
executors_.emplace_back(new details::ThreadedSSAGraphExecutor(
strategy_, {local_scopes_[i]}, {places_[i]}, graphs_[i]));
}
for (auto &node : graphs_[0]->Nodes()) {
if (node->IsVar() && !node->IsCtrlVar() && node->Var()) {
var_infos_.emplace_back();
var_infos_.back().name_ = node->Var()->Name();
var_infos_.back().type_ = node->Var()->GetType();
var_infos_.back().persistable_ = node->Var()->Persistable();
}
}
for (auto *scope : local_scopes_) {
NewTempScopeAndInitVars(var_infos_, scope);
}
ProcessGraph(graphs_, local_scopes_[0]);
}
void AsyncSSAGraphExecutor::StartOffPythonTrainLoop() {
VLOG(3) << "StartOffPythonTrainLoop size = " << places_.size();
for (size_t i = 1; i < places_.size(); ++i) {
auto call = [this, i]() -> void {
VLOG(3) << "start off python thread " << i;
try {
while (true) {
executors_[i]->Run({});
}
} catch (...) {
exception_holder_.Catch(std::current_exception());
VLOG(3) << "get exception type = " << exception_holder_.Type();
}
VLOG(3) << "thread " << i << " exited!";
};
run_futures_.emplace_back(pool_->enqueue(std::move(call)));
}
}
void AsyncSSAGraphExecutor::HandleException() {
if (exception_holder_.IsCaught()) {
for (auto &f : run_futures_) {
VLOG(3) << "wait future";
f.wait();
}
VLOG(3) << "caught exception " << exception_holder_.Type()
<< ", rethrow it";
run_futures_.clear();
exception_holder_.ReThrow();
}
}
FeedFetchList AsyncSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) {
// init once
if (run_futures_.size() == 0 && places_.size() > 1) {
exception_holder_.Clear();
StartOffPythonTrainLoop();
}
if (places_.size() == 1) {
exception_holder_.Clear();
} else {
HandleException();
}
FeedFetchList fetch_data;
fetch_data.reserve(fetch_tensors.size());
try {
fetch_data = executors_[0]->Run(fetch_tensors);
} catch (...) {
exception_holder_.Catch(std::current_exception());
}
HandleException();
FeedFetchList ret;
for (size_t fetch_idx = 0; fetch_idx < fetch_tensors.size(); ++fetch_idx) {
std::vector<const LoDTensor *> lodtensor_ptrs;
lodtensor_ptrs.push_back(&fetch_data.at(fetch_idx));
ret.emplace_back();
ret.back().MergeLoDTensor(lodtensor_ptrs, platform::CPUPlace());
}
return ret;
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "ThreadPool.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
namespace paddle {
namespace framework {
namespace details {
struct VarInfo {
std::string name_;
proto::VarType::Type type_;
bool persistable_;
};
class AsyncSSAGraphExecutor : public SSAGraphExecutor {
public:
AsyncSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
std::vector<ir::Graph *> graphs);
~AsyncSSAGraphExecutor() final = default;
const ir::Graph &Graph() const override { return *graphs_[0]; }
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
private:
void StartOffPythonTrainLoop();
void HandleException();
private:
ExecutionStrategy strategy_;
std::vector<Scope *> local_scopes_;
std::unique_ptr<::ThreadPool> pool_{nullptr};
std::vector<platform::Place> places_;
std::vector<ir::Graph *> graphs_;
std::vector<std::unique_ptr<details::ThreadedSSAGraphExecutor>> executors_;
ExceptionHolder exception_holder_;
std::vector<std::future<void>> run_futures_;
std::vector<VarInfo> var_infos_;
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -184,8 +184,12 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -184,8 +184,12 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Convert graph to run on multi-devices. // Convert graph to run on multi-devices.
void AppendMultiDevPass(const BuildStrategy &strategy) { void AppendMultiDevPass(const BuildStrategy &strategy) {
ir::Pass *multi_devices_pass = nullptr; ir::Pass *multi_devices_pass = nullptr;
if (strategy.is_distribution_) {
VLOG(10) << "Add dist_multi_devices_pass"; if (strategy_.async_mode_) {
multi_devices_pass = AppendPass("async_multi_devices_pass").get();
} else if (strategy_.is_distribution_) {
VLOG(10)
<< "Add dist_multi_devices_pass, multi device parameter server mode";
multi_devices_pass = AppendPass("dist_multi_devices_pass").get(); multi_devices_pass = AppendPass("dist_multi_devices_pass").get();
} else { } else {
if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
...@@ -234,10 +238,12 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, ...@@ -234,10 +238,12 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
#else #else
const bool use_cuda) const { const bool use_cuda) const {
#endif #endif
VLOG(3) << "apply all passes";
// Create a default one if not finalized by user. // Create a default one if not finalized by user.
CreatePassesFromStrategy(false); CreatePassesFromStrategy(false);
for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) { for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) {
VLOG(3) << "apply " << pass->Type();
if (IsMultiDevPass(pass->Type())) { if (IsMultiDevPass(pass->Type())) {
pass->Erase(kPlaces); pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places); pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
...@@ -293,6 +299,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, ...@@ -293,6 +299,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
graph = pass->Apply(graph); graph = pass->Apply(graph);
VLOG(3) << "Finish Apply Pass " << pass->Type(); VLOG(3) << "Finish Apply Pass " << pass->Type();
} }
VLOG(3) << "All Passes Applied";
return graph; return graph;
} }
......
...@@ -97,6 +97,7 @@ struct BuildStrategy { ...@@ -97,6 +97,7 @@ struct BuildStrategy {
// num_trainers is 1, so the current fields of build_strategy doesn't tell if // num_trainers is 1, so the current fields of build_strategy doesn't tell if
// it's distributed model. // it's distributed model.
bool is_distribution_{false}; bool is_distribution_{false};
bool async_mode_{false};
int num_trainers_{1}; int num_trainers_{1};
int trainer_id_{0}; int trainer_id_{0};
std::vector<std::string> trainers_endpoints_; std::vector<std::string> trainers_endpoints_;
......
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
#pragma once #pragma once
#include <memory>
#include <string>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -64,6 +67,21 @@ class ExceptionHolder { ...@@ -64,6 +67,21 @@ class ExceptionHolder {
ClearImpl(); ClearImpl();
} }
std::string Type() {
std::lock_guard<std::mutex> lock(mu_);
switch (type_) {
case kNone:
return "None";
case kEnforceNotMet: {
return "EnforceNotMet";
}
case kEOF: {
return "EOF";
}
}
return "unknown";
}
private: private:
void ClearImpl() { void ClearImpl() {
exception_.reset(); exception_.reset();
......
...@@ -31,6 +31,8 @@ struct ExecutionStrategy { ...@@ -31,6 +31,8 @@ struct ExecutionStrategy {
size_t num_iteration_per_drop_scope_{1}; size_t num_iteration_per_drop_scope_{1};
ExecutorType type_{kDefault}; ExecutorType type_{kDefault};
bool dry_run_{false}; bool dry_run_{false};
size_t num_iteration_per_run_{1}; // only use with async_ssa_graph_executor
// and pyreader with data queue
}; };
} // namespace details } // namespace details
......
...@@ -198,8 +198,22 @@ void MultiDevSSAGraphBuilderBase::ApplyImpl(ir::Graph *graph) const { ...@@ -198,8 +198,22 @@ void MultiDevSSAGraphBuilderBase::ApplyImpl(ir::Graph *graph) const {
static_cast<bool>(boost::get<int>(node->Op()->GetAttr( static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) & OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kBackward)); static_cast<int>(OpRole::kBackward));
// optimize op is already processed in DealWithSpecialOp,
// here we only consider backward op
if (!is_bk_op) continue; if (!is_bk_op) continue;
/*
* the op that will generate the gradient of on parameter will have
one attr op_role_var
* to record the parameter and gradient, like:
attrs {
name: "op_role_var"
type: STRINGS
strings: "fc_1.b_0"
strings: "fc_1.b_0@GRAD"
}
*/
// Currently, we assume that once gradient is generated, it can be // Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once. // broadcast, and each gradient is only broadcast once.
auto backward_vars = auto backward_vars =
...@@ -256,6 +270,8 @@ void MultiDevSSAGraphBuilderBase::InsertScaleLossGradOp( ...@@ -256,6 +270,8 @@ void MultiDevSSAGraphBuilderBase::InsertScaleLossGradOp(
break; break;
} }
VLOG(3) << "loss_scale: " << loss_scale;
if (loss_scale) { if (loss_scale) {
// TODO(paddle-dev): Why is there no input for this op_handle? // TODO(paddle-dev): Why is there no input for this op_handle?
auto loss_grad_name = node->Op()->OutputArgumentNames()[0]; auto loss_grad_name = node->Op()->OutputArgumentNames()[0];
...@@ -407,7 +423,7 @@ void MultiDevSSAGraphBuilderBase::CreateFusedBroadcastOp( ...@@ -407,7 +423,7 @@ void MultiDevSSAGraphBuilderBase::CreateFusedBroadcastOp(
void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result, void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result,
ir::Node *node, ir::Node *node,
int dev_id) const { size_t dev_id) const {
result->Get<GraphOps>(kGraphOps).emplace_back( result->Get<GraphOps>(kGraphOps).emplace_back(
new ComputationOpHandle(result->CreateOpNode(node->Op()), new ComputationOpHandle(result->CreateOpNode(node->Op()),
local_scopes_[dev_id], places_[dev_id], dev_id)); local_scopes_[dev_id], places_[dev_id], dev_id));
...@@ -494,9 +510,8 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOps( ...@@ -494,9 +510,8 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOps(
} }
} }
VarHandle *MultiDevSSAGraphBuilderBase::CreateReduceOp(ir::Graph *result, VarHandle *MultiDevSSAGraphBuilderBase::CreateReduceOp(
const std::string &og, ir::Graph *result, const std::string &og, size_t dst_dev_id) const {
int dst_dev_id) const {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
result->Get<GraphOps>(kGraphOps).emplace_back(new ReduceOpHandle( result->Get<GraphOps>(kGraphOps).emplace_back(new ReduceOpHandle(
result->CreateEmptyNode("reduce", ir::Node::Type::kOperation), result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
...@@ -774,6 +789,8 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result, ...@@ -774,6 +789,8 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
} else if (OpHaveRole(*node, OpRole::kDist)) { } else if (OpHaveRole(*node, OpRole::kDist)) {
int op_dev_id = CreateDistTrainOp(result, node); int op_dev_id = CreateDistTrainOp(result, node);
if (node->Op()->Type() == "concat") { if (node->Op()->Type() == "concat") {
// the input(block of parameter) of concat is on different device,
// the output(parameter) will on one device.
auto origin_param_name = node->Op()->OutputArgumentNames()[0]; auto origin_param_name = node->Op()->OutputArgumentNames()[0];
bcast_var_name_set_[op_dev_id].emplace(origin_param_name); bcast_var_name_set_[op_dev_id].emplace(origin_param_name);
} }
...@@ -781,6 +798,7 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result, ...@@ -781,6 +798,7 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
} else { } else {
int op_dev_id = GetOpDeviceID(node); int op_dev_id = GetOpDeviceID(node);
if (op_dev_id != -1) { // This op only runs on one specific device. if (op_dev_id != -1) { // This op only runs on one specific device.
// optimize op will be processed here.
CreateComputationalOp(result, node, op_dev_id); CreateComputationalOp(result, node, op_dev_id);
for (ir::Node *n : node->outputs) { for (ir::Node *n : node->outputs) {
sharded_var_device_.emplace(n->Name(), op_dev_id); sharded_var_device_.emplace(n->Name(), op_dev_id);
...@@ -961,6 +979,7 @@ bool DistSSAGraphBuilder::IsEncoded(const std::string &p_name) const { ...@@ -961,6 +979,7 @@ bool DistSSAGraphBuilder::IsEncoded(const std::string &p_name) const {
void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result, void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result,
const std::string &p_name, const std::string &p_name,
const std::string &g_name) const { const std::string &g_name) const {
// collective gradient to each device
size_t cur_device_id = 0; size_t cur_device_id = 0;
switch (strategy_.reduce_) { switch (strategy_.reduce_) {
case BuildStrategy::ReduceStrategy::kReduce: case BuildStrategy::ReduceStrategy::kReduce:
...@@ -1049,3 +1068,5 @@ REGISTER_MULTI_DEVICES_PASS( ...@@ -1049,3 +1068,5 @@ REGISTER_MULTI_DEVICES_PASS(
paddle::framework::details::AllReduceSSAGraphBuilder); paddle::framework::details::AllReduceSSAGraphBuilder);
REGISTER_MULTI_DEVICES_PASS(dist_multi_devices_pass, REGISTER_MULTI_DEVICES_PASS(dist_multi_devices_pass,
paddle::framework::details::DistSSAGraphBuilder); paddle::framework::details::DistSSAGraphBuilder);
REGISTER_MULTI_DEVICES_PASS(async_multi_devices_pass,
paddle::framework::details::AsyncSSAGraphBuilder);
...@@ -56,8 +56,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass { ...@@ -56,8 +56,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
bool UseGPU() const; bool UseGPU() const;
bool NeedCollectiveForGrad(const std::string &grad_name, virtual bool NeedCollectiveForGrad(const std::string &grad_name,
std::vector<ir::Node *> ops) const; std::vector<ir::Node *> ops) const;
bool IsScaleLossOp(ir::Node *node) const; bool IsScaleLossOp(ir::Node *node) const;
...@@ -70,10 +70,10 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass { ...@@ -70,10 +70,10 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
proto::VarType::Type dtype) const; proto::VarType::Type dtype) const;
VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og, VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og,
int dst_dev_id) const; size_t dst_dev_id) const;
void CreateComputationalOp(ir::Graph *result, ir::Node *node, void CreateComputationalOp(ir::Graph *result, ir::Node *node,
int dev_id) const; size_t dev_id) const;
bool IsSparseGradient(const std::string &og) const; bool IsSparseGradient(const std::string &og) const;
...@@ -115,6 +115,35 @@ class AllReduceSSAGraphBuilder : public MultiDevSSAGraphBuilderBase { ...@@ -115,6 +115,35 @@ class AllReduceSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
virtual void InsertPostprocessOps(ir::Graph *result) const {} virtual void InsertPostprocessOps(ir::Graph *result) const {}
}; };
class AsyncSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
protected:
void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
const std::string &g_name) const override {}
bool NeedCollectiveForGrad(const std::string &grad_name,
std::vector<ir::Node *> ops) const {
return false;
}
bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const override {
if (node->Op()->Type() == "recv") {
VLOG(1) << "set recv op do_not_run to true";
node->Op()->SetAttr("do_not_run", true);
node->Op()->Flush();
} else if (node->Name() == "lookup_table" || node->Name() == "nce" ||
node->Name() == "hierarchical_sigmoid") {
// in async_mode, we do not need remote prefetch, because communicator
// will do async parameter recv.
VLOG(1) << "set " << node->Name() << " op remote_prefetch to false";
node->Op()->SetAttr("remote_prefetch", false);
node->Op()->Flush();
}
return false;
}
void InsertPostprocessOps(ir::Graph *result) const override {}
};
class BalanceVarSSAGraphBuilder : public MultiDevSSAGraphBuilderBase { class BalanceVarSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
protected: protected:
int GetVarDeviceID(const std::string &varname) const; int GetVarDeviceID(const std::string &varname) const;
......
...@@ -31,11 +31,23 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor( ...@@ -31,11 +31,23 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
prepare_pool_(1), prepare_pool_(1),
pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_) pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_)
: nullptr) { : nullptr) {
if (strategy_.num_iteration_per_run_ > 1) {
int read_op_num = 0;
for (auto *node : graph_->Nodes()) {
if (node->IsOp() && node->Name() == "read") {
read_op_num++;
}
}
if (read_op_num == 0) {
LOG(WARNING) << "when num_iteration_per_run_ is larger then 1, the model "
"should use pyreader to feed data!";
}
}
PrepareOpDeps(); PrepareOpDeps();
CopyOpDeps(); CopyOpDeps();
} }
FeedFetchList ThreadedSSAGraphExecutor::Run( inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
const std::vector<std::string> &fetch_tensors) { const std::vector<std::string> &fetch_tensors) {
std::unique_ptr<platform::RecordEvent> event( std::unique_ptr<platform::RecordEvent> event(
new platform::RecordEvent("ThreadedSSAGraphExecutorPrepare")); new platform::RecordEvent("ThreadedSSAGraphExecutorPrepare"));
...@@ -84,6 +96,8 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -84,6 +96,8 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
auto cur_ready_vars = ready_vars->PopAll(1, &timeout); auto cur_ready_vars = ready_vars->PopAll(1, &timeout);
if (timeout) { if (timeout) {
if (exception_holder_.IsCaught()) { if (exception_holder_.IsCaught()) {
VLOG(3) << "caught exception " << exception_holder_.Type()
<< ", rethrow it";
for (auto &run_op_future : run_op_futures_) { for (auto &run_op_future : run_op_futures_) {
run_op_future.wait(); run_op_future.wait();
} }
...@@ -114,6 +128,14 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -114,6 +128,14 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
return fetch_data; return fetch_data;
} }
FeedFetchList ThreadedSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) {
for (size_t j = 0; j < strategy_.num_iteration_per_run_ - 1; ++j) {
RunImpl({});
}
return RunImpl(fetch_tensors);
}
void ThreadedSSAGraphExecutor::InsertFetchOps( void ThreadedSSAGraphExecutor::InsertFetchOps(
const std::vector<std::string> &fetch_tensors, const std::vector<std::string> &fetch_tensors,
std::vector<FetchOpHandle *> *fetch_ops, std::vector<FetchOpHandle *> *fetch_ops,
......
...@@ -23,7 +23,9 @@ ...@@ -23,7 +23,9 @@
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "ThreadPool.h" // ThreadPool in thrird party
#include <ThreadPool.h> // ThreadPool in thrird party
#include "paddle/fluid/framework/blocking_queue.h" #include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/details/exception_holder.h" #include "paddle/fluid/framework/details/exception_holder.h"
#include "paddle/fluid/framework/details/execution_strategy.h" #include "paddle/fluid/framework/details/execution_strategy.h"
...@@ -59,6 +61,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -59,6 +61,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
~ThreadedSSAGraphExecutor() final = default; ~ThreadedSSAGraphExecutor() final = default;
private: private:
inline FeedFetchList RunImpl(const std::vector<std::string> &fetch_tensors);
void RunOp(const std::shared_ptr<BlockingQueue<VarHandleBase *>> &ready_var_q, void RunOp(const std::shared_ptr<BlockingQueue<VarHandleBase *>> &ready_var_q,
details::OpHandleBase *op); details::OpHandleBase *op);
......
...@@ -13,11 +13,16 @@ See the License for the specific language governing permissions and ...@@ -13,11 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/pass.h"
#include <memory>
#include <utility>
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
Graph* Pass::Apply(Graph* graph) const { Graph* Pass::Apply(Graph* graph) const {
PADDLE_ENFORCE(graph, "graph passed to Pass::Apply() cannot be empty."); PADDLE_ENFORCE(graph, "graph passed to Pass::Apply() cannot be empty.");
for (const std::string& attr : required_pass_attrs_) { for (const std::string& attr : required_pass_attrs_) {
......
...@@ -24,6 +24,7 @@ limitations under the License. */ ...@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/details/all_reduce_deps_pass.h" #include "paddle/fluid/framework/details/all_reduce_deps_pass.h"
#include "paddle/fluid/framework/details/async_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h" #include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h"
...@@ -218,6 +219,18 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places, ...@@ -218,6 +219,18 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
} }
} }
std::vector<ir::Graph *> graphs;
if (build_strategy.async_mode_) {
PADDLE_ENFORCE(!member_->use_cuda_,
"gpu mode does not support async_mode_ now!");
graphs.push_back(graph);
for (int i = 1; i < places.size(); ++i) {
auto *tmp_graph = new ir::Graph(graph->OriginProgram());
async_graphs_.emplace_back(tmp_graph);
graphs.push_back(tmp_graph);
}
}
// FIXME(Yancey1989): parallel graph mode get better performance // FIXME(Yancey1989): parallel graph mode get better performance
// in GPU allreduce distributed training. Need an elegant way to // in GPU allreduce distributed training. Need an elegant way to
// choice the execution strategy. // choice the execution strategy.
...@@ -294,19 +307,46 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places, ...@@ -294,19 +307,46 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
if (need_broadcast()) { if (need_broadcast()) {
BCastParamsToDevices(bcast_vars, build_strategy.trainer_id_); BCastParamsToDevices(bcast_vars, build_strategy.trainer_id_);
} }
// Startup Program has been run. All local scopes has correct parameters.
// Startup Program has been run. All local scopes has correct parameters. // Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert std::vector<ir::Graph *> async_graphs(places.size());
// ncclOp
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
graph = build_strategy.Apply(graph, member_->places_, loss_var_name, if (build_strategy.async_mode_) {
member_->local_scopes_, member_->nranks_, VLOG(3) << "use local async mode";
graph = build_strategy.Apply(graph, {member_->places_[0]}, loss_var_name,
{member_->local_scopes_[0]}, 1,
member_->use_cuda_, member_->nccl_ctxs_.get());
for (int i = 1; i < member_->places_.size(); ++i) {
graphs[i] =
build_strategy.Apply(graphs[i], {member_->places_[i]}, loss_var_name,
{member_->local_scopes_[i]}, 1,
member_->use_cuda_, member_->nccl_ctxs_.get()); member_->use_cuda_, member_->nccl_ctxs_.get());
async_graphs[i] = graphs[i];
}
} else {
graph = build_strategy.Apply(graph, member_->places_, loss_var_name,
member_->local_scopes_, member_->nranks_,
member_->use_cuda_, member_->nccl_ctxs_.get());
}
#else #else
graph = build_strategy.Apply(graph, member_->places_, loss_var_name, if (build_strategy.async_mode_) {
member_->local_scopes_, member_->nranks_, VLOG(3) << "use local async mode";
member_->use_cuda_); graph = build_strategy.Apply(graph, {member_->places_[0]}, loss_var_name,
{member_->local_scopes_[0]}, 1,
member_->use_cuda_);
for (int i = 1; i < member_->places_.size(); ++i) {
graphs[i] = build_strategy.Apply(
graphs[i], {member_->places_[i]}, loss_var_name,
{member_->local_scopes_[i]}, 1, member_->use_cuda_);
async_graphs[i] = graphs[i];
}
} else {
graph = build_strategy.Apply(graph, member_->places_, loss_var_name,
member_->local_scopes_, member_->nranks_,
member_->use_cuda_);
}
#endif #endif
auto max_memory_size = GetEagerDeletionThreshold(); auto max_memory_size = GetEagerDeletionThreshold();
...@@ -317,6 +357,8 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places, ...@@ -317,6 +357,8 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
static_cast<size_t>(max_memory_size)); static_cast<size_t>(max_memory_size));
} }
async_graphs[0] = graph;
// Step 3. Create vars in each scope. Passes may also create new vars. // Step 3. Create vars in each scope. Passes may also create new vars.
// skip control vars and empty vars // skip control vars and empty vars
std::vector<details::VariableInfo> var_infos; std::vector<details::VariableInfo> var_infos;
...@@ -344,7 +386,12 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places, ...@@ -344,7 +386,12 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
} }
} }
if (build_strategy.enable_parallel_graph_) { if (build_strategy.async_mode_) {
VLOG(3) << "use AsyncSSAGraphExecutor";
member_->executor_.reset(new details::AsyncSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_, async_graphs));
} else if (build_strategy.enable_parallel_graph_) {
VLOG(3) << "use ParallelSSAGraphExecutor";
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
// TODO(Yancey1989): Remove passing in the main_program when // TODO(Yancey1989): Remove passing in the main_program when
// allreduce_seq_pass doesn't need it as the attr. // allreduce_seq_pass doesn't need it as the attr.
...@@ -356,21 +403,27 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places, ...@@ -356,21 +403,27 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
#endif #endif
} else { } else {
if (exec_strategy.type_ == ExecutionStrategy::kDefault) { if (exec_strategy.type_ == ExecutionStrategy::kDefault) {
VLOG(3) << "use ThreadedSSAGraphExecutor";
member_->executor_.reset(new details::ThreadedSSAGraphExecutor( member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_, graph)); exec_strategy, member_->local_scopes_, member_->places_, graph));
} else { } else {
VLOG(3) << "use FastThreadedSSAGraphExecutor";
member_->executor_.reset(new details::FastThreadedSSAGraphExecutor( member_->executor_.reset(new details::FastThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_, graph)); exec_strategy, member_->local_scopes_, member_->places_, graph));
} }
} }
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( VLOG(3) << "use ScopeBufferedSSAGraphExecutor";
exec_strategy, member_->local_scopes_, std::move(var_infos), if (!build_strategy.async_mode_) {
member_->places_, std::move(member_->executor_))); member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, std::move(var_infos),
member_->places_, std::move(member_->executor_)));
}
} }
void ParallelExecutor::BCastParamsToDevices( void ParallelExecutor::BCastParamsToDevices(
const std::vector<std::string> &vars, int trainer_id) const { const std::vector<std::string> &vars, int trainer_id) const {
VLOG(3) << "BCastParamsToDevices";
// the initializing bcast, all vars would be bcast from device(0). // the initializing bcast, all vars would be bcast from device(0).
for (auto &var : vars) { for (auto &var : vars) {
framework::Variable *main_var = member_->local_scopes_[0]->FindVar(var); framework::Variable *main_var = member_->local_scopes_[0]->FindVar(var);
...@@ -425,14 +478,22 @@ void ParallelExecutor::BCastParamsToDevices( ...@@ -425,14 +478,22 @@ void ParallelExecutor::BCastParamsToDevices(
auto local_scope = member_->local_scopes_[i]; auto local_scope = member_->local_scopes_[i];
auto *t = local_scope->Var(var)->GetMutable<LoDTensor>(); auto *t = local_scope->Var(var)->GetMutable<LoDTensor>();
// FIXME(zcd): LR_DECAY_COUNTER should not be shared. This is a hot fix. auto copy_memory = [&] {
if (member_->use_all_reduce_ || member_->use_cuda_ ||
var == "@LR_DECAY_COUNTER@") {
t->Resize(dims); t->Resize(dims);
t->mutable_data(cpu, main_tensor.type()); t->mutable_data(cpu, main_tensor.type());
paddle::framework::TensorCopy(main_tensor, cpu, t); paddle::framework::TensorCopy(main_tensor, cpu, t);
};
auto share_memory = [&] { t->ShareDataWith(main_tensor); };
// FIXME(zcd): LR_DECAY_COUNTER should not be shared. This is a hot fix.
if (member_->build_strategy_.async_mode_) {
share_memory();
} else if (member_->use_all_reduce_ || member_->use_cuda_ ||
var == "@LR_DECAY_COUNTER@") {
copy_memory();
} else { } else {
t->ShareDataWith(main_tensor); share_memory();
} }
} }
} }
......
...@@ -81,6 +81,7 @@ class ParallelExecutor { ...@@ -81,6 +81,7 @@ class ParallelExecutor {
const BuildStrategy &build_strategy) const; const BuildStrategy &build_strategy) const;
ParallelExecutorPrivate *member_; ParallelExecutorPrivate *member_;
std::vector<std::unique_ptr<ir::Graph>> async_graphs_;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
std::unique_ptr<ncclUniqueId> local_nccl_id_; std::unique_ptr<ncclUniqueId> local_nccl_id_;
#endif #endif
......
...@@ -69,6 +69,9 @@ void ReaderBase::Start() { ...@@ -69,6 +69,9 @@ void ReaderBase::Start() {
ReaderBase::~ReaderBase() {} ReaderBase::~ReaderBase() {}
DecoratedReader::~DecoratedReader() { reader_->Shutdown(); } DecoratedReader::~DecoratedReader() {
VLOG(1) << "~DecoratedReader";
reader_->Shutdown();
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <memory> #include <memory>
#include <unordered_set> #include <unordered_set>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
...@@ -77,7 +78,10 @@ class DecoratedReader : public ReaderBase, ...@@ -77,7 +78,10 @@ class DecoratedReader : public ReaderBase,
~DecoratedReader(); ~DecoratedReader();
protected: protected:
void ShutdownImpl() override { reader_->Shutdown(); } void ShutdownImpl() override {
VLOG(1) << "ShutdownImpl";
reader_->Shutdown();
}
void StartImpl() override { reader_->Start(); } void StartImpl() override { reader_->Start(); }
...@@ -98,6 +102,8 @@ class ReaderHolder { ...@@ -98,6 +102,8 @@ class ReaderHolder {
reader_ = reader_base; reader_ = reader_base;
} }
~ReaderHolder() { VLOG(1) << "~ReaderHolder"; }
const std::shared_ptr<ReaderBase>& Get() const { return reader_; } const std::shared_ptr<ReaderBase>& Get() const { return reader_; }
void ReadNext(std::vector<LoDTensor>* out) { void ReadNext(std::vector<LoDTensor>* out) {
...@@ -106,6 +112,7 @@ class ReaderHolder { ...@@ -106,6 +112,7 @@ class ReaderHolder {
} }
void ResetAll() { void ResetAll() {
VLOG(1) << "ResetAll";
auto end_readers = reader_->GetEndPoints(); auto end_readers = reader_->GetEndPoints();
for (auto* reader : end_readers) { for (auto* reader : end_readers) {
reader->Shutdown(); reader->Shutdown();
...@@ -116,11 +123,13 @@ class ReaderHolder { ...@@ -116,11 +123,13 @@ class ReaderHolder {
} }
void Shutdown() { void Shutdown() {
VLOG(1) << "Shutdown";
PADDLE_ENFORCE_NOT_NULL(reader_); PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->Shutdown(); reader_->Shutdown();
} }
void Start() { void Start() {
VLOG(1) << "start";
PADDLE_ENFORCE_NOT_NULL(reader_); PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->Start(); reader_->Start();
} }
......
...@@ -59,6 +59,10 @@ Scope& Scope::NewScope() const { ...@@ -59,6 +59,10 @@ Scope& Scope::NewScope() const {
return *child; return *child;
} }
std::unique_ptr<Scope> Scope::NewTmpScope() const {
return std::unique_ptr<Scope>(new Scope(this));
}
Variable* Scope::Var(const std::string& name) { Variable* Scope::Var(const std::string& name) {
SCOPE_VARS_WRITER_LOCK SCOPE_VARS_WRITER_LOCK
return VarInternal(name); return VarInternal(name);
......
...@@ -52,6 +52,10 @@ class Scope { ...@@ -52,6 +52,10 @@ class Scope {
/// Mark it to const because that new kid scope cannot change parent scope. /// Mark it to const because that new kid scope cannot change parent scope.
Scope& NewScope() const; Scope& NewScope() const;
/// Create a sub-scope for current scope but do not record it in the kids to
/// avoid performance problems.
std::unique_ptr<Scope> NewTmpScope() const;
/// Create a variable with given name if it doesn't exist. /// Create a variable with given name if it doesn't exist.
/// Caller doesn't own the returned Variable. /// Caller doesn't own the returned Variable.
Variable* Var(const std::string& name); Variable* Var(const std::string& name);
......
...@@ -28,7 +28,7 @@ limitations under the License. */ ...@@ -28,7 +28,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void InitializeVariable(Variable* var, proto::VarType::Type var_type) { void InitializeVariable(Variable *var, proto::VarType::Type var_type) {
if (var_type == proto::VarType::LOD_TENSOR) { if (var_type == proto::VarType::LOD_TENSOR) {
var->GetMutable<LoDTensor>(); var->GetMutable<LoDTensor>();
} else if (var_type == proto::VarType::SELECTED_ROWS) { } else if (var_type == proto::VarType::SELECTED_ROWS) {
...@@ -38,7 +38,7 @@ void InitializeVariable(Variable* var, proto::VarType::Type var_type) { ...@@ -38,7 +38,7 @@ void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
} else if (var_type == proto::VarType::FETCH_LIST) { } else if (var_type == proto::VarType::FETCH_LIST) {
var->GetMutable<FeedFetchList>(); var->GetMutable<FeedFetchList>();
} else if (var_type == proto::VarType::STEP_SCOPES) { } else if (var_type == proto::VarType::STEP_SCOPES) {
var->GetMutable<std::vector<framework::Scope*>>(); var->GetMutable<std::vector<framework::Scope *>>();
} else if (var_type == proto::VarType::LOD_RANK_TABLE) { } else if (var_type == proto::VarType::LOD_RANK_TABLE) {
var->GetMutable<LoDRankTable>(); var->GetMutable<LoDRankTable>();
} else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) { } else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) {
...@@ -57,5 +57,27 @@ void InitializeVariable(Variable* var, proto::VarType::Type var_type) { ...@@ -57,5 +57,27 @@ void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
var_type); var_type);
} }
} }
void CopyVariable(const Variable &src_var, Variable *dst_var) {
// only support cpu now
auto cpu_place = platform::CPUPlace();
if (src_var.IsType<framework::LoDTensor>()) {
auto *tmp_grad_tensor = dst_var->GetMutable<framework::LoDTensor>();
auto &src_tensor = src_var.Get<framework::LoDTensor>();
tmp_grad_tensor->set_lod(src_tensor.lod());
framework::TensorCopy(src_tensor, cpu_place, tmp_grad_tensor);
} else if (src_var.IsType<framework::SelectedRows>()) {
auto &src_slr = src_var.Get<framework::SelectedRows>();
auto *tmp_grad_slr = dst_var->GetMutable<framework::SelectedRows>();
tmp_grad_slr->set_rows(src_slr.rows());
tmp_grad_slr->set_height(src_slr.height());
auto &src_t = src_slr.value();
auto *dst_t = tmp_grad_slr->mutable_value();
framework::TensorCopy(src_t, cpu_place, dst_t);
} else {
PADDLE_THROW("unknown var type to copy");
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -17,7 +17,9 @@ limitations under the License. */ ...@@ -17,7 +17,9 @@ limitations under the License. */
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void InitializeVariable(Variable *var, proto::VarType::Type var_type);
void InitializeVariable(Variable* var, proto::VarType::Type var_type);
void CopyVariable(const Variable& src_var, Variable* dst_var);
} // end namespace framework } // end namespace framework
} // end namespace paddle } // end namespace paddle
...@@ -30,7 +30,7 @@ if(WITH_GRPC) ...@@ -30,7 +30,7 @@ if(WITH_GRPC)
else() else()
set(BRPC_SRCS brpc/brpc_client.cc brpc/brpc_server.cc brpc/brpc_sendrecvop_utils.cc brpc/brpc_variable_response.cc brpc/brpc_rdma_pool.cc) set(BRPC_SRCS brpc/brpc_client.cc brpc/brpc_server.cc brpc/brpc_sendrecvop_utils.cc brpc/brpc_variable_response.cc brpc/brpc_rdma_pool.cc)
set_source_files_properties(${BRPC_SRCS} parameter_prefetch.cc rpc_server_test.cc brpc/brpc_serde_test.cc collective_server.cc collective_server_test.cc collective_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(${BRPC_SRCS} parameter_prefetch.cc parameter_send.cc parameter_recv.cc communicator.cc rpc_server_test.cc brpc/brpc_serde_test.cc collective_server.cc collective_server_test.cc collective_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set(BRPC_DEPS brpc ssl crypto protobuf leveldb snappystream snappy zlib) set(BRPC_DEPS brpc ssl crypto protobuf leveldb snappystream snappy zlib)
...@@ -50,8 +50,12 @@ endif() ...@@ -50,8 +50,12 @@ endif()
cc_test(rpc_server_test SRCS rpc_server_test.cc cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS ${RPC_DEPS} executor proto_desc lookup_sparse_table_op SERIAL) DEPS ${RPC_DEPS} executor proto_desc lookup_sparse_table_op SERIAL)
cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler) cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope)
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory) cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory)
cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory)
cc_library(parameter_recv SRCS parameter_recv.cc DEPS sendrecvop_rpc memory)
cc_library(communicator SRCS communicator.cc DEPS scope selected_rows tensor variable_helper selected_rows_functor simple_threadpool parameter_send parameter_recv)
cc_test(communicator_test SRCS communicator_test.cc DEPS communicator)
if(WITH_GPU) if(WITH_GPU)
cc_test(collective_server_test SRCS collective_server_test.cc cc_test(collective_server_test SRCS collective_server_test.cc
DEPS sendrecvop_rpc executor ${RPC_DEPS} DEPS sendrecvop_rpc executor ${RPC_DEPS}
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/distributed/communicator.h"
#include <gflags/gflags.h>
#include <chrono> // NOLINT
#include <thread> // NOLINT
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/distributed/parameter_recv.h"
#include "paddle/fluid/operators/distributed/parameter_send.h"
DEFINE_bool(communicator_independent_recv_thread, true,
"use an independent to recv vars from parameter server");
DEFINE_int32(communicator_send_queue_size, 20,
"queue size to recv gradient before send");
DEFINE_int32(communicator_max_send_grad_num_before_recv, 20,
"max grad num to send before recv parameters");
DEFINE_int32(communicator_thread_pool_size, 5, "thread num to do send or recv");
DEFINE_int32(communicator_max_merge_var_num, 20,
"max var num to merge and send");
DEFINE_bool(communicator_fake_rpc, false,
"fake mode does not really send any thing");
namespace paddle {
namespace operators {
namespace distributed {
inline double GetCurrentUS() {
struct timeval time;
gettimeofday(&time, NULL);
return 1e+6 * time.tv_sec + time.tv_usec;
}
std::unique_ptr<Communicator> Communicator::communicator_(nullptr);
std::once_flag Communicator::init_flag_;
Communicator::Communicator(const RpcCtxMap &send_varname_to_ctx,
const RpcCtxMap &recv_varname_to_ctx,
Scope *recv_scope)
: send_varname_to_ctx_(send_varname_to_ctx),
recv_varname_to_ctx_(recv_varname_to_ctx),
recv_scope_(recv_scope) {
// get all send information from graph, build vars_to_send
VLOG(0) << "communicator_independent_recv_thread: "
<< FLAGS_communicator_independent_recv_thread;
VLOG(0) << "communicator_send_queue_size: "
<< FLAGS_communicator_send_queue_size;
VLOG(0) << "communicator_max_send_grad_num_before_recv: "
<< FLAGS_communicator_max_send_grad_num_before_recv;
VLOG(0) << "communicator_thread_pool_size: "
<< FLAGS_communicator_thread_pool_size;
VLOG(0) << "communicator_max_merge_var_num: "
<< FLAGS_communicator_max_merge_var_num;
VLOG(0) << "communicator_fake_rpc: " << FLAGS_communicator_fake_rpc;
send_scope_.reset(new Scope());
for (auto &iter : send_varname_to_ctx_) {
send_varname_to_queue_[iter.first] =
std::make_shared<BlockingQueue<std::shared_ptr<Variable>>>(
FLAGS_communicator_send_queue_size);
}
send_threadpool_.reset(new ::ThreadPool(FLAGS_communicator_thread_pool_size));
recv_threadpool_.reset(new ::ThreadPool(FLAGS_communicator_thread_pool_size));
}
Communicator::~Communicator() {
VLOG(3) << "~Communicator";
running_ = false;
if (send_thread_) send_thread_->join();
if (recv_thread_) recv_thread_->join();
VLOG(3) << "~Communicator done";
}
void Communicator::SendThread() {
VLOG(3) << "SendThread start!";
while (running_) {
std::vector<std::future<void>> task_futures;
task_futures.reserve(send_varname_to_ctx_.size());
VLOG(3) << "run send graph";
auto before_run_send_graph = GetCurrentUS();
for (auto &iter : send_varname_to_queue_) {
auto &var_name = iter.first;
auto &var_queue = iter.second;
if (var_queue->Size() > 0) {
auto send_task = [this, &var_name, &var_queue] {
VLOG(3) << var_name << " merge and send";
std::vector<std::shared_ptr<Variable>> vars;
size_t merged_var_num = 0;
while (var_queue->Size() > 0 &&
merged_var_num < FLAGS_communicator_max_merge_var_num) {
vars.push_back(var_queue->Pop());
// only count the send number of the first var
if (var_name == send_varname_to_queue_.begin()->first) {
grad_num_.fetch_add(1, std::memory_order_relaxed);
}
merged_var_num++;
}
auto before_merge = GetCurrentUS();
MergeVars(var_name, vars, send_scope_.get());
auto after_merge = GetCurrentUS();
VLOG(3) << "merge " << var_name << " use time "
<< after_merge - before_merge;
auto send_functor = distributed::ParameterSend<float>();
auto &ctx = send_varname_to_ctx_.at(var_name);
if (!FLAGS_communicator_fake_rpc) {
send_functor(ctx, *send_scope_, true);
}
auto after_send = GetCurrentUS();
VLOG(3) << "send " << var_name << " use time "
<< after_send - after_merge;
};
task_futures.emplace_back(
send_threadpool_->enqueue(std::move(send_task)));
} else {
VLOG(3) << var_name << " queue empty";
}
}
for (auto &task_f : task_futures) {
task_f.wait();
}
auto after_run_send_graph = GetCurrentUS();
auto send_graph_use_time = after_run_send_graph - before_run_send_graph;
if (send_graph_use_time > 100) {
VLOG(1) << "run send graph use time "
<< after_run_send_graph - before_run_send_graph;
}
if (!FLAGS_communicator_independent_recv_thread) {
RecvAll();
}
}
}
void Communicator::RecvAll() {
VLOG(3) << "parallel run recv graph";
auto before_send = GetCurrentUS();
std::vector<std::future<void>> task_futures;
task_futures.reserve(recv_varname_to_ctx_.size());
for (auto &iter : recv_varname_to_ctx_) {
auto recv_task = [this, &iter] {
auto &var_name = iter.first;
VLOG(3) << "recv var " << var_name;
auto recv_functor = distributed::ParameterRecv<float>();
if (!FLAGS_communicator_fake_rpc) {
recv_functor(iter.second, *recv_scope_);
}
};
task_futures.emplace_back(recv_threadpool_->enqueue(std::move(recv_task)));
}
for (auto &task : task_futures) {
task.wait();
}
auto after_recv = GetCurrentUS();
VLOG(1) << "run recv graph use time " << after_recv - before_send;
}
void Communicator::RecvThread() {
VLOG(3) << "RecvThread start!";
while (running_) {
auto grad_num = grad_num_.load();
if (grad_num > FLAGS_communicator_max_send_grad_num_before_recv) {
VLOG(1) << "current grad num " << grad_num;
RecvAll();
grad_num_.store(0);
} else {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
}
}
void Communicator::Send(const std::string &var_name,
const framework::Scope &scope) {
VLOG(3) << "communicator send " << var_name;
// push var into send queue by var_name
auto *grad_var = scope.FindVar(var_name);
PADDLE_ENFORCE(grad_var->IsInitialized(), "grad var should be inited");
auto tmp_grad_var = std::make_shared<Variable>();
framework::CopyVariable(*grad_var, tmp_grad_var.get());
auto &queue = send_varname_to_queue_.at(var_name);
VLOG(3) << "send " << var_name << " queue size " << queue->Size();
queue->Push(tmp_grad_var);
}
Communicator *Communicator::GetInstance() { return communicator_.get(); }
void Communicator::Start() {
running_ = true;
// start send and recv thread
send_thread_.reset(
new std::thread(std::bind(&Communicator::SendThread, this)));
if (FLAGS_communicator_independent_recv_thread) {
recv_thread_.reset(
new std::thread(std::bind(&Communicator::RecvThread, this)));
}
}
} // namespace distributed
} // namespace operators
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <atomic>
#include <deque>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include <ThreadPool.h>
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/distributed/rpc_common.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace operators {
namespace distributed {
using Scope = framework::Scope;
using Variable = framework::Variable;
template <typename T>
class BlockingQueue {
public:
explicit BlockingQueue(size_t capacity) : capacity_(capacity) {
PADDLE_ENFORCE_GT(capacity_, 0, "The capacity must be greater than 0.");
}
bool Push(const T& elem) {
{
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [&] { return queue_.size() < capacity_; });
PADDLE_ENFORCE_LT(queue_.size(), capacity_);
queue_.push_back(elem);
}
cv_.notify_one();
return true;
}
bool Push(T&& elem) {
{
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [&] { return queue_.size() < capacity_; });
PADDLE_ENFORCE_LT(queue_.size(), capacity_);
queue_.emplace_back(std::move(elem));
}
cv_.notify_one();
return true;
}
T Pop() {
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [=] { return !queue_.empty(); });
T rc(std::move(queue_.front()));
queue_.pop_front();
cv_.notify_one();
return rc;
}
size_t Cap() const {
std::lock_guard<std::mutex> lock(mutex_);
return capacity_;
}
size_t Size() const {
std::lock_guard<std::mutex> lock(mutex_);
return queue_.size();
}
private:
const size_t capacity_;
std::deque<T> queue_;
mutable std::mutex mutex_;
std::condition_variable cv_;
};
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
inline void MergeVars(const std::string& var_name,
const std::vector<std::shared_ptr<Variable>>& vars,
Scope* scope) {
PADDLE_ENFORCE(!vars.empty(), "should have value to merge!");
auto cpu_place = platform::CPUPlace();
auto& var0 = vars[0];
auto* out_var = scope->Var(var_name);
if (var0->IsType<framework::LoDTensor>()) {
auto dims = var0->Get<framework::LoDTensor>().dims();
VLOG(3) << "merge " << var_name << " LoDTensor " << dims;
// init output tensor
auto* out_t = out_var->GetMutable<framework::LoDTensor>();
out_t->mutable_data<float>(dims, cpu_place);
// check the input dims
for (auto& var : vars) {
auto& var_t = var->Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(var_t.dims(), dims, "should have the same dims");
}
// set output tensor to 0.
auto cpu_ctx = paddle::platform::CPUDeviceContext();
math::SetConstant<paddle::platform::CPUDeviceContext, float>
constant_functor;
constant_functor(cpu_ctx, out_t, static_cast<float>(0));
// sum all vars to out
auto result = EigenVector<float>::Flatten(*out_t);
for (auto& var : vars) {
auto& in_t = var->Get<framework::LoDTensor>();
auto in = EigenVector<float>::Flatten(in_t);
result.device(*cpu_ctx.eigen_device()) = result + in;
}
} else if (var0->IsType<framework::SelectedRows>()) {
auto& slr0 = var0->Get<framework::SelectedRows>();
auto* out_slr = out_var->GetMutable<framework::SelectedRows>();
out_slr->mutable_rows()->clear();
out_slr->mutable_value()->mutable_data<float>({{}}, cpu_place);
std::vector<const paddle::framework::SelectedRows*> inputs;
inputs.reserve(vars.size());
for (auto& var : vars) {
inputs.push_back(&var->Get<framework::SelectedRows>());
}
math::scatter::MergeAdd<paddle::platform::CPUDeviceContext, float>
merge_add;
auto dev_ctx = paddle::platform::CPUDeviceContext();
merge_add(dev_ctx, inputs, out_slr, false);
VLOG(3) << "merge " << var_name << " SelectedRows height: " << slr0.height()
<< " dims: " << slr0.value().dims();
} else {
PADDLE_THROW("unsupported var type!");
}
}
using RpcCtxMap = std::unordered_map<std::string, RpcContext>;
class Communicator {
public:
Communicator(const RpcCtxMap& send_varname_to_ctx,
const RpcCtxMap& recv_varname_to_ctx, Scope* recv_scope);
~Communicator();
void Start();
// send grad
void Send(const std::string& var_name, const framework::Scope& scope);
private:
// recv all parameter
void RecvAll();
void SendThread();
void RecvThread();
bool running_ = false;
std::unordered_map<std::string,
std::shared_ptr<BlockingQueue<std::shared_ptr<Variable>>>>
send_varname_to_queue_;
RpcCtxMap send_varname_to_ctx_;
RpcCtxMap recv_varname_to_ctx_;
std::unique_ptr<std::thread> send_thread_;
std::unique_ptr<std::thread> recv_thread_;
Scope* recv_scope_; // should be global scope
std::unique_ptr<Scope> send_scope_; // an independent scope
std::unique_ptr<::ThreadPool> send_threadpool_{nullptr};
std::unique_ptr<::ThreadPool> recv_threadpool_{nullptr};
std::atomic_uint grad_num_{0}; // the num of gradient sent since last recv
// the following code is for initialize the commnunicator
public:
static void Init(const RpcCtxMap& send_varname_to_ctx,
const RpcCtxMap& recv_varname_to_ctx, Scope* recv_scope) {
InitImpl(send_varname_to_ctx, recv_varname_to_ctx, recv_scope);
}
static Communicator* GetInstance();
private:
// Init is called by GetInstance.
static void InitImpl(const RpcCtxMap& send_varname_to_ctx,
const RpcCtxMap& recv_varname_to_ctx,
Scope* recv_scope) {
if (communicator_ == nullptr) {
communicator_.reset(new Communicator(send_varname_to_ctx,
recv_varname_to_ctx, recv_scope));
}
}
private:
static std::once_flag init_flag_;
static std::unique_ptr<Communicator> communicator_;
};
} // namespace distributed
} // namespace operators
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <algorithm>
#include <memory>
#include <vector>
#include "paddle/fluid/operators/distributed/communicator.h"
namespace paddle {
namespace operators {
namespace distributed {
using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
TEST(communicator, merge_lod_tensors) {
auto cpu_place = platform::CPUPlace();
auto dims = framework::make_ddim({2, 3});
std::vector<std::shared_ptr<framework::Variable>> in_vars;
float out_value = 0;
for (auto i = 0; i < 10; ++i) {
auto var = std::make_shared<Variable>();
in_vars.emplace_back(var);
auto *tensor = var->GetMutable<LoDTensor>();
auto *data = tensor->mutable_data<float>(dims, cpu_place);
for (auto j = 0; j < tensor->numel(); ++j) {
data[j] = static_cast<float>(i);
}
out_value += static_cast<float>(i);
}
const std::string out_name = "Out";
std::unique_ptr<framework::Scope> scope;
scope.reset(new framework::Scope());
scope->Var(out_name);
for (auto i = 0; i < 10; ++i) {
MergeVars(out_name, in_vars, scope.get());
}
auto &out_tensor = scope->FindVar(out_name)->Get<LoDTensor>();
auto *out_data = out_tensor.data<float>();
ASSERT_EQ(out_tensor.dims(), dims);
for (auto i = 0; i < out_tensor.numel(); ++i) {
ASSERT_EQ(out_data[i], out_value);
}
}
TEST(communicator, merge_selected_rows) {
auto cpu_place = platform::CPUPlace();
int64_t width = 10;
std::vector<std::shared_ptr<framework::Variable>> in_vars;
const int64_t height = 100;
for (auto i = 0; i < 10; ++i) {
std::vector<int64_t> rows;
for (auto k = 0; k <= i; ++k) {
rows.push_back(k);
}
auto var = std::make_shared<Variable>();
in_vars.emplace_back(var);
auto *slr = var->GetMutable<SelectedRows>();
slr->set_height(height);
slr->set_rows(rows);
auto dims =
framework::make_ddim({static_cast<int64_t>(rows.size()), width});
auto *data = slr->mutable_value()->mutable_data<float>(dims, cpu_place);
for (auto i = 0; i < rows.size(); ++i) {
for (auto j = 0; j < width; ++j) {
data[i * width + j] = static_cast<float>(rows[i]);
}
}
}
const std::string out_name = "Out";
std::unique_ptr<framework::Scope> scope;
scope.reset(new framework::Scope());
scope->Var(out_name);
for (auto i = 0; i < 10; ++i) {
MergeVars(out_name, in_vars, scope.get());
}
auto &out_slr = scope->FindVar(out_name)->Get<SelectedRows>();
auto &out_t = out_slr.value();
auto *out_data = out_t.data<float>();
ASSERT_EQ(out_t.dims(), framework::make_ddim({10, width}));
std::vector<float> out_values;
out_values.reserve(10);
for (auto i = 0; i < 10; ++i) {
out_values.push_back(static_cast<float>(i * (10 - i)));
}
for (auto i = 0; i < out_slr.rows().size(); ++i) {
ASSERT_EQ(out_slr.rows()[i], i);
for (auto j = 0; j < width; ++j) {
ASSERT_EQ(out_data[i * width + j], out_values[i]);
}
}
}
} // namespace distributed
} // namespace operators
} // namespace paddle
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <limits> #include <limits>
#include <memory>
#include <string> #include <string>
#include "paddle/fluid/operators/distributed/grpc/grpc_serde.h" #include "paddle/fluid/operators/distributed/grpc/grpc_serde.h"
...@@ -106,7 +107,6 @@ class RequestSend final : public RequestBase { ...@@ -106,7 +107,6 @@ class RequestSend final : public RequestBase {
auto invar = request_->GetVar(); auto invar = request_->GetVar();
int trainer_id = request_->GetTrainerId(); int trainer_id = request_->GetTrainerId();
framework::Variable* outvar = nullptr; framework::Variable* outvar = nullptr;
request_handler_->Handle(varname, scope, invar, &outvar, trainer_id); request_handler_->Handle(varname, scope, invar, &outvar, trainer_id);
Finish(reply_, &responder_); Finish(reply_, &responder_);
} }
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <memory>
#include <set> #include <set>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
...@@ -38,30 +39,9 @@ using LoDTensor = framework::LoDTensor; ...@@ -38,30 +39,9 @@ using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows; using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim; using DDim = framework::DDim;
static size_t GetSectionIndex(int64_t id,
const std::vector<int64_t>& abs_sections) {
for (size_t i = 1; i < abs_sections.size(); ++i) {
if (id < abs_sections[i]) {
return i - 1;
}
}
return abs_sections.size() - 1;
}
static std::vector<int64_t> ToAbsoluteSection(
const std::vector<int>& height_sections) {
std::vector<int64_t> abs_sections;
abs_sections.resize(height_sections.size());
abs_sections[0] = 0;
for (size_t i = 1; i < height_sections.size(); ++i) {
abs_sections[i] = height_sections[i - 1] + abs_sections[i - 1];
}
return abs_sections;
}
static std::vector<std::vector<int64_t>> SplitIds( static std::vector<std::vector<int64_t>> SplitIds(
const std::vector<int64_t>& ids_vector, const std::vector<int64_t>& ids_vector,
const std::vector<int>& height_section, framework::Scope* scope) { const std::vector<int64_t>& height_section) {
std::set<int64_t> all_ids; std::set<int64_t> all_ids;
for (auto id : ids_vector) { for (auto id : ids_vector) {
all_ids.insert(id); all_ids.insert(id);
...@@ -79,7 +59,7 @@ static std::vector<std::vector<int64_t>> SplitIds( ...@@ -79,7 +59,7 @@ static std::vector<std::vector<int64_t>> SplitIds(
static void SplitIdsIntoMultipleVarsBySection( static void SplitIdsIntoMultipleVarsBySection(
const std::vector<std::string>& in_var_names, const std::vector<std::string>& in_var_names,
const std::vector<int>& height_section, const std::vector<int64_t>& height_section,
const std::vector<std::vector<int64_t>>& splited_ids, const std::vector<std::vector<int64_t>>& splited_ids,
framework::Scope* scope) { framework::Scope* scope) {
PADDLE_ENFORCE_EQ(in_var_names.size(), height_section.size(), ""); PADDLE_ENFORCE_EQ(in_var_names.size(), height_section.size(), "");
...@@ -101,7 +81,7 @@ static void SplitIdsIntoMultipleVarsBySection( ...@@ -101,7 +81,7 @@ static void SplitIdsIntoMultipleVarsBySection(
static void MergeMultipleVarsIntoOneBySection( static void MergeMultipleVarsIntoOneBySection(
const std::string& id_name, const std::vector<int64_t>& ids_vector, const std::string& id_name, const std::vector<int64_t>& ids_vector,
const std::string& out_name, const std::vector<std::string>& out_var_names, const std::string& out_name, const std::vector<std::string>& out_var_names,
const std::vector<int>& height_section, const std::vector<int64_t>& height_section,
const std::vector<std::vector<int64_t>>& splited_ids, const std::vector<std::vector<int64_t>>& splited_ids,
const framework::ExecutionContext& context, framework::Scope* scope, const framework::ExecutionContext& context, framework::Scope* scope,
platform::DeviceContext* actual_ctx) { platform::DeviceContext* actual_ctx) {
...@@ -178,10 +158,10 @@ static void MergeMultipleVarsIntoOneBySection( ...@@ -178,10 +158,10 @@ static void MergeMultipleVarsIntoOneBySection(
void prefetch(const std::string& id_name, const std::string& out_name, void prefetch(const std::string& id_name, const std::string& out_name,
const std::vector<std::string>& table_names, const std::vector<std::string>& table_names,
const std::vector<std::string>& epmap, const std::vector<std::string>& epmap,
const std::vector<int>& height_sections, const std::vector<int64_t>& height_sections,
const framework::ExecutionContext& context, const framework::ExecutionContext& context,
const framework::Scope& scope) { const framework::Scope& scope) {
auto& local_scope = scope.NewScope(); std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope();
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& cpu_ctx = *pool.Get(platform::CPUPlace()); auto& cpu_ctx = *pool.Get(platform::CPUPlace());
...@@ -225,23 +205,23 @@ void prefetch(const std::string& id_name, const std::string& out_name, ...@@ -225,23 +205,23 @@ void prefetch(const std::string& id_name, const std::string& out_name,
#endif #endif
} }
auto splited_ids = SplitIds(ids_vector, height_sections, &local_scope); auto splited_ids = SplitIds(ids_vector, height_sections);
SplitIdsIntoMultipleVarsBySection(in_var_names, height_sections, splited_ids, SplitIdsIntoMultipleVarsBySection(in_var_names, height_sections, splited_ids,
&local_scope); local_scope.get());
// create output var in local scope // create output var in local scope
for (auto& name : out_var_names) { for (auto& name : out_var_names) {
local_scope.Var(name)->GetMutable<framework::LoDTensor>(); local_scope->Var(name)->GetMutable<framework::LoDTensor>();
} }
std::vector<distributed::VarHandlePtr> rets; std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < in_var_names.size(); i++) { for (size_t i = 0; i < in_var_names.size(); i++) {
if (NeedSend(local_scope, in_var_names[i])) { if (NeedSend(*local_scope.get(), in_var_names[i])) {
VLOG(3) << "sending " << in_var_names[i] << " to " << epmap[i] VLOG(3) << "sending " << in_var_names[i] << " to " << epmap[i]
<< " to get " << out_var_names[i] << " back"; << " to get " << out_var_names[i] << " back";
rets.push_back(rpc_client->AsyncPrefetchVar( rets.push_back(rpc_client->AsyncPrefetchVar(
epmap[i], cpu_ctx, local_scope, in_var_names[i], out_var_names[i], epmap[i], cpu_ctx, *local_scope.get(), in_var_names[i],
table_names[i])); out_var_names[i], table_names[i]));
} else { } else {
VLOG(3) << "don't send no-initialied variable: " << out_var_names[i]; VLOG(3) << "don't send no-initialied variable: " << out_var_names[i];
} }
...@@ -253,8 +233,7 @@ void prefetch(const std::string& id_name, const std::string& out_name, ...@@ -253,8 +233,7 @@ void prefetch(const std::string& id_name, const std::string& out_name,
MergeMultipleVarsIntoOneBySection(id_name, ids_vector, out_name, MergeMultipleVarsIntoOneBySection(id_name, ids_vector, out_name,
out_var_names, height_sections, splited_ids, out_var_names, height_sections, splited_ids,
context, &local_scope, &actual_ctx); context, local_scope.get(), &actual_ctx);
scope.DeleteScope(&local_scope);
} }
}; // namespace distributed }; // namespace distributed
......
...@@ -26,7 +26,7 @@ namespace distributed { ...@@ -26,7 +26,7 @@ namespace distributed {
void prefetch(const std::string& id_name, const std::string& out_name, void prefetch(const std::string& id_name, const std::string& out_name,
const std::vector<std::string>& table_names, const std::vector<std::string>& table_names,
const std::vector<std::string>& epmap, const std::vector<std::string>& epmap,
const std::vector<int>& height_sections, const std::vector<int64_t>& height_sections,
const framework::ExecutionContext& context, const framework::ExecutionContext& context,
const framework::Scope& scope); const framework::Scope& scope);
...@@ -35,7 +35,7 @@ void prefetch_with_reconstruct(const std::string& id_name, ...@@ -35,7 +35,7 @@ void prefetch_with_reconstruct(const std::string& id_name,
const std::string& out_name, const std::string& out_name,
const std::vector<std::string>& table_names, const std::vector<std::string>& table_names,
const std::vector<std::string>& epmap, const std::vector<std::string>& epmap,
const std::vector<int>& height_sections, const std::vector<int64_t>& height_sections,
const framework::ExecutionContext& context, const framework::ExecutionContext& context,
const framework::Scope& scope, const framework::Scope& scope,
framework::LoDTensor* original) { framework::LoDTensor* original) {
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "paddle/fluid/operators/distributed/parameter_recv.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
#include "paddle/fluid/operators/strided_memcpy.h"
namespace paddle {
namespace operators {
namespace distributed {
using LoDTensor = framework::LoDTensor;
using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim;
template <typename T>
void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
const framework::Scope &scope) {
VLOG(3) << "ParameterRecv in";
std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &cpu_ctx = *pool.Get(platform::CPUPlace());
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
auto *recv_var = scope.FindVar(rpc_ctx.var_name);
std::vector<framework::Tensor *> recved_tensors;
// recv all vars to local scope
if (recv_var->IsType<framework::LoDTensor>()) {
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
auto &recv_var_name = rpc_ctx.splited_var_names[i];
framework::Tensor *t =
local_scope->Var(recv_var_name)->GetMutable<framework::LoDTensor>();
recved_tensors.push_back(t);
VLOG(3) << "recv " << recv_var_name << " from " << rpc_ctx.epmap[i];
rets.push_back(rpc_client->AsyncGetVar(rpc_ctx.epmap[i], cpu_ctx,
*local_scope.get(), recv_var_name,
recv_var_name));
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
}
} else {
PADDLE_THROW("unsupported var type to recv!");
}
// concat recved tensor into one var
{
size_t output_offset = 0;
framework::Tensor *recv_tensor =
recv_var->GetMutable<framework::LoDTensor>();
auto dev_ctx = paddle::platform::CPUDeviceContext();
int64_t recv_numel = 0;
for (auto *in : recved_tensors) {
recv_numel += in->numel();
auto in_stride = framework::stride_numel(in->dims());
auto out_stride = framework::stride_numel(recv_tensor->dims());
StridedNumelCopyWithAxis<T>(
dev_ctx, 0, recv_tensor->data<T>() + output_offset, out_stride,
in->data<T>(), in_stride, in_stride[0]);
output_offset += in_stride[0];
}
PADDLE_ENFORCE_EQ(recv_numel, recv_tensor->numel());
}
VLOG(3) << "ParameterRecv out";
}
template struct ParameterRecv<float>;
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/distributed/rpc_common.h"
namespace paddle {
namespace operators {
namespace distributed {
template <typename T>
struct ParameterRecv {
void operator()(const RpcContext &rpc_ctx, const framework::Scope &scope);
};
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "paddle/fluid/operators/distributed/parameter_send.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
namespace paddle {
namespace operators {
namespace distributed {
using LoDTensor = framework::LoDTensor;
using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim;
template <typename T>
void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
const framework::Scope &scope, bool sync) {
std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &cpu_ctx = *pool.Get(platform::CPUPlace());
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
auto *send_var = scope.FindVar(rpc_ctx.var_name);
size_t out_num = rpc_ctx.splited_var_names.size();
if (send_var->IsType<framework::LoDTensor>()) {
if (out_num > 1) {
auto &send_tensor = send_var->Get<framework::LoDTensor>();
auto &send_tensor_dims = send_tensor.dims();
std::vector<framework::DDim> outs_dims;
outs_dims.reserve(out_num);
// infer output shape
PADDLE_ENFORCE_EQ(rpc_ctx.height_sections.size(), out_num,
"tensor split sections size"
"should be equal to output size.");
for (size_t i = 0; i < out_num; ++i) {
auto dim = send_tensor_dims;
dim[0] = rpc_ctx.height_sections[i];
outs_dims.push_back(dim);
}
// create output var in local scope
size_t row_offset = 0;
for (auto i = 0; i < out_num; ++i) {
framework::Tensor *out = local_scope->Var(rpc_ctx.splited_var_names[i])
->GetMutable<framework::LoDTensor>();
*out = send_tensor.Slice(row_offset, row_offset + outs_dims[i][0]);
row_offset += outs_dims[i][0];
}
}
} else if (send_var->IsType<framework::SelectedRows>()) {
auto &send_slr = send_var->Get<framework::SelectedRows>();
auto abs_sections = ToAbsoluteSection(rpc_ctx.height_sections);
auto &send_rows = send_slr.rows();
std::vector<std::vector<size_t>> outs_rows_idx;
std::vector<std::vector<size_t>> outs_dense_idx;
outs_rows_idx.resize(out_num);
outs_dense_idx.resize(out_num);
auto row_numel = send_slr.value().numel() / send_slr.value().dims()[0];
auto *src = send_slr.value().data<T>();
// create output var in local scope
std::vector<framework::SelectedRows *> outs;
for (auto &name : rpc_ctx.splited_var_names) {
auto *out = local_scope->Var(name)->GetMutable<framework::SelectedRows>();
outs.push_back(out);
}
// split rows index into output sparse vars
for (size_t i = 0; i < send_rows.size(); ++i) {
size_t out_idx = GetSectionIndex(send_rows[i], abs_sections);
outs_rows_idx[out_idx].push_back(send_rows[i]);
outs_dense_idx[out_idx].push_back(i);
}
auto place = platform::CPUPlace();
for (size_t i = 0; i < outs_rows_idx.size(); ++i) {
auto rows_idx = outs_rows_idx[i];
outs[i]->set_height(rpc_ctx.height_sections[i]);
auto dims = send_slr.GetCompleteDims();
dims[0] = rows_idx.size();
outs[i]->mutable_rows()->clear();
outs[i]->mutable_value()->mutable_data<T>(dims, send_slr.place());
if (rows_idx.size() > 0) {
for (auto idx : rows_idx) {
outs[i]->mutable_rows()->push_back(idx - abs_sections[i]);
}
auto dst = outs[i]->mutable_value()->mutable_data<T>(place);
for (size_t j = 0; j < rows_idx.size(); j++) {
if (platform::is_cpu_place(place)) {
memory::Copy(
platform::CPUPlace(), dst + j * row_numel, platform::CPUPlace(),
src + outs_dense_idx[i][j] * row_numel, sizeof(T) * row_numel);
} else {
PADDLE_THROW("do not support GPU now");
/*
#ifdef PADDLE_WITH_CUDA
auto stream = ctx.cuda_device_context().stream();
memory::Copy(platform::CUDAPlace(), dst + j * row_numel,
platform::CUDAPlace(),
src + outs_dense_idx[i][j] * row_numel,
sizeof(T) * row_numel, stream);
#else
PADDLE_THROW("Paddle is not compiled with GPU");
#endif
*/
}
}
}
PADDLE_ENFORCE_EQ(rows_idx.size(), outs[i]->rows().size(),
"rows should has the same size with tensor dim 0");
}
} else {
PADDLE_THROW("unsupported var type to send!");
}
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
auto &send_var_name = rpc_ctx.splited_var_names[i];
auto &endpoint = rpc_ctx.epmap[i];
if (NeedSend(*local_scope.get(), send_var_name)) {
VLOG(3) << "sending " << send_var_name << " to " << endpoint;
rets.push_back(rpc_client->AsyncSendVar(
endpoint, cpu_ctx, *local_scope.get(), send_var_name));
} else {
VLOG(3) << "don't send non-initialized variable: "
<< rpc_ctx.splited_var_names[i];
}
}
if (sync) {
for (auto &handle : rets) {
PADDLE_ENFORCE(handle->Wait(), "internal error in RPCClient");
}
}
}
template struct ParameterSend<float>;
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/distributed/rpc_common.h"
namespace paddle {
namespace operators {
namespace distributed {
template <typename T>
struct ParameterSend {
void operator()(const RpcContext &rpc_ctx, const framework::Scope &scope,
bool sync);
};
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
...@@ -59,13 +59,8 @@ bool RequestSendHandler::Handle(const std::string& varname, ...@@ -59,13 +59,8 @@ bool RequestSendHandler::Handle(const std::string& varname,
"async mode should not recv BATCH_BARRIER_MESSAGE or " "async mode should not recv BATCH_BARRIER_MESSAGE or "
"COMPLETE_MESSAGE"); "COMPLETE_MESSAGE");
} }
try { executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(),
executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(), scope);
scope);
} catch (std::exception& e) {
LOG(ERROR) << "async: run sub program error " << e.what();
return false;
}
return true; return true;
} else { // sync } else { // sync
rpc_server_->WaitCond(kRequestSend); rpc_server_->WaitCond(kRequestSend);
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iostream>
#include <string>
#include <vector>
namespace paddle {
namespace operators {
namespace distributed {
struct RpcContext {
RpcContext() = default;
RpcContext(const std::string &name, const std::vector<std::string> &names,
const std::vector<std::string> &emap,
const std::vector<int64_t> &sections)
: var_name(name),
splited_var_names(names),
epmap(emap),
height_sections(sections) {}
RpcContext(const RpcContext &ctx) {
var_name = ctx.var_name;
splited_var_names = ctx.splited_var_names;
epmap = ctx.epmap;
height_sections = ctx.height_sections;
}
std::string var_name;
std::vector<std::string> splited_var_names;
std::vector<std::string> epmap;
std::vector<int64_t> height_sections;
};
inline std::ostream &operator<<(std::ostream &os, const RpcContext &rpc_ctx) {
os << "{";
os << "var_name: " << rpc_ctx.var_name << "\n";
os << "splited_var_names: [";
for (auto &name : rpc_ctx.splited_var_names) {
os << name << ", ";
}
os << "]\n";
os << "epmap: [";
for (auto &ep : rpc_ctx.epmap) {
os << ep << ", ";
}
os << "]\n";
os << "height_sections: [";
for (auto &section : rpc_ctx.height_sections) {
os << section << ", ";
}
os << "]\n";
os << "}";
return os;
}
} // namespace distributed
} // namespace operators
} // namespace paddle
...@@ -60,13 +60,14 @@ class VariableResponse { ...@@ -60,13 +60,14 @@ class VariableResponse {
bool create_scope = false) bool create_scope = false)
: scope_(scope), dev_ctx_(dev_ctx), create_scope_(create_scope) { : scope_(scope), dev_ctx_(dev_ctx), create_scope_(create_scope) {
if (create_scope) { if (create_scope) {
local_scope_ = &scope->NewScope(); local_scope_ = scope->NewTmpScope().release();
} }
} }
virtual ~VariableResponse() { virtual ~VariableResponse() {
if (create_scope_) { if (local_scope_) {
scope_->DeleteScope(local_scope_); delete local_scope_;
local_scope_ = nullptr;
} }
} }
......
...@@ -2,9 +2,9 @@ include(operators) ...@@ -2,9 +2,9 @@ include(operators)
set(DISTRIBUTE_DEPS "") set(DISTRIBUTE_DEPS "")
if(WITH_GRPC) if(WITH_GRPC)
set(DISTRIBUTE_DEPS sendrecvop_rpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node) set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node)
else() else()
set(DISTRIBUTE_DEPS sendrecvop_rpc brpc leveldb snappystream snappy protobuf ssl crypto zlib node) set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator brpc leveldb snappystream snappy protobuf ssl crypto zlib node)
if(WITH_BRPC_RDMA) if(WITH_BRPC_RDMA)
find_library(IBVERBS_LIBRARY NAMES ibverbs) find_library(IBVERBS_LIBRARY NAMES ibverbs)
ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL) ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL)
......
...@@ -20,6 +20,8 @@ limitations under the License. */ ...@@ -20,6 +20,8 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/parameter_recv.h"
#include "paddle/fluid/operators/distributed/rpc_common.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
...@@ -34,6 +36,11 @@ class RecvOp : public framework::OperatorBase { ...@@ -34,6 +36,11 @@ class RecvOp : public framework::OperatorBase {
void RunImpl(const framework::Scope &scope, void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
bool do_not_run = Attr<bool>("do_not_run");
if (do_not_run) {
VLOG(3) << "recv do not run!";
return;
}
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap"); std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
std::vector<std::string> varnames = std::vector<std::string> varnames =
Attr<std::vector<std::string>>("varnames"); Attr<std::vector<std::string>>("varnames");
...@@ -48,32 +55,41 @@ class RecvOp : public framework::OperatorBase { ...@@ -48,32 +55,41 @@ class RecvOp : public framework::OperatorBase {
distributed::RPCClient::GetInstance<RPCCLIENT_T>( distributed::RPCClient::GetInstance<RPCCLIENT_T>(
Attr<int>("trainer_id")); Attr<int>("trainer_id"));
if (with_barrier) { std::vector<std::string> recv_varnames =
std::vector<distributed::VarHandlePtr> rets; Attr<std::vector<std::string>>("recv_varnames");
for (size_t i = 0; i < outs.size(); i++) {
std::string varname = varnames.size() == 0 ? outs[i] : varnames[i]; if (recv_varnames.size() > 0) {
VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with " auto recv_functor = distributed::ParameterRecv<float>();
<< varname << " and with AsyncGetVar"; auto rpc_ctx = distributed::RpcContext(outs[0], recv_varnames, epmap, {});
rets.push_back( recv_functor(rpc_ctx, scope);
rpc_client->AsyncGetVar(epmap[i], ctx, scope, varname, outs[i])); } else {
} if (with_barrier) {
if (sync_mode) { std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < outs.size(); i++) {
std::string varname = varnames.size() == 0 ? outs[i] : varnames[i];
VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with "
<< varname << " and with AsyncGetVar";
rets.push_back(
rpc_client->AsyncGetVar(epmap[i], ctx, scope, varname, outs[i]));
}
if (sync_mode) {
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
}
}
} else {
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < outs.size(); i++) {
std::string varname = varnames.size() == 0 ? outs[i] : varnames[i];
VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with "
<< varname << " and with AsyncGetVarNoBarrier";
rets.push_back(rpc_client->AsyncGetVarNoBarrier(epmap[i], ctx, scope,
varname, outs[i]));
}
for (size_t i = 0; i < rets.size(); i++) { for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
} }
} }
} else {
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < outs.size(); i++) {
std::string varname = varnames.size() == 0 ? outs[i] : varnames[i];
VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with "
<< varname << " and with AsyncGetVarNoBarrier";
rets.push_back(rpc_client->AsyncGetVarNoBarrier(epmap[i], ctx, scope,
varname, outs[i]));
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
}
} }
} }
}; };
...@@ -110,6 +126,12 @@ This operator can get variables from server side. ...@@ -110,6 +126,12 @@ This operator can get variables from server side.
"for example: we need var named 'moment_1@127.0.0.1:1001', " "for example: we need var named 'moment_1@127.0.0.1:1001', "
"and it real name on parameter server is 'moment_1'. ") "and it real name on parameter server is 'moment_1'. ")
.SetDefault({}); .SetDefault({});
AddAttr<std::vector<std::string>>(
"recv_varnames",
"(vector<string>) "
"the splited parameter varnames to be recved from pserver")
.SetDefault(std::vector<std::string>{});
AddAttr<bool>("do_not_run", "if recv need to really run").SetDefault(false);
} }
}; };
......
...@@ -19,7 +19,10 @@ limitations under the License. */ ...@@ -19,7 +19,10 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed/communicator.h"
#include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/parameter_send.h"
#include "paddle/fluid/operators/distributed/rpc_common.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h" #include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -37,30 +40,47 @@ class SendOp : public framework::OperatorBase { ...@@ -37,30 +40,47 @@ class SendOp : public framework::OperatorBase {
const platform::Place& place) const override { const platform::Place& place) const override {
auto ins = Inputs("X"); auto ins = Inputs("X");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap"); auto epmap = Attr<std::vector<std::string>>("epmap");
int sync_send = Attr<int>("sync_mode"); int sync_send = Attr<int>("sync_mode");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto send_varnames = Attr<std::vector<std::string>>("send_varnames");
auto& ctx = *pool.Get(place); auto height_sections = Attr<std::vector<int64_t>>("sections");
distributed::RPCClient* rpc_client = if (send_varnames.size() > 0) {
distributed::RPCClient::GetInstance<RPCCLIENT_T>( PADDLE_ENFORCE_EQ(ins.size(), 1, "");
Attr<int>("trainer_id")); if (distributed::Communicator::GetInstance() == nullptr) {
auto send_functor = distributed::ParameterSend<float>();
std::vector<distributed::VarHandlePtr> rets; auto rpc_ctx = distributed::RpcContext(ins[0], send_varnames, epmap,
for (size_t i = 0; i < ins.size(); i++) { height_sections);
if (NeedSend(scope, ins[i])) { send_functor(rpc_ctx, scope, true);
VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
rets.push_back(rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i]));
} else { } else {
VLOG(3) << "don't send no-initialied variable: " << ins[i]; distributed::Communicator::GetInstance()->Send(ins[0], scope);
} }
} } else {
if (sync_send) { platform::DeviceContextPool& pool =
for (size_t i = 0; i < rets.size(); i++) { platform::DeviceContextPool::Instance();
VLOG(7) << "before sync_send " << ins[i] << "from " << epmap[i]; auto& ctx = *pool.Get(place);
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
VLOG(7) << "after sync_send " << ins[i] << "from " << epmap[i]; distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(
Attr<int>("trainer_id"));
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
rets.push_back(
rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i]));
} else {
VLOG(3) << "don't send no-initialied variable: " << ins[i];
}
}
if (sync_send) {
for (size_t i = 0; i < rets.size(); i++) {
VLOG(7) << "before sync_send " << ins[i] << "from " << epmap[i];
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
VLOG(7) << "after sync_send " << ins[i] << "from " << epmap[i];
}
} }
} }
} }
...@@ -88,6 +108,21 @@ This operator will send variables to listen_and_serve op at the parameter server ...@@ -88,6 +108,21 @@ This operator will send variables to listen_and_serve op at the parameter server
"Server endpoints in the order of input " "Server endpoints in the order of input "
"variables for mapping") "variables for mapping")
.SetDefault({"127.0.0.1:6164"}); .SetDefault({"127.0.0.1:6164"});
AddAttr<std::vector<int64_t>>("sections",
"(vector<int>) "
"the length of each output along the "
"specified axis.")
.SetDefault(std::vector<int64_t>{});
AddAttr<std::vector<std::string>>(
"send_varnames",
"(vector<string>) "
"the splited output varnames to send to pserver")
.SetDefault(std::vector<std::string>{});
AddAttr<int>("num",
"(int, default 0)"
"Number of sub-tensors. This must evenly divide "
"Input.dims()[axis]")
.SetDefault(0);
} }
}; };
......
...@@ -13,8 +13,14 @@ See the License for the specific language governing permissions and ...@@ -13,8 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include <vector>
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -42,5 +48,26 @@ inline bool NeedSend(const framework::Scope& scope, ...@@ -42,5 +48,26 @@ inline bool NeedSend(const framework::Scope& scope,
return false; return false;
} }
inline std::vector<int64_t> ToAbsoluteSection(
const std::vector<int64_t>& height_sections) {
std::vector<int64_t> abs_sections;
abs_sections.resize(height_sections.size());
abs_sections[0] = 0;
for (size_t i = 1; i < height_sections.size(); ++i) {
abs_sections[i] = height_sections[i - 1] + abs_sections[i - 1];
}
return abs_sections;
}
inline size_t GetSectionIndex(int64_t id,
const std::vector<int64_t>& abs_sections) {
for (size_t i = 1; i < abs_sections.size(); ++i) {
if (id < abs_sections[i]) {
return i - 1;
}
}
return abs_sections.size() - 1;
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -134,9 +134,9 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -134,9 +134,9 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
// for parameter prefetch // for parameter prefetch
AddAttr<bool>("remote_prefetch", "").SetDefault(false); AddAttr<bool>("remote_prefetch", "").SetDefault(false);
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0); AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<std::vector<int>>("height_sections", AddAttr<std::vector<int64_t>>("height_sections",
"Height for each output SelectedRows.") "Height for each output SelectedRows.")
.SetDefault(std::vector<int>({})); .SetDefault(std::vector<int64_t>({}));
AddAttr<std::vector<std::string>>( AddAttr<std::vector<std::string>>(
"epmap", "epmap",
"(string vector, default 127.0.0.1:6164)" "(string vector, default 127.0.0.1:6164)"
......
...@@ -13,11 +13,14 @@ See the License for the specific language governing permissions and ...@@ -13,11 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <iostream> #include <iostream>
#include <iterator> #include <iterator>
#include <memory>
#include <set> #include <set>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/mixed_vector.h" #include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/clip_op.h" #include "paddle/fluid/operators/clip_op.h"
...@@ -65,12 +68,13 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> { ...@@ -65,12 +68,13 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes")); size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
// for remote prefetch // for remote prefetch
auto remote_prefetch = ctx.Attr<bool>("remote_prefetch");
auto epmap = ctx.Attr<std::vector<std::string>>("epmap"); auto epmap = ctx.Attr<std::vector<std::string>>("epmap");
if (!epmap.empty()) { if (remote_prefetch && !epmap.empty()) {
// if epmap is not empty, then the parameter will be fetched from remote // if epmap is not empty, then the parameter will be fetched from remote
// parameter // parameter
// server // server
auto height_sections = ctx.Attr<std::vector<int>>("height_sections"); auto height_sections = ctx.Attr<std::vector<int64_t>>("height_sections");
auto table_names = ctx.Attr<std::vector<std::string>>("table_names"); auto table_names = ctx.Attr<std::vector<std::string>>("table_names");
std::vector<int64_t> real_rows = PathToRows(*path); std::vector<int64_t> real_rows = PathToRows(*path);
framework::Scope& local_scope = ctx.scope().NewScope(); framework::Scope& local_scope = ctx.scope().NewScope();
......
...@@ -91,9 +91,9 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -91,9 +91,9 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
// for parameter prefetch // for parameter prefetch
AddAttr<bool>("remote_prefetch", "").SetDefault(false); AddAttr<bool>("remote_prefetch", "").SetDefault(false);
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0); AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<std::vector<int>>("height_sections", AddAttr<std::vector<int64_t>>("height_sections",
"Height for each output SelectedRows.") "Height for each output SelectedRows.")
.SetDefault(std::vector<int>({})); .SetDefault(std::vector<int64_t>({}));
AddAttr<std::vector<std::string>>( AddAttr<std::vector<std::string>>(
"epmap", "epmap",
"(string vector, default 127.0.0.1:6164)" "(string vector, default 127.0.0.1:6164)"
......
...@@ -84,7 +84,8 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> { ...@@ -84,7 +84,8 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
// for remote prefetch // for remote prefetch
auto epmap = context.Attr<std::vector<std::string>>("epmap"); auto epmap = context.Attr<std::vector<std::string>>("epmap");
auto height_sections = context.Attr<std::vector<int>>("height_sections"); auto height_sections =
context.Attr<std::vector<int64_t>>("height_sections");
auto table_names = context.Attr<std::vector<std::string>>("table_names"); auto table_names = context.Attr<std::vector<std::string>>("table_names");
if (!epmap.empty()) { if (!epmap.empty()) {
......
...@@ -50,10 +50,12 @@ class LookupTableKernel : public framework::OpKernel<T> { ...@@ -50,10 +50,12 @@ class LookupTableKernel : public framework::OpKernel<T> {
// for remote prefetch // for remote prefetch
auto epmap = context.Attr<std::vector<std::string>>("epmap"); auto epmap = context.Attr<std::vector<std::string>>("epmap");
auto height_sections = context.Attr<std::vector<int>>("height_sections"); auto remote_prefetch = context.Attr<bool>("remote_prefetch");
auto height_sections =
context.Attr<std::vector<int64_t>>("height_sections");
auto table_names = context.Attr<std::vector<std::string>>("table_names"); auto table_names = context.Attr<std::vector<std::string>>("table_names");
if (!epmap.empty()) { if (remote_prefetch && !epmap.empty()) {
// if epmap is not empty, then the parameter will be fetched from remote // if epmap is not empty, then the parameter will be fetched from remote
// parameter // parameter
// server // server
......
...@@ -95,7 +95,7 @@ struct MergeAdd { ...@@ -95,7 +95,7 @@ struct MergeAdd {
enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY }; enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY };
// out = seleted_rows_in / tensor // out = selected_rows_in / tensor
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct UpdateToTensor { struct UpdateToTensor {
void operator()(const DeviceContext& context, const ScatterOps& op, void operator()(const DeviceContext& context, const ScatterOps& op,
......
...@@ -156,9 +156,9 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -156,9 +156,9 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker {
// for parameter prefetch // for parameter prefetch
AddAttr<bool>("remote_prefetch", "").SetDefault(false); AddAttr<bool>("remote_prefetch", "").SetDefault(false);
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0); AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<std::vector<int>>("height_sections", AddAttr<std::vector<int64_t>>("height_sections",
"Height for each output SelectedRows.") "Height for each output SelectedRows.")
.SetDefault(std::vector<int>({})); .SetDefault(std::vector<int64_t>({}));
AddAttr<std::vector<std::string>>( AddAttr<std::vector<std::string>>(
"epmap", "epmap",
"(string vector, default 127.0.0.1:6164)" "(string vector, default 127.0.0.1:6164)"
......
...@@ -156,9 +156,10 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -156,9 +156,10 @@ class NCEKernel : public framework::OpKernel<T> {
auto input_mat = EigenMatrix<T>::From(*(context.Input<Tensor>("Input"))); auto input_mat = EigenMatrix<T>::From(*(context.Input<Tensor>("Input")));
// for remote prefetch // for remote prefetch
auto remote_prefetch = context.Attr<bool>("remote_prefetch");
auto epmap = context.Attr<std::vector<std::string>>("epmap"); auto epmap = context.Attr<std::vector<std::string>>("epmap");
if (!epmap.empty()) { if (remote_prefetch && !epmap.empty()) {
// if epmap is not empty, then the parameter will be fetched from remote // if epmap is not empty, then the parameter will be fetched from remote
// parameter // parameter
// server // server
...@@ -172,7 +173,8 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -172,7 +173,8 @@ class NCEKernel : public framework::OpKernel<T> {
framework::Scope &local_scope = context.scope().NewScope(); framework::Scope &local_scope = context.scope().NewScope();
auto height_sections = context.Attr<std::vector<int>>("height_sections"); auto height_sections =
context.Attr<std::vector<int64_t>>("height_sections");
auto table_names = context.Attr<std::vector<std::string>>("table_names"); auto table_names = context.Attr<std::vector<std::string>>("table_names");
auto *ids = local_scope.Var("Ids@Prefetch"); auto *ids = local_scope.Var("Ids@Prefetch");
......
...@@ -80,12 +80,14 @@ class BlockingQueue { ...@@ -80,12 +80,14 @@ class BlockingQueue {
return true; return true;
} else { } else {
PADDLE_ENFORCE(closed_); PADDLE_ENFORCE(closed_);
VLOG(3) << "queue is closed! return nothing.";
return false; return false;
} }
} }
void ReOpen() { void ReOpen() {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
VLOG(1) << "reopen queue";
closed_ = false; closed_ = false;
std::deque<T> new_deque; std::deque<T> new_deque;
queue_.swap(new_deque); queue_.swap(new_deque);
...@@ -95,6 +97,7 @@ class BlockingQueue { ...@@ -95,6 +97,7 @@ class BlockingQueue {
void Close() { void Close() {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
VLOG(1) << "close queue";
closed_ = true; closed_ = true;
send_cv_.notify_all(); send_cv_.notify_all();
receive_cv_.notify_all(); receive_cv_.notify_all();
......
...@@ -22,6 +22,7 @@ namespace paddle { ...@@ -22,6 +22,7 @@ namespace paddle {
namespace operators { namespace operators {
namespace reader { namespace reader {
BufferedReader::~BufferedReader() { BufferedReader::~BufferedReader() {
VLOG(1) << "~BufferedReader";
reader_->Shutdown(); reader_->Shutdown();
while (!position_.empty()) { while (!position_.empty()) {
position_.front().wait(); position_.front().wait();
...@@ -45,6 +46,7 @@ BufferedReader::BufferedReader( ...@@ -45,6 +46,7 @@ BufferedReader::BufferedReader(
thread_pool_(1), thread_pool_(1),
place_(place), place_(place),
buffer_size_(buffer_size) { buffer_size_(buffer_size) {
VLOG(1) << "BufferedReader";
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place_)) { if (platform::is_gpu_place(place_)) {
platform::SetDeviceId(boost::get<platform::CUDAPlace>(place_).device); platform::SetDeviceId(boost::get<platform::CUDAPlace>(place_).device);
...@@ -131,6 +133,7 @@ void BufferedReader::ReadAsync(size_t i) { ...@@ -131,6 +133,7 @@ void BufferedReader::ReadAsync(size_t i) {
} }
void BufferedReader::ShutdownImpl() { void BufferedReader::ShutdownImpl() {
VLOG(1) << "ShutdownImpl";
reader_->Shutdown(); reader_->Shutdown();
while (!position_.empty()) { while (!position_.empty()) {
position_.pop(); position_.pop();
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <memory> #include <memory>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
...@@ -57,7 +58,10 @@ class LoDTensorBlockingQueue { ...@@ -57,7 +58,10 @@ class LoDTensorBlockingQueue {
inline void ReOpen() { queue_.ReOpen(); } inline void ReOpen() { queue_.ReOpen(); }
inline void Close() { queue_.Close(); } inline void Close() {
VLOG(1) << "LoDTensorBlockingQueue close";
queue_.Close();
}
inline bool IsClosed() const { return queue_.IsClosed(); } inline bool IsClosed() const { return queue_.IsClosed(); }
......
...@@ -16,31 +16,12 @@ limitations under the License. */ ...@@ -16,31 +16,12 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/operators/math/selected_rows_functor.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
static int FindOutIdx(int row, const std::vector<int64_t>& abs_sections) {
for (size_t i = 1; i < abs_sections.size(); ++i) {
if (row < abs_sections[i]) {
return i - 1;
}
}
return abs_sections.size() - 1;
}
static std::vector<int64_t> ToAbsoluteSection(
const std::vector<int64_t>& height_sections) {
std::vector<int64_t> abs_sections;
abs_sections.resize(height_sections.size());
abs_sections[0] = 0;
for (size_t i = 1; i < height_sections.size(); ++i) {
abs_sections[i] = height_sections[i - 1] + abs_sections[i - 1];
}
return abs_sections;
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class SplitSelectedRowsOpKernel : public framework::OpKernel<T> { class SplitSelectedRowsOpKernel : public framework::OpKernel<T> {
public: public:
...@@ -51,7 +32,8 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel<T> { ...@@ -51,7 +32,8 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel<T> {
auto abs_sections = ToAbsoluteSection(height_sections); auto abs_sections = ToAbsoluteSection(height_sections);
auto x_rows = x->rows(); auto& x_rows = x->rows();
auto height = x->height();
std::vector<std::vector<int>> outs_rows_idx; std::vector<std::vector<int>> outs_rows_idx;
std::vector<std::vector<int>> outs_dense_idx; std::vector<std::vector<int>> outs_dense_idx;
...@@ -63,8 +45,10 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel<T> { ...@@ -63,8 +45,10 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel<T> {
// split rows index into output sparse vars // split rows index into output sparse vars
for (size_t i = 0; i < x_rows.size(); ++i) { for (size_t i = 0; i < x_rows.size(); ++i) {
int out_idx = FindOutIdx(x_rows[i], abs_sections); auto& id = x_rows[i];
outs_rows_idx[out_idx].push_back(x_rows[i]); PADDLE_ENFORCE_LT(id, height);
int out_idx = GetSectionIndex(id, abs_sections);
outs_rows_idx[out_idx].push_back(id);
outs_dense_idx[out_idx].push_back(i); outs_dense_idx[out_idx].push_back(i);
} }
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
...@@ -78,7 +62,9 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel<T> { ...@@ -78,7 +62,9 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel<T> {
outs[i]->mutable_rows()->clear(); outs[i]->mutable_rows()->clear();
if (rows_idx.size() > 0) { if (rows_idx.size() > 0) {
for (auto idx : rows_idx) { for (auto idx : rows_idx) {
outs[i]->mutable_rows()->push_back(idx - abs_sections[i]); auto id_offset = idx - abs_sections[i];
PADDLE_ENFORCE_LT(id_offset, height_sections[i]);
outs[i]->mutable_rows()->push_back(id_offset);
} }
auto dst = outs[i]->mutable_value()->mutable_data<T>(ctx.GetPlace()); auto dst = outs[i]->mutable_value()->mutable_data<T>(ctx.GetPlace());
for (size_t j = 0; j < rows_idx.size(); j++) { for (size_t j = 0; j < rows_idx.size(); j++) {
......
...@@ -625,6 +625,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -625,6 +625,7 @@ All parameter, weight, gradient are variables in Paddle.
m.def("init_lod_tensor_blocking_queue", m.def("init_lod_tensor_blocking_queue",
[](Variable &var, [](Variable &var,
size_t capacity) -> std::shared_ptr<LoDTensorBlockingQueue> { size_t capacity) -> std::shared_ptr<LoDTensorBlockingQueue> {
VLOG(1) << "init_lod_tensor_blocking_queue";
auto *holder = var.GetMutable<LoDTensorBlockingQueueHolder>(); auto *holder = var.GetMutable<LoDTensorBlockingQueueHolder>();
holder->InitOnce(capacity, FLAGS_reader_queue_speed_test_mode); holder->InitOnce(capacity, FLAGS_reader_queue_speed_test_mode);
return holder->GetQueue(); return holder->GetQueue();
...@@ -1144,6 +1145,17 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1144,6 +1145,17 @@ All parameter, weight, gradient are variables in Paddle.
2. In some NLP model, it may cause the GPU memory is insufficient, 2. In some NLP model, it may cause the GPU memory is insufficient,
in this case, you should reduce `num_iteration_per_drop_scope`. in this case, you should reduce `num_iteration_per_drop_scope`.
)DOC") )DOC")
.def_property(
"num_iteration_per_run",
[](const ExecutionStrategy &self) {
return self.num_iteration_per_run_;
},
[](ExecutionStrategy &self, size_t num_iteration_per_run) {
self.num_iteration_per_run_ = num_iteration_per_run;
},
R"DOC(This config that how many iteration the executor will run when
user call pe.run() in python
)DOC")
.def_property("_dry_run", .def_property("_dry_run",
[](const ExecutionStrategy &self) { return self.dry_run_; }, [](const ExecutionStrategy &self) { return self.dry_run_; },
[](ExecutionStrategy &self, bool dry_run) { [](ExecutionStrategy &self, bool dry_run) {
...@@ -1320,6 +1332,9 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1320,6 +1332,9 @@ All parameter, weight, gradient are variables in Paddle.
"is_distribution", "is_distribution",
[](const BuildStrategy &self) { return self.is_distribution_; }, [](const BuildStrategy &self) { return self.is_distribution_; },
[](BuildStrategy &self, bool b) { self.is_distribution_ = b; }) [](BuildStrategy &self, bool b) { self.is_distribution_ = b; })
.def_property("async_mode",
[](const BuildStrategy &self) { return self.async_mode_; },
[](BuildStrategy &self, bool b) { self.async_mode_ = b; })
.def_property( .def_property(
"enable_inplace", "enable_inplace",
[](const BuildStrategy &self) { return self.enable_inplace_; }, [](const BuildStrategy &self) { return self.enable_inplace_; },
......
...@@ -157,6 +157,7 @@ def __bootstrap__(): ...@@ -157,6 +157,7 @@ def __bootstrap__():
read_env_flags.append('use_ngraph') read_env_flags.append('use_ngraph')
if core.is_compiled_with_dist(): if core.is_compiled_with_dist():
#env for rpc
read_env_flags.append('rpc_deadline') read_env_flags.append('rpc_deadline')
read_env_flags.append('rpc_server_profile_path') read_env_flags.append('rpc_server_profile_path')
read_env_flags.append('enable_rpc_profiler') read_env_flags.append('enable_rpc_profiler')
...@@ -164,6 +165,14 @@ def __bootstrap__(): ...@@ -164,6 +165,14 @@ def __bootstrap__():
read_env_flags.append('rpc_get_thread_num') read_env_flags.append('rpc_get_thread_num')
read_env_flags.append('rpc_prefetch_thread_num') read_env_flags.append('rpc_prefetch_thread_num')
read_env_flags.append('rpc_disable_reuse_port') read_env_flags.append('rpc_disable_reuse_port')
# env for communicator
read_env_flags.append('communicator_independent_recv_thread')
read_env_flags.append('communicator_send_queue_size')
read_env_flags.append('communicator_max_send_grad_num_before_recv')
read_env_flags.append('communicator_thread_pool_size')
read_env_flags.append('communicator_max_merge_var_num')
read_env_flags.append('communicator_fake_rpc')
if core.is_compiled_with_brpc(): if core.is_compiled_with_brpc():
read_env_flags.append('max_body_size') read_env_flags.append('max_body_size')
#set brpc max body size #set brpc max body size
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import unittest
import numpy
import time
import paddle
import paddle.fluid as fluid
BATCH_SIZE = 64
def convolutional_neural_network(use_py_reader):
with fluid.unique_name.guard():
img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
py_reader = None
if use_py_reader:
py_reader = fluid.layers.create_py_reader_by_data(
capacity=64,
feed_list=[img, label],
name='py_reader',
use_double_buffer=False)
img, label = fluid.layers.read_file(py_reader)
conv_pool_1 = fluid.nets.simple_img_conv_pool(
input=img,
filter_size=5,
num_filters=20,
pool_size=2,
pool_stride=2,
act="relu")
conv_pool_1 = fluid.layers.batch_norm(conv_pool_1)
conv_pool_2 = fluid.nets.simple_img_conv_pool(
input=conv_pool_1,
filter_size=5,
num_filters=50,
pool_size=2,
pool_stride=2,
act="relu")
prediction = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=prediction, label=label)
avg_loss = fluid.layers.mean(loss)
acc = fluid.layers.accuracy(input=prediction, label=label)
return img, label, prediction, avg_loss, acc, py_reader
def test():
place = fluid.CPUPlace()
exe = fluid.Executor(place)
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=BATCH_SIZE)
img, label, prediction, avg_loss, acc, py_reader = convolutional_neural_network(
use_py_reader=False)
feeder = fluid.DataFeeder(feed_list=[img, label], place=place)
def train_test(train_test_program, train_test_feed, train_test_reader):
acc_set = []
avg_loss_set = []
for test_data in train_test_reader():
acc_np, avg_loss_np = exe.run(program=train_test_program,
feed=train_test_feed.feed(test_data),
fetch_list=[acc, avg_loss])
acc_set.append(float(acc_np))
avg_loss_set.append(float(avg_loss_np))
# get test acc and loss
acc_val_mean = numpy.array(acc_set).mean()
avg_loss_val_mean = numpy.array(avg_loss_set).mean()
return avg_loss_val_mean, acc_val_mean
# test for epoch
avg_loss_val, acc_val = train_test(
train_test_program=fluid.default_main_program(),
train_test_reader=test_reader,
train_test_feed=feeder)
print("Test: avg_cost: %s, acc: %s" % (avg_loss_val, acc_val))
assert acc_val > 0.96
def train(use_cuda, thread_num, cpu_num):
if use_cuda and not fluid.core.is_compiled_with_cuda():
print("paddle is not compiled with cuda, exit!")
return
img, label, prediction, avg_loss, acc, py_reader = convolutional_neural_network(
use_py_reader=True)
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
optimizer.minimize(avg_loss)
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=500),
batch_size=BATCH_SIZE)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
os.environ['CPU_NUM'] = str(cpu_num)
print("cpu_num:" + str(cpu_num))
print("thread_num:" + str(thread_num))
build_strategy = fluid.BuildStrategy()
build_strategy.async_mode = True
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.num_threads = thread_num
exec_strategy.num_iteration_per_run = 10
main_program = fluid.default_main_program()
pe = fluid.ParallelExecutor(
use_cuda=False,
loss_name=avg_loss.name,
main_program=main_program,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
py_reader.decorate_paddle_reader(train_reader)
for pass_id in range(2):
step = 0
py_reader.start()
try:
while True:
loss_val = pe.run(fetch_list=[avg_loss.name])
loss_val = numpy.mean(loss_val)
if step % 10 == 0:
print("Pass %d, Batch %d, Cost %f, queue size %d" %
(pass_id, step, loss_val, py_reader.queue.size()))
step += 1
except fluid.core.EOFException:
print("train end pass = " + str(pass_id))
py_reader.reset()
return step
class TestAsyncSSAGraphExecutor(unittest.TestCase):
def test_check_async_ssa_exe_train(self):
step_list = []
for cpu_num in [1, 2, 4]:
print("run cpu_num -> " + str(cpu_num))
with fluid.scope_guard(fluid.core.Scope()):
with fluid.program_guard(
main_program=fluid.Program(),
startup_program=fluid.Program()):
start_time = time.time()
step = train(
use_cuda=False, thread_num=cpu_num, cpu_num=cpu_num)
end_time = time.time()
step_list.append(step)
print("cpu_num -> " + str(cpu_num) + " step -> " + str(step) +
" time -> " + str(end_time - start_time))
with fluid.program_guard(
main_program=fluid.Program(),
startup_program=fluid.Program()):
test()
assert abs(int(step_list[0] / 2) - int(step_list[1])) < 5
assert abs(int(step_list[1] / 2) - int(step_list[2])) < 5
if __name__ == "__main__":
unittest.main()
...@@ -52,6 +52,7 @@ class TestDistRunnerBase(object): ...@@ -52,6 +52,7 @@ class TestDistRunnerBase(object):
# NOTE: import fluid until runtime, or else forking processes will cause error. # NOTE: import fluid until runtime, or else forking processes will cause error.
config = fluid.DistributeTranspilerConfig() config = fluid.DistributeTranspilerConfig()
config.enable_dc_asgd = dc_asgd config.enable_dc_asgd = dc_asgd
# config.runtime_split_send_recv = True
t = fluid.DistributeTranspiler(config=config) t = fluid.DistributeTranspiler(config=config)
t.transpile( t.transpile(
trainer_id=trainer_id, trainer_id=trainer_id,
...@@ -139,8 +140,7 @@ class TestDistRunnerBase(object): ...@@ -139,8 +140,7 @@ class TestDistRunnerBase(object):
pass_builder = None pass_builder = None
if args.batch_merge_repeat > 1: if args.batch_merge_repeat > 1:
pass_builder = build_stra._finalize_strategy_and_create_passes() pass_builder = build_stra._finalize_strategy_and_create_passes()
mypass = pass_builder.insert_pass( mypass = pass_builder.insert_pass(0, "multi_batch_merge_pass")
len(pass_builder.all_passes()) - 3, "multi_batch_merge_pass")
mypass.set("num_repeats", args.batch_merge_repeat) mypass.set("num_repeats", args.batch_merge_repeat)
if args.update_method == "nccl2" or args.update_method == "nccl2_reduce_layer": if args.update_method == "nccl2" or args.update_method == "nccl2_reduce_layer":
......
...@@ -38,7 +38,7 @@ class TestSpliteSelectedRows(unittest.TestCase): ...@@ -38,7 +38,7 @@ class TestSpliteSelectedRows(unittest.TestCase):
def check_with_place(self, place): def check_with_place(self, place):
scope = core.Scope() scope = core.Scope()
rows = [0, 5, 7, 4, 20] rows = [0, 5, 7, 4, 20]
height = 20 height = 21
row_numel = 2 row_numel = 2
# initialize input variable X # initialize input variable X
......
...@@ -156,6 +156,8 @@ class DistributeTranspilerConfig(object): ...@@ -156,6 +156,8 @@ class DistributeTranspilerConfig(object):
mode = "pserver" mode = "pserver"
print_log = False print_log = False
wait_port = True wait_port = True
# split the send recv var in runtime
runtime_split_send_recv = False
class DistributeTranspiler(object): class DistributeTranspiler(object):
...@@ -398,8 +400,10 @@ class DistributeTranspiler(object): ...@@ -398,8 +400,10 @@ class DistributeTranspiler(object):
orig_var = program.global_block().vars[splited_grad_varname] orig_var = program.global_block().vars[splited_grad_varname]
index = find_op_by_output_arg( index = find_op_by_output_arg(
program.global_block(), splited_grad_varname, reverse=True) program.global_block(), splited_grad_varname, reverse=True)
self._insert_split_op(program, orig_var, index, splited_vars) if not self.config.runtime_split_send_recv:
index += 1 self._insert_split_op(program, orig_var, index,
splited_vars)
index += 1
else: else:
AssertionError("Can not insert the send op by original " AssertionError("Can not insert the send op by original "
"variable name :", splited_grad_varname) "variable name :", splited_grad_varname)
...@@ -408,6 +412,17 @@ class DistributeTranspiler(object): ...@@ -408,6 +412,17 @@ class DistributeTranspiler(object):
name=framework.generate_control_dev_var_name()) name=framework.generate_control_dev_var_name())
self.grad_name_to_send_dummy_out[grad_varname] = dummy_output self.grad_name_to_send_dummy_out[grad_varname] = dummy_output
if self.config.runtime_split_send_recv:
send_input_vars = [
program.global_block().vars[splited_grad_varname]
]
sections = self._get_splited_var_sections(splited_vars)
send_varnames = [var.name for var in splited_vars]
else:
send_input_vars = splited_vars
sections = []
send_varnames = []
# get send op_role_var, if not splited, the grad should have .trainer suffix # get send op_role_var, if not splited, the grad should have .trainer suffix
# if splited, grad should be the original grad var name (split_by_ref and send # if splited, grad should be the original grad var name (split_by_ref and send
# will be on the same place). ParallelExecutor # will be on the same place). ParallelExecutor
...@@ -415,10 +430,12 @@ class DistributeTranspiler(object): ...@@ -415,10 +430,12 @@ class DistributeTranspiler(object):
program.global_block()._insert_op( program.global_block()._insert_op(
index=index + 1, index=index + 1,
type="send", type="send",
inputs={"X": splited_vars}, inputs={"X": send_input_vars},
outputs={"Out": dummy_output}, outputs={"Out": dummy_output},
attrs={ attrs={
"epmap": eplist, "epmap": eplist,
"sections": sections,
"send_varnames": send_varnames,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
OP_ROLE_VAR_ATTR_NAME: [ OP_ROLE_VAR_ATTR_NAME: [
self.grad_name_to_param_name[grad_varname], self.grad_name_to_param_name[grad_varname],
...@@ -501,13 +518,20 @@ class DistributeTranspiler(object): ...@@ -501,13 +518,20 @@ class DistributeTranspiler(object):
self._update_remote_sparse_update_op( self._update_remote_sparse_update_op(
param_varname, height_sections, eps, table_names) param_varname, height_sections, eps, table_names)
else: else:
recv_varnames = []
if self.config.runtime_split_send_recv:
orig_param = program.global_block().vars[param_varname]
recv_varnames = [var.name for var in splited_var]
splited_var = [orig_param]
all_recv_outputs.extend(splited_var) all_recv_outputs.extend(splited_var)
program.global_block().append_op( program.global_block().append_op(
type="recv", type="recv",
inputs={"X": [recv_dep_in]}, inputs={"X": [recv_dep_in]},
outputs={"Out": splited_var}, outputs={"Out": splited_var},
attrs={ attrs={
"epmap": eps, "epmap": eps,
"recv_varnames": recv_varnames,
"trainer_id": self.trainer_id, "trainer_id": self.trainer_id,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
OP_ROLE_VAR_ATTR_NAME: OP_ROLE_VAR_ATTR_NAME:
...@@ -532,14 +556,15 @@ class DistributeTranspiler(object): ...@@ -532,14 +556,15 @@ class DistributeTranspiler(object):
continue continue
orig_param = program.global_block().vars[param_varname] orig_param = program.global_block().vars[param_varname]
if param_varname not in self.sparse_param_to_height_sections: if param_varname not in self.sparse_param_to_height_sections:
program.global_block().append_op( if not self.config.runtime_split_send_recv:
type="concat", program.global_block().append_op(
inputs={"X": splited_var}, type="concat",
outputs={"Out": [orig_param]}, inputs={"X": splited_var},
attrs={ outputs={"Out": [orig_param]},
"axis": 0, attrs={
RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE "axis": 0,
}) RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE
})
self._get_trainer_startup_program(recv_vars=recv_vars, eplist=eplist) self._get_trainer_startup_program(recv_vars=recv_vars, eplist=eplist)
...@@ -1552,11 +1577,17 @@ class DistributeTranspiler(object): ...@@ -1552,11 +1577,17 @@ class DistributeTranspiler(object):
lod_level=var.lod_level, lod_level=var.lod_level,
persistable=persistable) persistable=persistable)
@staticmethod
def _get_splited_var_sections(splited_vars):
height_sections = []
for v in splited_vars:
height_sections.append(v.shape[0])
return height_sections
def _insert_split_op(self, program, orig_var, index, splited_vars): def _insert_split_op(self, program, orig_var, index, splited_vars):
height_sections = self._get_splited_var_sections(splited_vars)
if orig_var.type == core.VarDesc.VarType.SELECTED_ROWS: if orig_var.type == core.VarDesc.VarType.SELECTED_ROWS:
height_sections = []
for v in splited_vars:
height_sections.append(v.shape[0])
sparse_param_name = self.grad_name_to_param_name[orig_var.name] sparse_param_name = self.grad_name_to_param_name[orig_var.name]
if self._is_input_of_remote_sparse_update_op(sparse_param_name): if self._is_input_of_remote_sparse_update_op(sparse_param_name):
self.sparse_param_to_height_sections[ self.sparse_param_to_height_sections[
...@@ -1571,16 +1602,13 @@ class DistributeTranspiler(object): ...@@ -1571,16 +1602,13 @@ class DistributeTranspiler(object):
RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE
}) })
elif orig_var.type == core.VarDesc.VarType.LOD_TENSOR: elif orig_var.type == core.VarDesc.VarType.LOD_TENSOR:
sections = []
for v in splited_vars:
sections.append(v.shape[0])
program.global_block()._insert_op( program.global_block()._insert_op(
index=index + 1, index=index + 1,
type="split_byref", type="split_byref",
inputs={"X": orig_var}, inputs={"X": orig_var},
outputs={"Out": splited_vars}, outputs={"Out": splited_vars},
attrs={ attrs={
"sections": sections, "sections": height_sections,
RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE
}) })
else: else:
...@@ -2052,7 +2080,7 @@ class DistributeTranspiler(object): ...@@ -2052,7 +2080,7 @@ class DistributeTranspiler(object):
Get optimizer operators, parameters and gradients from origin_program Get optimizer operators, parameters and gradients from origin_program
Returns: Returns:
opt_ops (list): optimize operators. opt_ops (list): optimize operators.
params_grads (dict): paramter->gradient. params_grads (dict): parameter->gradient.
""" """
block = self.origin_program.global_block() block = self.origin_program.global_block()
opt_ops = [] opt_ops = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册