提交 3e579812 编写于 作者: M minqiyang

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into imperative_lr_scheduler

test=develop
......@@ -134,7 +134,7 @@ paddle.fluid.layers.sampled_softmax_with_cross_entropy (ArgSpec(args=['logits',
paddle.fluid.layers.hsigmoid (ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name', 'path_table', 'path_code', 'is_custom', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, False, False)), ('document', '80641ee6810b1cdc3fd6e14fc89ecc9d'))
paddle.fluid.layers.beam_search (ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'is_accumulated', 'name', 'return_parent_idx'], varargs=None, keywords=None, defaults=(0, True, None, False)), ('document', 'b350b9a30a18e7efd7e1bb740eef6996'))
paddle.fluid.layers.row_conv (ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None)), ('document', '17485788fffe4e2d36dc58c2ac8d174e'))
paddle.fluid.layers.multiplex (ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None), ('document', '013795af319e2e86d3506741941078ee'))
paddle.fluid.layers.multiplex (ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None), ('document', '2c4d1ae83da6ed35e3b36ba1b3b51d23'))
paddle.fluid.layers.layer_norm (ArgSpec(args=['input', 'scale', 'shift', 'begin_norm_axis', 'epsilon', 'param_attr', 'bias_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(True, True, 1, 1e-05, None, None, None, None)), ('document', 'de6a906950bae9f3c245cb744d22b94e'))
paddle.fluid.layers.group_norm (ArgSpec(args=['input', 'groups', 'epsilon', 'param_attr', 'bias_attr', 'act', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(1e-05, None, None, None, 'NCHW', None)), ('document', '419c3a24a83cc89219a029cf4092788b'))
paddle.fluid.layers.spectral_norm (ArgSpec(args=['weight', 'dim', 'power_iters', 'eps', 'name'], varargs=None, keywords=None, defaults=(0, 1, 1e-12, None)), ('document', '3f536aafba30d793287b52d231baff1b'))
......
......@@ -134,6 +134,11 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
out_layout =
out_layout == DataLayout::kAnyLayout ? DataLayout::kNCHW : out_layout;
auto& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = dynamic_cast<platform::MKLDNNDeviceContext*>(
pool.Get(expected_kernel_type.place_));
auto& cpu_engine = dev_ctx->GetEngine();
std::vector<int> in_tz = paddle::framework::vectorize2int(in.dims());
std::vector<int> out_tz = in_tz;
......@@ -142,25 +147,29 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
"Input tensor type is not supported: %s", in.type());
memory::data_type out_type = in_type;
auto in_format = platform::MKLDNNFormatForSize(in_tz.size(), in.format());
auto out_format =
platform::MKLDNNFormatForSize(in_tz.size(), ToMKLDNNFormat(out_layout));
// output tensor has the same dims as input. Reorder don't change dims
out->Resize(in.dims());
// tempory mem pd fr out , to make reorder
auto out_mem_pd = paddle::platform::create_prim_desc_from_dims(
paddle::framework::vectorize2int(out->dims()),
mkldnn::memory::format::blocked, out_type);
if (in.get_mkldnn_prim_desc() != out_mem_pd) {
if (in_format != out_format) {
void* in_data = GetDataFromTensor(in, in_type);
auto out_data = out->mutable_data(expected_kernel_type.place_, in.type());
auto in_memory = memory(in.get_mkldnn_prim_desc(), in_data);
auto out_memory = memory(out_mem_pd, out_data);
auto in_memory =
memory({{{in_tz}, in_type, in_format}, cpu_engine}, in_data);
auto out_memory =
memory({{{out_tz}, out_type, out_format}, cpu_engine}, out_data);
platform::Reorder(in_memory, out_memory);
} else {
out->ShareDataWith(in);
}
out->set_layout(out_layout);
// reset format since the out tensor will be feed to non-MKLDNN OPkernel
out->set_format(memory::format::format_undef);
#endif
}
......
......@@ -51,31 +51,13 @@ void TransformData(const OpKernelType &expected_kernel_type,
#ifdef PADDLE_WITH_MKLDNN
// Case1 - transform from Non-MKLDNN OPKernel to MKLDNN OPKernel
// Just set layout/format. No real transform occur
auto out_format = platform::MKLDNNFormatForSize(in.dims().size(),
ToMKLDNNFormat(lin));
out.ShareDataWith(input_tensor);
// TODO(jczaja): Remove that once all mkldnn ops
// are modified to work with mkldnn_blocked
auto mkldnn_fmt = [&](int rank) {
switch (rank) {
case 5:
return mkldnn::memory::format::ncdhw;
case 4:
return mkldnn::memory::format::nchw;
case 3:
return mkldnn::memory::format::ncw;
case 2:
return mkldnn::memory::format::nc;
case 1:
return mkldnn::memory::format::x;
default:
return mkldnn::memory::format::blocked;
}
};
auto out_mem_pd = paddle::platform::create_prim_desc_from_dims(
paddle::framework::vectorize2int(out.dims()),
mkldnn_fmt(out.dims().size()));
out.set_mkldnn_prim_desc(out_mem_pd);
out.set_layout(DataLayout::kMKLDNN);
out.set_format(out_format);
#endif
} else {
// Case2 - transfrom from MKLDNN OPKernel to Non-MKLDNN OPKernel
......
......@@ -10,7 +10,10 @@ cc_library(fetch_barrier_op_handle SRCS fetch_barrier_op_handle.cc DEPS framewor
cc_library(multi_devices_helper SRCS multi_devices_helper.cc DEPS graph graph_helper)
cc_library(multi_devices_graph_print_pass SRCS multi_devices_graph_print_pass.cc DEPS multi_devices_helper)
cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc DEPS multi_devices_helper)
cc_library(alloc_continuous_space_for_grad_pass SRCS alloc_continuous_space_for_grad_pass.cc DEPS graph graph_helper)
cc_library(fuse_adam_op_pass SRCS fuse_adam_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper)
cc_library(fuse_sgd_op_pass SRCS fuse_sgd_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper)
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
......@@ -104,5 +107,7 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS
graph_viz_pass multi_devices_graph_pass
multi_devices_graph_print_pass multi_devices_graph_check_pass
fuse_elewise_add_act_pass multi_batch_merge_pass
fuse_relu_depthwise_conv_pass
memory_optimize_pass lock_free_optimize_pass alloc_continuous_space_for_grad_pass fuse_all_reduce_op_pass)
fuse_relu_depthwise_conv_pass
memory_optimize_pass lock_free_optimize_pass
alloc_continuous_space_for_grad_pass fuse_all_reduce_op_pass
fuse_adam_op_pass fuse_sgd_op_pass)
......@@ -42,8 +42,7 @@ VarHandle* GetValidInput(const OpHandleBase* a) {
return nullptr;
}
std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void AllReduceDepsPass::ApplyImpl(ir::Graph* graph) const {
auto graph_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
// get vars order
......@@ -131,8 +130,6 @@ std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl(
VLOG(10) << "pre_op:" << pre_op->DebugString()
<< ", op:" << op->DebugString();
}
return graph;
}
} // namespace details
......
......@@ -24,8 +24,7 @@ namespace details {
// TODO(gongwb): overlap allreduce with backward computation.
class AllReduceDepsPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace details
......
......@@ -21,6 +21,7 @@
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
DEFINE_uint32(fuse_parameter_memory_size, 0, // 0 KB
"fuse_parameter_memory_size is up limited memory size "
"of one group parameters' gradient which is the input "
......@@ -46,8 +47,7 @@ static framework::proto::VarType::Type kDefaultDtype =
class AllocContinuousSpaceForGradPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override {
void ApplyImpl(ir::Graph *graph) const override {
ir::Graph &result = *graph;
auto &places = Get<const std::vector<platform::Place>>(kPlaces);
......@@ -65,7 +65,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
if (params_grads.size() == 0) {
VLOG(10) << "Doesn't find gradients";
return std::move(graph);
return;
}
std::unordered_map<std::string, ir::Node *> vars;
......@@ -106,26 +106,33 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
auto ele_dtype = iter->second->Var()->GetDataType();
if (dtype == kDefaultDtype) {
dtype = ele_dtype;
PADDLE_ENFORCE_NE(ele_dtype, kDefaultDtype);
PADDLE_ENFORCE_NE(ele_dtype, kDefaultDtype,
"The data type should not be bool.");
}
PADDLE_ENFORCE_EQ(ele_dtype, dtype);
PADDLE_ENFORCE_EQ(ele_dtype, dtype,
"The data type of input is not consistent.");
}
// Create the fused variable name.
// Create a FusedVarsSet to avoid duplicating names for fused_var in other
// pass.
if (!result.Has(kFusedVars)) {
result.Set(kFusedVars, new FusedVars);
}
const std::string prefix(kFusedVarNamePrefix);
// The fused_var_name should be unique.
auto fused_var_name = prefix + "GRAD@" + params_grads[0].second;
// the kFusedGrads is used be fuse_optimizer_op_pass.
result.Set(kFusedGrads, new FusedGrads);
// the fused_var_name should be unique, so it appends
// params_grads.begin()->second.
auto fused_var_name = std::string(kFusedVarNamePrefix) + "@GRAD@" +
params_grads.begin()->second;
result.Get<FusedGrads>(kFusedGrads) = fused_var_name;
auto &fused_var_set = result.Get<FusedVars>(kFusedVars);
PADDLE_ENFORCE_EQ(fused_var_set.count(fused_var_name), 0);
PADDLE_ENFORCE_EQ(fused_var_set.count(fused_var_name), 0,
"%s is duplicate in FusedVars.", fused_var_name);
fused_var_set.insert(fused_var_name);
InitFusedVarsAndAllocSpaceForVars(places, local_scopes, vars,
fused_var_name, params_grads);
return std::move(graph);
}
template <typename AttrType>
......@@ -298,17 +305,6 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
return type == proto::VarType::LOD_TENSOR;
}
void AppendAllocSpaceForVarsOp(const std::vector<std::string> &params_name,
const std::vector<std::string> &grads_name,
const std::string &fused_var_name,
BlockDesc *global_block) const {
auto op_desc = global_block->AppendOp();
op_desc->SetType("alloc_continuous_space");
op_desc->SetInput("Input", params_name);
op_desc->SetOutput("Output", grads_name);
op_desc->SetOutput("FusedOutput", {fused_var_name});
}
void RecordParamsAndGrads(ir::Node *node,
ParamsAndGrads *params_grads) const {
try {
......@@ -361,6 +357,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
}
}
// Alloc continuous space for vars.
std::vector<std::string> grads_name;
std::vector<std::string> params_name;
grads_name.reserve(params_grads.size());
......@@ -373,7 +370,6 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
AppendAllocSpaceForVarsOp(params_name, grads_name, fused_var_name,
program_desc.MutableBlock(0));
// Run Only Once Programs
for (size_t i = 0; i < local_scopes.size(); ++i) {
for (auto &op_desc : program_desc.Block(0).AllOps()) {
auto op = OpRegistry::CreateOp(*op_desc);
......@@ -381,6 +377,17 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
}
}
}
void AppendAllocSpaceForVarsOp(const std::vector<std::string> &params_name,
const std::vector<std::string> &grads_name,
const std::string &fused_var_name,
BlockDesc *global_block) const {
auto op_desc = global_block->AppendOp();
op_desc->SetType("alloc_continuous_space");
op_desc->SetInput("Input", params_name);
op_desc->SetOutput("Output", grads_name);
op_desc->SetOutput("FusedOutput", {fused_var_name});
}
};
} // namespace details
......
......@@ -27,20 +27,17 @@ void BroadcastOpHandle::RunImpl() {
if (places_.size() == 1) return;
// The input and output may have dummy vars.
VarHandle *in_var_handle;
{
auto in_var_handles = DynamicCast<VarHandle>(inputs_);
PADDLE_ENFORCE_EQ(in_var_handles.size(), 1UL,
"The number of input should be one.");
in_var_handle = in_var_handles[0];
}
auto in_var_handles = DynamicCast<VarHandle>(inputs_);
auto out_var_handles = DynamicCast<VarHandle>(outputs_);
PADDLE_ENFORCE_EQ(in_var_handles.size(), 1UL,
"The number of input should be one.");
PADDLE_ENFORCE_EQ(
out_var_handles.size(), places_.size(),
"The number of output should equal to the number of places.");
VarHandle *in_var_handle = in_var_handles[0];
WaitInputVarGenerated();
std::vector<const Scope *> var_scopes;
......
......@@ -17,7 +17,6 @@ limitations under the License. */
#include <glog/logging.h>
#include <memory>
#include <utility>
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/details/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
......@@ -82,23 +81,43 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendPass("inplace_pass");
}
if (strategy.fuse_elewise_add_act_ops_) {
if (strategy_.fuse_elewise_add_act_ops_) {
VLOG(10) << "Add fuse_elewise_add_act_pass";
AppendPass("fuse_elewise_add_act_pass");
}
// for single card training, fuse_all_reduce_ops is unnecessary.
// alloc_continuous_space_for_grad_pass should be before of MultiDevPass.
if (strategy.fuse_all_reduce_ops_) {
if (strategy_.fuse_all_reduce_ops_) {
VLOG(10) << "Add alloc_continuous_space_for_grad_pass";
AppendPass("alloc_continuous_space_for_grad_pass");
}
if (strategy_.fuse_all_optimizer_ops_) {
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce ||
strategy_.is_distribution_) {
VLOG(3)
<< "Currently, fuse_all_optimizer_ops only works under AllReduce "
"mode.";
strategy_.fuse_all_optimizer_ops_ = false;
} else {
VLOG(10) << "Add alloc_continuous_space_for_grad_pass";
AppendPass("alloc_continuous_space_for_grad_pass");
// NOTE: fuse_all_xx_ops will count the number of xx operator first,
// if the number is zero, fuse_all_reduce_ops will do nothing.
// Currently, only one type of optimization algorithm can be fused.
VLOG(10) << "Add fuse_adam_op_pass";
AppendPass("fuse_adam_op_pass");
VLOG(10) << "Add fuse_sgd_op_pass";
AppendPass("fuse_sgd_op_pass");
}
}
// Add a graph viz pass to record a graph.
if (!strategy.debug_graphviz_path_.empty()) {
auto viz_pass = AppendPass("graph_viz_pass");
const std::string graph_path = string::Sprintf(
"%s%s", strategy.debug_graphviz_path_.c_str(), "_fused_graph");
"%s%s", strategy_.debug_graphviz_path_.c_str(), "_fused_graph");
viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
}
......@@ -118,14 +137,14 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// the de-fact IR, any reuse on Graph is meaningless.
// A side-effect of that, memory optimize cannot forsee the fetched vars
// , so fetchlist should be set persistable before call the Run interface.
if (strategy.memory_optimize_) {
if (strategy_.memory_optimize_) {
VLOG(10) << "Add memory_optimize_pass";
AppendPass("memory_optimize_pass");
}
AppendMultiDevPass(strategy);
AppendMultiDevPass(strategy_);
if (strategy.fuse_all_reduce_ops_) {
if (strategy_.fuse_all_reduce_ops_) {
// NOTE: fuse_all_reduce_ops will count the number of all_reduce operator
// first, if the number is zero, fuse_all_reduce_ops will do nothing.
VLOG(10) << "Add fuse_all_reduce_op_pass";
......@@ -151,7 +170,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendPass("all_reduce_deps_pass");
}
if (SeqOnlyAllReduceOps(strategy)) {
if (SeqOnlyAllReduceOps(strategy_)) {
VLOG(10) << "Add all_reduce_deps_pass";
AppendPass("all_reduce_deps_pass");
}
......@@ -165,7 +184,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Convert graph to run on multi-devices.
void AppendMultiDevPass(const BuildStrategy &strategy) {
ir::Pass *multi_devices_pass = nullptr;
if (strategy_.is_distribution_) {
if (strategy.is_distribution_) {
VLOG(10) << "Add dist_multi_devices_pass";
multi_devices_pass = AppendPass("dist_multi_devices_pass").get();
} else {
......@@ -204,15 +223,16 @@ bool BuildStrategy::IsMultiDevPass(const std::string &pass_name) const {
return framework::details::MultiDevSSAGraphBuilder().count(pass_name) > 0;
}
std::unique_ptr<ir::Graph> BuildStrategy::Apply(
std::unique_ptr<ir::Graph> graph,
const std::vector<platform::Place> &places,
const std::string &loss_var_name, const std::vector<Scope *> &local_scopes,
const size_t &nranks,
ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::vector<Scope *> &local_scopes,
const size_t &nranks,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const bool use_cuda, platform::NCCLContextMap *nccl_ctxs) const {
const bool use_cuda,
platform::NCCLContextMap *nccl_ctxs) const {
#else
const bool use_cuda) const {
const bool use_cuda) const {
#endif
// Create a default one if not finalized by user.
CreatePassesFromStrategy(false);
......@@ -234,17 +254,22 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
pass->Erase(kNCCLCtxs);
pass->SetNotOwned<platform::NCCLContextMap>(kNCCLCtxs, nctx);
#endif
} else if (pass->Type() == "fuse_all_reduce_op_pass") {
} else if (pass->Type() == "alloc_continuous_space_for_grad_pass" ||
pass->Type() == "fuse_adam_op_pass" ||
pass->Type() == "fuse_sgd_op_pass" ||
pass->Type() == "fuse_all_reduce_op_pass") {
pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
pass->Erase(kLocalScopes);
pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes,
&local_scopes);
if (pass->Type() == "fuse_all_reduce_op_pass") {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
pass->Erase(kNCCLCtxs);
pass->SetNotOwned<platform::NCCLContextMap>(kNCCLCtxs, nctx);
platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
pass->Erase(kNCCLCtxs);
pass->SetNotOwned<platform::NCCLContextMap>(kNCCLCtxs, nctx);
#endif
}
} else if (pass->Type() == "alloc_continuous_space_for_grad_pass") {
pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
......@@ -265,7 +290,7 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
}
}
VLOG(3) << "Start Apply Pass " << pass->Type();
graph = pass->Apply(std::move(graph));
graph = pass->Apply(graph);
VLOG(3) << "Finish Apply Pass " << pass->Type();
}
return graph;
......@@ -293,4 +318,6 @@ USE_PASS(inplace_pass);
USE_PASS(lock_free_optimize_pass);
USE_PASS(alloc_continuous_space_for_grad_pass);
USE_PASS(graph_to_program_pass);
USE_PASS(fuse_adam_op_pass);
USE_PASS(fuse_sgd_op_pass);
USE_PASS(fuse_all_reduce_op_pass);
......@@ -18,7 +18,6 @@
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/ir/pass_builder.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
......@@ -76,6 +75,8 @@ struct BuildStrategy {
bool fuse_elewise_add_act_ops_{false};
bool fuse_all_optimizer_ops_{false};
bool fuse_all_reduce_ops_{false};
bool fuse_relu_depthwise_conv_{false};
......@@ -120,16 +121,15 @@ struct BuildStrategy {
// Apply the passes built by the pass_builder_. The passes will be
// applied to the Program and output an ir::Graph.
std::unique_ptr<ir::Graph> Apply(std::unique_ptr<ir::Graph> graph,
const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::vector<Scope *> &local_scopes,
const size_t &nranks,
ir::Graph *Apply(ir::Graph *graph, const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::vector<Scope *> &local_scopes,
const size_t &nranks,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const bool use_cuda,
platform::NCCLContextMap *nccl_ctxs) const;
const bool use_cuda,
platform::NCCLContextMap *nccl_ctxs) const;
#else
const bool use_cuda) const;
const bool use_cuda) const;
#endif
// If set true, ParallelExecutor would build the main_program into multiple
......
......@@ -170,12 +170,10 @@ static OpToVarNameSetMap ShrinkGCVars(
class EagerDeletionPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph *graph) const override;
};
std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
auto &ref_cnts =
Get<std::vector<AtomicReferenceCountMap>>(kRuntimeReferenceCount);
PADDLE_ENFORCE(ref_cnts.empty(),
......@@ -240,7 +238,7 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
auto while_op_eager_deletion_pass =
ir::PassRegistry::Instance().Get("while_op_eager_deletion_pass");
return while_op_eager_deletion_pass->Apply(std::move(graph));
while_op_eager_deletion_pass->Apply(graph);
}
} // namespace details
......
......@@ -31,9 +31,10 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
local_scopes_(local_scopes),
places_(places),
graph_(graph),
fetch_ctxs_(places),
pool_(strategy.num_threads_),
prepare_pool_(1), // add one more thread for generate op_deps
fetch_ctxs_(places) {
// add one more thread for generate op_deps
prepare_pool_(1) {
for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) {
int dep = static_cast<int>(op->NotReadyInputSize());
op_deps_.emplace(op, dep);
......
......@@ -14,7 +14,9 @@
#pragma once
#include <ThreadPool.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/details/exception_holder.h"
......@@ -37,6 +39,8 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
const ir::Graph &Graph() const override;
private:
// Note(zcd): the ThreadPool should be placed last so that ThreadPool should
// be destroyed first.
ExecutionStrategy strategy_;
std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_;
......@@ -45,21 +49,22 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
std::unordered_map<OpHandleBase *, int> op_deps_;
std::vector<OpHandleBase *> bootstrap_ops_;
::ThreadPool pool_;
::ThreadPool prepare_pool_;
platform::DeviceContextPool fetch_ctxs_;
std::atomic<int> remaining_;
std::future<
std::unique_ptr<std::unordered_map<OpHandleBase *, std::atomic<int>>>>
atomic_op_deps_;
ExceptionHolder exception_;
::ThreadPool pool_;
::ThreadPool prepare_pool_;
void RunOpAsync(std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps,
OpHandleBase *op,
const std::shared_ptr<BlockingQueue<size_t>> &complete_q);
void PrepareAtomicOpDeps();
std::future<
std::unique_ptr<std::unordered_map<OpHandleBase *, std::atomic<int>>>>
atomic_op_deps_;
ExceptionHolder exception_;
};
} // namespace details
} // namespace framework
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/fuse_adam_op_pass.h"
#include <algorithm>
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace framework {
namespace details {
const std::string FuseAdamOpPass::GetOpType() const { return "adam"; }
const std::vector<std::string> FuseAdamOpPass::GetAuxiliaryVarNames() const {
return {"Param", "Moment1", "Moment2", "Beta1Pow", "Beta2Pow"};
}
void FuseAdamOpPass::FuseOptimizerOps(
const std::unordered_map<std::string, std::vector<std::string>>
&aux_var_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &adam_ops, ir::Graph *graph) const {
FuseAdamOps(aux_var_set, fused_vars_name, adam_ops, graph);
FuseScaleOps(aux_var_set.at("Beta1Pow"), fused_vars_name.at("Beta1Pow"),
adam_ops, graph);
FuseScaleOps(aux_var_set.at("Beta2Pow"), fused_vars_name.at("Beta2Pow"),
adam_ops, graph);
}
void FuseAdamOpPass::FuseAdamOps(
const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &adam_ops, ir::Graph *graph) const {
PADDLE_ENFORCE_GT(adam_ops.size(), static_cast<size_t>(0));
// Check attributions
// NOTE: If new attribution is added, the following code maybe need change.
int op_role = boost::get<int>(
adam_ops[0]->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName()));
float beta1 = boost::get<float>(adam_ops[0]->Op()->GetAttr("beta1"));
float beta2 = boost::get<float>(adam_ops[0]->Op()->GetAttr("beta2"));
float epsilon = boost::get<float>(adam_ops[0]->Op()->GetAttr("epsilon"));
bool lazy_mode = boost::get<bool>(adam_ops[0]->Op()->GetAttr("lazy_mode"));
int64_t min_row_size_to_use_multithread = boost::get<int64_t>(
adam_ops[0]->Op()->GetAttr("min_row_size_to_use_multithread"));
for (auto &adam_op : adam_ops) {
PADDLE_ENFORCE_EQ(beta1,
boost::get<float>(adam_op->Op()->GetAttr("beta1")));
PADDLE_ENFORCE_EQ(beta2,
boost::get<float>(adam_op->Op()->GetAttr("beta2")));
PADDLE_ENFORCE_EQ(epsilon,
boost::get<float>(adam_op->Op()->GetAttr("epsilon")));
PADDLE_ENFORCE_EQ(lazy_mode,
boost::get<bool>(adam_op->Op()->GetAttr("lazy_mode")));
PADDLE_ENFORCE_EQ(min_row_size_to_use_multithread,
boost::get<int64_t>(adam_op->Op()->GetAttr(
"min_row_size_to_use_multithread")));
PADDLE_ENFORCE_EQ(op_role, boost::get<int>(adam_op->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())));
}
// NOTE: fused_var is only exist in scope, so the graph doesn't have fused_var
// node.
VLOG(10) << "Insert adam to graph ";
OpDesc adam_desc(adam_ops[0]->Op()->Block());
adam_desc.SetType("adam");
adam_desc.SetInput("Param", {fused_vars_name.at("Param")});
adam_desc.SetInput("Grad", {fused_vars_name.at("Grad")});
adam_desc.SetInput("Moment1", {fused_vars_name.at("Moment1")});
adam_desc.SetInput("Moment2", {fused_vars_name.at("Moment2")});
// TODO(zcd): The LearningRate, Beta1Pow, Beta2Pow should be equal.
adam_desc.SetInput("LearningRate", adam_ops[0]->Op()->Input("LearningRate"));
adam_desc.SetInput("Beta1Pow", adam_ops[0]->Op()->Input("Beta1Pow"));
adam_desc.SetInput("Beta2Pow", adam_ops[0]->Op()->Input("Beta2Pow"));
adam_desc.SetOutput("ParamOut", {fused_vars_name.at("Param")});
adam_desc.SetOutput("Moment1Out", {fused_vars_name.at("Moment1")});
adam_desc.SetOutput("Moment2Out", {fused_vars_name.at("Moment2")});
adam_desc.SetAttr("beta1", beta1);
adam_desc.SetAttr("beta2", beta2);
adam_desc.SetAttr("epsilon", epsilon);
adam_desc.SetAttr("lazy_mode", lazy_mode);
adam_desc.SetAttr("min_row_size_to_use_multithread",
min_row_size_to_use_multithread);
adam_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), op_role);
auto adam_node = graph->CreateOpNode(&adam_desc);
InserInputAndOutputForOptOps(adam_ops, adam_node);
}
void FuseAdamOpPass::FuseScaleOps(const std::vector<std::string> &beta_name,
const std::string &fused_var_name,
const std::vector<ir::Node *> &adam_ops,
ir::Graph *graph) const {
PADDLE_ENFORCE_EQ(beta_name.size(), adam_ops.size());
const std::string scale_op_name = "scale";
// Get the scale_ops of dealing the adam's beta var.
std::vector<ir::Node *> scale_ops;
scale_ops.reserve(beta_name.size());
for (size_t i = 0; i < adam_ops.size(); ++i) {
auto &beta_1_pow_name = beta_name[i];
auto beta_pow_iter = std::find_if(
adam_ops[i]->inputs.begin(), adam_ops[i]->inputs.end(),
[&beta_name, &beta_1_pow_name](ir::Node *var_node) -> bool {
return var_node->Var() && var_node->Var()->Name() == beta_1_pow_name;
});
PADDLE_ENFORCE(beta_pow_iter != adam_ops[i]->inputs.end());
auto beta_pow_node = *beta_pow_iter;
auto scale_op_iter = std::find_if(
beta_pow_node->outputs.begin(), beta_pow_node->outputs.end(),
[&scale_op_name](ir::Node *op_node) -> bool {
return op_node->Op() && op_node->Op()->Type() == scale_op_name;
});
PADDLE_ENFORCE(scale_op_iter != beta_pow_node->outputs.end());
scale_ops.emplace_back(*scale_op_iter);
}
PADDLE_ENFORCE_EQ(scale_ops.size(), beta_name.size());
// Check attributions
// NOTE: If new attribution is added, the following code maybe need change.
int op_role = boost::get<int>(
scale_ops[0]->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName()));
float scale = boost::get<float>(scale_ops[0]->Op()->GetAttr("scale"));
float bias = boost::get<float>(scale_ops[0]->Op()->GetAttr("bias"));
bool bias_after_scale =
boost::get<bool>(scale_ops[0]->Op()->GetAttr("bias_after_scale"));
for (auto &scale_op : scale_ops) {
PADDLE_ENFORCE_EQ(scale,
boost::get<float>(scale_op->Op()->GetAttr("scale")));
PADDLE_ENFORCE_EQ(bias, boost::get<float>(scale_op->Op()->GetAttr("bias")));
PADDLE_ENFORCE_EQ(
bias_after_scale,
boost::get<bool>(scale_op->Op()->GetAttr("bias_after_scale")));
PADDLE_ENFORCE_EQ(op_role, boost::get<int>(scale_op->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())));
}
// NOTE: fused_var is only exist in scope, so the graph doesn't have fused_var
// node.
VLOG(10) << "Insert fused scale to graph.";
OpDesc scale_desc(scale_ops[0]->Op()->Block());
scale_desc.SetType("scale");
scale_desc.SetInput("X", {fused_var_name});
scale_desc.SetOutput("Out", {fused_var_name});
scale_desc.SetAttr("scale", scale);
scale_desc.SetAttr("bias", bias);
scale_desc.SetAttr("bias_after_scale", bias_after_scale);
scale_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), op_role);
auto scale_node = graph->CreateOpNode(&scale_desc);
for (auto scale_op : scale_ops) {
// set inputs
scale_node->inputs.insert(scale_node->inputs.begin(),
scale_op->inputs.begin(), scale_op->inputs.end());
for (auto &input : scale_op->inputs) {
std::replace(input->outputs.begin(), input->outputs.end(), scale_op,
scale_node);
}
// set outputs
scale_node->outputs.insert(scale_node->outputs.begin(),
scale_op->outputs.begin(),
scale_op->outputs.end());
for (auto &output : scale_op->outputs) {
std::replace(output->inputs.begin(), output->inputs.end(), scale_op,
scale_node);
}
}
// Delete scale_ops
for (auto &scale_op : scale_ops) {
graph->RemoveNode(scale_op);
}
}
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(fuse_adam_op_pass, paddle::framework::details::FuseAdamOpPass)
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace details {
class FuseAdamOpPass : public FuseOptimizerOpPass {
private:
virtual const std::string GetOpType() const;
virtual const std::vector<std::string> GetAuxiliaryVarNames() const;
// Fuse Adam Ops and Scale Ops which are used to update "Beta1Pow", "Beta2Pow"
virtual void FuseOptimizerOps(
const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &adam_ops, ir::Graph *graph) const;
void FuseAdamOps(
const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &adam_ops, ir::Graph *graph) const;
void FuseScaleOps(const std::vector<std::string> &aux_var_set,
const std::string &fused_var_name,
const std::vector<ir::Node *> &adam_ops,
ir::Graph *graph) const;
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -28,8 +28,7 @@ namespace details {
class FuseAllReduceOpPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override {
void ApplyImpl(ir::Graph *graph) const override {
ir::Graph &result = *graph;
auto &places = Get<const std::vector<platform::Place>>(kPlaces);
......@@ -71,7 +70,7 @@ class FuseAllReduceOpPass : public ir::Pass {
VLOG(10) << "Find all_reduce_ops: " << all_reduce_ops.size();
if (all_reduce_ops.size() == 0) {
return std::move(graph);
return;
}
PADDLE_ENFORCE_EQ(all_reduce_ops.size(), grads.size(),
......@@ -99,7 +98,6 @@ class FuseAllReduceOpPass : public ir::Pass {
group_all_reduce_ops, &result);
#endif
}
return std::move(graph);
}
void InsertFusedAllReduce(const std::vector<platform::Place> &places,
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h"
#include <algorithm>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace framework {
namespace details {
void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
ir::Graph &result = *graph;
auto &places = Get<const std::vector<platform::Place>>(kPlaces);
auto &local_scopes = Get<const std::vector<Scope *>>(kLocalScopes);
const std::string fuse_op_type = GetOpType();
const std::vector<std::string> aux_var_names = GetAuxiliaryVarNames();
// Step 1: Get the specified op and auxiliary variables.
std::vector<ir::Node *> topo_nodes = ir::TopologySortOperations(result);
std::unordered_map<std::string, std::vector<std::string>> aux_var_set;
std::vector<ir::Node *> opt_ops;
for (auto &node : topo_nodes) {
GetSpecifiedOpsAndVars(fuse_op_type, aux_var_names, node, &opt_ops,
&aux_var_set);
}
VLOG(10) << "Find " << fuse_op_type << " operators: " << opt_ops.size();
if (opt_ops.size() == 0) {
return;
}
if (result.Has(kFusedOptType)) {
VLOG(10)
<< "Currently only support fusing one type optimizer op. Has fused "
<< result.Get<FusedOptType>(kFusedOptType);
return;
} else {
result.Set(kFusedOptType, new FusedOptType);
}
result.Get<FusedOptType>(kFusedOptType) = fuse_op_type;
// Step 2: Insert fused_var_name to FusedVars, and the FusedVars need be
// initialized in scopes before execution.
if (!result.Has(kFusedVars)) {
result.Set(kFusedVars, new FusedVars);
}
std::unordered_map<std::string, std::string> fused_vars_name;
fused_vars_name.reserve(aux_var_names.size() + 1);
auto &fused_var_set = result.Get<FusedVars>(kFusedVars);
const std::string prefix(kFusedVarNamePrefix);
// NOTE: the fused_var_name should be unique.
for (auto &var_name : aux_var_names) {
auto fused_var_name = prefix + "_" + fuse_op_type + "_" + var_name + "_" +
aux_var_set[var_name][0];
VLOG(10) << fused_var_name;
fused_vars_name.emplace(var_name, fused_var_name);
PADDLE_ENFORCE_EQ(fused_var_set.count(fused_var_name), 0);
fused_var_set.insert(fused_var_name);
}
// Step 3: Get the fused Gradient's name
auto &params_grads = result.Get<ParamsAndGrads>(kParamsAndGrads);
if (!result.Has(kFusedGrads)) {
PADDLE_THROW(
"The alloc_continuous_space_for_grad_pass should be called before this "
"pass.");
}
auto &fused_grad = result.Get<FusedGrads>(kFusedGrads);
auto &fused_vars = result.Get<FusedVars>(kFusedVars);
auto iter = std::find(fused_vars.begin(), fused_vars.end(), fused_grad);
PADDLE_ENFORCE(iter != fused_vars.end(), "Not find the fused_grad.");
fused_vars_name.emplace("Grad", fused_grad);
// Step 4: Sort the parameters and auxiliary variables according
// to parameters' name to make variables' name correspond correctly.
PADDLE_ENFORCE(result.Has(kParamsAndGrads), "Does't find kParamsAndGrads.");
PADDLE_ENFORCE_EQ(params_grads.size(), aux_var_set.begin()->second.size(),
"The size of params_grads and aux_var_set are not equal.");
SortParametersAndAuxVars(params_grads, &aux_var_set, &opt_ops);
// Step 5: Alloc continuous space for Parameters and AuxiliaryVar(e.g.
// Moment1, Moment2, Beta1Pow, Beta2Pow) of all the optimizer ops separately.
InitFusedVarsAndAllocSpaceForVars(places, local_scopes, aux_var_names,
aux_var_set, fused_vars_name);
// Step 6: Fuse optimizer Ops and Scale Ops
FuseOptimizerOps(aux_var_set, fused_vars_name, opt_ops, &result);
// Step 7: Remove optimizer Ops
for (auto &opt_op : opt_ops) {
graph->RemoveNode(opt_op);
}
}
void FuseOptimizerOpPass::InitFusedVarsAndAllocSpaceForVars(
const std::vector<platform::Place> &places,
const std::vector<Scope *> &local_scopes,
const std::vector<std::string> &aux_var_names,
const std::unordered_map<std::string, std::vector<std::string>>
&aux_var_set,
const std::unordered_map<std::string, std::string> &fused_vars_name) const {
VLOG(10) << "Init FusedVars.";
// Alloc parameters and auxiliary vars in the respective scope.
size_t idx = local_scopes.size();
for (auto iter = local_scopes.rbegin(); iter != local_scopes.rend();
++iter, --idx) {
auto &scope = *iter;
for (auto &var_name : aux_var_names) {
auto fused_var_name = fused_vars_name.at(var_name);
VLOG(10) << "Init " << fused_var_name;
PADDLE_ENFORCE(scope->FindVar(fused_var_name) == nullptr,
"%s has exist in scope[%d]", fused_var_name, idx);
scope->Var(fused_var_name)->GetMutable<LoDTensor>();
}
}
ProgramDesc program_desc;
auto *global_block = program_desc.MutableBlock(0);
for (auto &var_name : aux_var_names) {
AppendAllocContinuousSpace(aux_var_set.at(var_name),
fused_vars_name.at(var_name), true,
global_block);
}
for (size_t i = 0; i < local_scopes.size(); ++i) {
for (auto &op_desc : global_block->AllOps()) {
auto op = OpRegistry::CreateOp(*op_desc);
op->Run(*local_scopes[i], places[i]);
}
}
}
void FuseOptimizerOpPass::SortParametersAndAuxVars(
const std::vector<std::pair<std::string, std::string>> &params_grads,
std::unordered_map<std::string, std::vector<std::string>> *aux_vars_set,
std::vector<ir::Node *> *ops) const {
PADDLE_ENFORCE_NE(aux_vars_set->count("Param"), static_cast<size_t>(0));
auto &param_vec = aux_vars_set->at("Param");
std::vector<size_t> param_sort_idx;
param_sort_idx.reserve(param_vec.size());
for (auto &p_g : params_grads) {
auto iter = std::find(param_vec.begin(), param_vec.end(), p_g.first);
PADDLE_ENFORCE(iter != param_vec.end());
auto idx = std::distance(param_vec.begin(), iter);
param_sort_idx.emplace_back(idx);
}
for (auto &aux_vars : *aux_vars_set) {
std::vector<std::string> sorted_vars;
sorted_vars.reserve(aux_vars.second.size());
for (size_t i = 0; i < aux_vars.second.size(); ++i) {
sorted_vars.emplace_back(aux_vars.second.at(param_sort_idx[i]));
}
std::swap(aux_vars.second, sorted_vars);
std::stringstream out;
for (auto &var_name : aux_vars.second) {
out << var_name << " ";
}
VLOG(10) << aux_vars.first << ": " << out.str();
}
std::vector<ir::Node *> sorted_ops;
sorted_ops.reserve(ops->size());
for (size_t i = 0; i < ops->size(); ++i) {
sorted_ops.emplace_back(ops->at(param_sort_idx[i]));
}
std::swap(*ops, sorted_ops);
}
void FuseOptimizerOpPass::GetSpecifiedOpsAndVars(
const std::string &op_type, const std::vector<std::string> &aux_vars_name,
ir::Node *node, std::vector<ir::Node *> *ops,
std::unordered_map<std::string, std::vector<std::string>> *aux_args_name)
const {
if (node->Op()->Type() != op_type) return;
for (auto &var_n : aux_vars_name) {
auto arg_names = node->Op()->Input(var_n);
PADDLE_ENFORCE_EQ(arg_names.size(), static_cast<size_t>(1));
(*aux_args_name)[var_n].emplace_back(arg_names[0]);
VLOG(10) << var_n << ", " << arg_names[0];
}
ops->emplace_back(node);
}
void FuseOptimizerOpPass::AppendAllocContinuousSpace(
const std::vector<std::string> &args, const std::string &out_arg,
bool copy_data, BlockDesc *global_block) const {
auto op_desc = global_block->AppendOp();
op_desc->SetType("alloc_continuous_space");
op_desc->SetInput("Input", args);
op_desc->SetOutput("Output", args);
op_desc->SetOutput("FusedOutput", {out_arg});
op_desc->SetAttr("copy_data", copy_data);
op_desc->SetAttr("check_name", true);
}
void FuseOptimizerOpPass::InserInputAndOutputForOptOps(
const std::vector<ir::Node *> &opt_ops, ir::Node *opt_node) const {
std::unordered_set<ir::Node *> inputs;
std::unordered_set<ir::Node *> outputs;
for (auto opt_op : opt_ops) {
// set inputs
inputs.insert(opt_op->inputs.begin(), opt_op->inputs.end());
for (auto &input : opt_op->inputs) {
replace(input->outputs.begin(), input->outputs.end(), opt_op, opt_node);
}
// set outputs
outputs.insert(opt_op->outputs.begin(), opt_op->outputs.end());
for (auto &output : opt_op->outputs) {
replace(output->inputs.begin(), output->inputs.end(), opt_op, opt_node);
}
}
opt_node->inputs.insert(opt_node->inputs.begin(), inputs.begin(),
inputs.end());
opt_node->outputs.insert(opt_node->outputs.begin(), outputs.begin(),
outputs.end());
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace details {
class FuseOptimizerOpPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph *graph) const override;
protected:
virtual void SortParametersAndAuxVars(
const std::vector<std::pair<std::string, std::string>> &params_grads,
std::unordered_map<std::string, std::vector<std::string>> *aux_var_set,
std::vector<ir::Node *> *ops) const;
void InserInputAndOutputForOptOps(const std::vector<ir::Node *> &opt_ops,
ir::Node *opt_node) const;
private:
virtual const std::string GetOpType() const = 0;
virtual const std::vector<std::string> GetAuxiliaryVarNames() const = 0;
virtual void FuseOptimizerOps(
const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &adam_ops, ir::Graph *graph) const = 0;
void GetSpecifiedOpsAndVars(
const std::string &op_type, const std::vector<std::string> &aux_vars_name,
ir::Node *node, std::vector<ir::Node *> *ops,
std::unordered_map<std::string, std::vector<std::string>> *aux_args_name)
const;
void AppendAllocContinuousSpace(const std::vector<std::string> &args,
const std::string &out_arg, bool copy_data,
BlockDesc *global_block) const;
void InitFusedVarsAndAllocSpaceForVars(
const std::vector<platform::Place> &places,
const std::vector<Scope *> &local_scopes,
const std::vector<std::string> &aux_var_names,
const std::unordered_map<std::string, std::vector<std::string>>
&aux_var_set,
const std::unordered_map<std::string, std::string> &fused_vars_name)
const;
};
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/fuse_sgd_op_pass.h"
#include <algorithm>
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace framework {
namespace details {
const std::string FuseSgdOpPass::GetOpType() const { return "sgd"; }
const std::vector<std::string> FuseSgdOpPass::GetAuxiliaryVarNames() const {
return {"Param"};
}
void FuseSgdOpPass::FuseOptimizerOps(
const std::unordered_map<std::string, std::vector<std::string>>
&aux_var_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &sgd_ops, ir::Graph *graph) const {
FuseSgdOps(aux_var_set, fused_vars_name, sgd_ops, graph);
}
void FuseSgdOpPass::FuseSgdOps(
const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &sgd_ops, ir::Graph *graph) const {
PADDLE_ENFORCE_GT(sgd_ops.size(), static_cast<size_t>(0));
// NOTE: fused_var is only exist in scope, so the graph doesn't have fused_var
// node.
int op_role = boost::get<int>(
sgd_ops[0]->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName()));
VLOG(10) << "Insert sgd to graph ";
// Add fused scale
OpDesc Sgd_desc(sgd_ops[0]->Op()->Block());
Sgd_desc.SetType("sgd");
Sgd_desc.SetInput("Param", {fused_vars_name.at("Param")});
Sgd_desc.SetInput("Grad", {fused_vars_name.at("Grad")});
Sgd_desc.SetOutput("ParamOut", {fused_vars_name.at("Param")});
// TODO(zcd): The LearningRate, Beta1Pow, Beta2Pow should be equal.
Sgd_desc.SetInput("LearningRate", sgd_ops[0]->Op()->Input("LearningRate"));
// NOTE: multi_devices_pass requires that every op should have a role.
Sgd_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), op_role);
auto sgd_node = graph->CreateOpNode(&Sgd_desc);
InserInputAndOutputForOptOps(sgd_ops, sgd_node);
}
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(fuse_sgd_op_pass, paddle::framework::details::FuseSgdOpPass)
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace details {
class FuseSgdOpPass : public FuseOptimizerOpPass {
private:
virtual const std::string GetOpType() const;
virtual const std::vector<std::string> GetAuxiliaryVarNames() const;
// Fuse Sgd Ops
virtual void FuseOptimizerOps(
const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &sgd_ops, ir::Graph *graph) const;
void FuseSgdOps(
const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &sgd_ops, ir::Graph *graph) const;
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -24,6 +24,19 @@ namespace paddle {
namespace framework {
namespace details {
// Note(zcd): Addresses should be aligned, otherwise, the results may have
// diff.
static size_t Alignment(size_t size, const platform::Place &place) {
// Allow to allocate the minimum chunk size is 4 KB.
size_t alignment = 1 << 12;
if (platform::is_gpu_place(place)) {
// Allow to allocate the minimum chunk size is 256 B.
alignment = 1 << 8;
}
size_t remaining = size % alignment;
return remaining == 0 ? size : size + (alignment - remaining);
}
typedef std::vector<std::vector<std::pair<std::string, const LoDTensor *>>>
GradientAndLoDTensor;
......@@ -111,10 +124,11 @@ void FusedAllReduceOpHandle::RunImpl() {
return grad1.second->data<void>() < grad2.second->data<void>();
});
size_t size_of_dtype = framework::SizeOfType(dtype);
for (size_t k = 1; k < g_tensor.size(); ++k) {
const void *cur_address = g_tensor.at(k - 1).second->data<void>();
int64_t len = g_tensor.at(k - 1).second->numel();
auto offset = len * framework::SizeOfType(dtype);
auto offset = Alignment(len * size_of_dtype, places_[0]);
void *infer_next_address = reinterpret_cast<void *>(
reinterpret_cast<uintptr_t>(cur_address) + offset);
const void *next_address = g_tensor.at(k).second->data<void>();
......@@ -228,18 +242,21 @@ void FusedAllReduceOpHandle::GetDTypeAndNumel(
const std::vector<std::pair<std::string, const LoDTensor *>> &grad_tensor,
proto::VarType::Type *dtype, int64_t *numel) const {
*numel = 0;
size_t size_of_dtype = 0;
for (size_t i = 0; i < grad_tensor.size(); ++i) {
// Get element number
int64_t len = grad_tensor.at(i).second->numel();
PADDLE_ENFORCE_GT(len, 0);
*numel += len;
// Get dtype
auto ele_type = grad_tensor.at(i).second->type();
if (i == 0) {
*dtype = ele_type;
size_of_dtype = framework::SizeOfType(ele_type);
}
PADDLE_ENFORCE_EQ(ele_type, *dtype);
// Get element number
int64_t len = grad_tensor.at(i).second->numel();
PADDLE_ENFORCE_GT(len, 0);
// Alignment(len)
*numel += Alignment(len * size_of_dtype, places_[0]) / size_of_dtype;
}
}
......
......@@ -144,10 +144,9 @@ void InplacePass::InitSSAGraphNodes() const {
}
}
std::unique_ptr<ir::Graph> InplacePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void InplacePass::ApplyImpl(ir::Graph* graph) const {
var_nodes_.clear();
view_.Build(graph.get());
view_.Build(graph);
InitSSAGraphNodes();
auto cnt = 0;
......@@ -155,11 +154,9 @@ std::unique_ptr<ir::Graph> InplacePass::ApplyImpl(
VLOG(4) << "Handle op " << cnt++ << ": " << op->Name();
if (FLAGS_enable_inplace_whitelist && !whitelist_.count(op->Name()))
continue;
TryInplaceOpInputOutput(op, graph.get());
TryInplaceOpInputOutput(op, graph);
}
// graph->ResolveHazard(var_nodes_);
return graph;
}
void InplacePass::InplaceModifyDesc(const std::string& var,
......
......@@ -69,8 +69,7 @@ class InplacePass : public ir::Pass {
InplacePass();
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
void InitSSAGraphNodes() const;
......
......@@ -44,8 +44,7 @@ namespace paddle {
namespace framework {
namespace details {
std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void MemoryOptimizePass::ApplyImpl(ir::Graph* graph) const {
auto nodes = graph->Nodes();
CollectSkipVarsSet(nodes);
......@@ -113,7 +112,7 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
cfg_->RenameVarInCFGGraph(var_name, cache_name, idx);
RenameVarInGraphDesc(var_name, cache_name, idx);
RenameVarInGraphNode(var_name, cache_name, idx, graph.get());
RenameVarInGraphNode(var_name, cache_name, idx, graph);
pool_.Erase(cache_name);
}
}
......@@ -128,8 +127,6 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
}
}
graph->ResolveHazard(var_nodes_);
return graph;
}
void MemoryOptimizePass::SubGraphOptimize(OpDesc* op_desc) const {
......
......@@ -21,6 +21,7 @@
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
......@@ -35,8 +36,7 @@ namespace details {
class MemoryOptimizePass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
// fill the variable map(var_nodes) by version.
void InitSSAGraphNodes() const;
......
......@@ -34,8 +34,7 @@ static bool IsLockAndRecordEventFreeComputationOpHandle(
return true;
}
std::unique_ptr<ir::Graph> ModifyOpLockAndRecordEventPass::ApplyImpl(
std::unique_ptr<ir::Graph> ir_graph) const {
void ModifyOpLockAndRecordEventPass::ApplyImpl(ir::Graph *ir_graph) const {
auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*ir_graph);
OpGraphView graph_view(all_ops);
for (auto &op : all_ops) {
......@@ -49,7 +48,6 @@ std::unique_ptr<ir::Graph> ModifyOpLockAndRecordEventPass::ApplyImpl(
<< compute_op->DebugString();
}
}
return ir_graph;
}
} // namespace details
......
......@@ -23,8 +23,7 @@ namespace details {
class ModifyOpLockAndRecordEventPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace details
......
......@@ -23,10 +23,8 @@ namespace details {
class SSAGraghBuilderWithChecker : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override {
PADDLE_ENFORCE(IsValidGraph(graph.get()));
return graph;
void ApplyImpl(ir::Graph *graph) const override {
PADDLE_ENFORCE(IsValidGraph(graph));
}
bool IsValidGraph(const ir::Graph *graph) const {
......
......@@ -153,8 +153,7 @@ void MultiDevSSAGraphBuilderBase::Init() const {
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
}
std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void MultiDevSSAGraphBuilderBase::ApplyImpl(ir::Graph *graph) const {
Init();
CheckGraph(*graph);
std::vector<ir::Node *> sorted_ops = SortOperations(*graph);
......@@ -236,7 +235,6 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
AddOutputToLeafOps(&result);
result.Erase(kGraphOps);
return graph;
}
void MultiDevSSAGraphBuilderBase::InsertScaleLossGradOp(
......
......@@ -20,7 +20,6 @@
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
......@@ -34,10 +33,13 @@ namespace framework {
class Scope;
namespace details {
constexpr char kLossVarName[] = "loss_var_name";
constexpr char kStrategy[] = "strategy";
constexpr char kNRanks[] = "nranks";
class MultiDevSSAGraphBuilderBase : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph *graph) const override;
virtual void Init() const;
......
......@@ -13,7 +13,9 @@
// limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
......
......@@ -17,6 +17,7 @@
#include <glog/logging.h>
#include <fstream>
#include <iosfwd>
#include <memory>
#include <ostream>
#include <string>
#include "paddle/fluid/framework/details/multi_devices_helper.h"
......@@ -40,13 +41,11 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
class SSAGraghBuilderWithPrinter : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override {
void ApplyImpl(ir::Graph* graph) const override {
std::unique_ptr<std::ostream> fout(
new std::ofstream(Get<std::string>(kGraphvizPath)));
PADDLE_ENFORCE(fout->good());
Get<GraphvizSSAGraphPrinter>("graph_printer").Print(*graph, *fout);
return graph;
}
};
......
......@@ -20,7 +20,6 @@
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/var_handle.h"
......@@ -41,22 +40,25 @@ namespace details {
// `std::vector<VarHandle*>` is the version of varaibles.
typedef std::vector<std::unordered_map<std::string, std::vector<VarHandle *>>>
GraphVars;
const char kGraphVars[] = "vars";
// aux variables to represent dependency. Useful to resolve data hazard.
typedef std::unordered_set<VarHandleBase *> GraphDepVars;
const char kGraphDepVars[] = "dep_vars";
constexpr char kGraphVars[] = "vars";
constexpr char kNCCLCtxs[] = "nccl_ctxs";
constexpr char kLossVarName[] = "loss_var_name";
constexpr char kPlaces[] = "places";
constexpr char kLocalScopes[] = "local_scopes";
constexpr char kStrategy[] = "strategy";
constexpr char kNRanks[] = "nranks";
constexpr char kNCCLCtxs[] = "nccl_ctxs";
// aux variables to represent dependency. Useful to resolve data hazard.
typedef std::unordered_set<VarHandleBase *> GraphDepVars;
constexpr char kGraphDepVars[] = "dep_vars";
typedef std::unordered_set<std::string> FusedVars;
constexpr char kFusedVars[] = "fused_vars";
constexpr char kFusedVarNamePrefix[] = "@FUSEDVAR@";
typedef std::string FusedOptType;
constexpr char kFusedOptType[] = "fused_opt_type";
typedef std::string FusedGrads;
constexpr char kFusedGrads[] = "fused_gradients";
typedef std::vector<std::pair<std::string, std::string>> ParamsAndGrads;
constexpr char kParamsAndGrads[] = "params_grads";
......@@ -65,8 +67,6 @@ typedef std::vector<std::vector<std::pair<std::string, std::string>>>
GroupGradsAndParams;
constexpr char kGroupGradsAndParams[] = "group_grads_params";
constexpr char kFusedVarNamePrefix[] = "@FUSEDVAR@";
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -96,7 +96,7 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
auto seq_allreduce_pass =
ir::PassRegistry::Instance().Get("all_reduce_deps_pass");
for (size_t i = 0; i < graphs_.size(); ++i) {
graphs_[i] = seq_allreduce_pass->Apply(std::move(graphs_[i]));
graphs_[i].reset(seq_allreduce_pass->Apply(graphs_[i].release()));
}
// set the correct size of thread pool to each device.
......
......@@ -266,8 +266,7 @@ static bool ShrinkNoNeedBufferVarOpDependency(
}
}
std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
auto &ref_cnts = Get<std::vector<ReferenceCountMap>>(kGlobalReferenceCount);
auto &last_live_ops_of_vars =
Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars);
......@@ -335,14 +334,13 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
var_name);
ref_cnts[i].emplace(var_name, result.size());
last_live_ops_of_vars[i].emplace(var_name, std::move(result));
break;
}
// Seldomly, all preceding trying failed.
// Just skip this corner case
}
}
return graph;
}
} // namespace details
......
......@@ -23,8 +23,7 @@ namespace details {
class ReferenceCountPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace details
......
......@@ -29,8 +29,7 @@ static bool IsSameOpDesc(OpDesc *op1, OpDesc *op2) {
op1->Outputs() == op2->Outputs();
}
std::unique_ptr<ir::Graph> SequentialExecutionPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void SequentialExecutionPass::ApplyImpl(ir::Graph *graph) const {
// FIXME(zjl): Insert dependencies between some distributed ops may cause
// the multi_devices_graph_pass fails. So we skip these ops here.
// Indeed, maybe we should not insert dependencies between these ops
......@@ -98,7 +97,6 @@ std::unique_ptr<ir::Graph> SequentialExecutionPass::ApplyImpl(
VLOG(10) << "Add dependencies between " << op_node_list[i - 1]->Name()
<< " and " << op_node_list[i]->Name();
}
return graph;
}
} // namespace details
......
......@@ -23,8 +23,7 @@ namespace details {
class SequentialExecutionPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace details
......
......@@ -24,13 +24,13 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, ir::Graph *graph)
: graph_(graph),
pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_)
: nullptr),
prepare_pool_(1),
local_scopes_(local_scopes),
places_(places),
fetch_ctxs_(places),
strategy_(strategy) {
strategy_(strategy),
prepare_pool_(1),
pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_)
: nullptr) {
PrepareOpDeps();
CopyOpDeps();
}
......
......@@ -63,13 +63,20 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
details::OpHandleBase *op);
private:
// Note(zcd): the ThreadPool should be placed last so that ThreadPool should
// be destroyed first.
ir::Graph *graph_;
std::unique_ptr<::ThreadPool> pool_;
::ThreadPool prepare_pool_;
std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_;
platform::DeviceContextPool fetch_ctxs_;
ExceptionHolder exception_holder_;
std::unique_ptr<OpDependentData> op_deps_;
std::future<std::unique_ptr<OpDependentData>> op_deps_futures_;
ExecutionStrategy strategy_;
// use std::list because clear(), push_back, and for_each are O(1)
std::list<std::future<void>> run_op_futures_;
::ThreadPool prepare_pool_;
std::unique_ptr<::ThreadPool> pool_;
void InsertPendingOp(std::unordered_map<OpHandleBase *, size_t> *pending_ops,
OpHandleBase *op_instance) const;
......@@ -88,14 +95,6 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
void PrepareOpDeps();
void CopyOpDeps();
private:
std::future<std::unique_ptr<OpDependentData>> op_deps_futures_;
ExecutionStrategy strategy_;
std::unique_ptr<OpDependentData> op_deps_;
// use std::list because clear(), push_back, and for_each are O(1)
std::list<std::future<void>> run_op_futures_;
};
} // namespace details
......
......@@ -23,8 +23,7 @@ namespace details {
class WhileOpEagerDeletionPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override {
void ApplyImpl(ir::Graph *graph) const override {
auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
// Find all while_op and while_grad_op
......@@ -50,7 +49,6 @@ class WhileOpEagerDeletionPass : public ir::Pass {
operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
while_ops, while_grad_ops);
}
return graph;
}
};
......
......@@ -29,10 +29,9 @@ namespace ir {
GET_IR_NODE(elementwise_mul); \
GET_IR_NODE(elementwise_mul_out);
std::unique_ptr<ir::Graph> AnakinFillconstantElementwisemulFuse::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void AnakinFillconstantElementwisemulFuse::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "anakin_fillconstant_elementwisemul_fuse";
FusePassBase::Init(pattern_name, graph.get());
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()
......@@ -69,12 +68,11 @@ std::unique_ptr<ir::Graph> AnakinFillconstantElementwisemulFuse::ApplyImpl(
IR_NODE_LINK_TO(scale_op, elementwise_mul_out); // Output
// Delete the unneeded nodes.
GraphSafeRemoveNodes(graph.get(),
GraphSafeRemoveNodes(graph,
{fill_constant, fill_constant_out, elementwise_mul});
};
gpd(graph.get(), handler);
return graph;
gpd(graph, handler);
}
} // namespace ir
......
......@@ -26,8 +26,7 @@ class AnakinFillconstantElementwisemulFuse : public FusePassBase {
virtual ~AnakinFillconstantElementwisemulFuse() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/attention_lstm_fuse_pass.h"
#include <string>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/lod_tensor.h"
......@@ -253,8 +254,7 @@ void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input,
// Parameters
std::unique_ptr<ir::Graph> AttentionLSTMFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void AttentionLSTMFusePass::ApplyImpl(ir::Graph* graph) const {
PDPattern external_pattern, subblock_pattern;
// Use the following variables to tell whether this model is RNN1.
......@@ -269,12 +269,11 @@ std::unique_ptr<ir::Graph> AttentionLSTMFusePass::ApplyImpl(
}
}
if (count < specified_vars.size()) {
return graph;
return;
}
// Continue to fuse.
FindWhileOp(graph.get());
return graph;
FindWhileOp(graph);
}
} // namespace ir
......
......@@ -22,8 +22,7 @@ namespace ir {
class AttentionLSTMFusePass : public FusePassBase {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
......
......@@ -77,10 +77,9 @@ void recompute_bias_and_weights(const Scope* scope, ir::Node* conv_weight,
weights_array_2d.colwise() *= scale_array;
}
std::unique_ptr<ir::Graph> ConvAffineChannelFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init(name_scope_, graph.get());
void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
PADDLE_ENFORCE(scope);
......@@ -139,7 +138,7 @@ std::unique_ptr<ir::Graph> ConvAffineChannelFusePass::ApplyImpl(
desc.SetAttr("axis", 1);
auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(graph.get(), {ac_scale, ac_bias, affine_channel});
GraphSafeRemoveNodes(graph, {ac_scale, ac_bias, affine_channel});
IR_NODE_LINK_TO(conv_out, eltwise_op);
IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op);
......@@ -147,16 +146,14 @@ std::unique_ptr<ir::Graph> ConvAffineChannelFusePass::ApplyImpl(
found_conv_ac_count++;
};
gpd(graph.get(), handler);
gpd(graph, handler);
AddStatis(found_conv_ac_count);
return graph;
}
std::unique_ptr<ir::Graph> ConvEltwiseAddAffineChannelFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init(name_scope_, graph.get());
void ConvEltwiseAddAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
PADDLE_ENFORCE(scope);
......@@ -199,7 +196,7 @@ std::unique_ptr<ir::Graph> ConvEltwiseAddAffineChannelFusePass::ApplyImpl(
eltwise->Op()->SetAttr("axis", 1);
eltwise->Op()->SetOutput("Out", std::vector<std::string>({ac_out->Name()}));
GraphSafeRemoveNodes(graph.get(),
GraphSafeRemoveNodes(graph,
{ac_scale, ac_bias, affine_channel, eltwise_out});
IR_NODE_LINK_TO(eltwise, ac_out);
......@@ -207,9 +204,8 @@ std::unique_ptr<ir::Graph> ConvEltwiseAddAffineChannelFusePass::ApplyImpl(
found_conv_ac_count++;
};
gpd(graph.get(), handler);
gpd(graph, handler);
AddStatis(found_conv_ac_count);
return graph;
}
} // namespace ir
......
......@@ -31,8 +31,7 @@ class ConvAffineChannelFusePass : public FusePassBase {
virtual ~ConvAffineChannelFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph*) const override;
const std::string name_scope_{"conv_affine_channel_fuse"};
};
......@@ -41,8 +40,7 @@ class ConvEltwiseAddAffineChannelFusePass : public FusePassBase {
virtual ~ConvEltwiseAddAffineChannelFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph*) const override;
const std::string name_scope_{"conv_eltwiseadd_affine_channel_fuse"};
};
......
......@@ -101,10 +101,9 @@ void recompute_bias_and_weights(const Scope* scope,
weights_array_2d.colwise() *= variance_array;
}
std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init(name_scope_, graph.get());
void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
PADDLE_ENFORCE(scope);
......@@ -187,7 +186,7 @@ std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl(
std::vector<std::string>({bn_out->Name()}));
GraphSafeRemoveNodes(
graph.get(),
graph,
{conv_out, bn_scale, bn_bias, bn_mean, bn_variance, batch_norm,
bn_mean_out, bn_variance_out, bn_saved_mean, bn_saved_variance});
......@@ -203,10 +202,9 @@ std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl(
desc.SetAttr("axis", 1);
auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(
graph.get(),
{bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out,
bn_variance_out, bn_saved_mean, bn_saved_variance});
GraphSafeRemoveNodes(graph, {bn_scale, bn_bias, bn_mean, bn_variance,
batch_norm, bn_mean_out, bn_variance_out,
bn_saved_mean, bn_saved_variance});
IR_NODE_LINK_TO(conv_out, eltwise_op);
IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op);
......@@ -215,16 +213,14 @@ std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl(
}
};
gpd(graph.get(), handler);
gpd(graph, handler);
AddStatis(found_conv_bn_count);
return graph;
}
std::unique_ptr<ir::Graph> ConvEltwiseAddBNFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init(name_scope_, graph.get());
void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
PADDLE_ENFORCE(scope);
......@@ -274,7 +270,7 @@ std::unique_ptr<ir::Graph> ConvEltwiseAddBNFusePass::ApplyImpl(
eltwise->Op()->SetOutput("Out", std::vector<std::string>({bn_out->Name()}));
GraphSafeRemoveNodes(
graph.get(),
graph,
{bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out,
bn_variance_out, bn_saved_mean, bn_saved_variance, eltwise_out});
......@@ -283,10 +279,9 @@ std::unique_ptr<ir::Graph> ConvEltwiseAddBNFusePass::ApplyImpl(
found_conv_bn_count++;
};
gpd(graph.get(), handler);
gpd(graph, handler);
AddStatis(found_conv_bn_count);
return graph;
}
} // namespace ir
......
......@@ -31,8 +31,7 @@ class ConvBNFusePass : public FusePassBase {
virtual ~ConvBNFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"conv_bn_fuse"};
};
......@@ -41,8 +40,7 @@ class ConvEltwiseAddBNFusePass : public FusePassBase {
virtual ~ConvEltwiseAddBNFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"conv_eltwiseadd_bn_fuse"};
};
......
......@@ -50,10 +50,9 @@ framework::proto::OpDesc PrepareOpDesc(
return *desc.Proto();
}
std::unique_ptr<ir::Graph> ConvElementwiseAddActFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "conv_elementwise_add_act_fuse";
FusePassBase::Init(pattern_name, graph.get());
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()->NewNode("x")->AsInput()->assert_is_op_input(
......@@ -95,7 +94,6 @@ std::unique_ptr<ir::Graph> ConvElementwiseAddActFusePass::ApplyImpl(
elementwise_add_out});
};
gpd(graph.get(), handler);
return graph;
}
} // namespace ir
......
......@@ -51,10 +51,9 @@ framework::proto::OpDesc PrepareOpDesc(
return *desc.Proto();
}
std::unique_ptr<ir::Graph> ConvElementwiseAdd2ActFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void ConvElementwiseAdd2ActFusePass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "conv_elementwise_add2_act_fuse";
FusePassBase::Init(pattern_name, graph.get());
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()->NewNode("x")->AsInput()->assert_is_op_input(
......@@ -92,12 +91,10 @@ std::unique_ptr<ir::Graph> ConvElementwiseAdd2ActFusePass::ApplyImpl(
// Delete the unneeded nodes.
GraphSafeRemoveNodes(
graph.get(),
{conv_op, conv_out, elementwise_add_op, elementwise_add_op_1,
elementwise_add_out, elementwise_add_out_1, act_op});
graph, {conv_op, conv_out, elementwise_add_op, elementwise_add_op_1,
elementwise_add_out, elementwise_add_out_1, act_op});
};
gpd(graph.get(), handler);
return graph;
gpd(graph, handler);
}
} // namespace ir
......
......@@ -25,8 +25,7 @@ class ConvElementwiseAdd2ActFusePass : public FusePassBase {
virtual ~ConvElementwiseAdd2ActFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
......
......@@ -48,10 +48,9 @@ framework::proto::OpDesc PrepareOpDesc(
return *desc.Proto();
}
std::unique_ptr<ir::Graph> ConvElementwiseAddActFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "conv_elementwise_add_act_fuse";
FusePassBase::Init(pattern_name, graph.get());
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()
......@@ -88,12 +87,11 @@ std::unique_ptr<ir::Graph> ConvElementwiseAddActFusePass::ApplyImpl(
IR_NODE_LINK_TO(new_conv_op, act_out); // Output
// Delete the unneeded nodes.
GraphSafeRemoveNodes(graph.get(), {conv_op, conv_out, elementwise_add_op,
elementwise_add_out, act_op});
GraphSafeRemoveNodes(graph, {conv_op, conv_out, elementwise_add_op,
elementwise_add_out, act_op});
};
gpd(graph.get(), handler);
return graph;
gpd(graph, handler);
}
} // namespace ir
......
......@@ -25,8 +25,7 @@ class ConvElementwiseAddActFusePass : public FusePassBase {
virtual ~ConvElementwiseAddActFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
......
......@@ -30,10 +30,9 @@ namespace ir {
GET_IR_NODE(elementwise_add_in_y); \
GET_IR_NODE(elementwise_add_out);
std::unique_ptr<ir::Graph> ConvElementwiseAddFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void ConvElementwiseAddFusePass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "conv_elementwise_add_fuse";
FusePassBase::Init(pattern_name, graph.get());
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()
......@@ -76,11 +75,10 @@ std::unique_ptr<ir::Graph> ConvElementwiseAddFusePass::ApplyImpl(
IR_NODE_LINK_TO(new_conv_op, elementwise_add_out); // Output
// Delete the unneeded nodes.
GraphSafeRemoveNodes(graph.get(), {conv_op, conv_out, elementwise_add_op});
GraphSafeRemoveNodes(graph, {conv_op, conv_out, elementwise_add_op});
};
gpd(graph.get(), handler);
return graph;
gpd(graph, handler);
}
} // namespace ir
......
......@@ -25,8 +25,7 @@ class ConvElementwiseAddFusePass : public FusePassBase {
virtual ~ConvElementwiseAddFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
......
......@@ -15,6 +15,8 @@
#include "paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h"
#include <algorithm>
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/operators/math/blas.h"
......@@ -201,7 +203,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
// Remove unneeded nodes.
// TODO(jczaja): Proper removing of lookup table
std::unordered_set<const Node*> marked_nodes(
//{lookup_table, mul, lstm, elementwise_add, fc_bias, W});
// {lookup_table, mul, lstm, elementwise_add, fc_bias, W});
{mul, lstm, elementwise_add, fc_bias});
GraphSafeRemoveNodes(graph, marked_nodes);
} else {
......@@ -224,15 +226,13 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
return fusion_count;
}
std::unique_ptr<ir::Graph> EmbeddingFCLSTMFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init(name_scope_, graph.get());
void EmbeddingFCLSTMFusePass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(),
true /*with_fc_bias*/);
int fusion_count =
BuildFusion(graph, name_scope_, param_scope(), true /*with_fc_bias*/);
AddStatis(fusion_count);
return graph;
}
} // namespace ir
......
......@@ -32,8 +32,7 @@ class EmbeddingFCLSTMFusePass : public FusePassBase {
virtual ~EmbeddingFCLSTMFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"embedding_fc_lstm_fuse"};
};
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/fc_fuse_pass.h"
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -22,10 +23,9 @@ namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init("fc_fuse", graph.get());
void FCFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init("fc_fuse", graph);
std::unordered_set<Node*> nodes2delete;
......@@ -61,7 +61,7 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
desc.SetAttr("in_num_col_dims", mul->Op()->GetAttr("x_num_col_dims"));
desc.SetType("fc");
auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(graph.get(), {mul, elementwise_add, mul_out});
GraphSafeRemoveNodes(graph, {mul, elementwise_add, mul_out});
PADDLE_ENFORCE(subgraph.count(x));
IR_NODE_LINK_TO(subgraph.at(x), fc_node);
......@@ -72,10 +72,9 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
found_fc_count++;
};
gpd(graph.get(), handler);
gpd(graph, handler);
AddStatis(found_fc_count);
return graph;
}
} // namespace ir
......
......@@ -31,8 +31,7 @@ class FCFusePass : public FusePassBase {
virtual ~FCFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
......
......@@ -73,7 +73,7 @@ TEST(FCFusePass, basic) {
int pre_nodes = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
graph.reset(pass->Apply(graph.release()));
int after_nodes = graph->Nodes().size();
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/fc_gru_fuse_pass.h"
#include <string>
#include <unordered_set>
#include "paddle/fluid/framework/lod_tensor.h"
namespace paddle {
......@@ -39,7 +40,6 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
// Create New OpDesc
auto gru_creater = [&](Node* gru, Node* x, Node* weight_x, Node* weight_h,
Node* bias, Node* hidden, Node* fc_bias) {
OpDesc op_desc;
op_desc.SetType("fusion_gru");
......@@ -155,26 +155,22 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
return fusion_count;
}
std::unique_ptr<ir::Graph> MulGRUFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init(name_scope_, graph.get());
void MulGRUFusePass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(),
false /*with_fc_bias*/);
int fusion_count =
BuildFusion(graph, name_scope_, param_scope(), false /*with_fc_bias*/);
AddStatis(fusion_count);
return graph;
}
std::unique_ptr<ir::Graph> FCGRUFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init(name_scope_, graph.get());
void FCGRUFusePass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(),
true /*with_fc_bias*/);
int fusion_count =
BuildFusion(graph, name_scope_, param_scope(), true /*with_fc_bias*/);
AddStatis(fusion_count);
return graph;
}
} // namespace ir
......
......@@ -30,8 +30,7 @@ class FCGRUFusePass : public FusePassBase {
virtual ~FCGRUFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"fc_gru_fuse"};
};
......@@ -42,8 +41,7 @@ class MulGRUFusePass : public FusePassBase {
virtual ~MulGRUFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"fc_nobias_gru_fuse"};
};
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
#include <string>
#include <unordered_set>
#include "paddle/fluid/framework/lod_tensor.h"
namespace paddle {
......@@ -157,26 +158,22 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
return fusion_count;
}
std::unique_ptr<ir::Graph> MulLstmFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init(name_scope_, graph.get());
void MulLstmFusePass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(),
false /*with_fc_bias*/);
int fusion_count =
BuildFusion(graph, name_scope_, param_scope(), false /*with_fc_bias*/);
AddStatis(fusion_count);
return graph;
}
std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init(name_scope_, graph.get());
void FCLstmFusePass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(),
true /*with_fc_bias*/);
int fusion_count =
BuildFusion(graph, name_scope_, param_scope(), true /*with_fc_bias*/);
AddStatis(fusion_count);
return graph;
}
} // namespace ir
......
......@@ -32,8 +32,7 @@ class FCLstmFusePass : public FusePassBase {
virtual ~FCLstmFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"fc_lstm_fuse"};
};
......@@ -43,8 +42,7 @@ class MulLstmFusePass : public FusePassBase {
virtual ~MulLstmFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"fc_nobias_lstm_fuse"};
};
......
......@@ -15,6 +15,8 @@
#include "paddle/fluid/framework/ir/fuse_elewise_add_act_pass.h"
#include <algorithm>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -23,29 +25,25 @@ namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<ir::Graph> FuseElewiseAddActPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void FuseElewiseAddActPass::ApplyImpl(ir::Graph *graph) const {
std::unordered_set<std::string> act_types = {"relu", "scale"};
graph = FuseActElewiseAdd(std::move(graph), act_types);
graph = FuseElewiseAddAct(std::move(graph), act_types);
graph = FuseActElewiseAdd(graph, act_types);
graph = FuseElewiseAddAct(graph, act_types);
// backward
{
std::unordered_set<std::string> in_place_act_types = {"relu_grad"};
graph = FuseElewiseAddActInplaceGrad(std::move(graph), in_place_act_types);
graph = FuseElewiseAddActInplaceGrad(graph, in_place_act_types);
}
// Remove the removable intermediate_out.
RemoveIntermediateOut(graph.get());
return graph;
RemoveIntermediateOut(graph);
}
// ele_add(x, act(y))
std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseElewiseAddAct(
std::unique_ptr<ir::Graph> graph,
const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init("elewise_add_act", graph.get());
ir::Graph *FuseElewiseAddActPass::FuseElewiseAddAct(
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init("elewise_add_act", graph);
GraphPatternDetector gpd;
auto *x = gpd.mutable_pattern()
......@@ -86,18 +84,17 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseElewiseAddAct(
found_elewise_add_act_count++;
};
gpd(graph.get(), handler);
gpd(graph, handler);
AddStatis(found_elewise_add_act_count);
return graph;
}
// act(ele_add(x,y))
std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseActElewiseAdd(
std::unique_ptr<ir::Graph> graph,
const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init("act_elewise_add", graph.get());
ir::Graph *FuseElewiseAddActPass::FuseActElewiseAdd(
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init("act_elewise_add", graph);
GraphPatternDetector gpd;
auto *x = gpd.mutable_pattern()
......@@ -137,7 +134,7 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseActElewiseAdd(
found_elewise_add_act_count++;
};
gpd(graph.get(), handler);
gpd(graph, handler);
AddStatis(found_elewise_add_act_count);
return graph;
......@@ -146,11 +143,10 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseActElewiseAdd(
// the backward of act(ele_add(x,y))
// act_grad: in["Out", "Out@GRAD"], out["X@GRAD"]
// ele_add_grad: in["Y", "Out@GRAD"], out["X@GRAD", "Y@GRAD"]
std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad(
std::unique_ptr<ir::Graph> graph,
const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init("elewise_add_act_grad", graph.get());
ir::Graph *FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad(
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init("elewise_add_act_grad", graph);
GraphPatternDetector gpd;
auto *d_act_out = gpd.mutable_pattern()
......@@ -217,7 +213,7 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad(
found_elewise_add_act_count++;
};
gpd(graph.get(), handler);
gpd(graph, handler);
AddStatis(found_elewise_add_act_count);
return graph;
......
......@@ -14,6 +14,8 @@
#pragma once
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
......@@ -32,20 +34,16 @@ class FuseElewiseAddActPass : public FusePassBase {
virtual ~FuseElewiseAddActPass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph *graph) const override;
std::unique_ptr<ir::Graph> FuseElewiseAddAct(
std::unique_ptr<ir::Graph> graph,
const std::unordered_set<std::string> &act_types) const;
ir::Graph *FuseElewiseAddAct(
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const;
std::unique_ptr<ir::Graph> FuseActElewiseAdd(
std::unique_ptr<ir::Graph> graph,
const std::unordered_set<std::string> &act_types) const;
ir::Graph *FuseActElewiseAdd(
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const;
std::unique_ptr<ir::Graph> FuseElewiseAddActInplaceGrad(
std::unique_ptr<ir::Graph> graph,
const std::unordered_set<std::string> &act_types) const;
ir::Graph *FuseElewiseAddActInplaceGrad(
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const;
/**
* Remove the removable intermediate_out.
......
......@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/ir/fuse_relu_depthwise_conv_pass.h"
#include <algorithm>
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -23,20 +24,18 @@ namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<ir::Graph> FuseReluDepthwiseConvPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
graph = FuseReluDepthwiseConv(std::move(graph), true);
graph = FuseReluDepthwiseConv(std::move(graph), false);
return graph;
void FuseReluDepthwiseConvPass::ApplyImpl(ir::Graph *graph) const {
graph = FuseReluDepthwiseConv(graph, true);
graph = FuseReluDepthwiseConv(graph, false);
}
std::unique_ptr<ir::Graph> FuseReluDepthwiseConvPass::FuseReluDepthwiseConv(
std::unique_ptr<ir::Graph> graph, bool only_forward) const {
PADDLE_ENFORCE(graph.get());
ir::Graph *FuseReluDepthwiseConvPass::FuseReluDepthwiseConv(
ir::Graph *graph, bool only_forward) const {
PADDLE_ENFORCE(graph);
if (only_forward)
FusePassBase::Init("relu_depthwise_conv_only_forward", graph.get());
FusePassBase::Init("relu_depthwise_conv_only_forward", graph);
else
FusePassBase::Init("relu_depthwise_conv", graph.get());
FusePassBase::Init("relu_depthwise_conv", graph);
/*
x ---act--> y ---layer-> z
+----------+
......@@ -144,10 +143,9 @@ std::unique_ptr<ir::Graph> FuseReluDepthwiseConvPass::FuseReluDepthwiseConv(
}
count++;
};
gpd(graph.get(), handler);
GraphSafeRemoveNodes(graph.get(), need_removed_nodes);
gpd(graph, handler);
GraphSafeRemoveNodes(graph, need_removed_nodes);
AddStatis(count);
return graph;
}
......
......@@ -32,10 +32,8 @@ class FuseReluDepthwiseConvPass : public FusePassBase {
virtual ~FuseReluDepthwiseConvPass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
std::unique_ptr<ir::Graph> FuseReluDepthwiseConv(
std::unique_ptr<ir::Graph> graph, bool only_forward) const;
void ApplyImpl(ir::Graph* graph) const override;
ir::Graph* FuseReluDepthwiseConv(ir::Graph* graph, bool only_forward) const;
};
} // namespace ir
......
......@@ -15,7 +15,9 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include <map>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
......@@ -26,8 +28,7 @@ namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<Graph> GraphToProgramPass::ApplyImpl(
std::unique_ptr<Graph> graph) const {
void GraphToProgramPass::ApplyImpl(ir::Graph* graph) const {
// Remove the unneeded variables after memory optimization.
std::unordered_set<std::string> vars2remove;
if (graph->Has(kGraphToProgramVarsToRemove)) {
......@@ -73,7 +74,6 @@ std::unique_ptr<Graph> GraphToProgramPass::ApplyImpl(
}
program.CopyFrom(*program_pb);
return graph;
}
} // namespace ir
......
......@@ -26,7 +26,7 @@ const char kGraphToProgramSortKind[] = "__graph_to_program_sort_kind__";
class GraphToProgramPass : public Pass {
protected:
std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
......
......@@ -14,7 +14,9 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/program_desc.h"
......@@ -84,7 +86,7 @@ TEST(GraphToProgramPass, Basic) {
ProgramDesc compiled_prog;
pass->SetNotOwned<paddle::framework::ProgramDesc>("program", &compiled_prog);
pass->Apply(std::move(g));
pass->Apply(g.get());
std::vector<OpDesc*> ops = compiled_prog.Block(0).AllOps();
EXPECT_EQ(ops[0]->Type(), "op1");
EXPECT_EQ(ops[1]->Type(), "op2");
......
......@@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include <algorithm>
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/inference/analysis/dot.h"
#include "paddle/fluid/string/printf.h"
......@@ -38,8 +38,7 @@ std::string FormatName(const Node* node) {
}
} // namespace
std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void GraphVizPass::ApplyImpl(ir::Graph* graph) const {
const std::string graph_viz_path = Get<std::string>(kGraphVizPath);
VLOG(3) << "draw IR graph viz to " << graph_viz_path;
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_viz_path));
......@@ -82,7 +81,7 @@ std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
{Dot::Attr("style", "filled,rounded"), Dot::Attr("shape", "box"),
Dot::Attr("fillcolor", "yellow")});
auto marked_nodes = ConsumeMarkedNodes(graph.get());
auto marked_nodes = ConsumeMarkedNodes(graph);
// Create nodes
for (const Node* n : graph->Nodes()) {
std::string node_id = FormatName(n) + "(" + std::to_string(n->id()) + ")";
......@@ -115,8 +114,6 @@ std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
}
sout << dot.Build();
return graph;
}
GraphVizPass::marked_nodes_t GraphVizPass::ConsumeMarkedNodes(
......@@ -135,4 +132,4 @@ GraphVizPass::marked_nodes_t GraphVizPass::ConsumeMarkedNodes(
} // namespace paddle
REGISTER_PASS(graph_viz_pass, paddle::framework::ir::GraphVizPass)
.RequirePassAttr(paddle::framework::ir::kGraphVizPath);
\ No newline at end of file
.RequirePassAttr(paddle::framework::ir::kGraphVizPath);
......@@ -18,6 +18,7 @@ limitations under the License. */
#include <map>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
......@@ -34,8 +35,7 @@ class GraphVizPass : public Pass {
using marked_nodes_t = std::unordered_set<const Node*>;
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
// Tell whether there are any marked nodes in the graph. Consume the
// corresponding attribute.
......
......@@ -20,9 +20,8 @@ namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<ir::Graph> IdentityScaleOpCleanPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init("identity_scale_op_clean", graph.get());
void IdentityScaleOpCleanPass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init("identity_scale_op_clean", graph);
// pre_op -> scale_in -> scale_op -> scale_out
// ->
......@@ -72,8 +71,7 @@ std::unique_ptr<ir::Graph> IdentityScaleOpCleanPass::ApplyImpl(
IR_NODE_LINK_TO(pre_op_var, scale_out_var);
};
detector(graph.get(), handler);
return graph;
detector(graph, handler);
}
} // namespace ir
......
......@@ -22,8 +22,7 @@ namespace ir {
class IdentityScaleOpCleanPass : public FusePassBase {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
private:
virtual ~IdentityScaleOpCleanPass() = default;
......
......@@ -26,9 +26,9 @@ class InferCleanGraphPass : public FusePassBase {
virtual ~InferCleanGraphPass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init("original_graph", graph.get());
PADDLE_ENFORCE(graph.get());
void ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init("original_graph", graph);
PADDLE_ENFORCE(graph);
auto is_valid_node = [](Node* x) {
return x && IsControlDepVar(*x) && x->IsVar() && !x->Var();
......@@ -46,11 +46,9 @@ class InferCleanGraphPass : public FusePassBase {
}
}
GraphSafeRemoveNodes(graph.get(), invalid_nodes);
GraphSafeRemoveNodes(graph, invalid_nodes);
AddStatis(valid_op);
return graph;
}
void CleanEdges(std::vector<Node*>* nodes,
......
......@@ -20,8 +20,7 @@ namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<ir::Graph> IsTestPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void IsTestPass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Sets is_test attrbiute to true and if it is missing, inserts it "
"for activations and pooling.";
auto op_list = {"pool2d", "sigmoid", "logsigmoid",
......@@ -47,7 +46,6 @@ std::unique_ptr<ir::Graph> IsTestPass::ApplyImpl(
}
}
}
return graph;
}
} // namespace ir
......
......@@ -22,8 +22,7 @@ namespace ir {
class IsTestPass : public Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
......
......@@ -97,7 +97,7 @@ TEST(IsTestPass, basic) {
auto pass = PassRegistry::Instance().Get("is_test_pass");
graph = pass->Apply(std::move(graph));
graph.reset(pass->Apply(graph.release()));
for (auto* node : graph->Nodes()) {
if (node->IsOp()) {
......
......@@ -32,9 +32,8 @@ const char kSumGradOpName[] = "sum";
// other optimizers later.
const char kOptimizerType[] = "sgd";
std::unique_ptr<ir::Graph> LockFreeOptimizePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
void LockFreeOptimizePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
// We could collect all weights' name from SGD, where
// W1 <- SGD(W0, Grad0)
......@@ -92,14 +91,14 @@ std::unique_ptr<ir::Graph> LockFreeOptimizePass::ApplyImpl(
// find the forward op related to the backward op
ir::Node* forward_op =
FindForwardOpViaBackwardOp(graph.get(), backward_op);
FindForwardOpViaBackwardOp(graph, backward_op);
VLOG(3) << "Found forward_op " << forward_op->Name();
PADDLE_ENFORCE(forward_op);
Node* new_optimizer_node = CreateNewSGDNode(
graph.get(), forward_op, backward_op, node, opt_node);
graph, forward_op, backward_op, node, opt_node);
PADDLE_ENFORCE(new_optimizer_node);
}
......@@ -140,8 +139,6 @@ std::unique_ptr<ir::Graph> LockFreeOptimizePass::ApplyImpl(
}
}
}
return graph;
}
ir::Node* LockFreeOptimizePass::CreateNewSGDNode(
......
......@@ -60,8 +60,7 @@ class LockFreeOptimizePass : public Pass {
virtual ~LockFreeOptimizePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
private:
// Create a new sgd node via current optimizer node
......
......@@ -38,10 +38,9 @@ LoDTensor tensor_apply_eltwise(const LoDTensor& vec_a, const LoDTensor& vec_b,
return vec_y;
}
std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init(name_scope_, graph.get());
void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
PADDLE_ENFORCE(scope);
......@@ -99,7 +98,7 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
conv->Op()->SetOutput("Output",
std::vector<std::string>({eltwise_out->Name()}));
GraphSafeRemoveNodes(graph.get(), {eltwise, conv_out});
GraphSafeRemoveNodes(graph, {eltwise, conv_out});
IR_NODE_LINK_TO(conv, eltwise_out);
} else {
......@@ -123,14 +122,13 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
IR_NODE_LINK_TO(eltwise_bias, conv_bias_node);
IR_NODE_LINK_TO(conv_bias_node, eltwise_out);
GraphSafeRemoveNodes(graph.get(), {conv, eltwise, conv_out});
GraphSafeRemoveNodes(graph, {conv, eltwise, conv_out});
}
found_conv_bias_count++;
};
gpd(graph.get(), handler);
gpd(graph, handler);
AddStatis(found_conv_bias_count);
return graph;
}
} // namespace ir
} // namespace framework
......
......@@ -29,8 +29,7 @@ class ConvBiasFusePass : public FusePassBase {
virtual bool is_conv3d() const { return false; }
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"conv_bias_mkldnn_fuse"};
};
/*
......
......@@ -13,10 +13,10 @@
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/platform/place.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
......@@ -103,7 +103,7 @@ void MainTest(bool convWithExistingBias) {
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size();
......
......@@ -16,8 +16,8 @@
#include <functional>
#include <list>
#include <map>
#include <memory>
#include <tuple>
#include "paddle/fluid/framework/ir/graph_traits.h"
namespace paddle {
......@@ -327,17 +327,15 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
get_node_from_elementwise_add);
}
graph_ptr ResidualConnectionMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
FusePassBase::Init(name_scope_, graph.get());
void ResidualConnectionMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
FusePassBase::Init(name_scope_, graph);
auto fused_graph_with_stats = FuseConvAsY(
name_scope_,
FuseConvAsX(
name_scope_,
FuseProjectionConv(name_scope_, std::make_pair(graph.get(), 0))));
FuseConvAsX(name_scope_,
FuseProjectionConv(name_scope_, std::make_pair(graph, 0))));
std::cout << "Fused graph " << fused_graph_with_stats.second << std::endl;
AddStatis(fused_graph_with_stats.second);
return graph;
}
} // namespace ir
} // namespace framework
......
......@@ -14,6 +14,7 @@
#pragma once
#include <memory>
#include <string>
#include <tuple>
#include <utility>
......@@ -27,7 +28,7 @@ namespace paddle {
namespace framework {
namespace ir {
using graph_ptr = std::unique_ptr<ir::Graph>;
using graph_ptr = ir::Graph*;
using GraphWithStats = std::pair<ir::Graph*, int>;
void CorrectGraphEdges(Graph* graph, Node* from, Node* to);
......@@ -124,7 +125,7 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
virtual ~ResidualConnectionMKLDNNFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(graph_ptr graph) const;
void ApplyImpl(graph_ptr graph) const;
const std::string name_scope_{"residual_connection_fuse_pass"};
};
......
......@@ -148,7 +148,7 @@ void RunPassAndAssert(ProgramDesc* prog, const std::string& from,
auto pass =
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size();
EXPECT_TRUE(is_reachable(graph)(from, to));
......@@ -258,7 +258,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, NoFusion) {
auto pass =
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size();
EXPECT_TRUE(is_reachable(graph)("a", "g"));
......
......@@ -21,10 +21,9 @@ namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init("conv_relu_mkldnn_fuse", graph.get());
void ConvReLUFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init("conv_relu_mkldnn_fuse", graph);
GraphPatternDetector gpd;
auto* conv_input = gpd.mutable_pattern()
......@@ -56,7 +55,7 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl(
OpDesc* desc = conv->Op();
desc->SetOutput("Output", std::vector<std::string>({relu_out->Name()}));
desc->SetAttr("fuse_relu", true);
GraphSafeRemoveNodes(graph.get(), {relu, conv_out});
GraphSafeRemoveNodes(graph, {relu, conv_out});
PADDLE_ENFORCE(subgraph.count(conv_input));
IR_NODE_LINK_TO(conv, relu_out);
......@@ -64,10 +63,9 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl(
found_conv_relu_count++;
};
gpd(graph.get(), handler);
gpd(graph, handler);
AddStatis(found_conv_relu_count);
return graph;
}
} // namespace ir
......
......@@ -31,8 +31,7 @@ class ConvReLUFusePass : public FusePassBase {
virtual ~ConvReLUFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
......
......@@ -88,7 +88,7 @@ TEST(ConvReLUFusePass, basic) {
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size();
......
......@@ -216,19 +216,16 @@ void CPUQuantizePass::QuantizePool(Graph* graph) const {
PrettyLogDetail("--- quantized %d pool2d ops", quantize_pool_count);
}
std::unique_ptr<ir::Graph> CPUQuantizePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Quantizing the graph.";
PADDLE_ENFORCE(graph.get());
FusePassBase::Init(name_scope_, graph.get());
PADDLE_ENFORCE(graph);
FusePassBase::Init(name_scope_, graph);
PADDLE_ENFORCE(param_scope());
QuantizeConv(graph.get(), false /* with_residual_data */);
QuantizeConv(graph.get(), true /* with_residual_data */);
QuantizePool(graph.get());
return graph;
QuantizeConv(graph, false /* with_residual_data */);
QuantizeConv(graph, true /* with_residual_data */);
QuantizePool(graph);
}
} // namespace ir
......
......@@ -42,8 +42,7 @@ class CPUQuantizePass : public FusePassBase {
virtual ~CPUQuantizePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
void QuantizeConv(Graph* graph, bool with_residual_data = false) const;
......
......@@ -139,7 +139,7 @@ void MainTest(const ProgramDesc& prog, int conv_count, int pool_count,
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size();
......
......@@ -20,8 +20,7 @@ namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<ir::Graph> CPUQuantizePlacementPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Marks operators which are to be quantized.";
const auto& excluded_ids_list =
Get<std::unordered_set<int>>("quantize_excluded_op_ids");
......@@ -43,7 +42,6 @@ std::unique_ptr<ir::Graph> CPUQuantizePlacementPass::ApplyImpl(
}
}
}
return graph;
}
} // namespace ir
......
......@@ -25,8 +25,7 @@ namespace ir {
*/
class CPUQuantizePlacementPass : public Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
......
......@@ -94,7 +94,7 @@ void MainTest(std::initializer_list<std::string> quantize_enabled_op_types,
pass->Set("quantize_excluded_op_ids",
new std::unordered_set<int>(quantize_excluded_op_ids));
graph = pass->Apply(std::move(graph));
graph.reset(pass->Apply(graph.release()));
unsigned use_quantizer_true_count = 0;
......
......@@ -126,16 +126,13 @@ void CPUQuantizeSquashPass::Squash(
found_squash_count);
}
std::unique_ptr<ir::Graph> CPUQuantizeSquashPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init("cpu_quantize_squash_pass", graph.get());
void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init("cpu_quantize_squash_pass", graph);
std::unordered_map<const Node*, int> nodes_keep_counter;
FindNodesToKeep(graph.get(), &nodes_keep_counter);
Squash(graph.get(), &nodes_keep_counter);
return graph;
FindNodesToKeep(graph, &nodes_keep_counter);
Squash(graph, &nodes_keep_counter);
}
} // namespace ir
......
......@@ -34,8 +34,7 @@ class CPUQuantizeSquashPass : public FusePassBase {
virtual ~CPUQuantizeSquashPass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
/*
* For each dequantize's output find the number of operators it is an input to
......
......@@ -125,7 +125,7 @@ void MainTest(const ProgramDesc& prog, int removed_nodes_num) {
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size();
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册