From 25f47fc0afd3cbfe821fcbc82d978ff243ae29b4 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 28 May 2018 15:46:08 +0800 Subject: [PATCH] fix prefetch bugs, optimize code --- paddle/fluid/framework/selected_rows.cc | 35 +++++++++++--------- paddle/fluid/framework/selected_rows.h | 2 +- paddle/fluid/operators/detail/grpc_server.cc | 11 +++--- 3 files changed, 25 insertions(+), 23 deletions(-) diff --git a/paddle/fluid/framework/selected_rows.cc b/paddle/fluid/framework/selected_rows.cc index 56cf6693c..b4168f389 100644 --- a/paddle/fluid/framework/selected_rows.cc +++ b/paddle/fluid/framework/selected_rows.cc @@ -121,24 +121,29 @@ bool SelectedRows::HasKey(int64_t key) const { } std::vector> SelectedRows::Get( - std::vector keys, framework::Tensor* value) const { + const std::vector& keys, framework::Tensor* value) const { PADDLE_ENFORCE(value->IsInitialized(), "The value tensor should be initialized."); std::vector> non_keys_pair; - int64_t value_width = value_->numel() / value_->dims()[0]; - PADDLE_ENFORCE_EQ(value_width, value->numel() / value->dims()[0], - "output tensor should have the same shape with table " - "execpt the dims[0]."); - - for (size_t i = 0; i < keys.size(); ++i) { - int64_t index = Index(keys[i]); - if (index == -1) { - non_keys_pair.push_back(std::make_pair(keys[i], static_cast(i))); - } else { - framework::VisitDataType( - framework::ToDataType(value_->type()), - TensorCopyVisitor(value, i * value_width, *value_.get(), - index * value_width, value_width)); + if (keys.empty()) { + VLOG(3) << "keys is empty, please check data!"; + } else { + int64_t value_width = value_->numel() / value_->dims()[0]; + PADDLE_ENFORCE_EQ(value_width, value->numel() / value->dims()[0], + "output tensor should have the same shape with table " + "except the dims[0]."); + + for (size_t i = 0; i < keys.size(); ++i) { + int64_t index = Index(keys[i]); + if (index == -1) { + non_keys_pair.push_back( + std::make_pair(keys[i], static_cast(i))); + } else { + framework::VisitDataType( + framework::ToDataType(value_->type()), + TensorCopyVisitor(value, i * value_width, *value_.get(), + index * value_width, value_width)); + } } } return non_keys_pair; diff --git a/paddle/fluid/framework/selected_rows.h b/paddle/fluid/framework/selected_rows.h index c27c927ee..c80b05eed 100644 --- a/paddle/fluid/framework/selected_rows.h +++ b/paddle/fluid/framework/selected_rows.h @@ -82,7 +82,7 @@ class SelectedRows { * @return a list of pair which contains the non-exists key and the index in * the value */ - std::vector> Get(std::vector keys, + std::vector> Get(const std::vector& keys, framework::Tensor* value) const; /* diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 58faead2b..361cc24b5 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -177,11 +177,8 @@ class RequestPrefetch final : public RequestBase { program_(program), prefetch_ctx_(prefetch_ctx), req_id_(req_id) { - if (sync_mode_) { - request_.reset(new VariableResponse(scope, dev_ctx_, false)); - } else { - request_.reset(new VariableResponse(scope, dev_ctx_, true)); - } + // prefetch always create a new sub scope + request_.reset(new VariableResponse(scope, dev_ctx_, true)); int method_id = static_cast(detail::GrpcMethod::kPrefetchVariable); service_->RequestAsyncUnary( method_id, &ctx_, request_.get(), &responder_, cq_, cq_, @@ -198,10 +195,10 @@ class RequestPrefetch final : public RequestBase { std::string var_name = request_->OutVarname(); VLOG(3) << "RequestPrefetch " << var_name; auto var_desc = program_->Block(0).FindVar(var_name); - framework::Scope* local_scope = &scope_->NewScope(); + framework::Scope* local_scope = request_->GetMutableLocalScope(); auto* var = local_scope->FindVar(var_name); InitializeVariable(var, var_desc->GetType()); - executor_->RunPreparedContext(prefetch_ctx_, scope_); + executor_->RunPreparedContext(prefetch_ctx_, local_scope); SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply_); -- GitLab