未验证 提交 d21ab2e2 编写于 作者: 武毅 提交者: GitHub

Merge pull request #9448 from typhoonzero/fix_dist_slr_height

fix dist train selected rows height missing
...@@ -204,7 +204,6 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) { ...@@ -204,7 +204,6 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
} }
grpc::ChannelArguments args; grpc::ChannelArguments args;
args.SetInt("grpc.testing.fixed_reconnect_backoff_ms", 5000);
args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE); args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);
args.SetMaxSendMessageSize(std::numeric_limits<int>::max()); args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max()); args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
......
...@@ -59,12 +59,12 @@ message VariableMessage { ...@@ -59,12 +59,12 @@ message VariableMessage {
// lod details: // lod details:
int64 lod_level = 5; int64 lod_level = 5;
repeated LodData lod = 6; repeated LodData lod = 6;
// selected_rows height, aka. original dim0
int64 slr_height = 7;
// tensor data // tensor data
bytes serialized = 7; bytes serialized = 8;
// selected_rows data // selected_rows data
bytes rows = 8; bytes rows = 9;
} }
message VoidMessage {} message VoidMessage {}
message TestMessage { int64 test_1 = 1; }
...@@ -108,6 +108,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -108,6 +108,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
e.WriteUint64(VarMsg::kDimsFieldNumber, dim); e.WriteUint64(VarMsg::kDimsFieldNumber, dim);
} }
e.WriteUint64(VarMsg::kLodLevelFieldNumber, 0); e.WriteUint64(VarMsg::kLodLevelFieldNumber, 0);
e.WriteUint64(VarMsg::kSlrHeightFieldNumber, slr->height());
auto* tensor = slr->mutable_value(); auto* tensor = slr->mutable_value();
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -154,7 +155,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -154,7 +155,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
ProtoEncodeHelper e2((char*)buf, 128); ProtoEncodeHelper e2((char*)buf, 128);
// NOTE: rows is of type int64_t // NOTE: rows is of type int64_t
size_t rows_memory_size = size_t rows_memory_size =
slr->rows().capacity() * framework::SizeOfType(typeid(int64_t)); slr->rows().size() * framework::SizeOfType(typeid(int64_t));
e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size); e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size);
slices[2] = ::grpc::Slice(e2.size()); slices[2] = ::grpc::Slice(e2.size());
memcpy(const_cast<uint8_t*>(slices[2].begin()), e2.data(), e2.size()); memcpy(const_cast<uint8_t*>(slices[2].begin()), e2.data(), e2.size());
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <sys/time.h>
#include <iostream> #include <iostream>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -35,6 +36,12 @@ namespace detail { ...@@ -35,6 +36,12 @@ namespace detail {
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV" #define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV" #define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
static int64_t GetTimestamp() {
struct timeval tp;
gettimeofday(&tp, NULL);
return tp.tv_sec * 1000 + tp.tv_usec / 1000;
}
typedef void (*DestroyCallback)(void*); typedef void (*DestroyCallback)(void*);
void SerializeToByteBuffer(const std::string& name, framework::Variable* var, void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
......
...@@ -40,14 +40,14 @@ void RunSerdeTestSelectedRows(platform::Place place) { ...@@ -40,14 +40,14 @@ void RunSerdeTestSelectedRows(platform::Place place) {
// serialize var to ByteBuffer // serialize var to ByteBuffer
framework::Variable var; framework::Variable var;
auto* slr = var.GetMutable<framework::SelectedRows>(); auto* slr = var.GetMutable<framework::SelectedRows>();
slr->set_height(1000);
auto* tensor = slr->mutable_value(); auto* tensor = slr->mutable_value();
auto* rows = slr->mutable_rows(); auto* rows = slr->mutable_rows();
tensor->Resize(framework::make_ddim({2, 10})); tensor->Resize(framework::make_ddim({564, 128}));
tensor->mutable_data<float>(place); tensor->mutable_data<float>(place);
int tensor_numel = 2 * 10; int tensor_numel = 564 * 128;
math::set_constant(ctx, tensor, 32.7); math::set_constant(ctx, tensor, 32.7);
rows->push_back(3); for (int i = 0; i < 564; ++i) rows->push_back(i);
rows->push_back(10);
::grpc::ByteBuffer msg; ::grpc::ByteBuffer msg;
operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg); operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg);
...@@ -64,6 +64,7 @@ void RunSerdeTestSelectedRows(platform::Place place) { ...@@ -64,6 +64,7 @@ void RunSerdeTestSelectedRows(platform::Place place) {
sendrecv::VariableMessage varmsg; sendrecv::VariableMessage varmsg;
EXPECT_TRUE(varmsg.ParseFromString(tmp)); EXPECT_TRUE(varmsg.ParseFromString(tmp));
// deserialize bytebuffer
EXPECT_EQ(varmsg.varname(), "myvar"); EXPECT_EQ(varmsg.varname(), "myvar");
EXPECT_EQ(varmsg.type(), 1); EXPECT_EQ(varmsg.type(), 1);
...@@ -74,8 +75,10 @@ void RunSerdeTestSelectedRows(platform::Place place) { ...@@ -74,8 +75,10 @@ void RunSerdeTestSelectedRows(platform::Place place) {
for (int i = 0; i < tensor_numel; ++i) { for (int i = 0; i < tensor_numel; ++i) {
EXPECT_FLOAT_EQ(tensor_data[i], 32.7); EXPECT_FLOAT_EQ(tensor_data[i], 32.7);
} }
EXPECT_EQ(rows_data[0], 3); for (int i = 0; i < 564; ++i) {
EXPECT_EQ(rows_data[1], 10); EXPECT_EQ(rows_data[i], i);
}
// deserialize zero-copy // deserialize zero-copy
// framework::Variable var2; // framework::Variable var2;
// operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2); // operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2);
...@@ -104,8 +107,10 @@ void RunSerdeTestSelectedRows(platform::Place place) { ...@@ -104,8 +107,10 @@ void RunSerdeTestSelectedRows(platform::Place place) {
for (int i = 0; i < tensor_numel; ++i) { for (int i = 0; i < tensor_numel; ++i) {
EXPECT_FLOAT_EQ(tensor_data2[i], 32.7); EXPECT_FLOAT_EQ(tensor_data2[i], 32.7);
} }
EXPECT_EQ(rows_data2[0], 3); for (int i = 0; i < rows2->size(); ++i) {
EXPECT_EQ(rows_data2[1], 10); EXPECT_EQ(rows_data2[i], i);
}
EXPECT_EQ(slr2->height(), 1000);
} }
void RunTestLodTensor(platform::Place place, int from_type = 0) { void RunTestLodTensor(platform::Place place, int from_type = 0) {
......
...@@ -147,8 +147,13 @@ bool VariableResponse::CopySelectRowsTensorData( ...@@ -147,8 +147,13 @@ bool VariableResponse::CopySelectRowsTensorData(
const platform::DeviceContext& ctx, framework::DDim& dims, int length) { const platform::DeviceContext& ctx, framework::DDim& dims, int length) {
auto var = scope_->FindVar(meta_.varname()); auto var = scope_->FindVar(meta_.varname());
auto* slr = var->GetMutable<framework::SelectedRows>(); auto* slr = var->GetMutable<framework::SelectedRows>();
slr->set_height(meta_.slr_height());
auto* tensor = slr->mutable_value(); auto* tensor = slr->mutable_value();
tensor->Resize(dims); tensor->Resize(dims);
PADDLE_ENFORCE_EQ(
tensor->numel(),
length / framework::SizeOfType(
paddle::operators::detail::ToTypeIndex(meta_.data_type())));
void* tensor_data = tensor->mutable_data( void* tensor_data = tensor->mutable_data(
ctx.GetPlace(), ctx.GetPlace(),
paddle::operators::detail::ToTypeIndex(meta_.data_type())); paddle::operators::detail::ToTypeIndex(meta_.data_type()));
...@@ -165,7 +170,8 @@ bool VariableResponse::CopySelectRowsData( ...@@ -165,7 +170,8 @@ bool VariableResponse::CopySelectRowsData(
const platform::DeviceContext& ctx, int length) { const platform::DeviceContext& ctx, int length) {
auto var = scope_->FindVar(meta_.varname()); auto var = scope_->FindVar(meta_.varname());
auto* slr = var->GetMutable<framework::SelectedRows>(); auto* slr = var->GetMutable<framework::SelectedRows>();
slr->mutable_rows()->resize(length / 8); // int64 slr->mutable_rows()->resize(length /
framework::SizeOfType(typeid(int64_t))); // int64
int64_t* rows_data = slr->mutable_rows()->data(); int64_t* rows_data = slr->mutable_rows()->data();
// copy rows CPU data, GPU data will be copied lazily. // copy rows CPU data, GPU data will be copied lazily.
...@@ -348,6 +354,14 @@ int VariableResponse::Parse(Source* source) { ...@@ -348,6 +354,14 @@ int VariableResponse::Parse(Source* source) {
} }
break; break;
} }
case sendrecv::VariableMessage::kSlrHeightFieldNumber: {
uint64_t v = 0;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) {
return tag;
}
meta_.set_slr_height(static_cast<int64_t>(v));
break;
}
case sendrecv::VariableMessage::kSerializedFieldNumber: { case sendrecv::VariableMessage::kSerializedFieldNumber: {
PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS || PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS ||
meta_.type() == sendrecv::LOD_TENSOR) && meta_.type() == sendrecv::LOD_TENSOR) &&
......
...@@ -141,6 +141,7 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -141,6 +141,7 @@ class ListenAndServOp : public framework::OperatorBase {
// and this will still work. // and this will still work.
std::vector<std::future<void>> fs; std::vector<std::future<void>> fs;
double ts = detail::GetTimestamp();
// block0 contains only listen_and_serv op, start run from block1. // block0 contains only listen_and_serv op, start run from block1.
for (int blkid = 1; blkid < num_blocks - 1; ++blkid) { for (int blkid = 1; blkid < num_blocks - 1; ++blkid) {
fs.push_back( fs.push_back(
...@@ -162,6 +163,7 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -162,6 +163,7 @@ class ListenAndServOp : public framework::OperatorBase {
LOG(ERROR) << "run sub program error " << e.what(); LOG(ERROR) << "run sub program error " << e.what();
} }
} }
VLOG(2) << "run all blocks spent (ms) " << detail::GetTimestamp() - ts;
// Reset the received sparse variables, the sum operator would not // Reset the received sparse variables, the sum operator would not
// sum the input sparse variables which rows is empty at the next // sum the input sparse variables which rows is empty at the next
......
...@@ -72,7 +72,7 @@ class SendOp : public framework::OperatorBase { ...@@ -72,7 +72,7 @@ class SendOp : public framework::OperatorBase {
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) { if (NeedSend(scope, ins[i])) {
VLOG(2) << "sending " << ins[i] << " to " << epmap[i]; VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
rpc_client->AsyncSendVariable(epmap[i], ctx, scope, ins[i]); rpc_client->AsyncSendVariable(epmap[i], ctx, scope, ins[i]);
} else { } else {
VLOG(3) << "don't send no-initialied variable: " << ins[i]; VLOG(3) << "don't send no-initialied variable: " << ins[i];
...@@ -81,7 +81,7 @@ class SendOp : public framework::OperatorBase { ...@@ -81,7 +81,7 @@ class SendOp : public framework::OperatorBase {
PADDLE_ENFORCE(rpc_client->Wait()); PADDLE_ENFORCE(rpc_client->Wait());
for (auto& ep : endpoints) { for (auto& ep : endpoints) {
VLOG(2) << "batch barrier, ep: " << ep; VLOG(3) << "batch barrier, ep: " << ep;
rpc_client->AsyncSendBatchBarrier(ep); rpc_client->AsyncSendBatchBarrier(ep);
} }
PADDLE_ENFORCE(rpc_client->Wait()); PADDLE_ENFORCE(rpc_client->Wait());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册