提交 33a004a7 编写于 作者: T tangwei12

fix numel nce and prefetch

上级 57557f67
...@@ -114,9 +114,15 @@ static void MergeMultipleVarsIntoOneBySection( ...@@ -114,9 +114,15 @@ static void MergeMultipleVarsIntoOneBySection(
id_to_offset[ids_vector[i]].push_back(i); id_to_offset[ids_vector[i]].push_back(i);
} }
auto& id_tensor = scope.FindVar(id_name)->Get<framework::LoDTensor>(); auto& id_tensor = scope->FindVar(id_name)->Get<framework::LoDTensor>();
auto* out_tensor = auto* out_tensor =
scope.FindVar(out_name)->GetMutable<framework::LoDTensor>(); scope->FindVar(out_name)->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE_GT(
out_tensor->numel(), 0,
"When calling this method, the Tensor's numel must larger than zero. "
"Please check Tensor::Resize has been called first.");
auto* out_tensor_data = out_tensor->mutable_data<float>(id_tensor.place()); auto* out_tensor_data = out_tensor->mutable_data<float>(id_tensor.place());
bool is_on_cpu_place = true; bool is_on_cpu_place = true;
......
...@@ -166,11 +166,12 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -166,11 +166,12 @@ class NCEKernel : public framework::OpKernel<T> {
std::set<T> st(labels.begin(), labels.end()); std::set<T> st(labels.begin(), labels.end());
labels.assign(st.begin(), st.end()); labels.assign(st.begin(), st.end());
auto &local_scope = context.scope().NewScope(); framework::Scope &local_scope = context.scope().NewScope();
auto height_sections = context.Attr<std::vector<int>>("height_sections"); auto height_sections = context.Attr<std::vector<int>>("height_sections");
auto table_names = context.Attr<std::vector<std::string>>("table_names"); auto table_names = context.Attr<std::vector<std::string>>("table_names");
auto *ids = local_scope.Var("Ids@Local"); auto *ids = local_scope.Var("Ids@Prefetch");
auto *x_tensor = ids->GetMutable<framework::LoDTensor>(); auto *x_tensor = ids->GetMutable<framework::LoDTensor>();
x_tensor->mutable_data<int64_t>( x_tensor->mutable_data<int64_t>(
framework::make_ddim({static_cast<int64_t>(labels.size()), 1}), framework::make_ddim({static_cast<int64_t>(labels.size()), 1}),
...@@ -179,12 +180,18 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -179,12 +180,18 @@ class NCEKernel : public framework::OpKernel<T> {
std::memcpy(x_tensor->data<int64_t>(), labels.data(), std::memcpy(x_tensor->data<int64_t>(), labels.data(),
labels.size() * sizeof(int64_t)); labels.size() * sizeof(int64_t));
local_scope.Var("Weight@Local"); std::vector<int> w_dims = paddle::framework::vectorize2int(
context.Input<Tensor>("Weight")->dims());
w_dims[0] = static_cast<int>(labels.size());
auto *w_tensor = local_scope.Var("Weight@Prefetch")
->GetMutable<framework::LoDTensor>();
w_tensor->Resize(framework::make_ddim(w_dims));
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
operators::distributed::prefetch("Ids@Local", "Weight@Local", table_names, operators::distributed::prefetch("Ids@Prefetch", "Weight@Prefetch",
epmap, height_sections, context, table_names, epmap, height_sections,
&local_scope); context, local_scope);
#else #else
PADDLE_THROW( PADDLE_THROW(
"paddle is not compiled with distribute support, can not do " "paddle is not compiled with distribute support, can not do "
...@@ -192,7 +199,7 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -192,7 +199,7 @@ class NCEKernel : public framework::OpKernel<T> {
#endif #endif
auto weight_mat = EigenMatrix<T>::From( auto weight_mat = EigenMatrix<T>::From(
(local_scope.Var("Weight@Local")->Get<framework::LoDTensor>())); (local_scope.Var("Weight@Prefetch")->Get<framework::LoDTensor>()));
for (int64_t i = 0; i < sample_labels->numel(); ++i) { for (int64_t i = 0; i < sample_labels->numel(); ++i) {
std::vector<int64_t>::iterator it = std::vector<int64_t>::iterator it =
std::find(labels.begin(), labels.end(), sample_labels_data[i]); std::find(labels.begin(), labels.end(), sample_labels_data[i]);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册