提交 361cb0e0 编写于 作者: Q Qiao Longfei

lookup remote table can compile

上级 7c3ce295
......@@ -68,6 +68,15 @@ class LookupRemoteTableOpMaker : public framework::OpProtoAndCheckerMaker {
"contains the ids to be looked up in W. "
"The last dimension size must be 1.");
AddOutput("Out", "The lookup results, which have the same type as W.");
AddAttr<std::vector<int64_t>>("height_sections",
"Height for each output SelectedRows.")
.SetDefault(std::vector<int64_t>({}));
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<std::vector<std::string>>(
"epmap",
"(string vector, default 127.0.0.1:6164)"
"Server endpoints in the order of input variables for mapping")
.SetDefault({"127.0.0.1:6164"});
AddAttr<int64_t>("padding_idx",
"(int64, default -1) "
"If the value is -1, it makes no effect to lookup. "
......@@ -98,7 +107,8 @@ or not. And the output only shares the LoD information with input Ids.
namespace ops = paddle::operators;
REGISTER_OPERATOR(lookup_remote_table, ops::LookupRemoteTableOp,
ops::EmptyGradOpMaker, ops::LookupRemoteTableOpMaker);
paddle::framework::EmptyGradOpMaker,
ops::LookupRemoteTableOpMaker);
REGISTER_OP_CPU_KERNEL(lookup_remote_table, ops::LookupRemoteTableKernel<float>,
ops::LookupRemoteTableKernel<double>);
......@@ -12,26 +12,32 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <future> // NOLINT
#include <ostream>
#include <set>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
#include "paddle/fluid/operators/math/blas.h"
namespace paddle {
namespace operators {
namespace distributed {
inline size_t GetSectionIndex(int64_t id,
const std::vector<int64_t>& abs_sections) {
for (size_t i = 1; i < abs_sections.size(); ++i) {
if (row < abs_sections[i]) {
if (id < abs_sections[i]) {
return i - 1;
}
}
......@@ -62,9 +68,10 @@ inline std::vector<std::vector<int64_t>> SplitIds(
std::vector<std::vector<int64_t>> splited_ids;
splited_ids.resize(height_section.size() + 1);
for (auto& id : all_ids) {
auto section_index = GetSectionIndex(id);
auto section_index = GetSectionIndex(id, abs_sections);
splited_ids[section_index].push_back(id - abs_sections[section_index]);
}
return splited_ids;
}
inline void SplitIdsIntoMultipleVarsBySection(
......@@ -82,7 +89,7 @@ inline void SplitIdsIntoMultipleVarsBySection(
auto& ids = splited_ids[i];
if (!ids.empty()) {
auto* id_tensor_data = id_tensor->mutable_data<int64_t>(
framework::make_ddim({ids.size(), 1}), place);
framework::make_ddim({static_cast<int64_t>(ids.size()), 1}), place);
memcpy(id_tensor_data, ids.data(), sizeof(int64_t) * ids.size());
}
}
......@@ -93,8 +100,8 @@ inline void MergeMultipleVarsIntoOnBySection(
const std::vector<std::string>& out_var_names,
const std::vector<int64_t>& height_section,
const std::vector<std::vector<int64_t>>& splited_ids,
framework::Scope* scope) {
PADDLE_ENFORCE_EQ(in_var_names.size(), height_section.size() + 1, "");
const framework::ExecutionContext& context, framework::Scope* scope) {
PADDLE_ENFORCE_EQ(out_var_names.size(), height_section.size() + 1, "");
auto cpu_place = platform::CPUPlace();
......@@ -106,15 +113,15 @@ inline void MergeMultipleVarsIntoOnBySection(
id_to_offset[id_data[i]].push_back(i);
}
auto& out_tensor = scope->Var(out_name)->Get<framework::LoDTensor>();
auto* out_tensor_data = out_tensor.mutable_data<float>();
auto* out_tensor = scope->Var(out_name)->GetMutable<framework::LoDTensor>();
auto* out_tensor_data = out_tensor->mutable_data<float>(context.GetPlace());
for (size_t section_idx = 0; section_idx < out_var_names.size();
++section_idx) {
auto& ids_in_this_section = splited_ids[section_idx];
auto& prefetch_out_var =
scope->Var(out_var_names[section_idx])->Get<framework::LoDTensor>();
const auto* out_var_data = prefetch_out_var.mutable_data<float>();
const auto* out_var_data = prefetch_out_var.data<float>();
auto& dims = prefetch_out_var.dims();
PADDLE_ENFORCE_EQ(dims.size(), 2, "");
......@@ -129,63 +136,64 @@ inline void MergeMultipleVarsIntoOnBySection(
for (auto& offset : offsets) {
// should support GPU tensor
memory::Copy(cpu_place, out_tensor_data + offset * row_numel, cpu_place,
out_var_data + i * grad_row_numel,
sizeof(T) * grad_row_numel);
out_var_data + i * row_numel, sizeof(float) * row_numel);
}
}
}
}
inline void prefetch(const std::string& table_name, const std::string& id_name,
const std::string& out_name,
const std::vector<std::string>& epmap,
const std::vector<int64_t>& height_section,
const framework::Scope& scope,
const platform::Place& place) const {
auto local_scope = scope.NewScope();
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(Attr<int>("trainer_id"));
std::vector<std::string> in_var_names;
std::vector<std::string> out_var_names;
for (size_t i = 0; i < epmap.size(); ++i) {
in_var_names.push_back(id_name + "@" + epmap[i]);
out_var_names.push_back(out_name + "@" + epmap[i]);
}
auto splited_ids = SplitIds(id_name, height_section, local_scope);
SplitIdsIntoMultipleVarsBySection(id_name, in_var_names, height_section,
splited_ids, local_scope);
// create output var in local scope
for (auto& name : out_var_names) {
local_scope.Var(name)->GetMutable<framework::LoDTensor>();
}
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(local_scope, ins[i])) {
VLOG(30) << "sending " << ins[i] << " to " << epmap[i] << " to get "
<< outs[i] << " back";
rets.push_back(rpc_client->AsyncPrefetchVar(
epmap[i], ctx, local_scope, in_var_names[i], out_var_names[i]));
} else {
VLOG(30) << "don't send no-initialied variable: " << out_var_names[i];
}
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
}
MergeMultipleVarsIntoOnBySection(id_name, out_name, out_var_names,
height_section, plited_ids, scope)
scope.DeleteScope(local_scope);
}
// inline void prefetch(const std::string& table_name, const std::string&
// id_name,
// const std::string& out_name,
// const std::vector<std::string>& epmap,
// const std::vector<int64_t>& height_section,
// const framework::Scope& scope,
// const platform::Place& place) {
// auto& local_scope = scope.NewScope();
//
// platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
// auto& ctx = *pool.Get(place);
//
// distributed::RPCClient* rpc_client =
// distributed::RPCClient::GetInstance<RPCCLIENT_T>(Attr<int>("trainer_id"));
//
// std::vector<std::string> in_var_names;
// std::vector<std::string> out_var_names;
// for (size_t i = 0; i < epmap.size(); ++i) {
// in_var_names.push_back(id_name + "@" + epmap[i]);
// out_var_names.push_back(out_name + "@" + epmap[i]);
// }
//
// auto splited_ids = SplitIds(id_name, height_section, &local_scope);
// SplitIdsIntoMultipleVarsBySection(id_name, in_var_names, height_section,
// splited_ids, &local_scope);
//
// // create output var in local scope
// for (auto& name : out_var_names) {
// local_scope.Var(name)->GetMutable<framework::LoDTensor>();
// }
//
// std::vector<distributed::VarHandlePtr> rets;
// for (size_t i = 0; i < in_var_names.size(); i++) {
// if (NeedSend(local_scope, in_var_names[i])) {
// VLOG(30) << "sending " << in_var_names[i] << " to " << epmap[i] << " to
// get "
// << out_var_names[i] << " back";
// rets.push_back(rpc_client->AsyncPrefetchVar(
// epmap[i], ctx, local_scope, in_var_names[i], out_var_names[i]));
// } else {
// VLOG(30) << "don't send no-initialied variable: " << out_var_names[i];
// }
// }
// for (size_t i = 0; i < rets.size(); i++) {
// PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
// }
//
// MergeMultipleVarsIntoOnBySection(id_name, out_name, out_var_names,
// height_section, splited_ids, &local_scope);
//
// scope.DeleteScope(&local_scope);
//}
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
......@@ -198,54 +206,70 @@ template <typename T>
class LookupRemoteTableKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* ids_t = context.Input<LoDTensor>("Ids"); // int tensor
std::string id_name = context.Inputs("Ids").front();
auto* ids_t = context.Input<LoDTensor>("Ids"); // int tensor
std::string out_name = context.Outputs("Out").front();
auto* output_t = context.Output<LoDTensor>("Out"); // float tensor
std::string table_name = context.Inputs("W").front();
auto* table_var = context.InputVar("W");
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
int64_t* ids = const_cast<int64_t*>(ids_t->data<int64_t>());
int64_t ids_numel = ids_t->numel();
if (table_var->IsType<LoDTensor>()) {
auto* table_t = context.Input<LoDTensor>("W");
int64_t row_number = table_t->dims()[0];
int64_t row_width = table_t->dims()[1];
auto* table = table_t->data<T>();
auto* output = output_t->mutable_data<T>(context.GetPlace());
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
PADDLE_ENFORCE_LT(ids[i], row_number);
PADDLE_ENFORCE_GE(ids[i], 0, "ids %d", i);
memcpy(output + i * row_width, table + ids[i] * row_width,
row_width * sizeof(T));
}
}
} else if (table_var->IsType<SelectedRows>()) {
const auto& table_t = table_var->Get<SelectedRows>();
int64_t row_width = table_t.value().dims()[1];
const auto* table = table_t.value().data<T>();
auto* output = output_t->mutable_data<T>(context.GetPlace());
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
PADDLE_ENFORCE_GE(ids[i], 0);
auto id_index = table_t.Index(ids[i]);
PADDLE_ENFORCE_GE(id_index, 0, "the input key should be exists.");
blas.VCOPY(row_width, table + id_index * row_width,
output + i * row_width);
}
auto epmap = context.Attr<std::vector<std::string>>("epmap");
auto height_sections =
context.Attr<std::vector<int64_t>>("height_sections");
auto& local_scope = context.scope().NewScope();
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(context.GetPlace());
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(
context.Attr<int>("trainer_id"));
std::vector<std::string> in_var_names;
std::vector<std::string> out_var_names;
for (size_t i = 0; i < epmap.size(); ++i) {
in_var_names.push_back(id_name + "@" + epmap[i]);
out_var_names.push_back(out_name + "@" + epmap[i]);
}
auto splited_ids = SplitIds(id_name, height_sections, &local_scope);
SplitIdsIntoMultipleVarsBySection(id_name, in_var_names, height_sections,
splited_ids, &local_scope);
// create output var in local scope
for (auto& name : out_var_names) {
local_scope.Var(name)->GetMutable<framework::LoDTensor>();
}
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < in_var_names.size(); i++) {
if (NeedSend(local_scope, in_var_names[i])) {
VLOG(30) << "sending " << in_var_names[i] << " to " << epmap[i]
<< " to get " << out_var_names[i] << " back";
rets.push_back(rpc_client->AsyncPrefetchVar(
epmap[i], ctx, local_scope, in_var_names[i], out_var_names[i]));
} else {
VLOG(30) << "don't send no-initialied variable: " << out_var_names[i];
}
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
}
MergeMultipleVarsIntoOnBySection(id_name, out_name, out_var_names,
height_sections, splited_ids, context,
&local_scope);
context.scope().DeleteScope(&local_scope);
}
};
} // namespace distributed
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册