From 0ffd33d30e2361a0f766b7f39824d3e7156f3453 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Thu, 19 Apr 2018 19:58:27 +0800 Subject: [PATCH] VariableResponse support deserialize var into local scope --- paddle/fluid/framework/scope.cc | 2 +- paddle/fluid/framework/scope.h | 2 +- paddle/fluid/operators/detail/grpc_server.cc | 4 ++-- .../operators/detail/sendrecvop_utils.cc | 2 +- paddle/fluid/operators/detail/serde_test.cc | 4 ++-- .../operators/detail/variable_response.cc | 9 +++---- .../operators/detail/variable_response.h | 24 +++++++++++++++---- paddle/fluid/operators/split_byref_op.h | 2 +- 8 files changed, 31 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/framework/scope.cc b/paddle/fluid/framework/scope.cc index 194df3e4a8..f986641141 100644 --- a/paddle/fluid/framework/scope.cc +++ b/paddle/fluid/framework/scope.cc @@ -91,7 +91,7 @@ std::vector Scope::LocalVarNames() const { return known_vars; } -void Scope::DeleteScope(Scope* scope) { +void Scope::DeleteScope(Scope* scope) const { std::unique_lock lock(mutex_); auto it = std::find(this->kids_.begin(), this->kids_.end(), scope); PADDLE_ENFORCE(it != this->kids_.end(), "Cannot find %p as kid scope", scope); diff --git a/paddle/fluid/framework/scope.h b/paddle/fluid/framework/scope.h index c8cb70549f..abc82e452d 100644 --- a/paddle/fluid/framework/scope.h +++ b/paddle/fluid/framework/scope.h @@ -63,7 +63,7 @@ class Scope { /// Find the scope or an ancestor scope that contains the given variable. const Scope* FindScope(const Variable* var) const; - void DeleteScope(Scope* scope); + void DeleteScope(Scope* scope) const; /// Drop all kids scopes belonged to this scope. void DropKids(); diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 119e146e07..b92dc59491 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -60,7 +60,7 @@ class RequestSend final : public RequestBase { framework::Scope* scope, ReceivedQueue* queue, const platform::DeviceContext* dev_ctx) : RequestBase(service, cq, dev_ctx), queue_(queue), responder_(&ctx_) { - request_.reset(new VariableResponse(scope, dev_ctx_)); + request_.reset(new VariableResponse(false, scope, dev_ctx_)); int method_id = static_cast(detail::GrpcMethod::kSendVariable); service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, cq_, cq_, this); @@ -146,7 +146,7 @@ class RequestPrefetch final : public RequestBase { executor_(executor), program_(program), prefetch_ctx_(prefetch_ctx) { - request_.reset(new VariableResponse(scope, dev_ctx_)); + request_.reset(new VariableResponse(false, scope, dev_ctx_)); int method_id = static_cast(detail::GrpcMethod::kPrefetchVariable); service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, cq_, cq_, this); diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.cc b/paddle/fluid/operators/detail/sendrecvop_utils.cc index 69fcffe9bc..dbfd4e6a86 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.cc +++ b/paddle/fluid/operators/detail/sendrecvop_utils.cc @@ -186,7 +186,7 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, const platform::DeviceContext& ctx, const framework::Scope* scope, framework::Variable** var) { - operators::detail::VariableResponse resp(scope, &ctx); + operators::detail::VariableResponse resp(false, scope, &ctx); PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!"); *var = resp.GetVar(); } diff --git a/paddle/fluid/operators/detail/serde_test.cc b/paddle/fluid/operators/detail/serde_test.cc index cb5f895834..fc9f60e3a5 100644 --- a/paddle/fluid/operators/detail/serde_test.cc +++ b/paddle/fluid/operators/detail/serde_test.cc @@ -84,7 +84,7 @@ void RunSerdeTestSelectedRows(platform::Place place) { // operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2); framework::Scope scope; scope.Var("myvar"); - operators::detail::VariableResponse resp(&scope, &ctx); + operators::detail::VariableResponse resp(false, &scope, &ctx); EXPECT_EQ(resp.Parse(msg), 0); framework::Variable* var2 = resp.GetVar(); @@ -171,7 +171,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) { // deserialize zero-copy framework::Scope scope; scope.Var("myvar"); - operators::detail::VariableResponse resp(&scope, &ctx); + operators::detail::VariableResponse resp(false, &scope, &ctx); if (from_type == 0) { EXPECT_EQ(resp.Parse(msg), 0); } else { diff --git a/paddle/fluid/operators/detail/variable_response.cc b/paddle/fluid/operators/detail/variable_response.cc index c9d7fd6d15..9185c7670b 100644 --- a/paddle/fluid/operators/detail/variable_response.cc +++ b/paddle/fluid/operators/detail/variable_response.cc @@ -114,8 +114,7 @@ bool VariableResponse::CopyLodTensorData( ::google::protobuf::io::CodedInputStream* input, const platform::DeviceContext& ctx, const framework::DDim& dims, int length) { - auto var = scope_->FindVar(meta_.varname()); - auto* tensor = var->GetMutable(); + auto* tensor = InitVar()->GetMutable(); tensor->Resize(dims); framework::LoD lod; @@ -151,8 +150,7 @@ bool VariableResponse::CopySelectRowsTensorData( ::google::protobuf::io::CodedInputStream* input, const platform::DeviceContext& ctx, const framework::DDim& dims, int length) { - auto var = scope_->FindVar(meta_.varname()); - auto* slr = var->GetMutable(); + auto* slr = InitVar()->GetMutable(); slr->set_height(meta_.slr_height()); auto* tensor = slr->mutable_value(); tensor->Resize(dims); @@ -174,8 +172,7 @@ bool VariableResponse::CopySelectRowsTensorData( bool VariableResponse::CopySelectRowsData( ::google::protobuf::io::CodedInputStream* input, const platform::DeviceContext& ctx, int length) { - auto var = scope_->FindVar(meta_.varname()); - auto* slr = var->GetMutable(); + auto* slr = InitVar()->GetMutable(); slr->mutable_rows()->resize(length / framework::SizeOfType(typeid(int64_t))); // int64 int64_t* rows_data = slr->mutable_rows()->data(); diff --git a/paddle/fluid/operators/detail/variable_response.h b/paddle/fluid/operators/detail/variable_response.h index 93b0d3cfb4..8e88836af0 100644 --- a/paddle/fluid/operators/detail/variable_response.h +++ b/paddle/fluid/operators/detail/variable_response.h @@ -36,11 +36,13 @@ namespace detail { class VariableResponse { public: - VariableResponse(const framework::Scope* scope, + VariableResponse(bool use_local_scope, const framework::Scope* scope, const platform::DeviceContext* dev_ctx) - : scope_(scope), dev_ctx_(dev_ctx) {} + : use_local_scope_(use_local_scope), scope_(scope), dev_ctx_(dev_ctx) { + local_scope_ = &scope->NewScope(); + } - virtual ~VariableResponse() {} + virtual ~VariableResponse() { scope_->DeleteScope(local_scope_); } // return: // 0:ok. @@ -54,11 +56,23 @@ class VariableResponse { // other: number of error field. int Parse(const ::grpc::ByteBuffer& byte_buffer); + const framework::Scope& GetLocalScope() const { return *local_scope_; } + inline std::string Varname() { return meta_.varname(); } inline std::string OutVarname() { return meta_.out_varname(); } // should call parse first. - framework::Variable* GetVar() { return scope_->FindVar(meta_.varname()); } + framework::Variable* GetVar() { + return local_scope_->FindVar(meta_.varname()); + } + + framework::Variable* InitVar() { + if (use_local_scope_) { + return local_scope_->Var(meta_.varname()); + } else { + return scope_->FindVar(meta_.varname()); + } + } private: bool CopySelectRowsTensorData(::google::protobuf::io::CodedInputStream* input, @@ -73,7 +87,9 @@ class VariableResponse { const framework::DDim& dims, int length); private: + bool use_local_scope_ = false; const framework::Scope* scope_; + framework::Scope* local_scope_ = nullptr; const platform::DeviceContext* dev_ctx_; // only Skeleton sendrecv::VariableMessage meta_; diff --git a/paddle/fluid/operators/split_byref_op.h b/paddle/fluid/operators/split_byref_op.h index a3aad68ea7..fedd7218dd 100644 --- a/paddle/fluid/operators/split_byref_op.h +++ b/paddle/fluid/operators/split_byref_op.h @@ -33,7 +33,7 @@ class SplitByrefOpKernel : public framework::OpKernel { // NOTE: no need to call mutable_data here to allocate memory. auto* out = outs[i]; VLOG(3) << "spliting by ref: " << row_offset << " " << out->dims()[0]; - *out = std::move(in->Slice(row_offset, row_offset + out->dims()[0])); + *out = in->Slice(row_offset, row_offset + out->dims()[0]); row_offset += out->dims()[0]; } } -- GitLab