提交 9861a92f 编写于 作者: Q Qiao Longfei

change the return type of NewTempScope to unique ptr test=develop

上级 fb6cc3a1
......@@ -59,7 +59,9 @@ Scope& Scope::NewScope() const {
return *child;
}
Scope* Scope::NewTmpScope() const { return new Scope(this); }
std::unique_ptr<Scope> Scope::NewTmpScope() const {
return std::unique_ptr<Scope>(new Scope(this));
}
Variable* Scope::Var(const std::string& name) {
SCOPE_VARS_WRITER_LOCK
......
......@@ -54,9 +54,7 @@ class Scope {
/// Create a sub-scope for current scope but do not record it in the kids to
/// avoid performance problems.
/// Note!!! You should delete the result pointer yourself to avoid memory
/// leak!
Scope* NewTmpScope() const;
std::unique_ptr<Scope> NewTmpScope() const;
/// Create a variable with given name if it doesn't exist.
/// Caller doesn't own the returned Variable.
......
......@@ -160,7 +160,7 @@ void prefetch(const std::string& id_name, const std::string& out_name,
const std::vector<int64_t>& height_sections,
const framework::ExecutionContext& context,
const framework::Scope& scope) {
framework::Scope* local_scope = scope.NewTmpScope();
std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope();
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& cpu_ctx = *pool.Get(platform::CPUPlace());
......@@ -206,7 +206,7 @@ void prefetch(const std::string& id_name, const std::string& out_name,
auto splited_ids = SplitIds(ids_vector, height_sections);
SplitIdsIntoMultipleVarsBySection(in_var_names, height_sections, splited_ids,
local_scope);
local_scope.get());
// create output var in local scope
for (auto& name : out_var_names) {
......@@ -215,12 +215,12 @@ void prefetch(const std::string& id_name, const std::string& out_name,
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < in_var_names.size(); i++) {
if (NeedSend(*local_scope, in_var_names[i])) {
if (NeedSend(*local_scope.get(), in_var_names[i])) {
VLOG(3) << "sending " << in_var_names[i] << " to " << epmap[i]
<< " to get " << out_var_names[i] << " back";
rets.push_back(rpc_client->AsyncPrefetchVar(
epmap[i], cpu_ctx, *local_scope, in_var_names[i], out_var_names[i],
table_names[i]));
epmap[i], cpu_ctx, *local_scope.get(), in_var_names[i],
out_var_names[i], table_names[i]));
} else {
VLOG(3) << "don't send no-initialied variable: " << out_var_names[i];
}
......@@ -232,8 +232,7 @@ void prefetch(const std::string& id_name, const std::string& out_name,
MergeMultipleVarsIntoOneBySection(id_name, ids_vector, out_name,
out_var_names, height_sections, splited_ids,
context, local_scope, &actual_ctx);
delete local_scope;
context, local_scope.get(), &actual_ctx);
}
}; // namespace distributed
......
......@@ -42,7 +42,7 @@ template <typename T>
void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
const framework::Scope &scope) {
VLOG(3) << "ParameterRecv in";
framework::Scope *local_scope = scope.NewTmpScope();
std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &cpu_ctx = *pool.Get(platform::CPUPlace());
......@@ -64,7 +64,7 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
recved_tensors.push_back(t);
VLOG(3) << "recv " << recv_var_name << " from " << rpc_ctx.epmap[i];
rets.push_back(rpc_client->AsyncGetVar(rpc_ctx.epmap[i], cpu_ctx,
*local_scope, recv_var_name,
*local_scope.get(), recv_var_name,
recv_var_name));
}
for (size_t i = 0; i < rets.size(); i++) {
......@@ -93,7 +93,6 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
PADDLE_ENFORCE_EQ(recv_numel, recv_tensor->numel());
}
delete local_scope;
VLOG(3) << "ParameterRecv out";
}
......
......@@ -40,7 +40,7 @@ using DDim = framework::DDim;
template <typename T>
void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
const framework::Scope &scope, bool sync) {
framework::Scope *local_scope = scope.NewTmpScope();
std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &cpu_ctx = *pool.Get(platform::CPUPlace());
......@@ -150,10 +150,10 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
auto &send_var_name = rpc_ctx.splited_var_names[i];
auto &endpoint = rpc_ctx.epmap[i];
if (NeedSend(*local_scope, send_var_name)) {
if (NeedSend(*local_scope.get(), send_var_name)) {
VLOG(3) << "sending " << send_var_name << " to " << endpoint;
rets.push_back(rpc_client->AsyncSendVar(endpoint, cpu_ctx, *local_scope,
send_var_name));
rets.push_back(rpc_client->AsyncSendVar(
endpoint, cpu_ctx, *local_scope.get(), send_var_name));
} else {
VLOG(3) << "don't send non-initialized variable: "
<< rpc_ctx.splited_var_names[i];
......@@ -165,8 +165,6 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
PADDLE_ENFORCE(handle->Wait(), "internal error in RPCClient");
}
}
delete local_scope;
}
template struct ParameterSend<float>;
......
......@@ -60,7 +60,7 @@ class VariableResponse {
bool create_scope = false)
: scope_(scope), dev_ctx_(dev_ctx), create_scope_(create_scope) {
if (create_scope) {
local_scope_ = scope->NewTmpScope();
local_scope_ = scope->NewTmpScope().release();
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册