提交 25f47fc0 编写于 作者: Q qiaolongfei

fix prefetch bugs, optimize code

上级 bf869e45
......@@ -121,24 +121,29 @@ bool SelectedRows::HasKey(int64_t key) const {
}
std::vector<std::pair<int64_t, int64_t>> SelectedRows::Get(
std::vector<int64_t> keys, framework::Tensor* value) const {
const std::vector<int64_t>& keys, framework::Tensor* value) const {
PADDLE_ENFORCE(value->IsInitialized(),
"The value tensor should be initialized.");
std::vector<std::pair<int64_t, int64_t>> non_keys_pair;
int64_t value_width = value_->numel() / value_->dims()[0];
PADDLE_ENFORCE_EQ(value_width, value->numel() / value->dims()[0],
"output tensor should have the same shape with table "
"execpt the dims[0].");
for (size_t i = 0; i < keys.size(); ++i) {
int64_t index = Index(keys[i]);
if (index == -1) {
non_keys_pair.push_back(std::make_pair(keys[i], static_cast<int64_t>(i)));
} else {
framework::VisitDataType(
framework::ToDataType(value_->type()),
TensorCopyVisitor(value, i * value_width, *value_.get(),
index * value_width, value_width));
if (keys.empty()) {
VLOG(3) << "keys is empty, please check data!";
} else {
int64_t value_width = value_->numel() / value_->dims()[0];
PADDLE_ENFORCE_EQ(value_width, value->numel() / value->dims()[0],
"output tensor should have the same shape with table "
"except the dims[0].");
for (size_t i = 0; i < keys.size(); ++i) {
int64_t index = Index(keys[i]);
if (index == -1) {
non_keys_pair.push_back(
std::make_pair(keys[i], static_cast<int64_t>(i)));
} else {
framework::VisitDataType(
framework::ToDataType(value_->type()),
TensorCopyVisitor(value, i * value_width, *value_.get(),
index * value_width, value_width));
}
}
}
return non_keys_pair;
......
......@@ -82,7 +82,7 @@ class SelectedRows {
* @return a list of pair which contains the non-exists key and the index in
* the value
*/
std::vector<std::pair<int64_t, int64_t>> Get(std::vector<int64_t> keys,
std::vector<std::pair<int64_t, int64_t>> Get(const std::vector<int64_t>& keys,
framework::Tensor* value) const;
/*
......
......@@ -177,11 +177,8 @@ class RequestPrefetch final : public RequestBase {
program_(program),
prefetch_ctx_(prefetch_ctx),
req_id_(req_id) {
if (sync_mode_) {
request_.reset(new VariableResponse(scope, dev_ctx_, false));
} else {
request_.reset(new VariableResponse(scope, dev_ctx_, true));
}
// prefetch always create a new sub scope
request_.reset(new VariableResponse(scope, dev_ctx_, true));
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
......@@ -198,10 +195,10 @@ class RequestPrefetch final : public RequestBase {
std::string var_name = request_->OutVarname();
VLOG(3) << "RequestPrefetch " << var_name;
auto var_desc = program_->Block(0).FindVar(var_name);
framework::Scope* local_scope = &scope_->NewScope();
framework::Scope* local_scope = request_->GetMutableLocalScope();
auto* var = local_scope->FindVar(var_name);
InitializeVariable(var, var_desc->GetType());
executor_->RunPreparedContext(prefetch_ctx_, scope_);
executor_->RunPreparedContext(prefetch_ctx_, local_scope);
SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply_);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册