diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index b92dc59491955ed319a2c2719244be93869a90c5..119e146e078e476b2768a8495ea63e468f952fd2 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -60,7 +60,7 @@ class RequestSend final : public RequestBase { framework::Scope* scope, ReceivedQueue* queue, const platform::DeviceContext* dev_ctx) : RequestBase(service, cq, dev_ctx), queue_(queue), responder_(&ctx_) { - request_.reset(new VariableResponse(false, scope, dev_ctx_)); + request_.reset(new VariableResponse(scope, dev_ctx_)); int method_id = static_cast(detail::GrpcMethod::kSendVariable); service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, cq_, cq_, this); @@ -146,7 +146,7 @@ class RequestPrefetch final : public RequestBase { executor_(executor), program_(program), prefetch_ctx_(prefetch_ctx) { - request_.reset(new VariableResponse(false, scope, dev_ctx_)); + request_.reset(new VariableResponse(scope, dev_ctx_)); int method_id = static_cast(detail::GrpcMethod::kPrefetchVariable); service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, cq_, cq_, this); diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.cc b/paddle/fluid/operators/detail/sendrecvop_utils.cc index dbfd4e6a86a458694462e082d6c182f05eec7da7..69fcffe9bc34006aef2e5a39227cf6d947e4615f 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.cc +++ b/paddle/fluid/operators/detail/sendrecvop_utils.cc @@ -186,7 +186,7 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, const platform::DeviceContext& ctx, const framework::Scope* scope, framework::Variable** var) { - operators::detail::VariableResponse resp(false, scope, &ctx); + operators::detail::VariableResponse resp(scope, &ctx); PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!"); *var = resp.GetVar(); } diff --git a/paddle/fluid/operators/detail/serde_test.cc b/paddle/fluid/operators/detail/serde_test.cc index fc9f60e3a51567839e901697f3bad027beeaa80a..221d2f4c5b30aef022a5d6b54cd657d1dec1f5a2 100644 --- a/paddle/fluid/operators/detail/serde_test.cc +++ b/paddle/fluid/operators/detail/serde_test.cc @@ -51,7 +51,7 @@ void RunSerdeTestSelectedRows(platform::Place place) { ::grpc::ByteBuffer msg; operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg); - EXPECT_GT(msg.Length(), 0); + EXPECT_GT(msg.Length(), static_cast(0)); // deserialize std::vector<::grpc::Slice> slices; @@ -84,7 +84,7 @@ void RunSerdeTestSelectedRows(platform::Place place) { // operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2); framework::Scope scope; scope.Var("myvar"); - operators::detail::VariableResponse resp(false, &scope, &ctx); + operators::detail::VariableResponse resp(&scope, &ctx); EXPECT_EQ(resp.Parse(msg), 0); framework::Variable* var2 = resp.GetVar(); @@ -129,7 +129,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) { ::grpc::ByteBuffer msg; operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg); - EXPECT_GT(msg.Length(), 0); + EXPECT_GT(msg.Length(), static_cast(0)); // deserialize std::vector<::grpc::Slice> slices; @@ -171,7 +171,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) { // deserialize zero-copy framework::Scope scope; scope.Var("myvar"); - operators::detail::VariableResponse resp(false, &scope, &ctx); + operators::detail::VariableResponse resp(&scope, &ctx); if (from_type == 0) { EXPECT_EQ(resp.Parse(msg), 0); } else { diff --git a/paddle/fluid/operators/detail/variable_response.cc b/paddle/fluid/operators/detail/variable_response.cc index 9185c7670b86fbe4315743c21ab55bd3f490b69e..fbef8d02a4d765052fccf3792ebe0373d46b1ef6 100644 --- a/paddle/fluid/operators/detail/variable_response.cc +++ b/paddle/fluid/operators/detail/variable_response.cc @@ -114,7 +114,7 @@ bool VariableResponse::CopyLodTensorData( ::google::protobuf::io::CodedInputStream* input, const platform::DeviceContext& ctx, const framework::DDim& dims, int length) { - auto* tensor = InitVar()->GetMutable(); + auto* tensor = GetVar()->GetMutable(); tensor->Resize(dims); framework::LoD lod; @@ -150,7 +150,7 @@ bool VariableResponse::CopySelectRowsTensorData( ::google::protobuf::io::CodedInputStream* input, const platform::DeviceContext& ctx, const framework::DDim& dims, int length) { - auto* slr = InitVar()->GetMutable(); + auto* slr = GetVar()->GetMutable(); slr->set_height(meta_.slr_height()); auto* tensor = slr->mutable_value(); tensor->Resize(dims); @@ -172,7 +172,7 @@ bool VariableResponse::CopySelectRowsTensorData( bool VariableResponse::CopySelectRowsData( ::google::protobuf::io::CodedInputStream* input, const platform::DeviceContext& ctx, int length) { - auto* slr = InitVar()->GetMutable(); + auto* slr = GetVar()->GetMutable(); slr->mutable_rows()->resize(length / framework::SizeOfType(typeid(int64_t))); // int64 int64_t* rows_data = slr->mutable_rows()->data(); diff --git a/paddle/fluid/operators/detail/variable_response.h b/paddle/fluid/operators/detail/variable_response.h index 8a1cab61cedb7fbd4f24333aa7b5b07a7084d173..3018a5c4af876828380ff4c1cbfdaafa8a2057e1 100644 --- a/paddle/fluid/operators/detail/variable_response.h +++ b/paddle/fluid/operators/detail/variable_response.h @@ -36,13 +36,18 @@ namespace detail { class VariableResponse { public: - VariableResponse(bool use_local_scope, const framework::Scope* scope, - const platform::DeviceContext* dev_ctx) - : use_local_scope_(use_local_scope), scope_(scope), dev_ctx_(dev_ctx) { - local_scope_ = &scope->NewScope(); + VariableResponse(const framework::Scope* scope, + const platform::DeviceContext* dev_ctx, + bool create_scope = false) + : scope_(scope), dev_ctx_(dev_ctx), create_scope_(create_scope) { + if (create_scope) { + local_scope_ = &scope->NewScope(); + } } - virtual ~VariableResponse() { scope_->DeleteScope(local_scope_); } + virtual ~VariableResponse() { + if (create_scope_) scope_->DeleteScope(local_scope_); + } // return: // 0:ok. @@ -63,17 +68,10 @@ class VariableResponse { // should call parse first. framework::Variable* GetVar() { - return local_scope_->FindVar(meta_.varname()); - } - - framework::Variable* InitVar() { - if (use_local_scope_) { - bool has_var = (scope_->FindVar(meta_.varname()) != nullptr); - PADDLE_ENFORCE(has_var); + if (create_scope_) { return local_scope_->Var(meta_.varname()); - } else { - return scope_->FindVar(meta_.varname()); } + return scope_->FindVar(meta_.varname()); } private: @@ -89,10 +87,10 @@ class VariableResponse { const framework::DDim& dims, int length); private: - bool use_local_scope_ = false; const framework::Scope* scope_; - framework::Scope* local_scope_ = nullptr; const platform::DeviceContext* dev_ctx_; + bool create_scope_ = false; + framework::Scope* local_scope_ = nullptr; // only Skeleton sendrecv::VariableMessage meta_; };