未验证 提交 caf9a09d 编写于 作者: Y Yancey 提交者: GitHub

Merge selected rows with dynamic variable count (#8023)

* dynamic send/recv selected rows

* update by comment

* fix by comment
上级 4f4abfa3
...@@ -101,6 +101,9 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -101,6 +101,9 @@ class ListenAndServOp : public framework::OperatorBase {
// TODO(typhoonzero): change this to a while_op for every cluster-batch. // TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool exit_flag = false; bool exit_flag = false;
// Record received sparse variables, so that
// we could reset those after execute optimize program
std::vector<framework::Variable *> sparse_vars;
while (!exit_flag) { while (!exit_flag) {
// Get from multiple trainers, we don't care about the order in which // Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient. // the gradients arrives, just add suffix 0~n and merge the gradient.
...@@ -143,6 +146,9 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -143,6 +146,9 @@ class ListenAndServOp : public framework::OperatorBase {
PADDLE_THROW("Can not find server side var"); PADDLE_THROW("Can not find server side var");
} }
detail::DeserializeFromMessage(v.second, dev_ctx, var); detail::DeserializeFromMessage(v.second, dev_ctx, var);
if (var->IsType<framework::SelectedRows>()) {
sparse_vars.push_back(var);
}
} }
} }
VLOG(3) << "recv " << recv_var_cnt << " parmeters for one barrier."; VLOG(3) << "recv " << recv_var_cnt << " parmeters for one barrier.";
...@@ -156,9 +162,19 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -156,9 +162,19 @@ class ListenAndServOp : public framework::OperatorBase {
} catch (std::exception &e) { } catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what(); LOG(ERROR) << "run sub program error " << e.what();
} }
// Reset the received sparse variables, the sum operator would not
// sum the input sparse variables which rows is empty at the next
// mini-batch.
// TOOD(Yancey1989): move the reset action into an operator, we couldn't
// have any hide logic in the operator.
for (auto &var : sparse_vars) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
}
rpc_service_->SetCond(1); rpc_service_->SetCond(1);
rpc_service_->WaitClientGet(update_param_cnt); rpc_service_->WaitClientGet(update_param_cnt);
grads_counter_.clear(); grads_counter_.clear();
sparse_vars.clear();
} // while(true) } // while(true)
} }
......
...@@ -24,6 +24,22 @@ limitations under the License. */ ...@@ -24,6 +24,22 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
static bool IsVariableInitialized(const framework::Scope& scope,
const std::string& varname) {
auto* var = scope.FindVar(varname);
PADDLE_ENFORCE_NOT_NULL(var, "Can not find variable '%s' in the send side.",
varname);
if (var->IsType<framework::LoDTensor>()) {
return var->Get<framework::LoDTensor>().IsInitialized();
} else if (var->IsType<framework::SelectedRows>()) {
return var->Get<framework::SelectedRows>().value().IsInitialized();
} else {
PADDLE_THROW(
"Variable type in send side should be in "
"[LodTensor, SelectedRows]");
}
return false;
}
class SendOp : public framework::OperatorBase { class SendOp : public framework::OperatorBase {
public: public:
...@@ -51,8 +67,12 @@ class SendOp : public framework::OperatorBase { ...@@ -51,8 +67,12 @@ class SendOp : public framework::OperatorBase {
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>(); detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
if (IsVariableInitialized(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i]; VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
rpc_client->AsyncSendVariable(epmap[i], ctx, scope, ins[i]); rpc_client->AsyncSendVariable(epmap[i], ctx, scope, ins[i]);
} else {
VLOG(3) << "don't send no-initialied variable: " << ins[i];
}
} }
PADDLE_ENFORCE(rpc_client->Wait()); PADDLE_ENFORCE(rpc_client->Wait());
......
...@@ -22,7 +22,7 @@ class SplitSelectedRowsOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -22,7 +22,7 @@ class SplitSelectedRowsOpMaker : public framework::OpProtoAndCheckerMaker {
SplitSelectedRowsOpMaker(OpProto *proto, OpAttrChecker *op_checker) SplitSelectedRowsOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input SelectedRows."); AddInput("X", "The input SelectedRows.");
AddOutput("Out", "The outputs of input SelectedRows.").AsDuplicable(); AddOutput("Out", "The outputs of the input SelectedRows.").AsDuplicable();
AddAttr<std::vector<int>>("height_sections", AddAttr<std::vector<int>>("height_sections",
"Height for each output SelectedRows.") "Height for each output SelectedRows.")
.SetDefault(std::vector<int>({})); .SetDefault(std::vector<int>({}));
...@@ -56,27 +56,6 @@ class SplitSelectedRowsOp : public framework::OperatorWithKernel { ...@@ -56,27 +56,6 @@ class SplitSelectedRowsOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("X"), "SplitSelectedRowsOp must has input X."); PADDLE_ENFORCE(ctx->HasInput("X"), "SplitSelectedRowsOp must has input X.");
PADDLE_ENFORCE(ctx->HasOutputs("Out"), PADDLE_ENFORCE(ctx->HasOutputs("Out"),
"SplitSelectedRowsOp must has output Out."); "SplitSelectedRowsOp must has output Out.");
std::vector<int> height_sections =
ctx->Attrs().Get<std::vector<int>>("height_sections");
int64_t n = ctx->Outputs("Out").size();
std::vector<framework::DDim> outs_dims;
outs_dims.reserve(n);
// make output dims
for (int64_t i = 0; i < n; ++i) {
auto dims = ctx->GetInputDim("X");
if (height_sections.size()) {
PADDLE_ENFORCE_EQ(
height_sections.size(), static_cast<size_t>(n),
"The size of height section should be the same with height"
" section size.");
dims[0] = height_sections[i];
}
outs_dims.push_back(dims);
}
ctx->SetOutputsDim("Out", outs_dims);
} }
}; };
......
...@@ -55,6 +55,7 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel<T> { ...@@ -55,6 +55,7 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel<T> {
for (size_t i = 0; i < outs_rows_idx.size(); ++i) { for (size_t i = 0; i < outs_rows_idx.size(); ++i) {
auto rows_idx = outs_rows_idx[i]; auto rows_idx = outs_rows_idx[i];
outs[i]->set_height(height_sections[i]);
if (rows_idx.size() > 0) { if (rows_idx.size() > 0) {
auto dims = x->GetCompleteDims(); auto dims = x->GetCompleteDims();
dims[0] = rows_idx.size(); dims[0] = rows_idx.size();
......
...@@ -116,7 +116,9 @@ class SumKernel : public framework::OpKernel<T> { ...@@ -116,7 +116,9 @@ class SumKernel : public framework::OpKernel<T> {
int64_t offset = 0; int64_t offset = 0;
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
auto &sel_row = get_selected_row(i); auto &sel_row = get_selected_row(i);
if (!sel_row.value().IsInitialized() || sel_row.rows().size() == 0) {
continue;
}
PADDLE_ENFORCE_EQ(out->height(), sel_row.height()); PADDLE_ENFORCE_EQ(out->height(), sel_row.height());
functor(context.template device_context<DeviceContext>(), sel_row, functor(context.template device_context<DeviceContext>(), sel_row,
offset, out); offset, out);
......
...@@ -191,6 +191,7 @@ class DistributeTranspiler: ...@@ -191,6 +191,7 @@ class DistributeTranspiler:
for b in param_blocks: for b in param_blocks:
varname, block_id, _ = b.split(":") varname, block_id, _ = b.split(":")
send_outputs.append(param_var_mapping[varname][int(block_id)]) send_outputs.append(param_var_mapping[varname][int(block_id)])
# let send_op know which endpoint to send which var to, eplist has the same # let send_op know which endpoint to send which var to, eplist has the same
# order as send_inputs. # order as send_inputs.
eplist = split_method(send_inputs, pserver_endpoints) eplist = split_method(send_inputs, pserver_endpoints)
...@@ -274,6 +275,7 @@ class DistributeTranspiler: ...@@ -274,6 +275,7 @@ class DistributeTranspiler:
name="%s.block%d" % (varname, i), name="%s.block%d" % (varname, i),
psersistable=False, psersistable=False,
dtype=orig_var.dtype, dtype=orig_var.dtype,
type=orig_var.type,
shape=splited_shape) # flattend splited var shape=splited_shape) # flattend splited var
var_mapping[varname].append(var) var_mapping[varname].append(var)
return var_mapping return var_mapping
...@@ -335,6 +337,7 @@ class DistributeTranspiler: ...@@ -335,6 +337,7 @@ class DistributeTranspiler:
name="%s.trainer_%d" % (var.name, i), name="%s.trainer_%d" % (var.name, i),
psersistable=var.persistable, psersistable=var.persistable,
dtype=var.dtype, dtype=var.dtype,
type=var.type,
shape=var.shape) shape=var.shape)
var_list.append(var_each) var_list.append(var_each)
return var_list return var_list
...@@ -561,6 +564,7 @@ class DistributeTranspiler: ...@@ -561,6 +564,7 @@ class DistributeTranspiler:
persistable=True, persistable=True,
dtype=v.dtype, dtype=v.dtype,
shape=v.shape) shape=v.shape)
# step6 # step6
optimize_block = pserver_program.create_block(0) optimize_block = pserver_program.create_block(0)
# step 6.1 # step 6.1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册