提交 3691a46f 编写于 作者: Q Qiao Longfei

improve communicator

上级 02425b2f
...@@ -27,7 +27,7 @@ limitations under the License. */ ...@@ -27,7 +27,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void InitializeVariable(Variable* var, proto::VarType::Type var_type) { void InitializeVariable(Variable *var, proto::VarType::Type var_type) {
if (var_type == proto::VarType::LOD_TENSOR) { if (var_type == proto::VarType::LOD_TENSOR) {
var->GetMutable<LoDTensor>(); var->GetMutable<LoDTensor>();
} else if (var_type == proto::VarType::SELECTED_ROWS) { } else if (var_type == proto::VarType::SELECTED_ROWS) {
...@@ -37,7 +37,7 @@ void InitializeVariable(Variable* var, proto::VarType::Type var_type) { ...@@ -37,7 +37,7 @@ void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
} else if (var_type == proto::VarType::FETCH_LIST) { } else if (var_type == proto::VarType::FETCH_LIST) {
var->GetMutable<FeedFetchList>(); var->GetMutable<FeedFetchList>();
} else if (var_type == proto::VarType::STEP_SCOPES) { } else if (var_type == proto::VarType::STEP_SCOPES) {
var->GetMutable<std::vector<framework::Scope*>>(); var->GetMutable<std::vector<framework::Scope *>>();
} else if (var_type == proto::VarType::LOD_RANK_TABLE) { } else if (var_type == proto::VarType::LOD_RANK_TABLE) {
var->GetMutable<LoDRankTable>(); var->GetMutable<LoDRankTable>();
} else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) { } else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) {
...@@ -56,5 +56,27 @@ void InitializeVariable(Variable* var, proto::VarType::Type var_type) { ...@@ -56,5 +56,27 @@ void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
var_type); var_type);
} }
} }
void CopyVariable(const Variable &src_var, Variable *dst_var) {
// only support cpu now
auto cpu_place = platform::CPUPlace();
if (src_var.IsType<framework::LoDTensor>()) {
auto *tmp_grad_tensor = dst_var->GetMutable<framework::LoDTensor>();
auto &src_tensor = src_var.Get<framework::LoDTensor>();
tmp_grad_tensor->set_lod(src_tensor.lod());
framework::TensorCopy(src_tensor, cpu_place, tmp_grad_tensor);
} else if (src_var.IsType<framework::SelectedRows>()) {
auto &src_slr = src_var.Get<framework::SelectedRows>();
auto *tmp_grad_slr = dst_var->GetMutable<framework::SelectedRows>();
tmp_grad_slr->set_rows(src_slr.rows());
tmp_grad_slr->set_height(src_slr.height());
auto &src_t = src_slr.value();
auto *dst_t = tmp_grad_slr->mutable_value();
framework::TensorCopy(src_t, cpu_place, dst_t);
} else {
PADDLE_THROW("unknown var type to copy");
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void InitializeVariable(Variable *var, proto::VarType::Type var_type); void InitializeVariable(Variable* var, proto::VarType::Type var_type);
void CopyVariable(const Variable& src_var, Variable* dst_var);
} }
} }
...@@ -54,6 +54,7 @@ cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope) ...@@ -54,6 +54,7 @@ cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope)
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory) cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory)
cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory) cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory)
cc_library(parameter_recv SRCS parameter_recv.cc DEPS sendrecvop_rpc memory) cc_library(parameter_recv SRCS parameter_recv.cc DEPS sendrecvop_rpc memory)
cc_library(communicator SRCS communicator.cc DEPS scope selected_rows tensor variable_helper selected_rows_functor)
if(WITH_GPU) if(WITH_GPU)
cc_test(collective_server_test SRCS collective_server_test.cc cc_test(collective_server_test SRCS collective_server_test.cc
DEPS sendrecvop_rpc executor ${RPC_DEPS} DEPS sendrecvop_rpc executor ${RPC_DEPS}
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/distributed/communicator.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/distributed/parameter_recv.h"
#include "paddle/fluid/operators/distributed/parameter_send.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
namespace paddle {
namespace operators {
namespace distributed {
static void MergeVars(const std::string &var_name,
const std::vector<std::shared_ptr<Variable>> &vars,
Scope *scope) {
PADDLE_ENFORCE(!vars.empty(), "should have value to merge!");
auto cpu_place = platform::CPUPlace();
auto &var0 = vars[0];
auto *out_var = scope->Var(var_name);
if (var0->IsType<framework::LoDTensor>()) {
auto *out_t = out_var->GetMutable<framework::LoDTensor>();
auto *out_ptr = out_t->mutable_data<float>(
var0->Get<framework::LoDTensor>().dims(), cpu_place);
auto numel = out_t->numel();
for (auto i = 0; i < numel; ++i) {
out_ptr[i] = 0;
for (auto &var : vars) {
auto &var_t = var->Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(var_t.numel(), numel, "should have the same dims");
out_ptr[i] += var_t.data<float>()[i];
}
}
} else if (var0->IsType<framework::SelectedRows>()) {
auto *out_slr = out_var->GetMutable<framework::SelectedRows>();
std::vector<const paddle::framework::SelectedRows *> inputs;
inputs.reserve(vars.size());
for (auto &var : vars) {
inputs.push_back(&var->Get<framework::SelectedRows>());
}
math::scatter::MergeAdd<paddle::platform::CPUDeviceContext, float>
merge_add;
auto dev_ctx = paddle::platform::CPUDeviceContext();
merge_add(dev_ctx, inputs, out_slr, false);
} else {
PADDLE_THROW("unsupported var type!");
}
}
void Communicator::SendThread() {
for (auto &iter : send_varname_to_queue_) {
auto &var_name = iter.first;
VLOG(3) << "merge var " << var_name << " and send";
auto &var_queue = iter.second;
std::vector<std::shared_ptr<Variable>> vars;
const size_t max_merge_var_num = 20;
size_t merged_var_num = 0;
while (var_queue->Size() > 0 && merged_var_num < max_merge_var_num) {
vars.push_back(var_queue->Pop());
merged_var_num++;
}
MergeVars(var_name, vars, send_scope_.get());
auto send_functor = distributed::ParameterSend<float>();
// send_functor(var_name, send_varname_to_ctx_[var_name], exe_ctx,
// send_scope_, true);
}
}
void Communicator::RecvThread() {
// parallel run recv graph
for (auto &iter : recv_varname_to_ctx_) {
auto &var_name = iter.first;
VLOG(3) << "recv var " << iter.first;
auto recv_functor = distributed::ParameterRecv<float>();
// recv_functor(var_name, iter.second, exe_ctx, recv_scope_);
}
}
void Communicator::Send(const std::string &var_name,
const framework::Scope &scope) {
// push var into send queue by var_name
auto *grad_var = scope.FindVar(var_name);
PADDLE_ENFORCE(grad_var->IsInitialized(), "grad var should be inited");
auto tmp_grad_var = std::make_shared<Variable>();
framework::CopyVariable(*grad_var, tmp_grad_var.get());
send_varname_to_queue_[var_name]->Push(tmp_grad_var);
}
void Communicator::Start() {
// start send and recv thread
send_thread_.reset(
new std::thread(std::bind(&Communicator::SendThread, this)));
recv_thread_.reset(
new std::thread(std::bind(&Communicator::RecvThread, this)));
}
} // namespace distributed
} // namespace operators
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <deque>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/distributed/rpc_common.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace operators {
namespace distributed {
using Scope = framework::Scope;
using Variable = framework::Variable;
template <typename T>
class BlockingQueue {
public:
explicit BlockingQueue(size_t capacity) : capacity_(capacity) {
PADDLE_ENFORCE_GT(capacity_, 0, "The capacity must be greater than 0.");
}
bool Push(const T& elem) {
std::unique_lock<std::mutex> lock(mutex_);
send_cv_.wait(lock, [&] { return queue_.size() < capacity_; });
PADDLE_ENFORCE_LT(queue_.size(), capacity_);
queue_.push_back(elem);
recv_cv_.notify_one();
return true;
}
bool Push(T&& elem) {
std::unique_lock<std::mutex> lock(mutex_);
send_cv_.wait(lock, [&] { return queue_.size() < capacity_; });
PADDLE_ENFORCE_LT(queue_.size(), capacity_);
queue_.emplace_back(std::move(elem));
recv_cv_.notify_one();
return true;
}
T Pop() {
std::unique_lock<std::mutex> lock(mutex_);
recv_cv_.wait(lock, [=] { return !queue_.empty(); });
T rc(std::move(queue_.front()));
queue_.pop_front();
return rc;
}
size_t Cap() const {
std::lock_guard<std::mutex> lock(mutex_);
return capacity_;
}
size_t Size() const {
std::lock_guard<std::mutex> lock(mutex_);
return queue_.size();
}
private:
const size_t capacity_;
std::deque<T> queue_;
mutable std::mutex mutex_;
std::condition_variable recv_cv_;
std::condition_variable send_cv_;
};
class Communicator {
public:
Communicator(
const std::unordered_map<std::string, RpcContext>& send_varname_to_ctx,
const std::unordered_map<std::string, RpcContext>& recv_varname_to_ctx,
Scope* recv_scope)
: send_varname_to_ctx_(send_varname_to_ctx),
recv_varname_to_ctx_(recv_varname_to_ctx),
recv_scope_(recv_scope) {
// get all send information from graph, build vars_to_send
send_scope_.reset(new Scope());
for (auto& iter : send_varname_to_ctx_) {
send_varname_to_queue_[iter.first] =
std::make_shared<BlockingQueue<std::shared_ptr<Variable>>>(10);
}
}
~Communicator() {}
void Start();
// send grad
void Send(const std::string& var_name, const framework::Scope& scope);
private:
void SendThread();
void RecvThread();
std::unordered_map<std::string,
std::shared_ptr<BlockingQueue<std::shared_ptr<Variable>>>>
send_varname_to_queue_;
std::unordered_map<std::string, RpcContext> send_varname_to_ctx_;
std::unordered_map<std::string, RpcContext> recv_varname_to_ctx_;
std::unique_ptr<std::thread> send_thread_;
std::unique_ptr<std::thread> recv_thread_;
Scope* recv_scope_; // should be global scope
std::unique_ptr<Scope> send_scope_; // an independent scope
};
} // namespace distributed
} // namespace operators
} // namespace paddle
...@@ -39,7 +39,7 @@ using DDim = framework::DDim; ...@@ -39,7 +39,7 @@ using DDim = framework::DDim;
static std::vector<std::vector<int64_t>> SplitIds( static std::vector<std::vector<int64_t>> SplitIds(
const std::vector<int64_t>& ids_vector, const std::vector<int64_t>& ids_vector,
const std::vector<int64_t>& height_section, framework::Scope* scope) { const std::vector<int64_t>& height_section) {
std::set<int64_t> all_ids; std::set<int64_t> all_ids;
for (auto id : ids_vector) { for (auto id : ids_vector) {
all_ids.insert(id); all_ids.insert(id);
...@@ -203,7 +203,7 @@ void prefetch(const std::string& id_name, const std::string& out_name, ...@@ -203,7 +203,7 @@ void prefetch(const std::string& id_name, const std::string& out_name,
#endif #endif
} }
auto splited_ids = SplitIds(ids_vector, height_sections, local_scope); auto splited_ids = SplitIds(ids_vector, height_sections);
SplitIdsIntoMultipleVarsBySection(in_var_names, height_sections, splited_ids, SplitIdsIntoMultipleVarsBySection(in_var_names, height_sections, splited_ids,
local_scope); local_scope);
......
...@@ -73,7 +73,7 @@ void ParameterRecv<T>::operator()(const std::string &var_name, ...@@ -73,7 +73,7 @@ void ParameterRecv<T>::operator()(const std::string &var_name,
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
} }
} else { } else {
PADDLE_THROW("unsupported var type to send!"); PADDLE_THROW("unsupported var type to recv!");
} }
// concat recved tensor into one var // concat recved tensor into one var
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -14,40 +14,20 @@ limitations under the License. */ ...@@ -14,40 +14,20 @@ limitations under the License. */
#pragma once #pragma once
#include <cstdint> #include <string>
#include <cstring>
#include <memory>
#include <typeindex>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
namespace paddle { namespace paddle {
namespace operators {
namespace framework { namespace distributed {
class Communicator { struct RpcContext {
public: std::string var_name;
Communicator() {} std::vector<std::string> splited_var_names;
~Communicator() {} std::vector<std::string> epmap;
std::vector<int64_t> height_sections;
// send grad
void send() {}
void receive() {}
void prefetch() {}
void wait() {}
private:
std::unique_ptr<std::thread> communicate_thread_;
}; };
} // namespace framework } // namespace distributed
} // namespace operators
} // namespace paddle } // namespace paddle
...@@ -95,7 +95,7 @@ struct MergeAdd { ...@@ -95,7 +95,7 @@ struct MergeAdd {
enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY }; enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY };
// out = seleted_rows_in / tensor // out = selected_rows_in / tensor
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct UpdateToTensor { struct UpdateToTensor {
void operator()(const DeviceContext& context, const ScatterOps& op, void operator()(const DeviceContext& context, const ScatterOps& op,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册