未验证 提交 b6e84806 编写于 作者: R Ruibiao Chen 提交者: GitHub

Merge kProgramDescs in GraphToProgram (#44526)

上级 9841b308
......@@ -517,13 +517,13 @@ cc_test(
DEPS op_call_stack)
cc_library(
program_processing
SRCS program_processing.cc
program_utils
SRCS program_utils.cc
DEPS proto_desc)
cc_test(
program_processing_test
SRCS program_processing_test.cc
DEPS proto_desc program_processing)
program_utils_test
SRCS program_utils_test.cc
DEPS proto_desc program_utils)
if(WITH_GPU)
nv_test(
......
......@@ -67,7 +67,7 @@ cc_library(
cc_library(
graph_helper
SRCS graph_helper.cc
DEPS graph scale_loss_grad_op_handle)
DEPS graph program_utils scale_loss_grad_op_handle)
cc_library(
pass
SRCS pass.cc
......
......@@ -19,7 +19,9 @@ limitations under the License. */
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/program_utils.h"
DECLARE_bool(convert_all_blocks);
PADDLE_DEFINE_EXPORTED_string(print_sub_graph_dir,
......@@ -559,20 +561,27 @@ static void GraphToBlock(const Graph &graph,
<< vars2remove.size() << " nodes";
}
std::vector<proto::VarDesc> vars_in_graph;
for (Node *node : graph.Nodes()) {
if (node->IsVar() && node->Var() &&
node->GetVarNodeBlockId() == graph.GetBlockId()) {
vars_in_graph.emplace_back(*node->Var()->Proto());
}
}
// add vars_in_graph to blcok
block->clear_vars();
std::unordered_set<std::string> visited_vars;
for (Node *n : graph.Nodes()) {
if (n->IsVar()) {
if (n->Var() && visited_vars.count(n->Var()->Name()) == 0 &&
!vars2remove.count(n->Var()->Name()) &&
n->GetVarNodeBlockId() == graph.GetBlockId()) {
visited_vars.insert(n->Var()->Name());
block->add_vars()->MergeFrom(*n->Var()->Proto());
}
for (proto::VarDesc &var : vars_in_graph) {
const std::string &var_name = var.name();
if (visited_vars.find(var_name) == visited_vars.end() &&
vars2remove.find(var_name) == vars2remove.end()) {
block->add_vars()->MergeFrom(var);
visited_vars.insert(var_name);
}
}
block->clear_ops();
block->clear_ops();
std::vector<Node *> nodes;
if (sort_kind != nullptr) {
// Inference Memory Optimize relays on this branch.
......@@ -630,6 +639,13 @@ void GraphToProgram(const Graph &graph,
}
program->CopyFrom(program_pb);
if (graph.Has(details::kProgramDescs)) {
details::ProgramDescs program_descs =
graph.Get<details::ProgramDescs>(details::kProgramDescs);
VLOG(8) << "Merge main programs";
MergePrograms(program, program_descs, /*append=*/false);
}
}
static std::vector<std::vector<ir::Node::Dep>> GetOpDependencies(
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/program_utils.h"
namespace paddle {
namespace framework {
......@@ -78,62 +79,6 @@ Graph *Pass::Apply(Graph *graph) const {
return graph;
}
template <typename Container, typename Visitor>
static void VisitAllElements(Container &&container,
Visitor &&visitor,
bool reverse) {
if (reverse) {
std::for_each(container.rbegin(), container.rend(), visitor);
} else {
std::for_each(container.begin(), container.end(), visitor);
}
}
static void MergePrograms(ProgramDesc *dst,
const details::ProgramDescs &srcs,
bool append) {
PADDLE_ENFORCE_NOT_NULL(
dst, platform::errors::InvalidArgument("Dst program must be provided."));
bool reverse = !append;
auto create_var_visitor = [dst](const ProgramDesc &src) {
PADDLE_ENFORCE_EQ(
src.Size(),
1,
platform::errors::Unimplemented("MergePrograms can only support to "
"merge program with only one block."));
const auto &src_block = src.Block(0);
auto *dst_block = dst->MutableBlock(0);
for (const auto *src_new_var : src_block.AllVars()) {
if (dst_block->FindVar(src_new_var->Name())) continue;
auto *dst_new_var = dst_block->Var(src_new_var->Name());
*dst_new_var = *src_new_var;
VLOG(10) << "Create new variable " << dst_new_var->Name();
}
};
VisitAllElements(srcs, create_var_visitor, reverse);
auto create_op_visitor = [dst, reverse](const ProgramDesc &src) {
auto ops = src.Block(0).AllOps();
auto copy_op_visitor = [dst, reverse](const OpDesc *src_op) {
auto *dst_block = dst->MutableBlock(0);
auto *op = reverse ? dst_block->PrependOp() : dst_block->AppendOp();
op->CopyFrom(*src_op);
VLOG(10) << (reverse ? "Prepend" : "Append") << " op " << op->Type();
// FIXME(zjl): some passes does not add VarDesc to program,
// we should fix this bug later...
for (const auto &in_var_name : op->InputArgumentNames()) {
dst_block->Var(in_var_name);
}
for (const auto &out_var_name : op->OutputArgumentNames()) {
dst_block->Var(out_var_name);
}
};
VisitAllElements(ops, copy_op_visitor, reverse);
};
VisitAllElements(srcs, create_op_visitor, reverse);
}
static void FillNotSpecifiedOpRole(const ProgramDesc &main_program) {
for (size_t block_idx = 0; block_idx < main_program.Size(); ++block_idx) {
auto ops = main_program.Block(block_idx).AllOps();
......
......@@ -12,13 +12,69 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/program_processing.h"
#include "paddle/fluid/framework/program_utils.h"
#include "paddle/fluid/framework/block_desc.h"
namespace paddle {
namespace framework {
template <typename Container, typename Visitor>
inline void VisitAllElements(Container &&container,
Visitor &&visitor,
bool reverse) {
if (reverse) {
std::for_each(container.rbegin(), container.rend(), visitor);
} else {
std::for_each(container.begin(), container.end(), visitor);
}
}
void MergePrograms(ProgramDesc *dst,
const std::vector<ProgramDesc> &srcs,
bool append) {
PADDLE_ENFORCE_NOT_NULL(
dst, platform::errors::InvalidArgument("Dst program must be provided."));
bool reverse = !append;
auto create_var_visitor = [dst](const ProgramDesc &src) {
PADDLE_ENFORCE_EQ(
src.Size(),
1,
platform::errors::Unimplemented("MergePrograms can only support to "
"merge program with only one block."));
const auto &src_block = src.Block(0);
auto *dst_block = dst->MutableBlock(0);
for (const auto *src_new_var : src_block.AllVars()) {
if (dst_block->FindVar(src_new_var->Name())) continue;
auto *dst_new_var = dst_block->Var(src_new_var->Name());
*dst_new_var = *src_new_var;
VLOG(10) << "Create new variable " << dst_new_var->Name();
}
};
VisitAllElements(srcs, create_var_visitor, reverse);
auto create_op_visitor = [dst, reverse](const ProgramDesc &src) {
auto ops = src.Block(0).AllOps();
auto copy_op_visitor = [dst, reverse](const OpDesc *src_op) {
auto *dst_block = dst->MutableBlock(0);
auto *op = reverse ? dst_block->PrependOp() : dst_block->AppendOp();
op->CopyFrom(*src_op);
VLOG(10) << (reverse ? "Prepend" : "Append") << " op " << op->Type();
// FIXME(zjl): some passes does not add VarDesc to program,
// we should fix this bug later...
for (const auto &in_var_name : op->InputArgumentNames()) {
dst_block->Var(in_var_name);
}
for (const auto &out_var_name : op->OutputArgumentNames()) {
dst_block->Var(out_var_name);
}
};
VisitAllElements(ops, copy_op_visitor, reverse);
};
VisitAllElements(srcs, create_op_visitor, reverse);
}
void ProgramProcessor::GetInputsOutputsInBlock(
const BlockDesc &current_block,
std::set<std::string> *inner_inputs,
......
......@@ -18,7 +18,9 @@ limitations under the License. */
namespace paddle {
namespace framework {
class ProgramDesc;
void MergePrograms(ProgramDesc *dst,
const std::vector<ProgramDesc> &srcs,
bool append);
class ProgramProcessor {
public:
......@@ -30,5 +32,6 @@ class ProgramProcessor {
void AddDepToBlockOp(const BlockDesc &block);
};
} // namespace framework
} // namespace paddle
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/program_processing.h"
#include "paddle/fluid/framework/program_utils.h"
#include "gtest/gtest-message.h"
#include "gtest/gtest-test-part.h"
......
......@@ -1445,9 +1445,11 @@ class Executor(object):
if key not in self._executor_cache._cached_executors:
# To apply IR pass, compile the Program to IrGraph and convert it back to Program
if isinstance(program, compiler.CompiledProgram):
# print(f"Program before convert:\n {inner_program}", flush=True)
program._compile(scope, self.place)
ir_graph = framework.IrGraph(program._graph)
inner_program = ir_graph.to_program()
# print(f"Program after convert:\n {inner_program}", flush=True)
else:
from paddle.incubate.autograd import prim_enabled, prim2orig
if prim_enabled() and program == default_main_program():
......
......@@ -470,7 +470,7 @@ HIGH_PARALLEL_JOB_NEW = [
'cipher_utils_test',
'test_program_code',
'test_save_model_without_var',
'program_processing_test',
'program_utils_test',
'test_fleet_distributed_strategy',
'test_hybrid_parallel_topology',
'test_ascend_trigger',
......@@ -1719,7 +1719,7 @@ CPU_PARALLEL_JOB = [
'test_sum_api',
'test_op_compat_sensible_pass',
'test_generate_pass_cc',
'program_processing_test',
'program_utils_test',
'build_strategy_test',
'test_fc_rnn_mkldnn_fuse_pass',
'scope_guard_test',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册