提交 25f47fc0 编写于 作者: Q qiaolongfei

fix prefetch bugs, optimize code

上级 bf869e45
...@@ -121,24 +121,29 @@ bool SelectedRows::HasKey(int64_t key) const { ...@@ -121,24 +121,29 @@ bool SelectedRows::HasKey(int64_t key) const {
} }
std::vector<std::pair<int64_t, int64_t>> SelectedRows::Get( std::vector<std::pair<int64_t, int64_t>> SelectedRows::Get(
std::vector<int64_t> keys, framework::Tensor* value) const { const std::vector<int64_t>& keys, framework::Tensor* value) const {
PADDLE_ENFORCE(value->IsInitialized(), PADDLE_ENFORCE(value->IsInitialized(),
"The value tensor should be initialized."); "The value tensor should be initialized.");
std::vector<std::pair<int64_t, int64_t>> non_keys_pair; std::vector<std::pair<int64_t, int64_t>> non_keys_pair;
int64_t value_width = value_->numel() / value_->dims()[0]; if (keys.empty()) {
PADDLE_ENFORCE_EQ(value_width, value->numel() / value->dims()[0], VLOG(3) << "keys is empty, please check data!";
"output tensor should have the same shape with table " } else {
"execpt the dims[0]."); int64_t value_width = value_->numel() / value_->dims()[0];
PADDLE_ENFORCE_EQ(value_width, value->numel() / value->dims()[0],
for (size_t i = 0; i < keys.size(); ++i) { "output tensor should have the same shape with table "
int64_t index = Index(keys[i]); "except the dims[0].");
if (index == -1) {
non_keys_pair.push_back(std::make_pair(keys[i], static_cast<int64_t>(i))); for (size_t i = 0; i < keys.size(); ++i) {
} else { int64_t index = Index(keys[i]);
framework::VisitDataType( if (index == -1) {
framework::ToDataType(value_->type()), non_keys_pair.push_back(
TensorCopyVisitor(value, i * value_width, *value_.get(), std::make_pair(keys[i], static_cast<int64_t>(i)));
index * value_width, value_width)); } else {
framework::VisitDataType(
framework::ToDataType(value_->type()),
TensorCopyVisitor(value, i * value_width, *value_.get(),
index * value_width, value_width));
}
} }
} }
return non_keys_pair; return non_keys_pair;
......
...@@ -82,7 +82,7 @@ class SelectedRows { ...@@ -82,7 +82,7 @@ class SelectedRows {
* @return a list of pair which contains the non-exists key and the index in * @return a list of pair which contains the non-exists key and the index in
* the value * the value
*/ */
std::vector<std::pair<int64_t, int64_t>> Get(std::vector<int64_t> keys, std::vector<std::pair<int64_t, int64_t>> Get(const std::vector<int64_t>& keys,
framework::Tensor* value) const; framework::Tensor* value) const;
/* /*
......
...@@ -177,11 +177,8 @@ class RequestPrefetch final : public RequestBase { ...@@ -177,11 +177,8 @@ class RequestPrefetch final : public RequestBase {
program_(program), program_(program),
prefetch_ctx_(prefetch_ctx), prefetch_ctx_(prefetch_ctx),
req_id_(req_id) { req_id_(req_id) {
if (sync_mode_) { // prefetch always create a new sub scope
request_.reset(new VariableResponse(scope, dev_ctx_, false)); request_.reset(new VariableResponse(scope, dev_ctx_, true));
} else {
request_.reset(new VariableResponse(scope, dev_ctx_, true));
}
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable); int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
service_->RequestAsyncUnary( service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_, method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
...@@ -198,10 +195,10 @@ class RequestPrefetch final : public RequestBase { ...@@ -198,10 +195,10 @@ class RequestPrefetch final : public RequestBase {
std::string var_name = request_->OutVarname(); std::string var_name = request_->OutVarname();
VLOG(3) << "RequestPrefetch " << var_name; VLOG(3) << "RequestPrefetch " << var_name;
auto var_desc = program_->Block(0).FindVar(var_name); auto var_desc = program_->Block(0).FindVar(var_name);
framework::Scope* local_scope = &scope_->NewScope(); framework::Scope* local_scope = request_->GetMutableLocalScope();
auto* var = local_scope->FindVar(var_name); auto* var = local_scope->FindVar(var_name);
InitializeVariable(var, var_desc->GetType()); InitializeVariable(var, var_desc->GetType());
executor_->RunPreparedContext(prefetch_ctx_, scope_); executor_->RunPreparedContext(prefetch_ctx_, local_scope);
SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply_); SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply_);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册