未验证 提交 ab953bae 编写于 作者: Q Qiao Longfei 提交者: GitHub

Merge pull request #10973 from jacquesqiao/fix-prefetch

Fix and optimize async distribute lookup table
......@@ -121,19 +121,23 @@ bool SelectedRows::HasKey(int64_t key) const {
}
std::vector<std::pair<int64_t, int64_t>> SelectedRows::Get(
std::vector<int64_t> keys, framework::Tensor* value) const {
const std::vector<int64_t>& keys, framework::Tensor* value) const {
PADDLE_ENFORCE(value->IsInitialized(),
"The value tensor should be initialized.");
std::vector<std::pair<int64_t, int64_t>> non_keys_pair;
if (keys.empty()) {
VLOG(3) << "keys is empty, please check data!";
} else {
int64_t value_width = value_->numel() / value_->dims()[0];
PADDLE_ENFORCE_EQ(value_width, value->numel() / value->dims()[0],
"output tensor should have the same shape with table "
"execpt the dims[0].");
"except the dims[0].");
for (size_t i = 0; i < keys.size(); ++i) {
int64_t index = Index(keys[i]);
if (index == -1) {
non_keys_pair.push_back(std::make_pair(keys[i], static_cast<int64_t>(i)));
non_keys_pair.push_back(
std::make_pair(keys[i], static_cast<int64_t>(i)));
} else {
framework::VisitDataType(
framework::ToDataType(value_->type()),
......@@ -141,6 +145,7 @@ std::vector<std::pair<int64_t, int64_t>> SelectedRows::Get(
index * value_width, value_width));
}
}
}
return non_keys_pair;
}
......
......@@ -82,7 +82,7 @@ class SelectedRows {
* @return a list of pair which contains the non-exists key and the index in
* the value
*/
std::vector<std::pair<int64_t, int64_t>> Get(std::vector<int64_t> keys,
std::vector<std::pair<int64_t, int64_t>> Get(const std::vector<int64_t>& keys,
framework::Tensor* value) const;
/*
......
......@@ -177,11 +177,8 @@ class RequestPrefetch final : public RequestBase {
program_(program),
prefetch_ctx_(prefetch_ctx),
req_id_(req_id) {
if (sync_mode_) {
request_.reset(new VariableResponse(scope, dev_ctx_, false));
} else {
// prefetch always create a new sub scope
request_.reset(new VariableResponse(scope, dev_ctx_, true));
}
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
......@@ -198,10 +195,10 @@ class RequestPrefetch final : public RequestBase {
std::string var_name = request_->OutVarname();
VLOG(3) << "RequestPrefetch " << var_name;
auto var_desc = program_->Block(0).FindVar(var_name);
framework::Scope* local_scope = &scope_->NewScope();
framework::Scope* local_scope = request_->GetMutableLocalScope();
auto* var = local_scope->FindVar(var_name);
InitializeVariable(var, var_desc->GetType());
executor_->RunPreparedContext(prefetch_ctx_, scope_);
executor_->RunPreparedContext(prefetch_ctx_, local_scope);
SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply_);
......
......@@ -207,6 +207,7 @@ static void AsyncUpdateThread(
while (!exit_flag) {
const detail::ReceivedMessage v = queue->Pop();
auto recv_var_name = v.first;
VLOG(4) << "async update " << recv_var_name;
auto var = v.second->GetVar();
if (var == nullptr) {
LOG(ERROR) << "Can not find server side var: " << recv_var_name;
......
......@@ -127,7 +127,7 @@ class LookupSparseTableOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(-1.0f);
AddAttr<float>("max",
"(float, default 1.0) "
"Maximun value of uniform random")
"Maximum value of uniform random")
.SetDefault(1.0f);
AddAttr<int>("seed",
"(int, default 0) "
......
......@@ -96,8 +96,12 @@ class SGDOpKernel : public framework::OpKernel<T> {
return;
}
size_t param_row_width = param.value().numel() / param.rows().size();
size_t grad_row_width = grad.value().numel() / grad.rows().size();
auto param_row_width = param.value().dims()[1];
auto grad_row_width = grad.value().dims()[1];
VLOG(4) << " param rows: " << param.rows().size()
<< " param memory rows: " << param.value().dims()[0]
<< " grad rows: " << grad.rows().size()
<< " grad memory rows: " << grad.value().dims()[0];
PADDLE_ENFORCE_EQ(param_row_width, grad_row_width,
"param_row should have the same size with grad_row");
......
......@@ -797,7 +797,7 @@ class Block(object):
Rename variable in vars and ops' inputs and outputs
"""
if not self.has_var(name):
raise ValueError("var %s is not in current" % name)
raise ValueError("var %s is not in current block" % name)
v = self.var(name)
if type(v) == Parameter:
var_type = "Parameter"
......@@ -843,6 +843,7 @@ class Block(object):
self.vars[new_name] = var
del self.vars[name]
self.sync_with_cpp()
return var
def remove_var(self, name):
self.sync_with_cpp()
......
......@@ -273,7 +273,8 @@ class DistributeTranspiler:
if param_grad[0].name == self.table_name
][0]
table_grad_var = self.table_param_grad[1]
self.table_grad_list = [
if self.sync_mode:
self.trainer_side_table_grad_list = [
program.global_block().create_var(
name="%s.trainer_%d.pserver_%d" %
(table_grad_var.name, trainer_id, index),
......@@ -282,6 +283,15 @@ class DistributeTranspiler:
dtype=table_grad_var.dtype)
for index in range(len(self.pserver_endpoints))
]
else:
self.trainer_side_table_grad_list = [
program.global_block().create_var(
name="%s.pserver_%d" % (table_grad_var.name, index),
type=table_grad_var.type,
shape=table_grad_var.shape,
dtype=table_grad_var.dtype)
for index in range(len(self.pserver_endpoints))
]
grad_blocks = split_dense_variable(grad_list, len(pserver_endpoints))
param_blocks = split_dense_variable(param_list, len(pserver_endpoints))
......@@ -400,7 +410,8 @@ class DistributeTranspiler:
attrs={"axis": 0})
if self.has_distributed_lookup_table:
self._replace_lookup_table_op_with_prefetch(program, eplist)
self._replace_lookup_table_op_with_prefetch(program,
pserver_endpoints)
self._split_table_grad_and_add_send_vars(program, pserver_endpoints)
def get_trainer_program(self):
......@@ -537,7 +548,7 @@ class DistributeTranspiler:
if self.has_distributed_lookup_table:
pserver_index = self.pserver_endpoints.index(endpoint)
table_opt_block = self._create_table_optimize_block(
pserver_index, pserver_program, pre_block_idx)
pserver_index, pserver_program, pre_block_idx, grad_to_block_id)
prefetch_block = self._create_prefetch_block(
pserver_index, pserver_program, table_opt_block)
......@@ -621,7 +632,8 @@ class DistributeTranspiler:
return s_prog
# transpiler function for dis lookup_table
def _replace_lookup_table_op_with_prefetch(self, program, eplist):
def _replace_lookup_table_op_with_prefetch(self, program,
pserver_endpoints):
# 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op
self.prefetch_input_vars = None
self.prefetch_output_vars = None
......@@ -670,7 +682,7 @@ class DistributeTranspiler:
inputs={'X': self.prefetch_input_vars},
outputs={"Out": self.prefetch_output_vars},
attrs={
"epmap": eplist,
"epmap": pserver_endpoints,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
......@@ -707,11 +719,11 @@ class DistributeTranspiler:
inputs={
'Ids': [program.global_block().vars[table_grad_name]]
},
outputs={"Out": self.table_grad_list})
outputs={"Out": self.trainer_side_table_grad_list})
program.global_block().insert_op(
index=op_index + 2,
type="send_vars",
inputs={'X': self.table_grad_list},
inputs={'X': self.trainer_side_table_grad_list},
outputs={},
attrs={
"sync_send": True,
......@@ -750,16 +762,7 @@ class DistributeTranspiler:
return prefetch_block
def _create_table_optimize_block(self, pserver_index, pserver_program,
pre_block_idx):
def _clone_var(block, var, persistable=True):
assert isinstance(var, Variable)
return block.create_var(
name=var.name,
shape=var.shape,
dtype=var.dtype,
type=var.type,
persistable=persistable)
pre_block_idx, grad_to_block_id):
# STEP: create table optimize block
# create table param and grad var in pserver program
origin_param_var = self.origin_program.global_block().vars[
......@@ -770,11 +773,11 @@ class DistributeTranspiler:
dtype=origin_param_var.dtype,
type=core.VarDesc.VarType.SELECTED_ROWS,
persistable=True)
grad_var = _clone_var(
pserver_program.global_block(),
# parameter must be selected rows
param_var.desc.set_type(core.VarDesc.VarType.SELECTED_ROWS)
grad_var = pserver_program.global_block().clone_variable(
self.origin_program.global_block().vars[grad_var_name(
self.table_name)],
persistable=False)
self.table_name)])
# create table optimize block in pserver program
table_opt_op = [
......@@ -788,7 +791,7 @@ class DistributeTranspiler:
if self.sync_mode:
# create grad vars in pserver program
table_grad_var = self.table_param_grad[1]
table_grad_list = [
pserver_side_table_grad_list = [
pserver_program.global_block().create_var(
name="%s.trainer_%d.pserver_%d" %
(table_grad_var.name, index, pserver_index),
......@@ -798,11 +801,21 @@ class DistributeTranspiler:
for index in range(self.trainer_num)
]
# append sum op for table_grad_list
# append sum op for pserver_side_table_grad_list
table_opt_block.append_op(
type="sum",
inputs={"X": table_grad_list},
inputs={"X": pserver_side_table_grad_list},
outputs={"Out": [grad_var]})
else:
# in async_mode, for table gradient, it also need to be splited to each parameter server
origin_grad_name = grad_var.name
splited_grad_name = self.trainer_side_table_grad_list[
pserver_index].name
if not splited_grad_name.startswith(origin_grad_name):
raise ValueError("origin_grad_var: " + splited_grad_name +
" grad_var:" + grad_var.name)
grad_var = pserver_program.global_block().rename_var(
origin_grad_name, splited_grad_name)
lr_var = pserver_program.global_block().vars[table_opt_op.input(
"LearningRate")[0]]
......@@ -818,6 +831,9 @@ class DistributeTranspiler:
outputs=outputs,
attrs=table_opt_op.attrs)
# add table parameter gradient and it's block id to grad_to_block_id
grad_to_block_id.append(grad_var.name + ":" + str(table_opt_block.idx))
return table_opt_block
# ====================== private transpiler functions =====================
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册