提交 8023c6d7 编写于 作者: Y Yancey1989

Create sub socpe when it is necessary

上级 d89a3068
...@@ -60,7 +60,7 @@ class RequestSend final : public RequestBase { ...@@ -60,7 +60,7 @@ class RequestSend final : public RequestBase {
framework::Scope* scope, ReceivedQueue* queue, framework::Scope* scope, ReceivedQueue* queue,
const platform::DeviceContext* dev_ctx) const platform::DeviceContext* dev_ctx)
: RequestBase(service, cq, dev_ctx), queue_(queue), responder_(&ctx_) { : RequestBase(service, cq, dev_ctx), queue_(queue), responder_(&ctx_) {
request_.reset(new VariableResponse(false, scope, dev_ctx_)); request_.reset(new VariableResponse(scope, dev_ctx_));
int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable); int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable);
service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_,
cq_, cq_, this); cq_, cq_, this);
...@@ -146,7 +146,7 @@ class RequestPrefetch final : public RequestBase { ...@@ -146,7 +146,7 @@ class RequestPrefetch final : public RequestBase {
executor_(executor), executor_(executor),
program_(program), program_(program),
prefetch_ctx_(prefetch_ctx) { prefetch_ctx_(prefetch_ctx) {
request_.reset(new VariableResponse(false, scope, dev_ctx_)); request_.reset(new VariableResponse(scope, dev_ctx_));
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable); int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_,
cq_, cq_, this); cq_, cq_, this);
......
...@@ -186,7 +186,7 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, ...@@ -186,7 +186,7 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope* scope, const framework::Scope* scope,
framework::Variable** var) { framework::Variable** var) {
operators::detail::VariableResponse resp(false, scope, &ctx); operators::detail::VariableResponse resp(scope, &ctx);
PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!"); PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!");
*var = resp.GetVar(); *var = resp.GetVar();
} }
......
...@@ -51,7 +51,7 @@ void RunSerdeTestSelectedRows(platform::Place place) { ...@@ -51,7 +51,7 @@ void RunSerdeTestSelectedRows(platform::Place place) {
::grpc::ByteBuffer msg; ::grpc::ByteBuffer msg;
operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg); operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg);
EXPECT_GT(msg.Length(), 0); EXPECT_GT(msg.Length(), static_cast<size_t>(0));
// deserialize // deserialize
std::vector<::grpc::Slice> slices; std::vector<::grpc::Slice> slices;
...@@ -84,7 +84,7 @@ void RunSerdeTestSelectedRows(platform::Place place) { ...@@ -84,7 +84,7 @@ void RunSerdeTestSelectedRows(platform::Place place) {
// operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2); // operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2);
framework::Scope scope; framework::Scope scope;
scope.Var("myvar"); scope.Var("myvar");
operators::detail::VariableResponse resp(false, &scope, &ctx); operators::detail::VariableResponse resp(&scope, &ctx);
EXPECT_EQ(resp.Parse(msg), 0); EXPECT_EQ(resp.Parse(msg), 0);
framework::Variable* var2 = resp.GetVar(); framework::Variable* var2 = resp.GetVar();
...@@ -129,7 +129,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) { ...@@ -129,7 +129,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
::grpc::ByteBuffer msg; ::grpc::ByteBuffer msg;
operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg); operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg);
EXPECT_GT(msg.Length(), 0); EXPECT_GT(msg.Length(), static_cast<size_t>(0));
// deserialize // deserialize
std::vector<::grpc::Slice> slices; std::vector<::grpc::Slice> slices;
...@@ -171,7 +171,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) { ...@@ -171,7 +171,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
// deserialize zero-copy // deserialize zero-copy
framework::Scope scope; framework::Scope scope;
scope.Var("myvar"); scope.Var("myvar");
operators::detail::VariableResponse resp(false, &scope, &ctx); operators::detail::VariableResponse resp(&scope, &ctx);
if (from_type == 0) { if (from_type == 0) {
EXPECT_EQ(resp.Parse(msg), 0); EXPECT_EQ(resp.Parse(msg), 0);
} else { } else {
......
...@@ -114,7 +114,7 @@ bool VariableResponse::CopyLodTensorData( ...@@ -114,7 +114,7 @@ bool VariableResponse::CopyLodTensorData(
::google::protobuf::io::CodedInputStream* input, ::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, const framework::DDim& dims, const platform::DeviceContext& ctx, const framework::DDim& dims,
int length) { int length) {
auto* tensor = InitVar()->GetMutable<framework::LoDTensor>(); auto* tensor = GetVar()->GetMutable<framework::LoDTensor>();
tensor->Resize(dims); tensor->Resize(dims);
framework::LoD lod; framework::LoD lod;
...@@ -150,7 +150,7 @@ bool VariableResponse::CopySelectRowsTensorData( ...@@ -150,7 +150,7 @@ bool VariableResponse::CopySelectRowsTensorData(
::google::protobuf::io::CodedInputStream* input, ::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, const framework::DDim& dims, const platform::DeviceContext& ctx, const framework::DDim& dims,
int length) { int length) {
auto* slr = InitVar()->GetMutable<framework::SelectedRows>(); auto* slr = GetVar()->GetMutable<framework::SelectedRows>();
slr->set_height(meta_.slr_height()); slr->set_height(meta_.slr_height());
auto* tensor = slr->mutable_value(); auto* tensor = slr->mutable_value();
tensor->Resize(dims); tensor->Resize(dims);
...@@ -172,7 +172,7 @@ bool VariableResponse::CopySelectRowsTensorData( ...@@ -172,7 +172,7 @@ bool VariableResponse::CopySelectRowsTensorData(
bool VariableResponse::CopySelectRowsData( bool VariableResponse::CopySelectRowsData(
::google::protobuf::io::CodedInputStream* input, ::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, int length) { const platform::DeviceContext& ctx, int length) {
auto* slr = InitVar()->GetMutable<framework::SelectedRows>(); auto* slr = GetVar()->GetMutable<framework::SelectedRows>();
slr->mutable_rows()->resize(length / slr->mutable_rows()->resize(length /
framework::SizeOfType(typeid(int64_t))); // int64 framework::SizeOfType(typeid(int64_t))); // int64
int64_t* rows_data = slr->mutable_rows()->data(); int64_t* rows_data = slr->mutable_rows()->data();
......
...@@ -36,13 +36,18 @@ namespace detail { ...@@ -36,13 +36,18 @@ namespace detail {
class VariableResponse { class VariableResponse {
public: public:
VariableResponse(bool use_local_scope, const framework::Scope* scope, VariableResponse(const framework::Scope* scope,
const platform::DeviceContext* dev_ctx) const platform::DeviceContext* dev_ctx,
: use_local_scope_(use_local_scope), scope_(scope), dev_ctx_(dev_ctx) { bool create_scope = false)
: scope_(scope), dev_ctx_(dev_ctx), create_scope_(create_scope) {
if (create_scope) {
local_scope_ = &scope->NewScope(); local_scope_ = &scope->NewScope();
} }
}
virtual ~VariableResponse() { scope_->DeleteScope(local_scope_); } virtual ~VariableResponse() {
if (create_scope_) scope_->DeleteScope(local_scope_);
}
// return: // return:
// 0:ok. // 0:ok.
...@@ -63,17 +68,10 @@ class VariableResponse { ...@@ -63,17 +68,10 @@ class VariableResponse {
// should call parse first. // should call parse first.
framework::Variable* GetVar() { framework::Variable* GetVar() {
return local_scope_->FindVar(meta_.varname()); if (create_scope_) {
}
framework::Variable* InitVar() {
if (use_local_scope_) {
bool has_var = (scope_->FindVar(meta_.varname()) != nullptr);
PADDLE_ENFORCE(has_var);
return local_scope_->Var(meta_.varname()); return local_scope_->Var(meta_.varname());
} else {
return scope_->FindVar(meta_.varname());
} }
return scope_->FindVar(meta_.varname());
} }
private: private:
...@@ -89,10 +87,10 @@ class VariableResponse { ...@@ -89,10 +87,10 @@ class VariableResponse {
const framework::DDim& dims, int length); const framework::DDim& dims, int length);
private: private:
bool use_local_scope_ = false;
const framework::Scope* scope_; const framework::Scope* scope_;
framework::Scope* local_scope_ = nullptr;
const platform::DeviceContext* dev_ctx_; const platform::DeviceContext* dev_ctx_;
bool create_scope_ = false;
framework::Scope* local_scope_ = nullptr;
// only Skeleton // only Skeleton
sendrecv::VariableMessage meta_; sendrecv::VariableMessage meta_;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册