diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 14e75e7b7b582d994b83d6c74ad9947135f6c449..3b7d61607301e685e67b5f4bc97fc837471e5722 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -128,7 +128,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( // // NOTE: DelayedOps have a lower priority. It will be scheduled after all // ready_ops have been performed. - if (ready_ops.empty() && allow_op_delay_) { + if (ready_ops.empty() && allow_op_delay_ && running_ops_ == 0) { run_all_ops(delayed_ops); } else { run_all_ops(ready_ops); diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index be14de9f5ed796f9bd914713d55679fbe16ea01c..27ddb675009f0adea6cf8f7349f3b538f8dd2ba7 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -68,9 +68,9 @@ class RequestSend final : public RequestBase { queue_(queue), responder_(&ctx_) { if (sync_mode_) { - request_.reset(new VariableResponse(false, scope, dev_ctx_)); + request_.reset(new VariableResponse(scope, dev_ctx_, false)); } else { - request_.reset(new VariableResponse(true, scope, dev_ctx_)); + request_.reset(new VariableResponse(scope, dev_ctx_, true)); } int method_id = static_cast(detail::GrpcMethod::kSendVariable); service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, @@ -158,9 +158,9 @@ class RequestPrefetch final : public RequestBase { program_(program), prefetch_ctx_(prefetch_ctx) { if (sync_mode_) { - request_.reset(new VariableResponse(false, scope, dev_ctx_)); + request_.reset(new VariableResponse(scope, dev_ctx_, false)); } else { - request_.reset(new VariableResponse(true, scope, dev_ctx_)); + request_.reset(new VariableResponse(scope, dev_ctx_, true)); } int method_id = static_cast(detail::GrpcMethod::kPrefetchVariable); service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.cc b/paddle/fluid/operators/detail/sendrecvop_utils.cc index dbfd4e6a86a458694462e082d6c182f05eec7da7..69fcffe9bc34006aef2e5a39227cf6d947e4615f 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.cc +++ b/paddle/fluid/operators/detail/sendrecvop_utils.cc @@ -186,7 +186,7 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, const platform::DeviceContext& ctx, const framework::Scope* scope, framework::Variable** var) { - operators::detail::VariableResponse resp(false, scope, &ctx); + operators::detail::VariableResponse resp(scope, &ctx); PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!"); *var = resp.GetVar(); } diff --git a/paddle/fluid/operators/detail/serde_test.cc b/paddle/fluid/operators/detail/serde_test.cc index fc9f60e3a51567839e901697f3bad027beeaa80a..221d2f4c5b30aef022a5d6b54cd657d1dec1f5a2 100644 --- a/paddle/fluid/operators/detail/serde_test.cc +++ b/paddle/fluid/operators/detail/serde_test.cc @@ -51,7 +51,7 @@ void RunSerdeTestSelectedRows(platform::Place place) { ::grpc::ByteBuffer msg; operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg); - EXPECT_GT(msg.Length(), 0); + EXPECT_GT(msg.Length(), static_cast(0)); // deserialize std::vector<::grpc::Slice> slices; @@ -84,7 +84,7 @@ void RunSerdeTestSelectedRows(platform::Place place) { // operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2); framework::Scope scope; scope.Var("myvar"); - operators::detail::VariableResponse resp(false, &scope, &ctx); + operators::detail::VariableResponse resp(&scope, &ctx); EXPECT_EQ(resp.Parse(msg), 0); framework::Variable* var2 = resp.GetVar(); @@ -129,7 +129,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) { ::grpc::ByteBuffer msg; operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg); - EXPECT_GT(msg.Length(), 0); + EXPECT_GT(msg.Length(), static_cast(0)); // deserialize std::vector<::grpc::Slice> slices; @@ -171,7 +171,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) { // deserialize zero-copy framework::Scope scope; scope.Var("myvar"); - operators::detail::VariableResponse resp(false, &scope, &ctx); + operators::detail::VariableResponse resp(&scope, &ctx); if (from_type == 0) { EXPECT_EQ(resp.Parse(msg), 0); } else { diff --git a/paddle/fluid/operators/detail/variable_response.cc b/paddle/fluid/operators/detail/variable_response.cc index 9185c7670b86fbe4315743c21ab55bd3f490b69e..fbef8d02a4d765052fccf3792ebe0373d46b1ef6 100644 --- a/paddle/fluid/operators/detail/variable_response.cc +++ b/paddle/fluid/operators/detail/variable_response.cc @@ -114,7 +114,7 @@ bool VariableResponse::CopyLodTensorData( ::google::protobuf::io::CodedInputStream* input, const platform::DeviceContext& ctx, const framework::DDim& dims, int length) { - auto* tensor = InitVar()->GetMutable(); + auto* tensor = GetVar()->GetMutable(); tensor->Resize(dims); framework::LoD lod; @@ -150,7 +150,7 @@ bool VariableResponse::CopySelectRowsTensorData( ::google::protobuf::io::CodedInputStream* input, const platform::DeviceContext& ctx, const framework::DDim& dims, int length) { - auto* slr = InitVar()->GetMutable(); + auto* slr = GetVar()->GetMutable(); slr->set_height(meta_.slr_height()); auto* tensor = slr->mutable_value(); tensor->Resize(dims); @@ -172,7 +172,7 @@ bool VariableResponse::CopySelectRowsTensorData( bool VariableResponse::CopySelectRowsData( ::google::protobuf::io::CodedInputStream* input, const platform::DeviceContext& ctx, int length) { - auto* slr = InitVar()->GetMutable(); + auto* slr = GetVar()->GetMutable(); slr->mutable_rows()->resize(length / framework::SizeOfType(typeid(int64_t))); // int64 int64_t* rows_data = slr->mutable_rows()->data(); diff --git a/paddle/fluid/operators/detail/variable_response.h b/paddle/fluid/operators/detail/variable_response.h index 8a1cab61cedb7fbd4f24333aa7b5b07a7084d173..3018a5c4af876828380ff4c1cbfdaafa8a2057e1 100644 --- a/paddle/fluid/operators/detail/variable_response.h +++ b/paddle/fluid/operators/detail/variable_response.h @@ -36,13 +36,18 @@ namespace detail { class VariableResponse { public: - VariableResponse(bool use_local_scope, const framework::Scope* scope, - const platform::DeviceContext* dev_ctx) - : use_local_scope_(use_local_scope), scope_(scope), dev_ctx_(dev_ctx) { - local_scope_ = &scope->NewScope(); + VariableResponse(const framework::Scope* scope, + const platform::DeviceContext* dev_ctx, + bool create_scope = false) + : scope_(scope), dev_ctx_(dev_ctx), create_scope_(create_scope) { + if (create_scope) { + local_scope_ = &scope->NewScope(); + } } - virtual ~VariableResponse() { scope_->DeleteScope(local_scope_); } + virtual ~VariableResponse() { + if (create_scope_) scope_->DeleteScope(local_scope_); + } // return: // 0:ok. @@ -63,17 +68,10 @@ class VariableResponse { // should call parse first. framework::Variable* GetVar() { - return local_scope_->FindVar(meta_.varname()); - } - - framework::Variable* InitVar() { - if (use_local_scope_) { - bool has_var = (scope_->FindVar(meta_.varname()) != nullptr); - PADDLE_ENFORCE(has_var); + if (create_scope_) { return local_scope_->Var(meta_.varname()); - } else { - return scope_->FindVar(meta_.varname()); } + return scope_->FindVar(meta_.varname()); } private: @@ -89,10 +87,10 @@ class VariableResponse { const framework::DDim& dims, int length); private: - bool use_local_scope_ = false; const framework::Scope* scope_; - framework::Scope* local_scope_ = nullptr; const platform::DeviceContext* dev_ctx_; + bool create_scope_ = false; + framework::Scope* local_scope_ = nullptr; // only Skeleton sendrecv::VariableMessage meta_; };