提交 97b143fb 编写于 作者: Q Qiao Longfei

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into fix-cpu-broadcast

test=develop
...@@ -21,12 +21,13 @@ function(CheckCompilerCXX11Flag) ...@@ -21,12 +21,13 @@ function(CheckCompilerCXX11Flag)
if (${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 3.3) if (${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 3.3)
message(FATAL_ERROR "Unsupported Clang version. Clang >= 3.3 required.") message(FATAL_ERROR "Unsupported Clang version. Clang >= 3.3 required.")
endif() endif()
endif() endif()
endif() endif()
endfunction() endfunction()
CheckCompilerCXX11Flag() CheckCompilerCXX11Flag()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -m64")
# safe_set_flag # safe_set_flag
# #
# Set a compile flag only if compiler is support # Set a compile flag only if compiler is support
......
...@@ -158,18 +158,19 @@ cc_library(variable_helper SRCS variable_helper.cc DEPS lod_tensor) ...@@ -158,18 +158,19 @@ cc_library(variable_helper SRCS variable_helper.cc DEPS lod_tensor)
cc_library(naive_executor SRCS naive_executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper) cc_library(naive_executor SRCS naive_executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper)
if(WITH_DISTRIBUTE) if(WITH_NGRAPH)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog set(NGRAPH_EXE_DEPS ngraph_engine)
lod_rank_table feed_fetch_method sendrecvop_rpc ${GLOB_DISTRIBUTE_DEPS} graph_to_program_pass variable_helper) else()
set(NGRAPH_EXE_DEPS)
endif()
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") if(WITH_DISTRIBUTE)
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog
lod_rank_table feed_fetch_method sendrecvop_rpc ${GLOB_DISTRIBUTE_DEPS} graph_to_program_pass variable_helper ${NGRAPH_EXE_DEPS})
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else() else()
if (WITH_NGRAPH) cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper ${NGRAPH_EXE_DEPS})
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper ngraph_engine)
else ()
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper)
endif()
cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op) cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op)
endif() endif()
......
...@@ -54,8 +54,6 @@ cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph grap ...@@ -54,8 +54,6 @@ cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph grap
cc_library(memory_optimize_pass SRCS memory_optimize_pass.cc DEPS memory_optimize_helper pass) cc_library(memory_optimize_pass SRCS memory_optimize_pass.cc DEPS memory_optimize_helper pass)
cc_library(inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_info) cc_library(inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_info)
cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper) cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper)
cc_library(memory_early_delete_pass SRCS memory_early_delete_pass.cc DEPS memory_optimize_pass computation_op_handle scale_loss_grad_op_handle rpc_op_handle
all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass)
cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle) cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle)
cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper) cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper)
cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass) cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass)
...@@ -67,13 +65,11 @@ cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_he ...@@ -67,13 +65,11 @@ cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_he
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle) scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle)
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass memory_early_delete_pass inplace_op_pass) set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass inplace_op_pass)
if (WITH_GPU) if (WITH_GPU)
list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass) list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass)
endif() endif()
cc_test(memory_optimize_helper_test SRCS memory_optimize_helper_test.cc memory_optimize_helper.cc DEPS framework_proto graph) cc_test(memory_optimize_helper_test SRCS memory_optimize_helper_test.cc memory_optimize_helper.cc DEPS framework_proto graph graph_helper op_registry)
cc_test(memory_optimize_pass_test SRCS memory_optimize_pass_test.cc memory_optimize_pass.cc memory_optimize_helper.cc DEPS framework_proto graph graph_helper op_registry pass)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS}) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS})
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
......
...@@ -209,8 +209,6 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply( ...@@ -209,8 +209,6 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
new std::vector<OpDesc *>(main_program.Block(0).AllOps()); new std::vector<OpDesc *>(main_program.Block(0).AllOps());
graph->Set<const std::vector<OpDesc *>>(kAllOpDescs, graph->Set<const std::vector<OpDesc *>>(kAllOpDescs,
all_op_descs); // take ownership all_op_descs); // take ownership
graph->Set<GraphNodePool>(kGraphNodePool,
new GraphNodePool); // take ownership
pass->Erase(kAllOpDescs); pass->Erase(kAllOpDescs);
pass->SetNotOwned<const std::vector<OpDesc *>>(kAllOpDescs, all_op_descs); pass->SetNotOwned<const std::vector<OpDesc *>>(kAllOpDescs, all_op_descs);
......
...@@ -77,9 +77,6 @@ struct BuildStrategy { ...@@ -77,9 +77,6 @@ struct BuildStrategy {
bool fuse_relu_depthwise_conv_{false}; bool fuse_relu_depthwise_conv_{false};
bool memory_optimize_{false}; bool memory_optimize_{false};
bool memory_early_delete_{false};
// TODO(dzhwinter): // TODO(dzhwinter):
// make enable_inplace, memory_optimize_ // make enable_inplace, memory_optimize_
// memory_early_delete_ true by default // memory_early_delete_ true by default
......
...@@ -171,16 +171,15 @@ void InplacePass::InplaceModifyDesc(const std::string& var, ...@@ -171,16 +171,15 @@ void InplacePass::InplaceModifyDesc(const std::string& var,
} }
} }
const SSANodePair InplacePass::TryInplaceModifyVar(const std::string& var, const NodeSwapQueue InplacePass::TryInplaceModifyVar(
const std::string& cache_var, const std::string& var, const std::string& cache_var, const size_t& idx,
const size_t& idx, ir::Graph* graph) const {
ir::Graph* graph) const {
PADDLE_ENFORCE(var_nodes_[var].size() >= 1 && PADDLE_ENFORCE(var_nodes_[var].size() >= 1 &&
var_nodes_[var].at(0)->Var() != nullptr); var_nodes_[var].at(0)->Var() != nullptr);
std::unique_ptr<VarDesc> var_desc(new VarDesc(*var_nodes_[var].at(0)->Var())); std::unique_ptr<VarDesc> var_desc(new VarDesc(*var_nodes_[var].at(0)->Var()));
var_desc->SetName(cache_var); var_desc->SetName(cache_var);
SSANodePair swap_nodes; NodeSwapQueue swap_nodes;
for (size_t i = idx; i < view_.AllOps().size(); ++i) { for (size_t i = idx; i < view_.AllOps().size(); ++i) {
auto* op = view_.AllOps()[i]; auto* op = view_.AllOps()[i];
...@@ -230,7 +229,7 @@ const SSANodePair InplacePass::TryInplaceModifyVar(const std::string& var, ...@@ -230,7 +229,7 @@ const SSANodePair InplacePass::TryInplaceModifyVar(const std::string& var,
return swap_nodes; return swap_nodes;
} }
void InplacePass::CommitModify(const SSANodePair& swap_nodes, void InplacePass::CommitModify(const NodeSwapQueue& swap_nodes,
ir::Graph* graph) const { ir::Graph* graph) const {
for (auto& pair : swap_nodes) { for (auto& pair : swap_nodes) {
auto *node = pair.first, *cache_node = pair.second; auto *node = pair.first, *cache_node = pair.second;
...@@ -245,7 +244,7 @@ void InplacePass::CommitModify(const SSANodePair& swap_nodes, ...@@ -245,7 +244,7 @@ void InplacePass::CommitModify(const SSANodePair& swap_nodes,
} }
} }
void InplacePass::WithdrawModify(const SSANodePair& nodes, void InplacePass::WithdrawModify(const NodeSwapQueue& nodes,
ir::Graph* graph) const { ir::Graph* graph) const {
for (auto& pair : nodes) { for (auto& pair : nodes) {
auto *node = pair.first, *cache_node = pair.second; auto *node = pair.first, *cache_node = pair.second;
......
...@@ -56,7 +56,8 @@ class GraphView { ...@@ -56,7 +56,8 @@ class GraphView {
std::map<ir::Node*, std::unordered_set<ir::Node*>> adj_list_; std::map<ir::Node*, std::unordered_set<ir::Node*>> adj_list_;
}; };
typedef std::vector<std::pair<ir::Node*, ir::Node*>> SSANodePair; // swap pairs in sequence
typedef std::vector<std::pair<ir::Node*, ir::Node*>> NodeSwapQueue;
class InplacePass : public ir::Pass { class InplacePass : public ir::Pass {
public: public:
InplacePass(); InplacePass();
...@@ -68,14 +69,14 @@ class InplacePass : public ir::Pass { ...@@ -68,14 +69,14 @@ class InplacePass : public ir::Pass {
void InitSSAGraphNodes() const; void InitSSAGraphNodes() const;
private: private:
const SSANodePair TryInplaceModifyVar(const std::string& var, const NodeSwapQueue TryInplaceModifyVar(const std::string& var,
const std::string& cache_var, const std::string& cache_var,
const size_t& idx, const size_t& idx,
ir::Graph* graph) const; ir::Graph* graph) const;
void CommitModify(const SSANodePair&, ir::Graph* graph) const; void CommitModify(const NodeSwapQueue&, ir::Graph* graph) const;
void WithdrawModify(const SSANodePair& nodes, ir::Graph* graph) const; void WithdrawModify(const NodeSwapQueue& nodes, ir::Graph* graph) const;
void InplaceModifyDesc(const std::string& in_var, const std::string& out_var, void InplaceModifyDesc(const std::string& in_var, const std::string& out_var,
const size_t& idx) const; const size_t& idx) const;
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/memory_early_delete_pass.h"
#include <queue>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle {
namespace framework {
namespace details {
static ComputationOpHandle* FindNextComputationOpHandle(VarHandle* var_in) {
std::queue<VarHandleBase*> queue;
queue.push(var_in);
do {
auto* var = queue.front();
queue.pop();
for (auto* op : var->PendingOps()) {
auto* compute_op = dynamic_cast<ComputationOpHandle*>(op);
if (compute_op != nullptr && compute_op->GetPlace() == var_in->place()) {
return compute_op;
}
for (auto* out_var : op->Outputs()) {
queue.push(out_var);
}
}
} while (!queue.empty());
return nullptr;
}
std::unique_ptr<ir::Graph> MemoryEarlyDeletePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
auto& graph_pool = Get<GraphNodePool>(kGraphNodePool);
auto& gcs = Get<GarbageCollectorMap>(kGarbageCollector);
std::unordered_map<std::string, std::unordered_set<OpDesc*>> unlived_vars;
unlived_vars.reserve(graph_pool.size());
for (auto& pair : graph_pool) {
unlived_vars.insert(std::make_pair(pair.first, pair.second));
}
auto compare_and_insert_early_delete_op = [&](
OpHandleBase* op, const std::vector<VarHandleBase*>& vars) {
if (unlived_vars.empty()) return;
// unlived vars can be deleted after the last used op has finished.
auto* compute_op = dynamic_cast<ComputationOpHandle*>(op);
const auto& places = Get<std::vector<platform::Place>>(kAllPlaces);
for (auto& var : vars) {
auto* var_handle = dynamic_cast<VarHandle*>(var);
auto var_name = var->Node()->Name();
auto& var_place = var_handle->place();
if (unlived_vars.count(var_name) == 0) continue;
if (!unlived_vars[var_name].empty()) {
if (compute_op != nullptr &&
unlived_vars[var_name].count(compute_op->Node()->Op()) != 0) {
unlived_vars[var_name].erase(compute_op->Node()->Op());
}
continue;
}
if (var_handle == nullptr || !var_handle->Node()->IsVar() ||
var_handle->Node()->IsCtrlVar())
continue;
// shameless copyed from reference count pass.
if (compute_op == nullptr) {
// use next computation op scope
compute_op = FindNextComputationOpHandle(var_handle);
}
auto* early_delete_node =
graph->CreateEmptyNode("early_delete", ir::Node::Type::kOperation);
GarbageCollector* gc = gcs.at(places[compute_op->GetScopeIdx()]).get();
auto* early_delete_handle = new EarlyDeleteOpHandle(
early_delete_node, compute_op->GetScope(), var_place, {var_name}, gc);
if (compute_op->Outputs().empty()) {
auto* dep_var = new DummyVarHandle(graph->CreateControlDepVar());
compute_op->AddOutput(dep_var);
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
}
early_delete_handle->AddInput(compute_op->Outputs().front());
VLOG(5) << "Add early delete op " << var_name << " to Operator"
<< compute_op->Name();
}
};
auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
for (auto& op : all_ops) {
compare_and_insert_early_delete_op(op, op->Inputs());
compare_and_insert_early_delete_op(op, op->Outputs());
}
return graph;
}
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(memory_early_delete_pass,
paddle::framework::details::MemoryEarlyDeletePass)
.RequireGraphAttr(paddle::framework::details::kGraphNodePool)
.RequireGraphAttr(paddle::framework::details::kGarbageCollector);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/details/early_delete_op_handle.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace details {
class MemoryEarlyDeletePass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -13,17 +13,108 @@ ...@@ -13,17 +13,108 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/memory_optimize_helper.h" #include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include <deque>
#include <functional> #include <functional>
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include "paddle/fluid/framework/var_desc.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
using paddle::framework::VarDesc;
size_t NodeSizeInBytes(const VarDesc& node) { std::vector<ir::Node*> SortOpLikeDescOrder(const ir::Graph& graph) {
PADDLE_ENFORCE(graph.Has(kAllOpDescs),
"Graph has no attribute of kAllOpDescs.");
// 1. get op desc order
auto& op_descs = graph.Get<const std::vector<OpDesc*>>(kAllOpDescs);
// 2. topology sort order
auto nodes = graph.Nodes();
std::deque<ir::Node*> ops;
FilterVariables(nodes, [&](ir::Node* op) {
if (op->IsOp() && op->Op() != nullptr) {
ops.emplace_back(op);
}
});
std::unordered_map<ir::Node*, size_t> op_deps;
std::list<ir::Node*> ready_ops;
std::unordered_map<ir::Node*, std::unordered_set<ir::Node*>> pending_ops;
for (auto* op : ops) {
std::unordered_set<ir::Node*> preceding_op;
for (auto* in : op->inputs) {
if (in->inputs.empty()) continue;
PADDLE_ENFORCE(in->inputs.size() == 1 && in->inputs[0]->IsOp());
preceding_op.emplace(in->inputs[0]);
pending_ops[in->inputs[0]].emplace(op);
}
op_deps[op] = preceding_op.size();
if (preceding_op.empty()) {
ready_ops.emplace_back(op);
}
}
// 3. generated op list based desc order and the topology order
std::vector<ir::Node*> ret;
std::list<OpDesc*> op_descs_list(op_descs.begin(), op_descs.end());
auto update_by_found_node = [&](ir::Node* found_node) {
for (auto* pending_op : pending_ops[found_node]) {
if (--op_deps[pending_op] == 0) {
ready_ops.emplace_back(pending_op);
}
}
ready_ops.remove(found_node);
ret.emplace_back(found_node);
};
while (!ready_ops.empty()) {
bool all_of_ready_op_unmatched = true;
for (auto it = op_descs_list.begin(); it != op_descs_list.end();) {
auto op_desc = *it;
ir::Node* found_node = nullptr;
for (auto* op : ready_ops) {
if (IsSameDesc(op->Op(), op_desc)) {
found_node = op;
break;
}
}
// 3.1 op desc deleted by other pass
if (found_node == nullptr) {
++it;
continue;
} else {
all_of_ready_op_unmatched = false;
it = op_descs_list.erase(it);
}
update_by_found_node(found_node);
}
// 3.2 op descs are added by other pass
// preceding op non empty means some new op descs are
// created, but not contained in return node list.
// these new op desc may depend on each other.
std::list<ir::Node*> prev_ready_ops(ready_ops);
if (all_of_ready_op_unmatched) {
for (auto op : prev_ready_ops) {
update_by_found_node(op);
}
}
}
PADDLE_ENFORCE(std::all_of(
op_deps.begin(), op_deps.end(),
[&](const std::pair<ir::Node*, size_t>& p) { return p.second == 0; }));
return ret;
}
size_t NodeSize(const VarDesc& node) {
auto shape = node.GetShape(); auto shape = node.GetShape();
int size = int size =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()); std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
...@@ -31,9 +122,9 @@ size_t NodeSizeInBytes(const VarDesc& node) { ...@@ -31,9 +122,9 @@ size_t NodeSizeInBytes(const VarDesc& node) {
return type_size * std::abs(size); return type_size * std::abs(size);
} }
size_t NodeSizeInBytes(ir::Node* n) { size_t NodeSize(ir::Node* n) {
auto* desc = FindVarDescInBlock(n); auto* desc = FindVarDescInBlock(n);
return NodeSizeInBytes(*desc); return NodeSize(*desc);
} }
std::string DebugStringImpl(VarDesc* var) { std::string DebugStringImpl(VarDesc* var) {
...@@ -59,7 +150,6 @@ std::string DebugStringImpl(VarDesc* var) { ...@@ -59,7 +150,6 @@ std::string DebugStringImpl(VarDesc* var) {
std::string DebugString(ir::Node* var) { std::string DebugString(ir::Node* var) {
return DebugStringImpl(FindVarDescInBlock(var)); return DebugStringImpl(FindVarDescInBlock(var));
} }
// return DebugString(var->Var()); }
// NOTE(dzh): based ir node, if a large node has been reused // NOTE(dzh): based ir node, if a large node has been reused
// by a small size node, then next time it appear in pool, it will // by a small size node, then next time it appear in pool, it will
...@@ -80,18 +170,17 @@ struct NodeComparator { ...@@ -80,18 +170,17 @@ struct NodeComparator {
auto rhs_shape = rhs_desc->GetShape(); auto rhs_shape = rhs_desc->GetShape();
if ((lhs_shape[0] == -1 && rhs_shape[0] == -1) || if ((lhs_shape[0] == -1 && rhs_shape[0] == -1) ||
(lhs_shape[0] != -1 && rhs_shape[0] != -1)) { (lhs_shape[0] != -1 && rhs_shape[0] != -1)) {
return NodeSizeInBytes(lhs) <= NodeSizeInBytes(rhs); return NodeSize(lhs) <= NodeSize(rhs);
} else { } else {
return false; return false;
} }
} }
}; };
void OrderedNodeList::Insert(ir::Node* var, ir::Node* op) { void OrderedSet::Insert(ir::Node* var) {
PADDLE_ENFORCE(var->IsVar() && !var->IsCtrlVar()); PADDLE_ENFORCE(var->IsVar() && !var->IsCtrlVar());
PADDLE_ENFORCE(op->IsOp());
if (mark_table_.count(var->Name()) != 0) { if (mark_table_.count(var->Name()) != 0) {
mark_table_[var->Name()]->second.insert(op); mark_table_[var->Name()]->emplace_back(var);
return; return;
} }
...@@ -99,14 +188,15 @@ void OrderedNodeList::Insert(ir::Node* var, ir::Node* op) { ...@@ -99,14 +188,15 @@ void OrderedNodeList::Insert(ir::Node* var, ir::Node* op) {
auto var_shape = var_desc->GetShape(); auto var_shape = var_desc->GetShape();
int batch_size = static_cast<int>(var_shape[0]); int batch_size = static_cast<int>(var_shape[0]);
NodeComparator compare_node; NodeComparator functor;
Iter it = nodes_.begin(); Iter it = nodes_.begin();
while (it != nodes_.end()) { while (it != nodes_.end()) {
auto* cache_desc = FindVarDescInBlock(it->first); auto& prev = it->front();
auto* cache_desc = FindVarDescInBlock(prev);
int cache_batch_size = cache_desc->GetShape()[0]; int cache_batch_size = cache_desc->GetShape()[0];
if ((cache_batch_size == -1 && batch_size == -1) || if ((cache_batch_size == -1 && batch_size == -1) ||
(cache_batch_size != -1 && batch_size != -1)) { (cache_batch_size != -1 && batch_size != -1)) {
if (compare_node(it->first, var)) { if (functor(prev, var)) {
++it; ++it;
} else { } else {
break; break;
...@@ -118,62 +208,80 @@ void OrderedNodeList::Insert(ir::Node* var, ir::Node* op) { ...@@ -118,62 +208,80 @@ void OrderedNodeList::Insert(ir::Node* var, ir::Node* op) {
} }
} }
it = it = nodes_.insert(it, {var});
nodes_.insert(it, std::make_pair(var, std::unordered_set<ir::Node*>{op}));
mark_table_[var->Name()] = it; mark_table_[var->Name()] = it;
} }
int OrderedNodeList::GetIndex(ir::Node* var) { int OrderedSet::GetNodeIndexInPool(ir::Node* var) {
return std::distance(nodes_.begin(), mark_table_[var->Name()]); return std::distance(nodes_.begin(), mark_table_[var->Name()]);
} }
ir::Node* OrderedNodeList::NodeMatch(ir::Node* var) const { ir::Node* OrderedSet::FindBestFitNode(ir::Node* var) const {
ir::Node* found_node = nullptr; ir::Node* found_node = nullptr;
NodeComparator compare_node; NodeComparator functor;
for (auto it = nodes_.begin(); it != nodes_.end(); ++it) { for (auto it = nodes_.begin(); it != nodes_.end(); ++it) {
if (compare_node(var, it->first)) { auto& candidate = it->front();
found_node = it->first; if (functor(var, candidate)) {
found_node = candidate;
break; break;
} }
} }
return found_node; return found_node;
} }
void OrderedNodeList::Erase(ir::Node* var) { Erase(var->Name()); } bool OrderedSet::Has(ir::Node* var) const {
if (mark_table_.count(var->Name())) {
auto& node_in_samename = mark_table_.at(var->Name());
auto iter =
std::find_if(node_in_samename->begin(), node_in_samename->end(),
[&](ir::Node* n) { return n->Name() == var->Name(); });
return iter != node_in_samename->end();
}
return false;
}
void OrderedNodeList::Erase(const std::string& var) { void OrderedSet::Erase(ir::Node* var) {
PADDLE_ENFORCE(mark_table_.count(var)); PADDLE_ENFORCE(mark_table_.count(var->Name()));
nodes_.erase(mark_table_[var]); nodes_.erase(mark_table_[var->Name()]);
mark_table_.erase(var); mark_table_.erase(var->Name());
} }
std::string OrderedNodeList::ToString() const { std::string OrderedSet::ToString() const {
std::stringstream ss; std::stringstream ss;
for (auto it = nodes_.begin(); it != nodes_.end(); ++it) { for (auto it = nodes_.begin(); it != nodes_.end(); ++it) {
ss << DebugString(it->first) << " "; for (auto& node : *it) {
ss << DebugString(node) << " ";
}
} }
return ss.str(); return ss.str();
} }
bool NodeCanReused(ir::Node* node) { bool NodeCanReused(ir::Node* node) {
// valid the node is a var node
if (node == nullptr || !node->IsVar() || node->IsCtrlVar()) return false; if (node == nullptr || !node->IsVar() || node->IsCtrlVar()) return false;
// auto* desc = node->Var();
bool flag = NodeCanReused(*node->Var()); bool flag = true;
// op output force generated in cpu, can not be reused.
for (auto* op : node->inputs) { for (auto* op : node->inputs) {
if (op->Op()->HasAttr("force_cpu")) { if (op->Op()->HasAttr("force_cpu")) {
// op output force generated in cpu, can not be reused.
flag &= framework::AttrReader(op->Op()->GetAttrMap()) flag &= framework::AttrReader(op->Op()->GetAttrMap())
.Get<bool>("force_cpu") == 0; .Get<bool>("force_cpu") == 0;
} }
} }
// var desc validation.
flag &= NodeCanReused(*node->Var());
return flag; return flag;
} }
bool NodeCanReused(const VarDesc& node) { bool NodeCanReused(const VarDesc& node) {
auto type = node.GetType(); auto type = node.GetType();
if (node.Persistable() || type != proto::VarType::LOD_TENSOR || if (!(type == proto::VarType::LOD_TENSOR ||
node.GetShape().empty()) { type == proto::VarType::SELECTED_ROWS ||
type == proto::VarType::LOD_TENSOR_ARRAY)) {
return false;
}
if (node.Persistable() || node.GetShape().empty()) {
return false; return false;
} }
// vars can be @EMPTY@, @LR_DECAY_REUSE_ID@. For example, while_grad // vars can be @EMPTY@, @LR_DECAY_REUSE_ID@. For example, while_grad
...@@ -193,6 +301,174 @@ bool OpHasSubBlock(OpDesc* desc) { ...@@ -193,6 +301,174 @@ bool OpHasSubBlock(OpDesc* desc) {
return false; return false;
} }
ControlFlowGraph::ControlFlowGraph(const ir::Graph& graph) {
ops_ = SortOpLikeDescOrder(graph);
ConnectNodes();
}
void ControlFlowGraph::BuildCFGGraph() {
// FIXME(dzh): same effect with ConnectNodes, but use the control
// link to build dependency graph, it goes wrong in transformer.
for (ir::Node* op : ops_) {
for (auto& input_var : op->inputs) {
if (!input_var->inputs.empty()) {
PADDLE_ENFORCE(
input_var->inputs.size() == 1 && input_var->inputs[0]->IsOp(),
"Preceding Op Node of Var Node must be unique");
auto* pred_op = input_var->inputs[0];
if (pred_op->Op() != nullptr) {
predecessors_[op].insert(pred_op);
successors_[pred_op].insert(op);
}
}
if (input_var->IsVar() && !input_var->IsCtrlVar()) {
uses_[op].insert(input_var->Name());
}
}
for (auto& output_var : op->outputs) {
// output var may be used by many op
for (auto* succ_op : output_var->outputs) {
if (succ_op->Op() != nullptr) {
successors_[op].insert(succ_op);
predecessors_[succ_op].insert(op);
}
}
if (output_var->IsVar() && !output_var->IsCtrlVar()) {
defs_[op].insert(output_var->Name());
}
}
}
}
void ControlFlowGraph::ConnectNodes() {
for (size_t i = 0; i < ops_.size(); ++i) {
auto& op = ops_[i];
try {
auto& next_op = ops_.at(i + 1);
successors_[op].insert(next_op);
predecessors_[next_op].insert(op);
} catch (...) {
// do nothing
}
FilterVariables(op->inputs,
[&](ir::Node* var) { uses_[op].emplace(var->Name()); });
FilterVariables(op->outputs,
[&](ir::Node* var) { defs_[op].emplace(var->Name()); });
}
}
void ControlFlowGraph::LiveVariableAnalysis() {
// NOTE(dzh): variable liveless analysis (a.k.a reversed_ops algorithm)
// compute the liveness of for each variable though reversed_ops algorithm.
// It iterates the operators from end to begin, compute the live in/live out
// variable set for each op, then the diff between in/out will be used for
// the variable reuse. For detail refer to
// http://www.cs.cornell.edu/courses/cs4120/2013fa/lectures/lec26-fa13.pdf
std::list<ir::Node*> work_list(ops_.rbegin(), ops_.rend());
while (!work_list.empty()) {
ir::Node* op = work_list.front();
work_list.pop_front();
// get the live_in calculated before. Empty if first.
auto prev_live_in = std::move(live_in_[op]);
for (auto& s : successors_[op]) {
for (auto& var : live_in_[s]) {
live_out_[op].insert(var);
}
}
for (auto& var : uses_[op]) {
live_in_[op].insert(var);
}
for (auto& var : live_out_[op]) {
live_in_[op].insert(var);
}
for (auto& var : defs_[op]) {
live_in_[op].erase(var);
}
// If the live_in is not changed, then the liveness analysis of
// predecessors is completed.
//
// Otherwise, recalculate the predecessors liveness
if (live_in_[op] != prev_live_in) {
for (auto& pre : predecessors_[op]) {
work_list.push_back(pre);
}
}
}
}
void ControlFlowGraph::RenameVarInCFGGraph(const std::string& old_node,
const std::string& new_node,
int begin_idx) {
// update graph from begin idx to the end
for (size_t i = begin_idx; i != ops_.size(); ++i) {
auto* op = ops_[i];
if (uses_[op].find(old_node) != uses_[op].end()) {
uses_[op].erase(old_node);
uses_[op].insert(new_node);
}
if (defs_[op].find(old_node) != defs_[op].end()) {
defs_[op].erase(old_node);
defs_[op].insert(new_node);
}
if (live_in_[op].find(old_node) != live_in_[op].end()) {
live_in_[op].erase(old_node);
live_in_[op].insert(new_node);
}
if (live_out_[op].find(old_node) != live_out_[op].end()) {
live_out_[op].erase(old_node);
live_out_[op].insert(new_node);
}
}
}
const std::set<std::string> ControlFlowGraph::LiveIn(ir::Node* op) const {
auto it = live_in_.find(op);
PADDLE_ENFORCE(
it != live_in_.end(),
string::Sprintf("Expect %s in live_in, but Not Found.", op->Name()));
return it->second;
}
const std::set<std::string> ControlFlowGraph::LiveOut(ir::Node* op) const {
auto it = live_out_.find(op);
PADDLE_ENFORCE(
it != live_out_.end(),
string::Sprintf("Expect %s in live_out, but Not Found.", op->Name()));
return it->second;
}
const std::set<std::string> ControlFlowGraph::Use(ir::Node* op) const {
auto it = uses_.find(op);
PADDLE_ENFORCE(
it != uses_.end(),
string::Sprintf("Expect %s in live_out, but Not Found.", op->Name()));
return it->second;
}
const std::vector<ir::Node*> ControlFlowGraph::Ops() const { return ops_; }
std::vector<ir::Node*>& ControlFlowGraph::Ops() { return ops_; }
ir::Node* ControlFlowGraph::GetNodeByName(const std::string& name,
ir::Node* op) const {
// in ssa-graph, different version nodes have same name,
// this function get the latest version var before target op
// It may return nullptr, such as data node.
ir::Node* found_node = nullptr;
for (auto* node : ops_) {
if (node == op) break;
for (auto& output : node->outputs) {
if (output->Name() == name) {
found_node = output;
}
}
}
return found_node;
}
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
#include <iostream> #include <iostream>
#include <iterator> #include <iterator>
#include <list> #include <list>
#include <map>
#include <set>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
...@@ -27,41 +29,41 @@ namespace paddle { ...@@ -27,41 +29,41 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
constexpr char kFetchedVars[] = "fetched_vars"; constexpr char kAllOpDescs[] = "all_op_descs";
constexpr char kGraphNodePool[] = "graph_node_pool";
// NOTE(dzh): Variable and the operators use the var. std::vector<ir::Node*> SortOpLikeDescOrder(const ir::Graph& graph);
// for early delete pass.
// Because analysis var pass build base on ir::Node, which maybe released
// or modified between passes, so we use OpDesc* to mark ops.
using GraphNodePool = std::vector<
std::pair<std::string /*var node*/, std::unordered_set<OpDesc*> /* ops */>>;
// NOTE(dzh): by default, it sort node in ascend order(by node bytes size). // NOTE(dzh): A ordered set for node reuse in memory optimize.
// in fluid, -1 means the batch_size is determined in runtime. // the orderedset sort node in ascend order(by node bytes size).
// the node batch_size equal -1 always ranking in the front than the node not. // in fluid, -1 means the batch_size, which is determined in runtime.
// So the reuse happens between nodes who's batch_size both are -1
// simultaneously or not.
//
// sort rule:
// rule 0 : smaller node ranking in front.
// rule 1 : batch_size equal -1 ranking in the front than the node not.
//
// For example, // For example,
// node0[-1, 1] node1[-1, 1, 1], node2[1,1], node3[1,1024], .. // node0[-1, 1] node1[-1, 1, 1], node2[1,1], node3[1,1024], ..
// O(1) insert, delete
class OrderedNodeList {
public:
using NodePair = std::pair<ir::Node*, std::unordered_set<ir::Node*>>;
using Iter = typename std::list<NodePair>::iterator;
using ConstIter = typename std::list<NodePair>::const_iterator;
void Insert(ir::Node* var, ir::Node* op); class OrderedSet {
public:
// nodes with same name exists in pool.
using NodeVector = std::vector<ir::Node*>;
using Iter = typename std::list<NodeVector>::iterator;
using ConstIter = typename std::list<NodeVector>::const_iterator;
void Insert(ir::Node* var);
void Erase(ir::Node* var); void Erase(ir::Node* var);
bool Has(ir::Node* var) const;
void Erase(const std::string& var); void Clear() {
mark_table_.clear();
bool Has(ir::Node* var) { return mark_table_.count(var->Name()); } nodes_.clear();
}
bool Has(const std::string& var) { return mark_table_.count(var); } // find the bestfit shape node block with var.
ir::Node* FindBestFitNode(ir::Node* var) const;
ir::Node* NodeMatch(ir::Node* var) const;
// map store non-const iterator, can not promise const // map store non-const iterator, can not promise const
int GetIndex(ir::Node* var); int GetNodeIndexInPool(ir::Node* var);
// pool all node to string // pool all node to string
std::string ToString() const; std::string ToString() const;
...@@ -69,18 +71,54 @@ class OrderedNodeList { ...@@ -69,18 +71,54 @@ class OrderedNodeList {
Iter end() { return nodes_.end(); } Iter end() { return nodes_.end(); }
ConstIter begin() const { return nodes_.begin(); } ConstIter begin() const { return nodes_.begin(); }
ConstIter end() const { return nodes_.end(); } ConstIter end() const { return nodes_.end(); }
size_t size() const { return nodes_.size(); }
void Clear() { size_t size() const { return nodes_.size(); }
mark_table_.clear();
nodes_.clear();
}
private: private:
// for searching. // for searching.
std::unordered_map<std::string, Iter> mark_table_; std::unordered_map<std::string, Iter> mark_table_;
// node swap pairs. var -> ops dep var // node pool
std::list<NodePair> nodes_; std::list<NodeVector> nodes_;
};
class ControlFlowGraph {
public:
ControlFlowGraph() = default;
// IR Graph
explicit ControlFlowGraph(const ir::Graph& graph);
void LiveVariableAnalysis();
void RenameVarInCFGGraph(const std::string& old_node,
const std::string& new_node, int begin_idx);
const std::set<std::string> LiveIn(ir::Node* op) const;
const std::set<std::string> LiveOut(ir::Node* op) const;
const std::set<std::string> Use(ir::Node* op) const;
const std::vector<ir::Node*> Ops() const;
std::vector<ir::Node*>& Ops();
// for ssa-graph nodes
ir::Node* GetNodeByName(const std::string& name, ir::Node* op) const;
private:
void BuildCFGGraph();
void ConnectNodes();
using NodeListMap = std::unordered_map<ir::Node*, std::set<ir::Node*>>;
using VarSetMap = std::map<ir::Node*, std::set<std::string>>;
// successors ops use the output variables.
NodeListMap successors_;
// predecessors ops generated input variables.
NodeListMap predecessors_;
// variables lived before run current op.
VarSetMap live_in_;
// variables lived after run current op.
VarSetMap live_out_;
VarSetMap uses_; // op inputs
VarSetMap defs_; // op outputs
std::vector<ir::Node*> ops_; // op sequence by topology sort
}; };
// valid a tensor can be reuse or not // valid a tensor can be reuse or not
...@@ -93,15 +131,24 @@ bool NodeCanReused(const VarDesc& node); ...@@ -93,15 +131,24 @@ bool NodeCanReused(const VarDesc& node);
bool OpHasSubBlock(OpDesc* desc); bool OpHasSubBlock(OpDesc* desc);
// node memory size in bytes // node memory size in bytes
size_t NodeSizeInBytes(ir::Node* n); size_t NodeSize(ir::Node* n);
// node memory size in bytes // node memory size in bytes
size_t NodeSizeInBytes(const VarDesc&); size_t NodeSize(const VarDesc&);
std::string DebugString(ir::Node* var); std::string DebugString(ir::Node* var);
// NOTE(dzhwinter)
// after node reuse, the replaced node shape is
// different with its VarDesc. So need to find the
// correct VarDesc in Block.
VarDesc* FindVarDescInBlock(ir::Node* n); VarDesc* FindVarDescInBlock(ir::Node* n);
static inline bool IsSameDesc(OpDesc* op1, OpDesc* op2) {
return op1->Type() == op2->Type() && op1->Inputs() == op2->Inputs() &&
op1->Outputs() == op2->Outputs();
}
template <typename Container, typename Callback> template <typename Container, typename Callback>
class FilterVariableImpl { class FilterVariableImpl {
public: public:
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/details/memory_optimize_helper.h" #include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
#include <iterator>
#include <memory> #include <memory>
#include <sstream> #include <sstream>
#include <string> #include <string>
...@@ -22,13 +23,19 @@ ...@@ -22,13 +23,19 @@
#include <vector> #include <vector>
#include "glog/logging.h" #include "glog/logging.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/details/graph_test_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
TEST(OrderedNodeList, Normal) { TEST(OrderedSet, Normal) {
OrderedNodeList pool; OrderedSet pool;
std::vector<std::unique_ptr<ir::Node>> nodes; std::vector<std::unique_ptr<ir::Node>> nodes;
// clang-format off // clang-format off
...@@ -56,8 +63,15 @@ TEST(OrderedNodeList, Normal) { ...@@ -56,8 +63,15 @@ TEST(OrderedNodeList, Normal) {
nodes.emplace_back(std::move(node)); nodes.emplace_back(std::move(node));
} }
// Insert
for (auto& node : nodes) { for (auto& node : nodes) {
pool.Insert(node.get(), op.get()); pool.Insert(node.get());
}
// Has/size
ASSERT_EQ(pool.size(), shapes.size());
for (auto& node : nodes) {
ASSERT_TRUE(pool.Has(node.get()));
} }
// assert its order and interface. // assert its order and interface.
...@@ -66,14 +80,14 @@ TEST(OrderedNodeList, Normal) { ...@@ -66,14 +80,14 @@ TEST(OrderedNodeList, Normal) {
std::cout << pool.ToString() << std::endl; std::cout << pool.ToString() << std::endl;
ASSERT_EQ(pool.size(), static_cast<size_t>(COUNT - 1)); ASSERT_EQ(pool.size(), static_cast<size_t>(COUNT - 1));
ASSERT_EQ(pool.GetIndex(nodes.back().get()), 0); ASSERT_EQ(pool.GetNodeIndexInPool(nodes.back().get()), 0);
{ {
auto v1 = block_desc->Var("11"); auto v1 = block_desc->Var("11");
v1->SetShape({-1, 256, 56, 56}); v1->SetShape({-1, 256, 56, 56});
std::unique_ptr<ir::Node> node1 = ir::CreateNodeForTest(v1); std::unique_ptr<ir::Node> node1 = ir::CreateNodeForTest(v1);
node1->inputs.emplace_back(op.get()); node1->inputs.emplace_back(op.get());
auto* cache = pool.NodeMatch(node1.get()); auto* cache = pool.FindBestFitNode(node1.get());
ASSERT_EQ(cache, nullptr); ASSERT_EQ(cache, nullptr);
} }
{ {
...@@ -81,16 +95,401 @@ TEST(OrderedNodeList, Normal) { ...@@ -81,16 +95,401 @@ TEST(OrderedNodeList, Normal) {
v2->SetShape({-1, 2, 5}); v2->SetShape({-1, 2, 5});
std::unique_ptr<ir::Node> node1 = ir::CreateNodeForTest(v2); std::unique_ptr<ir::Node> node1 = ir::CreateNodeForTest(v2);
node1->inputs.emplace_back(op.get()); node1->inputs.emplace_back(op.get());
auto* cache = pool.NodeMatch(node1.get()); auto* cache = pool.FindBestFitNode(node1.get());
ASSERT_EQ(pool.GetIndex(cache), 2); // match 6:[-1,2,5] ASSERT_EQ(pool.GetNodeIndexInPool(cache), 2); // match 6:[-1,2,5]
} }
{ {
auto v3 = block_desc->Var("13"); auto v3 = block_desc->Var("13");
v3->SetShape({2, 5}); v3->SetShape({2, 5});
std::unique_ptr<ir::Node> node1 = ir::CreateNodeForTest(v3); std::unique_ptr<ir::Node> node1 = ir::CreateNodeForTest(v3);
node1->inputs.emplace_back(op.get()); node1->inputs.emplace_back(op.get());
auto* cache = pool.NodeMatch(node1.get()); auto* cache = pool.FindBestFitNode(node1.get());
ASSERT_EQ(pool.GetIndex(cache), 5); // match 4:[5,2] ASSERT_EQ(pool.GetNodeIndexInPool(cache), 5); // match 4:[5,2]
}
}
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_OPERATOR(sum, paddle::framework::DummyOp,
paddle::framework::SumOpMaker,
paddle::framework::DummyVarTypeInference);
REGISTER_OPERATOR(assign, paddle::framework::DummyOp,
paddle::framework::AssignOpMaker,
paddle::framework::DummyVarTypeInference);
REGISTER_OPERATOR(dummy, paddle::framework::DummyOp,
paddle::framework::SumOpMaker,
paddle::framework::DummyVarTypeInference);
/*
https://en.wikipedia.org/wiki/Live_variable_analysis
Create a customed classical dependency graph, left row is the instruction
number.
1. a = 1
2. b = a
3. c = a
4. d = b + c
5. e = d
a--------+
| |
b c
| |
d--------+
|
e
Then analysis these variable's liveness range
*/
namespace paddle {
namespace framework {
namespace details {
inline static ProgramDesc FillProgramDesc() {
ProgramDesc prog;
prog.MutableBlock(0)->Var("a")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("b")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("c")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("d")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("e")->SetType(proto::VarType::LOD_TENSOR);
{
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("assign");
op->SetInput("X", {"a"});
op->SetOutput("Out", {"b"});
}
{
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("assign");
op->SetInput("X", {"a"});
op->SetOutput("Out", {"c"});
}
{
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum");
op->SetInput("X", {"b", "c"});
op->SetOutput("Out", {"d"});
}
{
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("assign");
op->SetInput("X", {"d"});
op->SetOutput("Out", {"e"});
}
return prog;
}
TEST(CFGGraph, IRGraph) {
// prepare ir graph
auto prog = FillProgramDesc();
ir::Graph graph(prog);
const std::vector<OpDesc*>* all_op_descs =
new std::vector<OpDesc*>(prog.Block(0).AllOps());
graph.Set(details::kAllOpDescs, all_op_descs); // take ownership
ControlFlowGraph cfg(graph);
cfg.LiveVariableAnalysis();
// test assign op
ASSERT_TRUE((std::set<std::string>{"a"} == cfg.LiveIn(cfg.Ops()[0])));
ASSERT_TRUE((std::set<std::string>{"a", "b"} == cfg.LiveOut(cfg.Ops()[0])));
// test assign op
ASSERT_TRUE((std::set<std::string>{"a", "b"} == cfg.LiveIn(cfg.Ops()[1])));
ASSERT_TRUE((std::set<std::string>{"b", "c"} == cfg.LiveOut(cfg.Ops()[1])));
// test sum op
ASSERT_TRUE((std::set<std::string>{"b", "c"} == cfg.LiveIn(cfg.Ops()[2])));
ASSERT_TRUE((std::set<std::string>{"d"} == cfg.LiveOut(cfg.Ops()[2])));
// test assign op
ASSERT_TRUE((std::set<std::string>{"d"} == cfg.LiveIn(cfg.Ops()[3])));
ASSERT_TRUE((std::set<std::string>{} == cfg.LiveOut(cfg.Ops()[3])));
}
// 1. normal test
TEST(SortOpLikeDescOrder, NormalTest) {
auto prog = FillProgramDesc();
ir::Graph graph(prog);
const std::vector<OpDesc*>* all_op_descs =
new std::vector<OpDesc*>(prog.Block(0).AllOps());
graph.Set(details::kAllOpDescs, all_op_descs); // take ownership
auto nodes = SortOpLikeDescOrder(graph);
auto op_descs = prog.Block(0).AllOps();
for (size_t i = 0; i < nodes.size(); ++i) {
auto node = nodes[i];
auto op_desc = op_descs[i];
ASSERT_TRUE(IsSameDesc(node->Op(), op_desc));
}
}
// 2. remove some op_desc
TEST(SortOpLikeDescOrder, RemoveOpDesc) {
auto prog = FillProgramDesc();
ir::Graph graph(prog);
const std::vector<OpDesc*>* all_op_descs =
new std::vector<OpDesc*>(prog.Block(0).AllOps());
graph.Set(details::kAllOpDescs, all_op_descs); // take ownership
auto nodes = graph.Nodes();
auto op_descs = prog.Block(0).AllOps();
ir::Node* found_node = nullptr;
for (auto node : nodes) {
if (node->IsOp() && node->outputs.back()->Name() == "e") {
found_node = node;
break;
}
}
PADDLE_ENFORCE(found_node != nullptr);
for (auto it = op_descs.begin(); it != op_descs.end();) {
if (IsSameDesc(*it, found_node->Op())) {
it = op_descs.erase(it);
} else {
++it;
}
}
auto find_node_in_graph = [&](std::string s) {
ir::Node* ret = nullptr;
for (auto n : graph.Nodes()) {
if (n->Name() == s) {
ret = n;
break;
}
}
PADDLE_ENFORCE(ret != nullptr);
return ret;
};
ir::Node* e = find_node_in_graph("e");
ir::Node* d = find_node_in_graph("d");
std::remove(d->outputs.begin(), d->outputs.end(), found_node);
graph.RemoveNode(found_node);
graph.RemoveNode(e);
// other node keeps the same order
auto remain_nodes = SortOpLikeDescOrder(graph);
for (size_t i = 0; i < remain_nodes.size(); ++i) {
auto node = remain_nodes[i];
auto op_desc = op_descs[i];
ASSERT_TRUE(IsSameDesc(node->Op(), op_desc));
}
}
// 3. add some op_desc
TEST(SortOpLikeDescOrder, AddOpDesc) {
auto prog = FillProgramDesc();
const std::vector<OpDesc*>* all_op_descs =
new std::vector<OpDesc*>(prog.Block(0).AllOps());
ir::Graph graph(prog);
auto find_node_in_graph = [&](std::string s) {
ir::Node* ret = nullptr;
for (auto n : graph.Nodes()) {
if (n->Name() == s) {
ret = n;
break;
}
}
PADDLE_ENFORCE(ret != nullptr);
return ret;
};
// cached desc different with real one
// mimic the intermidiete pass modify the programdesc.
graph.Set(details::kAllOpDescs, all_op_descs); // take ownership
auto op_descs = prog.Block(0).AllOps();
auto op = prog.MutableBlock(0)->AppendOp();
prog.MutableBlock(0)->Var("d1")->SetType(proto::VarType::LOD_TENSOR);
op->SetType("sum");
op->SetInput("X", {"b", "c"});
op->SetOutput("Out", {"d1"});
ir::Node* node = graph.CreateOpNode(op);
ir::Node* d1 = graph.CreateVarNode(prog.MutableBlock(0)->Var("d1"));
ir::Node* b = find_node_in_graph("b");
ir::Node* c = find_node_in_graph("c");
node->outputs.emplace_back(d1);
node->inputs.emplace_back(b);
node->inputs.emplace_back(c);
d1->inputs.emplace_back(node);
b->outputs.emplace_back(node);
c->outputs.emplace_back(node);
op_descs.insert(op_descs.begin() + 4, op);
auto nodes = SortOpLikeDescOrder(graph);
for (size_t i = 0; i < nodes.size(); ++i) {
auto node = nodes[i];
auto op_desc = op_descs[i];
ASSERT_TRUE(IsSameDesc(node->Op(), op_desc));
}
}
// 4. add and delete some op_desc
TEST(SortOpLikeDescOrder, AddAndDeleteOpDesc) {
auto prog = FillProgramDesc();
ir::Graph graph(prog);
const std::vector<OpDesc*>* all_op_descs =
new std::vector<OpDesc*>(prog.Block(0).AllOps());
graph.Set(details::kAllOpDescs, all_op_descs); // take ownership
auto find_node_in_graph = [&](std::string s) {
ir::Node* ret = nullptr;
for (auto n : graph.Nodes()) {
if (n->Name() == s) {
ret = n;
break;
}
}
PADDLE_ENFORCE(ret != nullptr);
return ret;
};
// remove sum node
auto op_descs = prog.Block(0).AllOps();
ir::Node* found_node = nullptr;
auto nodes = graph.Nodes();
for (auto node : nodes) {
if (node->Name() == "sum") {
found_node = node;
break;
}
}
PADDLE_ENFORCE(found_node != nullptr);
for (auto it = op_descs.begin(); it != op_descs.end();) {
if (IsSameDesc(*it, found_node->Op())) {
it = op_descs.erase(it);
} else {
++it;
}
}
{
ir::Node* d = find_node_in_graph("d");
ir::Node* c = find_node_in_graph("c");
ir::Node* e = find_node_in_graph("e");
std::remove(d->outputs.begin(), d->outputs.end(), found_node);
std::remove(c->outputs.begin(), c->outputs.end(), found_node);
ir::Node* pending_op = found_node->outputs[0]->outputs[0];
graph.RemoveNode(e);
graph.RemoveNode(pending_op);
graph.RemoveNode(found_node);
}
// add node
auto op = prog.MutableBlock(0)->AppendOp();
prog.MutableBlock(0)->Var("d1")->SetType(proto::VarType::LOD_TENSOR);
op->SetType("sum");
op->SetInput("X", {"b", "c"});
op->SetOutput("Out", {"d1"});
{
ir::Node* node = graph.CreateOpNode(op);
ir::Node* d1 = graph.CreateVarNode(prog.MutableBlock(0)->Var("d1"));
ir::Node* b = find_node_in_graph("b");
ir::Node* c = find_node_in_graph("c");
node->outputs.emplace_back(d1);
node->inputs.emplace_back(b);
node->inputs.emplace_back(c);
b->outputs.emplace_back(node);
c->outputs.emplace_back(node);
}
op_descs.insert(op_descs.begin() + 2, op);
// check the order
auto mynodes = SortOpLikeDescOrder(graph);
for (size_t i = 0; i < mynodes.size(); ++i) {
auto node = mynodes[i];
auto op_desc = op_descs[i];
ASSERT_TRUE(IsSameDesc(node->Op(), op_desc));
}
}
// 5. add and replace some op_desc inplace.
TEST(SortOpLikeDescOrder, AddAndReplaceOpDescInplace) {
auto prog = FillProgramDesc();
ir::Graph graph(prog);
const std::vector<OpDesc*>* all_op_descs =
new std::vector<OpDesc*>(prog.Block(0).AllOps());
graph.Set(details::kAllOpDescs, all_op_descs); // take ownership
auto find_node_in_graph = [&](std::string s) {
ir::Node* ret = nullptr;
for (auto n : graph.Nodes()) {
if (n->Name() == s) {
ret = n;
break;
}
}
PADDLE_ENFORCE(ret != nullptr);
return ret;
};
auto op_descs = prog.Block(0).AllOps();
// add node
auto op = prog.MutableBlock(0)->AppendOp();
prog.MutableBlock(0)->Var("d1")->SetType(proto::VarType::LOD_TENSOR);
op->SetType("sum");
op->SetInput("X", {"b", "c"});
op->SetOutput("Out", {"d1"});
{
ir::Node* node = graph.CreateOpNode(op);
ir::Node* d1 = graph.CreateVarNode(prog.MutableBlock(0)->Var("d1"));
ir::Node* b = find_node_in_graph("b");
ir::Node* c = find_node_in_graph("c");
node->outputs.emplace_back(d1);
node->inputs.emplace_back(b);
node->inputs.emplace_back(c);
d1->inputs.emplace_back(node);
b->outputs.emplace_back(node);
c->outputs.emplace_back(node);
}
op_descs.emplace_back(op);
// replace op_desc inplace
auto nodes = graph.Nodes();
ir::Node* found_node = nullptr;
for (auto node : nodes) {
if (node->IsOp() && node->Op() && node->Name() == "assign") {
if (node->outputs.size() == 1 && node->outputs[0]->Name() == "e") {
found_node = node;
break;
}
}
}
{
ir::Node* d = find_node_in_graph("d");
ir::Node* e = find_node_in_graph("e");
std::remove(d->outputs.begin(), d->outputs.end(), found_node);
std::remove(e->inputs.begin(), e->inputs.end(), found_node);
graph.RemoveNode(found_node);
}
op_descs.erase(op_descs.begin() + 3);
auto replace_op = prog.MutableBlock(0)->AppendOp();
replace_op->SetType("sum");
replace_op->SetInput("X", {"d", "d1"});
replace_op->SetOutput("Out", {"e"});
{
ir::Node* sum2 = graph.CreateOpNode(replace_op);
ir::Node* e = find_node_in_graph("e");
ir::Node* d = find_node_in_graph("d");
ir::Node* d1 = find_node_in_graph("d1");
sum2->inputs.emplace_back(d);
sum2->inputs.emplace_back(d1);
sum2->outputs.emplace_back(e);
e->inputs.emplace_back(sum2);
d->outputs.emplace_back(sum2);
d1->outputs.emplace_back(sum2);
}
op_descs.emplace_back(replace_op);
// compare op order
auto graph_nodes = SortOpLikeDescOrder(graph);
for (size_t i = 0; i < graph_nodes.size(); ++i) {
auto node = graph_nodes[i];
auto op_desc = op_descs[i];
ASSERT_TRUE(IsSameDesc(node->Op(), op_desc));
} }
} }
......
...@@ -43,11 +43,6 @@ namespace paddle { ...@@ -43,11 +43,6 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
static inline bool IsSameDesc(OpDesc* op1, OpDesc* op2) {
return op1->Type() == op2->Type() && op1->Inputs() == op2->Inputs() &&
op1->Outputs() == op2->Outputs();
}
std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl( std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
auto nodes = graph->Nodes(); auto nodes = graph->Nodes();
...@@ -77,7 +72,7 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl( ...@@ -77,7 +72,7 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
if (!NodeCanReused(var) || cfg_->Use(op).count(var->Name()) == 0 || if (!NodeCanReused(var) || cfg_->Use(op).count(var->Name()) == 0 ||
skip_set_.count(var->Name())) skip_set_.count(var->Name()))
continue; continue;
ir::Node* cache = pool_.NodeMatch(var); ir::Node* cache = pool_.FindBestFitNode(var);
if (var->Name() == FLAGS_memory_optimize_debug) { if (var->Name() == FLAGS_memory_optimize_debug) {
VLOG(3) << "start match var " << DebugString(var) << " of op " VLOG(3) << "start match var " << DebugString(var) << " of op "
...@@ -95,11 +90,12 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl( ...@@ -95,11 +90,12 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
<< "replace it again. Skip this candidate."; << "replace it again. Skip this candidate.";
continue; continue;
int node_idx_in_pool = pool_.GetIndex(cache); int node_idx_in_pool = pool_.GetNodeIndexInPool(cache);
VLOG(3) << string::Sprintf( VLOG(3) << string::Sprintf(
"!!! %s, %s => %s, cache idx %d, pool size %d", "!!! %s, %s => %s, cache idx %d, pool size %d",
std::to_string(reuse_id++), DebugString(var), DebugString(cache), std::to_string(reuse_id++), DebugString(var), DebugString(cache),
node_idx_in_pool, static_cast<int>(pool_.size())); node_idx_in_pool, static_cast<int>(pool_.size()));
// update CFG Graph on the fly. // update CFG Graph on the fly.
// reused var maybe re-fill into the pool // reused var maybe re-fill into the pool
cfg_->RenameVarInCFGGraph(var->Name(), cache->Name(), idx); cfg_->RenameVarInCFGGraph(var->Name(), cache->Name(), idx);
...@@ -112,6 +108,7 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl( ...@@ -112,6 +108,7 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
pool_.Erase(cache); pool_.Erase(cache);
} }
// fill the pool // fill the pool
std::unordered_set<std::string> unlived_vars; std::unordered_set<std::string> unlived_vars;
for (auto var : cfg_->LiveIn(op)) { for (auto var : cfg_->LiveIn(op)) {
...@@ -120,36 +117,15 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl( ...@@ -120,36 +117,15 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
} }
} }
for (auto var : unlived_vars) { for (auto var : unlived_vars) {
ir::Node* var_node = cfg_->GetNodeFromVarName(var, op); ir::Node* var_node = cfg_->GetNodeByName(var, op);
if (NodeCanReused(var_node) && !pool_.Has(var_node)) { if (NodeCanReused(var_node) && !pool_.Has(var_node)) {
pool_.Insert(var_node, op); pool_.Insert(var_node);
} }
} }
} }
} }
graph->ResolveHazard(var_nodes_); graph->ResolveHazard(var_nodes_);
// For early delete pass. use GraphNodePool load the unlived vars.
// 1. find all deps op for each unlived var in memory pool.
for (auto& op : graph->Nodes()) {
for (auto& var : op->inputs) {
if (pool_.Has(var)) {
pool_.Insert(var, op);
}
}
}
// 2. convert ir node based memory pool to graph node
// because Node* maybe released bettwen passes.
auto& graph_pool = graph->Get<GraphNodePool>(kGraphNodePool);
for (auto it = pool_.begin(); it != pool_.end(); ++it) {
std::unordered_set<OpDesc*> descs;
for (auto& op : it->second) {
PADDLE_ENFORCE(op->IsOp());
descs.insert(op->Op());
}
graph_pool.push_back(std::make_pair(it->first->Name(), descs));
}
return graph; return graph;
} }
...@@ -198,12 +174,12 @@ void MemoryOptimizePass::SubGraphOptimize(OpDesc* op_desc) const { ...@@ -198,12 +174,12 @@ void MemoryOptimizePass::SubGraphOptimize(OpDesc* op_desc) const {
PADDLE_ENFORCE(sub_op != nullptr); PADDLE_ENFORCE(sub_op != nullptr);
for (auto* var : sub_op->outputs) { for (auto* var : sub_op->outputs) {
if (NodeCanReused(var)) { if (NodeCanReused(var)) {
ir::Node* cache = pool_.NodeMatch(var); ir::Node* cache = pool_.FindBestFitNode(var);
if (cache != nullptr) { if (cache != nullptr) {
if (var->Var()->GetDataType() != cache->Var()->GetDataType()) { if (var->Var()->GetDataType() != cache->Var()->GetDataType()) {
continue; continue;
} }
int node_idx_in_pool = pool_.GetIndex(cache); int node_idx_in_pool = pool_.GetNodeIndexInPool(cache);
VLOG(3) << string::Sprintf( VLOG(3) << string::Sprintf(
"!!! %s, %s => %s, cache idx %d, pool size %d", "!!! %s, %s => %s, cache idx %d, pool size %d",
std::to_string(sub_reuse_id++), DebugString(var), std::to_string(sub_reuse_id++), DebugString(var),
...@@ -342,267 +318,10 @@ void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var, ...@@ -342,267 +318,10 @@ void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var,
var_nodes_.at(var).clear(); var_nodes_.at(var).clear();
} }
std::vector<ir::Node*> SortOpLikeDescOrder(const ir::Graph& graph) {
PADDLE_ENFORCE(graph.Has(kAllOpDescs),
"Graph has no attribute of kAllOpDescs.");
// 1. get op desc order
auto& op_descs = graph.Get<const std::vector<OpDesc*>>(kAllOpDescs);
// 2. topology sort order
auto nodes = graph.Nodes();
std::deque<ir::Node*> ops;
FilterVariables(nodes, [&](ir::Node* op) {
if (op->IsOp() && op->Op() != nullptr) {
ops.emplace_back(op);
}
});
std::unordered_map<ir::Node*, size_t> op_deps;
std::list<ir::Node*> ready_ops;
std::unordered_map<ir::Node*, std::unordered_set<ir::Node*>> pending_ops;
for (auto* op : ops) {
std::unordered_set<ir::Node*> preceding_op;
for (auto* in : op->inputs) {
if (in->inputs.empty()) continue;
PADDLE_ENFORCE(in->inputs.size() == 1 && in->inputs[0]->IsOp());
preceding_op.emplace(in->inputs[0]);
pending_ops[in->inputs[0]].emplace(op);
}
op_deps[op] = preceding_op.size();
if (preceding_op.empty()) {
ready_ops.emplace_back(op);
}
}
// 3. generated op list based desc order and the topology order
std::vector<ir::Node*> ret;
std::list<OpDesc*> op_descs_list(op_descs.begin(), op_descs.end());
auto update_by_found_node = [&](ir::Node* found_node) {
for (auto* pending_op : pending_ops[found_node]) {
if (--op_deps[pending_op] == 0) {
ready_ops.emplace_back(pending_op);
}
}
ready_ops.remove(found_node);
ret.emplace_back(found_node);
};
while (!ready_ops.empty()) {
bool all_of_ready_op_unmatched = true;
for (auto it = op_descs_list.begin(); it != op_descs_list.end();) {
auto op_desc = *it;
ir::Node* found_node = nullptr;
for (auto* op : ready_ops) {
if (IsSameDesc(op->Op(), op_desc)) {
found_node = op;
break;
}
}
// 3.1 op desc deleted by other pass
if (found_node == nullptr) {
++it;
continue;
} else {
all_of_ready_op_unmatched = false;
it = op_descs_list.erase(it);
}
update_by_found_node(found_node);
}
// 3.2 op descs are added by other pass
// preceding op non empty means some new op descs are
// created, but not contained in return node list.
// these new op desc may depend on each other.
std::list<ir::Node*> prev_ready_ops(ready_ops);
if (all_of_ready_op_unmatched) {
for (auto op : prev_ready_ops) {
update_by_found_node(op);
}
}
}
PADDLE_ENFORCE(std::all_of(
op_deps.begin(), op_deps.end(),
[&](const std::pair<ir::Node*, size_t>& p) { return p.second == 0; }));
return ret;
}
ControlFlowGraph::ControlFlowGraph(const ir::Graph& graph) {
ops_ = SortOpLikeDescOrder(graph);
ConnectNodes();
}
void ControlFlowGraph::BuildCFGGraph() {
// FIXME(dzh): same effect with ConnectNodes, but use the control
// link to build dependency graph, it goes wrong in transformer.
for (ir::Node* op : ops_) {
for (auto& input_var : op->inputs) {
if (!input_var->inputs.empty()) {
PADDLE_ENFORCE(
input_var->inputs.size() == 1 && input_var->inputs[0]->IsOp(),
"Preceding Op Node of Var Node must be unique");
auto* pred_op = input_var->inputs[0];
if (pred_op->Op() != nullptr) {
predecessors_[op].insert(pred_op);
successors_[pred_op].insert(op);
}
}
if (input_var->IsVar() && !input_var->IsCtrlVar()) {
uses_[op].insert(input_var->Name());
}
}
for (auto& output_var : op->outputs) {
// output var may be used by many op
for (auto* succ_op : output_var->outputs) {
if (succ_op->Op() != nullptr) {
successors_[op].insert(succ_op);
predecessors_[succ_op].insert(op);
}
}
if (output_var->IsVar() && !output_var->IsCtrlVar()) {
defs_[op].insert(output_var->Name());
}
}
}
}
void ControlFlowGraph::ConnectNodes() {
for (size_t i = 0; i < ops_.size(); ++i) {
auto& op = ops_[i];
try {
auto& next_op = ops_.at(i + 1);
successors_[op].insert(next_op);
predecessors_[next_op].insert(op);
} catch (...) {
// do nothing
}
FilterVariables(op->inputs,
[&](ir::Node* var) { uses_[op].emplace(var->Name()); });
FilterVariables(op->outputs,
[&](ir::Node* var) { defs_[op].emplace(var->Name()); });
}
}
void ControlFlowGraph::LiveVariableAnalysis() {
// NOTE(dzh): variable liveless analysis (a.k.a reversed_ops algorithm)
// compute the liveness of for each variable though reversed_ops algorithm.
// It iterates the operators from end to begin, compute the live in/live out
// variable set for each op, then the diff between in/out will be used for
// the variable reuse. For detail refer to
// http://www.cs.cornell.edu/courses/cs4120/2013fa/lectures/lec26-fa13.pdf
std::list<ir::Node*> work_list(ops_.rbegin(), ops_.rend());
while (!work_list.empty()) {
ir::Node* op = work_list.front();
work_list.pop_front();
// get the live_in calculated before. Empty if first.
auto prev_live_in = std::move(live_in_[op]);
for (auto& s : successors_[op]) {
for (auto& var : live_in_[s]) {
live_out_[op].insert(var);
}
}
for (auto& var : uses_[op]) {
live_in_[op].insert(var);
}
for (auto& var : live_out_[op]) {
live_in_[op].insert(var);
}
for (auto& var : defs_[op]) {
live_in_[op].erase(var);
}
// If the live_in is not changed, then the liveness analysis of
// predecessors is completed.
//
// Otherwise, recalculate the predecessors liveness
if (live_in_[op] != prev_live_in) {
for (auto& pre : predecessors_[op]) {
work_list.push_back(pre);
}
}
}
}
void ControlFlowGraph::RenameVarInCFGGraph(const std::string& old_node,
const std::string& new_node,
int begin_idx) {
// update graph from begin idx to the end
for (size_t i = begin_idx; i != ops_.size(); ++i) {
auto* op = ops_[i];
if (uses_[op].find(old_node) != uses_[op].end()) {
uses_[op].erase(old_node);
uses_[op].insert(new_node);
}
if (defs_[op].find(old_node) != defs_[op].end()) {
defs_[op].erase(old_node);
defs_[op].insert(new_node);
}
if (live_in_[op].find(old_node) != live_in_[op].end()) {
live_in_[op].erase(old_node);
live_in_[op].insert(new_node);
}
if (live_out_[op].find(old_node) != live_out_[op].end()) {
live_out_[op].erase(old_node);
live_out_[op].insert(new_node);
}
}
}
const std::set<std::string> ControlFlowGraph::LiveIn(ir::Node* op) const {
auto it = live_in_.find(op);
PADDLE_ENFORCE(
it != live_in_.end(),
string::Sprintf("Expect %s in live_in, but Not Found.", op->Name()));
return it->second;
}
const std::set<std::string> ControlFlowGraph::LiveOut(ir::Node* op) const {
auto it = live_out_.find(op);
PADDLE_ENFORCE(
it != live_out_.end(),
string::Sprintf("Expect %s in live_out, but Not Found.", op->Name()));
return it->second;
}
const std::set<std::string> ControlFlowGraph::Use(ir::Node* op) const {
auto it = uses_.find(op);
PADDLE_ENFORCE(
it != uses_.end(),
string::Sprintf("Expect %s in live_out, but Not Found.", op->Name()));
return it->second;
}
const std::vector<ir::Node*> ControlFlowGraph::Ops() const { return ops_; }
std::vector<ir::Node*>& ControlFlowGraph::Ops() { return ops_; }
ir::Node* ControlFlowGraph::GetNodeFromVarName(const std::string& name,
ir::Node* op) const {
// in ssa-graph, different version nodes have same name,
// this function get the latest version var before target op
// It may return nullptr, such as data node.
ir::Node* found_node = nullptr;
for (auto* node : ops_) {
if (node == op) break;
for (auto& output : node->outputs) {
if (output->Name() == name) {
found_node = output;
}
}
}
return found_node;
}
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(memory_optimize_pass, REGISTER_PASS(memory_optimize_pass,
paddle::framework::details::MemoryOptimizePass) paddle::framework::details::MemoryOptimizePass)
.RequireGraphAttr(paddle::framework::details::kGraphNodePool)
.RequireGraphAttr(paddle::framework::details::kAllOpDescs); .RequireGraphAttr(paddle::framework::details::kAllOpDescs);
...@@ -32,20 +32,15 @@ ...@@ -32,20 +32,15 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
constexpr char kAllOpDescs[] = "all_op_descs";
std::vector<ir::Node*> SortOpLikeDescOrder(const ir::Graph& graph);
class ControlFlowGraph;
class MemoryOptimizePass : public ir::Pass { class MemoryOptimizePass : public ir::Pass {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override; std::unique_ptr<ir::Graph> graph) const override;
private:
// fill the variable map(var_nodes) by version. // fill the variable map(var_nodes) by version.
void InitSSAGraphNodes() const; void InitSSAGraphNodes() const;
private:
// update program descs // update program descs
void RenameVarInGraphDesc(const std::string& var, void RenameVarInGraphDesc(const std::string& var,
const std::string& cache_var, size_t idx) const; const std::string& cache_var, size_t idx) const;
...@@ -62,7 +57,7 @@ class MemoryOptimizePass : public ir::Pass { ...@@ -62,7 +57,7 @@ class MemoryOptimizePass : public ir::Pass {
private: private:
// Reuse Node Pool, Owned. // Reuse Node Pool, Owned.
mutable OrderedNodeList pool_; mutable OrderedSet pool_;
// controlflow Graph // controlflow Graph
mutable std::unique_ptr<ControlFlowGraph> cfg_; mutable std::unique_ptr<ControlFlowGraph> cfg_;
// skip set // skip set
...@@ -71,45 +66,6 @@ class MemoryOptimizePass : public ir::Pass { ...@@ -71,45 +66,6 @@ class MemoryOptimizePass : public ir::Pass {
mutable std::map<std::string, std::vector<ir::Node*>> var_nodes_; mutable std::map<std::string, std::vector<ir::Node*>> var_nodes_;
}; };
class ControlFlowGraph {
public:
ControlFlowGraph() = default;
// For IR Graph in parallelexecutor
explicit ControlFlowGraph(const ir::Graph& graph);
void LiveVariableAnalysis();
void RenameVarInCFGGraph(const std::string& old_node,
const std::string& new_node, int begin_idx);
const std::set<std::string> LiveIn(ir::Node* op) const;
const std::set<std::string> LiveOut(ir::Node* op) const;
const std::set<std::string> Use(ir::Node* op) const;
const std::vector<ir::Node*> Ops() const;
std::vector<ir::Node*>& Ops();
// for ssa-graph nodes
ir::Node* GetNodeFromVarName(const std::string& name, ir::Node* op) const;
private:
void BuildCFGGraph();
void ConnectNodes();
using NodeListMap = std::unordered_map<ir::Node*, std::set<ir::Node*>>;
using VarSetMap = std::map<ir::Node*, std::set<std::string>>;
// successors ops use the output variables.
NodeListMap successors_;
// predecessors ops generated input variables.
NodeListMap predecessors_;
// variables lived before run current op.
VarSetMap live_in_;
// variables lived after run current op.
VarSetMap live_out_;
VarSetMap uses_; // op inputs
VarSetMap defs_; // op outputs
std::vector<ir::Node*> ops_; // op sequence by topology sort
};
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/memory_optimize_pass.h"
#include <algorithm>
#include <iostream>
#include <iterator>
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/details/graph_test_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
REGISTER_OPERATOR(sum, paddle::framework::DummyOp,
paddle::framework::SumOpMaker,
paddle::framework::DummyVarTypeInference);
REGISTER_OPERATOR(assign, paddle::framework::DummyOp,
paddle::framework::AssignOpMaker,
paddle::framework::DummyVarTypeInference);
REGISTER_OPERATOR(dummy, paddle::framework::DummyOp,
paddle::framework::SumOpMaker,
paddle::framework::DummyVarTypeInference);
/*
https://en.wikipedia.org/wiki/Live_variable_analysis
Create a customed classical dependency graph, left row is the instruction
number.
1. a = 1
2. b = a
3. c = a
4. d = b + c
5. e = d
a--------+
| |
b c
| |
d--------+
|
e
Then analysis these variable's liveness range
*/
namespace paddle {
namespace framework {
namespace details {
static inline bool IsSameDesc(OpDesc* op1, OpDesc* op2) {
return op1->Type() == op2->Type() && op1->Inputs() == op2->Inputs() &&
op1->Outputs() == op2->Outputs();
}
inline static ProgramDesc FillProgramDesc() {
ProgramDesc prog;
prog.MutableBlock(0)->Var("a")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("b")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("c")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("d")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("e")->SetType(proto::VarType::LOD_TENSOR);
{
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("assign");
op->SetInput("X", {"a"});
op->SetOutput("Out", {"b"});
}
{
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("assign");
op->SetInput("X", {"a"});
op->SetOutput("Out", {"c"});
}
{
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum");
op->SetInput("X", {"b", "c"});
op->SetOutput("Out", {"d"});
}
{
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("assign");
op->SetInput("X", {"d"});
op->SetOutput("Out", {"e"});
}
return prog;
}
TEST(CFGGraph, IRGraph) {
// prepare ir graph
auto prog = FillProgramDesc();
ir::Graph graph(prog);
const std::vector<OpDesc*>* all_op_descs =
new std::vector<OpDesc*>(prog.Block(0).AllOps());
graph.Set(details::kAllOpDescs, all_op_descs); // take ownership
ControlFlowGraph cfg(graph);
cfg.LiveVariableAnalysis();
// test assign op
ASSERT_TRUE((std::set<std::string>{"a"} == cfg.LiveIn(cfg.Ops()[0])));
ASSERT_TRUE((std::set<std::string>{"a", "b"} == cfg.LiveOut(cfg.Ops()[0])));
// test assign op
ASSERT_TRUE((std::set<std::string>{"a", "b"} == cfg.LiveIn(cfg.Ops()[1])));
ASSERT_TRUE((std::set<std::string>{"b", "c"} == cfg.LiveOut(cfg.Ops()[1])));
// test sum op
ASSERT_TRUE((std::set<std::string>{"b", "c"} == cfg.LiveIn(cfg.Ops()[2])));
ASSERT_TRUE((std::set<std::string>{"d"} == cfg.LiveOut(cfg.Ops()[2])));
// test assign op
ASSERT_TRUE((std::set<std::string>{"d"} == cfg.LiveIn(cfg.Ops()[3])));
ASSERT_TRUE((std::set<std::string>{} == cfg.LiveOut(cfg.Ops()[3])));
}
// 1. normal test
TEST(SortOpLikeDescOrder, NormalTest) {
auto prog = FillProgramDesc();
ir::Graph graph(prog);
const std::vector<OpDesc*>* all_op_descs =
new std::vector<OpDesc*>(prog.Block(0).AllOps());
graph.Set(details::kAllOpDescs, all_op_descs); // take ownership
auto nodes = SortOpLikeDescOrder(graph);
auto op_descs = prog.Block(0).AllOps();
for (size_t i = 0; i < nodes.size(); ++i) {
auto node = nodes[i];
auto op_desc = op_descs[i];
ASSERT_TRUE(IsSameDesc(node->Op(), op_desc));
}
}
// 2. remove some op_desc
TEST(SortOpLikeDescOrder, RemoveOpDesc) {
auto prog = FillProgramDesc();
ir::Graph graph(prog);
const std::vector<OpDesc*>* all_op_descs =
new std::vector<OpDesc*>(prog.Block(0).AllOps());
graph.Set(details::kAllOpDescs, all_op_descs); // take ownership
auto nodes = graph.Nodes();
auto op_descs = prog.Block(0).AllOps();
ir::Node* found_node = nullptr;
for (auto node : nodes) {
if (node->IsOp() && node->outputs.back()->Name() == "e") {
found_node = node;
break;
}
}
PADDLE_ENFORCE(found_node != nullptr);
for (auto it = op_descs.begin(); it != op_descs.end();) {
if (IsSameDesc(*it, found_node->Op())) {
it = op_descs.erase(it);
} else {
++it;
}
}
auto find_node_in_graph = [&](std::string s) {
ir::Node* ret = nullptr;
for (auto n : graph.Nodes()) {
if (n->Name() == s) {
ret = n;
break;
}
}
PADDLE_ENFORCE(ret != nullptr);
return ret;
};
ir::Node* e = find_node_in_graph("e");
ir::Node* d = find_node_in_graph("d");
std::remove(d->outputs.begin(), d->outputs.end(), found_node);
graph.RemoveNode(found_node);
graph.RemoveNode(e);
// other node keeps the same order
auto remain_nodes = SortOpLikeDescOrder(graph);
for (size_t i = 0; i < remain_nodes.size(); ++i) {
auto node = remain_nodes[i];
auto op_desc = op_descs[i];
ASSERT_TRUE(IsSameDesc(node->Op(), op_desc));
}
}
// 3. add some op_desc
TEST(SortOpLikeDescOrder, AddOpDesc) {
auto prog = FillProgramDesc();
const std::vector<OpDesc*>* all_op_descs =
new std::vector<OpDesc*>(prog.Block(0).AllOps());
ir::Graph graph(prog);
auto find_node_in_graph = [&](std::string s) {
ir::Node* ret = nullptr;
for (auto n : graph.Nodes()) {
if (n->Name() == s) {
ret = n;
break;
}
}
PADDLE_ENFORCE(ret != nullptr);
return ret;
};
// cached desc different with real one
// mimic the intermidiete pass modify the programdesc.
graph.Set(details::kAllOpDescs, all_op_descs); // take ownership
auto op_descs = prog.Block(0).AllOps();
auto op = prog.MutableBlock(0)->AppendOp();
prog.MutableBlock(0)->Var("d1")->SetType(proto::VarType::LOD_TENSOR);
op->SetType("sum");
op->SetInput("X", {"b", "c"});
op->SetOutput("Out", {"d1"});
ir::Node* node = graph.CreateOpNode(op);
ir::Node* d1 = graph.CreateVarNode(prog.MutableBlock(0)->Var("d1"));
ir::Node* b = find_node_in_graph("b");
ir::Node* c = find_node_in_graph("c");
node->outputs.emplace_back(d1);
node->inputs.emplace_back(b);
node->inputs.emplace_back(c);
d1->inputs.emplace_back(node);
b->outputs.emplace_back(node);
c->outputs.emplace_back(node);
op_descs.insert(op_descs.begin() + 4, op);
auto nodes = SortOpLikeDescOrder(graph);
for (size_t i = 0; i < nodes.size(); ++i) {
auto node = nodes[i];
auto op_desc = op_descs[i];
ASSERT_TRUE(IsSameDesc(node->Op(), op_desc));
}
}
// 4. add and delete some op_desc
TEST(SortOpLikeDescOrder, AddAndDeleteOpDesc) {
auto prog = FillProgramDesc();
ir::Graph graph(prog);
const std::vector<OpDesc*>* all_op_descs =
new std::vector<OpDesc*>(prog.Block(0).AllOps());
graph.Set(details::kAllOpDescs, all_op_descs); // take ownership
auto find_node_in_graph = [&](std::string s) {
ir::Node* ret = nullptr;
for (auto n : graph.Nodes()) {
if (n->Name() == s) {
ret = n;
break;
}
}
PADDLE_ENFORCE(ret != nullptr);
return ret;
};
// remove sum node
auto op_descs = prog.Block(0).AllOps();
ir::Node* found_node = nullptr;
auto nodes = graph.Nodes();
for (auto node : nodes) {
if (node->Name() == "sum") {
found_node = node;
break;
}
}
PADDLE_ENFORCE(found_node != nullptr);
for (auto it = op_descs.begin(); it != op_descs.end();) {
if (IsSameDesc(*it, found_node->Op())) {
it = op_descs.erase(it);
} else {
++it;
}
}
{
ir::Node* d = find_node_in_graph("d");
ir::Node* c = find_node_in_graph("c");
ir::Node* e = find_node_in_graph("e");
std::remove(d->outputs.begin(), d->outputs.end(), found_node);
std::remove(c->outputs.begin(), c->outputs.end(), found_node);
ir::Node* pending_op = found_node->outputs[0]->outputs[0];
graph.RemoveNode(e);
graph.RemoveNode(pending_op);
graph.RemoveNode(found_node);
}
// add node
auto op = prog.MutableBlock(0)->AppendOp();
prog.MutableBlock(0)->Var("d1")->SetType(proto::VarType::LOD_TENSOR);
op->SetType("sum");
op->SetInput("X", {"b", "c"});
op->SetOutput("Out", {"d1"});
{
ir::Node* node = graph.CreateOpNode(op);
ir::Node* d1 = graph.CreateVarNode(prog.MutableBlock(0)->Var("d1"));
ir::Node* b = find_node_in_graph("b");
ir::Node* c = find_node_in_graph("c");
node->outputs.emplace_back(d1);
node->inputs.emplace_back(b);
node->inputs.emplace_back(c);
b->outputs.emplace_back(node);
c->outputs.emplace_back(node);
}
op_descs.insert(op_descs.begin() + 2, op);
// check the order
auto mynodes = SortOpLikeDescOrder(graph);
for (size_t i = 0; i < mynodes.size(); ++i) {
auto node = mynodes[i];
auto op_desc = op_descs[i];
ASSERT_TRUE(IsSameDesc(node->Op(), op_desc));
}
}
// 5. add and replace some op_desc inplace.
TEST(SortOpLikeDescOrder, AddAndReplaceOpDescInplace) {
auto prog = FillProgramDesc();
ir::Graph graph(prog);
const std::vector<OpDesc*>* all_op_descs =
new std::vector<OpDesc*>(prog.Block(0).AllOps());
graph.Set(details::kAllOpDescs, all_op_descs); // take ownership
auto find_node_in_graph = [&](std::string s) {
ir::Node* ret = nullptr;
for (auto n : graph.Nodes()) {
if (n->Name() == s) {
ret = n;
break;
}
}
PADDLE_ENFORCE(ret != nullptr);
return ret;
};
auto op_descs = prog.Block(0).AllOps();
// add node
auto op = prog.MutableBlock(0)->AppendOp();
prog.MutableBlock(0)->Var("d1")->SetType(proto::VarType::LOD_TENSOR);
op->SetType("sum");
op->SetInput("X", {"b", "c"});
op->SetOutput("Out", {"d1"});
{
ir::Node* node = graph.CreateOpNode(op);
ir::Node* d1 = graph.CreateVarNode(prog.MutableBlock(0)->Var("d1"));
ir::Node* b = find_node_in_graph("b");
ir::Node* c = find_node_in_graph("c");
node->outputs.emplace_back(d1);
node->inputs.emplace_back(b);
node->inputs.emplace_back(c);
d1->inputs.emplace_back(node);
b->outputs.emplace_back(node);
c->outputs.emplace_back(node);
}
op_descs.emplace_back(op);
// replace op_desc inplace
auto nodes = graph.Nodes();
ir::Node* found_node = nullptr;
for (auto node : nodes) {
if (node->IsOp() && node->Op() && node->Name() == "assign") {
if (node->outputs.size() == 1 && node->outputs[0]->Name() == "e") {
found_node = node;
break;
}
}
}
{
ir::Node* d = find_node_in_graph("d");
ir::Node* e = find_node_in_graph("e");
std::remove(d->outputs.begin(), d->outputs.end(), found_node);
std::remove(e->inputs.begin(), e->inputs.end(), found_node);
graph.RemoveNode(found_node);
}
op_descs.erase(op_descs.begin() + 3);
auto replace_op = prog.MutableBlock(0)->AppendOp();
replace_op->SetType("sum");
replace_op->SetInput("X", {"d", "d1"});
replace_op->SetOutput("Out", {"e"});
{
ir::Node* sum2 = graph.CreateOpNode(replace_op);
ir::Node* e = find_node_in_graph("e");
ir::Node* d = find_node_in_graph("d");
ir::Node* d1 = find_node_in_graph("d1");
sum2->inputs.emplace_back(d);
sum2->inputs.emplace_back(d1);
sum2->outputs.emplace_back(e);
e->inputs.emplace_back(sum2);
d->outputs.emplace_back(sum2);
d1->outputs.emplace_back(sum2);
}
op_descs.emplace_back(replace_op);
// compare op order
auto graph_nodes = SortOpLikeDescOrder(graph);
for (size_t i = 0; i < graph_nodes.size(); ++i) {
auto node = graph_nodes[i];
auto op_desc = op_descs[i];
ASSERT_TRUE(IsSameDesc(node->Op(), op_desc));
}
}
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle { namespace paddle {
......
...@@ -21,8 +21,6 @@ namespace paddle { ...@@ -21,8 +21,6 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
constexpr char kAllOpDescs[] = "all_op_descs";
class SequentialExecutionPass : public ir::Pass { class SequentialExecutionPass : public ir::Pass {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( std::unique_ptr<ir::Graph> ApplyImpl(
......
...@@ -69,7 +69,7 @@ class InplaceInToOut : public InplaceOpInference { ...@@ -69,7 +69,7 @@ class InplaceInToOut : public InplaceOpInference {
bool TryInplaceInputOutput(const VarDesc& in, const VarDesc& out) const { bool TryInplaceInputOutput(const VarDesc& in, const VarDesc& out) const {
return in.Name() != out.Name() && details::NodeCanReused(in) && return in.Name() != out.Name() && details::NodeCanReused(in) &&
details::NodeCanReused(out) && details::NodeCanReused(out) &&
details::NodeSizeInBytes(out) <= details::NodeSizeInBytes(in); details::NodeSize(out) <= details::NodeSize(in);
} }
}; };
......
...@@ -171,14 +171,6 @@ std::unique_ptr<ir::Graph> ParallelExecutorPrivate::PrepareGCAndRefCnts( ...@@ -171,14 +171,6 @@ std::unique_ptr<ir::Graph> ParallelExecutorPrivate::PrepareGCAndRefCnts(
eager_deletion_pass->SetNotOwned(details::kAllPlaces, &places_); eager_deletion_pass->SetNotOwned(details::kAllPlaces, &places_);
graph = eager_deletion_pass->Apply(std::move(graph)); graph = eager_deletion_pass->Apply(std::move(graph));
VLOG(10) << "EagerDeletionPass Applied"; VLOG(10) << "EagerDeletionPass Applied";
if (build_strategy_.memory_early_delete_) {
auto early_delete_pass =
ir::PassRegistry::Instance().Get("memory_early_delete_pass");
early_delete_pass->SetNotOwned(details::kGarbageCollector, &gcs_);
graph = early_delete_pass->Apply(std::move(graph));
}
VLOG(10) << "MemoryEarlyDeletePass Applied.";
} }
return graph; return graph;
...@@ -288,6 +280,8 @@ ParallelExecutor::ParallelExecutor( ...@@ -288,6 +280,8 @@ ParallelExecutor::ParallelExecutor(
graphs.push_back(std::move(graph)); graphs.push_back(std::move(graph));
#endif #endif
auto max_memory_size = GetEagerDeletionThreshold(); auto max_memory_size = GetEagerDeletionThreshold();
VLOG(10) << "Eager Deletion Threshold "
<< static_cast<float>(max_memory_size) / (1 << 30);
if (max_memory_size >= 0) { if (max_memory_size >= 0) {
for (size_t i = 0; i < graphs.size(); ++i) { for (size_t i = 0; i < graphs.size(); ++i) {
graphs[i] = member_->PrepareGCAndRefCnts( graphs[i] = member_->PrepareGCAndRefCnts(
...@@ -506,6 +500,5 @@ ParallelExecutor::~ParallelExecutor() { ...@@ -506,6 +500,5 @@ ParallelExecutor::~ParallelExecutor() {
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
USE_PASS(memory_early_delete_pass);
USE_PASS(reference_count_pass); USE_PASS(reference_count_pass);
USE_PASS(eager_deletion_pass); USE_PASS(eager_deletion_pass);
...@@ -22,11 +22,7 @@ limitations under the License. */ ...@@ -22,11 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
DEFINE_bool(benchmark, false, DECLARE_bool(benchmark);
"Doing memory benchmark. It will make deleting scope synchronized, "
"and add some memory usage logs."
"Default cuda is asynchronous device, set to True will"
"force op run in synchronous mode.");
DEFINE_bool( DEFINE_bool(
eager_delete_scope, true, eager_delete_scope, true,
......
...@@ -52,8 +52,8 @@ cc_test(test_analysis_predictor SRCS analysis_predictor_tester.cc DEPS analysis_ ...@@ -52,8 +52,8 @@ cc_test(test_analysis_predictor SRCS analysis_predictor_tester.cc DEPS analysis_
if (WITH_ANAKIN AND WITH_MKL) # only needed in CI if (WITH_ANAKIN AND WITH_MKL) # only needed in CI
# compile the libinference_anakin_api.a and anakin.so. # compile the libinference_anakin_api.a and anakin.so.
cc_library(inference_anakin_api SRCS api.cc api_anakin_engine.cc DEPS anakin_shared anakin_saber mklml zero_copy_tensor_dummy) cc_library(inference_anakin_api SRCS api.cc api_anakin_engine.cc DEPS anakin_shared anakin_saber mklml zero_copy_tensor_dummy device_context)
cc_library(inference_anakin_api_shared SHARED SRCS api.cc api_anakin_engine.cc DEPS anakin_shared anakin_saber zero_copy_tensor_dummy) cc_library(inference_anakin_api_shared SHARED SRCS api.cc api_anakin_engine.cc DEPS anakin_shared anakin_saber zero_copy_tensor_dummy device_context)
function(anakin_target target_name) function(anakin_target target_name)
target_compile_options(${target_name} BEFORE PUBLIC ${ANAKIN_COMPILE_EXTRA_FLAGS}) target_compile_options(${target_name} BEFORE PUBLIC ${ANAKIN_COMPILE_EXTRA_FLAGS})
endfunction() endfunction()
......
...@@ -36,6 +36,7 @@ DEFINE_bool(init_allocated_mem, false, ...@@ -36,6 +36,7 @@ DEFINE_bool(init_allocated_mem, false,
"that initializing the allocated memory with a small value " "that initializing the allocated memory with a small value "
"during unit testing."); "during unit testing.");
DECLARE_double(fraction_of_gpu_memory_to_use); DECLARE_double(fraction_of_gpu_memory_to_use);
DECLARE_bool(benchmark);
namespace paddle { namespace paddle {
namespace memory { namespace memory {
...@@ -198,7 +199,7 @@ void *Alloc<platform::CUDAPlace>(const platform::CUDAPlace &place, ...@@ -198,7 +199,7 @@ void *Alloc<platform::CUDAPlace>(const platform::CUDAPlace &place,
<< string::HumanReadableSize(Used<platform::CUDAPlace>(place)); << string::HumanReadableSize(Used<platform::CUDAPlace>(place));
platform::SetDeviceId(cur_dev); platform::SetDeviceId(cur_dev);
} else { } else {
if (VLOG_IS_ON(3)) { if (FLAGS_benchmark) {
allocation::GPUMemMonitor.Add(place.device, size); allocation::GPUMemMonitor.Add(place.device, size);
} }
if (FLAGS_init_allocated_mem) { if (FLAGS_init_allocated_mem) {
...@@ -216,7 +217,7 @@ void Free<platform::CUDAPlace>(const platform::CUDAPlace &place, void *p, ...@@ -216,7 +217,7 @@ void Free<platform::CUDAPlace>(const platform::CUDAPlace &place, void *p,
size_t size) { size_t size) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
GetGPUBuddyAllocator(place.device)->Free(p); GetGPUBuddyAllocator(place.device)->Free(p);
if (VLOG_IS_ON(3)) { if (FLAGS_benchmark) {
allocation::GPUMemMonitor.Minus(place.device, size); allocation::GPUMemMonitor.Minus(place.device, size);
} }
#else #else
...@@ -257,7 +258,7 @@ void *Alloc<platform::CUDAPinnedPlace>(const platform::CUDAPinnedPlace &place, ...@@ -257,7 +258,7 @@ void *Alloc<platform::CUDAPinnedPlace>(const platform::CUDAPinnedPlace &place,
void *ptr = buddy_allocator->Alloc(size); void *ptr = buddy_allocator->Alloc(size);
if (ptr == nullptr) { if (ptr == nullptr) {
LOG(WARNING) << "cudaMallocHost Cannot allocate " << size LOG(WARNING) << "cudaHostAlloc Cannot allocate " << size
<< " bytes in CUDAPinnedPlace"; << " bytes in CUDAPinnedPlace";
} }
if (FLAGS_init_allocated_mem) { if (FLAGS_init_allocated_mem) {
......
...@@ -32,7 +32,7 @@ Allocation *CPUPinnedAllocator::AllocateImpl(size_t size, ...@@ -32,7 +32,7 @@ Allocation *CPUPinnedAllocator::AllocateImpl(size_t size,
// "CPUPinnedAllocator should be used for Cross-Device Communication"); // "CPUPinnedAllocator should be used for Cross-Device Communication");
void *ptr; void *ptr;
PADDLE_ENFORCE(cudaMallocHost(&ptr, size)); PADDLE_ENFORCE(cudaHostAlloc(&ptr, size, cudaHostAllocPortable));
return new CPUPinnedAllocation(ptr, size); return new CPUPinnedAllocation(ptr, size);
} }
} // namespace allocation } // namespace allocation
......
...@@ -19,7 +19,7 @@ namespace paddle { ...@@ -19,7 +19,7 @@ namespace paddle {
namespace memory { namespace memory {
namespace allocation { namespace allocation {
// Allocator uses `cudaMallocHost` // Allocator uses `cudaHostAlloc`
class CPUPinnedAllocation : public Allocation { class CPUPinnedAllocation : public Allocation {
public: public:
CPUPinnedAllocation(void *ptr, size_t size) CPUPinnedAllocation(void *ptr, size_t size)
......
...@@ -173,14 +173,14 @@ void* CUDAPinnedAllocator::Alloc(size_t* index, size_t size) { ...@@ -173,14 +173,14 @@ void* CUDAPinnedAllocator::Alloc(size_t* index, size_t size) {
void* p; void* p;
// PINNED memory is visible to all CUDA contexts. // PINNED memory is visible to all CUDA contexts.
cudaError_t result = cudaMallocHost(&p, size); cudaError_t result = cudaHostAlloc(&p, size, cudaHostAllocPortable);
if (result == cudaSuccess) { if (result == cudaSuccess) {
*index = 1; // PINNED memory *index = 1; // PINNED memory
cuda_pinnd_alloc_size_ += size; cuda_pinnd_alloc_size_ += size;
return p; return p;
} else { } else {
LOG(WARNING) << "cudaMallocHost failed."; LOG(WARNING) << "cudaHostAlloc failed.";
return nullptr; return nullptr;
} }
......
...@@ -264,6 +264,23 @@ class ElementwiseOpInplace : public framework::InplaceInToOut { ...@@ -264,6 +264,23 @@ class ElementwiseOpInplace : public framework::InplaceInToOut {
} }
}; };
class ElementwiseGradOpInplace : public framework::InplaceInToOut {
public:
using framework::InplaceInToOut::InplaceInToOut;
protected:
std::unordered_map<std::string, std::string> Apply(
const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
std::unordered_map<std::string, std::string> ret;
if (block->HasVar(framework::GradVarName("X")) &&
block->HasVar(framework::GradVarName("Out"))) {
ret[framework::GradVarName("Out")] = framework::GradVarName("X");
}
return ret;
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -316,4 +333,5 @@ class ElementwiseOpInplace : public framework::InplaceInToOut { ...@@ -316,4 +333,5 @@ class ElementwiseOpInplace : public framework::InplaceInToOut {
op_type##GradMaker, \ op_type##GradMaker, \
::paddle::operators::ElementwiseOpInplace); \ ::paddle::operators::ElementwiseOpInplace); \
REGISTER_OPERATOR(op_type##_grad, \ REGISTER_OPERATOR(op_type##_grad, \
::paddle::operators::ElementwiseOpExplicitGrad) ::paddle::operators::ElementwiseOpExplicitGrad, \
::paddle::operators::ElementwiseGradOpInplace)
...@@ -35,7 +35,7 @@ class NgraphEngineOp : public framework::OperatorWithKernel { ...@@ -35,7 +35,7 @@ class NgraphEngineOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
framework::OpKernelType kt = framework::OpKernelType( framework::OpKernelType kt = framework::OpKernelType(
framework::proto::VarType::FP32, ctx.GetPlace()); framework::proto::VarType::FP32, platform::CPUPlace());
return kt; return kt;
} }
}; };
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/operators/reader/buffered_reader.h" #include "paddle/fluid/operators/reader/buffered_reader.h"
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_type.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -24,6 +25,13 @@ BufferedReader::~BufferedReader() { ...@@ -24,6 +25,13 @@ BufferedReader::~BufferedReader() {
position_.front().wait(); position_.front().wait();
position_.pop(); position_.pop();
} }
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place_)) {
platform::SetDeviceId(boost::get<platform::CUDAPlace>(place_).device);
PADDLE_ENFORCE(cudaStreamDestroy(stream));
for (auto &event : events) PADDLE_ENFORCE(cudaEventDestroy(event));
}
#endif
} }
BufferedReader::BufferedReader( BufferedReader::BufferedReader(
...@@ -33,6 +41,19 @@ BufferedReader::BufferedReader( ...@@ -33,6 +41,19 @@ BufferedReader::BufferedReader(
thread_pool_(1), thread_pool_(1),
place_(place), place_(place),
buffer_size_(buffer_size) { buffer_size_(buffer_size) {
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place_)) {
platform::SetDeviceId(boost::get<platform::CUDAPlace>(place_).device);
compute_stream =
((platform::CUDADeviceContext *)(platform::DeviceContextPool::Instance()
.Get(place_)))
->stream();
events.resize(buffer_size);
for (auto &event : events)
PADDLE_ENFORCE(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
PADDLE_ENFORCE(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
}
#endif
cpu_buffer_.resize(buffer_size); cpu_buffer_.resize(buffer_size);
gpu_buffer_.resize(buffer_size); gpu_buffer_.resize(buffer_size);
ReadTillBufferFullAsync(); ReadTillBufferFullAsync();
...@@ -46,6 +67,12 @@ void BufferedReader::ReadTillBufferFullAsync() { ...@@ -46,6 +67,12 @@ void BufferedReader::ReadTillBufferFullAsync() {
} }
void BufferedReader::ReadAsync(size_t i) { void BufferedReader::ReadAsync(size_t i) {
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place_)) {
platform::SetDeviceId(boost::get<platform::CUDAPlace>(place_).device);
PADDLE_ENFORCE(cudaEventRecord(events[i], compute_stream));
}
#endif
position_.emplace(thread_pool_.enqueue([this, i]() -> size_t { position_.emplace(thread_pool_.enqueue([this, i]() -> size_t {
TensorVec &cpu = cpu_buffer_[i]; TensorVec &cpu = cpu_buffer_[i];
reader_->ReadNext(&cpu); reader_->ReadNext(&cpu);
...@@ -54,14 +81,41 @@ void BufferedReader::ReadAsync(size_t i) { ...@@ -54,14 +81,41 @@ void BufferedReader::ReadAsync(size_t i) {
return -1UL; return -1UL;
} }
#ifdef PADDLE_WITH_CUDA
// NOTE(liangdun): using async copy instead of TensorCopySync
// TensorCopySync would block other stream
if (platform::is_gpu_place(place_)) { if (platform::is_gpu_place(place_)) {
platform::SetDeviceId(boost::get<platform::CUDAPlace>(place_).device);
PADDLE_ENFORCE(cudaStreamWaitEvent(stream, events[i], 0));
TensorVec &gpu = gpu_buffer_[i]; TensorVec &gpu = gpu_buffer_[i];
gpu.resize(cpu.size()); gpu.resize(cpu.size());
for (size_t i = 0; i < cpu.size(); ++i) { for (size_t i = 0; i < cpu.size(); ++i) {
framework::TensorCopySync(cpu[i], place_, &gpu[i]); gpu[i].Resize(cpu[i].dims());
gpu[i].set_layout(cpu[i].layout());
auto cpu_place = cpu[i].place();
auto cpu_ptr = cpu[i].data<void>();
auto gpu_ptr = gpu[i].mutable_data(place_, cpu[i].type());
auto size =
cpu[i].numel() * paddle::framework::SizeOfType(cpu[i].type());
if (platform::is_cuda_pinned_place(cpu_place))
memory::Copy(boost::get<platform::CUDAPlace>(place_), gpu_ptr,
boost::get<platform::CUDAPinnedPlace>(cpu_place),
cpu_ptr, size, stream);
else if ((platform::is_gpu_place(cpu_place)))
memory::Copy(boost::get<platform::CUDAPlace>(place_), gpu_ptr,
boost::get<platform::CUDAPlace>(cpu_place), cpu_ptr,
size, stream);
else
// if cpu place is not pinned, async copy is slower than sync copy,
// so we use sync copy instead.
memory::Copy(boost::get<platform::CUDAPlace>(place_), gpu_ptr,
boost::get<platform::CPUPlace>(cpu_place), cpu_ptr, size,
0);
gpu[i].set_lod(cpu[i].lod()); gpu[i].set_lod(cpu[i].lod());
} }
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
} }
#endif
return i; return i;
})); }));
} }
......
...@@ -19,6 +19,9 @@ ...@@ -19,6 +19,9 @@
#include <vector> #include <vector>
#include "ThreadPool.h" #include "ThreadPool.h"
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/gpu_info.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -59,6 +62,11 @@ class BufferedReader : public framework::DecoratedReader { ...@@ -59,6 +62,11 @@ class BufferedReader : public framework::DecoratedReader {
std::vector<TensorVec> cpu_buffer_; std::vector<TensorVec> cpu_buffer_;
std::vector<TensorVec> gpu_buffer_; std::vector<TensorVec> gpu_buffer_;
size_t prev_pos_{-1UL}; size_t prev_pos_{-1UL};
#ifdef PADDLE_WITH_CUDA
cudaStream_t stream;
cudaStream_t compute_stream;
std::vector<cudaEvent_t> events;
#endif
}; };
} // namespace reader } // namespace reader
......
...@@ -14,6 +14,12 @@ limitations under the License. */ ...@@ -14,6 +14,12 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
DEFINE_bool(benchmark, false,
"Doing memory benchmark. It will make deleting scope synchronized, "
"and add some memory usage logs."
"Default cuda is asynchronous device, set to True will"
"force op run in synchronous mode.");
namespace paddle { namespace paddle {
namespace platform { namespace platform {
......
...@@ -295,6 +295,7 @@ PYBIND11_MODULE(core, m) { ...@@ -295,6 +295,7 @@ PYBIND11_MODULE(core, m) {
.def("_get_float_element", TensorGetElement<float>) .def("_get_float_element", TensorGetElement<float>)
.def("_set_double_element", TensorSetElement<double>) .def("_set_double_element", TensorSetElement<double>)
.def("_get_double_element", TensorGetElement<double>) .def("_get_double_element", TensorGetElement<double>)
.def("_place", [](Tensor &self) { return self.place(); })
.def("_dtype", [](Tensor &self) { return self.type(); }); .def("_dtype", [](Tensor &self) { return self.type(); });
py::class_<LoDTensor, Tensor>(m, "LoDTensor", R"DOC( py::class_<LoDTensor, Tensor>(m, "LoDTensor", R"DOC(
...@@ -673,6 +674,12 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -673,6 +674,12 @@ All parameter, weight, gradient are variables in Paddle.
py::class_<platform::Place>(m, "Place") py::class_<platform::Place>(m, "Place")
.def(py::init<>()) .def(py::init<>())
.def("is_gpu_place",
[](platform::Place &self) { return platform::is_gpu_place(self); })
.def("gpu_device_id",
[](platform::Place &self) {
return boost::get<platform::CUDAPlace>(self).device;
})
.def("set_place", .def("set_place",
[](platform::Place &self, const platform::CPUPlace &cpu_place) { [](platform::Place &self, const platform::CPUPlace &cpu_place) {
self = cpu_place; self = cpu_place;
...@@ -1092,10 +1099,6 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1092,10 +1099,6 @@ All parameter, weight, gradient are variables in Paddle.
"is_distribution", "is_distribution",
[](const BuildStrategy &self) { return self.is_distribution_; }, [](const BuildStrategy &self) { return self.is_distribution_; },
[](BuildStrategy &self, bool b) { self.is_distribution_ = b; }) [](BuildStrategy &self, bool b) { self.is_distribution_ = b; })
.def_property(
"memory_early_delete",
[](const BuildStrategy &self) { return self.memory_early_delete_; },
[](BuildStrategy &self, bool b) { self.memory_early_delete_ = b; })
.def_property( .def_property(
"enable_inplace", "enable_inplace",
[](const BuildStrategy &self) { return self.enable_inplace_; }, [](const BuildStrategy &self) { return self.enable_inplace_; },
......
...@@ -25,4 +25,5 @@ import paddle.reader ...@@ -25,4 +25,5 @@ import paddle.reader
import paddle.dataset import paddle.dataset
import paddle.batch import paddle.batch
import paddle.compat import paddle.compat
import paddle.distributed
batch = batch.batch batch = batch.batch
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
...@@ -37,7 +37,7 @@ default_envs = { ...@@ -37,7 +37,7 @@ default_envs = {
GPUS = 8 GPUS = 8
def start_procs(gpus, cmd, log_dir): def start_procs(gpus, entrypoint, entrypoint_args, log_dir):
procs = [] procs = []
log_fns = [] log_fns = []
os.system("mkdir -p %s" % log_dir) os.system("mkdir -p %s" % log_dir)
...@@ -73,12 +73,11 @@ def start_procs(gpus, cmd, log_dir): ...@@ -73,12 +73,11 @@ def start_procs(gpus, cmd, log_dir):
"PADDLE_TRAINER_ENDPOINTS": all_nodes_devices_endpoints "PADDLE_TRAINER_ENDPOINTS": all_nodes_devices_endpoints
}) })
print("starting process ", i, cmd, curr_env) print("starting process ", i, entrypoint, entrypoint_args, curr_env)
fn = open("%s/workerlog.%d" % (log_dir, i), "w") fn = open("%s/workerlog.%d" % (log_dir, i), "w")
log_fns.append(fn) log_fns.append(fn)
procs.append( cmd = [sys.executable, "-u", entrypoint] + entrypoint_args
subprocess.Popen( procs.append(subprocess.Popen(cmd, stdout=fn, stderr=fn, env=curr_env))
cmd.strip().split(" "), stdout=fn, stderr=fn, env=curr_env))
for i in range(gpus): for i in range(gpus):
try: try:
...@@ -89,7 +88,8 @@ def start_procs(gpus, cmd, log_dir): ...@@ -89,7 +88,8 @@ def start_procs(gpus, cmd, log_dir):
pass pass
def main(): def parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='''start paddle training using multi-process mode. description='''start paddle training using multi-process mode.
NOTE: your train program ***must*** run as distributed nccl2 mode, NOTE: your train program ***must*** run as distributed nccl2 mode,
...@@ -108,21 +108,27 @@ POD_IP (current node ip address, not needed for local training) ...@@ -108,21 +108,27 @@ POD_IP (current node ip address, not needed for local training)
type=int, type=int,
default=8, default=8,
help='start number of processes for every gpu') help='start number of processes for every gpu')
parser.add_argument(
'--cmd',
type=str,
default="",
help='command to run for each process, e.g. python train.py --lr 0.1')
parser.add_argument( parser.add_argument(
'--log_dir', '--log_dir',
type=str, type=str,
default="mylog", default="mylog",
help='directory to put logs per process.') help='directory to put logs per process.')
args = parser.parse_args() parser.add_argument(
if args.cmd == "": 'entrypoint_script',
parser.print_help() type=str,
exit(0) help="The entrypoint script to be launched in parallel,"
start_procs(args.gpus, args.cmd, args.log_dir) "followed by all the arguments for each process,"
"e.g. train.py --lr 0.1")
parser.add_argument('entrypoint_args', nargs=argparse.REMAINDER)
return parser.parse_args()
def main():
args = parse_args()
# launch multiple training process
start_procs(args.gpus, args.entrypoint_script, args.entrypoint_args,
args.log_dir)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -161,7 +161,6 @@ def __bootstrap__(): ...@@ -161,7 +161,6 @@ def __bootstrap__():
'times_excess_than_required_tmp_allocation', 'times_excess_than_required_tmp_allocation',
'enable_inplace_whitelist' 'enable_inplace_whitelist'
] ]
core.init_gflags([sys.argv[0]] + core.init_gflags([sys.argv[0]] +
["--tryfromenv=" + ",".join(read_env_flags)]) ["--tryfromenv=" + ",".join(read_env_flags)])
core.init_glog(sys.argv[0]) core.init_glog(sys.argv[0])
......
...@@ -302,7 +302,8 @@ class LayerHelper(object): ...@@ -302,7 +302,8 @@ class LayerHelper(object):
if default_initializer is None and attr.initializer is None: if default_initializer is None and attr.initializer is None:
if isinstance(dtype, core.VarDesc.VarType): if isinstance(dtype, core.VarDesc.VarType):
if dtype != core.VarDesc.VarType.FP32 and \ if dtype != core.VarDesc.VarType.FP32 and \
dtype != core.VarDesc.VarType.FP64: dtype != core.VarDesc.VarType.FP64 and \
dtype != core.VarDesc.VarType.FP16:
raise TypeError( raise TypeError(
"Can not create parameter with default initializer when dtype is not float type. Set default_initializer to fit the parameter dtype!" "Can not create parameter with default initializer when dtype is not float type. Set default_initializer to fit the parameter dtype!"
) )
......
...@@ -148,7 +148,8 @@ class ParallelExecutor(object): ...@@ -148,7 +148,8 @@ class ParallelExecutor(object):
else framework.default_main_program() else framework.default_main_program()
# FIXME(dzhwinter): enable_inplace should be after memory_optimize # FIXME(dzhwinter): enable_inplace should be after memory_optimize
# if turn on python memory optimize, turn off the inplace_pass. # if turn on python memory optimize, turn off the inplace_pass.
build_strategy.enable_inplace = False if main._is_mem_optimized else True if build_strategy.enable_inplace is None:
build_strategy.enable_inplace = False if main._is_mem_optimized else True
scope = scope if scope is not None else executor.global_scope() scope = scope if scope is not None else executor.global_scope()
if share_vars_from and not isinstance(share_vars_from, if share_vars_from and not isinstance(share_vars_from,
......
...@@ -16,14 +16,37 @@ from __future__ import print_function ...@@ -16,14 +16,37 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest from paddle.fluid.tests.unittests.op_test import OpTest
from paddle.fluid.tests.unittests.test_accuracy_op import TestAccuracyOp
class TestNGRAPHAccuracyOp(TestAccuracyOp): class TestNGRAPHAccuracyOp(OpTest):
def setUp(self): def setUp(self):
super(TestNGRAPHAccuracyOp, self).setUp() self.op_type = "accuracy"
self.dtype = np.float32
self.init_dtype()
n = 128
infer = np.random.random((n, 1)).astype(self.dtype)
indices = np.random.randint(0, 2, (n, 1))
label = np.random.randint(0, 2, (n, 1))
self.inputs = {'Out': infer, 'Indices': indices, "Label": label}
num_correct = 0
for rowid in range(n):
for ele in indices[rowid]:
if ele == label[rowid]:
num_correct += 1
break
self.outputs = {
'Accuracy': np.array([num_correct / float(n)]).astype(self.dtype),
'Correct': np.array([num_correct]).astype("int64"),
'Total': np.array([n]).astype("int64")
}
self._cpu_only = True
def init_dtype(self):
pass
def test_check_output(self):
self.check_output()
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -15,35 +15,59 @@ ...@@ -15,35 +15,59 @@
from __future__ import print_function from __future__ import print_function
import unittest import unittest
from paddle.fluid.tests.unittests.test_conv2d_op import * from paddle.fluid.tests.unittests.test_conv2d_op import TestConv2dOp, TestWithPad, TestWithStride, TestWithGroup, TestWith1x1, TestWithInput1x1Filter1x1
class TestNGRAPH(TestConv2dOp): class TestNGRAPH(TestConv2dOp):
def setUp(self):
super(TestNGRAPH, self).setUp()
self._cpu_only = True
def init_kernel_type(self): def init_kernel_type(self):
super(TestNGRAPH, self).init_kernel_type() super(TestNGRAPH, self).init_kernel_type()
class TestNGRAPHWithPad(TestWithPad): class TestNGRAPHWithPad(TestWithPad):
def setUp(self):
super(TestNGRAPHWithPad, self).setUp()
self._cpu_only = True
def init_kernel_type(self): def init_kernel_type(self):
super(TestNGRAPHWithPad, self).init_kernel_type() super(TestNGRAPHWithPad, self).init_kernel_type()
class TestNGRAPHWithStride(TestWithStride): class TestNGRAPHWithStride(TestWithStride):
def setUp(self):
super(TestNGRAPHWithStride, self).setUp()
self._cpu_only = True
def init_kernel_type(self): def init_kernel_type(self):
super(TestNGRAPHWithStride, self).init_kernel_type() super(TestNGRAPHWithStride, self).init_kernel_type()
class TestNGRAPHWithGroup(TestWithGroup): class TestNGRAPHWithGroup(TestWithGroup):
def setUp(self):
super(TestNGRAPHWithGroup, self).setUp()
self._cpu_only = True
def init_kernel_type(self): def init_kernel_type(self):
super(TestNGRAPHWithGroup, self).init_kernel_type() super(TestNGRAPHWithGroup, self).init_kernel_type()
class TestNGRAPHWith1x1(TestWith1x1): class TestNGRAPHWith1x1(TestWith1x1):
def setUp(self):
super(TestNGRAPHWith1x1, self).setUp()
self._cpu_only = True
def init_kernel_type(self): def init_kernel_type(self):
super(TestNGRAPHWith1x1, self).init_kernel_type() super(TestNGRAPHWith1x1, self).init_kernel_type()
class TestNGRAPHWithInput1x1Filter1x1(TestWithInput1x1Filter1x1): class TestNGRAPHWithInput1x1Filter1x1(TestWithInput1x1Filter1x1):
def setUp(self):
super(TestNGRAPHWithInput1x1Filter1x1, self).setUp()
self._cpu_only = True
def init_kernel_type(self): def init_kernel_type(self):
super(TestNGRAPHWithInput1x1Filter1x1, self).init_kernel_type() super(TestNGRAPHWithInput1x1Filter1x1, self).init_kernel_type()
......
...@@ -14,73 +14,16 @@ ...@@ -14,73 +14,16 @@
from __future__ import print_function from __future__ import print_function
import unittest import unittest
from paddle.fluid.tests.unittests.test_elementwise_add_op import * from paddle.fluid.tests.unittests.test_elementwise_add_op import TestElementwiseAddOp
class TestNGRAPHElementwiseAddOp(TestElementwiseAddOp): class TestNGRAPHElementwiseAddOp(TestElementwiseAddOp):
def init_input_output(self): def setUp(self):
super(TestNGRAPHElementwiseAddOp, self).init_input_output() super(TestNGRAPHElementwiseAddOp, self).setUp()
self._cpu_only = True
class TestNGRAPHElementwiseAddOp_scalar(TestElementwiseAddOp_scalar):
def init_input_output(self):
super(TestNGRAPHElementwiseAddOp_scalar, self).init_input_output()
class TestNGRAPHElementwiseAddOp_scalar2(TestElementwiseAddOp_scalar2):
def init_input_output(self):
super(TestNGRAPHElementwiseAddOp_scalar2, self).init_input_output()
class TestNGRAPHElementwiseAddOp_Vector(TestElementwiseAddOp_Vector):
def init_input_output(self):
super(TestNGRAPHElementwiseAddOp_Vector, self).init_input_output()
class TesNGRAPHtElementwiseAddOp_broadcast_0(TestElementwiseAddOp_broadcast_0):
def init_input_output(self):
super(TesNGRAPHtElementwiseAddOp_broadcast_0, self).init_input_output()
class TestNGRAPHElementwiseAddOp_broadcast_1(TestElementwiseAddOp_broadcast_1):
def init_input_output(self):
super(TestNGRAPHElementwiseAddOp_broadcast_1, self).init_input_output()
class TestNGRAPHElementwiseAddOp_broadcast_2(TestElementwiseAddOp_broadcast_2):
def init_input_output(self): def init_input_output(self):
super(TestNGRAPHElementwiseAddOp_broadcast_2, self).init_input_output() super(TestNGRAPHElementwiseAddOp, self).init_input_output()
class TestNGRAPHElementwiseAddOp_broadcast_3(TestElementwiseAddOp_broadcast_3):
def init_input_output(self):
super(TestNGRAPHElementwiseAddOp_broadcast_3, self).init_input_output()
class TestNGRAPHElementwiseAddOp_broadcast_4(TestElementwiseAddOp_broadcast_4):
def init_input_output(self):
super(TestNGRAPHElementwiseAddOp_broadcast_4, self).init_input_output()
class TestNGRAPHElementwiseAddOp_rowwise_add_0(
TestElementwiseAddOp_rowwise_add_0):
def init_input_output(self):
super(TestNGRAPHElementwiseAddOp_rowwise_add_0,
self).init_input_output()
class TestNGRAPHElementwiseAddOp_rowwise_add_1(
TestElementwiseAddOp_rowwise_add_1):
def init_input_output(self):
super(TestNGRAPHElementwiseAddOp_rowwise_add_1,
self).init_input_output()
class TestNGRAPHElementwiseAddOp_channelwise_add(
TestElementwiseAddOp_channelwise_add):
def init_input_output(self):
super(TestNGRAPHElementwiseAddOp_channelwise_add,
self).init_input_output()
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -14,17 +14,13 @@ ...@@ -14,17 +14,13 @@
from __future__ import print_function from __future__ import print_function
import unittest import unittest
from paddle.fluid.tests.unittests.test_mean_op import TestMeanOp, TestFP16MeanOp from paddle.fluid.tests.unittests.test_mean_op import TestMeanOp
class TestNGRAPHMeanOp(TestMeanOp): class TestNGRAPHMeanOp(TestMeanOp):
def setUp(self): def setUp(self):
super(TestNGRAPHMeanOp, self).setUp() super(TestNGRAPHMeanOp, self).setUp()
self._cpu_only = True
class TestNGRAPHFP16MeanOp(TestFP16MeanOp):
def setUp(self):
super(TestNGRAPHFP16MeanOp, self).setUp()
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -15,27 +15,38 @@ ...@@ -15,27 +15,38 @@
from __future__ import print_function from __future__ import print_function
import unittest import unittest
from paddle.fluid.tests.unittests.test_mul_op import TestMulOp, TestMulOp2, TestFP16MulOp1, TestFP16MulOp2 import numpy as np
from paddle.fluid.tests.unittests.op_test import OpTest
class TestNGRAPHMulOp(OpTest):
def setUp(self):
self.op_type = "mul"
self.dtype = np.float32
self.init_dtype_type()
self.inputs = {
'X': np.random.random((2, 4)).astype(self.dtype),
'Y': np.random.random((4, 4)).astype(self.dtype)
}
self.outputs = {'Out': np.dot(self.inputs['X'], self.inputs['Y'])}
self._cpu_only = True
class TestNGRAPHMulOp(TestMulOp):
def init_dtype_type(self): def init_dtype_type(self):
pass pass
def test_check_output(self):
self.check_output()
class TestNGRAPHMulOp2(TestMulOp2): def test_check_grad_normal(self):
def init_dtype_type(self): self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.5)
pass
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'], 'Out', max_relative_error=0.5, no_grad_set=set("X"))
class TestNGRAPHFP16MulOp1(TestFP16MulOp1): def test_check_grad_ingore_y(self):
def init_dtype_type(self): self.check_grad(
pass ['X'], 'Out', max_relative_error=0.5, no_grad_set=set('Y'))
class TestNGRAPHFP16MulOp2(TestFP16MulOp2):
def init_dtype_type(self):
pass
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -14,35 +14,59 @@ ...@@ -14,35 +14,59 @@
from __future__ import print_function from __future__ import print_function
from paddle.fluid.tests.unittests.test_pool2d_op import * from paddle.fluid.tests.unittests.test_pool2d_op import TestPool2D_Op, TestCase1, TestCase2, TestCase3, TestCase4, TestCase5
class TestNGRAPHPool2D_Op(TestPool2D_Op): class TestNGRAPHPool2D_Op(TestPool2D_Op):
def setUp(self):
super(TestNGRAPHPool2D_Op, self).setUp()
self._cpu_only = True
def init_test_case(self): def init_test_case(self):
super(TestNGRAPHPool2D_Op, self).init_test_case() super(TestNGRAPHPool2D_Op, self).init_test_case()
class TestNGRAPHCase1(TestCase1): class TestNGRAPHCase1(TestCase1):
def setUp(self):
super(TestNGRAPHCase1, self).setUp()
self._cpu_only = True
def init_test_case(self): def init_test_case(self):
super(TestNGRAPHCase1, self).init_test_case() super(TestNGRAPHCase1, self).init_test_case()
class TestNGRAPHCase2(TestCase2): class TestNGRAPHCase2(TestCase2):
def setUp(self):
super(TestNGRAPHCase2, self).setUp()
self._cpu_only = True
def init_test_case(self): def init_test_case(self):
super(TestNGRAPHCase2, self).init_test_case() super(TestNGRAPHCase2, self).init_test_case()
class TestNGRAPHCase3(TestCase3): class TestNGRAPHCase3(TestCase3):
def setUp(self):
super(TestNGRAPHCase3, self).setUp()
self._cpu_only = True
def init_pool_type(self): def init_pool_type(self):
super(TestNGRAPHCase3, self).init_pool_type() super(TestNGRAPHCase3, self).init_pool_type()
class TestNGRAPHCase4(TestCase4): class TestNGRAPHCase4(TestCase4):
def setUp(self):
super(TestNGRAPHCase4, self).setUp()
self._cpu_only = True
def init_pool_type(self): def init_pool_type(self):
super(TestNGRAPHCase4, self).init_pool_type() super(TestNGRAPHCase4, self).init_pool_type()
class TestNGRAPHCase5(TestCase5): class TestNGRAPHCase5(TestCase5):
def setUp(self):
super(TestNGRAPHCase5, self).setUp()
self._cpu_only = True
def init_pool_type(self): def init_pool_type(self):
super(TestNGRAPHCase5, self).init_pool_type() super(TestNGRAPHCase5, self).init_pool_type()
......
...@@ -13,25 +13,23 @@ ...@@ -13,25 +13,23 @@
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
import unittest import unittest
from paddle.fluid.tests.unittests.test_scale_op import TestScaleOp, TestScaleOpSelectedRows, TestScaleFp16Op, TestScaleFp16OpSelectedRows from paddle.fluid.tests.unittests.test_scale_op import TestScaleOp, TestScaleOpSelectedRows
class TestNGRAPHScaleOp(TestScaleOp): class TestNGRAPHScaleOp(TestScaleOp):
def init_dtype_type(self): def setUp(self):
pass super(TestNGRAPHScaleOp, self).setUp()
self._cpu_only = True
class TestNGRAPHScaleOpSelectedRows(TestScaleOpSelectedRows):
def init_dtype_type(self): def init_dtype_type(self):
pass pass
class TestNGRAPHScaleFp16Op(TestScaleFp16Op): class TestNGRAPHScaleOpSelectedRows(TestScaleOpSelectedRows):
def init_dtype_type(self): def setUp(self):
pass super(TestNGRAPHScaleOpSelectedRows, self).setUp()
self._cpu_only = True
class TestNGRAPHScaleFp16OpSelectedRows(TestScaleFp16OpSelectedRows):
def init_dtype_type(self): def init_dtype_type(self):
pass pass
......
...@@ -20,21 +20,25 @@ from paddle.fluid.tests.unittests.test_top_k_op import TestTopkOp, TestTopkOp3d, ...@@ -20,21 +20,25 @@ from paddle.fluid.tests.unittests.test_top_k_op import TestTopkOp, TestTopkOp3d,
class TestNGRAPHTopkOp(TestTopkOp): class TestNGRAPHTopkOp(TestTopkOp):
def setUp(self): def setUp(self):
super(TestNGRAPHTopkOp, self).setUp() super(TestNGRAPHTopkOp, self).setUp()
self._cpu_only = True
class TestNGRAPHTopkOp2(TestTopkOp2): class TestNGRAPHTopkOp2(TestTopkOp2):
def setUp(self): def setUp(self):
super(TestNGRAPHTopkOp2, self).setUp() super(TestNGRAPHTopkOp2, self).setUp()
self._cpu_only = True
class TestNGRAPHTopkOp3(TestTopkOp3): class TestNGRAPHTopkOp3(TestTopkOp3):
def setUp(self): def setUp(self):
super(TestNGRAPHTopkOp3, self).setUp() super(TestNGRAPHTopkOp3, self).setUp()
self._cpu_only = True
class TestNGRAPHTopkOp4(TestTopkOp4): class TestNGRAPHTopkOp4(TestTopkOp4):
def setUp(self): def setUp(self):
super(TestNGRAPHTopkOp4, self).setUp() super(TestNGRAPHTopkOp4, self).setUp()
self._cpu_only = True
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -100,6 +100,7 @@ packages=['paddle', ...@@ -100,6 +100,7 @@ packages=['paddle',
'paddle.utils', 'paddle.utils',
'paddle.dataset', 'paddle.dataset',
'paddle.reader', 'paddle.reader',
'paddle.distributed',
'paddle.fluid', 'paddle.fluid',
'paddle.fluid.imperative', 'paddle.fluid.imperative',
'paddle.fluid.proto', 'paddle.fluid.proto',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册