From e5fe8935fb8aedd8b7ecd63136bba56e02dbd96c Mon Sep 17 00:00:00 2001 From: Yancey Date: Fri, 5 Jan 2018 13:41:34 +0800 Subject: [PATCH] send_recv variables (#7161) * send_recv variable * delete unused logs * fix ci failed * update * resize tensor before tensor copy * add selectedrows unit test * check rows --- paddle/framework/lod_tensor.cc | 7 +- paddle/framework/lod_tensor.h | 3 +- paddle/framework/lod_tensor_test.cc | 2 +- paddle/framework/selected_rows.cc | 6 +- paddle/framework/selected_rows.h | 3 +- paddle/framework/selected_rows_test.cc | 4 +- paddle/framework/tensor_util.h | 57 ++++++--- paddle/framework/tensor_util_test.cc | 6 +- paddle/framework/var_type.h | 4 +- paddle/operators/detail/recv_impl.cc | 19 +-- paddle/operators/detail/send_impl.cc | 18 +-- paddle/operators/detail/send_recv.proto | 18 +-- paddle/operators/detail/send_recv_impl.h | 57 ++++++++- paddle/operators/load_op.cc | 2 +- paddle/operators/recv_op.cc | 13 +- paddle/operators/send_recv_op_test.cc | 148 +++++++++++++++++------ 16 files changed, 249 insertions(+), 118 deletions(-) diff --git a/paddle/framework/lod_tensor.cc b/paddle/framework/lod_tensor.cc index 92b3d7fccd..8f6944c241 100644 --- a/paddle/framework/lod_tensor.cc +++ b/paddle/framework/lod_tensor.cc @@ -217,9 +217,10 @@ void SerializeToStream(std::ostream &os, const LoDTensor &tensor, SerializeToStream(os, static_cast(tensor), dev_ctx); } -void DeserializeFromStream(std::istream &is, LoDTensor *tensor) { +void DeserializeFromStream(std::istream &is, LoDTensor *tensor, + const platform::DeviceContext &dev_ctx) { { - // the 1st field, unit32_t version for SelectedRows + // the 1st field, unit32_t version for LoDTensor uint32_t version; is.read(reinterpret_cast(&version), sizeof(version)); PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported"); @@ -240,7 +241,7 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor) { } } // the 3st filed, Tensor - DeserializeFromStream(is, static_cast(tensor)); + DeserializeFromStream(is, static_cast(tensor), dev_ctx); } } // namespace framework diff --git a/paddle/framework/lod_tensor.h b/paddle/framework/lod_tensor.h index 147db3ab08..d0b6befffe 100644 --- a/paddle/framework/lod_tensor.h +++ b/paddle/framework/lod_tensor.h @@ -208,7 +208,8 @@ void AppendLoD(LoD* lod, const LoD& lod_length); */ void SerializeToStream(std::ostream& os, const LoDTensor& tensor, const platform::DeviceContext& dev_ctx); -void DeserializeFromStream(std::istream& is, LoDTensor* tensor); +void DeserializeFromStream(std::istream& is, LoDTensor* tensor, + const platform::DeviceContext& dev_ctx); } // namespace framework } // namespace paddle diff --git a/paddle/framework/lod_tensor_test.cc b/paddle/framework/lod_tensor_test.cc index 0747c8db53..0868c1f6e6 100644 --- a/paddle/framework/lod_tensor_test.cc +++ b/paddle/framework/lod_tensor_test.cc @@ -132,7 +132,7 @@ TEST_F(LoDTensorTester, SerializeAndDeserialize) { std::ostringstream oss; SerializeToStream(oss, lod_tensor_, cpu_ctx); std::istringstream iss(oss.str()); - DeserializeFromStream(iss, &dst_tensor); + DeserializeFromStream(iss, &dst_tensor, cpu_ctx); float* dst_ptr = dst_tensor.mutable_data(platform::CPUPlace()); for (int i = 0; i < kLodTensorSize; ++i) { EXPECT_EQ(dst_ptr[i], i); diff --git a/paddle/framework/selected_rows.cc b/paddle/framework/selected_rows.cc index 82adfa7123..3b3e60177a 100644 --- a/paddle/framework/selected_rows.cc +++ b/paddle/framework/selected_rows.cc @@ -37,8 +37,8 @@ void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows, SerializeToStream(os, selected_rows.value(), dev_ctx); } -void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows) { - auto tensor = *selected_rows->mutable_value(); +void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows, + const platform::DeviceContext& dev_ctx) { { // the 1st field, unit32_t version for SelectedRows uint32_t version; @@ -62,7 +62,7 @@ void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows) { selected_rows->set_height(height); } // the 4st field, tensor which contains the data - DeserializeFromStream(is, &tensor); + DeserializeFromStream(is, selected_rows->mutable_value(), dev_ctx); } } // namespace framework diff --git a/paddle/framework/selected_rows.h b/paddle/framework/selected_rows.h index 699e392688..30d3dfc1e8 100644 --- a/paddle/framework/selected_rows.h +++ b/paddle/framework/selected_rows.h @@ -66,7 +66,8 @@ class SelectedRows { */ void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows, const platform::DeviceContext& dev_ctx); -void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows); +void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows, + const platform::DeviceContext& dev_ctx); } // namespace framework } // namespace paddle diff --git a/paddle/framework/selected_rows_test.cc b/paddle/framework/selected_rows_test.cc index 75487c4010..8ff3fb6a97 100644 --- a/paddle/framework/selected_rows_test.cc +++ b/paddle/framework/selected_rows_test.cc @@ -51,10 +51,12 @@ TEST_F(SelectedRowsTester, SerializeAndDeseralize) { SerializeToStream(oss, *selected_rows_, cpu_ctx); std::istringstream iss(oss.str()); - DeserializeFromStream(iss, &dst_tensor); + DeserializeFromStream(iss, &dst_tensor, cpu_ctx); ASSERT_EQ(selected_rows_->rows(), dst_tensor.rows()); ASSERT_EQ(selected_rows_->height(), dst_tensor.height()); + ASSERT_EQ(selected_rows_->value().dims(), dst_tensor.value().dims()); + ASSERT_EQ(selected_rows_->GetCompleteDims(), dst_tensor.GetCompleteDims()); } } // namespace framework diff --git a/paddle/framework/tensor_util.h b/paddle/framework/tensor_util.h index 6a21f8db1e..5ac13cba4d 100644 --- a/paddle/framework/tensor_util.h +++ b/paddle/framework/tensor_util.h @@ -270,7 +270,23 @@ inline void SerializeToStream(std::ostream& os, const Tensor& tensor, } } -inline void DeserializeFromStream(std::istream& is, Tensor* tensor) { +struct DeserializedDataFunctor { + DeserializedDataFunctor(void** buf, Tensor* tensor, + const platform::Place& place) + : buf_(buf), tensor_(tensor), place_(place) {} + + template + void operator()() { + *buf_ = tensor_->mutable_data(place_); + } + + void** buf_; + Tensor* tensor_; + platform::Place place_; +}; + +inline void DeserializeFromStream(std::istream& is, Tensor* tensor, + const platform::DeviceContext& dev_ctx) { uint32_t version; is.read(reinterpret_cast(&version), sizeof(version)); PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported"); @@ -289,27 +305,28 @@ inline void DeserializeFromStream(std::istream& is, Tensor* tensor) { dims.reserve(static_cast(desc.dims().size())); std::copy(desc.dims().begin(), desc.dims().end(), std::back_inserter(dims)); tensor->Resize(framework::make_ddim(dims)); - void* buf; - platform::Place cpu = platform::CPUPlace(); - // TODO(Yancey1989): use VisiterDataType instead of DataType switch - switch (desc.data_type()) { - case proto::FP32: - buf = tensor->mutable_data(cpu); - break; - case proto::FP64: - buf = tensor->mutable_data(cpu); - break; - case proto::INT32: - buf = tensor->mutable_data(cpu); - break; - case proto::INT64: - buf = tensor->mutable_data(cpu); - break; - default: - PADDLE_THROW("DataType %d not supported", desc.data_type()); + auto ctx = platform::CPUDeviceContext(); + if (platform::is_gpu_place(dev_ctx.GetPlace())) { +#ifdef PADDLE_WITH_CUDA + Tensor cpu_tensor; + cpu_tensor.Resize(framework::make_ddim(dims)); + framework::VisitDataType( + desc.data_type(), + DeserializedDataFunctor(&buf, &cpu_tensor, ctx.GetPlace())); + is.read(static_cast(buf), cpu_tensor.memory_size()); + auto cpu_place = new platform::CPUPlace(); + framework::CopyFrom(cpu_tensor, *cpu_place, dev_ctx, tensor); + delete cpu_place; +#else + PADDLE_THROW("Unexpected branch"); +#endif + } else { + framework::VisitDataType( + desc.data_type(), + DeserializedDataFunctor(&buf, tensor, ctx.GetPlace())); + is.read(static_cast(buf), tensor->memory_size()); } - is.read(static_cast(buf), tensor->memory_size()); } } diff --git a/paddle/framework/tensor_util_test.cc b/paddle/framework/tensor_util_test.cc index 0dc5166fca..15cd2bd09c 100644 --- a/paddle/framework/tensor_util_test.cc +++ b/paddle/framework/tensor_util_test.cc @@ -270,11 +270,12 @@ TEST(Tensor, SerializeAndDeserialize) { SerializeToStream(oss, src_tensor, cpu_ctx); std::istringstream iss(oss.str()); - DeserializeFromStream(iss, &dst_tensor); + DeserializeFromStream(iss, &dst_tensor, cpu_ctx); int* dst_ptr = dst_tensor.mutable_data(platform::CPUPlace()); for (int i = 0; i < 5; ++i) { ASSERT_EQ(dst_ptr[i], array[i]); } + ASSERT_EQ(dst_tensor.dims(), src_tensor.dims()); delete place; } #ifdef PADDLE_WITH_CUDA @@ -292,13 +293,12 @@ TEST(Tensor, SerializeAndDeserialize) { SerializeToStream(oss, gpu_tensor, gpu_ctx); std::istringstream iss(oss.str()); - DeserializeFromStream(iss, &dst_tensor); + DeserializeFromStream(iss, &dst_tensor, gpu_ctx); int* dst_ptr = dst_tensor.mutable_data(platform::CPUPlace()); for (int i = 0; i < 6; ++i) { ASSERT_EQ(dst_ptr[i], array[i]); } - delete gpu_place; } #endif diff --git a/paddle/framework/var_type.h b/paddle/framework/var_type.h index 0e6ea8dc69..5b7a08a087 100644 --- a/paddle/framework/var_type.h +++ b/paddle/framework/var_type.h @@ -17,6 +17,8 @@ limitations under the License. */ #include "paddle/framework/lod_rank_table.h" #include "paddle/framework/lod_tensor.h" #include "paddle/framework/lod_tensor_array.h" +#include "paddle/framework/selected_rows.h" +#include "paddle/framework/variable.h" namespace paddle { namespace framework { @@ -35,7 +37,7 @@ inline proto::VarDesc::VarType ToVarType(std::type_index type) { } template -inline void VisitVarType(const Variable& var, Visitor visitor) { +inline void VisitVarType(const framework::Variable& var, Visitor visitor) { switch (ToVarType(var.Type())) { case proto::VarDesc_VarType_LOD_TENSOR: visitor(var.Get()); diff --git a/paddle/operators/detail/recv_impl.cc b/paddle/operators/detail/recv_impl.cc index b746f9df46..319404e56a 100644 --- a/paddle/operators/detail/recv_impl.cc +++ b/paddle/operators/detail/recv_impl.cc @@ -21,14 +21,9 @@ namespace detail { Status SendRecvServerImpl::SendVariable(ServerContext *context, const VariableMessage *in_var, VoidMessage *out_var) { - // TODO(typhoonzero): support different variable types. - std::istringstream iss(in_var->serialized()); - framework::LoDTensor t; - framework::DeserializeFromStream(iss, &t); - TensorWithName tensor_with_name = - std::make_pair(in_var->varname(), std::move(t)); - - var_recv_queue_.Push(std::move(tensor_with_name)); + MessageWithName msg_with_name = + std::make_pair(in_var->varname(), std::move(*in_var)); + var_recv_queue_.Push(std::move(msg_with_name)); return Status::OK; } @@ -37,14 +32,8 @@ Status SendRecvServerImpl::GetVariable(ServerContext *context, VariableMessage *out_var) { std::string get_var_name = in_var->varname(); auto *var = scope_->FindVar(get_var_name); - auto tensor = var->Get(); - std::ostringstream oss; - framework::SerializeToStream(oss, tensor, platform::CPUDeviceContext()); - std::string *varname = out_var->mutable_varname(); - *varname = get_var_name; - std::string *serialized = out_var->mutable_serialized(); - *serialized = oss.str(); + SerializeToMessage(get_var_name, var, platform::CPUDeviceContext(), out_var); return Status::OK; } diff --git a/paddle/operators/detail/send_impl.cc b/paddle/operators/detail/send_impl.cc index a812fcf39b..ae85cf2cec 100644 --- a/paddle/operators/detail/send_impl.cc +++ b/paddle/operators/detail/send_impl.cc @@ -27,14 +27,8 @@ bool RPCClient::SendVariable(const framework::Scope& scope, auto ctx = platform::CPUDeviceContext(); auto* var = scope.FindVar(inname); PADDLE_ENFORCE(var); - // TODO(typhoonzero): support SelectedRows - PADDLE_ENFORCE(var->IsType(), - "Only support LoDTensor, %s has wrong type", inname); - const framework::LoDTensor& tensor = var->Get(); - std::ostringstream oss; - framework::SerializeToStream(oss, tensor, ctx); - msg.set_varname(inname); - msg.set_serialized(oss.str()); + SerializeToMessage(inname, var, ctx, &msg); + Status status = stub_->SendVariable(&context, msg, &out_msg); if (!status.ok()) { LOG(ERROR) << "gRPC error: " << status.error_message(); @@ -50,19 +44,15 @@ bool RPCClient::GetVariable(const framework::Scope& scope, call_msg.set_varname(outname); auto ctx = platform::CPUDeviceContext(); Status status = stub_->GetVariable(&context, call_msg, &ret_msg); + auto* outvar = scope.FindVar(outname); if (!status.ok()) { LOG(ERROR) << "gRPC error: " << status.error_message(); return false; } std::istringstream iss(ret_msg.serialized()); + DeserializeFromMessage(ret_msg, ctx, outvar); - framework::LoDTensor ret_tensor; - framework::DeserializeFromStream(iss, &ret_tensor); - auto* outvar = scope.FindVar(outname); - framework::LoDTensor* out_tensor = outvar->GetMutable(); - // FIXME(typhoonzero): do not copy. - framework::CopyFrom(ret_tensor, ctx.GetPlace(), ctx, out_tensor); return true; } diff --git a/paddle/operators/detail/send_recv.proto b/paddle/operators/detail/send_recv.proto index 95c8e70898..f141c755ce 100644 --- a/paddle/operators/detail/send_recv.proto +++ b/paddle/operators/detail/send_recv.proto @@ -1,7 +1,6 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. Licensed under +the Apache License, Version 2.0 (the "License"); you may not use this file +except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 @@ -13,7 +12,6 @@ See the License for the specific language governing permissions and limitations under the License. */ syntax = "proto3"; - package sendrecv; service SendRecvService { @@ -29,12 +27,18 @@ service SendRecvService { // VariableMessage is serialized paddle variable message. // It can be: -// Tensor // LoDTensor // SelectedRows +enum VarType { + LOD_TENSOR = 0; + SELECTED_ROWS = 1; +} + message VariableMessage { string varname = 1; - bytes serialized = 2; + // TODO(Yancey1989): reference framework::proto::VarDesc::VarType + VarType type = 2; + bytes serialized = 3; } message VoidMessage {} diff --git a/paddle/operators/detail/send_recv_impl.h b/paddle/operators/detail/send_recv_impl.h index 47f730f7ae..1fe54f1f05 100644 --- a/paddle/operators/detail/send_recv_impl.h +++ b/paddle/operators/detail/send_recv_impl.h @@ -14,10 +14,10 @@ limitations under the License. */ #pragma once -#include "paddle/framework/data_type.h" #include "paddle/framework/lod_tensor.h" #include "paddle/framework/scope.h" #include "paddle/framework/selected_rows.h" +#include "paddle/framework/var_type.h" #include "paddle/operators/detail/simple_block_queue.h" #include "paddle/operators/detail/send_recv.grpc.pb.h" @@ -44,7 +44,7 @@ namespace paddle { namespace operators { namespace detail { -typedef std::pair TensorWithName; +typedef std::pair MessageWithName; class SendRecvServerImpl final : public SendRecvService::Service { public: @@ -60,13 +60,13 @@ class SendRecvServerImpl final : public SendRecvService::Service { void Done(); void SetScope(framework::Scope *scope) { scope_ = scope; }; - const TensorWithName Get() { return this->var_recv_queue_.Pop(); } + const MessageWithName Get() { return this->var_recv_queue_.Pop(); } - void Push(const TensorWithName &msg) { this->var_recv_queue_.Push(msg); } + void Push(const MessageWithName &msg) { this->var_recv_queue_.Push(msg); } private: // received variable from RPC, operators fetch variable from this queue. - SimpleBlockQueue var_recv_queue_; + SimpleBlockQueue var_recv_queue_; framework::Scope *scope_; // condition of the sub program std::mutex mutex_; @@ -89,6 +89,53 @@ class RPCClient { std::unique_ptr stub_; }; +inline void SerializeToMessage(const std::string &name, + const framework::Variable *var, + const platform::DeviceContext &ctx, + VariableMessage *msg) { + msg->set_varname(name); + std::ostringstream oss; + switch (framework::ToVarType(var->Type())) { + case framework::proto::VarDesc_VarType_LOD_TENSOR: + msg->set_type(sendrecv::VarType::LOD_TENSOR); + framework::SerializeToStream(oss, var->Get(), ctx); + break; + case framework::proto::VarDesc_VarType_SELECTED_ROWS: + msg->set_type(sendrecv::VarType::SELECTED_ROWS); + framework::SerializeToStream(oss, var->Get(), + ctx); + break; + default: { + PADDLE_THROW("Serialize does not support type: %s", + typeid(var->Type()).name()); + break; + } + } + msg->set_serialized(oss.str()); +} + +inline void DeserializeFromMessage(const VariableMessage &msg, + const platform::DeviceContext &ctx, + framework::Variable *var) { + using namespace paddle::framework::proto; + std::istringstream iss(msg.serialized()); + switch (msg.type()) { + case sendrecv::VarType::LOD_TENSOR: + DeserializeFromStream(iss, var->GetMutable(), ctx); + break; + case sendrecv::VarType::SELECTED_ROWS: { + DeserializeFromStream(iss, var->GetMutable(), + ctx); + break; + } + default: { + PADDLE_THROW("Deserialize does not support type: %s", + typeid(var->Type()).name()); + break; + } + } +} + } // namespace detail } // namespace operators } // namespace paddle diff --git a/paddle/operators/load_op.cc b/paddle/operators/load_op.cc index 08b972a233..7f551f101f 100644 --- a/paddle/operators/load_op.cc +++ b/paddle/operators/load_op.cc @@ -38,10 +38,10 @@ class LoadOp : public framework::OperatorBase { out_var_name); auto *tensor = out_var->GetMutable(); - DeserializeFromStream(fin, tensor); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(place); + DeserializeFromStream(fin, tensor, dev_ctx); if (platform::is_gpu_place(place)) { // copy CPU to GPU diff --git a/paddle/operators/recv_op.cc b/paddle/operators/recv_op.cc index 322f8571cf..82fceb3da7 100644 --- a/paddle/operators/recv_op.cc +++ b/paddle/operators/recv_op.cc @@ -60,7 +60,7 @@ class RecvOp : public framework::OperatorBase { } void Stop() override { - detail::TensorWithName term_msg; + detail::MessageWithName term_msg; term_msg.first = LISTEN_TERMINATE_MESSAGE; rpc_service_->Push(term_msg); rpc_server_->Shutdown(); @@ -94,7 +94,7 @@ class RecvOp : public framework::OperatorBase { // the gradient arrives, just add suffix 0~n then average the gradient. for (size_t i = 0; i < param_count * trainer_count; ++i) { // blocking get one var from client. - const detail::TensorWithName &v = rpc_service_->Get(); + const detail::MessageWithName &v = rpc_service_->Get(); auto grad_var_name = v.first; if (grad_var_name == LISTEN_TERMINATE_MESSAGE) { exit_flag = true; @@ -121,11 +121,10 @@ class RecvOp : public framework::OperatorBase { } auto *var = recv_scope.Var(grad_var_name); - auto *tensor = var->GetMutable(); - // FIXME(typhoonzero): do not copy - platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); - auto &dev_ctx = *pool.Borrow(dev_place); - framework::CopyFrom(v.second, dev_place, dev_ctx, tensor); + platform::DeviceContextPool &pool = + platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(dev_place); + detail::DeserializeFromMessage(v.second, dev_ctx, var); } if (exit_flag) { break; diff --git a/paddle/operators/send_recv_op_test.cc b/paddle/operators/send_recv_op_test.cc index 108e2dec6b..fa94424bf9 100644 --- a/paddle/operators/send_recv_op_test.cc +++ b/paddle/operators/send_recv_op_test.cc @@ -20,22 +20,27 @@ limitations under the License. */ #include "paddle/framework/op_registry.h" #include "paddle/framework/operator.h" #include "paddle/framework/program_desc.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/operators/math/selected_rows_functor.h" #include "paddle/string/printf.h" USE_NO_KERNEL_OP(send); USE_NO_KERNEL_OP(recv); USE_OP(sum); +namespace f = paddle::framework; +namespace p = paddle::platform; +namespace m = paddle::operators::math; + // global for simplicity. -std::unique_ptr recv_op; +std::unique_ptr recv_op; -void InitTensorsInScope(paddle::framework::Scope &scope, - paddle::platform::CPUPlace &place) { - paddle::platform::CPUDeviceContext ctx(place); +void InitTensorsInScope(f::Scope &scope, p::CPUPlace &place) { + p::CPUDeviceContext ctx(place); for (int i = 0; i < 2; ++i) { auto var_name = paddle::string::Sprintf("x%d", i); auto var = scope.Var(var_name); - auto tensor = var->GetMutable(); + auto tensor = var->GetMutable(); tensor->Resize({10, 10}); float *expect = tensor->mutable_data(place); for (int64_t i = 0; i < tensor->numel(); ++i) { @@ -44,21 +49,53 @@ void InitTensorsInScope(paddle::framework::Scope &scope, } auto out_var = scope.Var("Out"); - auto out_tensor = out_var->GetMutable(); + auto out_tensor = out_var->GetMutable(); out_tensor->Resize({10, 10}); out_tensor->mutable_data(place); // allocate } -void AddOp(const std::string &type, - const paddle::framework::VariableNameMap &inputs, - const paddle::framework::VariableNameMap &outputs, - paddle::framework::AttributeMap attrs, - paddle::framework::BlockDesc *block) { +void InitSelectedRowsInScope(f::Scope &scope, p::CPUPlace &place) { + p::CPUDeviceContext ctx(place); + int64_t height = 10; + int64_t row_numel = 10; + m::SetConstant set_one; + // init x0 + std::vector rows0{0, 4, 7}; + auto x0_var = scope.Var("x0"); + auto x0 = x0_var->GetMutable(); + x0->set_rows(rows0); + x0->set_height(height); + auto x0_value = x0->mutable_value(); + x0_value->mutable_data( + f::make_ddim({static_cast(rows0.size()), row_numel}), place); + set_one(ctx, x0_value, 1.0); + + // init x1 + std::vector rows1{2, 9}; + auto x1_var = scope.Var("x1"); + auto x1 = x1_var->GetMutable(); + x1->set_rows(rows1); + x1->set_height(height); + auto x1_value = x1->mutable_value(); + x1_value->mutable_data( + f::make_ddim({static_cast(rows1.size()), row_numel}), place); + set_one(ctx, x1_value, 1.0); + + auto out_var = scope.Var("Out"); + auto out = out_var->GetMutable(); + auto out_value = out->mutable_value(); + out->set_height(height); + out_value->mutable_data(f::make_ddim({5, 10}), place); +} + +void AddOp(const std::string &type, const f::VariableNameMap &inputs, + const f::VariableNameMap &outputs, f::AttributeMap attrs, + f::BlockDesc *block) { // insert output for (auto kv : outputs) { for (auto v : kv.second) { auto var = block->Var(v); - var->SetDataType(paddle::framework::proto::DataType::FP32); + var->SetDataType(f::proto::DataType::FP32); } } @@ -74,58 +111,99 @@ void AddOp(const std::string &type, op->SetAttrMap(attrs); } -void StartServerNet() { - paddle::framework::Scope scope; - paddle::platform::CPUPlace place; - InitTensorsInScope(scope, place); +void StartServerNet(bool is_sparse) { + f::Scope scope; + p::CPUPlace place; + if (is_sparse) { + InitSelectedRowsInScope(scope, place); + } else { + InitTensorsInScope(scope, place); + } // sub program run in recv_op, for simple test we use sum - paddle::framework::ProgramDesc program; - paddle::framework::BlockDesc *block = program.MutableBlock(0); + f::ProgramDesc program; + f::BlockDesc *block = program.MutableBlock(0); // X for server side tensors, RX for received tensers, must be of same shape. - AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"x0"}}}, {}, block); + AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, block); - paddle::framework::AttributeMap attrs; + f::AttributeMap attrs; attrs.insert({"endpoint", std::string("127.0.0.1:6174")}); - attrs.insert({"ParamList", std::vector({"x0"})}); + attrs.insert({"ParamList", std::vector({"Out"})}); attrs.insert({"GradList", std::vector({"x1"})}); std::string program_proto; PADDLE_ENFORCE(program.Proto()->SerializeToString(&program_proto)); attrs.insert({"OptimizeProgram", program_proto}); - recv_op = paddle::framework::OpRegistry::CreateOp("recv", {{"RX", {"x1"}}}, - {}, attrs); + recv_op = f::OpRegistry::CreateOp("recv", {{"RX", {"x1"}}}, {}, attrs); recv_op->Run(scope, place); } -TEST(SendRecvOp, CPU) { - std::thread server_thread(StartServerNet); - sleep(5); // wait server to start +TEST(SendRecvOp, CPUDense) { + std::thread server_thread(StartServerNet, false); + sleep(3); // wait server to start // local net - paddle::framework::Scope scope; - paddle::platform::CPUPlace place; + f::Scope scope; + p::CPUPlace place; InitTensorsInScope(scope, place); - paddle::framework::AttributeMap attrs; + f::AttributeMap attrs; attrs.insert({"endpoints", std::vector({"127.0.0.1:6174"})}); attrs.insert({"epmap", std::vector({"127.0.0.1:6174"})}); - auto send_op = paddle::framework::OpRegistry::CreateOp( - "send", {{"X", {"x1"}}}, {{"Out", {"x0"}}}, attrs); + auto send_op = f::OpRegistry::CreateOp("send", {{"X", {"x1"}}}, + {{"Out", {"Out"}}}, attrs); send_op->Run(scope, place); auto in_var = scope.Var("x1"); - auto tensor = in_var->GetMutable(); + auto tensor = in_var->GetMutable(); float *expected = tensor->data(); - auto out_var = scope.Var("x0"); - auto target = out_var->GetMutable(); + auto out_var = scope.Var("Out"); + auto target = out_var->GetMutable(); // x1 * 2 == x0 EXPECT_NE(target->memory_size(), size_t(0)); float *actual = target->data(); for (int64_t i = 0; i < target->numel(); ++i) { EXPECT_EQ(expected[i] * 2, actual[i]); } + recv_op->Stop(); + server_thread.join(); + recv_op.reset(nullptr); +} +TEST(SendRecvOp, CPUSparse) { + std::thread server_thread(StartServerNet, true); + sleep(3); // wait server to start + // local net + f::Scope scope; + p::CPUPlace place; + p::CPUDeviceContext ctx(place); + InitSelectedRowsInScope(scope, place); + f::AttributeMap attrs; + attrs.insert({"endpoints", std::vector({"127.0.0.1:6174"})}); + attrs.insert({"epmap", std::vector({"127.0.0.1:6174"})}); + auto send_op = f::OpRegistry::CreateOp("send", {{"X", {"x1"}}}, + {{"Out", {"Out"}}}, attrs); + send_op->Run(scope, place); + + auto x0 = scope.Var("x0")->GetMutable(); + auto x1 = scope.Var("x1")->GetMutable(); + auto out = scope.Var("Out")->GetMutable(); + auto actual = out->mutable_value(); + + std::unique_ptr expect{new f::SelectedRows()}; + auto expect_value = expect->mutable_value(); + expect_value->mutable_data(f::make_ddim({5, 10}), place); + + m::SelectedRowsAdd add_functor; + add_functor(ctx, *x0, *x1, expect.get()); + + EXPECT_EQ(actual->numel(), expect_value->numel()); + EXPECT_EQ(out->rows().size(), x0->rows().size() + x1->rows().size()); + + for (int64_t i = 0; i < expect_value->numel(); ++i) { + EXPECT_EQ(expect_value->mutable_data(place)[i], + actual->mutable_data(place)[i]); + } recv_op->Stop(); server_thread.join(); - // recv_op.reset(); + recv_op.reset(); } -- GitLab