From 33a004a779e8c4acb19ab13b641cc16d3827a582 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Mon, 10 Dec 2018 20:36:49 +0800 Subject: [PATCH] fix numel nce and prefetch --- .../distributed/parameter_prefetch.cc | 10 +++++++-- paddle/fluid/operators/nce_op.h | 21 ++++++++++++------- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/distributed/parameter_prefetch.cc b/paddle/fluid/operators/distributed/parameter_prefetch.cc index 4cdeae81a..aebf6376d 100644 --- a/paddle/fluid/operators/distributed/parameter_prefetch.cc +++ b/paddle/fluid/operators/distributed/parameter_prefetch.cc @@ -114,9 +114,15 @@ static void MergeMultipleVarsIntoOneBySection( id_to_offset[ids_vector[i]].push_back(i); } - auto& id_tensor = scope.FindVar(id_name)->Get(); + auto& id_tensor = scope->FindVar(id_name)->Get(); auto* out_tensor = - scope.FindVar(out_name)->GetMutable(); + scope->FindVar(out_name)->GetMutable(); + + PADDLE_ENFORCE_GT( + out_tensor->numel(), 0, + "When calling this method, the Tensor's numel must larger than zero. " + "Please check Tensor::Resize has been called first."); + auto* out_tensor_data = out_tensor->mutable_data(id_tensor.place()); bool is_on_cpu_place = true; diff --git a/paddle/fluid/operators/nce_op.h b/paddle/fluid/operators/nce_op.h index 862064be1..99a3baba9 100644 --- a/paddle/fluid/operators/nce_op.h +++ b/paddle/fluid/operators/nce_op.h @@ -166,11 +166,12 @@ class NCEKernel : public framework::OpKernel { std::set st(labels.begin(), labels.end()); labels.assign(st.begin(), st.end()); - auto &local_scope = context.scope().NewScope(); + framework::Scope &local_scope = context.scope().NewScope(); + auto height_sections = context.Attr>("height_sections"); auto table_names = context.Attr>("table_names"); - auto *ids = local_scope.Var("Ids@Local"); + auto *ids = local_scope.Var("Ids@Prefetch"); auto *x_tensor = ids->GetMutable(); x_tensor->mutable_data( framework::make_ddim({static_cast(labels.size()), 1}), @@ -179,12 +180,18 @@ class NCEKernel : public framework::OpKernel { std::memcpy(x_tensor->data(), labels.data(), labels.size() * sizeof(int64_t)); - local_scope.Var("Weight@Local"); + std::vector w_dims = paddle::framework::vectorize2int( + context.Input("Weight")->dims()); + w_dims[0] = static_cast(labels.size()); + + auto *w_tensor = local_scope.Var("Weight@Prefetch") + ->GetMutable(); + w_tensor->Resize(framework::make_ddim(w_dims)); #ifdef PADDLE_WITH_DISTRIBUTE - operators::distributed::prefetch("Ids@Local", "Weight@Local", table_names, - epmap, height_sections, context, - &local_scope); + operators::distributed::prefetch("Ids@Prefetch", "Weight@Prefetch", + table_names, epmap, height_sections, + context, local_scope); #else PADDLE_THROW( "paddle is not compiled with distribute support, can not do " @@ -192,7 +199,7 @@ class NCEKernel : public framework::OpKernel { #endif auto weight_mat = EigenMatrix::From( - (local_scope.Var("Weight@Local")->Get())); + (local_scope.Var("Weight@Prefetch")->Get())); for (int64_t i = 0; i < sample_labels->numel(); ++i) { std::vector::iterator it = std::find(labels.begin(), labels.end(), sample_labels_data[i]); -- GitLab