未验证 提交 8008ab4e 编写于 作者: Z Zeng Jinle 提交者: GitHub

Remove legacy C++ memory optimization codes (#18834)

* remove legacy memory optimization codes, test=develop

* follow huihuang's comments,test=develop

* follow luotao's comments, test=develop
上级 52c1431e
......@@ -133,7 +133,7 @@ cc_test(version_test SRCS version_test.cc DEPS version)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc memory_optimize_helper)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc)
nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
......@@ -204,7 +204,6 @@ cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry
proto_desc)
cc_test(inplace_op_inference_test SRCS inplace_op_inference_test.cc DEPS inplace_op_pass op_registry proto_desc op_info memory_optimize_helper pass_builder)
cc_library(selected_rows SRCS selected_rows.cc DEPS tensor)
cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows)
......
......@@ -62,7 +62,7 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d
cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper)
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass inplace_op_pass buffer_shared_inplace_op_pass buffer_shared_cross_op_memory_reuse_pass)
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass buffer_shared_inplace_op_pass buffer_shared_cross_op_memory_reuse_pass)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS})
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
......@@ -92,6 +92,6 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS
multi_devices_graph_print_pass multi_devices_graph_check_pass
fuse_elewise_add_act_pass multi_batch_merge_pass
fuse_relu_depthwise_conv_pass
memory_optimize_pass lock_free_optimize_pass
lock_free_optimize_pass
coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass
fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass record_skip_memory_opt_vars_pass)
fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass)
......@@ -24,8 +24,6 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_printer.h"
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h"
DECLARE_bool(use_mkldnn);
......@@ -51,17 +49,13 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
ResolveOptionConfliction();
AppendPrintGraphPass("graph_viz_pass", "_original_graph");
// Note(zcd): record_skip_memory_opt_vars_pass should
// be the first pass.
AppendPass("record_skip_memory_opt_vars_pass");
AppendPassWithCheck(strategy_.enable_sequential_execution_,
"sequential_execution_pass");
AppendPassWithCheck(strategy_.sync_batch_norm_, "sync_batch_norm_pass");
AppendOpFusePasses();
AppendPrintGraphPass("graph_viz_pass", "_fused_graph");
// TODO(dev-paddle): memory optimize pass should be placed last.
AppendMemoryOptimizePasses();
AppendMultiDevPass();
AppendMultiGraphOptPasses();
......@@ -147,23 +141,6 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
}
}
void AppendMemoryOptimizePasses() { // Append Memory Optimize Pass
// TODO(zjl): refactor MemoryOptimizePass to fit
// new strategy, which does not need to set
// var.persistable = True
if (strategy_.use_legacy_memory_optimize_strategy_) {
AppendPassWithCheck(strategy_.enable_inplace_, "inplace_pass");
}
// NOTE(dzh): memory optimize should be a runtime pass.
// However, after multi_devices_pass, VarHandle, OpHandle is
// the de-fact IR, any reuse on Graph is meaningless.
// A side-effect of that, memory optimize cannot forsee the fetched vars
// , so fetchlist should be set persistable before call the Run interface.
if (strategy_.use_legacy_memory_optimize_strategy_) {
AppendPassWithCheck(strategy_.memory_optimize_, "memory_optimize_pass");
}
}
void SetCollectiveContext() const {
CollectiveContext *context = CollectiveContext::GetInstance();
context->endpoints_ = strategy_.trainers_endpoints_;
......@@ -330,9 +307,6 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
"GPU, skipped.";
continue;
}
} else if (pass->Type() == "inplace_pass") {
pass->Erase(ir::kUseCuda);
pass->Set<bool>(ir::kUseCuda, new bool(use_cuda));
} else if (pass->Type() == "mkldnn_placement_pass") {
pass->Set("mkldnn_enabled_op_types",
new std::unordered_set<std::string>(mkldnn_enabled_op_types_));
......@@ -365,12 +339,10 @@ USE_PASS(all_reduce_mode_multi_devices_pass);
USE_PASS(dist_multi_devices_pass);
USE_PASS(multi_devices_check_pass);
USE_PASS(multi_devices_print_pass);
USE_PASS(memory_optimize_pass);
USE_PASS(sequential_execution_pass);
USE_PASS(all_reduce_deps_pass);
USE_PASS(backward_optimizer_op_deps_pass);
USE_PASS(modify_op_lock_and_record_event_pass);
USE_PASS(inplace_pass);
USE_PASS(lock_free_optimize_pass);
USE_PASS(coalesce_grad_tensor_pass);
USE_PASS(graph_to_program_pass);
......@@ -379,7 +351,6 @@ USE_PASS(fuse_sgd_op_pass);
USE_PASS(fuse_momentum_op_pass);
USE_PASS(fuse_all_reduce_op_pass);
USE_PASS(runtime_context_cache_pass);
USE_PASS(record_skip_memory_opt_vars_pass);
#ifdef PADDLE_WITH_MKLDNN
USE_PASS(mkldnn_placement_pass);
#endif
......@@ -19,6 +19,7 @@
#include <unordered_set>
#include <utility>
#include <vector>
#include "boost/optional.hpp"
#include "paddle/fluid/framework/ir/pass_builder.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
......@@ -108,14 +109,14 @@ struct BuildStrategy {
// FLAGS_use_mkldnn=false
std::unordered_set<std::string> mkldnn_enabled_op_types_;
bool memory_optimize_{false};
// By default, memory_optimize would be opened if gc is disabled, and
// be closed if gc is enabled.
// Users can forcely enable/disable memory_optimize by setting True/False.
boost::optional<bool> memory_optimize_{boost::none};
// Turn on inplace by default.
bool enable_inplace_{true};
// TODO(zjl): Remove this flag when MemoryOptimizePass is refactored
bool use_legacy_memory_optimize_strategy_{false};
// FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode,
// num_trainers is 1, so the current fields of build_strategy doesn't tell if
// it's distributed model.
......
......@@ -13,13 +13,8 @@
// limitations under the License.
#pragma once
#include <functional>
#include <numeric>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "glog/logging.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/type_defs.h"
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <iterator>
#include <memory>
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/pass_builder.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_type_inference.h"
USE_PASS(inplace_pass);
namespace paddle {
namespace framework {
std::unique_ptr<ir::Pass> CreateInplacePass() {
auto pass = ir::PassRegistry::Instance().Get("inplace_pass");
pass->Set<bool>(ir::kUseCuda, new bool(true));
return pass;
}
class NOP : public OperatorBase {
public:
NOP(const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const Scope& scope,
const platform::Place& place) const override {}
};
class SingleOpMaker : public OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "").AsDuplicable();
AddOutput("Out", "");
AddComment("");
}
};
class SingleGradOpMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
auto* op = new framework::OpDesc();
op->SetType("single_op_grad");
op->SetInput("Out", OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
return std::unique_ptr<OpDesc>(op);
}
};
class SingleOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {
ctx->HasInput("X");
ctx->HasOutput("Out");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
}
};
class SingleGradOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {
ctx->HasInput(framework::GradVarName("Out"));
ctx->HasOutput(framework::GradVarName("X"));
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out"));
}
};
class MultiOutOpMaker : public OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "").AsDuplicable();
AddInput("Y", "").AsDuplicable();
AddInput("Z", "").AsDuplicable();
AddOutput("Out", "");
AddOutput("YOut", "");
AddOutput("ZOut", "");
AddOutput("NotReuseOut", "");
AddComment("");
}
};
class MultiOutShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {
ctx->ShareDim("X", "Out");
ctx->ShareDim("Y", "YOut");
ctx->ShareDim("Z", "ZOut");
}
};
class MultiGradOpMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
auto* op = new framework::OpDesc();
op->SetType("multi_out_grad");
op->SetInput("X", Input("X"));
op->SetOutput(framework::GradVarName("Y"), OutputGrad("YOut"));
op->SetOutput(framework::GradVarName("X"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("Z"), OutputGrad("ZOut"));
return std::unique_ptr<framework::OpDesc>(op);
}
};
class MultiOutGradShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("Y"),
ctx->GetInputDim(framework::GradVarName("YOut")));
ctx->SetOutputDim(framework::GradVarName("X"),
ctx->GetInputDim(framework::GradVarName("Out")));
ctx->SetOutputDim(framework::GradVarName("Z"),
ctx->GetInputDim(framework::GradVarName("ZOut")));
}
};
class MultiOutInplaceInToOut : public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const OpDesc& op_desc, bool use_cuda) const override {
return std::unordered_map<std::string, std::string>{
{"X", "Out"}, {"Y", "YOut"}, {"Z", "ZOut"},
};
}
};
class MultiOutGradInplaceInToOut : public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const OpDesc& op_desc, bool use_cuda) const override {
return std::unordered_map<std::string, std::string>{
{framework::GradVarName("YOut"), framework::GradVarName("Y")},
{framework::GradVarName("Out"), framework::GradVarName("X")},
{framework::GradVarName("ZOut"), framework::GradVarName("Z")},
};
}
};
} // namespace framework
} // namespace paddle
namespace f = paddle::framework;
REGISTER_OPERATOR(single_op, f::NOP, f::SingleOpMaker, f::SingleGradOpMaker,
f::SingleOpInplaceInToOut, f::SingleOpShapeInference);
REGISTER_OPERATOR(single_op_grad, f::NOP, f::SingleOpInplaceInToOut,
f::SingleGradOpShapeInference);
REGISTER_OPERATOR(multi_out_op, f::NOP, f::MultiOutOpMaker, f::MultiGradOpMaker,
f::MultiOutInplaceInToOut, f::MultiOutShapeInference);
REGISTER_OPERATOR(multi_out_grad, f::NOP, f::MultiOutGradInplaceInToOut,
f::MultiOutGradShapeInference);
namespace paddle {
namespace framework {
void FakeSuccData(ProgramDesc* prog) { // NOLINT
prog->MutableBlock(0)->Var("test2_a")->SetType(proto::VarType::LOD_TENSOR);
prog->MutableBlock(0)->Var("test2_a")->SetShape({32, 64, 128, 128});
prog->MutableBlock(0)->Var("test2_b")->SetType(proto::VarType::LOD_TENSOR);
prog->MutableBlock(0)->Var("test2_c")->SetType(proto::VarType::LOD_TENSOR);
prog->MutableBlock(0)->Var("test2_out");
prog->MutableBlock(0)->Var("test2_out")->SetShape({64, 32, 128, 128});
}
void FakeNoInplaceData(ProgramDesc* prog) { // NOLINT
prog->MutableBlock(0)->Var("test2_a")->SetType(proto::VarType::LOD_TENSOR);
prog->MutableBlock(0)->Var("test2_a")->SetShape({32, 64, 128, 128});
prog->MutableBlock(0)->Var("test2_b")->SetType(proto::VarType::LOD_TENSOR);
prog->MutableBlock(0)->Var("test2_c")->SetType(proto::VarType::LOD_TENSOR);
prog->MutableBlock(0)->Var("test2_out");
prog->MutableBlock(0)->Var("test2_out")->SetShape({64, 31, 128, 128});
}
ir::Node* GetNodeFromGraph(ir::Graph* g, std::string name) {
ir::Node* op_node = nullptr;
for (auto& item : g->Nodes()) {
if (item->Name() == name) {
op_node = item;
break;
}
}
return op_node;
}
std::unique_ptr<ir::Graph> test_SingleOpInplaceInToOut(
std::unique_ptr<ir::Graph> g) {
auto pass = CreateInplacePass();
ir::Node* op_node = GetNodeFromGraph(g.get(), "single_op");
EXPECT_NE(op_node, nullptr);
pass->Apply(g.get());
return g;
}
TEST(InferInplace, SingleOpInplaceInToOut) {
ProgramDesc prog;
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("single_op");
op->SetInput("X", {"test2_a", "test2_b", "test2_c"});
op->SetOutput("Out", {"test2_out"});
FakeSuccData(&prog);
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
g->Set(ir::kMemOptSkipVars, new std::unordered_set<std::string>());
g = test_SingleOpInplaceInToOut(std::move(g));
auto op_node = GetNodeFromGraph(g.get(), "single_op");
EXPECT_EQ(op_node->outputs[0]->Name(), "test2_a");
}
TEST(InferInplace, SingleOpInplaceInToOutNoInplace) {
ProgramDesc prog;
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("single_op");
op->SetInput("X", {"test2_a", "test2_b", "test2_c"});
op->SetOutput("Out", {"test2_out"});
FakeNoInplaceData(&prog);
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
g->Set(ir::kMemOptSkipVars, new std::unordered_set<std::string>());
g = test_SingleOpInplaceInToOut(std::move(g));
auto op_node = GetNodeFromGraph(g.get(), "single_op");
EXPECT_EQ(op_node->outputs[0]->Name(), "test2_out");
}
TEST(InferInplace, MultiOutInplaceInToOut) {
ProgramDesc prog;
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("multi_out_op");
op->SetInput("X", {"a0", "a1"});
op->SetInput("Y", {"b0"});
op->SetInput("Z", {"c0", "c1"});
op->SetOutput("Out", {"o0"});
op->SetOutput("YOut", {"y0"});
op->SetOutput("ZOut", {"z0"});
prog.MutableBlock(0)->Var("a0")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("b0")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("c0")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("c1")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("o0");
prog.MutableBlock(0)->Var("y0");
prog.MutableBlock(0)->Var("z0");
prog.MutableBlock(0)->Var("a0")->SetShape({32, 16, 1024, 1024});
prog.MutableBlock(0)->Var("b0")->SetShape({32, 16, 1024, 1024});
prog.MutableBlock(0)->Var("c0")->SetShape({32, 16, 1024, 1024});
prog.MutableBlock(0)->Var("o0")->SetShape({32, 16, 1024, 1024});
prog.MutableBlock(0)->Var("y0")->SetShape({32, 16, 1024, 1024});
prog.MutableBlock(0)->Var("z0")->SetShape({32, 16, 1024, 1024});
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
g->Set(ir::kMemOptSkipVars, new std::unordered_set<std::string>());
auto pass = CreateInplacePass();
pass->Apply(g.get());
auto op_node = GetNodeFromGraph(g.get(), "multi_out_op");
ASSERT_TRUE(op_node != nullptr);
EXPECT_EQ(op_node->outputs[0]->Name(), "a0");
EXPECT_EQ(op_node->outputs[1]->Name(), "b0");
EXPECT_EQ(op_node->outputs[2]->Name(), "c0");
}
TEST(InferInplace, MultiGradInplaceInToOut) {
ProgramDesc prog;
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("multi_out_grad");
op->SetInput(GradVarName("Out"), {"o0"});
op->SetInput(GradVarName("YOut"), {"y0"});
op->SetInput(GradVarName("ZOut"), {"z0"});
op->SetOutput(GradVarName("X"), {"a0", "a1"});
op->SetOutput(GradVarName("Y"), {"b0"});
op->SetOutput(GradVarName("Z"), {"c0", "c1"});
prog.MutableBlock(0)->Var("a0")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("b0")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("c0")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("c1")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("o0");
prog.MutableBlock(0)->Var("y0");
prog.MutableBlock(0)->Var("z0");
prog.MutableBlock(0)->Var("a0")->SetShape({32, 16, 1024, 1024});
prog.MutableBlock(0)->Var("b0")->SetShape({32, 16, 1024, 1024});
prog.MutableBlock(0)->Var("c0")->SetShape({32, 16, 1024, 1024});
prog.MutableBlock(0)->Var("o0")->SetShape({32, 16, 1024, 1024});
prog.MutableBlock(0)->Var("y0")->SetShape({32, 16, 1024, 1024});
prog.MutableBlock(0)->Var("z0")->SetShape({32, 15, 1024, 1024});
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
g->Set(ir::kMemOptSkipVars, new std::unordered_set<std::string>());
auto pass = CreateInplacePass();
pass->Apply(g.get());
auto op_node = GetNodeFromGraph(g.get(), "multi_out_grad");
ASSERT_TRUE(op_node != nullptr);
EXPECT_EQ(op_node->outputs[0]->Name(), "o0");
EXPECT_EQ(op_node->outputs[2]->Name(), "y0");
EXPECT_EQ(op_node->outputs[3]->Name(), "c0");
std::unordered_map<std::string, std::string> expects = {
{"o0", "a0"}, {"y0", "b0"}, {"z0", "c0"},
};
}
} // namespace framework
} // namespace paddle
......@@ -4,20 +4,8 @@ cc_library(recurrent_op_eager_deletion_pass SRCS recurrent_op_eager_deletion_pas
cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle var_handle)
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper)
if(WITH_GPU)
cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph graph_helper gpu_info)
else()
cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph graph_helper cpu_info)
endif()
cc_library(memory_optimize_pass SRCS memory_optimize_pass.cc DEPS memory_optimize_helper pass)
cc_library(inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_info)
cc_test(memory_optimize_helper_test SRCS memory_optimize_helper_test.cc memory_optimize_helper.cc DEPS framework_proto graph graph_helper op_registry)
cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle
eager_deletion_op_handle graph graph_helper pass while_op_eager_deletion_pass recurrent_op_eager_deletion_pass reference_count_pass_helper)
cc_library(record_skip_memory_opt_vars_pass SRCS record_skip_memory_opt_vars_pass.cc DEPS graph graph_helper)
cc_library(memory_reuse_pass SRCS memory_reuse_pass.cc DEPS computation_op_handle reference_count_pass_helper share_tensor_buffer_op_handle multi_devices_helper graph pass)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <map>
#include <queue>
#include <string>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_pass.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_info.h"
// NOTE(dzhwinter): inplace means one op output variable reuse the input space.
// By our design, one operator only can read its input(const Variable),
// write its output(non-const Variable). If one operator is inplaced, means
// user have chance to write the space before reading happens.
// Especially when some optimize code writing style is applied.
//
//
// /* wrong case in operator */
// /*In this case, a larger allocation is allocated, input content is lost*/
// const Tensor* in = ctx.Input<Tensor>("In")
// Tensor* out = ctx.Output<Tensor>("Out");
// auto* out_ptr = out->mutable_data<T>(ctx.GetPlace());
// out_ptr[0] = 0; // input contect is overwrited.
// NOTE(dzhwinter):
// Only for backward compacity and stable. if enable_inplace_whitelist is turn
// on.
// only the ops in whitelist will be use inplace strategy.
// if not, all the op will be inplaced if it registered with InplaceClass
DEFINE_bool(
enable_inplace_whitelist, false,
"If this option turns on, only these op in whitelist can be inplaced."
"If it turns off, all of the running op can be candidate of inplaced op."
"Such as scale, elementwise_add"
"By default, it's turned off");
namespace paddle {
namespace framework {
namespace ir {
// clang-format off
const std::string kInplacedOpWhiteList[] = { // NOLINT
"sigmoid",
"exp",
"relu",
"tanh",
"sqrt",
"ceil",
"floor",
"reciprocal",
"relu6",
"soft_relu",
"hard_sigmoid",
"batch_norm",
"batch_norm_grad",
"sum",
"sum_grad",
"scale",
"reshape",
"elementwise_add",
"elementwise_add_grad",
};
// FIXME(zjl): Shapes of in-out of some ops are exactly the same,
// but the static size during compiling time would be wrong.
// Use a flag to indicate such ops. Please fix me when found a better way.
static const std::unordered_set<std::string> kSameShapeOpWhiteSet{ // NOLINT
"reshape2", "reshape2_grad"
};
// clang-format on
class InplacePass : public ir::Pass {
public:
InplacePass();
protected:
void ApplyImpl(ir::Graph *graph) const override;
private:
// Collect vars that cannot be reused
// e.g.: subblock ops in/out, distributed ops in/out, op_role_var
void CollectSkipVars(ir::Graph *graph,
const std::vector<ir::Node *> &ops) const;
// Check whether var_name should be skipped
bool IsSkipVar(const std::string &var_name) const;
// Rename out with name of in, and guarantee that the graph is
// still a SSA graph
void RenameInOut(ir::Node *op, ir::Node *in, ir::Node *out) const;
// Check whether var is the last version one in SSA graph
bool IsLastVersionVar(ir::Node *var) const;
// Check whether var is the first version one in SSA graph
bool IsFirstVersionVar(ir::Node *var) const;
// Check whether all `ops` is the preceding ops of `op`
bool CheckOpDeps(ir::Node *op, const std::vector<ir::Node *> &ops) const;
// Find nodes whose names are equal to the given name
static std::unordered_set<ir::Node *> FindNodesByName(
const std::string &name, const std::vector<ir::Node *> &nodes);
// Collect inputs and outputs of op_desc
static void CollectInputArgsOfOpDesc(
const OpDesc *op_desc, std::unordered_multiset<std::string> *in_args);
// Get all versions vars named var_name
std::vector<ir::Node *> *AllVersionVars(const std::string &var_name) const;
private:
// SSA graph. var_name -> each version of vars
mutable std::map<std::string, std::vector<ir::Node *>> ssa_map_;
// Skip vars, including subblock ops in/out, distributed ops in/out,
// op_role_var
mutable std::unordered_set<std::string> skip_vars_;
// Op whitelist which should not peform inplace
// Only enabled when FLAGS_enable_inplace_whitelist is true.
mutable std::unordered_set<std::string> whitelist_ops_;
};
InplacePass::InplacePass() {
if (FLAGS_enable_inplace_whitelist) {
for (auto &s : kInplacedOpWhiteList) {
whitelist_ops_.emplace(s);
}
}
}
std::vector<ir::Node *> *InplacePass::AllVersionVars(
const std::string &var_name) const {
auto iter = ssa_map_.find(var_name);
PADDLE_ENFORCE(iter != ssa_map_.end(), "cannot find var %s in ssa graph",
var_name);
PADDLE_ENFORCE(!iter->second.empty(), "var %s is empty in ssa graph",
var_name);
return &(iter->second);
}
bool InplacePass::IsSkipVar(const std::string &var_name) const {
return skip_vars_.count(var_name) > 0;
}
bool InplacePass::IsFirstVersionVar(ir::Node *var) const {
return AllVersionVars(var->Name())->front() == var;
}
bool InplacePass::IsLastVersionVar(ir::Node *var) const {
return AllVersionVars(var->Name())->back() == var;
}
bool InplacePass::CheckOpDeps(ir::Node *op,
const std::vector<ir::Node *> &ops) const {
std::unordered_set<ir::Node *> other_ops(ops.begin(), ops.end());
other_ops.erase(op);
if (other_ops.empty()) return true;
// Traverse all preceding ops of op
std::queue<ir::Node *> queue;
std::unordered_set<ir::Node *> visited_ops;
queue.push(op);
visited_ops.insert(op);
// Visit all preceding ops of `op`, and erase it from other_ops if it is
// inside other_ops. Return true only if other_ops is empty(), which means
// that all `ops` are preceding ops of `op`.
while (!queue.empty()) {
auto *cur_op = queue.front();
queue.pop();
for (auto *in_var : cur_op->inputs) {
for (auto *in_op : in_var->inputs) {
if (visited_ops.count(in_op) != 0) {
continue;
}
visited_ops.insert(in_op);
queue.push(in_op);
other_ops.erase(in_op);
if (other_ops.empty()) return true;
}
}
}
return false;
}
void InplacePass::CollectSkipVars(ir::Graph *graph,
const std::vector<ir::Node *> &ops) const {
// 1. Collect op role vars
PADDLE_ENFORCE(graph->Has(kMemOptSkipVars), "Graph should have attr %s",
kMemOptSkipVars);
auto &mem_opt_whitelist = graph->Get<MemOptSkipVars>(kMemOptSkipVars);
for (const auto &var : mem_opt_whitelist) {
skip_vars_.emplace(var);
}
}
void InplacePass::RenameInOut(ir::Node *op, ir::Node *in_var,
ir::Node *out_var) const {
auto out_var_name = out_var->Name();
auto in_var_name = in_var->Name();
auto &all_out_nodes = *AllVersionVars(out_var_name);
auto &all_in_nodes = *AllVersionVars(in_var_name);
auto iter = std::find(all_out_nodes.begin(), all_out_nodes.end(), out_var);
PADDLE_ENFORCE(iter != all_out_nodes.end(), "Cannot find out var %s",
out_var_name);
// The following codes are designed to guarantee that ssa_map_ is still
// an ssa graph after inplace is performed.
// Step 1: Rename the following versions of out_var as the name of in_var
// Step 2: Remove the following versions of out_var and append them to in_var
// Be careful that the inputs of input op of out_var should not be renamed,
// but outputs should be renamed.
auto original_iter = iter;
while (iter != all_out_nodes.end()) {
auto *node = *iter;
/* Step 1 */
node->RenameVar(in_var_name);
if (iter != original_iter) {
for (auto *in : node->inputs) {
if (in->IsOp() && in->Op()) {
in->Op()->RenameOutput(out_var_name, in_var_name);
in->Op()->RenameInput(out_var_name, in_var_name);
in->Op()->Flush();
}
}
}
for (auto *out : node->outputs) {
if (out->IsOp() && out->Op()) {
out->Op()->RenameOutput(out_var_name, in_var_name);
out->Op()->RenameInput(out_var_name, in_var_name);
out->Op()->Flush();
}
}
/* Step 2 */
all_in_nodes.emplace_back(node);
++iter;
}
/* Step 2 */
all_out_nodes.erase(original_iter, all_out_nodes.end());
if (all_out_nodes.empty()) {
ssa_map_.erase(out_var_name);
}
op->Op()->RenameOutput(out_var_name, in_var_name);
op->Op()->Flush();
}
std::unordered_set<ir::Node *> InplacePass::FindNodesByName(
const std::string &name, const std::vector<ir::Node *> &nodes) {
std::unordered_set<ir::Node *> ret;
for (auto *node : nodes) {
if (node->Name() == name) {
ret.insert(node);
}
}
return ret;
}
void InplacePass::CollectInputArgsOfOpDesc(
const OpDesc *op_desc, std::unordered_multiset<std::string> *in_args) {
in_args->clear();
for (auto &in_name : op_desc->InputArgumentNames()) {
in_args->insert(in_name);
}
}
void InplacePass::ApplyImpl(ir::Graph *graph) const {
// Step 1: topo sort ops, collect skip vars
auto ops = ir::TopologySortOperations(*graph);
CollectSkipVars(graph, ops);
// Step 2: build ssa var map
for (auto *op_node : ops) {
for (auto *in : op_node->inputs) {
PADDLE_ENFORCE(in->IsVar());
// Only create a new var node when var first occurs in input of op.
if (ssa_map_.count(in->Name()) == 0) {
ssa_map_[in->Name()].emplace_back(in);
}
}
// Always create a new var node for each output of op.
for (auto *out : op_node->outputs) {
PADDLE_ENFORCE(out->IsVar());
ssa_map_[out->Name()].emplace_back(out);
}
}
// Step 3: traverse ops and try inplace if possible
bool use_cuda = Get<bool>(kUseCuda);
VLOG(4) << "Inplace pass is applied when use_cuda = "
<< (use_cuda ? "true" : "false");
for (auto *op_node : ops) {
PADDLE_ENFORCE_NOT_NULL(op_node->Op(), "op_desc is nullptr");
auto *op_desc = op_node->Op();
auto op_type = op_desc->Type();
// Skip op inside whitelist
if (whitelist_ops_.count(op_type) > 0) {
continue;
}
auto &infer_inplace = OpInfoMap::Instance().Get(op_type).infer_inplace_;
if (!infer_inplace) {
continue;
}
auto in_to_outs = infer_inplace(*op_desc, use_cuda);
if (in_to_outs.empty()) continue;
std::unordered_multiset<std::string> all_in_args;
CollectInputArgsOfOpDesc(op_desc, &all_in_args);
for (auto &pair : in_to_outs) {
auto &in_param = pair.first;
auto &out_param = pair.second;
auto &in_args = op_desc->Input(in_param);
auto &out_args = op_desc->Output(out_param);
if (in_args.empty()) {
VLOG(4) << "Cannot inplace because Input(" << in_param
<< ") is empty in " << op_type;
continue;
}
if (out_args.empty()) {
VLOG(4) << "Cannot inplace because Output(" << out_param
<< ") is empty in " << op_type;
continue;
}
auto &in_arg = in_args[0];
auto &out_arg = out_args[0];
if (IsSkipVar(in_arg)) {
VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
<< " is skipped in " << op_type;
continue;
}
if (IsSkipVar(out_arg)) {
VLOG(4) << "Cannot inplace because Output(" << out_param
<< ")=" << out_arg << " is skipped in " << op_type;
continue;
}
if (in_arg == out_arg) {
VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
<< " is the same with Output(" << out_param << ")=" << out_arg
<< " in " << op_type;
continue;
}
size_t in_arg_occur_times = all_in_args.count(in_arg);
if (in_arg_occur_times > 1) {
VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
<< " occurs " << in_arg_occur_times << " times in input of op "
<< op_type;
continue;
}
auto in_nodes = FindNodesByName(in_arg, op_node->inputs);
PADDLE_ENFORCE(!in_nodes.empty(), "Input(%s)=%s cannot be found in op %s",
in_param, in_arg, op_type);
if (in_nodes.size() > 1) {
VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
<< " occurs in other inputs of " << op_type;
continue;
}
auto *in_node = *in_nodes.begin();
if (!NodeCanReused(in_node)) {
VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
<< " is not reusable in " << op_type;
continue;
}
if (!IsLastVersionVar(in_node)) {
VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
<< " is not the last version in " << op_type;
continue;
}
// If in_node is used as inputs of many ops, check whether all of that ops
// depends on op_node. If not, in_node cannot be inplaced.
if (in_node->outputs.size() > 1 &&
!CheckOpDeps(op_node, in_node->outputs)) {
VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
<< " is not lastly used in " << op_type;
continue;
}
auto out_nodes = FindNodesByName(out_arg, op_node->outputs);
PADDLE_ENFORCE(!out_nodes.empty(),
"Output(%s)=%s cannot be found in op %s", out_param,
out_arg, op_type);
PADDLE_ENFORCE_EQ(
out_nodes.size(), 1,
"Wrong graph: Output(%s)=%s occurs in other outputs of op %s",
out_param, out_arg, op_type);
if (!FindNodesByName(in_arg, op_node->outputs).empty()) {
VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
<< " occurs in output of op " << op_type;
continue;
}
if (!FindNodesByName(out_arg, op_node->inputs).empty()) {
VLOG(4) << "Cannot inplace because Output(" << out_param
<< ")=" << out_arg << " occurs in input of op " << op_type;
continue;
}
auto *out_node = *out_nodes.begin();
if (!IsFirstVersionVar(out_node)) {
VLOG(4) << "Cannot inplace because Output(" << out_param
<< ")=" << out_arg << " does not occur first in op " << op_type;
continue;
}
if (!NodeCanReused(out_node)) {
VLOG(4) << "Cannot inplace because Output(" << out_param
<< ")=" << out_arg << " is not reusable in " << op_type;
continue;
}
if (in_node->Var()->GetType() != out_node->Var()->GetType()) {
VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
<< " is not the same type with "
<< "Output(" << out_param << ")=" << out_arg << " in "
<< op_type;
continue;
}
if (NodeSize(*in_node->Var()) != NodeSize(*out_node->Var()) &&
kSameShapeOpWhiteSet.count(op_desc->Type()) == 0) {
VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
<< " is not the same size with "
<< "Output(" << out_param << ")=" << out_arg << " in "
<< op_type;
continue;
}
VLOG(4) << "Rename " << out_node->Name() << " with " << in_node->Name()
<< " in " << op_type;
RenameInOut(op_node, in_node, out_node);
}
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(inplace_pass, paddle::framework::ir::InplacePass)
.RequirePassAttr(paddle::framework::ir::kUseCuda);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include <algorithm>
#include <deque>
#include <functional>
#include <iterator>
#include <numeric>
#include <sstream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/platform/cpu_info.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/gpu_info.h"
#endif // PADDLE_WITH_CUDA
namespace paddle {
namespace framework {
namespace ir {
using paddle::framework::VarDesc;
std::vector<ir::Node*> SortOpLikeDescOrder(const ir::Graph& graph) {
PADDLE_ENFORCE(graph.Has(details::kStaleProgramOpDescs),
"Graph has no attribute of kStaleProgramOpDescs.");
// 1. get op desc order
auto& op_descs =
graph.Get<const std::vector<OpDesc*>>(details::kStaleProgramOpDescs);
// 2. topology sort order
auto nodes = graph.Nodes();
std::deque<ir::Node*> ops;
FilterVariables(nodes, [&](ir::Node* op) {
if (op->IsOp() && op->Op() != nullptr) {
ops.emplace_back(op);
}
});
std::unordered_map<ir::Node*, size_t> op_deps;
std::list<ir::Node*> ready_ops;
std::unordered_map<ir::Node*, std::unordered_set<ir::Node*>> pending_ops;
for (auto* op : ops) {
std::unordered_set<ir::Node*> preceding_op;
for (auto* in : op->inputs) {
if (in->inputs.empty()) continue;
PADDLE_ENFORCE(in->inputs.size() == 1 && in->inputs[0]->IsOp());
preceding_op.emplace(in->inputs[0]);
pending_ops[in->inputs[0]].emplace(op);
}
op_deps[op] = preceding_op.size();
if (preceding_op.empty()) {
ready_ops.emplace_back(op);
}
}
// 3. generated op list based desc order and the topology order
std::vector<ir::Node*> ret;
std::list<OpDesc*> op_descs_list(op_descs.begin(), op_descs.end());
auto update_by_found_node = [&](ir::Node* found_node) {
for (auto* pending_op : pending_ops[found_node]) {
if (--op_deps[pending_op] == 0) {
ready_ops.emplace_back(pending_op);
}
}
ready_ops.remove(found_node);
ret.emplace_back(found_node);
};
while (!ready_ops.empty()) {
bool all_of_ready_op_unmatched = true;
for (auto it = op_descs_list.begin(); it != op_descs_list.end();) {
auto op_desc = *it;
ir::Node* found_node = nullptr;
for (auto* op : ready_ops) {
if (IsSameDesc(op->Op(), op_desc)) {
found_node = op;
break;
}
}
// 3.1 op desc deleted by other pass
if (found_node == nullptr) {
++it;
continue;
} else {
all_of_ready_op_unmatched = false;
it = op_descs_list.erase(it);
}
update_by_found_node(found_node);
}
// 3.2 op descs are added by other pass
// preceding op non empty means some new op descs are
// created, but not contained in return node list.
// these new op desc may depend on each other.
std::list<ir::Node*> prev_ready_ops(ready_ops);
if (all_of_ready_op_unmatched) {
for (auto op : prev_ready_ops) {
update_by_found_node(op);
}
}
}
PADDLE_ENFORCE(std::all_of(
op_deps.begin(), op_deps.end(),
[&](const std::pair<ir::Node*, size_t>& p) { return p.second == 0; }));
return ret;
}
size_t NodeSize(const VarDesc& node) {
auto shape = node.GetShape();
int size =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
size_t type_size = SizeOfType(node.GetDataType());
return type_size * std::abs(size);
}
size_t NodeSize(ir::Node* n) { return NodeSize(*(n->Var())); }
std::string DebugStringImpl(VarDesc* var) {
std::stringstream ss;
ss << var->Name();
ss << "[";
try {
auto shape = var->GetShape();
for (size_t i = 0; i < shape.size(); ++i) {
if (i != shape.size() - 1) {
ss << shape[i] << ",";
} else {
ss << shape[i];
}
}
ss << "]";
} catch (...) {
ss << "Var has no VarDesc !!! Name:" << var->Name();
}
return ss.str();
}
std::string DebugString(ir::Node* var) {
return DebugStringImpl(GetVarDesc(var));
}
// NOTE(dzh): based ir node, if a large node has been reused
// by a small size node, then next time it appear in pool, it will
// have the small size. Find the original node shap from blockdesc.
VarDesc* GetVarDesc(ir::Node* n) {
PADDLE_ENFORCE(n->IsVar() && !n->IsCtrlVar() && n->inputs.size() == 1);
return n->Var();
}
struct NodeComparator {
bool operator()(ir::Node* lhs, ir::Node* rhs) const {
if (lhs->Var()->GetType() != rhs->Var()->GetType()) return false;
auto* lhs_desc = GetVarDesc(lhs);
auto* rhs_desc = GetVarDesc(rhs);
// match data type
if (lhs_desc->GetDataType() != rhs_desc->GetDataType()) {
return false;
}
// match shape
auto lhs_shape = lhs_desc->GetShape();
auto rhs_shape = rhs_desc->GetShape();
if ((lhs_shape[0] == -1 && rhs_shape[0] == -1) ||
(lhs_shape[0] != -1 && rhs_shape[0] != -1)) {
return NodeSize(lhs) == NodeSize(rhs);
} else {
return false;
}
}
};
void OrderedSet::Insert(ir::Node* var) {
PADDLE_ENFORCE(var->IsVar() && !var->IsCtrlVar());
if (mark_table_.count(var->Name()) != 0) {
mark_table_[var->Name()]->emplace_back(var);
return;
}
auto* var_desc = var->Var();
auto var_shape = var_desc->GetShape();
int batch_size = static_cast<int>(var_shape[0]);
NodeComparator functor;
Iter it = nodes_.begin();
while (it != nodes_.end()) {
auto& prev = it->front();
auto* cache_desc = GetVarDesc(prev);
int cache_batch_size = cache_desc->GetShape()[0];
if ((cache_batch_size == -1 && batch_size == -1) ||
(cache_batch_size != -1 && batch_size != -1)) {
if (functor(prev, var)) {
++it;
} else {
break;
}
} else if (cache_batch_size == -1 && batch_size != -1) {
++it;
} else if (cache_batch_size != -1 && batch_size == -1) {
break;
}
}
it = nodes_.insert(it, {var});
mark_table_[var->Name()] = it;
}
int OrderedSet::GetNodeIndexInPool(ir::Node* var) {
return std::distance(nodes_.begin(), mark_table_[var->Name()]);
}
ir::Node* OrderedSet::FindBestFitNode(ir::Node* var) const {
ir::Node* found_node = nullptr;
NodeComparator functor;
for (auto it = nodes_.begin(); it != nodes_.end(); ++it) {
auto& candidate = it->front();
if (functor(var, candidate)) {
found_node = candidate;
break;
}
}
return found_node;
}
ir::Node* OrderedSet::FindNextBestFitNode(ir::Node* var, ir::Node* prev) const {
ir::Node* found_node = nullptr;
NodeComparator functor;
auto it =
std::find_if(nodes_.begin(), nodes_.end(), [&](const NodeVector& v) {
if (v.front() == prev)
return true;
else
return false;
});
PADDLE_ENFORCE(it != nodes_.end(), "Not found previous in node list!");
for (it = std::next(it); it != nodes_.end(); ++it) {
auto& candidate = it->front();
if (functor(var, candidate)) {
found_node = candidate;
break;
}
}
return found_node;
}
bool OrderedSet::Has(ir::Node* var) const {
if (mark_table_.count(var->Name())) {
auto& node_in_samename = mark_table_.at(var->Name());
auto iter =
std::find_if(node_in_samename->begin(), node_in_samename->end(),
[&](ir::Node* n) { return n->Name() == var->Name(); });
return iter != node_in_samename->end();
}
return false;
}
void OrderedSet::Erase(const std::string& var) {
PADDLE_ENFORCE(mark_table_.count(var));
nodes_.erase(mark_table_[var]);
mark_table_.erase(var);
}
void OrderedSet::Erase(ir::Node* var) {
PADDLE_ENFORCE(var != nullptr);
Erase(var->Name());
}
std::string OrderedSet::ToString() const {
std::stringstream ss;
for (auto it = nodes_.begin(); it != nodes_.end(); ++it) {
for (auto& node : *it) {
ss << DebugString(node) << " ";
}
}
return ss.str();
}
bool NodeCanReused(ir::Node* node) {
// valid the node is a var node
// vars can be @EMPTY@, @LR_DECAY_REUSE_ID@. For example, while_grad
if (node == nullptr || !node->IsVar() || node->IsCtrlVar() ||
node->Name() == kEmptyVarName)
return false;
bool flag = true;
// op output force generated in cpu, can not be reused.
for (auto* op : node->inputs) {
if (op->Op()->HasAttr("force_cpu")) {
flag &= framework::AttrReader(op->Op()->GetAttrMap())
.Get<bool>("force_cpu") == 0;
}
}
// var desc validation.
flag &= NodeCanReused(*node->Var());
return flag;
}
int MinChunkSize() {
int size{0};
#ifdef PADDLE_WITH_CUDA
size = platform::GpuMinChunkSize();
#else
size = platform::CpuMinChunkSize();
#endif // PADDLE_WITH_CUDA
return size;
}
bool NodeCanReused(const VarDesc& node) {
auto type = node.GetType();
// only these types holds bulk of gpu memory
// FIXME(liuwei1031) did not find good ways to test SELECTED_ROWS and
// LOD_TENSOR_ARRAY re-use logic,
// disable them in version 1.4
// if (!(type == proto::VarType::LOD_TENSOR ||
// type == proto::VarType::SELECTED_ROWS ||
// type == proto::VarType::LOD_TENSOR_ARRAY)) {
// return false;
// }
if (type != proto::VarType::LOD_TENSOR) return false;
// persistable variable is parameter
if (node.Persistable()) {
return false;
}
// shape < min_chunk_size is meaningless.
// further more, fetched loss always has size = 1
// which should not be reused.
auto shape = node.GetShape();
int size = std::abs(
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()));
if (shape.empty() || size < MinChunkSize()) {
return false;
}
return true;
}
bool OpHasSubBlock(OpDesc* desc) {
const AttributeMap& attrs = desc->GetAttrMap();
for (auto& attr : attrs) {
if (attr.second.type() == typeid(BlockDesc*) || // NOLINT
attr.second.type() == typeid(std::vector<BlockDesc*>)) // NOLINT
return true;
}
return false;
}
ControlFlowGraph::ControlFlowGraph(const ir::Graph& graph) {
ops_ = SortOpLikeDescOrder(graph);
ConnectNodes();
}
void ControlFlowGraph::BuildCFGGraph() {
// FIXME(dzh): same effect with ConnectNodes, but use the control
// link to build dependency graph, it goes wrong in transformer.
for (ir::Node* op : ops_) {
for (auto& input_var : op->inputs) {
if (!input_var->inputs.empty()) {
PADDLE_ENFORCE(
input_var->inputs.size() == 1 && input_var->inputs[0]->IsOp(),
"Preceding Op Node of Var Node must be unique");
auto* pred_op = input_var->inputs[0];
if (pred_op->Op() != nullptr) {
predecessors_[op].insert(pred_op);
successors_[pred_op].insert(op);
}
}
if (input_var->IsVar() && !input_var->IsCtrlVar()) {
uses_[op].insert(input_var->Name());
}
}
for (auto& output_var : op->outputs) {
// output var may be used by many op
for (auto* succ_op : output_var->outputs) {
if (succ_op->Op() != nullptr) {
successors_[op].insert(succ_op);
predecessors_[succ_op].insert(op);
}
}
if (output_var->IsVar() && !output_var->IsCtrlVar()) {
defs_[op].insert(output_var->Name());
}
}
}
}
void ControlFlowGraph::ConnectNodes() {
for (size_t i = 0; i < ops_.size(); ++i) {
auto& op = ops_[i];
try {
auto& next_op = ops_.at(i + 1);
successors_[op].insert(next_op);
predecessors_[next_op].insert(op);
} catch (...) {
// do nothing
}
FilterVariables(op->inputs,
[&](ir::Node* var) { uses_[op].emplace(var->Name()); });
FilterVariables(op->outputs,
[&](ir::Node* var) { defs_[op].emplace(var->Name()); });
}
}
void ControlFlowGraph::LiveVariableAnalysis() {
// NOTE(dzh): variable liveless analysis (a.k.a reversed_ops algorithm)
// compute the liveness of for each variable though reversed_ops algorithm.
// It iterates the operators from end to begin, compute the live in/live out
// variable set for each op, then the diff between in/out will be used for
// the variable reuse. For detail refer to
// http://www.cs.cornell.edu/courses/cs4120/2013fa/lectures/lec26-fa13.pdf
std::list<ir::Node*> work_list(ops_.rbegin(), ops_.rend());
while (!work_list.empty()) {
ir::Node* op = work_list.front();
work_list.pop_front();
// get the live_in calculated before. Empty if first.
auto prev_live_in = std::move(live_in_[op]);
for (auto& s : successors_[op]) {
for (auto& var : live_in_[s]) {
live_out_[op].insert(var);
}
}
for (auto& var : uses_[op]) {
live_in_[op].insert(var);
}
for (auto& var : live_out_[op]) {
live_in_[op].insert(var);
}
for (auto& var : defs_[op]) {
if (uses_[op].count(var)) continue;
live_in_[op].erase(var);
}
// If the live_in is not changed, then the liveness analysis of
// predecessors is completed.
//
// Otherwise, recalculate the predecessors liveness
if (live_in_[op] != prev_live_in) {
for (auto& pre : predecessors_[op]) {
work_list.push_back(pre);
}
}
}
for (auto* op : ops_) {
unlived_vars_[op] = std::set<std::string>();
for (auto& var : this->LiveIn(op)) {
if (!this->LiveOut(op).count(var)) {
unlived_vars_[op].insert(var);
}
}
}
}
void ControlFlowGraph::RenameVarInCFGGraph(const std::string& old_node,
const std::string& new_node,
int begin_idx) {
std::vector<bool> need_update(ops_.size(), false);
// update graph from begin idx to the end
for (size_t i = begin_idx; i != ops_.size(); ++i) {
auto* op = ops_[i];
if (uses_[op].find(old_node) != uses_[op].end()) {
uses_[op].erase(old_node);
uses_[op].insert(new_node);
}
if (defs_[op].find(old_node) != defs_[op].end()) {
defs_[op].erase(old_node);
defs_[op].insert(new_node);
}
if (live_in_[op].find(old_node) != live_in_[op].end()) {
live_in_[op].erase(old_node);
live_in_[op].insert(new_node);
need_update[i] = true;
}
if (live_out_[op].find(old_node) != live_out_[op].end()) {
live_out_[op].erase(old_node);
live_out_[op].insert(new_node);
need_update[i] = true;
}
}
for (size_t i = begin_idx; i < ops_.size(); ++i) {
if (!need_update[i]) continue;
auto* op = ops_[i];
for (auto& var : this->LiveIn(op)) {
if (!this->LiveOut(op).count(var)) {
unlived_vars_[op].insert(var);
}
}
}
}
const std::set<std::string>& ControlFlowGraph::LiveIn(ir::Node* op) const {
auto it = live_in_.find(op);
PADDLE_ENFORCE(
it != live_in_.end(),
string::Sprintf("Expect %s in live_in, but Not Found.", op->Name()));
return it->second;
}
const std::set<std::string>& ControlFlowGraph::LiveOut(ir::Node* op) const {
auto it = live_out_.find(op);
PADDLE_ENFORCE(
it != live_out_.end(),
string::Sprintf("Expect %s in live_out, but Not Found.", op->Name()));
return it->second;
}
const std::set<std::string>& ControlFlowGraph::Use(ir::Node* op) const {
auto it = uses_.find(op);
PADDLE_ENFORCE(
it != uses_.end(),
string::Sprintf("Expect %s in use, but Not Found.", op->Name()));
return it->second;
}
const std::set<std::string>& ControlFlowGraph::Unlived(ir::Node* op) const {
auto it = unlived_vars_.find(op);
PADDLE_ENFORCE(
it != unlived_vars_.end(),
string::Sprintf("Expect %s in unlived_set, but Not Found.", op->Name()));
return it->second;
return it->second;
}
const std::vector<ir::Node*>& ControlFlowGraph::Ops() const { return ops_; }
std::vector<ir::Node*>& ControlFlowGraph::Ops() { return ops_; }
ir::Node* ControlFlowGraph::GetNodeByName(const std::string& name,
ir::Node* op) const {
// in ssa-graph, different version nodes have same name,
// this function get the latest version var before target op
// It may return nullptr, such as data node.
ir::Node* found_node = nullptr;
for (auto* node : ops_) {
if (node == op) break;
for (auto& output : node->outputs) {
PADDLE_ENFORCE((output != nullptr && output->IsVar()),
"Output is empty!");
if (output->Var() && output->Name() == name) {
found_node = output;
}
}
}
return found_node;
}
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <iostream>
#include <iterator>
#include <list>
#include <map>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace ir {
/// this attribute is used to avoid some core variables removed/reused
/// in memory optimize related passes
constexpr char kMemOptSkipVars[] = "@MEM_OPT_SKIP_VARS@";
typedef std::unordered_set<std::string> MemOptSkipVars;
std::vector<ir::Node*> SortOpLikeDescOrder(const ir::Graph& graph);
// NOTE(dzh): A ordered set for node reuse in memory optimize.
// the orderedset sort node in ascend order(by node bytes size).
// in fluid, -1 means the batch_size, which is determined in runtime.
// So the reuse happens between nodes who's batch_size both are -1
// simultaneously or not.
//
// sort rule:
// rule 0 : smaller node ranking in front.
// rule 1 : batch_size equal -1 ranking in the front than the node not.
//
// For example,
// node0[-1, 1] node1[-1, 1, 1], node2[1,1], node3[1,1024], ..
class OrderedSet {
public:
// nodes with same name exists in pool.
using NodeVector = std::vector<ir::Node*>;
using Iter = typename std::list<NodeVector>::iterator;
using ConstIter = typename std::list<NodeVector>::const_iterator;
void Insert(ir::Node* var);
void Erase(ir::Node* var);
void Erase(const std::string& var);
bool Has(ir::Node* var) const;
void Clear() {
mark_table_.clear();
nodes_.clear();
}
// find the bestfit shape node block with var.
ir::Node* FindBestFitNode(ir::Node* var) const;
ir::Node* FindNextBestFitNode(ir::Node* var, ir::Node* prev) const;
// map store non-const iterator, can not promise const
int GetNodeIndexInPool(ir::Node* var);
// pool all node to string
std::string ToString() const;
Iter begin() { return nodes_.begin(); }
Iter end() { return nodes_.end(); }
ConstIter begin() const { return nodes_.begin(); }
ConstIter end() const { return nodes_.end(); }
size_t size() const { return nodes_.size(); }
private:
// for searching.
std::unordered_map<std::string, Iter> mark_table_;
// node pool
std::list<NodeVector> nodes_;
};
class ControlFlowGraph {
public:
ControlFlowGraph() = default;
// IR Graph
explicit ControlFlowGraph(const ir::Graph& graph);
void LiveVariableAnalysis();
void RenameVarInCFGGraph(const std::string& old_node,
const std::string& new_node, int begin_idx);
const std::set<std::string>& LiveIn(ir::Node* op) const;
const std::set<std::string>& LiveOut(ir::Node* op) const;
const std::set<std::string>& Use(ir::Node* op) const;
const std::set<std::string>& Unlived(ir::Node* op) const;
const std::vector<ir::Node*>& Ops() const;
std::vector<ir::Node*>& Ops();
// for ssa-graph nodes
ir::Node* GetNodeByName(const std::string& name, ir::Node* op) const;
private:
void BuildCFGGraph();
void ConnectNodes();
using NodeListMap = std::unordered_map<ir::Node*, std::set<ir::Node*>>;
using VarSetMap = std::map<ir::Node*, std::set<std::string>>;
// successors ops use the output variables.
NodeListMap successors_;
// predecessors ops generated input variables.
NodeListMap predecessors_;
// variables lived before run current op.
VarSetMap live_in_;
// variables lived after run current op.
VarSetMap live_out_;
VarSetMap uses_; // op inputs
VarSetMap defs_; // op outputs
std::unordered_map<ir::Node*, std::set<std::string>> unlived_vars_;
std::vector<ir::Node*> ops_; // op sequence by topology sort
};
// valid a tensor can be reuse or not
bool NodeCanReused(ir::Node* node);
// valid a tensor can be reuse or not.
bool NodeCanReused(const VarDesc& node);
// check op has subblock or not
bool OpHasSubBlock(OpDesc* desc);
// node memory size in bytes
size_t NodeSize(ir::Node* n);
// node memory size in bytes
size_t NodeSize(const VarDesc&);
std::string DebugString(ir::Node* var);
VarDesc* GetVarDesc(ir::Node* n);
static inline bool IsSameDesc(OpDesc* op1, OpDesc* op2) {
return op1->Type() == op2->Type() && op1->Inputs() == op2->Inputs() &&
op1->Outputs() == op2->Outputs();
}
template <typename Container, typename Callback>
class FilterVariableImpl {
public:
void operator()(const Container& nodes, Callback callback) {
for (auto* node : nodes) {
callback(node);
}
}
};
// filter var node for op->inputs/outputs
template <typename Callback>
class FilterVariableImpl<std::vector<ir::Node*>, Callback> {
public:
void operator()(const std::vector<ir::Node*>& nodes, Callback callback) {
for (auto* var : nodes) {
if (var->IsVar() && !var->IsCtrlVar()) {
callback(var);
}
}
}
};
template <typename Container, typename Callback>
void FilterVariables(const Container& nodes, Callback callback) {
FilterVariableImpl<Container, Callback>()(nodes, callback);
}
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include <algorithm>
#include <iostream>
#include <iterator>
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/details/graph_test_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace framework {
namespace ir {
TEST(OrderedSet, Normal) {
OrderedSet pool;
std::vector<std::unique_ptr<ir::Node>> nodes;
// clang-format off
std::vector<std::vector<int64_t>> shapes = {{-1, 10},
{-1, 20},
{1, 2},
{5, 2},
{10, 20},
{-1, 2, 5},
{-1, 1, 5},
{-1, 1}};
// clang-format on
const int COUNT = shapes.size();
ProgramDesc prog;
BlockDesc* block_desc = prog.MutableBlock(0);
auto* op_desc = block_desc->AppendOp();
op_desc->SetType("dummy");
std::unique_ptr<ir::Node> op = ir::CreateNodeForTest(op_desc);
for (int i = 0; i < COUNT; ++i) {
auto desc = block_desc->Var(std::to_string(i));
desc->SetShape(shapes[i]);
std::unique_ptr<ir::Node> node = ir::CreateNodeForTest(desc);
node->inputs.emplace_back(op.get());
nodes.emplace_back(std::move(node));
}
// Insert
for (auto& node : nodes) {
pool.Insert(node.get());
}
// Has/size
ASSERT_EQ(pool.size(), shapes.size());
for (auto& node : nodes) {
ASSERT_TRUE(pool.Has(node.get()));
}
// assert its order and interface.
std::cout << pool.ToString() << std::endl;
pool.Erase(nodes.front().get());
std::cout << pool.ToString() << std::endl;
ASSERT_EQ(pool.size(), static_cast<size_t>(COUNT - 1));
ASSERT_EQ(pool.GetNodeIndexInPool(nodes.back().get()), 0);
{
auto v1 = block_desc->Var("11");
v1->SetShape({-1, 256, 56, 56});
std::unique_ptr<ir::Node> node1 = ir::CreateNodeForTest(v1);
node1->inputs.emplace_back(op.get());
auto* cache = pool.FindBestFitNode(node1.get());
ASSERT_EQ(cache, nullptr);
}
{
auto v2 = block_desc->Var("12");
v2->SetShape({-1, 2, 5});
std::unique_ptr<ir::Node> node1 = ir::CreateNodeForTest(v2);
node1->inputs.emplace_back(op.get());
auto* cache = pool.FindBestFitNode(node1.get());
ASSERT_EQ(pool.GetNodeIndexInPool(cache), 2); // match 6:[-1,2,5]
}
{
auto v3 = block_desc->Var("13");
v3->SetShape({2, 5});
std::unique_ptr<ir::Node> node1 = ir::CreateNodeForTest(v3);
node1->inputs.emplace_back(op.get());
auto* cache = pool.FindBestFitNode(node1.get());
ASSERT_EQ(pool.GetNodeIndexInPool(cache), 5); // match 4:[5,2]
}
}
TEST(OrderedSet, FindBestFitNode) {
OrderedSet pool;
std::vector<std::unique_ptr<ir::Node>> nodes;
ProgramDesc prog;
BlockDesc* block_desc = prog.MutableBlock(0);
auto* op_desc = block_desc->AppendOp();
op_desc->SetType("dummy");
std::unique_ptr<ir::Node> op = ir::CreateNodeForTest(op_desc);
{
auto desc = block_desc->Var("a");
desc->SetShape({128, 128});
std::unique_ptr<ir::Node> node = ir::CreateNodeForTest(desc);
node->inputs.emplace_back(op.get());
nodes.emplace_back(std::move(node));
}
{
auto desc = block_desc->Var("b");
desc->SetShape({128, 129});
std::unique_ptr<ir::Node> node = ir::CreateNodeForTest(desc);
node->inputs.emplace_back(op.get());
nodes.emplace_back(std::move(node));
}
{
auto desc = block_desc->Var("c");
desc->SetShape({128, 128});
std::unique_ptr<ir::Node> node = ir::CreateNodeForTest(desc);
node->inputs.emplace_back(op.get());
nodes.emplace_back(std::move(node));
}
for (auto& node : nodes) {
pool.Insert(node.get());
}
auto* n = nodes[0].get();
auto* cache = pool.FindBestFitNode(n);
ASSERT_TRUE(cache->Name() == "a" || cache->Name() == "c");
auto* cache_b = pool.FindNextBestFitNode(n, cache);
ASSERT_TRUE(cache_b->Name() != cache->Name());
ASSERT_TRUE(cache_b->Name() == "a" || cache_b->Name() == "c");
cache = pool.FindNextBestFitNode(n, cache_b);
ASSERT_TRUE(cache == nullptr);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_OPERATOR(sum, paddle::framework::DummyOp,
paddle::framework::SumOpMaker,
paddle::framework::DummyVarTypeInference);
REGISTER_OPERATOR(assign, paddle::framework::DummyOp,
paddle::framework::AssignOpMaker,
paddle::framework::DummyVarTypeInference);
REGISTER_OPERATOR(dummy, paddle::framework::DummyOp,
paddle::framework::SumOpMaker,
paddle::framework::DummyVarTypeInference);
/*
https://en.wikipedia.org/wiki/Live_variable_analysis
Create a customed classical dependency graph, left row is the instruction
number.
1. a = 1
2. b = a
3. c = a
4. d = b + c
5. e = d
a--------+
| |
b c
| |
d--------+
|
e
Then analysis these variable's liveness range
*/
namespace paddle {
namespace framework {
namespace ir {
inline static ProgramDesc FillProgramDesc() {
ProgramDesc prog;
prog.MutableBlock(0)->Var("a")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("b")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("c")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("d")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("e")->SetType(proto::VarType::LOD_TENSOR);
{
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("assign");
op->SetInput("X", {"a"});
op->SetOutput("Out", {"b"});
}
{
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("assign");
op->SetInput("X", {"a"});
op->SetOutput("Out", {"c"});
}
{
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum");
op->SetInput("X", {"b", "c"});
op->SetOutput("Out", {"d"});
}
{
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("assign");
op->SetInput("X", {"d"});
op->SetOutput("Out", {"e"});
}
return prog;
}
TEST(CFGGraph, IRGraph) {
// prepare ir graph
auto prog = FillProgramDesc();
ir::Graph graph(prog);
ControlFlowGraph cfg(graph);
cfg.LiveVariableAnalysis();
// test assign op
ASSERT_TRUE((std::set<std::string>{"a"} == cfg.LiveIn(cfg.Ops()[0])));
ASSERT_TRUE((std::set<std::string>{"a", "b"} == cfg.LiveOut(cfg.Ops()[0])));
// test assign op
ASSERT_TRUE((std::set<std::string>{"a", "b"} == cfg.LiveIn(cfg.Ops()[1])));
ASSERT_TRUE((std::set<std::string>{"b", "c"} == cfg.LiveOut(cfg.Ops()[1])));
// test sum op
ASSERT_TRUE((std::set<std::string>{"b", "c"} == cfg.LiveIn(cfg.Ops()[2])));
ASSERT_TRUE((std::set<std::string>{"d"} == cfg.LiveOut(cfg.Ops()[2])));
// test assign op
ASSERT_TRUE((std::set<std::string>{"d"} == cfg.LiveIn(cfg.Ops()[3])));
ASSERT_TRUE((std::set<std::string>{} == cfg.LiveOut(cfg.Ops()[3])));
}
// 1. normal test
TEST(SortOpLikeDescOrder, NormalTest) {
auto prog = FillProgramDesc();
ir::Graph graph(prog);
auto nodes = SortOpLikeDescOrder(graph);
auto op_descs = prog.Block(0).AllOps();
for (size_t i = 0; i < nodes.size(); ++i) {
auto node = nodes[i];
auto op_desc = op_descs[i];
ASSERT_TRUE(IsSameDesc(node->Op(), op_desc));
}
}
// 2. remove some op_desc
TEST(SortOpLikeDescOrder, RemoveOpDesc) {
auto prog = FillProgramDesc();
ir::Graph graph(prog);
auto nodes = graph.Nodes();
auto op_descs = prog.Block(0).AllOps();
ir::Node* found_node = nullptr;
for (auto node : nodes) {
if (node->IsOp() && node->outputs.back()->Name() == "e") {
found_node = node;
break;
}
}
PADDLE_ENFORCE(found_node != nullptr);
for (auto it = op_descs.begin(); it != op_descs.end();) {
if (IsSameDesc(*it, found_node->Op())) {
it = op_descs.erase(it);
} else {
++it;
}
}
auto find_node_in_graph = [&](std::string s) {
ir::Node* ret = nullptr;
for (auto n : graph.Nodes()) {
if (n->Name() == s) {
ret = n;
break;
}
}
PADDLE_ENFORCE(ret != nullptr);
return ret;
};
ir::Node* e = find_node_in_graph("e");
ir::Node* d = find_node_in_graph("d");
std::remove(d->outputs.begin(), d->outputs.end(), found_node);
graph.RemoveNode(found_node);
graph.RemoveNode(e);
// other node keeps the same order
auto remain_nodes = SortOpLikeDescOrder(graph);
for (size_t i = 0; i < remain_nodes.size(); ++i) {
auto node = remain_nodes[i];
auto op_desc = op_descs[i];
ASSERT_TRUE(IsSameDesc(node->Op(), op_desc));
}
}
// 3. add some op_desc
TEST(SortOpLikeDescOrder, AddOpDesc) {
auto prog = FillProgramDesc();
ir::Graph graph(prog);
auto find_node_in_graph = [&](std::string s) {
ir::Node* ret = nullptr;
for (auto n : graph.Nodes()) {
if (n->Name() == s) {
ret = n;
break;
}
}
PADDLE_ENFORCE(ret != nullptr);
return ret;
};
// cached desc different with real one
// mimic the intermidiete pass modify the programdesc.
std::vector<OpDesc*> op_descs = graph.OriginProgram().Block(0).AllOps();
auto op = prog.MutableBlock(0)->AppendOp();
prog.MutableBlock(0)->Var("d1")->SetType(proto::VarType::LOD_TENSOR);
op->SetType("sum");
op->SetInput("X", {"b", "c"});
op->SetOutput("Out", {"d1"});
ir::Node* node = graph.CreateOpNode(op);
ir::Node* d1 = graph.CreateVarNode(prog.MutableBlock(0)->Var("d1"));
ir::Node* b = find_node_in_graph("b");
ir::Node* c = find_node_in_graph("c");
node->outputs.emplace_back(d1);
node->inputs.emplace_back(b);
node->inputs.emplace_back(c);
d1->inputs.emplace_back(node);
b->outputs.emplace_back(node);
c->outputs.emplace_back(node);
op_descs.insert(op_descs.begin() + 4, op);
auto nodes = SortOpLikeDescOrder(graph);
for (size_t i = 0; i < nodes.size(); ++i) {
auto node = nodes[i];
auto op_desc = op_descs[i];
ASSERT_TRUE(IsSameDesc(node->Op(), op_desc));
}
}
// 4. add and delete some op_desc
TEST(SortOpLikeDescOrder, AddAndDeleteOpDesc) {
auto prog = FillProgramDesc();
ir::Graph graph(prog);
auto find_node_in_graph = [&](std::string s) {
ir::Node* ret = nullptr;
for (auto n : graph.Nodes()) {
if (n->Name() == s) {
ret = n;
break;
}
}
PADDLE_ENFORCE(ret != nullptr);
return ret;
};
std::vector<OpDesc*> op_descs = graph.OriginProgram().Block(0).AllOps();
// remove sum node
ir::Node* found_node = nullptr;
auto nodes = graph.Nodes();
for (auto node : nodes) {
if (node->Name() == "sum") {
found_node = node;
break;
}
}
PADDLE_ENFORCE(found_node != nullptr);
for (auto it = op_descs.begin(); it != op_descs.end();) {
if (IsSameDesc(*it, found_node->Op())) {
it = op_descs.erase(it);
} else {
++it;
}
}
{
ir::Node* d = find_node_in_graph("d");
ir::Node* c = find_node_in_graph("c");
ir::Node* e = find_node_in_graph("e");
std::remove(d->outputs.begin(), d->outputs.end(), found_node);
std::remove(c->outputs.begin(), c->outputs.end(), found_node);
ir::Node* pending_op = found_node->outputs[0]->outputs[0];
graph.RemoveNode(e);
graph.RemoveNode(pending_op);
graph.RemoveNode(found_node);
}
// add node
auto op = prog.MutableBlock(0)->AppendOp();
prog.MutableBlock(0)->Var("d1")->SetType(proto::VarType::LOD_TENSOR);
op->SetType("sum");
op->SetInput("X", {"b", "c"});
op->SetOutput("Out", {"d1"});
{
ir::Node* node = graph.CreateOpNode(op);
ir::Node* d1 = graph.CreateVarNode(prog.MutableBlock(0)->Var("d1"));
ir::Node* b = find_node_in_graph("b");
ir::Node* c = find_node_in_graph("c");
node->outputs.emplace_back(d1);
node->inputs.emplace_back(b);
node->inputs.emplace_back(c);
b->outputs.emplace_back(node);
c->outputs.emplace_back(node);
}
op_descs.insert(op_descs.begin() + 2, op);
// check the order
auto mynodes = SortOpLikeDescOrder(graph);
for (size_t i = 0; i < mynodes.size(); ++i) {
auto node = mynodes[i];
auto op_desc = op_descs[i];
ASSERT_TRUE(IsSameDesc(node->Op(), op_desc));
}
}
// 5. add and replace some op_desc inplace.
TEST(SortOpLikeDescOrder, AddAndReplaceOpDescInplace) {
auto prog = FillProgramDesc();
ir::Graph graph(prog);
std::vector<OpDesc*> op_descs = graph.OriginProgram().Block(0).AllOps();
auto find_node_in_graph = [&](std::string s) {
ir::Node* ret = nullptr;
for (auto n : graph.Nodes()) {
if (n->Name() == s) {
ret = n;
break;
}
}
PADDLE_ENFORCE(ret != nullptr);
return ret;
};
// add node
auto op = prog.MutableBlock(0)->AppendOp();
prog.MutableBlock(0)->Var("d1")->SetType(proto::VarType::LOD_TENSOR);
op->SetType("sum");
op->SetInput("X", {"b", "c"});
op->SetOutput("Out", {"d1"});
{
ir::Node* node = graph.CreateOpNode(op);
ir::Node* d1 = graph.CreateVarNode(prog.MutableBlock(0)->Var("d1"));
ir::Node* b = find_node_in_graph("b");
ir::Node* c = find_node_in_graph("c");
node->outputs.emplace_back(d1);
node->inputs.emplace_back(b);
node->inputs.emplace_back(c);
d1->inputs.emplace_back(node);
b->outputs.emplace_back(node);
c->outputs.emplace_back(node);
}
op_descs.emplace_back(op);
// replace op_desc inplace
auto nodes = graph.Nodes();
ir::Node* found_node = nullptr;
for (auto node : nodes) {
if (node->IsOp() && node->Op() && node->Name() == "assign") {
if (node->outputs.size() == 1 && node->outputs[0]->Name() == "e") {
found_node = node;
break;
}
}
}
{
ir::Node* d = find_node_in_graph("d");
ir::Node* e = find_node_in_graph("e");
std::remove(d->outputs.begin(), d->outputs.end(), found_node);
std::remove(e->inputs.begin(), e->inputs.end(), found_node);
graph.RemoveNode(found_node);
}
op_descs.erase(op_descs.begin() + 3);
auto replace_op = prog.MutableBlock(0)->AppendOp();
replace_op->SetType("sum");
replace_op->SetInput("X", {"d", "d1"});
replace_op->SetOutput("Out", {"e"});
{
ir::Node* sum2 = graph.CreateOpNode(replace_op);
ir::Node* e = find_node_in_graph("e");
ir::Node* d = find_node_in_graph("d");
ir::Node* d1 = find_node_in_graph("d1");
sum2->inputs.emplace_back(d);
sum2->inputs.emplace_back(d1);
sum2->outputs.emplace_back(e);
e->inputs.emplace_back(sum2);
d->outputs.emplace_back(sum2);
d1->outputs.emplace_back(sum2);
}
op_descs.emplace_back(replace_op);
// compare op order
auto graph_nodes = SortOpLikeDescOrder(graph);
for (size_t i = 0; i < graph_nodes.size(); ++i) {
auto node = graph_nodes[i];
auto op_desc = op_descs[i];
ASSERT_TRUE(IsSameDesc(node->Op(), op_desc));
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_pass.h"
#include <algorithm>
#include <atomic>
#include <deque>
#include <fstream>
#include <iostream>
#include <iterator>
#include <memory>
#include <queue>
#include <sstream>
#include <string>
#include <type_traits>
#include <unordered_set>
#include <vector>
#include "gflags/gflags.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle {
namespace framework {
namespace ir {
void MemoryOptimizePass::ApplyImpl(ir::Graph* graph) const {
CollectSkipVarsSet(graph);
cfg_.reset(new ControlFlowGraph(*graph));
cfg_->LiveVariableAnalysis();
InitSSAGraphNodes();
int reuse_id = 0;
for (size_t idx = 0; idx < cfg_->Ops().size(); ++idx) {
auto& op = cfg_->Ops()[idx];
auto* op_desc = op->Op();
// some op in graph has no op desc
if (op_desc == nullptr) continue;
for (auto& var : op->outputs) {
if (var->IsVar() && !var->IsCtrlVar() && skip_set_.count(var->Name())) {
VLOG(3) << "Skip set contains variable of " << var->Name()
<< "disable reuse on it. skipped";
continue;
}
if (NodeCanReused(var) && cfg_->Use(op).count(var->Name()) == 0) {
ir::Node* cache = pool_.FindBestFitNode(var);
while (cache != nullptr && var->Name() == cache->Name()) {
VLOG(3) << "The same cache variable is cascade reused. "
<< cache->Name() << " is re-filled to the pool after "
<< "the reused op is finished. Current op can not "
<< "replace it again. Skip this candidate.";
cache = pool_.FindNextBestFitNode(var, cache);
}
if (cache != nullptr) {
int node_idx_in_pool = pool_.GetNodeIndexInPool(cache);
VLOG(3) << string::Sprintf(
"!!! %s, %s => %s, cache idx %d, pool size %d",
std::to_string(reuse_id++), DebugString(var), DebugString(cache),
node_idx_in_pool, static_cast<int>(pool_.size()));
// NOTE(dzhwinter): update the ProgramDesc/IR Graph
// and the CFG Graph on the fly.
//
// IR Graph define the dependence relationship between nodes.
//
// ProgramDesc defines the input/output vars. Its used in
// CreateOp, CreateVar when running happens.
//
// CFG Graph store the liveness information, when reuse happens
// we also need to update the variable liveness.
const std::string var_name = var->Name();
const std::string cache_name = cache->Name();
cfg_->RenameVarInCFGGraph(var_name, cache_name, idx);
RenameVarInGraphDesc(var_name, cache_name, idx);
RenameVarInGraphNode(var_name, cache_name, idx, graph);
pool_.Erase(cache_name);
}
}
}
// fill the pool
for (auto& var : cfg_->Unlived(op)) {
ir::Node* var_node = cfg_->GetNodeByName(var, op);
if (var_node == nullptr || var_node->IsCtrlVar()) continue;
if (NodeCanReused(var_node) && !pool_.Has(var_node)) {
pool_.Insert(var_node);
}
}
}
graph->ResolveHazard(var_nodes_);
}
void MemoryOptimizePass::CollectSkipVarsSet(ir::Graph* graph) const {
// fill skip_set_
PADDLE_ENFORCE(graph->Has(kMemOptSkipVars));
auto& mem_opt_whitelist = graph->Get<MemOptSkipVars>(kMemOptSkipVars);
for (const auto& var : mem_opt_whitelist) {
skip_set_.emplace(var);
}
}
void MemoryOptimizePass::RenameVarInGraphDesc(const std::string& var,
const std::string& cache_var,
size_t idx) const {
for (size_t i = idx; i < cfg_->Ops().size(); ++i) {
auto* op = cfg_->Ops()[i];
PADDLE_ENFORCE(op->IsOp() && op->Op());
auto* op_desc = op->Op();
op_desc->RenameInput(var, cache_var);
op_desc->RenameOutput(var, cache_var);
if (op_desc->Block() != nullptr) {
op_desc->Block()->RemoveVar(var);
} else {
LOG(WARNING) << "op " << op->Name() << " not know its block."
<< "Is the op_desc created without block pointer? "
<< "Can not find " << var << " in Block(0)";
}
op_desc->Flush();
}
}
void MemoryOptimizePass::InitSSAGraphNodes() const {
std::unordered_map<std::string, std::unordered_set<ir::Node*>> all_vars;
if (var_nodes_.empty()) {
for (auto* op : cfg_->Ops()) {
for (auto* node : op->inputs) {
if (all_vars[node->Name()].count(node) == 0) {
all_vars[node->Name()].emplace(node);
var_nodes_[node->Name()].emplace_back(node);
}
}
for (auto* node : op->outputs) {
if (all_vars[node->Name()].count(node) == 0) {
all_vars[node->Name()].emplace(node);
var_nodes_[node->Name()].emplace_back(node);
}
}
}
}
}
void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var,
const std::string& cache_var,
size_t idx,
ir::Graph* graph) const {
// if replace happens, we need to create a newer version cache_var
// but use the same dims/data_type with var.
PADDLE_ENFORCE(var_nodes_[var].size() >= 1 &&
var_nodes_[var].at(0)->Var() != nullptr);
std::unique_ptr<VarDesc> var_desc(new VarDesc(*var_nodes_[var].at(0)->Var()));
var_desc->SetName(cache_var);
for (size_t i = idx; i < cfg_->Ops().size(); ++i) {
auto* op = cfg_->Ops()[i];
// redirect the input to the latest version of cache_var
for (auto* node : op->inputs) {
if (node->Name() == var) {
ir::Node* cache_node = var_nodes_[cache_var].back();
// swap node to cache_node
cache_node->outputs.insert(cache_node->outputs.end(),
node->outputs.begin(), node->outputs.end());
PADDLE_ENFORCE(node->inputs.size() == 1 && node->inputs[0]->IsOp());
auto* prev_op = node->inputs[0];
std::replace(prev_op->outputs.begin(), prev_op->outputs.end(), node,
cache_node);
for (auto* next_op : node->outputs) {
std::replace(next_op->inputs.begin(), next_op->inputs.end(), node,
cache_node);
}
// erase unused node
auto& nodes = var_nodes_.at(var);
nodes.erase(std::remove(nodes.begin(), nodes.end(), node), nodes.end());
graph->RemoveNode(node);
}
}
// if we need to rename the output,
// always create a newer version of cache_var
for (auto* node : op->outputs) {
if (node->Name() == var) {
ir::Node* cache_node = graph->CreateVarNode(var_desc.get());
var_nodes_[cache_var].emplace_back(cache_node);
// swap node to cache node
cache_node->outputs.insert(cache_node->outputs.end(),
node->outputs.begin(), node->outputs.end());
cache_node->inputs.emplace_back(op);
std::replace(op->outputs.begin(), op->outputs.end(), node, cache_node);
for (auto* next_op : node->outputs) {
std::replace(next_op->inputs.begin(), next_op->inputs.end(), node,
cache_node);
}
// erase unused node
auto& nodes = var_nodes_.at(var);
nodes.erase(std::remove(nodes.begin(), nodes.end(), node), nodes.end());
graph->RemoveNode(node);
}
}
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(memory_optimize_pass, paddle::framework::ir::MemoryOptimizePass)
.RequireGraphAttr(paddle::framework::details::kStaleProgramOpDescs);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <list>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class MemoryOptimizePass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph* graph) const override;
// fill the variable map(var_nodes) by version.
void InitSSAGraphNodes() const;
private:
// update program descs
void RenameVarInGraphDesc(const std::string& var,
const std::string& cache_var, size_t idx) const;
// update ir nodes
void RenameVarInGraphNode(const std::string& var,
const std::string& cache_var, size_t idx,
ir::Graph* graph) const;
void SubGraphOptimize(OpDesc* op_desc) const;
// 1. scan op with subblock and collect the output/input vars.
// while, while_grad, conditional_block
// 2. scan distributed ops and collect the output/input vars
// 3. op_role_vars
void CollectSkipVarsSet(ir::Graph* graph) const;
private:
// Reuse Node Pool, Owned.
mutable OrderedSet pool_;
// controlflow Graph
mutable std::unique_ptr<ControlFlowGraph> cfg_;
// skip set
mutable std::unordered_set<std::string> skip_set_;
// var nodes
mutable std::map<std::string, std::vector<ir::Node*>> var_nodes_;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
namespace ir {
class RecordSkipMemoryOptVarsPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph* graph) const override {
PADDLE_ENFORCE(!graph->Has(kMemOptSkipVars));
graph->Set(kMemOptSkipVars, new MemOptSkipVars);
auto& skip_vars = graph->Get<MemOptSkipVars>(kMemOptSkipVars);
std::vector<ir::Node*> op_nodes;
for (auto& node : graph->Nodes()) {
PADDLE_ENFORCE_NOT_NULL(node, "The node should not be nullptr.");
if (node->IsOp() && node->Op()) {
op_nodes.emplace_back(node);
}
}
// Insert kEmptyVarName to avoid optimizing empty variable
skip_vars.insert(framework::kEmptyVarName);
// NOTE(zcd): Insert OpRoleVars to SkipVarSet to prevent the vars are rename
// in memory optimize pass.
InsertOpRoleVarsToSkipVarSet(op_nodes, &skip_vars);
InsertSkipMemOptOpInOutToSkipVarSet(op_nodes, &skip_vars);
}
private:
static void InsertOpRoleVarsToSkipVarSet(const std::vector<ir::Node*>& ops,
MemOptSkipVars* skip_vars) {
for (auto& node : ops) {
try {
auto op_role_vars =
boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(op_role_vars.size() % 2, 0);
for (size_t i = 0; i < op_role_vars.size(); i += 2) {
auto& g_name = op_role_vars[i + 1];
skip_vars->insert(g_name);
}
} catch (boost::bad_get& e) {
}
}
}
static void UpdateSkipVarSet(
MemOptSkipVars* skip_vars,
const std::vector<std::vector<std::string>>& var_names) {
for (auto& var_name : var_names) {
skip_vars->insert(var_name.begin(), var_name.end());
}
}
static std::vector<std::string> ToGradVarName(
const std::vector<std::string>& names) {
std::vector<std::string> ret;
ret.reserve(names.size());
for (auto& name : names) {
if (name != framework::kEmptyVarName) {
ret.emplace_back(framework::GradVarName(name));
}
}
return ret;
}
static void InsertSkipMemOptOpInOutToSkipVarSet(
const std::vector<ir::Node*>& ops, MemOptSkipVars* skip_vars) {
static std::unordered_set<std::string> kSkipMemOptOps{
"send", "recv", "prefetch", "send_barrier", "fetch_barrier"};
for (auto& node : ops) {
auto* op_desc = node->Op();
// Some ops (while, conditional_block, recurrent, etc.) have sub-blocks.
// These ops often use variables from its parent or forward blocks.
// Optimizing in/out of such ops would make these variables cannot
// be found when running sub-block ops.
if (OpHasSubBlock(op_desc)) {
UpdateSkipVarSet(skip_vars, {op_desc->InputArgumentNames(),
op_desc->OutputArgumentNames()});
}
// Skip ops that are related to parameter server.
// In distributed mode, trainers and parameter server use same
// variable names to track same variables. We cannot change the
// names of these variables, otherwise trainers or parameter
// server would not find them.
if (kSkipMemOptOps.count(op_desc->Type()) > 0) {
UpdateSkipVarSet(skip_vars, {op_desc->InputArgumentNames(),
op_desc->OutputArgumentNames()});
}
// FIXME(zjl): some ops use variables that are not from their
// inputs or outputs. We do not have a nice method to solve this
// issue yet. Currently, we should skip these variables when
// memory optimization is enabled.
auto op_type = op_desc->Type();
if (op_type == "while_grad") {
// In while_grad, framework::GradVarName(Input("X")) is visited
// without being any in/out of while_grad. While_grad uses
// these variable to accumulate gradient of X across time steps.
UpdateSkipVarSet(skip_vars, {ToGradVarName(op_desc->Input("X"))});
} else if (op_type == "conditional_block_grad") {
// In conditional_block_grad, framework::GradVarName(Input("Input",
// "Cond")) is visited without being any in/out of
// conditional_block_grad. Conditional_block_grad uses these
// variables to accumulate gradient of Input/Cond across time steps.
UpdateSkipVarSet(skip_vars, {ToGradVarName(op_desc->Input("Input")),
ToGradVarName(op_desc->Input("Cond"))});
} else if (op_type == "recurrent" || op_type == "recurrent_grad") {
// Recurrent and recurrent_grad ops are implemented by a very trickly
// way. Attr("states", "ex_states") is visited without being any
// in/out of op. It is because these variables are from sub blocks,
// not main block. Adding these variables to input would make recurrent
// fail since "states" and "ex_states" cannot be found in main block.
// When memory optimization is enabled, "states", "ex_states" and their
// gradient should be skipped.
auto ex_states =
boost::get<std::vector<std::string>>(op_desc->GetAttr("ex_states"));
auto states =
boost::get<std::vector<std::string>>(op_desc->GetAttr("states"));
if (op_type == "recurrent") {
UpdateSkipVarSet(skip_vars, {ex_states, states});
} else {
// In recurrent_grad, framework::GradVarName(Input("parameters",
// "input")) is visited without being any in/out of recurrent_grad.
// Recurrent_grad uses these variables to accumulate gradient of
// parameters/input across time steps.
UpdateSkipVarSet(
skip_vars,
{ToGradVarName(op_desc->Input("parameters")),
ToGradVarName(op_desc->Input("inputs")), ex_states, states,
ToGradVarName(ex_states), ToGradVarName(states)});
}
}
}
}
};
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(record_skip_memory_opt_vars_pass,
paddle::framework::ir::RecordSkipMemoryOptVarsPass);
......@@ -17,7 +17,6 @@
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_proto_maker.h"
......
......@@ -252,7 +252,22 @@ ir::Graph *ParallelExecutorPrivate::ApplyMemoryOptimizePass(ir::Graph *graph) {
VLOG(10) << "buffer_shared_inplace_pass Applied";
}
if (build_strategy_.memory_optimize_) {
/**
* NOTE(zengjinle): If BuildStrategy.memory_optimize = None in Python,
* set BuildStrategy.memory_optimize according to whether gc is enabled.
* If gc is enabled, BuildStrategy.memory_optimize = False.
* If gc is disabled, BuildStrategy.memory_optimize = True.
* This is because gc+memory_optimize is worse than gc only.
*
* As an option, users can enable BuildStrategy.memory_optimize forcely
* by setting True, and disable it forcely by setting False.
*/
bool is_gc_enabled = (GetEagerDeletionThreshold() >= 0);
if (!build_strategy_.memory_optimize_) {
build_strategy_.memory_optimize_ = !is_gc_enabled;
}
if (build_strategy_.memory_optimize_.get()) {
auto cross_op_memory_reuse_pass = ir::PassRegistry::Instance().Get(
"buffer_shared_cross_op_memory_reuse_pass");
cross_op_memory_reuse_pass->SetNotOwned(ir::kMemOptVarInfoMapList,
......@@ -265,7 +280,7 @@ ir::Graph *ParallelExecutorPrivate::ApplyMemoryOptimizePass(ir::Graph *graph) {
VLOG(10) << "buffer_shared_cross_op_memory_reuse_pass Applied";
}
if (GetEagerDeletionThreshold() < 0) {
if (!is_gc_enabled) {
return graph;
}
size_t max_memory_size = static_cast<size_t>(GetEagerDeletionThreshold());
......@@ -313,6 +328,9 @@ ir::Graph *ParallelExecutorPrivate::ApplyMemoryOptimizePass(ir::Graph *graph) {
eager_deletion_pass->SetNotOwned(ir::kAllPlaces, &places_);
graph = eager_deletion_pass->Apply(graph);
VLOG(10) << "EagerDeletionPass Applied";
LOG(INFO) << "Garbage collection strategy is enabled, when "
<< "FLAGS_eager_delete_tensor_gb = "
<< (static_cast<double>(GetEagerDeletionThreshold()) / (1 << 30));
}
return graph;
}
......
......@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/pybind/const_value.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_pass.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
......@@ -34,7 +33,6 @@ void BindConstValue(pybind11::module* m) {
m->def("kControlDepVarName",
[] { return framework::ir::Node::kControlDepVarName; });
m->def("kNewGradSuffix", [] { return framework::kNewGradSuffix; });
m->def("kMemOptSkipVars", [] { return framework::ir::kMemOptSkipVars; });
auto op_proto_and_checker_maker =
m->def_submodule("op_proto_and_checker_maker");
......
......@@ -1548,17 +1548,31 @@ All parameter, weight, gradient are variables in Paddle.
)DOC")
.def_property(
"memory_optimize",
[](const BuildStrategy &self) { return self.memory_optimize_; },
[](BuildStrategy &self, bool b) { self.memory_optimize_ = b; },
R"DOC(The type is BOOL, memory opitimize aims to save total memory
[](const BuildStrategy &self) -> py::object {
if (self.memory_optimize_) {
return py::cast(self.memory_optimize_.get());
} else {
return py::cast(nullptr);
}
},
[](BuildStrategy &self, const py::handle &value) {
auto *py_obj = value.ptr();
if (py_obj == nullptr || py_obj == Py_None) {
self.memory_optimize_ = boost::none;
} else if (PyBool_Check(py_obj)) {
self.memory_optimize_ = (py_obj == Py_True);
} else {
PADDLE_THROW(
"BuildStrategy.memory_optimize must be None, False or True");
}
},
R"DOC(The type is BOOL or None, memory opitimize aims to save total memory
consumption, set to True to enable it.
Memory Optimize is our experimental feature, some variables
may be reused/removed by optimize strategy. If you need to
fetch some variable values when using this feature, please
set the persistable property of the variables to True.
Default False)DOC")
Default None. None means framework would choose to use or not use
this strategy automatically. Currently, None means that it is
enabled when GC is disabled, and disabled when GC is enabled.
True means enabling and False means disabling. Default None.)DOC")
.def_property(
"is_distribution",
[](const BuildStrategy &self) { return self.is_distribution_; },
......@@ -1578,13 +1592,6 @@ All parameter, weight, gradient are variables in Paddle.
"enable_inplace",
[](const BuildStrategy &self) { return self.enable_inplace_; },
[](BuildStrategy &self, bool b) { self.enable_inplace_ = b; })
.def_property("_use_legacy_memory_optimize_strategy",
[](const BuildStrategy &self) {
return self.use_legacy_memory_optimize_strategy_;
},
[](BuildStrategy &self, bool b) {
self.use_legacy_memory_optimize_strategy_ = b;
})
.def_property(
"fuse_all_reduce_ops",
[](const BuildStrategy &self) { return self.fuse_all_reduce_ops_; },
......
......@@ -206,7 +206,7 @@ def __bootstrap__():
'cudnn_exhaustive_search', 'selected_gpus', 'sync_nccl_allreduce',
'limit_of_tmp_allocation',
'times_excess_than_required_tmp_allocation',
'enable_inplace_whitelist', 'cudnn_batchnorm_spatial_persistent'
'cudnn_batchnorm_spatial_persistent'
]
core.init_gflags([sys.argv[0]] +
["--tryfromenv=" + ",".join(read_env_flags)])
......
......@@ -533,36 +533,6 @@ class Executor(object):
return as_numpy(arr)
return [arr[i] for i in range(len(arr))]
def _check_fetch_vars_persistable(self, program, fetch_list):
for var in fetch_list:
if isinstance(var, Variable):
persistable = var.persistable
else:
block_num = program.desc.num_blocks()
persistable = None
var_name = cpt.to_bytes(var)
for i in six.moves.range(block_num):
var_desc = program.desc.block(i).find_var(var_name)
if var_desc:
persistable = var_desc.persistable()
break
assert persistable is not None, "Variable {} is not found".format(
var)
if not persistable:
logging.warn("""
Detect that build_strategy.memory_optimize = True, but the some variables in the fetch
list is not persistable, you may get wrong fetched value, or an exeception may be thrown
about cannot find variable of the fetch list.
TO FIX this:
# Sample
conv1 = fluid.layers.conv2d(data, 4, 5, 1, act=None)
# if you need to fetch conv1, then:
conv1.persistable = True
""")
def run(self,
program=None,
feed=None,
......@@ -667,10 +637,6 @@ class Executor(object):
scope=scope,
return_numpy=return_numpy,
use_program_cache=use_program_cache)
else:
if fetch_list and program._is_data_parallel and program._program and \
program._build_strategy._use_legacy_memory_optimize_strategy:
self._check_fetch_vars_persistable(program._program, fetch_list)
program._compile(scope, self.place)
if program._is_data_parallel:
......
......@@ -61,16 +61,13 @@ class TestSoftmaxWithXe(unittest.TestCase):
build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = inplace
if inplace:
build_strategy._use_legacy_memory_optimize_strategy = True
prog = fluid.CompiledProgram(fluid.default_main_program(
)).with_data_parallel(
build_strategy=build_strategy, places=place)
if inplace:
fetch_list = [z_d.name, x_d.name]
else:
fetch_list = [z_d.name, s_d.name]
fetch_list = [z_d.name, s_d.name]
print('Inplace is {}'.format("ON" if inplace else "OFF"))
z, s = exe.run(prog,
feed={x_d.name: x,
......
......@@ -14,6 +14,7 @@
from __future__ import print_function
import logging
import six
import sys
from collections import defaultdict, MutableSet
......@@ -550,8 +551,14 @@ def memory_optimize(input_program,
fluid.memory_optimize(main_prog)
"""
sys.stderr.write('memory_optimize is deprecated. '
'Use CompiledProgram and Executor\n')
logging.warn(
'Caution! paddle.fluid.memory_optimize() is deprecated '
'and not maintained any more, since it is not stable!\n'
'Please use the newest and stable memory optimization strategies!\n'
' 1. Enable garbage collection strategy by exporting environment '
'variable FLAGS_eager_delete_tensor_gb=0\n'
' 2. Set build_strategy.enable_inplace=True (True is the default '
'value) when using CompiledProgram or ParallelExecutor.\n')
def to_name_str(var):
if isinstance(var, Variable):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册