提交 9a9d67da 编写于 作者: T typhoonzero

fix dist train selected rows height missing

上级 eb62a778
...@@ -59,12 +59,12 @@ message VariableMessage { ...@@ -59,12 +59,12 @@ message VariableMessage {
// lod details: // lod details:
int64 lod_level = 5; int64 lod_level = 5;
repeated LodData lod = 6; repeated LodData lod = 6;
// selected_rows height, aka. original dim0
int64 slr_height = 7;
// tensor data // tensor data
bytes serialized = 7; bytes serialized = 8;
// selected_rows data // selected_rows data
bytes rows = 8; bytes rows = 9;
} }
message VoidMessage {} message VoidMessage {}
message TestMessage { int64 test_1 = 1; }
...@@ -108,6 +108,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -108,6 +108,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
e.WriteUint64(VarMsg::kDimsFieldNumber, dim); e.WriteUint64(VarMsg::kDimsFieldNumber, dim);
} }
e.WriteUint64(VarMsg::kLodLevelFieldNumber, 0); e.WriteUint64(VarMsg::kLodLevelFieldNumber, 0);
e.WriteUint64(VarMsg::kSlrHeightFieldNumber, slr->height());
auto* tensor = slr->mutable_value(); auto* tensor = slr->mutable_value();
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
......
...@@ -40,6 +40,7 @@ void RunSerdeTestSelectedRows(platform::Place place) { ...@@ -40,6 +40,7 @@ void RunSerdeTestSelectedRows(platform::Place place) {
// serialize var to ByteBuffer // serialize var to ByteBuffer
framework::Variable var; framework::Variable var;
auto* slr = var.GetMutable<framework::SelectedRows>(); auto* slr = var.GetMutable<framework::SelectedRows>();
slr->set_height(1000);
auto* tensor = slr->mutable_value(); auto* tensor = slr->mutable_value();
auto* rows = slr->mutable_rows(); auto* rows = slr->mutable_rows();
tensor->Resize(framework::make_ddim({2, 10})); tensor->Resize(framework::make_ddim({2, 10}));
...@@ -106,6 +107,7 @@ void RunSerdeTestSelectedRows(platform::Place place) { ...@@ -106,6 +107,7 @@ void RunSerdeTestSelectedRows(platform::Place place) {
} }
EXPECT_EQ(rows_data2[0], 3); EXPECT_EQ(rows_data2[0], 3);
EXPECT_EQ(rows_data2[1], 10); EXPECT_EQ(rows_data2[1], 10);
EXPECT_EQ(slr2->height(), 1000);
} }
void RunTestLodTensor(platform::Place place, int from_type = 0) { void RunTestLodTensor(platform::Place place, int from_type = 0) {
......
...@@ -68,6 +68,8 @@ bool ReadRaw(::google::protobuf::io::CodedInputStream* input, ...@@ -68,6 +68,8 @@ bool ReadRaw(::google::protobuf::io::CodedInputStream* input,
if (total_written + size_to_write > length) { if (total_written + size_to_write > length) {
size_to_write = length - total_written; size_to_write = length - total_written;
} }
VLOG(3) << "copy raw " << size_to_write
<< " bytes, written: " << total_written << ", length: " << length;
memory::Copy(boost::get<platform::CUDAPlace>(place), memory::Copy(boost::get<platform::CUDAPlace>(place),
reinterpret_cast<void*>(p), cpu, data, size_to_write, reinterpret_cast<void*>(p), cpu, data, size_to_write,
gpu_dev_ctx.stream()); gpu_dev_ctx.stream());
...@@ -147,6 +149,7 @@ bool VariableResponse::CopySelectRowsTensorData( ...@@ -147,6 +149,7 @@ bool VariableResponse::CopySelectRowsTensorData(
const platform::DeviceContext& ctx, framework::DDim& dims, int length) { const platform::DeviceContext& ctx, framework::DDim& dims, int length) {
auto var = scope_->FindVar(meta_.varname()); auto var = scope_->FindVar(meta_.varname());
auto* slr = var->GetMutable<framework::SelectedRows>(); auto* slr = var->GetMutable<framework::SelectedRows>();
slr->set_height(meta_.slr_height());
auto* tensor = slr->mutable_value(); auto* tensor = slr->mutable_value();
tensor->Resize(dims); tensor->Resize(dims);
void* tensor_data = tensor->mutable_data( void* tensor_data = tensor->mutable_data(
...@@ -348,6 +351,14 @@ int VariableResponse::Parse(Source* source) { ...@@ -348,6 +351,14 @@ int VariableResponse::Parse(Source* source) {
} }
break; break;
} }
case sendrecv::VariableMessage::kSlrHeightFieldNumber: {
uint64_t v = 0;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) {
return tag;
}
meta_.set_slr_height(static_cast<int64_t>(v));
break;
}
case sendrecv::VariableMessage::kSerializedFieldNumber: { case sendrecv::VariableMessage::kSerializedFieldNumber: {
PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS || PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS ||
meta_.type() == sendrecv::LOD_TENSOR) && meta_.type() == sendrecv::LOD_TENSOR) &&
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册