未验证 提交 84680379 编写于 作者: Y Yancey 提交者: GitHub

Fix sparse update memory error for distributed training (#8837)

Fix sparse update memory error for distributed training
上级 124b7501
...@@ -24,15 +24,15 @@ limitations under the License. */ ...@@ -24,15 +24,15 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
static bool IsVariableInitialized(const framework::Scope& scope, static bool NeedSend(const framework::Scope& scope,
const std::string& varname) { const std::string& varname) {
auto* var = scope.FindVar(varname); auto* var = scope.FindVar(varname);
PADDLE_ENFORCE_NOT_NULL(var, "Can not find variable '%s' in the send side.", PADDLE_ENFORCE_NOT_NULL(var, "Can not find variable '%s' in the send side.",
varname); varname);
if (var->IsType<framework::LoDTensor>()) { if (var->IsType<framework::LoDTensor>()) {
return var->Get<framework::LoDTensor>().IsInitialized(); return var->Get<framework::LoDTensor>().IsInitialized();
} else if (var->IsType<framework::SelectedRows>()) { } else if (var->IsType<framework::SelectedRows>()) {
return var->Get<framework::SelectedRows>().value().IsInitialized(); return var->Get<framework::SelectedRows>().rows().size() > 0UL;
} else { } else {
PADDLE_THROW( PADDLE_THROW(
"Variable type in send side should be in " "Variable type in send side should be in "
...@@ -67,7 +67,7 @@ class SendOp : public framework::OperatorBase { ...@@ -67,7 +67,7 @@ class SendOp : public framework::OperatorBase {
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>(); detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
if (IsVariableInitialized(scope, ins[i])) { if (NeedSend(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i]; VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
rpc_client->AsyncSendVariable(epmap[i], ctx, scope, ins[i]); rpc_client->AsyncSendVariable(epmap[i], ctx, scope, ins[i]);
} else { } else {
......
...@@ -39,6 +39,14 @@ class SGDOp : public framework::OperatorWithKernel { ...@@ -39,6 +39,14 @@ class SGDOp : public framework::OperatorWithKernel {
// and run time. // and run time.
ctx->SetOutputDim("ParamOut", param_dim); ctx->SetOutputDim("ParamOut", param_dim);
} }
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Param")->type()),
ctx.GetPlace());
}
}; };
class SGDOpMaker : public framework::OpProtoAndCheckerMaker { class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -47,6 +47,12 @@ class SGDOpKernel : public framework::OpKernel<T> { ...@@ -47,6 +47,12 @@ class SGDOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(param, param_out); PADDLE_ENFORCE_EQ(param, param_out);
auto* grad = ctx.Input<framework::SelectedRows>("Grad"); auto* grad = ctx.Input<framework::SelectedRows>("Grad");
// for distributed training, a sparse var may be empty,
// just skip updating.
if (grad->rows().size() == 0) {
return;
}
auto in_height = grad->height(); auto in_height = grad->height();
auto out_dims = param_out->dims(); auto out_dims = param_out->dims();
PADDLE_ENFORCE_EQ(in_height, out_dims[0]); PADDLE_ENFORCE_EQ(in_height, out_dims[0]);
...@@ -60,13 +66,15 @@ class SGDOpKernel : public framework::OpKernel<T> { ...@@ -60,13 +66,15 @@ class SGDOpKernel : public framework::OpKernel<T> {
auto* in_data = in_value.data<T>(); auto* in_data = in_value.data<T>();
auto* out_data = param_out->data<T>(); auto* out_data = param_out->data<T>();
auto* lr = learning_rate->data<T>(); auto* lr = learning_rate->data<T>();
for (size_t i = 0; i < in_rows.size(); i++) { for (size_t i = 0; i < in_rows.size(); i++) {
PADDLE_ENFORCE(in_rows[i] < in_height,
"Input rows index should less than height");
for (int64_t j = 0; j < in_row_numel; j++) { for (int64_t j = 0; j < in_row_numel; j++) {
out_data[in_rows[i] * in_row_numel + j] -= out_data[in_rows[i] * in_row_numel + j] -=
lr[0] * in_data[i * in_row_numel + j]; lr[0] * in_data[i * in_row_numel + j];
} }
} }
} else { } else {
PADDLE_THROW("Unsupported Variable Type of Grad"); PADDLE_THROW("Unsupported Variable Type of Grad");
} }
......
...@@ -21,15 +21,24 @@ limitations under the License. */ ...@@ -21,15 +21,24 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
static int FindOutIdx(int row, const std::vector<int>& height_sections) { static int FindOutIdx(int row, const std::vector<int>& abs_sections) {
int offset = 0; for (size_t i = 1; i < abs_sections.size(); ++i) {
for (size_t i = 0; i < height_sections.size(); ++i) { if (row < abs_sections[i]) {
if (row >= offset && row < (offset + height_sections[i])) { return i - 1;
return i;
} }
offset += height_sections[i];
} }
return -1; return abs_sections.size() - 1;
}
static std::vector<int> ToAbsoluteSection(
const std::vector<int>& height_sections) {
std::vector<int> abs_sections;
abs_sections.resize(height_sections.size());
abs_sections[0] = 0;
for (size_t i = 1; i < height_sections.size(); ++i) {
abs_sections[i] = height_sections[i - 1] + abs_sections[i - 1];
}
return abs_sections;
} }
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -40,16 +49,23 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel<T> { ...@@ -40,16 +49,23 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel<T> {
auto outs = ctx.MultiOutput<framework::SelectedRows>("Out"); auto outs = ctx.MultiOutput<framework::SelectedRows>("Out");
auto height_sections = ctx.Attr<std::vector<int>>("height_sections"); auto height_sections = ctx.Attr<std::vector<int>>("height_sections");
auto abs_sections = ToAbsoluteSection(height_sections);
auto x_rows = x->rows(); auto x_rows = x->rows();
std::vector<std::vector<int>> outs_rows_idx; std::vector<std::vector<int>> outs_rows_idx;
std::vector<std::vector<int>> outs_dense_idx;
outs_rows_idx.resize(outs.size()); outs_rows_idx.resize(outs.size());
outs_dense_idx.resize(outs.size());
auto row_numel = x->value().numel() / x->value().dims()[0]; auto row_numel = x->value().numel() / x->value().dims()[0];
auto src = x->value().data<T>(); auto src = x->value().data<T>();
// split rows index into output sparse vars
for (size_t i = 0; i < x_rows.size(); ++i) { for (size_t i = 0; i < x_rows.size(); ++i) {
int out_idx = FindOutIdx(x_rows[i], height_sections); int out_idx = FindOutIdx(x_rows[i], abs_sections);
outs_rows_idx[out_idx].push_back(i); outs_rows_idx[out_idx].push_back(x_rows[i]);
outs_dense_idx[out_idx].push_back(i);
} }
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
...@@ -61,19 +77,20 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel<T> { ...@@ -61,19 +77,20 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel<T> {
dims[0] = rows_idx.size(); dims[0] = rows_idx.size();
outs[i]->mutable_value()->mutable_data<T>(dims, x->place()); outs[i]->mutable_value()->mutable_data<T>(dims, x->place());
for (auto idx : rows_idx) { for (auto idx : rows_idx) {
outs[i]->mutable_rows()->push_back(x_rows[idx]); outs[i]->mutable_rows()->push_back(idx - abs_sections[i]);
} }
auto dst = outs[i]->mutable_value()->mutable_data<T>(ctx.GetPlace()); auto dst = outs[i]->mutable_value()->mutable_data<T>(ctx.GetPlace());
for (size_t j = 0; j < rows_idx.size(); j++) { for (size_t j = 0; j < rows_idx.size(); j++) {
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
memory::Copy(platform::CPUPlace(), dst + j * row_numel, memory::Copy(
platform::CPUPlace(), src + rows_idx[j] * row_numel, platform::CPUPlace(), dst + j * row_numel, platform::CPUPlace(),
sizeof(T) * row_numel); src + outs_dense_idx[i][j] * row_numel, sizeof(T) * row_numel);
} else { } else {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
auto stream = ctx.cuda_device_context().stream(); auto stream = ctx.cuda_device_context().stream();
memory::Copy(platform::CUDAPlace(), dst + j * row_numel, memory::Copy(platform::CUDAPlace(), dst + j * row_numel,
platform::CUDAPlace(), src + rows_idx[j] * row_numel, platform::CUDAPlace(),
src + outs_dense_idx[i][j] * row_numel,
sizeof(T) * row_numel, stream); sizeof(T) * row_numel, stream);
#else #else
PADDLE_THROW("Paddle is not compiled with GPU"); PADDLE_THROW("Paddle is not compiled with GPU");
......
...@@ -76,10 +76,16 @@ class SumOp : public framework::OperatorWithKernel { ...@@ -76,10 +76,16 @@ class SumOp : public framework::OperatorWithKernel {
static_cast<framework::proto::VarType::Type>(dtype), static_cast<framework::proto::VarType::Type>(dtype),
ctx.device_context()); ctx.device_context());
} else if (x_vars[0]->IsType<framework::SelectedRows>()) { } else if (x_vars[0]->IsType<framework::SelectedRows>()) {
return framework::OpKernelType( for (auto& var : x_vars) {
framework::ToDataType( auto& value = var->Get<framework::SelectedRows>().value();
x_vars[0]->Get<framework::SelectedRows>().value().type()), if (value.IsInitialized()) {
ctx.device_context()); return framework::OpKernelType(framework::ToDataType(value.type()),
ctx.device_context());
}
}
// if input sparse vars are not initialized, use an default kernel type.
return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.device_context());
} else if (x_vars[0]->IsType<framework::LoDTensorArray>()) { } else if (x_vars[0]->IsType<framework::LoDTensorArray>()) {
for (auto& x_var : x_vars) { for (auto& x_var : x_vars) {
auto& array = x_var->Get<framework::LoDTensorArray>(); auto& array = x_var->Get<framework::LoDTensorArray>();
......
...@@ -109,6 +109,12 @@ class SumKernel : public framework::OpKernel<T> { ...@@ -109,6 +109,12 @@ class SumKernel : public framework::OpKernel<T> {
in_dim[0] = static_cast<int64_t>(first_dim); in_dim[0] = static_cast<int64_t>(first_dim);
out_value->Resize(framework::make_ddim(in_dim)); out_value->Resize(framework::make_ddim(in_dim));
// if all the input sparse vars are empty, no need to
// merge these vars.
if (first_dim == 0UL) {
return;
}
out_value->mutable_data<T>(context.GetPlace()); out_value->mutable_data<T>(context.GetPlace());
math::SelectedRowsAddTo<DeviceContext, T> functor; math::SelectedRowsAddTo<DeviceContext, T> functor;
...@@ -116,7 +122,7 @@ class SumKernel : public framework::OpKernel<T> { ...@@ -116,7 +122,7 @@ class SumKernel : public framework::OpKernel<T> {
int64_t offset = 0; int64_t offset = 0;
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
auto &sel_row = get_selected_row(i); auto &sel_row = get_selected_row(i);
if (!sel_row.value().IsInitialized() || sel_row.rows().size() == 0) { if (sel_row.rows().size() == 0) {
continue; continue;
} }
PADDLE_ENFORCE_EQ(out->height(), sel_row.height()); PADDLE_ENFORCE_EQ(out->height(), sel_row.height());
......
...@@ -60,8 +60,8 @@ class TestSpliteSelectedRows(unittest.TestCase): ...@@ -60,8 +60,8 @@ class TestSpliteSelectedRows(unittest.TestCase):
# expected output selected rows # expected output selected rows
expected_out0_rows = [0, 4] expected_out0_rows = [0, 4]
expected_out1_rows = [5, 7] expected_out1_rows = [0, 2]
expected_out4_rows = [20] expected_out4_rows = [0]
op = Operator( op = Operator(
"split_selected_rows", "split_selected_rows",
...@@ -101,7 +101,7 @@ class TestSpliteSelectedRows(unittest.TestCase): ...@@ -101,7 +101,7 @@ class TestSpliteSelectedRows(unittest.TestCase):
out0_grad_tensor.set(np_array, place) out0_grad_tensor.set(np_array, place)
out1_grad = scope.var("out1@GRAD").get_selected_rows() out1_grad = scope.var("out1@GRAD").get_selected_rows()
rows1 = [7, 5] rows1 = [2, 0]
out1_grad.set_rows(rows1) out1_grad.set_rows(rows1)
out1_grad.set_height(height) out1_grad.set_height(height)
out1_grad_tensor = out1_grad.get_tensor() out1_grad_tensor = out1_grad.get_tensor()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册