提交 0a63234c 编写于 作者: D dzhwinter

follow comments. test=develop

上级 9e87fbeb
...@@ -53,6 +53,12 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -53,6 +53,12 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendPass("fuse_relu_depthwise_conv_pass"); AppendPass("fuse_relu_depthwise_conv_pass");
} }
// NOTE(dzhwinter): A note for automatical inplace.
// 1. modify program desc passes should put
// before inplace pass.
// 2. manually configured inplace should put
// before inplace_pass
// Add automatically inplace. // Add automatically inplace.
if (strategy_.enable_inplace_) { if (strategy_.enable_inplace_) {
AppendPass("inplace_pass"); AppendPass("inplace_pass");
......
...@@ -80,6 +80,9 @@ struct BuildStrategy { ...@@ -80,6 +80,9 @@ struct BuildStrategy {
bool memory_early_delete_{false}; bool memory_early_delete_{false};
// TODO(dzhwinter):
// make enable_inplace, memory_optimize_
// memory_early_delete_ true by default
bool enable_inplace_{false}; bool enable_inplace_{false};
bool enable_sequential_execution_{false}; bool enable_sequential_execution_{false};
......
...@@ -26,6 +26,11 @@ namespace details { ...@@ -26,6 +26,11 @@ namespace details {
constexpr char kGraphvizPath[] = "debug_graphviz_path"; constexpr char kGraphvizPath[] = "debug_graphviz_path";
constexpr char kGraphviz[] = "graphviz"; constexpr char kGraphviz[] = "graphviz";
// NOTE(dzhwinter): If the graph contains circles.
// the graph can not be topology sort.
// This printer will print the whole graph
// and highlight the circles. It's quite useful
// for debug the deadlock and circles.
class GraphvizNode { class GraphvizNode {
public: public:
GraphvizNode(ir::Node* n, const int& i) : node_(n), id_(i) {} GraphvizNode(ir::Node* n, const int& i) : node_(n), id_(i) {}
...@@ -37,7 +42,7 @@ class GraphvizNode { ...@@ -37,7 +42,7 @@ class GraphvizNode {
ir::Node* node_; ir::Node* node_;
int id_; int id_;
}; };
class GraphvizNode;
typedef std::unordered_set<std::unique_ptr<GraphvizNode>> GraphvizNodes; typedef std::unordered_set<std::unique_ptr<GraphvizNode>> GraphvizNodes;
class SSAGraphPrinter { class SSAGraphPrinter {
......
...@@ -13,7 +13,9 @@ ...@@ -13,7 +13,9 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/memory_optimize_helper.h" #include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include <functional>
#include <iostream> #include <iostream>
#include <numeric>
#include <sstream> #include <sstream>
#include <string> #include <string>
...@@ -21,15 +23,17 @@ namespace paddle { ...@@ -21,15 +23,17 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
size_t NodeSizeInBytes(const VarDesc& node) {
auto shape = node.GetShape();
int size =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
size_t type_size = SizeOfType(node.GetDataType());
return type_size * std::abs(size);
}
size_t NodeSizeInBytes(ir::Node* n) { size_t NodeSizeInBytes(ir::Node* n) {
auto* desc = FindVarDescInBlock(n); auto* desc = FindVarDescInBlock(n);
auto shape = desc->GetShape(); return NodeSizeInBytes(*desc);
size_t type_size = SizeOfType(desc->GetDataType());
int size = 1;
for (auto& s : shape) {
size *= s;
}
return type_size * std::abs(size);
} }
std::string DebugStringImpl(VarDesc* var) { std::string DebugStringImpl(VarDesc* var) {
...@@ -154,23 +158,28 @@ std::string OrderedNodeList::ToString() const { ...@@ -154,23 +158,28 @@ std::string OrderedNodeList::ToString() const {
bool NodeCanReused(ir::Node* node) { bool NodeCanReused(ir::Node* node) {
if (node == nullptr || !node->IsVar() || node->IsCtrlVar()) return false; if (node == nullptr || !node->IsVar() || node->IsCtrlVar()) return false;
auto* desc = node->Var(); // auto* desc = node->Var();
auto type = desc->GetType(); bool flag = NodeCanReused(*node->Var());
if (desc->Persistable() || type != proto::VarType::LOD_TENSOR ||
desc->GetShape().empty()) {
return false;
}
// vars can be @EMPTY@, @LR_DECAY_REUSE_ID@. For example, while_grad
std::string name = node->Name();
if (!name.empty() && name[0] == '@' && name[name.size() - 1] == '@')
return false;
for (auto* op : node->inputs) { for (auto* op : node->inputs) {
if (op->Op()->HasAttr("force_cpu")) { if (op->Op()->HasAttr("force_cpu")) {
// op output force generated in cpu, can not be reused. // op output force generated in cpu, can not be reused.
return framework::AttrReader(op->Op()->GetAttrMap()) flag &= framework::AttrReader(op->Op()->GetAttrMap())
.Get<bool>("force_cpu") == 0; .Get<bool>("force_cpu") == 0;
} }
} }
return flag;
}
bool NodeCanReused(const VarDesc& node) {
auto type = node.GetType();
if (node.Persistable() || type != proto::VarType::LOD_TENSOR ||
node.GetShape().empty()) {
return false;
}
// vars can be @EMPTY@, @LR_DECAY_REUSE_ID@. For example, while_grad
std::string name = node.Name();
if (!name.empty() && name[0] == '@' && name[name.size() - 1] == '@')
return false;
return true; return true;
} }
......
...@@ -86,12 +86,18 @@ class OrderedNodeList { ...@@ -86,12 +86,18 @@ class OrderedNodeList {
// valid a tensor can be reuse or not // valid a tensor can be reuse or not
bool NodeCanReused(ir::Node* node); bool NodeCanReused(ir::Node* node);
// valid a tensor can be reuse or not.
bool NodeCanReused(const VarDesc& node);
// check op has subblock or not // check op has subblock or not
bool OpHasSubBlock(OpDesc* desc); bool OpHasSubBlock(OpDesc* desc);
// node memory size in bytes // node memory size in bytes
size_t NodeSizeInBytes(ir::Node* n); size_t NodeSizeInBytes(ir::Node* n);
// node memory size in bytes
size_t NodeSizeInBytes(const VarDesc&);
std::string DebugString(ir::Node* var); std::string DebugString(ir::Node* var);
VarDesc* FindVarDescInBlock(ir::Node* n); VarDesc* FindVarDescInBlock(ir::Node* n);
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <unordered_map> #include <unordered_map>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/type_defs.h" #include "paddle/fluid/framework/type_defs.h"
...@@ -66,30 +67,9 @@ class InplaceInToOut : public InplaceOpInference { ...@@ -66,30 +67,9 @@ class InplaceInToOut : public InplaceOpInference {
const OpDesc& op_desc, BlockDesc* block) const = 0; const OpDesc& op_desc, BlockDesc* block) const = 0;
bool TryInplaceInputOutput(const VarDesc& in, const VarDesc& out) const { bool TryInplaceInputOutput(const VarDesc& in, const VarDesc& out) const {
auto var_can_reused = [&](const VarDesc& node) -> bool { return in.Name() != out.Name() && details::NodeCanReused(in) &&
auto type = node.GetType(); details::NodeCanReused(out) &&
if (node.Persistable() || type != proto::VarType::LOD_TENSOR || details::NodeSizeInBytes(out) <= details::NodeSizeInBytes(in);
node.GetShape().empty()) {
return false;
}
// vars can be @EMPTY@, @LR_DECAY_REUSE_ID@. For example, while_grad
std::string name = node.Name();
if (!name.empty() && name[0] == '@' && name[name.size() - 1] == '@')
return false;
return true;
};
auto var_size_in_bytes = [&](const VarDesc& node) -> size_t {
auto shape = node.GetShape();
int size = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<int>());
size_t type_size = SizeOfType(node.GetDataType());
return type_size * std::abs(size);
};
return in.Name() != out.Name() && var_can_reused(in) &&
var_can_reused(out) &&
var_size_in_bytes(out) <= var_size_in_bytes(in);
} }
}; };
......
...@@ -174,6 +174,11 @@ class CompiledProgram(object): ...@@ -174,6 +174,11 @@ class CompiledProgram(object):
self._exec_strategy.num_threads = cpu_num * 2 self._exec_strategy.num_threads = cpu_num * 2
trainers_endpoints = self._program._trainers_endpoints trainers_endpoints = self._program._trainers_endpoints
# FIXME(dzhwinter): enable_inplace should be after memory_optimize
# if turn on python memory optimize, turn off the inplace_pass.
self._build_strategy.enable_inplace = False if self._program._is_mem_optimized else True
if self._build_strategy.num_trainers > 1 and trainers_endpoints: if self._build_strategy.num_trainers > 1 and trainers_endpoints:
assert self._build_strategy.num_trainers == len( assert self._build_strategy.num_trainers == len(
trainers_endpoints), "num_trainers == len(end_points)" trainers_endpoints), "num_trainers == len(end_points)"
......
...@@ -1725,18 +1725,19 @@ class Program(object): ...@@ -1725,18 +1725,19 @@ class Program(object):
self._trainers_endpoints = [] self._trainers_endpoints = []
# the distributed lookup table names # the distributed lookup table names
self._distributed_lookup_table = None self._distributed_lookup_table = None
# @deprecated(the python memory optimize transpiler is deprecated)
# whether the program is optimized by memory_optimize_transpiler # whether the program is optimized by memory_optimize_transpiler
self.__is_optimized = False self.__is_mem_optimized = False
@property @property
def _is_optimized(self): def _is_mem_optimized(self):
# if the program is optimized, operator input/outputs # if the program is optimized, operator input/outputs
# maybe same, which conflict with save_inference_model. # maybe same, which conflict with save_inference_model.
return self.__is_optimized return self.__is_mem_optimized
@_is_optimized.setter @_is_mem_optimized.setter
def _is_optimized(self, target): def _is_mem_optimized(self, target):
self.__is_optimized = target self.__is_mem_optimized = target
@property @property
def op_role(self): def op_role(self):
......
...@@ -931,7 +931,7 @@ def save_inference_model(dirname, ...@@ -931,7 +931,7 @@ def save_inference_model(dirname,
if main_program is None: if main_program is None:
main_program = default_main_program() main_program = default_main_program()
if main_program._is_optimized: if main_program._is_mem_optimized:
warnings.warn( warnings.warn(
"save_inference_model must put before you call memory_optimize. \ "save_inference_model must put before you call memory_optimize. \
the memory_optimize will modify the original program, \ the memory_optimize will modify the original program, \
......
...@@ -148,7 +148,7 @@ class ParallelExecutor(object): ...@@ -148,7 +148,7 @@ class ParallelExecutor(object):
else framework.default_main_program() else framework.default_main_program()
# FIXME(dzhwinter): enable_inplace should be after memory_optimize # FIXME(dzhwinter): enable_inplace should be after memory_optimize
# if turn on python memory optimize, turn off the inplace_pass. # if turn on python memory optimize, turn off the inplace_pass.
build_strategy.enable_inplace = False if main._is_optimized else True build_strategy.enable_inplace = False if main._is_mem_optimized else True
scope = scope if scope is not None else executor.global_scope() scope = scope if scope is not None else executor.global_scope()
if share_vars_from and not isinstance(share_vars_from, if share_vars_from and not isinstance(share_vars_from,
......
...@@ -108,7 +108,7 @@ class TestSaveInferenceModel(unittest.TestCase): ...@@ -108,7 +108,7 @@ class TestSaveInferenceModel(unittest.TestCase):
exe.run(init_program, feed={}, fetch_list=[]) exe.run(init_program, feed={}, fetch_list=[])
memory_optimize(program, print_log=True) memory_optimize(program, print_log=True)
self.assertEqual(program._is_optimized, True) self.assertEqual(program._is_mem_optimized, True)
# will print warning message # will print warning message
save_inference_model(MODEL_DIR, ["x", "y"], [avg_cost], exe, program) save_inference_model(MODEL_DIR, ["x", "y"], [avg_cost], exe, program)
......
...@@ -540,7 +540,7 @@ def memory_optimize(input_program, ...@@ -540,7 +540,7 @@ def memory_optimize(input_program,
if skip_opt_set is not None: if skip_opt_set is not None:
skip_opt_set = set(map(to_name_str, skip_opt_set)) skip_opt_set = set(map(to_name_str, skip_opt_set))
cfgs = _get_cfgs(input_program) cfgs = _get_cfgs(input_program)
input_program._is_optimized = True input_program._is_mem_optimized = True
for cfg in cfgs: for cfg in cfgs:
cfg.memory_optimize(skip_opt_set=skip_opt_set, level=level) cfg.memory_optimize(skip_opt_set=skip_opt_set, level=level)
...@@ -560,6 +560,6 @@ def release_memory(input_program, skip_opt_set=None): ...@@ -560,6 +560,6 @@ def release_memory(input_program, skip_opt_set=None):
None None
""" """
cfgs = _get_cfgs(input_program) cfgs = _get_cfgs(input_program)
input_program._is_optimized = True input_program._is_mem_optimized = True
for cfg in cfgs: for cfg in cfgs:
cfg.release_memory(skip_opt_set=skip_opt_set) cfg.release_memory(skip_opt_set=skip_opt_set)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册