未验证 提交 ab953bae 编写于 作者: Q Qiao Longfei 提交者: GitHub

Merge pull request #10973 from jacquesqiao/fix-prefetch

Fix and optimize async distribute lookup table
...@@ -121,24 +121,29 @@ bool SelectedRows::HasKey(int64_t key) const { ...@@ -121,24 +121,29 @@ bool SelectedRows::HasKey(int64_t key) const {
} }
std::vector<std::pair<int64_t, int64_t>> SelectedRows::Get( std::vector<std::pair<int64_t, int64_t>> SelectedRows::Get(
std::vector<int64_t> keys, framework::Tensor* value) const { const std::vector<int64_t>& keys, framework::Tensor* value) const {
PADDLE_ENFORCE(value->IsInitialized(), PADDLE_ENFORCE(value->IsInitialized(),
"The value tensor should be initialized."); "The value tensor should be initialized.");
std::vector<std::pair<int64_t, int64_t>> non_keys_pair; std::vector<std::pair<int64_t, int64_t>> non_keys_pair;
int64_t value_width = value_->numel() / value_->dims()[0]; if (keys.empty()) {
PADDLE_ENFORCE_EQ(value_width, value->numel() / value->dims()[0], VLOG(3) << "keys is empty, please check data!";
"output tensor should have the same shape with table " } else {
"execpt the dims[0]."); int64_t value_width = value_->numel() / value_->dims()[0];
PADDLE_ENFORCE_EQ(value_width, value->numel() / value->dims()[0],
for (size_t i = 0; i < keys.size(); ++i) { "output tensor should have the same shape with table "
int64_t index = Index(keys[i]); "except the dims[0].");
if (index == -1) {
non_keys_pair.push_back(std::make_pair(keys[i], static_cast<int64_t>(i))); for (size_t i = 0; i < keys.size(); ++i) {
} else { int64_t index = Index(keys[i]);
framework::VisitDataType( if (index == -1) {
framework::ToDataType(value_->type()), non_keys_pair.push_back(
TensorCopyVisitor(value, i * value_width, *value_.get(), std::make_pair(keys[i], static_cast<int64_t>(i)));
index * value_width, value_width)); } else {
framework::VisitDataType(
framework::ToDataType(value_->type()),
TensorCopyVisitor(value, i * value_width, *value_.get(),
index * value_width, value_width));
}
} }
} }
return non_keys_pair; return non_keys_pair;
......
...@@ -82,7 +82,7 @@ class SelectedRows { ...@@ -82,7 +82,7 @@ class SelectedRows {
* @return a list of pair which contains the non-exists key and the index in * @return a list of pair which contains the non-exists key and the index in
* the value * the value
*/ */
std::vector<std::pair<int64_t, int64_t>> Get(std::vector<int64_t> keys, std::vector<std::pair<int64_t, int64_t>> Get(const std::vector<int64_t>& keys,
framework::Tensor* value) const; framework::Tensor* value) const;
/* /*
......
...@@ -177,11 +177,8 @@ class RequestPrefetch final : public RequestBase { ...@@ -177,11 +177,8 @@ class RequestPrefetch final : public RequestBase {
program_(program), program_(program),
prefetch_ctx_(prefetch_ctx), prefetch_ctx_(prefetch_ctx),
req_id_(req_id) { req_id_(req_id) {
if (sync_mode_) { // prefetch always create a new sub scope
request_.reset(new VariableResponse(scope, dev_ctx_, false)); request_.reset(new VariableResponse(scope, dev_ctx_, true));
} else {
request_.reset(new VariableResponse(scope, dev_ctx_, true));
}
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable); int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
service_->RequestAsyncUnary( service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_, method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
...@@ -198,10 +195,10 @@ class RequestPrefetch final : public RequestBase { ...@@ -198,10 +195,10 @@ class RequestPrefetch final : public RequestBase {
std::string var_name = request_->OutVarname(); std::string var_name = request_->OutVarname();
VLOG(3) << "RequestPrefetch " << var_name; VLOG(3) << "RequestPrefetch " << var_name;
auto var_desc = program_->Block(0).FindVar(var_name); auto var_desc = program_->Block(0).FindVar(var_name);
framework::Scope* local_scope = &scope_->NewScope(); framework::Scope* local_scope = request_->GetMutableLocalScope();
auto* var = local_scope->FindVar(var_name); auto* var = local_scope->FindVar(var_name);
InitializeVariable(var, var_desc->GetType()); InitializeVariable(var, var_desc->GetType());
executor_->RunPreparedContext(prefetch_ctx_, scope_); executor_->RunPreparedContext(prefetch_ctx_, local_scope);
SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply_); SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply_);
......
...@@ -207,6 +207,7 @@ static void AsyncUpdateThread( ...@@ -207,6 +207,7 @@ static void AsyncUpdateThread(
while (!exit_flag) { while (!exit_flag) {
const detail::ReceivedMessage v = queue->Pop(); const detail::ReceivedMessage v = queue->Pop();
auto recv_var_name = v.first; auto recv_var_name = v.first;
VLOG(4) << "async update " << recv_var_name;
auto var = v.second->GetVar(); auto var = v.second->GetVar();
if (var == nullptr) { if (var == nullptr) {
LOG(ERROR) << "Can not find server side var: " << recv_var_name; LOG(ERROR) << "Can not find server side var: " << recv_var_name;
......
...@@ -127,7 +127,7 @@ class LookupSparseTableOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -127,7 +127,7 @@ class LookupSparseTableOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(-1.0f); .SetDefault(-1.0f);
AddAttr<float>("max", AddAttr<float>("max",
"(float, default 1.0) " "(float, default 1.0) "
"Maximun value of uniform random") "Maximum value of uniform random")
.SetDefault(1.0f); .SetDefault(1.0f);
AddAttr<int>("seed", AddAttr<int>("seed",
"(int, default 0) " "(int, default 0) "
......
...@@ -96,8 +96,12 @@ class SGDOpKernel : public framework::OpKernel<T> { ...@@ -96,8 +96,12 @@ class SGDOpKernel : public framework::OpKernel<T> {
return; return;
} }
size_t param_row_width = param.value().numel() / param.rows().size(); auto param_row_width = param.value().dims()[1];
size_t grad_row_width = grad.value().numel() / grad.rows().size(); auto grad_row_width = grad.value().dims()[1];
VLOG(4) << " param rows: " << param.rows().size()
<< " param memory rows: " << param.value().dims()[0]
<< " grad rows: " << grad.rows().size()
<< " grad memory rows: " << grad.value().dims()[0];
PADDLE_ENFORCE_EQ(param_row_width, grad_row_width, PADDLE_ENFORCE_EQ(param_row_width, grad_row_width,
"param_row should have the same size with grad_row"); "param_row should have the same size with grad_row");
......
...@@ -797,7 +797,7 @@ class Block(object): ...@@ -797,7 +797,7 @@ class Block(object):
Rename variable in vars and ops' inputs and outputs Rename variable in vars and ops' inputs and outputs
""" """
if not self.has_var(name): if not self.has_var(name):
raise ValueError("var %s is not in current" % name) raise ValueError("var %s is not in current block" % name)
v = self.var(name) v = self.var(name)
if type(v) == Parameter: if type(v) == Parameter:
var_type = "Parameter" var_type = "Parameter"
...@@ -843,6 +843,7 @@ class Block(object): ...@@ -843,6 +843,7 @@ class Block(object):
self.vars[new_name] = var self.vars[new_name] = var
del self.vars[name] del self.vars[name]
self.sync_with_cpp() self.sync_with_cpp()
return var
def remove_var(self, name): def remove_var(self, name):
self.sync_with_cpp() self.sync_with_cpp()
......
...@@ -273,15 +273,25 @@ class DistributeTranspiler: ...@@ -273,15 +273,25 @@ class DistributeTranspiler:
if param_grad[0].name == self.table_name if param_grad[0].name == self.table_name
][0] ][0]
table_grad_var = self.table_param_grad[1] table_grad_var = self.table_param_grad[1]
self.table_grad_list = [ if self.sync_mode:
program.global_block().create_var( self.trainer_side_table_grad_list = [
name="%s.trainer_%d.pserver_%d" % program.global_block().create_var(
(table_grad_var.name, trainer_id, index), name="%s.trainer_%d.pserver_%d" %
type=table_grad_var.type, (table_grad_var.name, trainer_id, index),
shape=table_grad_var.shape, type=table_grad_var.type,
dtype=table_grad_var.dtype) shape=table_grad_var.shape,
for index in range(len(self.pserver_endpoints)) dtype=table_grad_var.dtype)
] for index in range(len(self.pserver_endpoints))
]
else:
self.trainer_side_table_grad_list = [
program.global_block().create_var(
name="%s.pserver_%d" % (table_grad_var.name, index),
type=table_grad_var.type,
shape=table_grad_var.shape,
dtype=table_grad_var.dtype)
for index in range(len(self.pserver_endpoints))
]
grad_blocks = split_dense_variable(grad_list, len(pserver_endpoints)) grad_blocks = split_dense_variable(grad_list, len(pserver_endpoints))
param_blocks = split_dense_variable(param_list, len(pserver_endpoints)) param_blocks = split_dense_variable(param_list, len(pserver_endpoints))
...@@ -400,7 +410,8 @@ class DistributeTranspiler: ...@@ -400,7 +410,8 @@ class DistributeTranspiler:
attrs={"axis": 0}) attrs={"axis": 0})
if self.has_distributed_lookup_table: if self.has_distributed_lookup_table:
self._replace_lookup_table_op_with_prefetch(program, eplist) self._replace_lookup_table_op_with_prefetch(program,
pserver_endpoints)
self._split_table_grad_and_add_send_vars(program, pserver_endpoints) self._split_table_grad_and_add_send_vars(program, pserver_endpoints)
def get_trainer_program(self): def get_trainer_program(self):
...@@ -537,7 +548,7 @@ class DistributeTranspiler: ...@@ -537,7 +548,7 @@ class DistributeTranspiler:
if self.has_distributed_lookup_table: if self.has_distributed_lookup_table:
pserver_index = self.pserver_endpoints.index(endpoint) pserver_index = self.pserver_endpoints.index(endpoint)
table_opt_block = self._create_table_optimize_block( table_opt_block = self._create_table_optimize_block(
pserver_index, pserver_program, pre_block_idx) pserver_index, pserver_program, pre_block_idx, grad_to_block_id)
prefetch_block = self._create_prefetch_block( prefetch_block = self._create_prefetch_block(
pserver_index, pserver_program, table_opt_block) pserver_index, pserver_program, table_opt_block)
...@@ -621,7 +632,8 @@ class DistributeTranspiler: ...@@ -621,7 +632,8 @@ class DistributeTranspiler:
return s_prog return s_prog
# transpiler function for dis lookup_table # transpiler function for dis lookup_table
def _replace_lookup_table_op_with_prefetch(self, program, eplist): def _replace_lookup_table_op_with_prefetch(self, program,
pserver_endpoints):
# 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op # 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op
self.prefetch_input_vars = None self.prefetch_input_vars = None
self.prefetch_output_vars = None self.prefetch_output_vars = None
...@@ -670,7 +682,7 @@ class DistributeTranspiler: ...@@ -670,7 +682,7 @@ class DistributeTranspiler:
inputs={'X': self.prefetch_input_vars}, inputs={'X': self.prefetch_input_vars},
outputs={"Out": self.prefetch_output_vars}, outputs={"Out": self.prefetch_output_vars},
attrs={ attrs={
"epmap": eplist, "epmap": pserver_endpoints,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
}) })
...@@ -707,11 +719,11 @@ class DistributeTranspiler: ...@@ -707,11 +719,11 @@ class DistributeTranspiler:
inputs={ inputs={
'Ids': [program.global_block().vars[table_grad_name]] 'Ids': [program.global_block().vars[table_grad_name]]
}, },
outputs={"Out": self.table_grad_list}) outputs={"Out": self.trainer_side_table_grad_list})
program.global_block().insert_op( program.global_block().insert_op(
index=op_index + 2, index=op_index + 2,
type="send_vars", type="send_vars",
inputs={'X': self.table_grad_list}, inputs={'X': self.trainer_side_table_grad_list},
outputs={}, outputs={},
attrs={ attrs={
"sync_send": True, "sync_send": True,
...@@ -750,16 +762,7 @@ class DistributeTranspiler: ...@@ -750,16 +762,7 @@ class DistributeTranspiler:
return prefetch_block return prefetch_block
def _create_table_optimize_block(self, pserver_index, pserver_program, def _create_table_optimize_block(self, pserver_index, pserver_program,
pre_block_idx): pre_block_idx, grad_to_block_id):
def _clone_var(block, var, persistable=True):
assert isinstance(var, Variable)
return block.create_var(
name=var.name,
shape=var.shape,
dtype=var.dtype,
type=var.type,
persistable=persistable)
# STEP: create table optimize block # STEP: create table optimize block
# create table param and grad var in pserver program # create table param and grad var in pserver program
origin_param_var = self.origin_program.global_block().vars[ origin_param_var = self.origin_program.global_block().vars[
...@@ -770,11 +773,11 @@ class DistributeTranspiler: ...@@ -770,11 +773,11 @@ class DistributeTranspiler:
dtype=origin_param_var.dtype, dtype=origin_param_var.dtype,
type=core.VarDesc.VarType.SELECTED_ROWS, type=core.VarDesc.VarType.SELECTED_ROWS,
persistable=True) persistable=True)
grad_var = _clone_var( # parameter must be selected rows
pserver_program.global_block(), param_var.desc.set_type(core.VarDesc.VarType.SELECTED_ROWS)
grad_var = pserver_program.global_block().clone_variable(
self.origin_program.global_block().vars[grad_var_name( self.origin_program.global_block().vars[grad_var_name(
self.table_name)], self.table_name)])
persistable=False)
# create table optimize block in pserver program # create table optimize block in pserver program
table_opt_op = [ table_opt_op = [
...@@ -788,7 +791,7 @@ class DistributeTranspiler: ...@@ -788,7 +791,7 @@ class DistributeTranspiler:
if self.sync_mode: if self.sync_mode:
# create grad vars in pserver program # create grad vars in pserver program
table_grad_var = self.table_param_grad[1] table_grad_var = self.table_param_grad[1]
table_grad_list = [ pserver_side_table_grad_list = [
pserver_program.global_block().create_var( pserver_program.global_block().create_var(
name="%s.trainer_%d.pserver_%d" % name="%s.trainer_%d.pserver_%d" %
(table_grad_var.name, index, pserver_index), (table_grad_var.name, index, pserver_index),
...@@ -798,11 +801,21 @@ class DistributeTranspiler: ...@@ -798,11 +801,21 @@ class DistributeTranspiler:
for index in range(self.trainer_num) for index in range(self.trainer_num)
] ]
# append sum op for table_grad_list # append sum op for pserver_side_table_grad_list
table_opt_block.append_op( table_opt_block.append_op(
type="sum", type="sum",
inputs={"X": table_grad_list}, inputs={"X": pserver_side_table_grad_list},
outputs={"Out": [grad_var]}) outputs={"Out": [grad_var]})
else:
# in async_mode, for table gradient, it also need to be splited to each parameter server
origin_grad_name = grad_var.name
splited_grad_name = self.trainer_side_table_grad_list[
pserver_index].name
if not splited_grad_name.startswith(origin_grad_name):
raise ValueError("origin_grad_var: " + splited_grad_name +
" grad_var:" + grad_var.name)
grad_var = pserver_program.global_block().rename_var(
origin_grad_name, splited_grad_name)
lr_var = pserver_program.global_block().vars[table_opt_op.input( lr_var = pserver_program.global_block().vars[table_opt_op.input(
"LearningRate")[0]] "LearningRate")[0]]
...@@ -818,6 +831,9 @@ class DistributeTranspiler: ...@@ -818,6 +831,9 @@ class DistributeTranspiler:
outputs=outputs, outputs=outputs,
attrs=table_opt_op.attrs) attrs=table_opt_op.attrs)
# add table parameter gradient and it's block id to grad_to_block_id
grad_to_block_id.append(grad_var.name + ":" + str(table_opt_block.idx))
return table_opt_block return table_opt_block
# ====================== private transpiler functions ===================== # ====================== private transpiler functions =====================
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册