未验证 提交 ed61d67c 编写于 作者: C chengduo 提交者: GitHub

Fix the interface of Pass::Apply (#16484)

* modify the interface of Pass::Allay
test=develop

* Polish code
test=develop

* Fix Travis CI
test=develop

* fix Pass::Apply interface
test=develop

* Fix Travis CI
test=develop
上级 59f75ec7
......@@ -42,8 +42,7 @@ VarHandle* GetValidInput(const OpHandleBase* a) {
return nullptr;
}
std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void AllReduceDepsPass::ApplyImpl(ir::Graph* graph) const {
auto graph_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
// get vars order
......@@ -131,8 +130,6 @@ std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl(
VLOG(10) << "pre_op:" << pre_op->DebugString()
<< ", op:" << op->DebugString();
}
return graph;
}
} // namespace details
......
......@@ -24,8 +24,7 @@ namespace details {
// TODO(gongwb): overlap allreduce with backward computation.
class AllReduceDepsPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace details
......
......@@ -46,8 +46,7 @@ static framework::proto::VarType::Type kDefaultDtype =
class AllocContinuousSpaceForGradPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override {
void ApplyImpl(ir::Graph *graph) const override {
ir::Graph &result = *graph;
auto &places = Get<const std::vector<platform::Place>>(kPlaces);
......@@ -65,7 +64,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
if (params_grads.size() == 0) {
VLOG(10) << "Doesn't find gradients";
return std::move(graph);
return;
}
std::unordered_map<std::string, ir::Node *> vars;
......@@ -124,8 +123,6 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
InitFusedVarsAndAllocSpaceForVars(places, local_scopes, vars,
fused_var_name, params_grads);
return std::move(graph);
}
template <typename AttrType>
......
......@@ -204,15 +204,16 @@ bool BuildStrategy::IsMultiDevPass(const std::string &pass_name) const {
return framework::details::MultiDevSSAGraphBuilder().count(pass_name) > 0;
}
std::unique_ptr<ir::Graph> BuildStrategy::Apply(
std::unique_ptr<ir::Graph> graph,
const std::vector<platform::Place> &places,
const std::string &loss_var_name, const std::vector<Scope *> &local_scopes,
const size_t &nranks,
ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::vector<Scope *> &local_scopes,
const size_t &nranks,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const bool use_cuda, platform::NCCLContextMap *nccl_ctxs) const {
const bool use_cuda,
platform::NCCLContextMap *nccl_ctxs) const {
#else
const bool use_cuda) const {
const bool use_cuda) const {
#endif
// Create a default one if not finalized by user.
CreatePassesFromStrategy(false);
......@@ -265,7 +266,7 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
}
}
VLOG(3) << "Start Apply Pass " << pass->Type();
graph = pass->Apply(std::move(graph));
graph = pass->Apply(graph);
VLOG(3) << "Finish Apply Pass " << pass->Type();
}
return graph;
......
......@@ -120,16 +120,15 @@ struct BuildStrategy {
// Apply the passes built by the pass_builder_. The passes will be
// applied to the Program and output an ir::Graph.
std::unique_ptr<ir::Graph> Apply(std::unique_ptr<ir::Graph> graph,
const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::vector<Scope *> &local_scopes,
const size_t &nranks,
ir::Graph *Apply(ir::Graph *graph, const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::vector<Scope *> &local_scopes,
const size_t &nranks,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const bool use_cuda,
platform::NCCLContextMap *nccl_ctxs) const;
const bool use_cuda,
platform::NCCLContextMap *nccl_ctxs) const;
#else
const bool use_cuda) const;
const bool use_cuda) const;
#endif
// If set true, ParallelExecutor would build the main_program into multiple
......
......@@ -170,12 +170,10 @@ static OpToVarNameSetMap ShrinkGCVars(
class EagerDeletionPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph *graph) const override;
};
std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
auto &ref_cnts =
Get<std::vector<AtomicReferenceCountMap>>(kRuntimeReferenceCount);
PADDLE_ENFORCE(ref_cnts.empty(),
......@@ -240,7 +238,7 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
auto while_op_eager_deletion_pass =
ir::PassRegistry::Instance().Get("while_op_eager_deletion_pass");
return while_op_eager_deletion_pass->Apply(std::move(graph));
while_op_eager_deletion_pass->Apply(graph);
}
} // namespace details
......
......@@ -28,8 +28,7 @@ namespace details {
class FuseAllReduceOpPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override {
void ApplyImpl(ir::Graph *graph) const override {
ir::Graph &result = *graph;
auto &places = Get<const std::vector<platform::Place>>(kPlaces);
......@@ -71,7 +70,7 @@ class FuseAllReduceOpPass : public ir::Pass {
VLOG(10) << "Find all_reduce_ops: " << all_reduce_ops.size();
if (all_reduce_ops.size() == 0) {
return std::move(graph);
return;
}
PADDLE_ENFORCE_EQ(all_reduce_ops.size(), grads.size(),
......@@ -99,7 +98,6 @@ class FuseAllReduceOpPass : public ir::Pass {
group_all_reduce_ops, &result);
#endif
}
return std::move(graph);
}
void InsertFusedAllReduce(const std::vector<platform::Place> &places,
......
......@@ -144,10 +144,9 @@ void InplacePass::InitSSAGraphNodes() const {
}
}
std::unique_ptr<ir::Graph> InplacePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void InplacePass::ApplyImpl(ir::Graph* graph) const {
var_nodes_.clear();
view_.Build(graph.get());
view_.Build(graph);
InitSSAGraphNodes();
auto cnt = 0;
......@@ -155,11 +154,9 @@ std::unique_ptr<ir::Graph> InplacePass::ApplyImpl(
VLOG(4) << "Handle op " << cnt++ << ": " << op->Name();
if (FLAGS_enable_inplace_whitelist && !whitelist_.count(op->Name()))
continue;
TryInplaceOpInputOutput(op, graph.get());
TryInplaceOpInputOutput(op, graph);
}
// graph->ResolveHazard(var_nodes_);
return graph;
}
void InplacePass::InplaceModifyDesc(const std::string& var,
......
......@@ -69,8 +69,7 @@ class InplacePass : public ir::Pass {
InplacePass();
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
void InitSSAGraphNodes() const;
......
......@@ -44,8 +44,7 @@ namespace paddle {
namespace framework {
namespace details {
std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void MemoryOptimizePass::ApplyImpl(ir::Graph* graph) const {
auto nodes = graph->Nodes();
CollectSkipVarsSet(nodes);
......@@ -113,7 +112,7 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
cfg_->RenameVarInCFGGraph(var_name, cache_name, idx);
RenameVarInGraphDesc(var_name, cache_name, idx);
RenameVarInGraphNode(var_name, cache_name, idx, graph.get());
RenameVarInGraphNode(var_name, cache_name, idx, graph);
pool_.Erase(cache_name);
}
}
......@@ -128,8 +127,6 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
}
}
graph->ResolveHazard(var_nodes_);
return graph;
}
void MemoryOptimizePass::SubGraphOptimize(OpDesc* op_desc) const {
......
......@@ -21,6 +21,7 @@
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
......@@ -35,8 +36,7 @@ namespace details {
class MemoryOptimizePass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
// fill the variable map(var_nodes) by version.
void InitSSAGraphNodes() const;
......
......@@ -34,8 +34,7 @@ static bool IsLockAndRecordEventFreeComputationOpHandle(
return true;
}
std::unique_ptr<ir::Graph> ModifyOpLockAndRecordEventPass::ApplyImpl(
std::unique_ptr<ir::Graph> ir_graph) const {
void ModifyOpLockAndRecordEventPass::ApplyImpl(ir::Graph *ir_graph) const {
auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*ir_graph);
OpGraphView graph_view(all_ops);
for (auto &op : all_ops) {
......@@ -49,7 +48,6 @@ std::unique_ptr<ir::Graph> ModifyOpLockAndRecordEventPass::ApplyImpl(
<< compute_op->DebugString();
}
}
return ir_graph;
}
} // namespace details
......
......@@ -23,8 +23,7 @@ namespace details {
class ModifyOpLockAndRecordEventPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace details
......
......@@ -23,10 +23,8 @@ namespace details {
class SSAGraghBuilderWithChecker : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override {
PADDLE_ENFORCE(IsValidGraph(graph.get()));
return graph;
void ApplyImpl(ir::Graph *graph) const override {
PADDLE_ENFORCE(IsValidGraph(graph));
}
bool IsValidGraph(const ir::Graph *graph) const {
......
......@@ -153,8 +153,7 @@ void MultiDevSSAGraphBuilderBase::Init() const {
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
}
std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void MultiDevSSAGraphBuilderBase::ApplyImpl(ir::Graph *graph) const {
Init();
CheckGraph(*graph);
std::vector<ir::Node *> sorted_ops = SortOperations(*graph);
......@@ -236,7 +235,6 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
AddOutputToLeafOps(&result);
result.Erase(kGraphOps);
return graph;
}
void MultiDevSSAGraphBuilderBase::InsertScaleLossGradOp(
......
......@@ -36,8 +36,7 @@ namespace details {
class MultiDevSSAGraphBuilderBase : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph *graph) const override;
virtual void Init() const;
......
......@@ -13,7 +13,9 @@
// limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
......
......@@ -17,6 +17,7 @@
#include <glog/logging.h>
#include <fstream>
#include <iosfwd>
#include <memory>
#include <ostream>
#include <string>
#include "paddle/fluid/framework/details/multi_devices_helper.h"
......@@ -40,13 +41,11 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
class SSAGraghBuilderWithPrinter : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override {
void ApplyImpl(ir::Graph* graph) const override {
std::unique_ptr<std::ostream> fout(
new std::ofstream(Get<std::string>(kGraphvizPath)));
PADDLE_ENFORCE(fout->good());
Get<GraphvizSSAGraphPrinter>("graph_printer").Print(*graph, *fout);
return graph;
}
};
......
......@@ -96,7 +96,7 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
auto seq_allreduce_pass =
ir::PassRegistry::Instance().Get("all_reduce_deps_pass");
for (size_t i = 0; i < graphs_.size(); ++i) {
graphs_[i] = seq_allreduce_pass->Apply(std::move(graphs_[i]));
graphs_[i].reset(seq_allreduce_pass->Apply(graphs_[i].release()));
}
// set the correct size of thread pool to each device.
......
......@@ -266,8 +266,7 @@ static bool ShrinkNoNeedBufferVarOpDependency(
}
}
std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
auto &ref_cnts = Get<std::vector<ReferenceCountMap>>(kGlobalReferenceCount);
auto &last_live_ops_of_vars =
Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars);
......@@ -342,8 +341,6 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
// Just skip this corner case
}
}
return graph;
}
} // namespace details
......
......@@ -23,8 +23,7 @@ namespace details {
class ReferenceCountPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace details
......
......@@ -29,8 +29,7 @@ static bool IsSameOpDesc(OpDesc *op1, OpDesc *op2) {
op1->Outputs() == op2->Outputs();
}
std::unique_ptr<ir::Graph> SequentialExecutionPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void SequentialExecutionPass::ApplyImpl(ir::Graph *graph) const {
// FIXME(zjl): Insert dependencies between some distributed ops may cause
// the multi_devices_graph_pass fails. So we skip these ops here.
// Indeed, maybe we should not insert dependencies between these ops
......@@ -98,7 +97,6 @@ std::unique_ptr<ir::Graph> SequentialExecutionPass::ApplyImpl(
VLOG(10) << "Add dependencies between " << op_node_list[i - 1]->Name()
<< " and " << op_node_list[i]->Name();
}
return graph;
}
} // namespace details
......
......@@ -23,8 +23,7 @@ namespace details {
class SequentialExecutionPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace details
......
......@@ -23,8 +23,7 @@ namespace details {
class WhileOpEagerDeletionPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override {
void ApplyImpl(ir::Graph *graph) const override {
auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
// Find all while_op and while_grad_op
......@@ -50,7 +49,6 @@ class WhileOpEagerDeletionPass : public ir::Pass {
operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
while_ops, while_grad_ops);
}
return graph;
}
};
......
......@@ -29,10 +29,9 @@ namespace ir {
GET_IR_NODE(elementwise_mul); \
GET_IR_NODE(elementwise_mul_out);
std::unique_ptr<ir::Graph> AnakinFillconstantElementwisemulFuse::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void AnakinFillconstantElementwisemulFuse::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "anakin_fillconstant_elementwisemul_fuse";
FusePassBase::Init(pattern_name, graph.get());
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()
......@@ -69,12 +68,11 @@ std::unique_ptr<ir::Graph> AnakinFillconstantElementwisemulFuse::ApplyImpl(
IR_NODE_LINK_TO(scale_op, elementwise_mul_out); // Output
// Delete the unneeded nodes.
GraphSafeRemoveNodes(graph.get(),
GraphSafeRemoveNodes(graph,
{fill_constant, fill_constant_out, elementwise_mul});
};
gpd(graph.get(), handler);
return graph;
gpd(graph, handler);
}
} // namespace ir
......
......@@ -26,8 +26,7 @@ class AnakinFillconstantElementwisemulFuse : public FusePassBase {
virtual ~AnakinFillconstantElementwisemulFuse() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/attention_lstm_fuse_pass.h"
#include <string>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/lod_tensor.h"
......@@ -253,8 +254,7 @@ void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input,
// Parameters
std::unique_ptr<ir::Graph> AttentionLSTMFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void AttentionLSTMFusePass::ApplyImpl(ir::Graph* graph) const {
PDPattern external_pattern, subblock_pattern;
// Use the following variables to tell whether this model is RNN1.
......@@ -269,12 +269,11 @@ std::unique_ptr<ir::Graph> AttentionLSTMFusePass::ApplyImpl(
}
}
if (count < specified_vars.size()) {
return graph;
return;
}
// Continue to fuse.
FindWhileOp(graph.get());
return graph;
FindWhileOp(graph);
}
} // namespace ir
......
......@@ -22,8 +22,7 @@ namespace ir {
class AttentionLSTMFusePass : public FusePassBase {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
......
......@@ -77,10 +77,9 @@ void recompute_bias_and_weights(const Scope* scope, ir::Node* conv_weight,
weights_array_2d.colwise() *= scale_array;
}
std::unique_ptr<ir::Graph> ConvAffineChannelFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init(name_scope_, graph.get());
void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
PADDLE_ENFORCE(scope);
......@@ -139,7 +138,7 @@ std::unique_ptr<ir::Graph> ConvAffineChannelFusePass::ApplyImpl(
desc.SetAttr("axis", 1);
auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(graph.get(), {ac_scale, ac_bias, affine_channel});
GraphSafeRemoveNodes(graph, {ac_scale, ac_bias, affine_channel});
IR_NODE_LINK_TO(conv_out, eltwise_op);
IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op);
......@@ -147,16 +146,14 @@ std::unique_ptr<ir::Graph> ConvAffineChannelFusePass::ApplyImpl(
found_conv_ac_count++;
};
gpd(graph.get(), handler);
gpd(graph, handler);
AddStatis(found_conv_ac_count);
return graph;
}
std::unique_ptr<ir::Graph> ConvEltwiseAddAffineChannelFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init(name_scope_, graph.get());
void ConvEltwiseAddAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
PADDLE_ENFORCE(scope);
......@@ -199,7 +196,7 @@ std::unique_ptr<ir::Graph> ConvEltwiseAddAffineChannelFusePass::ApplyImpl(
eltwise->Op()->SetAttr("axis", 1);
eltwise->Op()->SetOutput("Out", std::vector<std::string>({ac_out->Name()}));
GraphSafeRemoveNodes(graph.get(),
GraphSafeRemoveNodes(graph,
{ac_scale, ac_bias, affine_channel, eltwise_out});
IR_NODE_LINK_TO(eltwise, ac_out);
......@@ -207,9 +204,8 @@ std::unique_ptr<ir::Graph> ConvEltwiseAddAffineChannelFusePass::ApplyImpl(
found_conv_ac_count++;
};
gpd(graph.get(), handler);
gpd(graph, handler);
AddStatis(found_conv_ac_count);
return graph;
}
} // namespace ir
......
......@@ -31,8 +31,7 @@ class ConvAffineChannelFusePass : public FusePassBase {
virtual ~ConvAffineChannelFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph*) const override;
const std::string name_scope_{"conv_affine_channel_fuse"};
};
......@@ -41,8 +40,7 @@ class ConvEltwiseAddAffineChannelFusePass : public FusePassBase {
virtual ~ConvEltwiseAddAffineChannelFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph*) const override;
const std::string name_scope_{"conv_eltwiseadd_affine_channel_fuse"};
};
......
......@@ -101,10 +101,9 @@ void recompute_bias_and_weights(const Scope* scope,
weights_array_2d.colwise() *= variance_array;
}
std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init(name_scope_, graph.get());
void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
PADDLE_ENFORCE(scope);
......@@ -187,7 +186,7 @@ std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl(
std::vector<std::string>({bn_out->Name()}));
GraphSafeRemoveNodes(
graph.get(),
graph,
{conv_out, bn_scale, bn_bias, bn_mean, bn_variance, batch_norm,
bn_mean_out, bn_variance_out, bn_saved_mean, bn_saved_variance});
......@@ -203,10 +202,9 @@ std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl(
desc.SetAttr("axis", 1);
auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(
graph.get(),
{bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out,
bn_variance_out, bn_saved_mean, bn_saved_variance});
GraphSafeRemoveNodes(graph, {bn_scale, bn_bias, bn_mean, bn_variance,
batch_norm, bn_mean_out, bn_variance_out,
bn_saved_mean, bn_saved_variance});
IR_NODE_LINK_TO(conv_out, eltwise_op);
IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op);
......@@ -215,16 +213,14 @@ std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl(
}
};
gpd(graph.get(), handler);
gpd(graph, handler);
AddStatis(found_conv_bn_count);
return graph;
}
std::unique_ptr<ir::Graph> ConvEltwiseAddBNFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init(name_scope_, graph.get());
void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
PADDLE_ENFORCE(scope);
......@@ -274,7 +270,7 @@ std::unique_ptr<ir::Graph> ConvEltwiseAddBNFusePass::ApplyImpl(
eltwise->Op()->SetOutput("Out", std::vector<std::string>({bn_out->Name()}));
GraphSafeRemoveNodes(
graph.get(),
graph,
{bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out,
bn_variance_out, bn_saved_mean, bn_saved_variance, eltwise_out});
......@@ -283,10 +279,9 @@ std::unique_ptr<ir::Graph> ConvEltwiseAddBNFusePass::ApplyImpl(
found_conv_bn_count++;
};
gpd(graph.get(), handler);
gpd(graph, handler);
AddStatis(found_conv_bn_count);
return graph;
}
} // namespace ir
......
......@@ -31,8 +31,7 @@ class ConvBNFusePass : public FusePassBase {
virtual ~ConvBNFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"conv_bn_fuse"};
};
......@@ -41,8 +40,7 @@ class ConvEltwiseAddBNFusePass : public FusePassBase {
virtual ~ConvEltwiseAddBNFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"conv_eltwiseadd_bn_fuse"};
};
......
......@@ -50,10 +50,9 @@ framework::proto::OpDesc PrepareOpDesc(
return *desc.Proto();
}
std::unique_ptr<ir::Graph> ConvElementwiseAddActFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "conv_elementwise_add_act_fuse";
FusePassBase::Init(pattern_name, graph.get());
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()->NewNode("x")->AsInput()->assert_is_op_input(
......@@ -95,7 +94,6 @@ std::unique_ptr<ir::Graph> ConvElementwiseAddActFusePass::ApplyImpl(
elementwise_add_out});
};
gpd(graph.get(), handler);
return graph;
}
} // namespace ir
......
......@@ -51,10 +51,9 @@ framework::proto::OpDesc PrepareOpDesc(
return *desc.Proto();
}
std::unique_ptr<ir::Graph> ConvElementwiseAdd2ActFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void ConvElementwiseAdd2ActFusePass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "conv_elementwise_add2_act_fuse";
FusePassBase::Init(pattern_name, graph.get());
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()->NewNode("x")->AsInput()->assert_is_op_input(
......@@ -92,12 +91,10 @@ std::unique_ptr<ir::Graph> ConvElementwiseAdd2ActFusePass::ApplyImpl(
// Delete the unneeded nodes.
GraphSafeRemoveNodes(
graph.get(),
{conv_op, conv_out, elementwise_add_op, elementwise_add_op_1,
elementwise_add_out, elementwise_add_out_1, act_op});
graph, {conv_op, conv_out, elementwise_add_op, elementwise_add_op_1,
elementwise_add_out, elementwise_add_out_1, act_op});
};
gpd(graph.get(), handler);
return graph;
gpd(graph, handler);
}
} // namespace ir
......
......@@ -25,8 +25,7 @@ class ConvElementwiseAdd2ActFusePass : public FusePassBase {
virtual ~ConvElementwiseAdd2ActFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
......
......@@ -48,10 +48,9 @@ framework::proto::OpDesc PrepareOpDesc(
return *desc.Proto();
}
std::unique_ptr<ir::Graph> ConvElementwiseAddActFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "conv_elementwise_add_act_fuse";
FusePassBase::Init(pattern_name, graph.get());
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()
......@@ -88,12 +87,11 @@ std::unique_ptr<ir::Graph> ConvElementwiseAddActFusePass::ApplyImpl(
IR_NODE_LINK_TO(new_conv_op, act_out); // Output
// Delete the unneeded nodes.
GraphSafeRemoveNodes(graph.get(), {conv_op, conv_out, elementwise_add_op,
elementwise_add_out, act_op});
GraphSafeRemoveNodes(graph, {conv_op, conv_out, elementwise_add_op,
elementwise_add_out, act_op});
};
gpd(graph.get(), handler);
return graph;
gpd(graph, handler);
}
} // namespace ir
......
......@@ -25,8 +25,7 @@ class ConvElementwiseAddActFusePass : public FusePassBase {
virtual ~ConvElementwiseAddActFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
......
......@@ -30,10 +30,9 @@ namespace ir {
GET_IR_NODE(elementwise_add_in_y); \
GET_IR_NODE(elementwise_add_out);
std::unique_ptr<ir::Graph> ConvElementwiseAddFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void ConvElementwiseAddFusePass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "conv_elementwise_add_fuse";
FusePassBase::Init(pattern_name, graph.get());
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()
......@@ -76,11 +75,10 @@ std::unique_ptr<ir::Graph> ConvElementwiseAddFusePass::ApplyImpl(
IR_NODE_LINK_TO(new_conv_op, elementwise_add_out); // Output
// Delete the unneeded nodes.
GraphSafeRemoveNodes(graph.get(), {conv_op, conv_out, elementwise_add_op});
GraphSafeRemoveNodes(graph, {conv_op, conv_out, elementwise_add_op});
};
gpd(graph.get(), handler);
return graph;
gpd(graph, handler);
}
} // namespace ir
......
......@@ -25,8 +25,7 @@ class ConvElementwiseAddFusePass : public FusePassBase {
virtual ~ConvElementwiseAddFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
......
......@@ -15,6 +15,8 @@
#include "paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h"
#include <algorithm>
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/operators/math/blas.h"
......@@ -201,7 +203,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
// Remove unneeded nodes.
// TODO(jczaja): Proper removing of lookup table
std::unordered_set<const Node*> marked_nodes(
//{lookup_table, mul, lstm, elementwise_add, fc_bias, W});
// {lookup_table, mul, lstm, elementwise_add, fc_bias, W});
{mul, lstm, elementwise_add, fc_bias});
GraphSafeRemoveNodes(graph, marked_nodes);
} else {
......@@ -224,15 +226,13 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
return fusion_count;
}
std::unique_ptr<ir::Graph> EmbeddingFCLSTMFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init(name_scope_, graph.get());
void EmbeddingFCLSTMFusePass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(),
true /*with_fc_bias*/);
int fusion_count =
BuildFusion(graph, name_scope_, param_scope(), true /*with_fc_bias*/);
AddStatis(fusion_count);
return graph;
}
} // namespace ir
......
......@@ -32,8 +32,7 @@ class EmbeddingFCLSTMFusePass : public FusePassBase {
virtual ~EmbeddingFCLSTMFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"embedding_fc_lstm_fuse"};
};
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/fc_fuse_pass.h"
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -22,10 +23,9 @@ namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init("fc_fuse", graph.get());
void FCFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init("fc_fuse", graph);
std::unordered_set<Node*> nodes2delete;
......@@ -61,7 +61,7 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
desc.SetAttr("in_num_col_dims", mul->Op()->GetAttr("x_num_col_dims"));
desc.SetType("fc");
auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(graph.get(), {mul, elementwise_add, mul_out});
GraphSafeRemoveNodes(graph, {mul, elementwise_add, mul_out});
PADDLE_ENFORCE(subgraph.count(x));
IR_NODE_LINK_TO(subgraph.at(x), fc_node);
......@@ -72,10 +72,9 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
found_fc_count++;
};
gpd(graph.get(), handler);
gpd(graph, handler);
AddStatis(found_fc_count);
return graph;
}
} // namespace ir
......
......@@ -31,8 +31,7 @@ class FCFusePass : public FusePassBase {
virtual ~FCFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
......
......@@ -73,7 +73,7 @@ TEST(FCFusePass, basic) {
int pre_nodes = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
graph.reset(pass->Apply(graph.release()));
int after_nodes = graph->Nodes().size();
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/fc_gru_fuse_pass.h"
#include <string>
#include <unordered_set>
#include "paddle/fluid/framework/lod_tensor.h"
namespace paddle {
......@@ -39,7 +40,6 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
// Create New OpDesc
auto gru_creater = [&](Node* gru, Node* x, Node* weight_x, Node* weight_h,
Node* bias, Node* hidden, Node* fc_bias) {
OpDesc op_desc;
op_desc.SetType("fusion_gru");
......@@ -155,26 +155,22 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
return fusion_count;
}
std::unique_ptr<ir::Graph> MulGRUFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init(name_scope_, graph.get());
void MulGRUFusePass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(),
false /*with_fc_bias*/);
int fusion_count =
BuildFusion(graph, name_scope_, param_scope(), false /*with_fc_bias*/);
AddStatis(fusion_count);
return graph;
}
std::unique_ptr<ir::Graph> FCGRUFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init(name_scope_, graph.get());
void FCGRUFusePass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(),
true /*with_fc_bias*/);
int fusion_count =
BuildFusion(graph, name_scope_, param_scope(), true /*with_fc_bias*/);
AddStatis(fusion_count);
return graph;
}
} // namespace ir
......
......@@ -30,8 +30,7 @@ class FCGRUFusePass : public FusePassBase {
virtual ~FCGRUFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"fc_gru_fuse"};
};
......@@ -42,8 +41,7 @@ class MulGRUFusePass : public FusePassBase {
virtual ~MulGRUFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"fc_nobias_gru_fuse"};
};
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
#include <string>
#include <unordered_set>
#include "paddle/fluid/framework/lod_tensor.h"
namespace paddle {
......@@ -157,26 +158,22 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
return fusion_count;
}
std::unique_ptr<ir::Graph> MulLstmFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init(name_scope_, graph.get());
void MulLstmFusePass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(),
false /*with_fc_bias*/);
int fusion_count =
BuildFusion(graph, name_scope_, param_scope(), false /*with_fc_bias*/);
AddStatis(fusion_count);
return graph;
}
std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init(name_scope_, graph.get());
void FCLstmFusePass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(),
true /*with_fc_bias*/);
int fusion_count =
BuildFusion(graph, name_scope_, param_scope(), true /*with_fc_bias*/);
AddStatis(fusion_count);
return graph;
}
} // namespace ir
......
......@@ -32,8 +32,7 @@ class FCLstmFusePass : public FusePassBase {
virtual ~FCLstmFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"fc_lstm_fuse"};
};
......@@ -43,8 +42,7 @@ class MulLstmFusePass : public FusePassBase {
virtual ~MulLstmFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"fc_nobias_lstm_fuse"};
};
......
......@@ -15,6 +15,8 @@
#include "paddle/fluid/framework/ir/fuse_elewise_add_act_pass.h"
#include <algorithm>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -23,29 +25,25 @@ namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<ir::Graph> FuseElewiseAddActPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void FuseElewiseAddActPass::ApplyImpl(ir::Graph *graph) const {
std::unordered_set<std::string> act_types = {"relu", "scale"};
graph = FuseActElewiseAdd(std::move(graph), act_types);
graph = FuseElewiseAddAct(std::move(graph), act_types);
graph = FuseActElewiseAdd(graph, act_types);
graph = FuseElewiseAddAct(graph, act_types);
// backward
{
std::unordered_set<std::string> in_place_act_types = {"relu_grad"};
graph = FuseElewiseAddActInplaceGrad(std::move(graph), in_place_act_types);
graph = FuseElewiseAddActInplaceGrad(graph, in_place_act_types);
}
// Remove the removable intermediate_out.
RemoveIntermediateOut(graph.get());
return graph;
RemoveIntermediateOut(graph);
}
// ele_add(x, act(y))
std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseElewiseAddAct(
std::unique_ptr<ir::Graph> graph,
const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init("elewise_add_act", graph.get());
ir::Graph *FuseElewiseAddActPass::FuseElewiseAddAct(
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init("elewise_add_act", graph);
GraphPatternDetector gpd;
auto *x = gpd.mutable_pattern()
......@@ -86,18 +84,17 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseElewiseAddAct(
found_elewise_add_act_count++;
};
gpd(graph.get(), handler);
gpd(graph, handler);
AddStatis(found_elewise_add_act_count);
return graph;
}
// act(ele_add(x,y))
std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseActElewiseAdd(
std::unique_ptr<ir::Graph> graph,
const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init("act_elewise_add", graph.get());
ir::Graph *FuseElewiseAddActPass::FuseActElewiseAdd(
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init("act_elewise_add", graph);
GraphPatternDetector gpd;
auto *x = gpd.mutable_pattern()
......@@ -137,7 +134,7 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseActElewiseAdd(
found_elewise_add_act_count++;
};
gpd(graph.get(), handler);
gpd(graph, handler);
AddStatis(found_elewise_add_act_count);
return graph;
......@@ -146,11 +143,10 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseActElewiseAdd(
// the backward of act(ele_add(x,y))
// act_grad: in["Out", "Out@GRAD"], out["X@GRAD"]
// ele_add_grad: in["Y", "Out@GRAD"], out["X@GRAD", "Y@GRAD"]
std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad(
std::unique_ptr<ir::Graph> graph,
const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init("elewise_add_act_grad", graph.get());
ir::Graph *FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad(
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init("elewise_add_act_grad", graph);
GraphPatternDetector gpd;
auto *d_act_out = gpd.mutable_pattern()
......@@ -217,7 +213,7 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad(
found_elewise_add_act_count++;
};
gpd(graph.get(), handler);
gpd(graph, handler);
AddStatis(found_elewise_add_act_count);
return graph;
......
......@@ -14,6 +14,8 @@
#pragma once
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
......@@ -32,20 +34,16 @@ class FuseElewiseAddActPass : public FusePassBase {
virtual ~FuseElewiseAddActPass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph *graph) const override;
std::unique_ptr<ir::Graph> FuseElewiseAddAct(
std::unique_ptr<ir::Graph> graph,
const std::unordered_set<std::string> &act_types) const;
ir::Graph *FuseElewiseAddAct(
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const;
std::unique_ptr<ir::Graph> FuseActElewiseAdd(
std::unique_ptr<ir::Graph> graph,
const std::unordered_set<std::string> &act_types) const;
ir::Graph *FuseActElewiseAdd(
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const;
std::unique_ptr<ir::Graph> FuseElewiseAddActInplaceGrad(
std::unique_ptr<ir::Graph> graph,
const std::unordered_set<std::string> &act_types) const;
ir::Graph *FuseElewiseAddActInplaceGrad(
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const;
/**
* Remove the removable intermediate_out.
......
......@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/ir/fuse_relu_depthwise_conv_pass.h"
#include <algorithm>
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -23,20 +24,18 @@ namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<ir::Graph> FuseReluDepthwiseConvPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
graph = FuseReluDepthwiseConv(std::move(graph), true);
graph = FuseReluDepthwiseConv(std::move(graph), false);
return graph;
void FuseReluDepthwiseConvPass::ApplyImpl(ir::Graph *graph) const {
graph = FuseReluDepthwiseConv(graph, true);
graph = FuseReluDepthwiseConv(graph, false);
}
std::unique_ptr<ir::Graph> FuseReluDepthwiseConvPass::FuseReluDepthwiseConv(
std::unique_ptr<ir::Graph> graph, bool only_forward) const {
PADDLE_ENFORCE(graph.get());
ir::Graph *FuseReluDepthwiseConvPass::FuseReluDepthwiseConv(
ir::Graph *graph, bool only_forward) const {
PADDLE_ENFORCE(graph);
if (only_forward)
FusePassBase::Init("relu_depthwise_conv_only_forward", graph.get());
FusePassBase::Init("relu_depthwise_conv_only_forward", graph);
else
FusePassBase::Init("relu_depthwise_conv", graph.get());
FusePassBase::Init("relu_depthwise_conv", graph);
/*
x ---act--> y ---layer-> z
+----------+
......@@ -144,10 +143,9 @@ std::unique_ptr<ir::Graph> FuseReluDepthwiseConvPass::FuseReluDepthwiseConv(
}
count++;
};
gpd(graph.get(), handler);
GraphSafeRemoveNodes(graph.get(), need_removed_nodes);
gpd(graph, handler);
GraphSafeRemoveNodes(graph, need_removed_nodes);
AddStatis(count);
return graph;
}
......
......@@ -32,10 +32,8 @@ class FuseReluDepthwiseConvPass : public FusePassBase {
virtual ~FuseReluDepthwiseConvPass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
std::unique_ptr<ir::Graph> FuseReluDepthwiseConv(
std::unique_ptr<ir::Graph> graph, bool only_forward) const;
void ApplyImpl(ir::Graph* graph) const override;
ir::Graph* FuseReluDepthwiseConv(ir::Graph* graph, bool only_forward) const;
};
} // namespace ir
......
......@@ -15,7 +15,9 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include <map>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
......@@ -26,8 +28,7 @@ namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<Graph> GraphToProgramPass::ApplyImpl(
std::unique_ptr<Graph> graph) const {
void GraphToProgramPass::ApplyImpl(ir::Graph* graph) const {
// Remove the unneeded variables after memory optimization.
std::unordered_set<std::string> vars2remove;
if (graph->Has(kGraphToProgramVarsToRemove)) {
......@@ -73,7 +74,6 @@ std::unique_ptr<Graph> GraphToProgramPass::ApplyImpl(
}
program.CopyFrom(*program_pb);
return graph;
}
} // namespace ir
......
......@@ -26,7 +26,7 @@ const char kGraphToProgramSortKind[] = "__graph_to_program_sort_kind__";
class GraphToProgramPass : public Pass {
protected:
std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
......
......@@ -14,7 +14,9 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/program_desc.h"
......@@ -84,7 +86,7 @@ TEST(GraphToProgramPass, Basic) {
ProgramDesc compiled_prog;
pass->SetNotOwned<paddle::framework::ProgramDesc>("program", &compiled_prog);
pass->Apply(std::move(g));
pass->Apply(g.get());
std::vector<OpDesc*> ops = compiled_prog.Block(0).AllOps();
EXPECT_EQ(ops[0]->Type(), "op1");
EXPECT_EQ(ops[1]->Type(), "op2");
......
......@@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include <algorithm>
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/inference/analysis/dot.h"
#include "paddle/fluid/string/printf.h"
......@@ -38,8 +38,7 @@ std::string FormatName(const Node* node) {
}
} // namespace
std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void GraphVizPass::ApplyImpl(ir::Graph* graph) const {
const std::string graph_viz_path = Get<std::string>(kGraphVizPath);
VLOG(3) << "draw IR graph viz to " << graph_viz_path;
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_viz_path));
......@@ -82,7 +81,7 @@ std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
{Dot::Attr("style", "filled,rounded"), Dot::Attr("shape", "box"),
Dot::Attr("fillcolor", "yellow")});
auto marked_nodes = ConsumeMarkedNodes(graph.get());
auto marked_nodes = ConsumeMarkedNodes(graph);
// Create nodes
for (const Node* n : graph->Nodes()) {
std::string node_id = FormatName(n) + "(" + std::to_string(n->id()) + ")";
......@@ -115,8 +114,6 @@ std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
}
sout << dot.Build();
return graph;
}
GraphVizPass::marked_nodes_t GraphVizPass::ConsumeMarkedNodes(
......@@ -135,4 +132,4 @@ GraphVizPass::marked_nodes_t GraphVizPass::ConsumeMarkedNodes(
} // namespace paddle
REGISTER_PASS(graph_viz_pass, paddle::framework::ir::GraphVizPass)
.RequirePassAttr(paddle::framework::ir::kGraphVizPath);
\ No newline at end of file
.RequirePassAttr(paddle::framework::ir::kGraphVizPath);
......@@ -18,6 +18,7 @@ limitations under the License. */
#include <map>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
......@@ -34,8 +35,7 @@ class GraphVizPass : public Pass {
using marked_nodes_t = std::unordered_set<const Node*>;
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
// Tell whether there are any marked nodes in the graph. Consume the
// corresponding attribute.
......
......@@ -20,9 +20,8 @@ namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<ir::Graph> IdentityScaleOpCleanPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init("identity_scale_op_clean", graph.get());
void IdentityScaleOpCleanPass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init("identity_scale_op_clean", graph);
// pre_op -> scale_in -> scale_op -> scale_out
// ->
......@@ -72,8 +71,7 @@ std::unique_ptr<ir::Graph> IdentityScaleOpCleanPass::ApplyImpl(
IR_NODE_LINK_TO(pre_op_var, scale_out_var);
};
detector(graph.get(), handler);
return graph;
detector(graph, handler);
}
} // namespace ir
......
......@@ -22,8 +22,7 @@ namespace ir {
class IdentityScaleOpCleanPass : public FusePassBase {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
private:
virtual ~IdentityScaleOpCleanPass() = default;
......
......@@ -26,9 +26,9 @@ class InferCleanGraphPass : public FusePassBase {
virtual ~InferCleanGraphPass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init("original_graph", graph.get());
PADDLE_ENFORCE(graph.get());
void ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init("original_graph", graph);
PADDLE_ENFORCE(graph);
auto is_valid_node = [](Node* x) {
return x && IsControlDepVar(*x) && x->IsVar() && !x->Var();
......@@ -46,11 +46,9 @@ class InferCleanGraphPass : public FusePassBase {
}
}
GraphSafeRemoveNodes(graph.get(), invalid_nodes);
GraphSafeRemoveNodes(graph, invalid_nodes);
AddStatis(valid_op);
return graph;
}
void CleanEdges(std::vector<Node*>* nodes,
......
......@@ -20,8 +20,7 @@ namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<ir::Graph> IsTestPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void IsTestPass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Sets is_test attrbiute to true and if it is missing, inserts it "
"for activations and pooling.";
auto op_list = {"pool2d", "sigmoid", "logsigmoid",
......@@ -47,7 +46,6 @@ std::unique_ptr<ir::Graph> IsTestPass::ApplyImpl(
}
}
}
return graph;
}
} // namespace ir
......
......@@ -22,8 +22,7 @@ namespace ir {
class IsTestPass : public Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
......
......@@ -97,7 +97,7 @@ TEST(IsTestPass, basic) {
auto pass = PassRegistry::Instance().Get("is_test_pass");
graph = pass->Apply(std::move(graph));
graph.reset(pass->Apply(graph.release()));
for (auto* node : graph->Nodes()) {
if (node->IsOp()) {
......
......@@ -32,9 +32,8 @@ const char kSumGradOpName[] = "sum";
// other optimizers later.
const char kOptimizerType[] = "sgd";
std::unique_ptr<ir::Graph> LockFreeOptimizePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
void LockFreeOptimizePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
// We could collect all weights' name from SGD, where
// W1 <- SGD(W0, Grad0)
......@@ -92,14 +91,14 @@ std::unique_ptr<ir::Graph> LockFreeOptimizePass::ApplyImpl(
// find the forward op related to the backward op
ir::Node* forward_op =
FindForwardOpViaBackwardOp(graph.get(), backward_op);
FindForwardOpViaBackwardOp(graph, backward_op);
VLOG(3) << "Found forward_op " << forward_op->Name();
PADDLE_ENFORCE(forward_op);
Node* new_optimizer_node = CreateNewSGDNode(
graph.get(), forward_op, backward_op, node, opt_node);
graph, forward_op, backward_op, node, opt_node);
PADDLE_ENFORCE(new_optimizer_node);
}
......@@ -140,8 +139,6 @@ std::unique_ptr<ir::Graph> LockFreeOptimizePass::ApplyImpl(
}
}
}
return graph;
}
ir::Node* LockFreeOptimizePass::CreateNewSGDNode(
......
......@@ -60,8 +60,7 @@ class LockFreeOptimizePass : public Pass {
virtual ~LockFreeOptimizePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
private:
// Create a new sgd node via current optimizer node
......
......@@ -38,10 +38,9 @@ LoDTensor tensor_apply_eltwise(const LoDTensor& vec_a, const LoDTensor& vec_b,
return vec_y;
}
std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init(name_scope_, graph.get());
void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
PADDLE_ENFORCE(scope);
......@@ -99,7 +98,7 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
conv->Op()->SetOutput("Output",
std::vector<std::string>({eltwise_out->Name()}));
GraphSafeRemoveNodes(graph.get(), {eltwise, conv_out});
GraphSafeRemoveNodes(graph, {eltwise, conv_out});
IR_NODE_LINK_TO(conv, eltwise_out);
} else {
......@@ -123,14 +122,13 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
IR_NODE_LINK_TO(eltwise_bias, conv_bias_node);
IR_NODE_LINK_TO(conv_bias_node, eltwise_out);
GraphSafeRemoveNodes(graph.get(), {conv, eltwise, conv_out});
GraphSafeRemoveNodes(graph, {conv, eltwise, conv_out});
}
found_conv_bias_count++;
};
gpd(graph.get(), handler);
gpd(graph, handler);
AddStatis(found_conv_bias_count);
return graph;
}
} // namespace ir
} // namespace framework
......
......@@ -29,8 +29,7 @@ class ConvBiasFusePass : public FusePassBase {
virtual bool is_conv3d() const { return false; }
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"conv_bias_mkldnn_fuse"};
};
/*
......
......@@ -13,10 +13,10 @@
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/platform/place.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
......@@ -103,7 +103,7 @@ void MainTest(bool convWithExistingBias) {
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size();
......
......@@ -16,8 +16,8 @@
#include <functional>
#include <list>
#include <map>
#include <memory>
#include <tuple>
#include "paddle/fluid/framework/ir/graph_traits.h"
namespace paddle {
......@@ -327,17 +327,15 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
get_node_from_elementwise_add);
}
graph_ptr ResidualConnectionMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
FusePassBase::Init(name_scope_, graph.get());
void ResidualConnectionMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
FusePassBase::Init(name_scope_, graph);
auto fused_graph_with_stats = FuseConvAsY(
name_scope_,
FuseConvAsX(
name_scope_,
FuseProjectionConv(name_scope_, std::make_pair(graph.get(), 0))));
FuseConvAsX(name_scope_,
FuseProjectionConv(name_scope_, std::make_pair(graph, 0))));
std::cout << "Fused graph " << fused_graph_with_stats.second << std::endl;
AddStatis(fused_graph_with_stats.second);
return graph;
}
} // namespace ir
} // namespace framework
......
......@@ -14,6 +14,7 @@
#pragma once
#include <memory>
#include <string>
#include <tuple>
#include <utility>
......@@ -27,7 +28,7 @@ namespace paddle {
namespace framework {
namespace ir {
using graph_ptr = std::unique_ptr<ir::Graph>;
using graph_ptr = ir::Graph*;
using GraphWithStats = std::pair<ir::Graph*, int>;
void CorrectGraphEdges(Graph* graph, Node* from, Node* to);
......@@ -124,7 +125,7 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
virtual ~ResidualConnectionMKLDNNFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(graph_ptr graph) const;
void ApplyImpl(graph_ptr graph) const;
const std::string name_scope_{"residual_connection_fuse_pass"};
};
......
......@@ -148,7 +148,7 @@ void RunPassAndAssert(ProgramDesc* prog, const std::string& from,
auto pass =
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size();
EXPECT_TRUE(is_reachable(graph)(from, to));
......@@ -258,7 +258,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, NoFusion) {
auto pass =
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size();
EXPECT_TRUE(is_reachable(graph)("a", "g"));
......
......@@ -21,10 +21,9 @@ namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init("conv_relu_mkldnn_fuse", graph.get());
void ConvReLUFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init("conv_relu_mkldnn_fuse", graph);
GraphPatternDetector gpd;
auto* conv_input = gpd.mutable_pattern()
......@@ -56,7 +55,7 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl(
OpDesc* desc = conv->Op();
desc->SetOutput("Output", std::vector<std::string>({relu_out->Name()}));
desc->SetAttr("fuse_relu", true);
GraphSafeRemoveNodes(graph.get(), {relu, conv_out});
GraphSafeRemoveNodes(graph, {relu, conv_out});
PADDLE_ENFORCE(subgraph.count(conv_input));
IR_NODE_LINK_TO(conv, relu_out);
......@@ -64,10 +63,9 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl(
found_conv_relu_count++;
};
gpd(graph.get(), handler);
gpd(graph, handler);
AddStatis(found_conv_relu_count);
return graph;
}
} // namespace ir
......
......@@ -31,8 +31,7 @@ class ConvReLUFusePass : public FusePassBase {
virtual ~ConvReLUFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
......
......@@ -88,7 +88,7 @@ TEST(ConvReLUFusePass, basic) {
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size();
......
......@@ -216,19 +216,16 @@ void CPUQuantizePass::QuantizePool(Graph* graph) const {
PrettyLogDetail("--- quantized %d pool2d ops", quantize_pool_count);
}
std::unique_ptr<ir::Graph> CPUQuantizePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Quantizing the graph.";
PADDLE_ENFORCE(graph.get());
FusePassBase::Init(name_scope_, graph.get());
PADDLE_ENFORCE(graph);
FusePassBase::Init(name_scope_, graph);
PADDLE_ENFORCE(param_scope());
QuantizeConv(graph.get(), false /* with_residual_data */);
QuantizeConv(graph.get(), true /* with_residual_data */);
QuantizePool(graph.get());
return graph;
QuantizeConv(graph, false /* with_residual_data */);
QuantizeConv(graph, true /* with_residual_data */);
QuantizePool(graph);
}
} // namespace ir
......
......@@ -42,8 +42,7 @@ class CPUQuantizePass : public FusePassBase {
virtual ~CPUQuantizePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
void QuantizeConv(Graph* graph, bool with_residual_data = false) const;
......
......@@ -139,7 +139,7 @@ void MainTest(const ProgramDesc& prog, int conv_count, int pool_count,
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size();
......
......@@ -20,8 +20,7 @@ namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<ir::Graph> CPUQuantizePlacementPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Marks operators which are to be quantized.";
const auto& excluded_ids_list =
Get<std::unordered_set<int>>("quantize_excluded_op_ids");
......@@ -43,7 +42,6 @@ std::unique_ptr<ir::Graph> CPUQuantizePlacementPass::ApplyImpl(
}
}
}
return graph;
}
} // namespace ir
......
......@@ -25,8 +25,7 @@ namespace ir {
*/
class CPUQuantizePlacementPass : public Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
......
......@@ -94,7 +94,7 @@ void MainTest(std::initializer_list<std::string> quantize_enabled_op_types,
pass->Set("quantize_excluded_op_ids",
new std::unordered_set<int>(quantize_excluded_op_ids));
graph = pass->Apply(std::move(graph));
graph.reset(pass->Apply(graph.release()));
unsigned use_quantizer_true_count = 0;
......
......@@ -126,16 +126,13 @@ void CPUQuantizeSquashPass::Squash(
found_squash_count);
}
std::unique_ptr<ir::Graph> CPUQuantizeSquashPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init("cpu_quantize_squash_pass", graph.get());
void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init("cpu_quantize_squash_pass", graph);
std::unordered_map<const Node*, int> nodes_keep_counter;
FindNodesToKeep(graph.get(), &nodes_keep_counter);
Squash(graph.get(), &nodes_keep_counter);
return graph;
FindNodesToKeep(graph, &nodes_keep_counter);
Squash(graph, &nodes_keep_counter);
}
} // namespace ir
......
......@@ -34,8 +34,7 @@ class CPUQuantizeSquashPass : public FusePassBase {
virtual ~CPUQuantizeSquashPass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
/*
* For each dequantize's output find the number of operators it is an input to
......
......@@ -125,7 +125,7 @@ void MainTest(const ProgramDesc& prog, int removed_nodes_num) {
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size();
......
......@@ -25,10 +25,9 @@ namespace ir {
auto* id = subgraph.at(pattern.RetrieveNode(#id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
std::unique_ptr<ir::Graph> DepthwiseConvMKLDNNPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init("depthwise_conv_mkldnn_pass", graph.get());
void DepthwiseConvMKLDNNPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init("depthwise_conv_mkldnn_pass", graph);
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
......@@ -45,9 +44,8 @@ std::unique_ptr<ir::Graph> DepthwiseConvMKLDNNPass::ApplyImpl(
found_depthwise_conv_mkldnn_count++;
};
gpd(graph.get(), handler);
gpd(graph, handler);
AddStatis(found_depthwise_conv_mkldnn_count);
return graph;
}
} // namespace ir
......
......@@ -25,8 +25,7 @@ class DepthwiseConvMKLDNNPass : public FusePassBase {
virtual ~DepthwiseConvMKLDNNPass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
......
......@@ -86,7 +86,7 @@ TEST(DepthwiseConvMKLDNNPass, basic) {
counters before{1, 1, 1, 1};
graph = pass->Apply(std::move(graph));
graph.reset(pass->Apply(graph.release()));
// initialize counters before loop
counters after{0, 0, 0, 0};
......
......@@ -14,13 +14,13 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h"
#include <string>
#include <unordered_set>
namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<ir::Graph> MKLDNNPlacementPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void MKLDNNPlacementPass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Applies MKL-DNN placement strategy.";
const auto& op_types_list =
Get<std::unordered_set<std::string>>("mkldnn_enabled_op_types");
......@@ -37,7 +37,6 @@ std::unique_ptr<ir::Graph> MKLDNNPlacementPass::ApplyImpl(
}
}
}
return graph;
}
} // namespace ir
......
......@@ -26,8 +26,7 @@ namespace ir {
*/
class MKLDNNPlacementPass : public Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
......
......@@ -97,7 +97,7 @@ void MainTest(std::initializer_list<std::string> mkldnn_enabled_op_types,
pass->Set("mkldnn_enabled_op_types",
new std::unordered_set<std::string>(mkldnn_enabled_op_types));
graph = pass->Apply(std::move(graph));
graph.reset(pass->Apply(graph.release()));
unsigned use_mkldnn_true_count = 0;
......
......@@ -16,8 +16,9 @@
#include <map>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_proto_maker.h"
......@@ -68,8 +69,7 @@ VarDesc UpdateGradVarDesc(
return *var_desc;
}
std::unique_ptr<Graph> BatchMergePass::ApplyImpl(
std::unique_ptr<Graph> graph) const {
void BatchMergePass::ApplyImpl(ir::Graph* graph) const {
int num_repeats = Get<const int>(kNumRepeats);
std::vector<Node*> forward_backward_ops;
std::vector<Node*> optimize_ops;
......@@ -325,7 +325,6 @@ std::unique_ptr<Graph> BatchMergePass::ApplyImpl(
}
result.ResolveHazard(created);
return graph;
}
} // namespace ir
......
......@@ -36,7 +36,7 @@ class BatchMergePass : public Pass {
virtual ~BatchMergePass() {}
protected:
std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const override;
void ApplyImpl(Graph* graph) const override;
};
} // namespace ir
......
......@@ -18,8 +18,8 @@ limitations under the License. */
namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<Graph> Pass::Apply(std::unique_ptr<Graph> graph) const {
PADDLE_ENFORCE(graph.get(), "graph passed to Pass::Apply() cannot be empty.");
Graph* Pass::Apply(Graph* graph) const {
PADDLE_ENFORCE(graph, "graph passed to Pass::Apply() cannot be empty.");
for (const std::string& attr : required_pass_attrs_) {
PADDLE_ENFORCE(attrs_.find(attr) != attrs_.end(),
"Required pass atrribute %s not set.", attr);
......@@ -28,16 +28,16 @@ std::unique_ptr<Graph> Pass::Apply(std::unique_ptr<Graph> graph) const {
PADDLE_ENFORCE(graph->Has(attr), "Required graph atrribute %s not set.",
attr);
}
auto* native_graph = graph.get();
auto applied_graph = ApplyImpl(std::move(graph));
auto* native_graph = graph;
ApplyImpl(graph);
// TODO(panyx0718): Add more verifications.
PADDLE_ENFORCE(!HasCircle(*applied_graph),
PADDLE_ENFORCE(!HasCircle(*graph),
"Illegal Pass. Generated graph shouldn't has cycle.");
PADDLE_ENFORCE(applied_graph.get() == native_graph,
PADDLE_ENFORCE(graph == native_graph,
"Pass::Apply() cannot delete the passed graph and shouldn't "
"return a new graph.(For the need of pybind11)");
applied_ = true;
return applied_graph;
return graph;
}
PassRegistry& PassRegistry::Instance() {
......
......@@ -16,8 +16,10 @@ limitations under the License. */
#include <functional>
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h"
......@@ -44,7 +46,7 @@ class Pass {
std::string Type() const { return type_; }
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const;
Graph *Apply(Graph *graph) const;
// Get a reference to the attributed previously set.
template <typename AttrType>
......@@ -98,9 +100,8 @@ class Pass {
}
protected:
virtual std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const {
virtual void ApplyImpl(Graph *graph) const {
LOG(FATAL) << "Calling virtual Pass not implemented.";
return graph;
}
private:
......
......@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/pass.h"
#include <memory>
#include <string>
#include <utility>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/ir/graph.h"
......@@ -39,7 +41,7 @@ void BuildCircleGraph(Graph* g) {
class TestPass : public Pass {
protected:
std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const {
void ApplyImpl(ir::Graph* graph) const {
graph->Set<int>("copy_test_pass_attr", new int);
graph->Set<int>("copy_test_graph_attr", new int);
......@@ -48,7 +50,6 @@ class TestPass : public Pass {
int test_graph_attr = graph->Get<int>("test_graph_attr");
graph->Get<int>("copy_test_graph_attr") = test_graph_attr + 1;
return graph;
}
};
......@@ -58,7 +59,7 @@ TEST(PassTest, TestPassAttrCheck) {
std::unique_ptr<Graph> graph(new Graph(prog));
std::string exception;
try {
graph = pass->Apply(std::move(graph));
graph.reset(pass->Apply(graph.release()));
} catch (paddle::platform::EnforceNotMet e) {
exception = std::string(e.what());
}
......@@ -69,7 +70,7 @@ TEST(PassTest, TestPassAttrCheck) {
pass->SetNotOwned<int>("test_pass_attr", &val);
try {
graph = pass->Apply(std::move(graph));
graph.reset(pass->Apply(graph.release()));
} catch (paddle::platform::EnforceNotMet e) {
exception = std::string(e.what());
}
......@@ -78,14 +79,14 @@ TEST(PassTest, TestPassAttrCheck) {
graph.reset(new Graph(prog));
graph->Set<int>("test_graph_attr", new int);
graph->Get<int>("test_graph_attr") = 1;
graph = pass->Apply(std::move(graph));
graph.reset(pass->Apply(graph.release()));
ASSERT_EQ(graph->Get<int>("copy_test_pass_attr"), 2);
ASSERT_EQ(graph->Get<int>("copy_test_graph_attr"), 2);
// Allow apply more than once.
graph.reset(new Graph(prog));
graph->Set<int>("test_graph_attr", new int);
graph = pass->Apply(std::move(graph));
graph.reset(pass->Apply(graph.release()));
pass = PassRegistry::Instance().Get("test_pass");
pass->SetNotOwned<int>("test_pass_attr", &val);
......@@ -94,7 +95,7 @@ TEST(PassTest, TestPassAttrCheck) {
graph->Set<int>("test_graph_attr", new int);
graph->Get<int>("test_graph_attr") = 2;
try {
auto tmp = pass->Apply(std::move(graph));
pass->Apply(graph.release());
} catch (paddle::platform::EnforceNotMet e) {
exception = std::string(e.what());
}
......
......@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h"
#include <algorithm> // for max
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
......@@ -365,17 +366,14 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
return fusion_count;
}
std::unique_ptr<ir::Graph> RepeatedFCReluFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init(name_scope_, graph.get());
void RepeatedFCReluFusePass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
int fusion_count = 0;
for (int i = MAX_NUM_FC; i > 1; --i) {
fusion_count +=
BuildFusion(graph.get(), name_scope_ + "/" + std::to_string(i), i);
BuildFusion(graph, name_scope_ + "/" + std::to_string(i), i);
}
AddStatis(fusion_count);
return graph;
}
} // namespace ir
......
......@@ -31,8 +31,7 @@ class RepeatedFCReluFusePass : public FusePassBase {
virtual ~RepeatedFCReluFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"repeated_fc_relu_fuse"};
};
......
......@@ -20,15 +20,13 @@ namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<ir::Graph> RuntimeContextCachePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void RuntimeContextCachePass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Applies Runtime Context Cache strategy.";
for (const Node* n : graph->Nodes()) {
if (n->IsOp()) {
n->Op()->SetAttr(kEnableCacheRuntimeContext, true);
}
}
return graph;
}
} // namespace ir
......
......@@ -23,8 +23,7 @@ namespace ir {
class RuntimeContextCachePass : public Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
......
......@@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.h"
#include <set>
#include <string>
#include <unordered_set>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.h"
#include "paddle/fluid/framework/lod_tensor.h"
namespace paddle {
......@@ -178,9 +178,8 @@ PDNode* BuildFCPattern(PDPattern* pattern, PDNode* fc_x) {
return fc_out;
}
std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init("seq_concat_fc_fuse", graph.get());
void SeqConcatFcFusePass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init("seq_concat_fc_fuse", graph);
GraphPatternDetector detector;
auto* pattern = detector.mutable_pattern();
auto* concat_out = BuildSeqExpandConcatPattern(pattern);
......@@ -194,8 +193,8 @@ std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl(
int fuse_count{0};
detector(graph.get(), [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
detector(graph, [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "get one concat pattern";
// fc
GET_NODE(fc_w, detector.pattern());
......@@ -246,8 +245,6 @@ std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl(
});
AddStatis(fuse_count);
return graph;
}
} // namespace ir
......
......@@ -27,8 +27,7 @@ class SeqConcatFcFusePass : public FusePassBase {
virtual ~SeqConcatFcFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册