提交 fab1b54d 编写于 作者: Q Qiao Longfei

Merge branch 'add-communicator' of ssh://github.com/jacquesqiao/Paddle into...

Merge branch 'add-communicator' of ssh://github.com/jacquesqiao/Paddle into add-async-ssa-graph-executor-communicator
...@@ -186,10 +186,12 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply( ...@@ -186,10 +186,12 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
#else #else
const bool use_cuda) const { const bool use_cuda) const {
#endif #endif
VLOG(3) << "apply all passes";
// Create a default one if not finalized by user. // Create a default one if not finalized by user.
CreatePassesFromStrategy(false); CreatePassesFromStrategy(false);
for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) { for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) {
VLOG(3) << "apply " << pass->Type();
if (IsMultiDevPass(pass->Type())) { if (IsMultiDevPass(pass->Type())) {
pass->Erase(kPlaces); pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places); pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
......
...@@ -19,6 +19,7 @@ namespace paddle { ...@@ -19,6 +19,7 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
std::unique_ptr<Graph> Pass::Apply(std::unique_ptr<Graph> graph) const { std::unique_ptr<Graph> Pass::Apply(std::unique_ptr<Graph> graph) const {
VLOG(3) << "apply pass -> " << Type();
PADDLE_ENFORCE(graph.get(), "graph passed to Pass::Apply() cannot be empty."); PADDLE_ENFORCE(graph.get(), "graph passed to Pass::Apply() cannot be empty.");
for (const std::string& attr : required_pass_attrs_) { for (const std::string& attr : required_pass_attrs_) {
PADDLE_ENFORCE(attrs_.find(attr) != attrs_.end(), PADDLE_ENFORCE(attrs_.find(attr) != attrs_.end(),
......
...@@ -77,6 +77,8 @@ Scope& Scope::NewScope() const { ...@@ -77,6 +77,8 @@ Scope& Scope::NewScope() const {
return *child; return *child;
} }
Scope* Scope::NewTmpScope() const { return new Scope(this); }
Variable* Scope::Var(const std::string& name) { Variable* Scope::Var(const std::string& name) {
SCOPE_VARS_WRITER_LOCK SCOPE_VARS_WRITER_LOCK
return VarInternal(name); return VarInternal(name);
......
...@@ -55,6 +55,8 @@ class Scope { ...@@ -55,6 +55,8 @@ class Scope {
/// Mark it to const because that new kid scope cannot change parent scope. /// Mark it to const because that new kid scope cannot change parent scope.
Scope& NewScope() const; Scope& NewScope() const;
Scope* NewTmpScope() const;
/// Create a variable with given name if it doesn't exist. /// Create a variable with given name if it doesn't exist.
/// Caller doesn't own the returned Variable. /// Caller doesn't own the returned Variable.
Variable* Var(const std::string& name); Variable* Var(const std::string& name);
......
...@@ -27,7 +27,7 @@ limitations under the License. */ ...@@ -27,7 +27,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void InitializeVariable(Variable* var, proto::VarType::Type var_type) { void InitializeVariable(Variable *var, proto::VarType::Type var_type) {
if (var_type == proto::VarType::LOD_TENSOR) { if (var_type == proto::VarType::LOD_TENSOR) {
var->GetMutable<LoDTensor>(); var->GetMutable<LoDTensor>();
} else if (var_type == proto::VarType::SELECTED_ROWS) { } else if (var_type == proto::VarType::SELECTED_ROWS) {
...@@ -37,7 +37,7 @@ void InitializeVariable(Variable* var, proto::VarType::Type var_type) { ...@@ -37,7 +37,7 @@ void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
} else if (var_type == proto::VarType::FETCH_LIST) { } else if (var_type == proto::VarType::FETCH_LIST) {
var->GetMutable<FeedFetchList>(); var->GetMutable<FeedFetchList>();
} else if (var_type == proto::VarType::STEP_SCOPES) { } else if (var_type == proto::VarType::STEP_SCOPES) {
var->GetMutable<std::vector<framework::Scope*>>(); var->GetMutable<std::vector<framework::Scope *>>();
} else if (var_type == proto::VarType::LOD_RANK_TABLE) { } else if (var_type == proto::VarType::LOD_RANK_TABLE) {
var->GetMutable<LoDRankTable>(); var->GetMutable<LoDRankTable>();
} else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) { } else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) {
...@@ -56,5 +56,27 @@ void InitializeVariable(Variable* var, proto::VarType::Type var_type) { ...@@ -56,5 +56,27 @@ void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
var_type); var_type);
} }
} }
void CopyVariable(const Variable &src_var, Variable *dst_var) {
// only support cpu now
auto cpu_place = platform::CPUPlace();
if (src_var.IsType<framework::LoDTensor>()) {
auto *tmp_grad_tensor = dst_var->GetMutable<framework::LoDTensor>();
auto &src_tensor = src_var.Get<framework::LoDTensor>();
tmp_grad_tensor->set_lod(src_tensor.lod());
framework::TensorCopy(src_tensor, cpu_place, tmp_grad_tensor);
} else if (src_var.IsType<framework::SelectedRows>()) {
auto &src_slr = src_var.Get<framework::SelectedRows>();
auto *tmp_grad_slr = dst_var->GetMutable<framework::SelectedRows>();
tmp_grad_slr->set_rows(src_slr.rows());
tmp_grad_slr->set_height(src_slr.height());
auto &src_t = src_slr.value();
auto *dst_t = tmp_grad_slr->mutable_value();
framework::TensorCopy(src_t, cpu_place, dst_t);
} else {
PADDLE_THROW("unknown var type to copy");
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void InitializeVariable(Variable *var, proto::VarType::Type var_type); void InitializeVariable(Variable* var, proto::VarType::Type var_type);
void CopyVariable(const Variable& src_var, Variable* dst_var);
} }
} }
...@@ -30,7 +30,7 @@ if(WITH_GRPC) ...@@ -30,7 +30,7 @@ if(WITH_GRPC)
else() else()
set(BRPC_SRCS brpc/brpc_client.cc brpc/brpc_server.cc brpc/brpc_sendrecvop_utils.cc brpc/brpc_variable_response.cc brpc/brpc_rdma_pool.cc) set(BRPC_SRCS brpc/brpc_client.cc brpc/brpc_server.cc brpc/brpc_sendrecvop_utils.cc brpc/brpc_variable_response.cc brpc/brpc_rdma_pool.cc)
set_source_files_properties(${BRPC_SRCS} parameter_prefetch.cc rpc_server_test.cc brpc/brpc_serde_test.cc collective_server.cc collective_server_test.cc collective_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(${BRPC_SRCS} parameter_prefetch.cc parameter_send.cc parameter_recv.cc rpc_server_test.cc brpc/brpc_serde_test.cc collective_server.cc collective_server_test.cc collective_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set(BRPC_DEPS brpc ssl crypto protobuf leveldb snappystream snappy zlib) set(BRPC_DEPS brpc ssl crypto protobuf leveldb snappystream snappy zlib)
...@@ -50,8 +50,11 @@ endif() ...@@ -50,8 +50,11 @@ endif()
cc_test(rpc_server_test SRCS rpc_server_test.cc cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS ${RPC_DEPS} executor proto_desc lookup_sparse_table_op SERIAL) DEPS ${RPC_DEPS} executor proto_desc lookup_sparse_table_op SERIAL)
cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler) cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope)
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory) cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory)
cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory)
cc_library(parameter_recv SRCS parameter_recv.cc DEPS sendrecvop_rpc memory)
cc_library(communicator SRCS communicator.cc DEPS scope selected_rows tensor variable_helper selected_rows_functor simple_threadpool)
if(WITH_GPU) if(WITH_GPU)
cc_test(collective_server_test SRCS collective_server_test.cc cc_test(collective_server_test SRCS collective_server_test.cc
DEPS sendrecvop_rpc executor ${RPC_DEPS} DEPS sendrecvop_rpc executor ${RPC_DEPS}
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/distributed/communicator.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/distributed/parameter_recv.h"
#include "paddle/fluid/operators/distributed/parameter_send.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
namespace paddle {
namespace operators {
namespace distributed {
static inline void MergeVars(const std::string &var_name,
const std::vector<std::shared_ptr<Variable>> &vars,
Scope *scope) {
PADDLE_ENFORCE(!vars.empty(), "should have value to merge!");
auto cpu_place = platform::CPUPlace();
auto &var0 = vars[0];
auto *out_var = scope->Var(var_name);
if (var0->IsType<framework::LoDTensor>()) {
auto *out_t = out_var->GetMutable<framework::LoDTensor>();
auto *out_ptr = out_t->mutable_data<float>(
var0->Get<framework::LoDTensor>().dims(), cpu_place);
auto numel = out_t->numel();
for (auto i = 0; i < numel; ++i) {
out_ptr[i] = 0;
for (auto &var : vars) {
auto &var_t = var->Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(var_t.numel(), numel, "should have the same dims");
out_ptr[i] += var_t.data<float>()[i];
}
}
} else if (var0->IsType<framework::SelectedRows>()) {
auto *out_slr = out_var->GetMutable<framework::SelectedRows>();
out_slr->mutable_rows()->clear();
out_slr->mutable_value()->mutable_data<float>({{}}, cpu_place);
std::vector<const paddle::framework::SelectedRows *> inputs;
inputs.reserve(vars.size());
for (auto &var : vars) {
inputs.push_back(&var->Get<framework::SelectedRows>());
}
math::scatter::MergeAdd<paddle::platform::CPUDeviceContext, float>
merge_add;
auto dev_ctx = paddle::platform::CPUDeviceContext();
merge_add(dev_ctx, inputs, out_slr, false);
} else {
PADDLE_THROW("unsupported var type!");
}
}
void Communicator::SendThread() {
while (running_) {
std::vector<std::future<void>> task_futures;
task_futures.reserve(send_varname_to_ctx_.size());
for (auto &iter : send_varname_to_queue_) {
auto send_task = [this, &iter] {
auto &var_name = iter.first;
VLOG(3) << "merge var " << var_name << " and send";
auto &var_queue = iter.second;
std::vector<std::shared_ptr<Variable>> vars;
// TODO(qiao): need to be configurable
const size_t max_merge_var_num = 20;
size_t merged_var_num = 0;
while (var_queue->Size() > 0 && merged_var_num < max_merge_var_num) {
vars.push_back(var_queue->Pop());
merged_var_num++;
}
MergeVars(var_name, vars, send_scope_.get());
auto send_functor = distributed::ParameterSend<float>();
auto &ctx = send_varname_to_ctx_.at(var_name);
send_functor(ctx, *send_scope_, true);
};
task_futures.emplace_back(
send_threadpool_->enqueue(std::move(send_task)));
}
for (auto &task_f : task_futures) {
task_f.wait();
}
}
}
void Communicator::RecvThread() {
while (running_) {
// parallel run recv graph
std::vector<std::future<void>> task_futures;
task_futures.reserve(recv_varname_to_ctx_.size());
for (auto &iter : recv_varname_to_ctx_) {
auto recv_task = [this, &iter] {
auto &var_name = iter.first;
VLOG(3) << "recv var " << var_name;
auto recv_functor = distributed::ParameterRecv<float>();
recv_functor(iter.second, *recv_scope_);
};
task_futures.emplace_back(
recv_threadpool_->enqueue(std::move(recv_task)));
}
for (auto &task : task_futures) {
task.wait();
}
}
}
void Communicator::Send(const std::string &var_name,
const framework::Scope &scope) {
// push var into send queue by var_name
auto *grad_var = scope.FindVar(var_name);
PADDLE_ENFORCE(grad_var->IsInitialized(), "grad var should be inited");
auto tmp_grad_var = std::make_shared<Variable>();
framework::CopyVariable(*grad_var, tmp_grad_var.get());
send_varname_to_queue_[var_name]->Push(tmp_grad_var);
}
void Communicator::Start() {
running_ = true;
// start send and recv thread
send_thread_.reset(
new std::thread(std::bind(&Communicator::SendThread, this)));
recv_thread_.reset(
new std::thread(std::bind(&Communicator::RecvThread, this)));
}
} // namespace distributed
} // namespace operators
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <deque>
#include <memory>
#include <string>
#include <vector>
#include <ThreadPool.h>
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/distributed/rpc_common.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace operators {
namespace distributed {
using Scope = framework::Scope;
using Variable = framework::Variable;
template <typename T>
class BlockingQueue {
public:
explicit BlockingQueue(size_t capacity) : capacity_(capacity) {
PADDLE_ENFORCE_GT(capacity_, 0, "The capacity must be greater than 0.");
}
bool Push(const T& elem) {
std::unique_lock<std::mutex> lock(mutex_);
send_cv_.wait(lock, [&] { return queue_.size() < capacity_; });
PADDLE_ENFORCE_LT(queue_.size(), capacity_);
queue_.push_back(elem);
recv_cv_.notify_one();
return true;
}
bool Push(T&& elem) {
std::unique_lock<std::mutex> lock(mutex_);
send_cv_.wait(lock, [&] { return queue_.size() < capacity_; });
PADDLE_ENFORCE_LT(queue_.size(), capacity_);
queue_.emplace_back(std::move(elem));
recv_cv_.notify_one();
return true;
}
T Pop() {
std::unique_lock<std::mutex> lock(mutex_);
recv_cv_.wait(lock, [=] { return !queue_.empty(); });
T rc(std::move(queue_.front()));
queue_.pop_front();
return rc;
}
size_t Cap() const {
std::lock_guard<std::mutex> lock(mutex_);
return capacity_;
}
size_t Size() const {
std::lock_guard<std::mutex> lock(mutex_);
return queue_.size();
}
private:
const size_t capacity_;
std::deque<T> queue_;
mutable std::mutex mutex_;
std::condition_variable recv_cv_;
std::condition_variable send_cv_;
};
class Communicator {
public:
Communicator(
const std::unordered_map<std::string, RpcContext>& send_varname_to_ctx,
const std::unordered_map<std::string, RpcContext>& recv_varname_to_ctx,
Scope* recv_scope)
: send_varname_to_ctx_(send_varname_to_ctx),
recv_varname_to_ctx_(recv_varname_to_ctx),
recv_scope_(recv_scope) {
// get all send information from graph, build vars_to_send
send_scope_.reset(new Scope());
for (auto& iter : send_varname_to_ctx_) {
send_varname_to_queue_[iter.first] =
std::make_shared<BlockingQueue<std::shared_ptr<Variable>>>(10);
}
// TODO(qiao): default 5, need to config
send_threadpool_.reset(new ::ThreadPool(5));
recv_threadpool_.reset(new ::ThreadPool(5));
}
~Communicator() {
VLOG(3) << "~Communicator";
running_ = false;
send_thread_->join();
recv_thread_->join();
VLOG(3) << "~Communicator done";
}
void Start();
// send grad
void Send(const std::string& var_name, const framework::Scope& scope);
private:
void SendThread();
void RecvThread();
bool running_ = false;
std::unordered_map<std::string,
std::shared_ptr<BlockingQueue<std::shared_ptr<Variable>>>>
send_varname_to_queue_;
std::unordered_map<std::string, RpcContext> send_varname_to_ctx_;
std::unordered_map<std::string, RpcContext> recv_varname_to_ctx_;
std::unique_ptr<std::thread> send_thread_;
std::unique_ptr<std::thread> recv_thread_;
Scope* recv_scope_; // should be global scope
std::unique_ptr<Scope> send_scope_; // an independent scope
std::unique_ptr<::ThreadPool> send_threadpool_{nullptr};
std::unique_ptr<::ThreadPool> recv_threadpool_{nullptr};
};
} // namespace distributed
} // namespace operators
} // namespace paddle
...@@ -107,6 +107,9 @@ class RequestSend final : public RequestBase { ...@@ -107,6 +107,9 @@ class RequestSend final : public RequestBase {
int trainer_id = request_->GetTrainerId(); int trainer_id = request_->GetTrainerId();
framework::Variable* outvar = nullptr; framework::Variable* outvar = nullptr;
if (!request_handler_->sync_mode()) {
request_->ReleaseOwnershipOfLocalScope();
}
request_handler_->Handle(varname, scope, invar, &outvar, trainer_id); request_handler_->Handle(varname, scope, invar, &outvar, trainer_id);
Finish(reply_, &responder_); Finish(reply_, &responder_);
} }
......
...@@ -37,30 +37,9 @@ using LoDTensor = framework::LoDTensor; ...@@ -37,30 +37,9 @@ using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows; using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim; using DDim = framework::DDim;
static size_t GetSectionIndex(int64_t id,
const std::vector<int64_t>& abs_sections) {
for (size_t i = 1; i < abs_sections.size(); ++i) {
if (id < abs_sections[i]) {
return i - 1;
}
}
return abs_sections.size() - 1;
}
static std::vector<int64_t> ToAbsoluteSection(
const std::vector<int>& height_sections) {
std::vector<int64_t> abs_sections;
abs_sections.resize(height_sections.size());
abs_sections[0] = 0;
for (size_t i = 1; i < height_sections.size(); ++i) {
abs_sections[i] = height_sections[i - 1] + abs_sections[i - 1];
}
return abs_sections;
}
static std::vector<std::vector<int64_t>> SplitIds( static std::vector<std::vector<int64_t>> SplitIds(
const std::vector<int64_t>& ids_vector, const std::vector<int64_t>& ids_vector,
const std::vector<int>& height_section, framework::Scope* scope) { const std::vector<int64_t>& height_section) {
std::set<int64_t> all_ids; std::set<int64_t> all_ids;
for (auto id : ids_vector) { for (auto id : ids_vector) {
all_ids.insert(id); all_ids.insert(id);
...@@ -78,7 +57,7 @@ static std::vector<std::vector<int64_t>> SplitIds( ...@@ -78,7 +57,7 @@ static std::vector<std::vector<int64_t>> SplitIds(
static void SplitIdsIntoMultipleVarsBySection( static void SplitIdsIntoMultipleVarsBySection(
const std::vector<std::string>& in_var_names, const std::vector<std::string>& in_var_names,
const std::vector<int>& height_section, const std::vector<int64_t>& height_section,
const std::vector<std::vector<int64_t>>& splited_ids, const std::vector<std::vector<int64_t>>& splited_ids,
framework::Scope* scope) { framework::Scope* scope) {
PADDLE_ENFORCE_EQ(in_var_names.size(), height_section.size(), ""); PADDLE_ENFORCE_EQ(in_var_names.size(), height_section.size(), "");
...@@ -100,7 +79,7 @@ static void SplitIdsIntoMultipleVarsBySection( ...@@ -100,7 +79,7 @@ static void SplitIdsIntoMultipleVarsBySection(
static void MergeMultipleVarsIntoOneBySection( static void MergeMultipleVarsIntoOneBySection(
const std::string& id_name, const std::vector<int64_t>& ids_vector, const std::string& id_name, const std::vector<int64_t>& ids_vector,
const std::string& out_name, const std::vector<std::string>& out_var_names, const std::string& out_name, const std::vector<std::string>& out_var_names,
const std::vector<int>& height_section, const std::vector<int64_t>& height_section,
const std::vector<std::vector<int64_t>>& splited_ids, const std::vector<std::vector<int64_t>>& splited_ids,
const framework::ExecutionContext& context, framework::Scope* scope, const framework::ExecutionContext& context, framework::Scope* scope,
platform::DeviceContext* actual_ctx) { platform::DeviceContext* actual_ctx) {
...@@ -177,10 +156,10 @@ static void MergeMultipleVarsIntoOneBySection( ...@@ -177,10 +156,10 @@ static void MergeMultipleVarsIntoOneBySection(
void prefetch(const std::string& id_name, const std::string& out_name, void prefetch(const std::string& id_name, const std::string& out_name,
const std::vector<std::string>& table_names, const std::vector<std::string>& table_names,
const std::vector<std::string>& epmap, const std::vector<std::string>& epmap,
const std::vector<int>& height_sections, const std::vector<int64_t>& height_sections,
const framework::ExecutionContext& context, const framework::ExecutionContext& context,
const framework::Scope& scope) { const framework::Scope& scope) {
auto& local_scope = scope.NewScope(); framework::Scope* local_scope = scope.NewTmpScope();
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& cpu_ctx = *pool.Get(platform::CPUPlace()); auto& cpu_ctx = *pool.Get(platform::CPUPlace());
...@@ -224,22 +203,22 @@ void prefetch(const std::string& id_name, const std::string& out_name, ...@@ -224,22 +203,22 @@ void prefetch(const std::string& id_name, const std::string& out_name,
#endif #endif
} }
auto splited_ids = SplitIds(ids_vector, height_sections, &local_scope); auto splited_ids = SplitIds(ids_vector, height_sections);
SplitIdsIntoMultipleVarsBySection(in_var_names, height_sections, splited_ids, SplitIdsIntoMultipleVarsBySection(in_var_names, height_sections, splited_ids,
&local_scope); local_scope);
// create output var in local scope // create output var in local scope
for (auto& name : out_var_names) { for (auto& name : out_var_names) {
local_scope.Var(name)->GetMutable<framework::LoDTensor>(); local_scope->Var(name)->GetMutable<framework::LoDTensor>();
} }
std::vector<distributed::VarHandlePtr> rets; std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < in_var_names.size(); i++) { for (size_t i = 0; i < in_var_names.size(); i++) {
if (NeedSend(local_scope, in_var_names[i])) { if (NeedSend(*local_scope, in_var_names[i])) {
VLOG(3) << "sending " << in_var_names[i] << " to " << epmap[i] VLOG(3) << "sending " << in_var_names[i] << " to " << epmap[i]
<< " to get " << out_var_names[i] << " back"; << " to get " << out_var_names[i] << " back";
rets.push_back(rpc_client->AsyncPrefetchVar( rets.push_back(rpc_client->AsyncPrefetchVar(
epmap[i], cpu_ctx, local_scope, in_var_names[i], out_var_names[i], epmap[i], cpu_ctx, *local_scope, in_var_names[i], out_var_names[i],
table_names[i])); table_names[i]));
} else { } else {
VLOG(3) << "don't send no-initialied variable: " << out_var_names[i]; VLOG(3) << "don't send no-initialied variable: " << out_var_names[i];
...@@ -252,8 +231,8 @@ void prefetch(const std::string& id_name, const std::string& out_name, ...@@ -252,8 +231,8 @@ void prefetch(const std::string& id_name, const std::string& out_name,
MergeMultipleVarsIntoOneBySection(id_name, ids_vector, out_name, MergeMultipleVarsIntoOneBySection(id_name, ids_vector, out_name,
out_var_names, height_sections, splited_ids, out_var_names, height_sections, splited_ids,
context, &local_scope, &actual_ctx); context, local_scope, &actual_ctx);
scope.DeleteScope(&local_scope); delete local_scope;
} }
}; // namespace distributed }; // namespace distributed
......
...@@ -26,7 +26,7 @@ namespace distributed { ...@@ -26,7 +26,7 @@ namespace distributed {
void prefetch(const std::string& id_name, const std::string& out_name, void prefetch(const std::string& id_name, const std::string& out_name,
const std::vector<std::string>& table_names, const std::vector<std::string>& table_names,
const std::vector<std::string>& epmap, const std::vector<std::string>& epmap,
const std::vector<int>& height_sections, const std::vector<int64_t>& height_sections,
const framework::ExecutionContext& context, const framework::ExecutionContext& context,
const framework::Scope& scope); const framework::Scope& scope);
...@@ -35,7 +35,7 @@ void prefetch_with_reconstruct(const std::string& id_name, ...@@ -35,7 +35,7 @@ void prefetch_with_reconstruct(const std::string& id_name,
const std::string& out_name, const std::string& out_name,
const std::vector<std::string>& table_names, const std::vector<std::string>& table_names,
const std::vector<std::string>& epmap, const std::vector<std::string>& epmap,
const std::vector<int>& height_sections, const std::vector<int64_t>& height_sections,
const framework::ExecutionContext& context, const framework::ExecutionContext& context,
const framework::Scope& scope, const framework::Scope& scope,
framework::LoDTensor* original) { framework::LoDTensor* original) {
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <set>
#include <string>
#include <vector>
#include "paddle/fluid/operators/distributed/parameter_recv.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
#include "paddle/fluid/operators/strided_memcpy.h"
namespace paddle {
namespace operators {
namespace distributed {
using LoDTensor = framework::LoDTensor;
using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim;
template <typename T>
void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
const framework::Scope &scope) {
framework::Scope *local_scope = scope.NewTmpScope();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &cpu_ctx = *pool.Get(platform::CPUPlace());
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
auto *recv_var = scope.FindVar(rpc_ctx.var_name);
std::vector<framework::Tensor *> recved_tensors;
// recv all vars to local scope
if (recv_var->IsType<framework::LoDTensor>()) {
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
auto &recv_var_name = rpc_ctx.splited_var_names[i];
framework::Tensor *t =
local_scope->Var(recv_var_name)->GetMutable<framework::LoDTensor>();
recved_tensors.push_back(t);
VLOG(3) << "recv " << recv_var_name << " from " << rpc_ctx.epmap[i];
rets.push_back(rpc_client->AsyncGetVar(rpc_ctx.epmap[i], cpu_ctx,
*local_scope, recv_var_name,
recv_var_name));
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
}
} else {
PADDLE_THROW("unsupported var type to recv!");
}
// concat recved tensor into one var
{
size_t output_offset = 0;
framework::Tensor *recv_tensor =
recv_var->GetMutable<framework::LoDTensor>();
auto dev_ctx = paddle::platform::CPUDeviceContext();
for (auto *in : recved_tensors) {
auto in_stride = framework::stride_numel(in->dims());
auto out_stride = framework::stride_numel(recv_tensor->dims());
StridedNumelCopyWithAxis<T>(
dev_ctx, 0, recv_tensor->data<T>() + output_offset, out_stride,
in->data<T>(), in_stride, in_stride[0]);
output_offset += in_stride[0];
}
}
delete local_scope;
}
template struct ParameterRecv<float>;
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/distributed/rpc_common.h"
namespace paddle {
namespace operators {
namespace distributed {
template <typename T>
struct ParameterRecv {
void operator()(const RpcContext &rpc_ctx, const framework::Scope &scope);
};
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <set>
#include <string>
#include <vector>
#include "paddle/fluid/operators/distributed/parameter_send.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
namespace paddle {
namespace operators {
namespace distributed {
using LoDTensor = framework::LoDTensor;
using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim;
template <typename T>
void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
const framework::Scope &scope, bool sync) {
framework::Scope *local_scope = scope.NewTmpScope();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &cpu_ctx = *pool.Get(platform::CPUPlace());
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
auto *send_var = scope.FindVar(rpc_ctx.var_name);
size_t out_num = rpc_ctx.splited_var_names.size();
if (send_var->IsType<framework::LoDTensor>()) {
if (out_num > 1) {
auto &send_tensor = send_var->Get<framework::LoDTensor>();
auto &send_tensor_dims = send_tensor.dims();
std::vector<framework::DDim> outs_dims;
outs_dims.reserve(out_num);
// infer output shape
PADDLE_ENFORCE_EQ(rpc_ctx.height_sections.size(), out_num,
"tensor split sections size"
"should be equal to output size.");
for (size_t i = 0; i < out_num; ++i) {
auto dim = send_tensor_dims;
dim[0] = rpc_ctx.height_sections[i];
outs_dims.push_back(dim);
}
// create output var in local scope
size_t row_offset = 0;
for (auto i = 0; i < out_num; ++i) {
framework::Tensor *out = local_scope->Var(rpc_ctx.splited_var_names[i])
->GetMutable<framework::LoDTensor>();
*out = send_tensor.Slice(row_offset, row_offset + outs_dims[i][0]);
row_offset += outs_dims[i][0];
}
}
} else if (send_var->IsType<framework::SelectedRows>()) {
auto &send_slr = send_var->Get<framework::SelectedRows>();
auto abs_sections = ToAbsoluteSection(rpc_ctx.height_sections);
auto send_rows = send_slr.rows();
std::vector<std::vector<int>> outs_rows_idx;
std::vector<std::vector<int>> outs_dense_idx;
outs_rows_idx.resize(out_num);
outs_dense_idx.resize(out_num);
auto row_numel = send_slr.value().numel() / send_slr.value().dims()[0];
auto src = send_slr.value().data<T>();
// create output var in local scope
std::vector<framework::SelectedRows *> outs;
for (auto &name : rpc_ctx.splited_var_names) {
auto *out = local_scope->Var(name)->GetMutable<framework::SelectedRows>();
outs.push_back(out);
}
// split rows index into output sparse vars
for (size_t i = 0; i < send_rows.size(); ++i) {
int out_idx = FindOutIdx(send_rows[i], abs_sections);
outs_rows_idx[out_idx].push_back(send_rows[i]);
outs_dense_idx[out_idx].push_back(i);
}
auto place = platform::CPUPlace();
for (size_t i = 0; i < outs_rows_idx.size(); ++i) {
auto rows_idx = outs_rows_idx[i];
outs[i]->set_height(rpc_ctx.height_sections[i]);
auto dims = send_slr.GetCompleteDims();
dims[0] = rows_idx.size();
outs[i]->mutable_value()->mutable_data<T>(dims, send_slr.place());
outs[i]->mutable_rows()->clear();
if (rows_idx.size() > 0) {
for (auto idx : rows_idx) {
outs[i]->mutable_rows()->push_back(idx - abs_sections[i]);
}
auto dst = outs[i]->mutable_value()->mutable_data<T>(place);
for (size_t j = 0; j < rows_idx.size(); j++) {
if (platform::is_cpu_place(place)) {
memory::Copy(
platform::CPUPlace(), dst + j * row_numel, platform::CPUPlace(),
src + outs_dense_idx[i][j] * row_numel, sizeof(T) * row_numel);
} else {
PADDLE_THROW("do not support GPU now");
/*
#ifdef PADDLE_WITH_CUDA
auto stream = ctx.cuda_device_context().stream();
memory::Copy(platform::CUDAPlace(), dst + j * row_numel,
platform::CUDAPlace(),
src + outs_dense_idx[i][j] * row_numel,
sizeof(T) * row_numel, stream);
#else
PADDLE_THROW("Paddle is not compiled with GPU");
#endif
*/
}
}
}
PADDLE_ENFORCE_EQ(rows_idx.size(), outs[i]->rows().size(),
"rows should has the same size with tensor dim 0");
}
} else {
PADDLE_THROW("unsupported var type to send!");
}
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
auto &send_var_name = rpc_ctx.splited_var_names[i];
auto &endpoint = rpc_ctx.epmap[i];
if (NeedSend(*local_scope, send_var_name)) {
VLOG(3) << "sending " << send_var_name << " to " << endpoint;
rets.push_back(rpc_client->AsyncSendVar(endpoint, cpu_ctx, *local_scope,
send_var_name));
} else {
VLOG(3) << "don't send non-initialized variable: "
<< rpc_ctx.splited_var_names[i];
}
}
// note!! only support sync send now
if (true || sync) {
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
}
}
delete local_scope;
}
template struct ParameterSend<float>;
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/distributed/rpc_common.h"
namespace paddle {
namespace operators {
namespace distributed {
template <typename T>
struct ParameterSend {
void operator()(const RpcContext &rpc_ctx, const framework::Scope &scope,
bool sync);
};
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
...@@ -71,13 +71,15 @@ class VarHandle { ...@@ -71,13 +71,15 @@ class VarHandle {
VarHandle(const std::string ep, const std::string& method, VarHandle(const std::string ep, const std::string& method,
const std::string& name, const std::string& name,
const platform::DeviceContext* p_ctx = nullptr, const platform::DeviceContext* p_ctx = nullptr,
const framework::Scope* p_scope = nullptr) const framework::Scope* p_scope = nullptr,
bool delete_local_scope = false)
: status_(kDefaultState) { : status_(kDefaultState) {
ep_ = ep; ep_ = ep;
ctx_ = p_ctx; ctx_ = p_ctx;
scope_ = p_scope; scope_ = p_scope;
name_ = name; name_ = name;
method_ = method; method_ = method;
delete_local_scope_ = delete_local_scope;
} }
virtual ~VarHandle() {} virtual ~VarHandle() {}
...@@ -99,6 +101,7 @@ class VarHandle { ...@@ -99,6 +101,7 @@ class VarHandle {
std::unique_lock<std::mutex> lk(sync_mutex_); std::unique_lock<std::mutex> lk(sync_mutex_);
status_ = ok ? kFinishState : kErrorState; status_ = ok ? kFinishState : kErrorState;
} }
if (delete_local_scope_ && scope_) delete scope_;
VLOG(7) << "VarHandle finish:" << ok; VLOG(7) << "VarHandle finish:" << ok;
wait_cond_.notify_all(); wait_cond_.notify_all();
} }
...@@ -125,6 +128,7 @@ class VarHandle { ...@@ -125,6 +128,7 @@ class VarHandle {
std::string name_; std::string name_;
// RPC method name. // RPC method name.
std::string method_; std::string method_;
bool delete_local_scope_;
protected: protected:
std::mutex sync_mutex_; std::mutex sync_mutex_;
......
...@@ -59,9 +59,11 @@ bool RequestSendHandler::Handle(const std::string& varname, ...@@ -59,9 +59,11 @@ bool RequestSendHandler::Handle(const std::string& varname,
"async mode should not recv BATCH_BARRIER_MESSAGE or " "async mode should not recv BATCH_BARRIER_MESSAGE or "
"COMPLETE_MESSAGE"); "COMPLETE_MESSAGE");
} }
try { try {
executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(), executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(),
scope); scope);
delete scope;
} catch (std::exception& e) { } catch (std::exception& e) {
LOG(ERROR) << "async: run sub program error " << e.what(); LOG(ERROR) << "async: run sub program error " << e.what();
return false; return false;
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
namespace paddle {
namespace operators {
namespace distributed {
struct RpcContext {
RpcContext(const std::string& name, const std::vector<std::string>& names,
const std::vector<std::string>& emap,
const std::vector<int64_t>& sections)
: var_name(name),
splited_var_names(names),
epmap(emap),
height_sections(sections) {}
RpcContext(const RpcContext& ctx) {
var_name = ctx.var_name;
splited_var_names = ctx.splited_var_names;
epmap = ctx.epmap;
height_sections = ctx.height_sections;
}
std::string var_name;
std::vector<std::string> splited_var_names;
std::vector<std::string> epmap;
std::vector<int64_t> height_sections;
};
} // namespace distributed
} // namespace operators
} // namespace paddle
...@@ -60,14 +60,12 @@ class VariableResponse { ...@@ -60,14 +60,12 @@ class VariableResponse {
bool create_scope = false) bool create_scope = false)
: scope_(scope), dev_ctx_(dev_ctx), create_scope_(create_scope) { : scope_(scope), dev_ctx_(dev_ctx), create_scope_(create_scope) {
if (create_scope) { if (create_scope) {
local_scope_ = &scope->NewScope(); local_scope_ = scope->NewTmpScope();
} }
} }
virtual ~VariableResponse() { virtual ~VariableResponse() {
if (create_scope_) { if (local_scope_) delete local_scope_;
scope_->DeleteScope(local_scope_);
}
} }
int Parse(Source* source, const sendrecv::VariableMessage& meta) { int Parse(Source* source, const sendrecv::VariableMessage& meta) {
...@@ -86,6 +84,12 @@ class VariableResponse { ...@@ -86,6 +84,12 @@ class VariableResponse {
inline std::string Varname() const { return meta_.varname(); } inline std::string Varname() const { return meta_.varname(); }
inline std::string OutVarname() const { return meta_.out_varname(); } inline std::string OutVarname() const { return meta_.out_varname(); }
inline std::string TableName() const { return meta_.table_name(); } inline std::string TableName() const { return meta_.table_name(); }
inline void ReleaseOwnershipOfLocalScope() {
PADDLE_ENFORCE(create_scope_,
"only when create_scope_ is true can you release the "
"ownership of local scope");
local_scope_ = nullptr;
}
// should call parse first. // should call parse first.
framework::Variable* GetVar() { framework::Variable* GetVar() {
......
...@@ -2,9 +2,9 @@ include(operators) ...@@ -2,9 +2,9 @@ include(operators)
set(DISTRIBUTE_DEPS "") set(DISTRIBUTE_DEPS "")
if(WITH_GRPC) if(WITH_GRPC)
set(DISTRIBUTE_DEPS sendrecvop_rpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node) set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node)
else() else()
set(DISTRIBUTE_DEPS sendrecvop_rpc brpc leveldb snappystream snappy protobuf ssl crypto zlib node) set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv brpc leveldb snappystream snappy protobuf ssl crypto zlib node)
if(WITH_BRPC_RDMA) if(WITH_BRPC_RDMA)
find_library(IBVERBS_LIBRARY NAMES ibverbs) find_library(IBVERBS_LIBRARY NAMES ibverbs)
ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL) ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL)
......
...@@ -20,6 +20,8 @@ limitations under the License. */ ...@@ -20,6 +20,8 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/parameter_recv.h"
#include "paddle/fluid/operators/distributed/rpc_common.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
...@@ -48,6 +50,14 @@ class RecvOp : public framework::OperatorBase { ...@@ -48,6 +50,14 @@ class RecvOp : public framework::OperatorBase {
distributed::RPCClient::GetInstance<RPCCLIENT_T>( distributed::RPCClient::GetInstance<RPCCLIENT_T>(
Attr<int>("trainer_id")); Attr<int>("trainer_id"));
std::vector<std::string> recv_varnames =
Attr<std::vector<std::string>>("recv_varnames");
if (recv_varnames.size() > 0) {
auto recv_functor = distributed::ParameterRecv<float>();
auto rpc_ctx = distributed::RpcContext(outs[0], recv_varnames, epmap, {});
recv_functor(rpc_ctx, scope);
} else {
if (with_barrier) { if (with_barrier) {
std::vector<distributed::VarHandlePtr> rets; std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < outs.size(); i++) { for (size_t i = 0; i < outs.size(); i++) {
...@@ -76,6 +86,7 @@ class RecvOp : public framework::OperatorBase { ...@@ -76,6 +86,7 @@ class RecvOp : public framework::OperatorBase {
} }
} }
} }
}
}; };
class RecvOpMaker : public framework::OpProtoAndCheckerMaker { class RecvOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -110,6 +121,11 @@ This operator can get variables from server side. ...@@ -110,6 +121,11 @@ This operator can get variables from server side.
"for example: we need var named 'moment_1@127.0.0.1:1001', " "for example: we need var named 'moment_1@127.0.0.1:1001', "
"and it real name on parameter server is 'moment_1'. ") "and it real name on parameter server is 'moment_1'. ")
.SetDefault({}); .SetDefault({});
AddAttr<std::vector<std::string>>(
"recv_varnames",
"(vector<string>) "
"the splited parameter varnames to be recved from pserver")
.SetDefault(std::vector<std::string>{});
} }
}; };
......
...@@ -20,6 +20,8 @@ limitations under the License. */ ...@@ -20,6 +20,8 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/parameter_send.h"
#include "paddle/fluid/operators/distributed/rpc_common.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h" #include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -37,10 +39,21 @@ class SendOp : public framework::OperatorBase { ...@@ -37,10 +39,21 @@ class SendOp : public framework::OperatorBase {
const platform::Place& place) const override { const platform::Place& place) const override {
auto ins = Inputs("X"); auto ins = Inputs("X");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap"); auto epmap = Attr<std::vector<std::string>>("epmap");
int sync_send = Attr<int>("sync_mode"); int sync_send = Attr<int>("sync_mode");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto send_varnames = Attr<std::vector<std::string>>("send_varnames");
auto height_sections = Attr<std::vector<int64_t>>("sections");
if (send_varnames.size() > 0) {
PADDLE_ENFORCE_EQ(ins.size(), 1, "");
auto send_functor = distributed::ParameterSend<float>();
auto rpc_ctx = distributed::RpcContext(ins[0], send_varnames, epmap,
height_sections);
send_functor(rpc_ctx, scope, static_cast<bool>(sync_send));
} else {
platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place); auto& ctx = *pool.Get(place);
distributed::RPCClient* rpc_client = distributed::RPCClient* rpc_client =
...@@ -51,7 +64,8 @@ class SendOp : public framework::OperatorBase { ...@@ -51,7 +64,8 @@ class SendOp : public framework::OperatorBase {
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) { if (NeedSend(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i]; VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
rets.push_back(rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i])); rets.push_back(
rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i]));
} else { } else {
VLOG(3) << "don't send no-initialied variable: " << ins[i]; VLOG(3) << "don't send no-initialied variable: " << ins[i];
} }
...@@ -64,6 +78,7 @@ class SendOp : public framework::OperatorBase { ...@@ -64,6 +78,7 @@ class SendOp : public framework::OperatorBase {
} }
} }
} }
}
}; };
class SendOpMaker : public framework::OpProtoAndCheckerMaker { class SendOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -88,6 +103,21 @@ This operator will send variables to listen_and_serve op at the parameter server ...@@ -88,6 +103,21 @@ This operator will send variables to listen_and_serve op at the parameter server
"Server endpoints in the order of input " "Server endpoints in the order of input "
"variables for mapping") "variables for mapping")
.SetDefault({"127.0.0.1:6164"}); .SetDefault({"127.0.0.1:6164"});
AddAttr<std::vector<int64_t>>("sections",
"(vector<int>) "
"the length of each output along the "
"specified axis.")
.SetDefault(std::vector<int64_t>{});
AddAttr<std::vector<std::string>>(
"send_varnames",
"(vector<string>) "
"the splited output varnames to send to pserver")
.SetDefault(std::vector<std::string>{});
AddAttr<int>("num",
"(int, default 0)"
"Number of sub-tensors. This must evenly divide "
"Input.dims()[axis]")
.SetDefault(0);
} }
}; };
......
...@@ -13,8 +13,14 @@ See the License for the specific language governing permissions and ...@@ -13,8 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include <vector>
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -42,5 +48,35 @@ inline bool NeedSend(const framework::Scope& scope, ...@@ -42,5 +48,35 @@ inline bool NeedSend(const framework::Scope& scope,
return false; return false;
} }
inline int FindOutIdx(int row, const std::vector<int64_t>& abs_sections) {
for (size_t i = 1; i < abs_sections.size(); ++i) {
if (row < abs_sections[i]) {
return i - 1;
}
}
return abs_sections.size() - 1;
}
inline std::vector<int64_t> ToAbsoluteSection(
const std::vector<int64_t>& height_sections) {
std::vector<int64_t> abs_sections;
abs_sections.resize(height_sections.size());
abs_sections[0] = 0;
for (size_t i = 1; i < height_sections.size(); ++i) {
abs_sections[i] = height_sections[i - 1] + abs_sections[i - 1];
}
return abs_sections;
}
inline size_t GetSectionIndex(int64_t id,
const std::vector<int64_t>& abs_sections) {
for (size_t i = 1; i < abs_sections.size(); ++i) {
if (id < abs_sections[i]) {
return i - 1;
}
}
return abs_sections.size() - 1;
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -134,9 +134,9 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -134,9 +134,9 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
// for parameter prefetch // for parameter prefetch
AddAttr<bool>("remote_prefetch", "").SetDefault(false); AddAttr<bool>("remote_prefetch", "").SetDefault(false);
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0); AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<std::vector<int>>("height_sections", AddAttr<std::vector<int64_t>>("height_sections",
"Height for each output SelectedRows.") "Height for each output SelectedRows.")
.SetDefault(std::vector<int>({})); .SetDefault(std::vector<int64_t>({}));
AddAttr<std::vector<std::string>>( AddAttr<std::vector<std::string>>(
"epmap", "epmap",
"(string vector, default 127.0.0.1:6164)" "(string vector, default 127.0.0.1:6164)"
......
...@@ -70,7 +70,7 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> { ...@@ -70,7 +70,7 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
// if epmap is not empty, then the parameter will be fetched from remote // if epmap is not empty, then the parameter will be fetched from remote
// parameter // parameter
// server // server
auto height_sections = ctx.Attr<std::vector<int>>("height_sections"); auto height_sections = ctx.Attr<std::vector<int64_t>>("height_sections");
auto table_names = ctx.Attr<std::vector<std::string>>("table_names"); auto table_names = ctx.Attr<std::vector<std::string>>("table_names");
std::vector<int64_t> real_rows = PathToRows(*path); std::vector<int64_t> real_rows = PathToRows(*path);
framework::Scope& local_scope = ctx.scope().NewScope(); framework::Scope& local_scope = ctx.scope().NewScope();
......
...@@ -91,9 +91,9 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -91,9 +91,9 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
// for parameter prefetch // for parameter prefetch
AddAttr<bool>("remote_prefetch", "").SetDefault(false); AddAttr<bool>("remote_prefetch", "").SetDefault(false);
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0); AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<std::vector<int>>("height_sections", AddAttr<std::vector<int64_t>>("height_sections",
"Height for each output SelectedRows.") "Height for each output SelectedRows.")
.SetDefault(std::vector<int>({})); .SetDefault(std::vector<int64_t>({}));
AddAttr<std::vector<std::string>>( AddAttr<std::vector<std::string>>(
"epmap", "epmap",
"(string vector, default 127.0.0.1:6164)" "(string vector, default 127.0.0.1:6164)"
......
...@@ -50,7 +50,8 @@ class LookupTableKernel : public framework::OpKernel<T> { ...@@ -50,7 +50,8 @@ class LookupTableKernel : public framework::OpKernel<T> {
// for remote prefetch // for remote prefetch
auto epmap = context.Attr<std::vector<std::string>>("epmap"); auto epmap = context.Attr<std::vector<std::string>>("epmap");
auto height_sections = context.Attr<std::vector<int>>("height_sections"); auto height_sections =
context.Attr<std::vector<int64_t>>("height_sections");
auto table_names = context.Attr<std::vector<std::string>>("table_names"); auto table_names = context.Attr<std::vector<std::string>>("table_names");
if (!epmap.empty()) { if (!epmap.empty()) {
......
...@@ -95,7 +95,7 @@ struct MergeAdd { ...@@ -95,7 +95,7 @@ struct MergeAdd {
enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY }; enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY };
// out = seleted_rows_in / tensor // out = selected_rows_in / tensor
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct UpdateToTensor { struct UpdateToTensor {
void operator()(const DeviceContext& context, const ScatterOps& op, void operator()(const DeviceContext& context, const ScatterOps& op,
......
...@@ -156,9 +156,9 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -156,9 +156,9 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker {
// for parameter prefetch // for parameter prefetch
AddAttr<bool>("remote_prefetch", "").SetDefault(false); AddAttr<bool>("remote_prefetch", "").SetDefault(false);
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0); AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<std::vector<int>>("height_sections", AddAttr<std::vector<int64_t>>("height_sections",
"Height for each output SelectedRows.") "Height for each output SelectedRows.")
.SetDefault(std::vector<int>({})); .SetDefault(std::vector<int64_t>({}));
AddAttr<std::vector<std::string>>( AddAttr<std::vector<std::string>>(
"epmap", "epmap",
"(string vector, default 127.0.0.1:6164)" "(string vector, default 127.0.0.1:6164)"
......
...@@ -172,7 +172,8 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -172,7 +172,8 @@ class NCEKernel : public framework::OpKernel<T> {
framework::Scope &local_scope = context.scope().NewScope(); framework::Scope &local_scope = context.scope().NewScope();
auto height_sections = context.Attr<std::vector<int>>("height_sections"); auto height_sections =
context.Attr<std::vector<int64_t>>("height_sections");
auto table_names = context.Attr<std::vector<std::string>>("table_names"); auto table_names = context.Attr<std::vector<std::string>>("table_names");
auto *ids = local_scope.Var("Ids@Prefetch"); auto *ids = local_scope.Var("Ids@Prefetch");
......
...@@ -16,31 +16,12 @@ limitations under the License. */ ...@@ -16,31 +16,12 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/operators/math/selected_rows_functor.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
static int FindOutIdx(int row, const std::vector<int64_t>& abs_sections) {
for (size_t i = 1; i < abs_sections.size(); ++i) {
if (row < abs_sections[i]) {
return i - 1;
}
}
return abs_sections.size() - 1;
}
static std::vector<int64_t> ToAbsoluteSection(
const std::vector<int64_t>& height_sections) {
std::vector<int64_t> abs_sections;
abs_sections.resize(height_sections.size());
abs_sections[0] = 0;
for (size_t i = 1; i < height_sections.size(); ++i) {
abs_sections[i] = height_sections[i - 1] + abs_sections[i - 1];
}
return abs_sections;
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class SplitSelectedRowsOpKernel : public framework::OpKernel<T> { class SplitSelectedRowsOpKernel : public framework::OpKernel<T> {
public: public:
......
...@@ -48,6 +48,7 @@ class TestDistRunnerBase(object): ...@@ -48,6 +48,7 @@ class TestDistRunnerBase(object):
# NOTE: import fluid until runtime, or else forking processes will cause error. # NOTE: import fluid until runtime, or else forking processes will cause error.
config = fluid.DistributeTranspilerConfig() config = fluid.DistributeTranspilerConfig()
config.enable_dc_asgd = dc_asgd config.enable_dc_asgd = dc_asgd
config.runtime_split_send_recv = True
t = fluid.DistributeTranspiler(config=config) t = fluid.DistributeTranspiler(config=config)
t.transpile( t.transpile(
trainer_id=trainer_id, trainer_id=trainer_id,
...@@ -87,6 +88,9 @@ class TestDistRunnerBase(object): ...@@ -87,6 +88,9 @@ class TestDistRunnerBase(object):
args.endpoints, args.trainers, args.endpoints, args.trainers,
args.sync_mode, args.dc_asgd) args.sync_mode, args.dc_asgd)
trainer_prog = t.get_trainer_program() trainer_prog = t.get_trainer_program()
with open("/tmp/trainer." + str(args.trainer_id) + ".proto",
"w") as f:
f.write(str(trainer_prog))
elif args.update_method == "nccl2": elif args.update_method == "nccl2":
# transpile for nccl2 # transpile for nccl2
config = fluid.DistributeTranspilerConfig() config = fluid.DistributeTranspilerConfig()
...@@ -115,6 +119,7 @@ class TestDistRunnerBase(object): ...@@ -115,6 +119,7 @@ class TestDistRunnerBase(object):
strategy.allow_op_delay = False strategy.allow_op_delay = False
build_stra = fluid.BuildStrategy() build_stra = fluid.BuildStrategy()
build_stra.debug_graphviz_path = "/tmp/graph-" + str(args.trainer_id)
if args.use_reduce: if args.use_reduce:
build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
...@@ -123,8 +128,7 @@ class TestDistRunnerBase(object): ...@@ -123,8 +128,7 @@ class TestDistRunnerBase(object):
if args.batch_merge_repeat > 1: if args.batch_merge_repeat > 1:
pass_builder = build_stra._finalize_strategy_and_create_passes() pass_builder = build_stra._finalize_strategy_and_create_passes()
mypass = pass_builder.insert_pass( mypass = pass_builder.insert_pass(0, "multi_batch_merge_pass")
len(pass_builder.all_passes()) - 3, "multi_batch_merge_pass")
mypass.set("num_repeats", args.batch_merge_repeat) mypass.set("num_repeats", args.batch_merge_repeat)
if args.update_method == "nccl2": if args.update_method == "nccl2":
......
...@@ -156,6 +156,8 @@ class DistributeTranspilerConfig(object): ...@@ -156,6 +156,8 @@ class DistributeTranspilerConfig(object):
mode = "pserver" mode = "pserver"
print_log = False print_log = False
wait_port = True wait_port = True
# split the send recv var in runtime
runtime_split_send_recv = False
class DistributeTranspiler(object): class DistributeTranspiler(object):
...@@ -398,7 +400,9 @@ class DistributeTranspiler(object): ...@@ -398,7 +400,9 @@ class DistributeTranspiler(object):
orig_var = program.global_block().vars[splited_grad_varname] orig_var = program.global_block().vars[splited_grad_varname]
index = find_op_by_output_arg( index = find_op_by_output_arg(
program.global_block(), splited_grad_varname, reverse=True) program.global_block(), splited_grad_varname, reverse=True)
self._insert_split_op(program, orig_var, index, splited_vars) if not self.config.runtime_split_send_recv:
self._insert_split_op(program, orig_var, index,
splited_vars)
index += 1 index += 1
else: else:
AssertionError("Can not insert the send op by original " AssertionError("Can not insert the send op by original "
...@@ -408,6 +412,17 @@ class DistributeTranspiler(object): ...@@ -408,6 +412,17 @@ class DistributeTranspiler(object):
name=framework.generate_control_dev_var_name()) name=framework.generate_control_dev_var_name())
self.grad_name_to_send_dummy_out[grad_varname] = dummy_output self.grad_name_to_send_dummy_out[grad_varname] = dummy_output
if self.config.runtime_split_send_recv:
send_input_vars = [
program.global_block().vars[splited_grad_varname]
]
sections = self._get_splited_var_sections(splited_vars)
send_varnames = [var.name for var in splited_vars]
else:
send_input_vars = splited_vars
sections = []
send_varnames = []
# get send op_role_var, if not splited, the grad should have .trainer suffix # get send op_role_var, if not splited, the grad should have .trainer suffix
# if splited, grad should be the original grad var name (split_by_ref and send # if splited, grad should be the original grad var name (split_by_ref and send
# will be on the same place). ParallelExecutor # will be on the same place). ParallelExecutor
...@@ -415,10 +430,12 @@ class DistributeTranspiler(object): ...@@ -415,10 +430,12 @@ class DistributeTranspiler(object):
program.global_block()._insert_op( program.global_block()._insert_op(
index=index + 1, index=index + 1,
type="send", type="send",
inputs={"X": splited_vars}, inputs={"X": send_input_vars},
outputs={"Out": dummy_output}, outputs={"Out": dummy_output},
attrs={ attrs={
"epmap": eplist, "epmap": eplist,
"sections": sections,
"send_varnames": send_varnames,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
OP_ROLE_VAR_ATTR_NAME: [ OP_ROLE_VAR_ATTR_NAME: [
self.grad_name_to_param_name[grad_varname], self.grad_name_to_param_name[grad_varname],
...@@ -501,13 +518,20 @@ class DistributeTranspiler(object): ...@@ -501,13 +518,20 @@ class DistributeTranspiler(object):
self._update_remote_sparse_update_op( self._update_remote_sparse_update_op(
param_varname, height_sections, eps, table_names) param_varname, height_sections, eps, table_names)
else: else:
recv_varnames = []
if self.config.runtime_split_send_recv:
orig_param = program.global_block().vars[param_varname]
recv_varnames = [var.name for var in splited_var]
splited_var = [orig_param]
all_recv_outputs.extend(splited_var) all_recv_outputs.extend(splited_var)
program.global_block().append_op( program.global_block().append_op(
type="recv", type="recv",
inputs={"X": [recv_dep_in]}, inputs={"X": [recv_dep_in]},
outputs={"Out": splited_var}, outputs={"Out": splited_var},
attrs={ attrs={
"epmap": eps, "epmap": eps,
"recv_varnames": recv_varnames,
"trainer_id": self.trainer_id, "trainer_id": self.trainer_id,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
OP_ROLE_VAR_ATTR_NAME: OP_ROLE_VAR_ATTR_NAME:
...@@ -532,6 +556,7 @@ class DistributeTranspiler(object): ...@@ -532,6 +556,7 @@ class DistributeTranspiler(object):
continue continue
orig_param = program.global_block().vars[param_varname] orig_param = program.global_block().vars[param_varname]
if param_varname not in self.sparse_param_to_height_sections: if param_varname not in self.sparse_param_to_height_sections:
if not self.config.runtime_split_send_recv:
program.global_block().append_op( program.global_block().append_op(
type="concat", type="concat",
inputs={"X": splited_var}, inputs={"X": splited_var},
...@@ -1376,9 +1401,8 @@ class DistributeTranspiler(object): ...@@ -1376,9 +1401,8 @@ class DistributeTranspiler(object):
# create table param and grad var in pserver program # create table param and grad var in pserver program
# create table optimize block in pserver program # create table optimize block in pserver program
table_opt_op = [ table_opt_op = [
op for op in self.optimize_ops op for op in self.optimize_ops if 'Param' in op.input_names and
if 'Param' in op.input_names and op.input("Param")[0] == op.input("Param")[0] == self.table_name
self.table_name
][0] ][0]
origin_param_var = self.origin_program.global_block().vars[ origin_param_var = self.origin_program.global_block().vars[
...@@ -1552,11 +1576,17 @@ class DistributeTranspiler(object): ...@@ -1552,11 +1576,17 @@ class DistributeTranspiler(object):
lod_level=var.lod_level, lod_level=var.lod_level,
persistable=persistable) persistable=persistable)
def _insert_split_op(self, program, orig_var, index, splited_vars): @staticmethod
if orig_var.type == core.VarDesc.VarType.SELECTED_ROWS: def _get_splited_var_sections(splited_vars):
height_sections = [] height_sections = []
for v in splited_vars: for v in splited_vars:
height_sections.append(v.shape[0]) height_sections.append(v.shape[0])
return height_sections
def _insert_split_op(self, program, orig_var, index, splited_vars):
height_sections = self._get_splited_var_sections(splited_vars)
if orig_var.type == core.VarDesc.VarType.SELECTED_ROWS:
sparse_param_name = self.grad_name_to_param_name[orig_var.name] sparse_param_name = self.grad_name_to_param_name[orig_var.name]
if self._is_input_of_remote_sparse_update_op(sparse_param_name): if self._is_input_of_remote_sparse_update_op(sparse_param_name):
self.sparse_param_to_height_sections[ self.sparse_param_to_height_sections[
...@@ -1571,16 +1601,13 @@ class DistributeTranspiler(object): ...@@ -1571,16 +1601,13 @@ class DistributeTranspiler(object):
RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE
}) })
elif orig_var.type == core.VarDesc.VarType.LOD_TENSOR: elif orig_var.type == core.VarDesc.VarType.LOD_TENSOR:
sections = []
for v in splited_vars:
sections.append(v.shape[0])
program.global_block()._insert_op( program.global_block()._insert_op(
index=index + 1, index=index + 1,
type="split_byref", type="split_byref",
inputs={"X": orig_var}, inputs={"X": orig_var},
outputs={"Out": splited_vars}, outputs={"Out": splited_vars},
attrs={ attrs={
"sections": sections, "sections": height_sections,
RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE
}) })
else: else:
...@@ -2052,7 +2079,7 @@ class DistributeTranspiler(object): ...@@ -2052,7 +2079,7 @@ class DistributeTranspiler(object):
Get optimizer operators, parameters and gradients from origin_program Get optimizer operators, parameters and gradients from origin_program
Returns: Returns:
opt_ops (list): optimize operators. opt_ops (list): optimize operators.
params_grads (dict): paramter->gradient. params_grads (dict): parameter->gradient.
""" """
block = self.origin_program.global_block() block = self.origin_program.global_block()
opt_ops = [] opt_ops = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册