diff --git a/paddle/fluid/framework/scope.cc b/paddle/fluid/framework/scope.cc index 194df3e4a8b50700e2be01ce5ebca83b92501fb8..f9866411417ece784aab860c6f707b1a1fcd8528 100644 --- a/paddle/fluid/framework/scope.cc +++ b/paddle/fluid/framework/scope.cc @@ -91,7 +91,7 @@ std::vector Scope::LocalVarNames() const { return known_vars; } -void Scope::DeleteScope(Scope* scope) { +void Scope::DeleteScope(Scope* scope) const { std::unique_lock lock(mutex_); auto it = std::find(this->kids_.begin(), this->kids_.end(), scope); PADDLE_ENFORCE(it != this->kids_.end(), "Cannot find %p as kid scope", scope); diff --git a/paddle/fluid/framework/scope.h b/paddle/fluid/framework/scope.h index c8cb70549f1d131b66fa7c6eeb35f3b7151a9e7f..abc82e452d732638a2f7315022074850f299a7ea 100644 --- a/paddle/fluid/framework/scope.h +++ b/paddle/fluid/framework/scope.h @@ -63,7 +63,7 @@ class Scope { /// Find the scope or an ancestor scope that contains the given variable. const Scope* FindScope(const Variable* var) const; - void DeleteScope(Scope* scope); + void DeleteScope(Scope* scope) const; /// Drop all kids scopes belonged to this scope. void DropKids(); diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 119e146e078e476b2768a8495ea63e468f952fd2..b92dc59491955ed319a2c2719244be93869a90c5 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -60,7 +60,7 @@ class RequestSend final : public RequestBase { framework::Scope* scope, ReceivedQueue* queue, const platform::DeviceContext* dev_ctx) : RequestBase(service, cq, dev_ctx), queue_(queue), responder_(&ctx_) { - request_.reset(new VariableResponse(scope, dev_ctx_)); + request_.reset(new VariableResponse(false, scope, dev_ctx_)); int method_id = static_cast(detail::GrpcMethod::kSendVariable); service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, cq_, cq_, this); @@ -146,7 +146,7 @@ class RequestPrefetch final : public RequestBase { executor_(executor), program_(program), prefetch_ctx_(prefetch_ctx) { - request_.reset(new VariableResponse(scope, dev_ctx_)); + request_.reset(new VariableResponse(false, scope, dev_ctx_)); int method_id = static_cast(detail::GrpcMethod::kPrefetchVariable); service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, cq_, cq_, this); diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.cc b/paddle/fluid/operators/detail/sendrecvop_utils.cc index 69fcffe9bc34006aef2e5a39227cf6d947e4615f..dbfd4e6a86a458694462e082d6c182f05eec7da7 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.cc +++ b/paddle/fluid/operators/detail/sendrecvop_utils.cc @@ -186,7 +186,7 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, const platform::DeviceContext& ctx, const framework::Scope* scope, framework::Variable** var) { - operators::detail::VariableResponse resp(scope, &ctx); + operators::detail::VariableResponse resp(false, scope, &ctx); PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!"); *var = resp.GetVar(); } diff --git a/paddle/fluid/operators/detail/serde_test.cc b/paddle/fluid/operators/detail/serde_test.cc index cb5f89583436b059ac4d6509dac9f2e3868561aa..fc9f60e3a51567839e901697f3bad027beeaa80a 100644 --- a/paddle/fluid/operators/detail/serde_test.cc +++ b/paddle/fluid/operators/detail/serde_test.cc @@ -84,7 +84,7 @@ void RunSerdeTestSelectedRows(platform::Place place) { // operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2); framework::Scope scope; scope.Var("myvar"); - operators::detail::VariableResponse resp(&scope, &ctx); + operators::detail::VariableResponse resp(false, &scope, &ctx); EXPECT_EQ(resp.Parse(msg), 0); framework::Variable* var2 = resp.GetVar(); @@ -171,7 +171,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) { // deserialize zero-copy framework::Scope scope; scope.Var("myvar"); - operators::detail::VariableResponse resp(&scope, &ctx); + operators::detail::VariableResponse resp(false, &scope, &ctx); if (from_type == 0) { EXPECT_EQ(resp.Parse(msg), 0); } else { diff --git a/paddle/fluid/operators/detail/variable_response.cc b/paddle/fluid/operators/detail/variable_response.cc index c9d7fd6d1581f6f4182e9e3e0d633c13a3c336a5..9185c7670b86fbe4315743c21ab55bd3f490b69e 100644 --- a/paddle/fluid/operators/detail/variable_response.cc +++ b/paddle/fluid/operators/detail/variable_response.cc @@ -114,8 +114,7 @@ bool VariableResponse::CopyLodTensorData( ::google::protobuf::io::CodedInputStream* input, const platform::DeviceContext& ctx, const framework::DDim& dims, int length) { - auto var = scope_->FindVar(meta_.varname()); - auto* tensor = var->GetMutable(); + auto* tensor = InitVar()->GetMutable(); tensor->Resize(dims); framework::LoD lod; @@ -151,8 +150,7 @@ bool VariableResponse::CopySelectRowsTensorData( ::google::protobuf::io::CodedInputStream* input, const platform::DeviceContext& ctx, const framework::DDim& dims, int length) { - auto var = scope_->FindVar(meta_.varname()); - auto* slr = var->GetMutable(); + auto* slr = InitVar()->GetMutable(); slr->set_height(meta_.slr_height()); auto* tensor = slr->mutable_value(); tensor->Resize(dims); @@ -174,8 +172,7 @@ bool VariableResponse::CopySelectRowsTensorData( bool VariableResponse::CopySelectRowsData( ::google::protobuf::io::CodedInputStream* input, const platform::DeviceContext& ctx, int length) { - auto var = scope_->FindVar(meta_.varname()); - auto* slr = var->GetMutable(); + auto* slr = InitVar()->GetMutable(); slr->mutable_rows()->resize(length / framework::SizeOfType(typeid(int64_t))); // int64 int64_t* rows_data = slr->mutable_rows()->data(); diff --git a/paddle/fluid/operators/detail/variable_response.h b/paddle/fluid/operators/detail/variable_response.h index 93b0d3cfb4f7d7f336414361773f872d7b259482..8e88836af059150d91324232df720c80e4f684c1 100644 --- a/paddle/fluid/operators/detail/variable_response.h +++ b/paddle/fluid/operators/detail/variable_response.h @@ -36,11 +36,13 @@ namespace detail { class VariableResponse { public: - VariableResponse(const framework::Scope* scope, + VariableResponse(bool use_local_scope, const framework::Scope* scope, const platform::DeviceContext* dev_ctx) - : scope_(scope), dev_ctx_(dev_ctx) {} + : use_local_scope_(use_local_scope), scope_(scope), dev_ctx_(dev_ctx) { + local_scope_ = &scope->NewScope(); + } - virtual ~VariableResponse() {} + virtual ~VariableResponse() { scope_->DeleteScope(local_scope_); } // return: // 0:ok. @@ -54,11 +56,23 @@ class VariableResponse { // other: number of error field. int Parse(const ::grpc::ByteBuffer& byte_buffer); + const framework::Scope& GetLocalScope() const { return *local_scope_; } + inline std::string Varname() { return meta_.varname(); } inline std::string OutVarname() { return meta_.out_varname(); } // should call parse first. - framework::Variable* GetVar() { return scope_->FindVar(meta_.varname()); } + framework::Variable* GetVar() { + return local_scope_->FindVar(meta_.varname()); + } + + framework::Variable* InitVar() { + if (use_local_scope_) { + return local_scope_->Var(meta_.varname()); + } else { + return scope_->FindVar(meta_.varname()); + } + } private: bool CopySelectRowsTensorData(::google::protobuf::io::CodedInputStream* input, @@ -73,7 +87,9 @@ class VariableResponse { const framework::DDim& dims, int length); private: + bool use_local_scope_ = false; const framework::Scope* scope_; + framework::Scope* local_scope_ = nullptr; const platform::DeviceContext* dev_ctx_; // only Skeleton sendrecv::VariableMessage meta_; diff --git a/paddle/fluid/operators/split_byref_op.h b/paddle/fluid/operators/split_byref_op.h index a3aad68ea736e223d3917607cca17f5cccfef630..fedd7218dd6cc9481e94a92a3820cafbe4157bd0 100644 --- a/paddle/fluid/operators/split_byref_op.h +++ b/paddle/fluid/operators/split_byref_op.h @@ -33,7 +33,7 @@ class SplitByrefOpKernel : public framework::OpKernel { // NOTE: no need to call mutable_data here to allocate memory. auto* out = outs[i]; VLOG(3) << "spliting by ref: " << row_offset << " " << out->dims()[0]; - *out = std::move(in->Slice(row_offset, row_offset + out->dims()[0])); + *out = in->Slice(row_offset, row_offset + out->dims()[0]); row_offset += out->dims()[0]; } }