未验证 提交 4140fe11 编写于 作者: C chengduo 提交者: GitHub

Open fuse optimization ops (#18741)

* open fuse optimization ops
test=develop
上级 582cc297
...@@ -21,12 +21,12 @@ limitations under the License. */ ...@@ -21,12 +21,12 @@ limitations under the License. */
#include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_printer.h"
#include "paddle/fluid/framework/ir/graph_to_program_pass.h" #include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h" #include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h" #include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h" #include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h" #include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.h"
DECLARE_bool(use_mkldnn); DECLARE_bool(use_mkldnn);
...@@ -48,212 +48,195 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -48,212 +48,195 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
public: public:
explicit ParallelExecutorPassBuilder(const BuildStrategy &strategy) explicit ParallelExecutorPassBuilder(const BuildStrategy &strategy)
: ir::PassBuilder(), strategy_(strategy) { : ir::PassBuilder(), strategy_(strategy) {
// Add a graph viz pass to record a graph. ResolveOptionConfliction();
if (!strategy_.debug_graphviz_path_.empty()) {
VLOG(1) << "Add graph_viz_pass";
auto viz_pass = AppendPass("graph_viz_pass");
const std::string graph_path = string::Sprintf(
"%s%s", strategy_.debug_graphviz_path_.c_str(), "_original_graph");
viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
}
// Note(zcd): record_skip_memory_opt_vars_pass should be the first pass. AppendPrintGraphPass("graph_viz_pass", "_original_graph");
VLOG(1) << "Add record_skip_memory_opt_vars_pass"; // Note(zcd): record_skip_memory_opt_vars_pass should
// be the first pass.
AppendPass("record_skip_memory_opt_vars_pass"); AppendPass("record_skip_memory_opt_vars_pass");
AppendPassWithCheck(strategy_.enable_sequential_execution_,
"sequential_execution_pass");
AppendPassWithCheck(strategy_.sync_batch_norm_, "sync_batch_norm_pass");
AppendOpFusePasses();
AppendPrintGraphPass("graph_viz_pass", "_fused_graph");
// TODO(dev-paddle): memory optimize pass should be placed last.
AppendMemoryOptimizePasses();
AppendMultiDevPass();
AppendMultiGraphOptPasses();
AppendPassToSetMkldnnAttr("mkldnn_placement_pass");
// runtime_context_cache pass should be the last pass to enable the attr of
// all original and fused operators. But no operators can be enabled this
// attr if putting it after MultiDevPass.
AppendPassWithCheck(strategy_.cache_runtime_context_,
"runtime_context_cache_pass");
AppendPassWithCheck(strategy_.remove_unnecessary_lock_,
"modify_op_lock_and_record_event_pass");
// Note: This pass is used to check whether the multi_device_graph is right.
AppendPass("multi_devices_check_pass");
#ifdef PADDLE_WITH_MKLDNN SetCollectiveContext();
if (FLAGS_use_mkldnn) { }
VLOG(1) << "Add mkldnn_placement_pass";
AppendPass("mkldnn_placement_pass");
} else if (!strategy_.mkldnn_enabled_op_types_.empty()) {
LOG(WARNING)
<< "mkldnn_enabled_op_types specify the operator type list to "
"use MKLDNN acceleration. It is null in default, means "
"that all the operators supported by MKLDNN will be "
"accelerated. And it should not be set when "
"FLAGS_use_mkldnn=false.";
}
#else
PADDLE_ENFORCE(!FLAGS_use_mkldnn,
"Please compile with MKLDNN first to use MKLDNN");
#endif
if (strategy_.enable_sequential_execution_) { void ResolveOptionConfliction() {
VLOG(1) << "Add sequential_execution_pass"; // Specifies the restrictions between different pass.
AppendPass("sequential_execution_pass"); if (strategy_.enable_parallel_graph_) {
VLOG_IF(3, strategy_.fuse_all_optimizer_ops_)
<< "Currently, fuse_all_optimizer_ops doesn't works under "
"parallel_graph.";
strategy_.fuse_all_optimizer_ops_ = false;
} }
if (strategy_.is_distribution_) {
// Add op fusion. VLOG_IF(3, strategy_.fuse_all_optimizer_ops_)
if (strategy.sync_batch_norm_) { << "Currently, fuse_all_optimizer_ops only works under "
AppendPass("sync_batch_norm_pass"); "Non-distributed mode.";
strategy_.fuse_all_optimizer_ops_ = false;
} }
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
// Add op fusion. VLOG_IF(3, strategy_.fuse_all_optimizer_ops_)
if (strategy.fuse_relu_depthwise_conv_) { << "Currently, fuse_all_optimizer_ops only works under AllReduce "
VLOG(1) << "Add fuse_relu_depthwise_conv_pass"; "mode.";
AppendPass("fuse_relu_depthwise_conv_pass"); strategy_.fuse_all_optimizer_ops_ = false;
VLOG_IF(3, strategy_.fuse_all_reduce_ops_)
<< "fuse_all_optimizer_ops only work in Reducer mode.";
strategy_.fuse_all_reduce_ops_ = false;
} }
}
// TODO(zjl): refactor MemoryOptimizePass to fit void AppendMultiGraphOptPasses() {
// new strategy, which does not need to set // NOTE: fuse_all_reduce_ops will count the number of all_reduce operator
// var.persistable = True // first, if the number is zero, fuse_all_reduce_ops will do nothing.
if (strategy_.use_legacy_memory_optimize_strategy_) { AppendPassWithCheck(strategy_.fuse_all_reduce_ops_,
if (strategy_.enable_inplace_) { "fuse_all_reduce_op_pass");
VLOG(5) << "Add inplace_pass"; AppendPrintGraphPass("multi_devices_print_pass", "_multi_devices_graph");
AppendPass("inplace_pass");
}
}
if (strategy_.fuse_elewise_add_act_ops_) { // experimental shows that the program will be faster if append
VLOG(1) << "Add fuse_elewise_add_act_pass"; // all_reduce_deps_pass here.
AppendPass("fuse_elewise_add_act_pass"); bool append_all_reduce_deps_pass =
} !strategy_.enable_parallel_graph_ &&
(SeqOnlyAllReduceOps(strategy_) ||
strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce);
AppendPassWithCheck(append_all_reduce_deps_pass, "all_reduce_deps_pass");
bool append_backward_optimizer_op_deps_pass =
strategy_.num_trainers_ > 1 && !strategy_.async_mode_ &&
!strategy_.is_distribution_ &&
strategy_.enable_backward_optimizer_op_deps_;
AppendPassWithCheck(append_backward_optimizer_op_deps_pass,
"backward_optimizer_op_deps_pass");
}
void AppendOpFusePasses() {
AppendPassWithCheck(strategy_.fuse_relu_depthwise_conv_,
"fuse_relu_depthwise_conv_pass");
AppendPassWithCheck(strategy_.fuse_elewise_add_act_ops_,
"fuse_elewise_add_act_pass");
// for single card training, fuse_all_reduce_ops is unnecessary. // for single card training, fuse_all_reduce_ops is unnecessary.
// coalesce_grad_tensor_pass should be before of MultiDevPass. // coalesce_grad_tensor_pass should be before of MultiDevPass.
if (strategy_.fuse_all_reduce_ops_) { AppendPassWithCheck(strategy_.fuse_all_reduce_ops_,
VLOG(1) << "Add coalesce_grad_tensor_pass"; "coalesce_grad_tensor_pass");
AppendPass("coalesce_grad_tensor_pass");
}
// Fuse all the optimization operators. // Fuse all the optimization operators.
if (strategy_.is_distribution_) { // NOTE: fuse_all_xx_ops will count the number of xx operator first,
VLOG(3) << "Currently, fuse_all_optimizer_ops only works under " // if the number is zero, fuse_all_reduce_ops will do nothing.
"Non-distributed mode."; // Currently, only one type of optimization algorithm can be fused.
strategy_.fuse_all_optimizer_ops_ = false;
}
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce ||
strategy_.is_distribution_) {
VLOG(3) << "Currently, fuse_all_optimizer_ops only works under AllReduce "
"mode.";
strategy_.fuse_all_optimizer_ops_ = false;
}
if (strategy_.fuse_all_optimizer_ops_) { if (strategy_.fuse_all_optimizer_ops_) {
// NOTE: fuse_all_xx_ops will count the number of xx operator first,
// if the number is zero, fuse_all_reduce_ops will do nothing.
// Currently, only one type of optimization algorithm can be fused.
VLOG(1) << "Add fuse_adam_op_pass";
AppendPass("fuse_adam_op_pass"); AppendPass("fuse_adam_op_pass");
VLOG(1) << "Add fuse_sgd_op_pass";
AppendPass("fuse_sgd_op_pass"); AppendPass("fuse_sgd_op_pass");
VLOG(1) << "Add fuse_momentum_op_pass";
AppendPass("fuse_momentum_op_pass"); AppendPass("fuse_momentum_op_pass");
} }
}
// Add a graph viz pass to record a graph. void AppendMemoryOptimizePasses() { // Append Memory Optimize Pass
if (!strategy.debug_graphviz_path_.empty()) { // TODO(zjl): refactor MemoryOptimizePass to fit
auto viz_pass = AppendPass("graph_viz_pass"); // new strategy, which does not need to set
const std::string graph_path = string::Sprintf( // var.persistable = True
"%s%s", strategy_.debug_graphviz_path_.c_str(), "_fused_graph"); if (strategy_.use_legacy_memory_optimize_strategy_) {
viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path)); AppendPassWithCheck(strategy_.enable_inplace_, "inplace_pass");
}
CollectiveContext *context = CollectiveContext::GetInstance();
context->endpoints_ = strategy_.trainers_endpoints_;
context->trainer_id_ = strategy_.trainer_id_;
PADDLE_ENFORCE(strategy_.trainer_id_ >= 0, "trainer_id_ >= 0");
if (strategy_.trainer_id_ > 0 && strategy_.trainers_endpoints_.size() > 0) {
PADDLE_ENFORCE((unsigned)(strategy_.trainer_id_) <
strategy_.trainers_endpoints_.size(),
"trainer_id_ < endpoints_ size");
} }
VLOG(1) << "CollectiveContext:" << context->String();
// NOTE(dzh): memory optimize should be a runtime pass. // NOTE(dzh): memory optimize should be a runtime pass.
// However, after multi_devices_pass, VarHandle, OpHandle is // However, after multi_devices_pass, VarHandle, OpHandle is
// the de-fact IR, any reuse on Graph is meaningless. // the de-fact IR, any reuse on Graph is meaningless.
// A side-effect of that, memory optimize cannot forsee the fetched vars // A side-effect of that, memory optimize cannot forsee the fetched vars
// , so fetchlist should be set persistable before call the Run interface. // , so fetchlist should be set persistable before call the Run interface.
if (strategy_.use_legacy_memory_optimize_strategy_) { if (strategy_.use_legacy_memory_optimize_strategy_) {
if (strategy_.memory_optimize_) { AppendPassWithCheck(strategy_.memory_optimize_, "memory_optimize_pass");
VLOG(5) << "Add memory_optimize_pass";
AppendPass("memory_optimize_pass");
}
}
// runtime_context_cache pass should be the last pass to enable the attr of
// all original and fused operators. But no operators can be enabled this
// attr if putting it after MultiDevPass.
if (strategy_.cache_runtime_context_) {
VLOG(1) << "Add runtime_context_cache_pass";
AppendPass("runtime_context_cache_pass");
}
AppendMultiDevPass(strategy_);
if (strategy_.fuse_all_reduce_ops_) {
// NOTE: fuse_all_reduce_ops will count the number of all_reduce operator
// first, if the number is zero, fuse_all_reduce_ops will do nothing.
VLOG(1) << "Add fuse_all_reduce_op_pass";
AppendPass("fuse_all_reduce_op_pass");
}
// Add a graph print pass to record a graph with device info.
if (!strategy_.debug_graphviz_path_.empty()) {
VLOG(1) << "Add multi_devices_print_pass";
auto multi_devices_print_pass = AppendPass("multi_devices_print_pass");
const std::string graph_path =
string::Sprintf("%s%s", strategy_.debug_graphviz_path_.c_str(),
"_multi_devices_graph");
multi_devices_print_pass->Set<std::string>(ir::kGraphvizPath,
new std::string(graph_path));
multi_devices_print_pass->Set<ir::GraphvizSSAGraphPrinter>(
"graph_printer", new ir::GraphvizSSAGraphPrinter);
}
// experimental shows that the program will be faster if append
// all_reduce_deps_pass here.
if (!strategy_.enable_parallel_graph_ &&
(SeqOnlyAllReduceOps(strategy_) ||
strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce)) {
VLOG(1) << "Add all_reduce_deps_pass";
AppendPass("all_reduce_deps_pass");
}
if (strategy_.num_trainers_ > 1 && !strategy_.async_mode_ &&
!strategy_.is_distribution_ &&
strategy_.enable_backward_optimizer_op_deps_) {
VLOG(1) << "Add backward_op_deps_pass";
AppendPass("backward_optimizer_op_deps_pass");
} }
}
if (strategy_.remove_unnecessary_lock_) { void SetCollectiveContext() const {
VLOG(1) << "Add modify_op_lock_and_record_event_pass"; CollectiveContext *context = CollectiveContext::GetInstance();
AppendPass("modify_op_lock_and_record_event_pass"); context->endpoints_ = strategy_.trainers_endpoints_;
context->trainer_id_ = strategy_.trainer_id_;
PADDLE_ENFORCE_GE(strategy_.trainer_id_, 0, "trainer_id_ >= 0");
if (strategy_.trainer_id_ > 0 && strategy_.trainers_endpoints_.size() > 0) {
PADDLE_ENFORCE_LT(static_cast<size_t>(strategy_.trainer_id_),
strategy_.trainers_endpoints_.size(),
"trainer_id_ < endpoints_ size");
} }
VLOG(1) << "CollectiveContext:" << context->String();
// Verify that the graph is correct for multi-device executor.
VLOG(1) << "Add multi_devices_check_pass";
AppendPass("multi_devices_check_pass");
} }
// Convert graph to run on multi-devices. // Convert graph to run on multi-devices.
void AppendMultiDevPass(const BuildStrategy &strategy) { void AppendMultiDevPass() {
ir::Pass *multi_devices_pass = nullptr; ir::Pass *multi_devices_pass = nullptr;
if (strategy_.async_mode_) { if (strategy_.async_mode_) {
VLOG(1) << "Add async_multi_devices_pass";
multi_devices_pass = AppendPass("async_multi_devices_pass").get(); multi_devices_pass = AppendPass("async_multi_devices_pass").get();
} else if (strategy_.is_distribution_) { } else if (strategy_.is_distribution_) {
VLOG(1)
<< "Add dist_multi_devices_pass, multi device parameter server mode";
multi_devices_pass = AppendPass("dist_multi_devices_pass").get(); multi_devices_pass = AppendPass("dist_multi_devices_pass").get();
} else { } else {
if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { switch (strategy_.reduce_) {
VLOG(1) << "Add all_reduce_mode_multi_devices_pass"; case BuildStrategy::ReduceStrategy::kAllReduce:
multi_devices_pass = multi_devices_pass =
AppendPass("all_reduce_mode_multi_devices_pass").get(); AppendPass("all_reduce_mode_multi_devices_pass").get();
} else if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kReduce) { break;
VLOG(1) << "Add reduce_mode_multi_devices_pass"; case BuildStrategy::ReduceStrategy::kReduce:
multi_devices_pass = AppendPass("reduce_mode_multi_devices_pass").get(); multi_devices_pass =
} else { AppendPass("reduce_mode_multi_devices_pass").get();
PADDLE_THROW("Unknown reduce strategy."); break;
default:
PADDLE_THROW("Unknown reduce strategy.");
} }
} }
multi_devices_pass->SetNotOwned<const BuildStrategy>("strategy", multi_devices_pass->SetNotOwned<const BuildStrategy>("strategy",
&strategy_); &strategy_);
} }
void AppendPrintGraphPass(const std::string &pass_name,
const std::string &debug_file_suffix) {
if (!strategy_.debug_graphviz_path_.empty()) {
auto viz_pass = AppendPass(pass_name);
const std::string graph_path = string::Sprintf(
"%s%s", strategy_.debug_graphviz_path_.c_str(), debug_file_suffix);
viz_pass->Set<std::string>(ir::kGraphvizPath,
new std::string(graph_path));
}
}
void AppendPassWithCheck(bool append_pass, const std::string &pass_name) {
if (append_pass) {
AppendPass(pass_name);
}
}
void AppendPassToSetMkldnnAttr(const std::string &pass_name) {
#ifdef PADDLE_WITH_MKLDNN
if (FLAGS_use_mkldnn) {
AppendPass(pass_name);
} else if (!strategy_.mkldnn_enabled_op_types_.empty()) {
LOG(WARNING)
<< "mkldnn_enabled_op_types specify the operator type list to "
"use MKLDNN acceleration. It is null in default, means "
"that all the operators supported by MKLDNN will be "
"accelerated. And it should not be set when "
"FLAGS_use_mkldnn=false.";
}
#else
PADDLE_ENFORCE(!FLAGS_use_mkldnn,
"Please compile with MKLDNN first to use MKLDNN");
#endif
}
private: private:
BuildStrategy strategy_; BuildStrategy strategy_;
}; };
...@@ -307,26 +290,20 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, ...@@ -307,26 +290,20 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
pass->Erase(kNCCLCtxs); pass->Erase(kNCCLCtxs);
pass->SetNotOwned<platform::NCCLCommunicator>(kNCCLCtxs, nctx); pass->SetNotOwned<platform::NCCLCommunicator>(kNCCLCtxs, nctx);
#endif #endif
} else if (pass->Type() == "coalesce_grad_tensor_pass" || } else if (pass->Type() == "fuse_all_reduce_op_pass") {
pass->Type() == "fuse_adam_op_pass" ||
pass->Type() == "fuse_sgd_op_pass" ||
pass->Type() == "fuse_momentum_op_pass" ||
pass->Type() == "fuse_all_reduce_op_pass") {
pass->Erase(kPlaces); pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places); pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
pass->Erase(kLocalScopes); pass->Erase(kLocalScopes);
pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes, pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes,
&local_scopes); &local_scopes);
if (pass->Type() == "fuse_all_reduce_op_pass") {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform::NCCLCommunicator *nctx = use_cuda ? nccl_ctxs : nullptr; platform::NCCLCommunicator *nctx = use_cuda ? nccl_ctxs : nullptr;
pass->Erase(kNCCLCtxs); pass->Erase(kNCCLCtxs);
pass->SetNotOwned<platform::NCCLCommunicator>(kNCCLCtxs, nctx); pass->SetNotOwned<platform::NCCLCommunicator>(kNCCLCtxs, nctx);
pass->Erase(kUseHierarchicalAllReduce); pass->Erase(kUseHierarchicalAllReduce);
pass->Set<bool>(kUseHierarchicalAllReduce, pass->Set<bool>(kUseHierarchicalAllReduce,
new bool(use_hierarchical_allreduce_)); new bool(use_hierarchical_allreduce_));
#endif #endif
}
} else if (pass->Type() == "coalesce_grad_tensor_pass") { } else if (pass->Type() == "coalesce_grad_tensor_pass") {
pass->Erase(kPlaces); pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places); pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
......
...@@ -88,7 +88,7 @@ struct BuildStrategy { ...@@ -88,7 +88,7 @@ struct BuildStrategy {
bool fuse_elewise_add_act_ops_{false}; bool fuse_elewise_add_act_ops_{false};
// Fuse_all_optimizer_ops and fuse_all_reduce_ops require that gradients // Fuse_all_optimizer_ops and fuse_all_reduce_ops require that gradients
// should not be sparse types // should not be sparse types
bool fuse_all_optimizer_ops_{false}; bool fuse_all_optimizer_ops_{true};
bool fuse_all_reduce_ops_{false}; bool fuse_all_reduce_ops_{false};
// fuse_relu_depthwise_conv can fuse the `relu -> // fuse_relu_depthwise_conv can fuse the `relu ->
// depthwise_conv` // depthwise_conv`
......
...@@ -483,6 +483,4 @@ class CoalesceGradTensorPass : public ir::Pass { ...@@ -483,6 +483,4 @@ class CoalesceGradTensorPass : public ir::Pass {
} // namespace paddle } // namespace paddle
REGISTER_PASS(coalesce_grad_tensor_pass, REGISTER_PASS(coalesce_grad_tensor_pass,
paddle::framework::ir::CoalesceGradTensorPass) paddle::framework::ir::CoalesceGradTensorPass);
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes);
...@@ -204,6 +204,4 @@ class FuseAdamOpPass : public FuseOptimizerOpPass { ...@@ -204,6 +204,4 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(fuse_adam_op_pass, paddle::framework::ir::FuseAdamOpPass) REGISTER_PASS(fuse_adam_op_pass, paddle::framework::ir::FuseAdamOpPass);
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes);
...@@ -87,6 +87,4 @@ class FuseMomentumOpPass : public FuseOptimizerOpPass { ...@@ -87,6 +87,4 @@ class FuseMomentumOpPass : public FuseOptimizerOpPass {
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(fuse_momentum_op_pass, paddle::framework::ir::FuseMomentumOpPass) REGISTER_PASS(fuse_momentum_op_pass, paddle::framework::ir::FuseMomentumOpPass);
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes);
...@@ -65,6 +65,4 @@ class FuseSgdOpPass : public FuseOptimizerOpPass { ...@@ -65,6 +65,4 @@ class FuseSgdOpPass : public FuseOptimizerOpPass {
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(fuse_sgd_op_pass, paddle::framework::ir::FuseSgdOpPass) REGISTER_PASS(fuse_sgd_op_pass, paddle::framework::ir::FuseSgdOpPass);
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes);
...@@ -26,7 +26,7 @@ namespace paddle { ...@@ -26,7 +26,7 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
constexpr char kGraphvizPath[] = "debug_graphviz_path"; constexpr char kGraphvizPath[] = "graph_viz_path";
class SSAGraphPrinter { class SSAGraphPrinter {
public: public:
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "paddle/fluid/framework/ir/graph_printer.h"
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/inference/analysis/dot.h" #include "paddle/fluid/inference/analysis/dot.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
...@@ -25,8 +26,6 @@ namespace framework { ...@@ -25,8 +26,6 @@ namespace framework {
namespace ir { namespace ir {
using inference::analysis::Dot; using inference::analysis::Dot;
namespace { namespace {
const char kGraphVizPath[] = "graph_viz_path";
std::string FormatName(const Node* node) { std::string FormatName(const Node* node) {
if (!node->IsOp() || !node->Op() || if (!node->IsOp() || !node->Op() ||
!node->Op()->HasAttr(OpProtoAndCheckerMaker::OpNamescopeAttrName())) { !node->Op()->HasAttr(OpProtoAndCheckerMaker::OpNamescopeAttrName())) {
...@@ -39,7 +38,7 @@ std::string FormatName(const Node* node) { ...@@ -39,7 +38,7 @@ std::string FormatName(const Node* node) {
} // namespace } // namespace
void GraphVizPass::ApplyImpl(ir::Graph* graph) const { void GraphVizPass::ApplyImpl(ir::Graph* graph) const {
const std::string graph_viz_path = Get<std::string>(kGraphVizPath); const std::string& graph_viz_path = Get<std::string>(kGraphvizPath);
VLOG(3) << "draw IR graph viz to " << graph_viz_path; VLOG(3) << "draw IR graph viz to " << graph_viz_path;
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_viz_path)); std::unique_ptr<std::ostream> fout(new std::ofstream(graph_viz_path));
PADDLE_ENFORCE(fout->good()); PADDLE_ENFORCE(fout->good());
...@@ -132,4 +131,4 @@ GraphVizPass::marked_nodes_t GraphVizPass::ConsumeMarkedNodes( ...@@ -132,4 +131,4 @@ GraphVizPass::marked_nodes_t GraphVizPass::ConsumeMarkedNodes(
} // namespace paddle } // namespace paddle
REGISTER_PASS(graph_viz_pass, paddle::framework::ir::GraphVizPass) REGISTER_PASS(graph_viz_pass, paddle::framework::ir::GraphVizPass)
.RequirePassAttr(paddle::framework::ir::kGraphVizPath); .RequirePassAttr(paddle::framework::ir::kGraphvizPath);
...@@ -12,12 +12,12 @@ ...@@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_printer.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -29,7 +29,12 @@ class SSAGraghBuilderWithPrinterPass : public ir::Pass { ...@@ -29,7 +29,12 @@ class SSAGraghBuilderWithPrinterPass : public ir::Pass {
std::unique_ptr<std::ostream> fout( std::unique_ptr<std::ostream> fout(
new std::ofstream(Get<std::string>(kGraphvizPath))); new std::ofstream(Get<std::string>(kGraphvizPath)));
PADDLE_ENFORCE(fout->good()); PADDLE_ENFORCE(fout->good());
Get<GraphvizSSAGraphPrinter>("graph_printer").Print(*graph, *fout); if (Has("graph_printer")) {
Get<GraphvizSSAGraphPrinter>("graph_printer").Print(*graph, *fout);
} else {
GraphvizSSAGraphPrinter printer;
printer.Print(*graph, *fout);
}
} }
}; };
......
...@@ -24,6 +24,7 @@ namespace framework { ...@@ -24,6 +24,7 @@ namespace framework {
namespace ir { namespace ir {
Graph* Pass::Apply(Graph* graph) const { Graph* Pass::Apply(Graph* graph) const {
CheckPrevPass();
PADDLE_ENFORCE(graph, "graph passed to Pass::Apply() cannot be empty."); PADDLE_ENFORCE(graph, "graph passed to Pass::Apply() cannot be empty.");
for (const std::string& attr : required_pass_attrs_) { for (const std::string& attr : required_pass_attrs_) {
PADDLE_ENFORCE(attrs_.find(attr) != attrs_.end(), PADDLE_ENFORCE(attrs_.find(attr) != attrs_.end(),
...@@ -41,6 +42,10 @@ Graph* Pass::Apply(Graph* graph) const { ...@@ -41,6 +42,10 @@ Graph* Pass::Apply(Graph* graph) const {
PADDLE_ENFORCE(VarDescIsConsistency(*graph), PADDLE_ENFORCE(VarDescIsConsistency(*graph),
"The VarDescs of persistable variable are not consistency."); "The VarDescs of persistable variable are not consistency.");
applied_ = true; applied_ = true;
if (!graph->Has(kPassRecorder)) {
graph->Set<PassRecorder>(kPassRecorder, new PassRecorder);
}
graph->Get<PassRecorder>(kPassRecorder).insert(Type());
return graph; return graph;
} }
......
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
...@@ -31,6 +32,9 @@ namespace ir { ...@@ -31,6 +32,9 @@ namespace ir {
template <typename PassType> template <typename PassType>
struct PassRegistrar; struct PassRegistrar;
typedef std::unordered_set<std::string> PassRecorder;
constexpr char kPassRecorder[] = "pass_recorder";
class Pass { class Pass {
public: public:
Pass() = default; Pass() = default;
...@@ -104,6 +108,10 @@ class Pass { ...@@ -104,6 +108,10 @@ class Pass {
LOG(FATAL) << "Calling virtual Pass not implemented."; LOG(FATAL) << "Calling virtual Pass not implemented.";
} }
// Some Pass must be placed before this Pass, and some
// Pass must be placed after this Pass.
virtual void CheckPrevPass() const {}
private: private:
template <typename PassType> template <typename PassType>
friend struct PassRegistrar; friend struct PassRegistrar;
......
...@@ -13,12 +13,15 @@ See the License for the specific language governing permissions and ...@@ -13,12 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/ir/pass_builder.h" #include "paddle/fluid/framework/ir/pass_builder.h"
#include <memory>
#include <utility>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
std::shared_ptr<Pass> PassBuilder::AppendPass(const std::string& pass_type) { std::shared_ptr<Pass> PassBuilder::AppendPass(const std::string& pass_type) {
VLOG(3) << "Append " << pass_type;
auto pass = ir::PassRegistry::Instance().Get(pass_type); auto pass = ir::PassRegistry::Instance().Get(pass_type);
passes_.emplace_back(pass.release()); passes_.emplace_back(pass.release());
return passes_.back(); return passes_.back();
......
...@@ -26,7 +26,7 @@ class SyncBatchNormPass : public Pass { ...@@ -26,7 +26,7 @@ class SyncBatchNormPass : public Pass {
void ApplyImpl(ir::Graph *graph) const override { void ApplyImpl(ir::Graph *graph) const override {
VLOG(3) << "Use synchronous batch norm"; VLOG(3) << "Use synchronous batch norm";
for (const Node *n : graph->Nodes()) { for (const Node *n : graph->Nodes()) {
if (n->IsOp()) { if (n->IsOp() && n->Op()) {
auto *op = n->Op(); auto *op = n->Op();
if (op->Type() == "batch_norm") { if (op->Type() == "batch_norm") {
op->SetType("sync_batch_norm"); op->SetType("sync_batch_norm");
......
...@@ -32,6 +32,7 @@ feed_dict = { ...@@ -32,6 +32,7 @@ feed_dict = {
class InplaceTestBase(unittest.TestCase): class InplaceTestBase(unittest.TestCase):
def initParameter(self): def initParameter(self):
self.use_cuda = True self.use_cuda = True
self.fuse_all_optimizer_ops = False
def setUp(self): def setUp(self):
self.initParameter() self.initParameter()
...@@ -39,7 +40,6 @@ class InplaceTestBase(unittest.TestCase): ...@@ -39,7 +40,6 @@ class InplaceTestBase(unittest.TestCase):
self.device_count = fluid.core.get_cuda_device_count() self.device_count = fluid.core.get_cuda_device_count()
else: else:
self.device_count = 4 self.device_count = 4
assert batch_size % self.device_count == 0 assert batch_size % self.device_count == 0
def build_program_and_scope(self): def build_program_and_scope(self):
...@@ -90,6 +90,7 @@ class InplaceTestBase(unittest.TestCase): ...@@ -90,6 +90,7 @@ class InplaceTestBase(unittest.TestCase):
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.memory_optimize = memory_optimize build_strategy.memory_optimize = memory_optimize
build_strategy.enable_inplace = enable_inplace build_strategy.enable_inplace = enable_inplace
build_strategy.fuse_all_optimizer_ops = self.fuse_all_optimizer_ops
compiled_prog = fluid.CompiledProgram(prog).with_data_parallel( compiled_prog = fluid.CompiledProgram(prog).with_data_parallel(
loss_name=loss.name, loss_name=loss.name,
build_strategy=build_strategy, build_strategy=build_strategy,
...@@ -135,6 +136,7 @@ class InplaceTestBase(unittest.TestCase): ...@@ -135,6 +136,7 @@ class InplaceTestBase(unittest.TestCase):
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.memory_optimize = memory_optimize build_strategy.memory_optimize = memory_optimize
build_strategy.enable_inplace = enable_inplace build_strategy.enable_inplace = enable_inplace
build_strategy.fuse_all_optimizer_ops = self.fuse_all_optimizer_ops
compiled_program = fluid.CompiledProgram( compiled_program = fluid.CompiledProgram(
prog).with_data_parallel( prog).with_data_parallel(
loss_name=loss.name, loss_name=loss.name,
...@@ -162,6 +164,19 @@ class InplaceTestBase(unittest.TestCase): ...@@ -162,6 +164,19 @@ class InplaceTestBase(unittest.TestCase):
class CPUInplaceTest(InplaceTestBase): class CPUInplaceTest(InplaceTestBase):
def initParameter(self): def initParameter(self):
self.use_cuda = False self.use_cuda = False
self.fuse_all_optimizer_ops = False
class CUDAInplaceTestWithFuseOptimizationOps(InplaceTestBase):
def initParameter(self):
self.use_cuda = True
self.fuse_all_optimizer_ops = True
class CPUInplaceTestWithFuseOptimizationOps(InplaceTestBase):
def initParameter(self):
self.use_cuda = True
self.fuse_all_optimizer_ops = True
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册