提交 9a9d67da 编写于 作者: T typhoonzero

fix dist train selected rows height missing

上级 eb62a778
......@@ -59,12 +59,12 @@ message VariableMessage {
// lod details:
int64 lod_level = 5;
repeated LodData lod = 6;
// selected_rows height, aka. original dim0
int64 slr_height = 7;
// tensor data
bytes serialized = 7;
bytes serialized = 8;
// selected_rows data
bytes rows = 8;
bytes rows = 9;
}
message VoidMessage {}
message TestMessage { int64 test_1 = 1; }
......@@ -108,6 +108,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
e.WriteUint64(VarMsg::kDimsFieldNumber, dim);
}
e.WriteUint64(VarMsg::kLodLevelFieldNumber, 0);
e.WriteUint64(VarMsg::kSlrHeightFieldNumber, slr->height());
auto* tensor = slr->mutable_value();
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
......
......@@ -40,6 +40,7 @@ void RunSerdeTestSelectedRows(platform::Place place) {
// serialize var to ByteBuffer
framework::Variable var;
auto* slr = var.GetMutable<framework::SelectedRows>();
slr->set_height(1000);
auto* tensor = slr->mutable_value();
auto* rows = slr->mutable_rows();
tensor->Resize(framework::make_ddim({2, 10}));
......@@ -106,6 +107,7 @@ void RunSerdeTestSelectedRows(platform::Place place) {
}
EXPECT_EQ(rows_data2[0], 3);
EXPECT_EQ(rows_data2[1], 10);
EXPECT_EQ(slr2->height(), 1000);
}
void RunTestLodTensor(platform::Place place, int from_type = 0) {
......
......@@ -68,6 +68,8 @@ bool ReadRaw(::google::protobuf::io::CodedInputStream* input,
if (total_written + size_to_write > length) {
size_to_write = length - total_written;
}
VLOG(3) << "copy raw " << size_to_write
<< " bytes, written: " << total_written << ", length: " << length;
memory::Copy(boost::get<platform::CUDAPlace>(place),
reinterpret_cast<void*>(p), cpu, data, size_to_write,
gpu_dev_ctx.stream());
......@@ -147,6 +149,7 @@ bool VariableResponse::CopySelectRowsTensorData(
const platform::DeviceContext& ctx, framework::DDim& dims, int length) {
auto var = scope_->FindVar(meta_.varname());
auto* slr = var->GetMutable<framework::SelectedRows>();
slr->set_height(meta_.slr_height());
auto* tensor = slr->mutable_value();
tensor->Resize(dims);
void* tensor_data = tensor->mutable_data(
......@@ -348,6 +351,14 @@ int VariableResponse::Parse(Source* source) {
}
break;
}
case sendrecv::VariableMessage::kSlrHeightFieldNumber: {
uint64_t v = 0;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) {
return tag;
}
meta_.set_slr_height(static_cast<int64_t>(v));
break;
}
case sendrecv::VariableMessage::kSerializedFieldNumber: {
PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS ||
meta_.type() == sendrecv::LOD_TENSOR) &&
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册