提交 9861a92f 编写于 作者: Q Qiao Longfei

change the return type of NewTempScope to unique ptr test=develop

上级 fb6cc3a1
...@@ -59,7 +59,9 @@ Scope& Scope::NewScope() const { ...@@ -59,7 +59,9 @@ Scope& Scope::NewScope() const {
return *child; return *child;
} }
Scope* Scope::NewTmpScope() const { return new Scope(this); } std::unique_ptr<Scope> Scope::NewTmpScope() const {
return std::unique_ptr<Scope>(new Scope(this));
}
Variable* Scope::Var(const std::string& name) { Variable* Scope::Var(const std::string& name) {
SCOPE_VARS_WRITER_LOCK SCOPE_VARS_WRITER_LOCK
......
...@@ -54,9 +54,7 @@ class Scope { ...@@ -54,9 +54,7 @@ class Scope {
/// Create a sub-scope for current scope but do not record it in the kids to /// Create a sub-scope for current scope but do not record it in the kids to
/// avoid performance problems. /// avoid performance problems.
/// Note!!! You should delete the result pointer yourself to avoid memory std::unique_ptr<Scope> NewTmpScope() const;
/// leak!
Scope* NewTmpScope() const;
/// Create a variable with given name if it doesn't exist. /// Create a variable with given name if it doesn't exist.
/// Caller doesn't own the returned Variable. /// Caller doesn't own the returned Variable.
......
...@@ -160,7 +160,7 @@ void prefetch(const std::string& id_name, const std::string& out_name, ...@@ -160,7 +160,7 @@ void prefetch(const std::string& id_name, const std::string& out_name,
const std::vector<int64_t>& height_sections, const std::vector<int64_t>& height_sections,
const framework::ExecutionContext& context, const framework::ExecutionContext& context,
const framework::Scope& scope) { const framework::Scope& scope) {
framework::Scope* local_scope = scope.NewTmpScope(); std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope();
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& cpu_ctx = *pool.Get(platform::CPUPlace()); auto& cpu_ctx = *pool.Get(platform::CPUPlace());
...@@ -206,7 +206,7 @@ void prefetch(const std::string& id_name, const std::string& out_name, ...@@ -206,7 +206,7 @@ void prefetch(const std::string& id_name, const std::string& out_name,
auto splited_ids = SplitIds(ids_vector, height_sections); auto splited_ids = SplitIds(ids_vector, height_sections);
SplitIdsIntoMultipleVarsBySection(in_var_names, height_sections, splited_ids, SplitIdsIntoMultipleVarsBySection(in_var_names, height_sections, splited_ids,
local_scope); local_scope.get());
// create output var in local scope // create output var in local scope
for (auto& name : out_var_names) { for (auto& name : out_var_names) {
...@@ -215,12 +215,12 @@ void prefetch(const std::string& id_name, const std::string& out_name, ...@@ -215,12 +215,12 @@ void prefetch(const std::string& id_name, const std::string& out_name,
std::vector<distributed::VarHandlePtr> rets; std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < in_var_names.size(); i++) { for (size_t i = 0; i < in_var_names.size(); i++) {
if (NeedSend(*local_scope, in_var_names[i])) { if (NeedSend(*local_scope.get(), in_var_names[i])) {
VLOG(3) << "sending " << in_var_names[i] << " to " << epmap[i] VLOG(3) << "sending " << in_var_names[i] << " to " << epmap[i]
<< " to get " << out_var_names[i] << " back"; << " to get " << out_var_names[i] << " back";
rets.push_back(rpc_client->AsyncPrefetchVar( rets.push_back(rpc_client->AsyncPrefetchVar(
epmap[i], cpu_ctx, *local_scope, in_var_names[i], out_var_names[i], epmap[i], cpu_ctx, *local_scope.get(), in_var_names[i],
table_names[i])); out_var_names[i], table_names[i]));
} else { } else {
VLOG(3) << "don't send no-initialied variable: " << out_var_names[i]; VLOG(3) << "don't send no-initialied variable: " << out_var_names[i];
} }
...@@ -232,8 +232,7 @@ void prefetch(const std::string& id_name, const std::string& out_name, ...@@ -232,8 +232,7 @@ void prefetch(const std::string& id_name, const std::string& out_name,
MergeMultipleVarsIntoOneBySection(id_name, ids_vector, out_name, MergeMultipleVarsIntoOneBySection(id_name, ids_vector, out_name,
out_var_names, height_sections, splited_ids, out_var_names, height_sections, splited_ids,
context, local_scope, &actual_ctx); context, local_scope.get(), &actual_ctx);
delete local_scope;
} }
}; // namespace distributed }; // namespace distributed
......
...@@ -42,7 +42,7 @@ template <typename T> ...@@ -42,7 +42,7 @@ template <typename T>
void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx, void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
const framework::Scope &scope) { const framework::Scope &scope) {
VLOG(3) << "ParameterRecv in"; VLOG(3) << "ParameterRecv in";
framework::Scope *local_scope = scope.NewTmpScope(); std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &cpu_ctx = *pool.Get(platform::CPUPlace()); auto &cpu_ctx = *pool.Get(platform::CPUPlace());
...@@ -64,7 +64,7 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx, ...@@ -64,7 +64,7 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
recved_tensors.push_back(t); recved_tensors.push_back(t);
VLOG(3) << "recv " << recv_var_name << " from " << rpc_ctx.epmap[i]; VLOG(3) << "recv " << recv_var_name << " from " << rpc_ctx.epmap[i];
rets.push_back(rpc_client->AsyncGetVar(rpc_ctx.epmap[i], cpu_ctx, rets.push_back(rpc_client->AsyncGetVar(rpc_ctx.epmap[i], cpu_ctx,
*local_scope, recv_var_name, *local_scope.get(), recv_var_name,
recv_var_name)); recv_var_name));
} }
for (size_t i = 0; i < rets.size(); i++) { for (size_t i = 0; i < rets.size(); i++) {
...@@ -93,7 +93,6 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx, ...@@ -93,7 +93,6 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
PADDLE_ENFORCE_EQ(recv_numel, recv_tensor->numel()); PADDLE_ENFORCE_EQ(recv_numel, recv_tensor->numel());
} }
delete local_scope;
VLOG(3) << "ParameterRecv out"; VLOG(3) << "ParameterRecv out";
} }
......
...@@ -40,7 +40,7 @@ using DDim = framework::DDim; ...@@ -40,7 +40,7 @@ using DDim = framework::DDim;
template <typename T> template <typename T>
void ParameterSend<T>::operator()(const RpcContext &rpc_ctx, void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
const framework::Scope &scope, bool sync) { const framework::Scope &scope, bool sync) {
framework::Scope *local_scope = scope.NewTmpScope(); std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &cpu_ctx = *pool.Get(platform::CPUPlace()); auto &cpu_ctx = *pool.Get(platform::CPUPlace());
...@@ -150,10 +150,10 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx, ...@@ -150,10 +150,10 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) { for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
auto &send_var_name = rpc_ctx.splited_var_names[i]; auto &send_var_name = rpc_ctx.splited_var_names[i];
auto &endpoint = rpc_ctx.epmap[i]; auto &endpoint = rpc_ctx.epmap[i];
if (NeedSend(*local_scope, send_var_name)) { if (NeedSend(*local_scope.get(), send_var_name)) {
VLOG(3) << "sending " << send_var_name << " to " << endpoint; VLOG(3) << "sending " << send_var_name << " to " << endpoint;
rets.push_back(rpc_client->AsyncSendVar(endpoint, cpu_ctx, *local_scope, rets.push_back(rpc_client->AsyncSendVar(
send_var_name)); endpoint, cpu_ctx, *local_scope.get(), send_var_name));
} else { } else {
VLOG(3) << "don't send non-initialized variable: " VLOG(3) << "don't send non-initialized variable: "
<< rpc_ctx.splited_var_names[i]; << rpc_ctx.splited_var_names[i];
...@@ -165,8 +165,6 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx, ...@@ -165,8 +165,6 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
PADDLE_ENFORCE(handle->Wait(), "internal error in RPCClient"); PADDLE_ENFORCE(handle->Wait(), "internal error in RPCClient");
} }
} }
delete local_scope;
} }
template struct ParameterSend<float>; template struct ParameterSend<float>;
......
...@@ -60,7 +60,7 @@ class VariableResponse { ...@@ -60,7 +60,7 @@ class VariableResponse {
bool create_scope = false) bool create_scope = false)
: scope_(scope), dev_ctx_(dev_ctx), create_scope_(create_scope) { : scope_(scope), dev_ctx_(dev_ctx), create_scope_(create_scope) {
if (create_scope) { if (create_scope) {
local_scope_ = scope->NewTmpScope(); local_scope_ = scope->NewTmpScope().release();
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册