diff --git a/paddle/fluid/framework/selected_rows.cc b/paddle/fluid/framework/selected_rows.cc index 56cf6693caf4529d6e157e6e9a0d5c27d05ee0c3..b4168f38949c7fcb057ec8c5c562d0529a6d9e48 100644 --- a/paddle/fluid/framework/selected_rows.cc +++ b/paddle/fluid/framework/selected_rows.cc @@ -121,24 +121,29 @@ bool SelectedRows::HasKey(int64_t key) const { } std::vector> SelectedRows::Get( - std::vector keys, framework::Tensor* value) const { + const std::vector& keys, framework::Tensor* value) const { PADDLE_ENFORCE(value->IsInitialized(), "The value tensor should be initialized."); std::vector> non_keys_pair; - int64_t value_width = value_->numel() / value_->dims()[0]; - PADDLE_ENFORCE_EQ(value_width, value->numel() / value->dims()[0], - "output tensor should have the same shape with table " - "execpt the dims[0]."); - - for (size_t i = 0; i < keys.size(); ++i) { - int64_t index = Index(keys[i]); - if (index == -1) { - non_keys_pair.push_back(std::make_pair(keys[i], static_cast(i))); - } else { - framework::VisitDataType( - framework::ToDataType(value_->type()), - TensorCopyVisitor(value, i * value_width, *value_.get(), - index * value_width, value_width)); + if (keys.empty()) { + VLOG(3) << "keys is empty, please check data!"; + } else { + int64_t value_width = value_->numel() / value_->dims()[0]; + PADDLE_ENFORCE_EQ(value_width, value->numel() / value->dims()[0], + "output tensor should have the same shape with table " + "except the dims[0]."); + + for (size_t i = 0; i < keys.size(); ++i) { + int64_t index = Index(keys[i]); + if (index == -1) { + non_keys_pair.push_back( + std::make_pair(keys[i], static_cast(i))); + } else { + framework::VisitDataType( + framework::ToDataType(value_->type()), + TensorCopyVisitor(value, i * value_width, *value_.get(), + index * value_width, value_width)); + } } } return non_keys_pair; diff --git a/paddle/fluid/framework/selected_rows.h b/paddle/fluid/framework/selected_rows.h index c27c927ee751c4392840bfb71f4814991b23a8c9..c80b05eed9b1c50325316057a8afc26d5d52e82c 100644 --- a/paddle/fluid/framework/selected_rows.h +++ b/paddle/fluid/framework/selected_rows.h @@ -82,7 +82,7 @@ class SelectedRows { * @return a list of pair which contains the non-exists key and the index in * the value */ - std::vector> Get(std::vector keys, + std::vector> Get(const std::vector& keys, framework::Tensor* value) const; /* diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 58faead2bdf9a89749e08207d964836bbf5cb68e..361cc24b5ba11e2654f1282327730befaeca9f55 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -177,11 +177,8 @@ class RequestPrefetch final : public RequestBase { program_(program), prefetch_ctx_(prefetch_ctx), req_id_(req_id) { - if (sync_mode_) { - request_.reset(new VariableResponse(scope, dev_ctx_, false)); - } else { - request_.reset(new VariableResponse(scope, dev_ctx_, true)); - } + // prefetch always create a new sub scope + request_.reset(new VariableResponse(scope, dev_ctx_, true)); int method_id = static_cast(detail::GrpcMethod::kPrefetchVariable); service_->RequestAsyncUnary( method_id, &ctx_, request_.get(), &responder_, cq_, cq_, @@ -198,10 +195,10 @@ class RequestPrefetch final : public RequestBase { std::string var_name = request_->OutVarname(); VLOG(3) << "RequestPrefetch " << var_name; auto var_desc = program_->Block(0).FindVar(var_name); - framework::Scope* local_scope = &scope_->NewScope(); + framework::Scope* local_scope = request_->GetMutableLocalScope(); auto* var = local_scope->FindVar(var_name); InitializeVariable(var, var_desc->GetType()); - executor_->RunPreparedContext(prefetch_ctx_, scope_); + executor_->RunPreparedContext(prefetch_ctx_, local_scope); SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply_);