未验证 提交 b6e84806 编写于 作者: R Ruibiao Chen 提交者: GitHub

Merge kProgramDescs in GraphToProgram (#44526)

上级 9841b308
...@@ -517,13 +517,13 @@ cc_test( ...@@ -517,13 +517,13 @@ cc_test(
DEPS op_call_stack) DEPS op_call_stack)
cc_library( cc_library(
program_processing program_utils
SRCS program_processing.cc SRCS program_utils.cc
DEPS proto_desc) DEPS proto_desc)
cc_test( cc_test(
program_processing_test program_utils_test
SRCS program_processing_test.cc SRCS program_utils_test.cc
DEPS proto_desc program_processing) DEPS proto_desc program_utils)
if(WITH_GPU) if(WITH_GPU)
nv_test( nv_test(
......
...@@ -67,7 +67,7 @@ cc_library( ...@@ -67,7 +67,7 @@ cc_library(
cc_library( cc_library(
graph_helper graph_helper
SRCS graph_helper.cc SRCS graph_helper.cc
DEPS graph scale_loss_grad_op_handle) DEPS graph program_utils scale_loss_grad_op_handle)
cc_library( cc_library(
pass pass
SRCS pass.cc SRCS pass.cc
......
...@@ -19,7 +19,9 @@ limitations under the License. */ ...@@ -19,7 +19,9 @@ limitations under the License. */
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/program_utils.h"
DECLARE_bool(convert_all_blocks); DECLARE_bool(convert_all_blocks);
PADDLE_DEFINE_EXPORTED_string(print_sub_graph_dir, PADDLE_DEFINE_EXPORTED_string(print_sub_graph_dir,
...@@ -559,20 +561,27 @@ static void GraphToBlock(const Graph &graph, ...@@ -559,20 +561,27 @@ static void GraphToBlock(const Graph &graph,
<< vars2remove.size() << " nodes"; << vars2remove.size() << " nodes";
} }
std::vector<proto::VarDesc> vars_in_graph;
for (Node *node : graph.Nodes()) {
if (node->IsVar() && node->Var() &&
node->GetVarNodeBlockId() == graph.GetBlockId()) {
vars_in_graph.emplace_back(*node->Var()->Proto());
}
}
// add vars_in_graph to blcok
block->clear_vars(); block->clear_vars();
std::unordered_set<std::string> visited_vars; std::unordered_set<std::string> visited_vars;
for (Node *n : graph.Nodes()) { for (proto::VarDesc &var : vars_in_graph) {
if (n->IsVar()) { const std::string &var_name = var.name();
if (n->Var() && visited_vars.count(n->Var()->Name()) == 0 && if (visited_vars.find(var_name) == visited_vars.end() &&
!vars2remove.count(n->Var()->Name()) && vars2remove.find(var_name) == vars2remove.end()) {
n->GetVarNodeBlockId() == graph.GetBlockId()) { block->add_vars()->MergeFrom(var);
visited_vars.insert(n->Var()->Name()); visited_vars.insert(var_name);
block->add_vars()->MergeFrom(*n->Var()->Proto());
}
} }
} }
block->clear_ops();
block->clear_ops();
std::vector<Node *> nodes; std::vector<Node *> nodes;
if (sort_kind != nullptr) { if (sort_kind != nullptr) {
// Inference Memory Optimize relays on this branch. // Inference Memory Optimize relays on this branch.
...@@ -630,6 +639,13 @@ void GraphToProgram(const Graph &graph, ...@@ -630,6 +639,13 @@ void GraphToProgram(const Graph &graph,
} }
program->CopyFrom(program_pb); program->CopyFrom(program_pb);
if (graph.Has(details::kProgramDescs)) {
details::ProgramDescs program_descs =
graph.Get<details::ProgramDescs>(details::kProgramDescs);
VLOG(8) << "Merge main programs";
MergePrograms(program, program_descs, /*append=*/false);
}
} }
static std::vector<std::vector<ir::Node::Dep>> GetOpDependencies( static std::vector<std::vector<ir::Node::Dep>> GetOpDependencies(
......
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/program_utils.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -78,62 +79,6 @@ Graph *Pass::Apply(Graph *graph) const { ...@@ -78,62 +79,6 @@ Graph *Pass::Apply(Graph *graph) const {
return graph; return graph;
} }
template <typename Container, typename Visitor>
static void VisitAllElements(Container &&container,
Visitor &&visitor,
bool reverse) {
if (reverse) {
std::for_each(container.rbegin(), container.rend(), visitor);
} else {
std::for_each(container.begin(), container.end(), visitor);
}
}
static void MergePrograms(ProgramDesc *dst,
const details::ProgramDescs &srcs,
bool append) {
PADDLE_ENFORCE_NOT_NULL(
dst, platform::errors::InvalidArgument("Dst program must be provided."));
bool reverse = !append;
auto create_var_visitor = [dst](const ProgramDesc &src) {
PADDLE_ENFORCE_EQ(
src.Size(),
1,
platform::errors::Unimplemented("MergePrograms can only support to "
"merge program with only one block."));
const auto &src_block = src.Block(0);
auto *dst_block = dst->MutableBlock(0);
for (const auto *src_new_var : src_block.AllVars()) {
if (dst_block->FindVar(src_new_var->Name())) continue;
auto *dst_new_var = dst_block->Var(src_new_var->Name());
*dst_new_var = *src_new_var;
VLOG(10) << "Create new variable " << dst_new_var->Name();
}
};
VisitAllElements(srcs, create_var_visitor, reverse);
auto create_op_visitor = [dst, reverse](const ProgramDesc &src) {
auto ops = src.Block(0).AllOps();
auto copy_op_visitor = [dst, reverse](const OpDesc *src_op) {
auto *dst_block = dst->MutableBlock(0);
auto *op = reverse ? dst_block->PrependOp() : dst_block->AppendOp();
op->CopyFrom(*src_op);
VLOG(10) << (reverse ? "Prepend" : "Append") << " op " << op->Type();
// FIXME(zjl): some passes does not add VarDesc to program,
// we should fix this bug later...
for (const auto &in_var_name : op->InputArgumentNames()) {
dst_block->Var(in_var_name);
}
for (const auto &out_var_name : op->OutputArgumentNames()) {
dst_block->Var(out_var_name);
}
};
VisitAllElements(ops, copy_op_visitor, reverse);
};
VisitAllElements(srcs, create_op_visitor, reverse);
}
static void FillNotSpecifiedOpRole(const ProgramDesc &main_program) { static void FillNotSpecifiedOpRole(const ProgramDesc &main_program) {
for (size_t block_idx = 0; block_idx < main_program.Size(); ++block_idx) { for (size_t block_idx = 0; block_idx < main_program.Size(); ++block_idx) {
auto ops = main_program.Block(block_idx).AllOps(); auto ops = main_program.Block(block_idx).AllOps();
......
...@@ -12,13 +12,69 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,13 +12,69 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/program_processing.h" #include "paddle/fluid/framework/program_utils.h"
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
template <typename Container, typename Visitor>
inline void VisitAllElements(Container &&container,
Visitor &&visitor,
bool reverse) {
if (reverse) {
std::for_each(container.rbegin(), container.rend(), visitor);
} else {
std::for_each(container.begin(), container.end(), visitor);
}
}
void MergePrograms(ProgramDesc *dst,
const std::vector<ProgramDesc> &srcs,
bool append) {
PADDLE_ENFORCE_NOT_NULL(
dst, platform::errors::InvalidArgument("Dst program must be provided."));
bool reverse = !append;
auto create_var_visitor = [dst](const ProgramDesc &src) {
PADDLE_ENFORCE_EQ(
src.Size(),
1,
platform::errors::Unimplemented("MergePrograms can only support to "
"merge program with only one block."));
const auto &src_block = src.Block(0);
auto *dst_block = dst->MutableBlock(0);
for (const auto *src_new_var : src_block.AllVars()) {
if (dst_block->FindVar(src_new_var->Name())) continue;
auto *dst_new_var = dst_block->Var(src_new_var->Name());
*dst_new_var = *src_new_var;
VLOG(10) << "Create new variable " << dst_new_var->Name();
}
};
VisitAllElements(srcs, create_var_visitor, reverse);
auto create_op_visitor = [dst, reverse](const ProgramDesc &src) {
auto ops = src.Block(0).AllOps();
auto copy_op_visitor = [dst, reverse](const OpDesc *src_op) {
auto *dst_block = dst->MutableBlock(0);
auto *op = reverse ? dst_block->PrependOp() : dst_block->AppendOp();
op->CopyFrom(*src_op);
VLOG(10) << (reverse ? "Prepend" : "Append") << " op " << op->Type();
// FIXME(zjl): some passes does not add VarDesc to program,
// we should fix this bug later...
for (const auto &in_var_name : op->InputArgumentNames()) {
dst_block->Var(in_var_name);
}
for (const auto &out_var_name : op->OutputArgumentNames()) {
dst_block->Var(out_var_name);
}
};
VisitAllElements(ops, copy_op_visitor, reverse);
};
VisitAllElements(srcs, create_op_visitor, reverse);
}
void ProgramProcessor::GetInputsOutputsInBlock( void ProgramProcessor::GetInputsOutputsInBlock(
const BlockDesc &current_block, const BlockDesc &current_block,
std::set<std::string> *inner_inputs, std::set<std::string> *inner_inputs,
......
...@@ -18,7 +18,9 @@ limitations under the License. */ ...@@ -18,7 +18,9 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class ProgramDesc; void MergePrograms(ProgramDesc *dst,
const std::vector<ProgramDesc> &srcs,
bool append);
class ProgramProcessor { class ProgramProcessor {
public: public:
...@@ -30,5 +32,6 @@ class ProgramProcessor { ...@@ -30,5 +32,6 @@ class ProgramProcessor {
void AddDepToBlockOp(const BlockDesc &block); void AddDepToBlockOp(const BlockDesc &block);
}; };
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/program_processing.h" #include "paddle/fluid/framework/program_utils.h"
#include "gtest/gtest-message.h" #include "gtest/gtest-message.h"
#include "gtest/gtest-test-part.h" #include "gtest/gtest-test-part.h"
......
...@@ -1445,9 +1445,11 @@ class Executor(object): ...@@ -1445,9 +1445,11 @@ class Executor(object):
if key not in self._executor_cache._cached_executors: if key not in self._executor_cache._cached_executors:
# To apply IR pass, compile the Program to IrGraph and convert it back to Program # To apply IR pass, compile the Program to IrGraph and convert it back to Program
if isinstance(program, compiler.CompiledProgram): if isinstance(program, compiler.CompiledProgram):
# print(f"Program before convert:\n {inner_program}", flush=True)
program._compile(scope, self.place) program._compile(scope, self.place)
ir_graph = framework.IrGraph(program._graph) ir_graph = framework.IrGraph(program._graph)
inner_program = ir_graph.to_program() inner_program = ir_graph.to_program()
# print(f"Program after convert:\n {inner_program}", flush=True)
else: else:
from paddle.incubate.autograd import prim_enabled, prim2orig from paddle.incubate.autograd import prim_enabled, prim2orig
if prim_enabled() and program == default_main_program(): if prim_enabled() and program == default_main_program():
......
...@@ -470,7 +470,7 @@ HIGH_PARALLEL_JOB_NEW = [ ...@@ -470,7 +470,7 @@ HIGH_PARALLEL_JOB_NEW = [
'cipher_utils_test', 'cipher_utils_test',
'test_program_code', 'test_program_code',
'test_save_model_without_var', 'test_save_model_without_var',
'program_processing_test', 'program_utils_test',
'test_fleet_distributed_strategy', 'test_fleet_distributed_strategy',
'test_hybrid_parallel_topology', 'test_hybrid_parallel_topology',
'test_ascend_trigger', 'test_ascend_trigger',
...@@ -1719,7 +1719,7 @@ CPU_PARALLEL_JOB = [ ...@@ -1719,7 +1719,7 @@ CPU_PARALLEL_JOB = [
'test_sum_api', 'test_sum_api',
'test_op_compat_sensible_pass', 'test_op_compat_sensible_pass',
'test_generate_pass_cc', 'test_generate_pass_cc',
'program_processing_test', 'program_utils_test',
'build_strategy_test', 'build_strategy_test',
'test_fc_rnn_mkldnn_fuse_pass', 'test_fc_rnn_mkldnn_fuse_pass',
'scope_guard_test', 'scope_guard_test',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册