提交 450be963 编写于 作者: T typhoonzero

fix sparse errors

上级 9a9d67da
......@@ -204,7 +204,6 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
}
grpc::ChannelArguments args;
args.SetInt("grpc.testing.fixed_reconnect_backoff_ms", 5000);
args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);
args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
......
......@@ -155,7 +155,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
ProtoEncodeHelper e2((char*)buf, 128);
// NOTE: rows is of type int64_t
size_t rows_memory_size =
slr->rows().capacity() * framework::SizeOfType(typeid(int64_t));
slr->rows().size() * framework::SizeOfType(typeid(int64_t));
e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size);
slices[2] = ::grpc::Slice(e2.size());
memcpy(const_cast<uint8_t*>(slices[2].begin()), e2.data(), e2.size());
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <sys/time.h>
#include <iostream>
#include <string>
#include <vector>
......@@ -35,6 +36,12 @@ namespace detail {
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
static int64_t GetTimestamp() {
struct timeval tp;
gettimeofday(&tp, NULL);
return tp.tv_sec * 1000 + tp.tv_usec / 1000;
}
typedef void (*DestroyCallback)(void*);
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
......
......@@ -43,12 +43,11 @@ void RunSerdeTestSelectedRows(platform::Place place) {
slr->set_height(1000);
auto* tensor = slr->mutable_value();
auto* rows = slr->mutable_rows();
tensor->Resize(framework::make_ddim({2, 10}));
tensor->Resize(framework::make_ddim({564, 128}));
tensor->mutable_data<float>(place);
int tensor_numel = 2 * 10;
int tensor_numel = 564 * 128;
math::set_constant(ctx, tensor, 32.7);
rows->push_back(3);
rows->push_back(10);
for (int i = 0; i < 564; ++i) rows->push_back(i);
::grpc::ByteBuffer msg;
operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg);
......@@ -65,6 +64,7 @@ void RunSerdeTestSelectedRows(platform::Place place) {
sendrecv::VariableMessage varmsg;
EXPECT_TRUE(varmsg.ParseFromString(tmp));
// deserialize bytebuffer
EXPECT_EQ(varmsg.varname(), "myvar");
EXPECT_EQ(varmsg.type(), 1);
......@@ -75,8 +75,10 @@ void RunSerdeTestSelectedRows(platform::Place place) {
for (int i = 0; i < tensor_numel; ++i) {
EXPECT_FLOAT_EQ(tensor_data[i], 32.7);
}
EXPECT_EQ(rows_data[0], 3);
EXPECT_EQ(rows_data[1], 10);
for (int i = 0; i < 564; ++i) {
EXPECT_EQ(rows_data[i], i);
}
// deserialize zero-copy
// framework::Variable var2;
// operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2);
......@@ -105,8 +107,9 @@ void RunSerdeTestSelectedRows(platform::Place place) {
for (int i = 0; i < tensor_numel; ++i) {
EXPECT_FLOAT_EQ(tensor_data2[i], 32.7);
}
EXPECT_EQ(rows_data2[0], 3);
EXPECT_EQ(rows_data2[1], 10);
for (int i = 0; i < rows2->size(); ++i) {
EXPECT_EQ(rows_data2[i], i);
}
EXPECT_EQ(slr2->height(), 1000);
}
......
......@@ -68,8 +68,6 @@ bool ReadRaw(::google::protobuf::io::CodedInputStream* input,
if (total_written + size_to_write > length) {
size_to_write = length - total_written;
}
VLOG(3) << "copy raw " << size_to_write
<< " bytes, written: " << total_written << ", length: " << length;
memory::Copy(boost::get<platform::CUDAPlace>(place),
reinterpret_cast<void*>(p), cpu, data, size_to_write,
gpu_dev_ctx.stream());
......@@ -152,6 +150,10 @@ bool VariableResponse::CopySelectRowsTensorData(
slr->set_height(meta_.slr_height());
auto* tensor = slr->mutable_value();
tensor->Resize(dims);
PADDLE_ENFORCE_EQ(
tensor->numel(),
length / framework::SizeOfType(
paddle::operators::detail::ToTypeIndex(meta_.data_type())));
void* tensor_data = tensor->mutable_data(
ctx.GetPlace(),
paddle::operators::detail::ToTypeIndex(meta_.data_type()));
......@@ -168,7 +170,8 @@ bool VariableResponse::CopySelectRowsData(
const platform::DeviceContext& ctx, int length) {
auto var = scope_->FindVar(meta_.varname());
auto* slr = var->GetMutable<framework::SelectedRows>();
slr->mutable_rows()->resize(length / 8); // int64
slr->mutable_rows()->resize(length /
framework::SizeOfType(typeid(int64_t))); // int64
int64_t* rows_data = slr->mutable_rows()->data();
// copy rows CPU data, GPU data will be copied lazily.
......
......@@ -141,6 +141,7 @@ class ListenAndServOp : public framework::OperatorBase {
// and this will still work.
std::vector<std::future<void>> fs;
double ts = detail::GetTimestamp();
// block0 contains only listen_and_serv op, start run from block1.
for (int blkid = 1; blkid < num_blocks - 1; ++blkid) {
fs.push_back(
......@@ -162,6 +163,7 @@ class ListenAndServOp : public framework::OperatorBase {
LOG(ERROR) << "run sub program error " << e.what();
}
}
VLOG(2) << "run all blocks spent (ms) " << detail::GetTimestamp() - ts;
// Reset the received sparse variables, the sum operator would not
// sum the input sparse variables which rows is empty at the next
......
......@@ -72,7 +72,7 @@ class SendOp : public framework::OperatorBase {
for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) {
VLOG(2) << "sending " << ins[i] << " to " << epmap[i];
VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
rpc_client->AsyncSendVariable(epmap[i], ctx, scope, ins[i]);
} else {
VLOG(3) << "don't send no-initialied variable: " << ins[i];
......@@ -81,7 +81,7 @@ class SendOp : public framework::OperatorBase {
PADDLE_ENFORCE(rpc_client->Wait());
for (auto& ep : endpoints) {
VLOG(2) << "batch barrier, ep: " << ep;
VLOG(3) << "batch barrier, ep: " << ep;
rpc_client->AsyncSendBatchBarrier(ep);
}
PADDLE_ENFORCE(rpc_client->Wait());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册