From 9a9d67dac28c362b6b2e86ffeec7c68fa1704d01 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Wed, 28 Mar 2018 16:47:06 +0800 Subject: [PATCH] fix dist train selected rows height missing --- paddle/fluid/operators/detail/send_recv.proto | 8 ++++---- paddle/fluid/operators/detail/sendrecvop_utils.cc | 1 + paddle/fluid/operators/detail/test_serde.cc | 2 ++ paddle/fluid/operators/detail/variable_response.cc | 11 +++++++++++ 4 files changed, 18 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/detail/send_recv.proto b/paddle/fluid/operators/detail/send_recv.proto index 598aaa4c51..2d33f026e4 100644 --- a/paddle/fluid/operators/detail/send_recv.proto +++ b/paddle/fluid/operators/detail/send_recv.proto @@ -59,12 +59,12 @@ message VariableMessage { // lod details: int64 lod_level = 5; repeated LodData lod = 6; + // selected_rows height, aka. original dim0 + int64 slr_height = 7; // tensor data - bytes serialized = 7; + bytes serialized = 8; // selected_rows data - bytes rows = 8; + bytes rows = 9; } message VoidMessage {} - -message TestMessage { int64 test_1 = 1; } diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.cc b/paddle/fluid/operators/detail/sendrecvop_utils.cc index d7bbf79c50..f318f8ac28 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.cc +++ b/paddle/fluid/operators/detail/sendrecvop_utils.cc @@ -108,6 +108,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, e.WriteUint64(VarMsg::kDimsFieldNumber, dim); } e.WriteUint64(VarMsg::kLodLevelFieldNumber, 0); + e.WriteUint64(VarMsg::kSlrHeightFieldNumber, slr->height()); auto* tensor = slr->mutable_value(); if (platform::is_gpu_place(ctx.GetPlace())) { #ifdef PADDLE_WITH_CUDA diff --git a/paddle/fluid/operators/detail/test_serde.cc b/paddle/fluid/operators/detail/test_serde.cc index e646c894d1..e9e2dc84ad 100644 --- a/paddle/fluid/operators/detail/test_serde.cc +++ b/paddle/fluid/operators/detail/test_serde.cc @@ -40,6 +40,7 @@ void RunSerdeTestSelectedRows(platform::Place place) { // serialize var to ByteBuffer framework::Variable var; auto* slr = var.GetMutable(); + slr->set_height(1000); auto* tensor = slr->mutable_value(); auto* rows = slr->mutable_rows(); tensor->Resize(framework::make_ddim({2, 10})); @@ -106,6 +107,7 @@ void RunSerdeTestSelectedRows(platform::Place place) { } EXPECT_EQ(rows_data2[0], 3); EXPECT_EQ(rows_data2[1], 10); + EXPECT_EQ(slr2->height(), 1000); } void RunTestLodTensor(platform::Place place, int from_type = 0) { diff --git a/paddle/fluid/operators/detail/variable_response.cc b/paddle/fluid/operators/detail/variable_response.cc index bdda570343..862fd26b54 100644 --- a/paddle/fluid/operators/detail/variable_response.cc +++ b/paddle/fluid/operators/detail/variable_response.cc @@ -68,6 +68,8 @@ bool ReadRaw(::google::protobuf::io::CodedInputStream* input, if (total_written + size_to_write > length) { size_to_write = length - total_written; } + VLOG(3) << "copy raw " << size_to_write + << " bytes, written: " << total_written << ", length: " << length; memory::Copy(boost::get(place), reinterpret_cast(p), cpu, data, size_to_write, gpu_dev_ctx.stream()); @@ -147,6 +149,7 @@ bool VariableResponse::CopySelectRowsTensorData( const platform::DeviceContext& ctx, framework::DDim& dims, int length) { auto var = scope_->FindVar(meta_.varname()); auto* slr = var->GetMutable(); + slr->set_height(meta_.slr_height()); auto* tensor = slr->mutable_value(); tensor->Resize(dims); void* tensor_data = tensor->mutable_data( @@ -348,6 +351,14 @@ int VariableResponse::Parse(Source* source) { } break; } + case sendrecv::VariableMessage::kSlrHeightFieldNumber: { + uint64_t v = 0; + if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) { + return tag; + } + meta_.set_slr_height(static_cast(v)); + break; + } case sendrecv::VariableMessage::kSerializedFieldNumber: { PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS || meta_.type() == sendrecv::LOD_TENSOR) && -- GitLab