未验证 提交 d21ab2e2 编写于 作者: 武毅 提交者: GitHub

Merge pull request #9448 from typhoonzero/fix_dist_slr_height

fix dist train selected rows height missing
......@@ -204,7 +204,6 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
}
grpc::ChannelArguments args;
args.SetInt("grpc.testing.fixed_reconnect_backoff_ms", 5000);
args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);
args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
......
......@@ -59,12 +59,12 @@ message VariableMessage {
// lod details:
int64 lod_level = 5;
repeated LodData lod = 6;
// selected_rows height, aka. original dim0
int64 slr_height = 7;
// tensor data
bytes serialized = 7;
bytes serialized = 8;
// selected_rows data
bytes rows = 8;
bytes rows = 9;
}
message VoidMessage {}
message TestMessage { int64 test_1 = 1; }
......@@ -108,6 +108,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
e.WriteUint64(VarMsg::kDimsFieldNumber, dim);
}
e.WriteUint64(VarMsg::kLodLevelFieldNumber, 0);
e.WriteUint64(VarMsg::kSlrHeightFieldNumber, slr->height());
auto* tensor = slr->mutable_value();
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
......@@ -154,7 +155,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
ProtoEncodeHelper e2((char*)buf, 128);
// NOTE: rows is of type int64_t
size_t rows_memory_size =
slr->rows().capacity() * framework::SizeOfType(typeid(int64_t));
slr->rows().size() * framework::SizeOfType(typeid(int64_t));
e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size);
slices[2] = ::grpc::Slice(e2.size());
memcpy(const_cast<uint8_t*>(slices[2].begin()), e2.data(), e2.size());
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <sys/time.h>
#include <iostream>
#include <string>
#include <vector>
......@@ -35,6 +36,12 @@ namespace detail {
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
static int64_t GetTimestamp() {
struct timeval tp;
gettimeofday(&tp, NULL);
return tp.tv_sec * 1000 + tp.tv_usec / 1000;
}
typedef void (*DestroyCallback)(void*);
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
......
......@@ -40,14 +40,14 @@ void RunSerdeTestSelectedRows(platform::Place place) {
// serialize var to ByteBuffer
framework::Variable var;
auto* slr = var.GetMutable<framework::SelectedRows>();
slr->set_height(1000);
auto* tensor = slr->mutable_value();
auto* rows = slr->mutable_rows();
tensor->Resize(framework::make_ddim({2, 10}));
tensor->Resize(framework::make_ddim({564, 128}));
tensor->mutable_data<float>(place);
int tensor_numel = 2 * 10;
int tensor_numel = 564 * 128;
math::set_constant(ctx, tensor, 32.7);
rows->push_back(3);
rows->push_back(10);
for (int i = 0; i < 564; ++i) rows->push_back(i);
::grpc::ByteBuffer msg;
operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg);
......@@ -64,6 +64,7 @@ void RunSerdeTestSelectedRows(platform::Place place) {
sendrecv::VariableMessage varmsg;
EXPECT_TRUE(varmsg.ParseFromString(tmp));
// deserialize bytebuffer
EXPECT_EQ(varmsg.varname(), "myvar");
EXPECT_EQ(varmsg.type(), 1);
......@@ -74,8 +75,10 @@ void RunSerdeTestSelectedRows(platform::Place place) {
for (int i = 0; i < tensor_numel; ++i) {
EXPECT_FLOAT_EQ(tensor_data[i], 32.7);
}
EXPECT_EQ(rows_data[0], 3);
EXPECT_EQ(rows_data[1], 10);
for (int i = 0; i < 564; ++i) {
EXPECT_EQ(rows_data[i], i);
}
// deserialize zero-copy
// framework::Variable var2;
// operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2);
......@@ -104,8 +107,10 @@ void RunSerdeTestSelectedRows(platform::Place place) {
for (int i = 0; i < tensor_numel; ++i) {
EXPECT_FLOAT_EQ(tensor_data2[i], 32.7);
}
EXPECT_EQ(rows_data2[0], 3);
EXPECT_EQ(rows_data2[1], 10);
for (int i = 0; i < rows2->size(); ++i) {
EXPECT_EQ(rows_data2[i], i);
}
EXPECT_EQ(slr2->height(), 1000);
}
void RunTestLodTensor(platform::Place place, int from_type = 0) {
......
......@@ -147,8 +147,13 @@ bool VariableResponse::CopySelectRowsTensorData(
const platform::DeviceContext& ctx, framework::DDim& dims, int length) {
auto var = scope_->FindVar(meta_.varname());
auto* slr = var->GetMutable<framework::SelectedRows>();
slr->set_height(meta_.slr_height());
auto* tensor = slr->mutable_value();
tensor->Resize(dims);
PADDLE_ENFORCE_EQ(
tensor->numel(),
length / framework::SizeOfType(
paddle::operators::detail::ToTypeIndex(meta_.data_type())));
void* tensor_data = tensor->mutable_data(
ctx.GetPlace(),
paddle::operators::detail::ToTypeIndex(meta_.data_type()));
......@@ -165,7 +170,8 @@ bool VariableResponse::CopySelectRowsData(
const platform::DeviceContext& ctx, int length) {
auto var = scope_->FindVar(meta_.varname());
auto* slr = var->GetMutable<framework::SelectedRows>();
slr->mutable_rows()->resize(length / 8); // int64
slr->mutable_rows()->resize(length /
framework::SizeOfType(typeid(int64_t))); // int64
int64_t* rows_data = slr->mutable_rows()->data();
// copy rows CPU data, GPU data will be copied lazily.
......@@ -348,6 +354,14 @@ int VariableResponse::Parse(Source* source) {
}
break;
}
case sendrecv::VariableMessage::kSlrHeightFieldNumber: {
uint64_t v = 0;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) {
return tag;
}
meta_.set_slr_height(static_cast<int64_t>(v));
break;
}
case sendrecv::VariableMessage::kSerializedFieldNumber: {
PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS ||
meta_.type() == sendrecv::LOD_TENSOR) &&
......
......@@ -141,6 +141,7 @@ class ListenAndServOp : public framework::OperatorBase {
// and this will still work.
std::vector<std::future<void>> fs;
double ts = detail::GetTimestamp();
// block0 contains only listen_and_serv op, start run from block1.
for (int blkid = 1; blkid < num_blocks - 1; ++blkid) {
fs.push_back(
......@@ -162,6 +163,7 @@ class ListenAndServOp : public framework::OperatorBase {
LOG(ERROR) << "run sub program error " << e.what();
}
}
VLOG(2) << "run all blocks spent (ms) " << detail::GetTimestamp() - ts;
// Reset the received sparse variables, the sum operator would not
// sum the input sparse variables which rows is empty at the next
......
......@@ -72,7 +72,7 @@ class SendOp : public framework::OperatorBase {
for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) {
VLOG(2) << "sending " << ins[i] << " to " << epmap[i];
VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
rpc_client->AsyncSendVariable(epmap[i], ctx, scope, ins[i]);
} else {
VLOG(3) << "don't send no-initialied variable: " << ins[i];
......@@ -81,7 +81,7 @@ class SendOp : public framework::OperatorBase {
PADDLE_ENFORCE(rpc_client->Wait());
for (auto& ep : endpoints) {
VLOG(2) << "batch barrier, ep: " << ep;
VLOG(3) << "batch barrier, ep: " << ep;
rpc_client->AsyncSendBatchBarrier(ep);
}
PADDLE_ENFORCE(rpc_client->Wait());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册