提交 361cb0e0 编写于 作者: Q Qiao Longfei

lookup remote table can compile

上级 7c3ce295
...@@ -68,6 +68,15 @@ class LookupRemoteTableOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -68,6 +68,15 @@ class LookupRemoteTableOpMaker : public framework::OpProtoAndCheckerMaker {
"contains the ids to be looked up in W. " "contains the ids to be looked up in W. "
"The last dimension size must be 1."); "The last dimension size must be 1.");
AddOutput("Out", "The lookup results, which have the same type as W."); AddOutput("Out", "The lookup results, which have the same type as W.");
AddAttr<std::vector<int64_t>>("height_sections",
"Height for each output SelectedRows.")
.SetDefault(std::vector<int64_t>({}));
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<std::vector<std::string>>(
"epmap",
"(string vector, default 127.0.0.1:6164)"
"Server endpoints in the order of input variables for mapping")
.SetDefault({"127.0.0.1:6164"});
AddAttr<int64_t>("padding_idx", AddAttr<int64_t>("padding_idx",
"(int64, default -1) " "(int64, default -1) "
"If the value is -1, it makes no effect to lookup. " "If the value is -1, it makes no effect to lookup. "
...@@ -98,7 +107,8 @@ or not. And the output only shares the LoD information with input Ids. ...@@ -98,7 +107,8 @@ or not. And the output only shares the LoD information with input Ids.
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(lookup_remote_table, ops::LookupRemoteTableOp, REGISTER_OPERATOR(lookup_remote_table, ops::LookupRemoteTableOp,
ops::EmptyGradOpMaker, ops::LookupRemoteTableOpMaker); paddle::framework::EmptyGradOpMaker,
ops::LookupRemoteTableOpMaker);
REGISTER_OP_CPU_KERNEL(lookup_remote_table, ops::LookupRemoteTableKernel<float>, REGISTER_OP_CPU_KERNEL(lookup_remote_table, ops::LookupRemoteTableKernel<float>,
ops::LookupRemoteTableKernel<double>); ops::LookupRemoteTableKernel<double>);
...@@ -12,26 +12,32 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,26 +12,32 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once
#include <future> // NOLINT #include <future> // NOLINT
#include <ostream> #include <ostream>
#include <set> #include <set>
#include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/detail/macros.h" #include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h" #include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
#include "paddle/fluid/operators/math/blas.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace distributed {
inline size_t GetSectionIndex(int64_t id, inline size_t GetSectionIndex(int64_t id,
const std::vector<int64_t>& abs_sections) { const std::vector<int64_t>& abs_sections) {
for (size_t i = 1; i < abs_sections.size(); ++i) { for (size_t i = 1; i < abs_sections.size(); ++i) {
if (row < abs_sections[i]) { if (id < abs_sections[i]) {
return i - 1; return i - 1;
} }
} }
...@@ -62,9 +68,10 @@ inline std::vector<std::vector<int64_t>> SplitIds( ...@@ -62,9 +68,10 @@ inline std::vector<std::vector<int64_t>> SplitIds(
std::vector<std::vector<int64_t>> splited_ids; std::vector<std::vector<int64_t>> splited_ids;
splited_ids.resize(height_section.size() + 1); splited_ids.resize(height_section.size() + 1);
for (auto& id : all_ids) { for (auto& id : all_ids) {
auto section_index = GetSectionIndex(id); auto section_index = GetSectionIndex(id, abs_sections);
splited_ids[section_index].push_back(id - abs_sections[section_index]); splited_ids[section_index].push_back(id - abs_sections[section_index]);
} }
return splited_ids;
} }
inline void SplitIdsIntoMultipleVarsBySection( inline void SplitIdsIntoMultipleVarsBySection(
...@@ -82,7 +89,7 @@ inline void SplitIdsIntoMultipleVarsBySection( ...@@ -82,7 +89,7 @@ inline void SplitIdsIntoMultipleVarsBySection(
auto& ids = splited_ids[i]; auto& ids = splited_ids[i];
if (!ids.empty()) { if (!ids.empty()) {
auto* id_tensor_data = id_tensor->mutable_data<int64_t>( auto* id_tensor_data = id_tensor->mutable_data<int64_t>(
framework::make_ddim({ids.size(), 1}), place); framework::make_ddim({static_cast<int64_t>(ids.size()), 1}), place);
memcpy(id_tensor_data, ids.data(), sizeof(int64_t) * ids.size()); memcpy(id_tensor_data, ids.data(), sizeof(int64_t) * ids.size());
} }
} }
...@@ -93,8 +100,8 @@ inline void MergeMultipleVarsIntoOnBySection( ...@@ -93,8 +100,8 @@ inline void MergeMultipleVarsIntoOnBySection(
const std::vector<std::string>& out_var_names, const std::vector<std::string>& out_var_names,
const std::vector<int64_t>& height_section, const std::vector<int64_t>& height_section,
const std::vector<std::vector<int64_t>>& splited_ids, const std::vector<std::vector<int64_t>>& splited_ids,
framework::Scope* scope) { const framework::ExecutionContext& context, framework::Scope* scope) {
PADDLE_ENFORCE_EQ(in_var_names.size(), height_section.size() + 1, ""); PADDLE_ENFORCE_EQ(out_var_names.size(), height_section.size() + 1, "");
auto cpu_place = platform::CPUPlace(); auto cpu_place = platform::CPUPlace();
...@@ -106,15 +113,15 @@ inline void MergeMultipleVarsIntoOnBySection( ...@@ -106,15 +113,15 @@ inline void MergeMultipleVarsIntoOnBySection(
id_to_offset[id_data[i]].push_back(i); id_to_offset[id_data[i]].push_back(i);
} }
auto& out_tensor = scope->Var(out_name)->Get<framework::LoDTensor>(); auto* out_tensor = scope->Var(out_name)->GetMutable<framework::LoDTensor>();
auto* out_tensor_data = out_tensor.mutable_data<float>(); auto* out_tensor_data = out_tensor->mutable_data<float>(context.GetPlace());
for (size_t section_idx = 0; section_idx < out_var_names.size(); for (size_t section_idx = 0; section_idx < out_var_names.size();
++section_idx) { ++section_idx) {
auto& ids_in_this_section = splited_ids[section_idx]; auto& ids_in_this_section = splited_ids[section_idx];
auto& prefetch_out_var = auto& prefetch_out_var =
scope->Var(out_var_names[section_idx])->Get<framework::LoDTensor>(); scope->Var(out_var_names[section_idx])->Get<framework::LoDTensor>();
const auto* out_var_data = prefetch_out_var.mutable_data<float>(); const auto* out_var_data = prefetch_out_var.data<float>();
auto& dims = prefetch_out_var.dims(); auto& dims = prefetch_out_var.dims();
PADDLE_ENFORCE_EQ(dims.size(), 2, ""); PADDLE_ENFORCE_EQ(dims.size(), 2, "");
...@@ -129,63 +136,64 @@ inline void MergeMultipleVarsIntoOnBySection( ...@@ -129,63 +136,64 @@ inline void MergeMultipleVarsIntoOnBySection(
for (auto& offset : offsets) { for (auto& offset : offsets) {
// should support GPU tensor // should support GPU tensor
memory::Copy(cpu_place, out_tensor_data + offset * row_numel, cpu_place, memory::Copy(cpu_place, out_tensor_data + offset * row_numel, cpu_place,
out_var_data + i * grad_row_numel, out_var_data + i * row_numel, sizeof(float) * row_numel);
sizeof(T) * grad_row_numel);
} }
} }
} }
} }
inline void prefetch(const std::string& table_name, const std::string& id_name, // inline void prefetch(const std::string& table_name, const std::string&
const std::string& out_name, // id_name,
const std::vector<std::string>& epmap, // const std::string& out_name,
const std::vector<int64_t>& height_section, // const std::vector<std::string>& epmap,
const framework::Scope& scope, // const std::vector<int64_t>& height_section,
const platform::Place& place) const { // const framework::Scope& scope,
auto local_scope = scope.NewScope(); // const platform::Place& place) {
// auto& local_scope = scope.NewScope();
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); //
auto& ctx = *pool.Get(place); // platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
// auto& ctx = *pool.Get(place);
distributed::RPCClient* rpc_client = //
distributed::RPCClient::GetInstance<RPCCLIENT_T>(Attr<int>("trainer_id")); // distributed::RPCClient* rpc_client =
// distributed::RPCClient::GetInstance<RPCCLIENT_T>(Attr<int>("trainer_id"));
std::vector<std::string> in_var_names; //
std::vector<std::string> out_var_names; // std::vector<std::string> in_var_names;
for (size_t i = 0; i < epmap.size(); ++i) { // std::vector<std::string> out_var_names;
in_var_names.push_back(id_name + "@" + epmap[i]); // for (size_t i = 0; i < epmap.size(); ++i) {
out_var_names.push_back(out_name + "@" + epmap[i]); // in_var_names.push_back(id_name + "@" + epmap[i]);
} // out_var_names.push_back(out_name + "@" + epmap[i]);
// }
auto splited_ids = SplitIds(id_name, height_section, local_scope); //
SplitIdsIntoMultipleVarsBySection(id_name, in_var_names, height_section, // auto splited_ids = SplitIds(id_name, height_section, &local_scope);
splited_ids, local_scope); // SplitIdsIntoMultipleVarsBySection(id_name, in_var_names, height_section,
// splited_ids, &local_scope);
// create output var in local scope //
for (auto& name : out_var_names) { // // create output var in local scope
local_scope.Var(name)->GetMutable<framework::LoDTensor>(); // for (auto& name : out_var_names) {
} // local_scope.Var(name)->GetMutable<framework::LoDTensor>();
// }
std::vector<distributed::VarHandlePtr> rets; //
for (size_t i = 0; i < ins.size(); i++) { // std::vector<distributed::VarHandlePtr> rets;
if (NeedSend(local_scope, ins[i])) { // for (size_t i = 0; i < in_var_names.size(); i++) {
VLOG(30) << "sending " << ins[i] << " to " << epmap[i] << " to get " // if (NeedSend(local_scope, in_var_names[i])) {
<< outs[i] << " back"; // VLOG(30) << "sending " << in_var_names[i] << " to " << epmap[i] << " to
rets.push_back(rpc_client->AsyncPrefetchVar( // get "
epmap[i], ctx, local_scope, in_var_names[i], out_var_names[i])); // << out_var_names[i] << " back";
} else { // rets.push_back(rpc_client->AsyncPrefetchVar(
VLOG(30) << "don't send no-initialied variable: " << out_var_names[i]; // epmap[i], ctx, local_scope, in_var_names[i], out_var_names[i]));
} // } else {
} // VLOG(30) << "don't send no-initialied variable: " << out_var_names[i];
for (size_t i = 0; i < rets.size(); i++) { // }
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); // }
} // for (size_t i = 0; i < rets.size(); i++) {
// PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
MergeMultipleVarsIntoOnBySection(id_name, out_name, out_var_names, // }
height_section, plited_ids, scope) //
// MergeMultipleVarsIntoOnBySection(id_name, out_name, out_var_names,
scope.DeleteScope(local_scope); // height_section, splited_ids, &local_scope);
} //
// scope.DeleteScope(&local_scope);
//}
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
...@@ -198,54 +206,70 @@ template <typename T> ...@@ -198,54 +206,70 @@ template <typename T>
class LookupRemoteTableKernel : public framework::OpKernel<T> { class LookupRemoteTableKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* ids_t = context.Input<LoDTensor>("Ids"); // int tensor std::string id_name = context.Inputs("Ids").front();
auto* ids_t = context.Input<LoDTensor>("Ids"); // int tensor
std::string out_name = context.Outputs("Out").front();
auto* output_t = context.Output<LoDTensor>("Out"); // float tensor auto* output_t = context.Output<LoDTensor>("Out"); // float tensor
std::string table_name = context.Inputs("W").front();
auto* table_var = context.InputVar("W"); auto* table_var = context.InputVar("W");
int64_t padding_idx = context.Attr<int64_t>("padding_idx"); int64_t padding_idx = context.Attr<int64_t>("padding_idx");
int64_t* ids = const_cast<int64_t*>(ids_t->data<int64_t>()); int64_t* ids = const_cast<int64_t*>(ids_t->data<int64_t>());
int64_t ids_numel = ids_t->numel(); int64_t ids_numel = ids_t->numel();
if (table_var->IsType<LoDTensor>()) { auto epmap = context.Attr<std::vector<std::string>>("epmap");
auto* table_t = context.Input<LoDTensor>("W"); auto height_sections =
int64_t row_number = table_t->dims()[0]; context.Attr<std::vector<int64_t>>("height_sections");
int64_t row_width = table_t->dims()[1];
auto& local_scope = context.scope().NewScope();
auto* table = table_t->data<T>();
auto* output = output_t->mutable_data<T>(context.GetPlace()); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(context.GetPlace());
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx != kNoPadding && ids[i] == padding_idx) { distributed::RPCClient* rpc_client =
memset(output + i * row_width, 0, row_width * sizeof(T)); distributed::RPCClient::GetInstance<RPCCLIENT_T>(
} else { context.Attr<int>("trainer_id"));
PADDLE_ENFORCE_LT(ids[i], row_number);
PADDLE_ENFORCE_GE(ids[i], 0, "ids %d", i); std::vector<std::string> in_var_names;
memcpy(output + i * row_width, table + ids[i] * row_width, std::vector<std::string> out_var_names;
row_width * sizeof(T)); for (size_t i = 0; i < epmap.size(); ++i) {
} in_var_names.push_back(id_name + "@" + epmap[i]);
} out_var_names.push_back(out_name + "@" + epmap[i]);
} else if (table_var->IsType<SelectedRows>()) { }
const auto& table_t = table_var->Get<SelectedRows>();
int64_t row_width = table_t.value().dims()[1]; auto splited_ids = SplitIds(id_name, height_sections, &local_scope);
const auto* table = table_t.value().data<T>(); SplitIdsIntoMultipleVarsBySection(id_name, in_var_names, height_sections,
auto* output = output_t->mutable_data<T>(context.GetPlace()); splited_ids, &local_scope);
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context); // create output var in local scope
for (int64_t i = 0; i < ids_numel; ++i) { for (auto& name : out_var_names) {
if (padding_idx != kNoPadding && ids[i] == padding_idx) { local_scope.Var(name)->GetMutable<framework::LoDTensor>();
memset(output + i * row_width, 0, row_width * sizeof(T)); }
} else {
PADDLE_ENFORCE_GE(ids[i], 0); std::vector<distributed::VarHandlePtr> rets;
auto id_index = table_t.Index(ids[i]); for (size_t i = 0; i < in_var_names.size(); i++) {
PADDLE_ENFORCE_GE(id_index, 0, "the input key should be exists."); if (NeedSend(local_scope, in_var_names[i])) {
blas.VCOPY(row_width, table + id_index * row_width, VLOG(30) << "sending " << in_var_names[i] << " to " << epmap[i]
output + i * row_width); << " to get " << out_var_names[i] << " back";
} rets.push_back(rpc_client->AsyncPrefetchVar(
epmap[i], ctx, local_scope, in_var_names[i], out_var_names[i]));
} else {
VLOG(30) << "don't send no-initialied variable: " << out_var_names[i];
} }
} }
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
}
MergeMultipleVarsIntoOnBySection(id_name, out_name, out_var_names,
height_sections, splited_ids, context,
&local_scope);
context.scope().DeleteScope(&local_scope);
} }
}; };
} // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册