未验证 提交 4140fe11 编写于 作者: C chengduo 提交者: GitHub

Open fuse optimization ops (#18741)

* open fuse optimization ops
test=develop
上级 582cc297
......@@ -21,12 +21,12 @@ limitations under the License. */
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_printer.h"
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.h"
DECLARE_bool(use_mkldnn);
......@@ -48,212 +48,195 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
public:
explicit ParallelExecutorPassBuilder(const BuildStrategy &strategy)
: ir::PassBuilder(), strategy_(strategy) {
// Add a graph viz pass to record a graph.
if (!strategy_.debug_graphviz_path_.empty()) {
VLOG(1) << "Add graph_viz_pass";
auto viz_pass = AppendPass("graph_viz_pass");
const std::string graph_path = string::Sprintf(
"%s%s", strategy_.debug_graphviz_path_.c_str(), "_original_graph");
viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
}
ResolveOptionConfliction();
// Note(zcd): record_skip_memory_opt_vars_pass should be the first pass.
VLOG(1) << "Add record_skip_memory_opt_vars_pass";
AppendPrintGraphPass("graph_viz_pass", "_original_graph");
// Note(zcd): record_skip_memory_opt_vars_pass should
// be the first pass.
AppendPass("record_skip_memory_opt_vars_pass");
AppendPassWithCheck(strategy_.enable_sequential_execution_,
"sequential_execution_pass");
AppendPassWithCheck(strategy_.sync_batch_norm_, "sync_batch_norm_pass");
AppendOpFusePasses();
AppendPrintGraphPass("graph_viz_pass", "_fused_graph");
// TODO(dev-paddle): memory optimize pass should be placed last.
AppendMemoryOptimizePasses();
AppendMultiDevPass();
AppendMultiGraphOptPasses();
AppendPassToSetMkldnnAttr("mkldnn_placement_pass");
// runtime_context_cache pass should be the last pass to enable the attr of
// all original and fused operators. But no operators can be enabled this
// attr if putting it after MultiDevPass.
AppendPassWithCheck(strategy_.cache_runtime_context_,
"runtime_context_cache_pass");
AppendPassWithCheck(strategy_.remove_unnecessary_lock_,
"modify_op_lock_and_record_event_pass");
// Note: This pass is used to check whether the multi_device_graph is right.
AppendPass("multi_devices_check_pass");
#ifdef PADDLE_WITH_MKLDNN
if (FLAGS_use_mkldnn) {
VLOG(1) << "Add mkldnn_placement_pass";
AppendPass("mkldnn_placement_pass");
} else if (!strategy_.mkldnn_enabled_op_types_.empty()) {
LOG(WARNING)
<< "mkldnn_enabled_op_types specify the operator type list to "
"use MKLDNN acceleration. It is null in default, means "
"that all the operators supported by MKLDNN will be "
"accelerated. And it should not be set when "
"FLAGS_use_mkldnn=false.";
}
#else
PADDLE_ENFORCE(!FLAGS_use_mkldnn,
"Please compile with MKLDNN first to use MKLDNN");
#endif
SetCollectiveContext();
}
if (strategy_.enable_sequential_execution_) {
VLOG(1) << "Add sequential_execution_pass";
AppendPass("sequential_execution_pass");
void ResolveOptionConfliction() {
// Specifies the restrictions between different pass.
if (strategy_.enable_parallel_graph_) {
VLOG_IF(3, strategy_.fuse_all_optimizer_ops_)
<< "Currently, fuse_all_optimizer_ops doesn't works under "
"parallel_graph.";
strategy_.fuse_all_optimizer_ops_ = false;
}
// Add op fusion.
if (strategy.sync_batch_norm_) {
AppendPass("sync_batch_norm_pass");
if (strategy_.is_distribution_) {
VLOG_IF(3, strategy_.fuse_all_optimizer_ops_)
<< "Currently, fuse_all_optimizer_ops only works under "
"Non-distributed mode.";
strategy_.fuse_all_optimizer_ops_ = false;
}
// Add op fusion.
if (strategy.fuse_relu_depthwise_conv_) {
VLOG(1) << "Add fuse_relu_depthwise_conv_pass";
AppendPass("fuse_relu_depthwise_conv_pass");
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
VLOG_IF(3, strategy_.fuse_all_optimizer_ops_)
<< "Currently, fuse_all_optimizer_ops only works under AllReduce "
"mode.";
strategy_.fuse_all_optimizer_ops_ = false;
VLOG_IF(3, strategy_.fuse_all_reduce_ops_)
<< "fuse_all_optimizer_ops only work in Reducer mode.";
strategy_.fuse_all_reduce_ops_ = false;
}
}
// TODO(zjl): refactor MemoryOptimizePass to fit
// new strategy, which does not need to set
// var.persistable = True
if (strategy_.use_legacy_memory_optimize_strategy_) {
if (strategy_.enable_inplace_) {
VLOG(5) << "Add inplace_pass";
AppendPass("inplace_pass");
}
}
void AppendMultiGraphOptPasses() {
// NOTE: fuse_all_reduce_ops will count the number of all_reduce operator
// first, if the number is zero, fuse_all_reduce_ops will do nothing.
AppendPassWithCheck(strategy_.fuse_all_reduce_ops_,
"fuse_all_reduce_op_pass");
AppendPrintGraphPass("multi_devices_print_pass", "_multi_devices_graph");
if (strategy_.fuse_elewise_add_act_ops_) {
VLOG(1) << "Add fuse_elewise_add_act_pass";
AppendPass("fuse_elewise_add_act_pass");
}
// experimental shows that the program will be faster if append
// all_reduce_deps_pass here.
bool append_all_reduce_deps_pass =
!strategy_.enable_parallel_graph_ &&
(SeqOnlyAllReduceOps(strategy_) ||
strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce);
AppendPassWithCheck(append_all_reduce_deps_pass, "all_reduce_deps_pass");
bool append_backward_optimizer_op_deps_pass =
strategy_.num_trainers_ > 1 && !strategy_.async_mode_ &&
!strategy_.is_distribution_ &&
strategy_.enable_backward_optimizer_op_deps_;
AppendPassWithCheck(append_backward_optimizer_op_deps_pass,
"backward_optimizer_op_deps_pass");
}
void AppendOpFusePasses() {
AppendPassWithCheck(strategy_.fuse_relu_depthwise_conv_,
"fuse_relu_depthwise_conv_pass");
AppendPassWithCheck(strategy_.fuse_elewise_add_act_ops_,
"fuse_elewise_add_act_pass");
// for single card training, fuse_all_reduce_ops is unnecessary.
// coalesce_grad_tensor_pass should be before of MultiDevPass.
if (strategy_.fuse_all_reduce_ops_) {
VLOG(1) << "Add coalesce_grad_tensor_pass";
AppendPass("coalesce_grad_tensor_pass");
}
AppendPassWithCheck(strategy_.fuse_all_reduce_ops_,
"coalesce_grad_tensor_pass");
// Fuse all the optimization operators.
if (strategy_.is_distribution_) {
VLOG(3) << "Currently, fuse_all_optimizer_ops only works under "
"Non-distributed mode.";
strategy_.fuse_all_optimizer_ops_ = false;
}
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce ||
strategy_.is_distribution_) {
VLOG(3) << "Currently, fuse_all_optimizer_ops only works under AllReduce "
"mode.";
strategy_.fuse_all_optimizer_ops_ = false;
}
// NOTE: fuse_all_xx_ops will count the number of xx operator first,
// if the number is zero, fuse_all_reduce_ops will do nothing.
// Currently, only one type of optimization algorithm can be fused.
if (strategy_.fuse_all_optimizer_ops_) {
// NOTE: fuse_all_xx_ops will count the number of xx operator first,
// if the number is zero, fuse_all_reduce_ops will do nothing.
// Currently, only one type of optimization algorithm can be fused.
VLOG(1) << "Add fuse_adam_op_pass";
AppendPass("fuse_adam_op_pass");
VLOG(1) << "Add fuse_sgd_op_pass";
AppendPass("fuse_sgd_op_pass");
VLOG(1) << "Add fuse_momentum_op_pass";
AppendPass("fuse_momentum_op_pass");
}
}
// Add a graph viz pass to record a graph.
if (!strategy.debug_graphviz_path_.empty()) {
auto viz_pass = AppendPass("graph_viz_pass");
const std::string graph_path = string::Sprintf(
"%s%s", strategy_.debug_graphviz_path_.c_str(), "_fused_graph");
viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
}
CollectiveContext *context = CollectiveContext::GetInstance();
context->endpoints_ = strategy_.trainers_endpoints_;
context->trainer_id_ = strategy_.trainer_id_;
PADDLE_ENFORCE(strategy_.trainer_id_ >= 0, "trainer_id_ >= 0");
if (strategy_.trainer_id_ > 0 && strategy_.trainers_endpoints_.size() > 0) {
PADDLE_ENFORCE((unsigned)(strategy_.trainer_id_) <
strategy_.trainers_endpoints_.size(),
"trainer_id_ < endpoints_ size");
void AppendMemoryOptimizePasses() { // Append Memory Optimize Pass
// TODO(zjl): refactor MemoryOptimizePass to fit
// new strategy, which does not need to set
// var.persistable = True
if (strategy_.use_legacy_memory_optimize_strategy_) {
AppendPassWithCheck(strategy_.enable_inplace_, "inplace_pass");
}
VLOG(1) << "CollectiveContext:" << context->String();
// NOTE(dzh): memory optimize should be a runtime pass.
// However, after multi_devices_pass, VarHandle, OpHandle is
// the de-fact IR, any reuse on Graph is meaningless.
// A side-effect of that, memory optimize cannot forsee the fetched vars
// , so fetchlist should be set persistable before call the Run interface.
if (strategy_.use_legacy_memory_optimize_strategy_) {
if (strategy_.memory_optimize_) {
VLOG(5) << "Add memory_optimize_pass";
AppendPass("memory_optimize_pass");
}
}
// runtime_context_cache pass should be the last pass to enable the attr of
// all original and fused operators. But no operators can be enabled this
// attr if putting it after MultiDevPass.
if (strategy_.cache_runtime_context_) {
VLOG(1) << "Add runtime_context_cache_pass";
AppendPass("runtime_context_cache_pass");
}
AppendMultiDevPass(strategy_);
if (strategy_.fuse_all_reduce_ops_) {
// NOTE: fuse_all_reduce_ops will count the number of all_reduce operator
// first, if the number is zero, fuse_all_reduce_ops will do nothing.
VLOG(1) << "Add fuse_all_reduce_op_pass";
AppendPass("fuse_all_reduce_op_pass");
}
// Add a graph print pass to record a graph with device info.
if (!strategy_.debug_graphviz_path_.empty()) {
VLOG(1) << "Add multi_devices_print_pass";
auto multi_devices_print_pass = AppendPass("multi_devices_print_pass");
const std::string graph_path =
string::Sprintf("%s%s", strategy_.debug_graphviz_path_.c_str(),
"_multi_devices_graph");
multi_devices_print_pass->Set<std::string>(ir::kGraphvizPath,
new std::string(graph_path));
multi_devices_print_pass->Set<ir::GraphvizSSAGraphPrinter>(
"graph_printer", new ir::GraphvizSSAGraphPrinter);
}
// experimental shows that the program will be faster if append
// all_reduce_deps_pass here.
if (!strategy_.enable_parallel_graph_ &&
(SeqOnlyAllReduceOps(strategy_) ||
strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce)) {
VLOG(1) << "Add all_reduce_deps_pass";
AppendPass("all_reduce_deps_pass");
}
if (strategy_.num_trainers_ > 1 && !strategy_.async_mode_ &&
!strategy_.is_distribution_ &&
strategy_.enable_backward_optimizer_op_deps_) {
VLOG(1) << "Add backward_op_deps_pass";
AppendPass("backward_optimizer_op_deps_pass");
AppendPassWithCheck(strategy_.memory_optimize_, "memory_optimize_pass");
}
}
if (strategy_.remove_unnecessary_lock_) {
VLOG(1) << "Add modify_op_lock_and_record_event_pass";
AppendPass("modify_op_lock_and_record_event_pass");
void SetCollectiveContext() const {
CollectiveContext *context = CollectiveContext::GetInstance();
context->endpoints_ = strategy_.trainers_endpoints_;
context->trainer_id_ = strategy_.trainer_id_;
PADDLE_ENFORCE_GE(strategy_.trainer_id_, 0, "trainer_id_ >= 0");
if (strategy_.trainer_id_ > 0 && strategy_.trainers_endpoints_.size() > 0) {
PADDLE_ENFORCE_LT(static_cast<size_t>(strategy_.trainer_id_),
strategy_.trainers_endpoints_.size(),
"trainer_id_ < endpoints_ size");
}
// Verify that the graph is correct for multi-device executor.
VLOG(1) << "Add multi_devices_check_pass";
AppendPass("multi_devices_check_pass");
VLOG(1) << "CollectiveContext:" << context->String();
}
// Convert graph to run on multi-devices.
void AppendMultiDevPass(const BuildStrategy &strategy) {
void AppendMultiDevPass() {
ir::Pass *multi_devices_pass = nullptr;
if (strategy_.async_mode_) {
VLOG(1) << "Add async_multi_devices_pass";
multi_devices_pass = AppendPass("async_multi_devices_pass").get();
} else if (strategy_.is_distribution_) {
VLOG(1)
<< "Add dist_multi_devices_pass, multi device parameter server mode";
multi_devices_pass = AppendPass("dist_multi_devices_pass").get();
} else {
if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
VLOG(1) << "Add all_reduce_mode_multi_devices_pass";
multi_devices_pass =
AppendPass("all_reduce_mode_multi_devices_pass").get();
} else if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
VLOG(1) << "Add reduce_mode_multi_devices_pass";
multi_devices_pass = AppendPass("reduce_mode_multi_devices_pass").get();
} else {
PADDLE_THROW("Unknown reduce strategy.");
switch (strategy_.reduce_) {
case BuildStrategy::ReduceStrategy::kAllReduce:
multi_devices_pass =
AppendPass("all_reduce_mode_multi_devices_pass").get();
break;
case BuildStrategy::ReduceStrategy::kReduce:
multi_devices_pass =
AppendPass("reduce_mode_multi_devices_pass").get();
break;
default:
PADDLE_THROW("Unknown reduce strategy.");
}
}
multi_devices_pass->SetNotOwned<const BuildStrategy>("strategy",
&strategy_);
}
void AppendPrintGraphPass(const std::string &pass_name,
const std::string &debug_file_suffix) {
if (!strategy_.debug_graphviz_path_.empty()) {
auto viz_pass = AppendPass(pass_name);
const std::string graph_path = string::Sprintf(
"%s%s", strategy_.debug_graphviz_path_.c_str(), debug_file_suffix);
viz_pass->Set<std::string>(ir::kGraphvizPath,
new std::string(graph_path));
}
}
void AppendPassWithCheck(bool append_pass, const std::string &pass_name) {
if (append_pass) {
AppendPass(pass_name);
}
}
void AppendPassToSetMkldnnAttr(const std::string &pass_name) {
#ifdef PADDLE_WITH_MKLDNN
if (FLAGS_use_mkldnn) {
AppendPass(pass_name);
} else if (!strategy_.mkldnn_enabled_op_types_.empty()) {
LOG(WARNING)
<< "mkldnn_enabled_op_types specify the operator type list to "
"use MKLDNN acceleration. It is null in default, means "
"that all the operators supported by MKLDNN will be "
"accelerated. And it should not be set when "
"FLAGS_use_mkldnn=false.";
}
#else
PADDLE_ENFORCE(!FLAGS_use_mkldnn,
"Please compile with MKLDNN first to use MKLDNN");
#endif
}
private:
BuildStrategy strategy_;
};
......@@ -307,26 +290,20 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
pass->Erase(kNCCLCtxs);
pass->SetNotOwned<platform::NCCLCommunicator>(kNCCLCtxs, nctx);
#endif
} else if (pass->Type() == "coalesce_grad_tensor_pass" ||
pass->Type() == "fuse_adam_op_pass" ||
pass->Type() == "fuse_sgd_op_pass" ||
pass->Type() == "fuse_momentum_op_pass" ||
pass->Type() == "fuse_all_reduce_op_pass") {
} else if (pass->Type() == "fuse_all_reduce_op_pass") {
pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
pass->Erase(kLocalScopes);
pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes,
&local_scopes);
if (pass->Type() == "fuse_all_reduce_op_pass") {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform::NCCLCommunicator *nctx = use_cuda ? nccl_ctxs : nullptr;
pass->Erase(kNCCLCtxs);
pass->SetNotOwned<platform::NCCLCommunicator>(kNCCLCtxs, nctx);
pass->Erase(kUseHierarchicalAllReduce);
pass->Set<bool>(kUseHierarchicalAllReduce,
new bool(use_hierarchical_allreduce_));
platform::NCCLCommunicator *nctx = use_cuda ? nccl_ctxs : nullptr;
pass->Erase(kNCCLCtxs);
pass->SetNotOwned<platform::NCCLCommunicator>(kNCCLCtxs, nctx);
pass->Erase(kUseHierarchicalAllReduce);
pass->Set<bool>(kUseHierarchicalAllReduce,
new bool(use_hierarchical_allreduce_));
#endif
}
} else if (pass->Type() == "coalesce_grad_tensor_pass") {
pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
......
......@@ -88,7 +88,7 @@ struct BuildStrategy {
bool fuse_elewise_add_act_ops_{false};
// Fuse_all_optimizer_ops and fuse_all_reduce_ops require that gradients
// should not be sparse types
bool fuse_all_optimizer_ops_{false};
bool fuse_all_optimizer_ops_{true};
bool fuse_all_reduce_ops_{false};
// fuse_relu_depthwise_conv can fuse the `relu ->
// depthwise_conv`
......
......@@ -483,6 +483,4 @@ class CoalesceGradTensorPass : public ir::Pass {
} // namespace paddle
REGISTER_PASS(coalesce_grad_tensor_pass,
paddle::framework::ir::CoalesceGradTensorPass)
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes);
paddle::framework::ir::CoalesceGradTensorPass);
......@@ -204,6 +204,4 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
} // namespace framework
} // namespace paddle
REGISTER_PASS(fuse_adam_op_pass, paddle::framework::ir::FuseAdamOpPass)
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes);
REGISTER_PASS(fuse_adam_op_pass, paddle::framework::ir::FuseAdamOpPass);
......@@ -87,6 +87,4 @@ class FuseMomentumOpPass : public FuseOptimizerOpPass {
} // namespace framework
} // namespace paddle
REGISTER_PASS(fuse_momentum_op_pass, paddle::framework::ir::FuseMomentumOpPass)
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes);
REGISTER_PASS(fuse_momentum_op_pass, paddle::framework::ir::FuseMomentumOpPass);
......@@ -65,6 +65,4 @@ class FuseSgdOpPass : public FuseOptimizerOpPass {
} // namespace framework
} // namespace paddle
REGISTER_PASS(fuse_sgd_op_pass, paddle::framework::ir::FuseSgdOpPass)
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes);
REGISTER_PASS(fuse_sgd_op_pass, paddle::framework::ir::FuseSgdOpPass);
......@@ -26,7 +26,7 @@ namespace paddle {
namespace framework {
namespace ir {
constexpr char kGraphvizPath[] = "debug_graphviz_path";
constexpr char kGraphvizPath[] = "graph_viz_path";
class SSAGraphPrinter {
public:
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <algorithm>
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph_printer.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/inference/analysis/dot.h"
#include "paddle/fluid/string/printf.h"
......@@ -25,8 +26,6 @@ namespace framework {
namespace ir {
using inference::analysis::Dot;
namespace {
const char kGraphVizPath[] = "graph_viz_path";
std::string FormatName(const Node* node) {
if (!node->IsOp() || !node->Op() ||
!node->Op()->HasAttr(OpProtoAndCheckerMaker::OpNamescopeAttrName())) {
......@@ -39,7 +38,7 @@ std::string FormatName(const Node* node) {
} // namespace
void GraphVizPass::ApplyImpl(ir::Graph* graph) const {
const std::string graph_viz_path = Get<std::string>(kGraphVizPath);
const std::string& graph_viz_path = Get<std::string>(kGraphvizPath);
VLOG(3) << "draw IR graph viz to " << graph_viz_path;
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_viz_path));
PADDLE_ENFORCE(fout->good());
......@@ -132,4 +131,4 @@ GraphVizPass::marked_nodes_t GraphVizPass::ConsumeMarkedNodes(
} // namespace paddle
REGISTER_PASS(graph_viz_pass, paddle::framework::ir::GraphVizPass)
.RequirePassAttr(paddle::framework::ir::kGraphVizPath);
.RequirePassAttr(paddle::framework::ir::kGraphvizPath);
......@@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.h"
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_printer.h"
namespace paddle {
namespace framework {
......@@ -29,7 +29,12 @@ class SSAGraghBuilderWithPrinterPass : public ir::Pass {
std::unique_ptr<std::ostream> fout(
new std::ofstream(Get<std::string>(kGraphvizPath)));
PADDLE_ENFORCE(fout->good());
Get<GraphvizSSAGraphPrinter>("graph_printer").Print(*graph, *fout);
if (Has("graph_printer")) {
Get<GraphvizSSAGraphPrinter>("graph_printer").Print(*graph, *fout);
} else {
GraphvizSSAGraphPrinter printer;
printer.Print(*graph, *fout);
}
}
};
......
......@@ -24,6 +24,7 @@ namespace framework {
namespace ir {
Graph* Pass::Apply(Graph* graph) const {
CheckPrevPass();
PADDLE_ENFORCE(graph, "graph passed to Pass::Apply() cannot be empty.");
for (const std::string& attr : required_pass_attrs_) {
PADDLE_ENFORCE(attrs_.find(attr) != attrs_.end(),
......@@ -41,6 +42,10 @@ Graph* Pass::Apply(Graph* graph) const {
PADDLE_ENFORCE(VarDescIsConsistency(*graph),
"The VarDescs of persistable variable are not consistency.");
applied_ = true;
if (!graph->Has(kPassRecorder)) {
graph->Set<PassRecorder>(kPassRecorder, new PassRecorder);
}
graph->Get<PassRecorder>(kPassRecorder).insert(Type());
return graph;
}
......
......@@ -20,6 +20,7 @@ limitations under the License. */
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h"
......@@ -31,6 +32,9 @@ namespace ir {
template <typename PassType>
struct PassRegistrar;
typedef std::unordered_set<std::string> PassRecorder;
constexpr char kPassRecorder[] = "pass_recorder";
class Pass {
public:
Pass() = default;
......@@ -104,6 +108,10 @@ class Pass {
LOG(FATAL) << "Calling virtual Pass not implemented.";
}
// Some Pass must be placed before this Pass, and some
// Pass must be placed after this Pass.
virtual void CheckPrevPass() const {}
private:
template <typename PassType>
friend struct PassRegistrar;
......
......@@ -13,12 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/pass_builder.h"
#include <memory>
#include <utility>
namespace paddle {
namespace framework {
namespace ir {
std::shared_ptr<Pass> PassBuilder::AppendPass(const std::string& pass_type) {
VLOG(3) << "Append " << pass_type;
auto pass = ir::PassRegistry::Instance().Get(pass_type);
passes_.emplace_back(pass.release());
return passes_.back();
......
......@@ -26,7 +26,7 @@ class SyncBatchNormPass : public Pass {
void ApplyImpl(ir::Graph *graph) const override {
VLOG(3) << "Use synchronous batch norm";
for (const Node *n : graph->Nodes()) {
if (n->IsOp()) {
if (n->IsOp() && n->Op()) {
auto *op = n->Op();
if (op->Type() == "batch_norm") {
op->SetType("sync_batch_norm");
......
......@@ -32,6 +32,7 @@ feed_dict = {
class InplaceTestBase(unittest.TestCase):
def initParameter(self):
self.use_cuda = True
self.fuse_all_optimizer_ops = False
def setUp(self):
self.initParameter()
......@@ -39,7 +40,6 @@ class InplaceTestBase(unittest.TestCase):
self.device_count = fluid.core.get_cuda_device_count()
else:
self.device_count = 4
assert batch_size % self.device_count == 0
def build_program_and_scope(self):
......@@ -90,6 +90,7 @@ class InplaceTestBase(unittest.TestCase):
build_strategy = fluid.BuildStrategy()
build_strategy.memory_optimize = memory_optimize
build_strategy.enable_inplace = enable_inplace
build_strategy.fuse_all_optimizer_ops = self.fuse_all_optimizer_ops
compiled_prog = fluid.CompiledProgram(prog).with_data_parallel(
loss_name=loss.name,
build_strategy=build_strategy,
......@@ -135,6 +136,7 @@ class InplaceTestBase(unittest.TestCase):
build_strategy = fluid.BuildStrategy()
build_strategy.memory_optimize = memory_optimize
build_strategy.enable_inplace = enable_inplace
build_strategy.fuse_all_optimizer_ops = self.fuse_all_optimizer_ops
compiled_program = fluid.CompiledProgram(
prog).with_data_parallel(
loss_name=loss.name,
......@@ -162,6 +164,19 @@ class InplaceTestBase(unittest.TestCase):
class CPUInplaceTest(InplaceTestBase):
def initParameter(self):
self.use_cuda = False
self.fuse_all_optimizer_ops = False
class CUDAInplaceTestWithFuseOptimizationOps(InplaceTestBase):
def initParameter(self):
self.use_cuda = True
self.fuse_all_optimizer_ops = True
class CPUInplaceTestWithFuseOptimizationOps(InplaceTestBase):
def initParameter(self):
self.use_cuda = True
self.fuse_all_optimizer_ops = True
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册