未验证 提交 a2be4b4d 编写于 作者: C chengduo 提交者: GitHub

Add fuse momenutum ops (#16745)

* Add fuse momenutum ops
上级 03d469ad
......@@ -14,6 +14,7 @@ cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc
cc_library(alloc_continuous_space_for_grad_pass SRCS alloc_continuous_space_for_grad_pass.cc DEPS graph graph_helper)
cc_library(fuse_adam_op_pass SRCS fuse_adam_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper)
cc_library(fuse_sgd_op_pass SRCS fuse_sgd_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper)
cc_library(fuse_momentum_op_pass SRCS fuse_momentum_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper)
cc_library(record_skip_memory_opt_vars_pass SRCS record_skip_memory_opt_vars_pass.cc DEPS graph graph_helper)
......@@ -126,4 +127,5 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS
fuse_relu_depthwise_conv_pass
memory_optimize_pass lock_free_optimize_pass
alloc_continuous_space_for_grad_pass fuse_all_reduce_op_pass
fuse_adam_op_pass fuse_sgd_op_pass record_skip_memory_opt_vars_pass)
fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass
record_skip_memory_opt_vars_pass)
......@@ -57,7 +57,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendPass("record_skip_memory_opt_vars_pass");
if (strategy_.enable_sequential_execution_) {
VLOG(10) << "Add sequential_execution_pass";
VLOG(5) << "Add sequential_execution_pass";
AppendPass("sequential_execution_pass");
}
......@@ -68,7 +68,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Add op fusion.
if (strategy.fuse_relu_depthwise_conv_) {
VLOG(10) << "Add fuse_relu_depthwise_conv_pass";
VLOG(5) << "Add fuse_relu_depthwise_conv_pass";
AppendPass("fuse_relu_depthwise_conv_pass");
}
......@@ -80,19 +80,19 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Add automatically inplace.
if (strategy_.enable_inplace_) {
VLOG(10) << "Add inplace_pass";
VLOG(5) << "Add inplace_pass";
AppendPass("inplace_pass");
}
if (strategy_.fuse_elewise_add_act_ops_) {
VLOG(10) << "Add fuse_elewise_add_act_pass";
VLOG(5) << "Add fuse_elewise_add_act_pass";
AppendPass("fuse_elewise_add_act_pass");
}
// for single card training, fuse_all_reduce_ops is unnecessary.
// alloc_continuous_space_for_grad_pass should be before of MultiDevPass.
if (strategy_.fuse_all_reduce_ops_) {
VLOG(10) << "Add alloc_continuous_space_for_grad_pass";
VLOG(5) << "Add alloc_continuous_space_for_grad_pass";
AppendPass("alloc_continuous_space_for_grad_pass");
}
......@@ -107,10 +107,12 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// NOTE: fuse_all_xx_ops will count the number of xx operator first,
// if the number is zero, fuse_all_reduce_ops will do nothing.
// Currently, only one type of optimization algorithm can be fused.
VLOG(10) << "Add fuse_adam_op_pass";
VLOG(5) << "Add fuse_adam_op_pass";
AppendPass("fuse_adam_op_pass");
VLOG(10) << "Add fuse_sgd_op_pass";
VLOG(5) << "Add fuse_sgd_op_pass";
AppendPass("fuse_sgd_op_pass");
VLOG(5) << "Add fuse_momentum_op_pass";
AppendPass("fuse_momentum_op_pass");
}
}
......@@ -139,7 +141,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// A side-effect of that, memory optimize cannot forsee the fetched vars
// , so fetchlist should be set persistable before call the Run interface.
if (strategy_.memory_optimize_) {
VLOG(10) << "Add memory_optimize_pass";
VLOG(5) << "Add memory_optimize_pass";
AppendPass("memory_optimize_pass");
}
......@@ -147,7 +149,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// all original and fused operators. But no operators can be enabled this
// attr if putting it after MultiDevPass.
if (strategy_.cache_runtime_context_) {
VLOG(10) << "Add runtime_context_cache_pass";
VLOG(5) << "Add runtime_context_cache_pass";
AppendPass("runtime_context_cache_pass");
}
......@@ -161,7 +163,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
if (strategy_.fuse_all_reduce_ops_) {
// NOTE: fuse_all_reduce_ops will count the number of all_reduce operator
// first, if the number is zero, fuse_all_reduce_ops will do nothing.
VLOG(10) << "Add fuse_all_reduce_op_pass";
VLOG(5) << "Add fuse_all_reduce_op_pass";
AppendPass("fuse_all_reduce_op_pass");
}
......@@ -182,12 +184,12 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
if (!strategy_.enable_parallel_graph_ &&
(SeqOnlyAllReduceOps(strategy_) ||
strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce)) {
VLOG(10) << "Add all_reduce_deps_pass";
VLOG(5) << "Add all_reduce_deps_pass";
AppendPass("all_reduce_deps_pass");
}
if (strategy_.remove_unnecessary_lock_) {
VLOG(10) << "Add modify_op_lock_and_record_event_pass";
VLOG(5) << "Add modify_op_lock_and_record_event_pass";
AppendPass("modify_op_lock_and_record_event_pass");
}
......@@ -202,16 +204,16 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
if (strategy_.async_mode_) {
multi_devices_pass = AppendPass("async_multi_devices_pass").get();
} else if (strategy_.is_distribution_) {
VLOG(10)
VLOG(5)
<< "Add dist_multi_devices_pass, multi device parameter server mode";
multi_devices_pass = AppendPass("dist_multi_devices_pass").get();
} else {
if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
VLOG(10) << "Add all_reduce_mode_multi_devices_pass";
VLOG(5) << "Add all_reduce_mode_multi_devices_pass";
multi_devices_pass =
AppendPass("all_reduce_mode_multi_devices_pass").get();
} else if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
VLOG(10) << "Add reduce_mode_multi_devices_pass";
VLOG(5) << "Add reduce_mode_multi_devices_pass";
multi_devices_pass = AppendPass("reduce_mode_multi_devices_pass").get();
} else {
PADDLE_THROW("Unknown reduce strategy.");
......@@ -277,6 +279,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
} else if (pass->Type() == "alloc_continuous_space_for_grad_pass" ||
pass->Type() == "fuse_adam_op_pass" ||
pass->Type() == "fuse_sgd_op_pass" ||
pass->Type() == "fuse_momentum_op_pass" ||
pass->Type() == "fuse_all_reduce_op_pass") {
pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
......@@ -341,6 +344,7 @@ USE_PASS(alloc_continuous_space_for_grad_pass);
USE_PASS(graph_to_program_pass);
USE_PASS(fuse_adam_op_pass);
USE_PASS(fuse_sgd_op_pass);
USE_PASS(fuse_momentum_op_pass);
USE_PASS(fuse_all_reduce_op_pass);
USE_PASS(runtime_context_cache_pass);
USE_PASS(expected_kernel_cache_pass);
......
......@@ -11,9 +11,15 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/fuse_adam_op_pass.h"
#include <algorithm>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
......@@ -21,13 +27,15 @@ namespace paddle {
namespace framework {
namespace details {
const std::string FuseAdamOpPass::GetOpType() const { return "adam"; }
class FuseAdamOpPass : public FuseOptimizerOpPass {
private:
const std::string GetOpType() const { return "adam"; }
const std::vector<std::string> FuseAdamOpPass::GetAuxiliaryVarNames() const {
const std::vector<std::string> GetAuxiliaryVarNames() const {
return {"Moment1", "Moment2", "Beta1Pow", "Beta2Pow"};
}
}
void FuseAdamOpPass::FuseOptimizerOps(
void FuseOptimizerOps(
const std::unordered_map<std::string, std::vector<std::string>>
&aux_var_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
......@@ -37,9 +45,9 @@ void FuseAdamOpPass::FuseOptimizerOps(
adam_ops, graph);
FuseScaleOps(aux_var_set.at("Beta2Pow"), fused_vars_name.at("Beta2Pow"),
adam_ops, graph);
}
}
void FuseAdamOpPass::FuseAdamOps(
void FuseAdamOps(
const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &adam_ops, ir::Graph *graph) const {
......@@ -67,14 +75,15 @@ void FuseAdamOpPass::FuseAdamOps(
PADDLE_ENFORCE_EQ(min_row_size_to_use_multithread,
boost::get<int64_t>(adam_op->Op()->GetAttr(
"min_row_size_to_use_multithread")));
PADDLE_ENFORCE_EQ(op_role, boost::get<int>(adam_op->Op()->GetAttr(
PADDLE_ENFORCE_EQ(op_role,
boost::get<int>(adam_op->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())));
}
// NOTE: fused_var is only exist in scope, so the graph doesn't have fused_var
// node.
// NOTE: fused_var is only exist in scope, so the graph doesn't have
// fused_var node.
VLOG(10) << "Insert adam to graph ";
VLOG(7) << "Insert adam to graph ";
OpDesc adam_desc(adam_ops[0]->Op()->Block());
adam_desc.SetType("adam");
adam_desc.SetInput(kParam, {fused_vars_name.at(kParam)});
......@@ -100,9 +109,9 @@ void FuseAdamOpPass::FuseAdamOps(
auto adam_node = graph->CreateOpNode(&adam_desc);
InserInputAndOutputForOptOps(adam_ops, adam_node);
}
}
void FuseAdamOpPass::FuseScaleOps(const std::vector<std::string> &beta_name,
void FuseScaleOps(const std::vector<std::string> &beta_name,
const std::string &fused_var_name,
const std::vector<ir::Node *> &adam_ops,
ir::Graph *graph) const {
......@@ -117,7 +126,8 @@ void FuseAdamOpPass::FuseScaleOps(const std::vector<std::string> &beta_name,
auto beta_pow_iter = std::find_if(
adam_ops[i]->inputs.begin(), adam_ops[i]->inputs.end(),
[&beta_name, &beta_1_pow_name](ir::Node *var_node) -> bool {
return var_node->Var() && var_node->Var()->Name() == beta_1_pow_name;
return var_node->Var() &&
var_node->Var()->Name() == beta_1_pow_name;
});
PADDLE_ENFORCE(beta_pow_iter != adam_ops[i]->inputs.end());
......@@ -144,18 +154,20 @@ void FuseAdamOpPass::FuseScaleOps(const std::vector<std::string> &beta_name,
for (auto &scale_op : scale_ops) {
PADDLE_ENFORCE_EQ(scale,
boost::get<float>(scale_op->Op()->GetAttr("scale")));
PADDLE_ENFORCE_EQ(bias, boost::get<float>(scale_op->Op()->GetAttr("bias")));
PADDLE_ENFORCE_EQ(bias,
boost::get<float>(scale_op->Op()->GetAttr("bias")));
PADDLE_ENFORCE_EQ(
bias_after_scale,
boost::get<bool>(scale_op->Op()->GetAttr("bias_after_scale")));
PADDLE_ENFORCE_EQ(op_role, boost::get<int>(scale_op->Op()->GetAttr(
PADDLE_ENFORCE_EQ(op_role,
boost::get<int>(scale_op->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())));
}
// NOTE: fused_var is only exist in scope, so the graph doesn't have fused_var
// node.
// NOTE: fused_var is only exist in scope, so the graph doesn't have
// fused_var node.
VLOG(10) << "Insert fused scale to graph.";
VLOG(7) << "Insert fused scale to graph.";
OpDesc scale_desc(scale_ops[0]->Op()->Block());
scale_desc.SetType("scale");
scale_desc.SetInput("X", {fused_var_name});
......@@ -169,7 +181,8 @@ void FuseAdamOpPass::FuseScaleOps(const std::vector<std::string> &beta_name,
for (auto scale_op : scale_ops) {
// set inputs
scale_node->inputs.insert(scale_node->inputs.begin(),
scale_op->inputs.begin(), scale_op->inputs.end());
scale_op->inputs.begin(),
scale_op->inputs.end());
for (auto &input : scale_op->inputs) {
std::replace(input->outputs.begin(), input->outputs.end(), scale_op,
scale_node);
......@@ -188,8 +201,8 @@ void FuseAdamOpPass::FuseScaleOps(const std::vector<std::string> &beta_name,
for (auto &scale_op : scale_ops) {
graph->RemoveNode(scale_op);
}
}
}
};
} // namespace details
} // namespace framework
} // namespace paddle
......
......@@ -12,44 +12,83 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace framework {
namespace details {
class FuseAdamOpPass : public FuseOptimizerOpPass {
class FuseMomentumOpPass : public FuseOptimizerOpPass {
private:
virtual const std::string GetOpType() const;
virtual const std::string GetOpType() const { return "momentum"; }
virtual const std::vector<std::string> GetAuxiliaryVarNames() const;
virtual const std::vector<std::string> GetAuxiliaryVarNames() const {
return {"Velocity"};
}
// Fuse Adam Ops and Scale Ops which are used to update "Beta1Pow", "Beta2Pow"
// Fuse Momentum Ops
virtual void FuseOptimizerOps(
const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &adam_ops, ir::Graph *graph) const;
const std::vector<ir::Node *> &momentum_ops, ir::Graph *graph) const {
PADDLE_ENFORCE_GT(momentum_ops.size(), static_cast<size_t>(0));
void FuseAdamOps(
const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &adam_ops, ir::Graph *graph) const;
// Check attributions
// NOTE: If new attribution is added, the following code maybe need change.
int op_role = boost::get<int>(momentum_ops[0]->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName()));
float mu = boost::get<float>(momentum_ops[0]->Op()->GetAttr("mu"));
bool use_nesterov =
boost::get<bool>(momentum_ops[0]->Op()->GetAttr("use_nesterov"));
for (auto &momentum_op : momentum_ops) {
PADDLE_ENFORCE_EQ(mu,
boost::get<float>(momentum_op->Op()->GetAttr("mu")));
PADDLE_ENFORCE_EQ(
use_nesterov,
boost::get<bool>(momentum_op->Op()->GetAttr("use_nesterov")));
PADDLE_ENFORCE_EQ(op_role,
boost::get<int>(momentum_op->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())));
}
// NOTE: fused_var is only exist in scope, so the graph doesn't have
// fused_var node.
void FuseScaleOps(const std::vector<std::string> &aux_var_set,
const std::string &fused_var_name,
const std::vector<ir::Node *> &adam_ops,
ir::Graph *graph) const;
VLOG(7) << "Insert momentum to graph ";
OpDesc momentum_desc(momentum_ops[0]->Op()->Block());
momentum_desc.SetType("momentum");
momentum_desc.SetInput(kParam, {fused_vars_name.at(kParam)});
momentum_desc.SetInput(kGrad, {fused_vars_name.at(kGrad)});
momentum_desc.SetInput("Velocity", {fused_vars_name.at("Velocity")});
// TODO(zcd): The LearningRate should be equal.
momentum_desc.SetInput(kLearningRate,
momentum_ops[0]->Op()->Input(kLearningRate));
momentum_desc.SetOutput("ParamOut", {fused_vars_name.at(kParam)});
momentum_desc.SetOutput("VelocityOut", {fused_vars_name.at("Velocity")});
momentum_desc.SetAttr("mu", mu);
momentum_desc.SetAttr("use_nesterov", use_nesterov);
momentum_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), op_role);
auto momentum_node = graph->CreateOpNode(&momentum_desc);
InserInputAndOutputForOptOps(momentum_ops, momentum_node);
}
};
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(fuse_momentum_op_pass,
paddle::framework::details::FuseMomentumOpPass)
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes);
......@@ -42,14 +42,13 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
&aux_var_set);
}
VLOG(10) << "Find " << fuse_op_type << " operators: " << opt_ops.size();
VLOG(6) << "Find " << fuse_op_type << " operators: " << opt_ops.size();
if (opt_ops.size() == 0) {
return;
}
if (result.Has(kFusedOptType)) {
VLOG(10)
<< "Currently only support fusing one type optimizer op. Has fused "
VLOG(6) << "Currently only support fusing one type optimizer op. Has fused "
<< result.Get<FusedOptType>(kFusedOptType);
return;
} else {
......@@ -70,7 +69,7 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
for (auto &var_name : aux_var_names) {
auto fused_var_name = prefix + "_" + fuse_op_type + "_" + var_name + "_" +
aux_var_set[var_name][0];
VLOG(10) << fused_var_name;
VLOG(6) << var_name << ": " << fused_var_name;
fused_vars_name.emplace(var_name, fused_var_name);
PADDLE_ENFORCE_EQ(fused_var_set.count(fused_var_name), 0);
fused_var_set.insert(fused_var_name);
......@@ -151,7 +150,7 @@ void FuseOptimizerOpPass::InitFusedGradsAndAllocSpaceForGrads(
// Init Grads
for (auto it = local_scopes.rbegin(); it != local_scopes.rend(); ++it) {
auto &scope = *it;
VLOG(10) << "Init " << fused_grad_name;
VLOG(6) << "Init: " << fused_grad_name;
PADDLE_ENFORCE(scope->FindVar(fused_grad_name) == nullptr,
"%s has existed in scope.", fused_grad_name);
scope->Var(fused_grad_name)->GetMutable<LoDTensor>();
......@@ -211,13 +210,12 @@ void FuseOptimizerOpPass::RunInitOps(const std::vector<platform::Place> &places,
void FuseOptimizerOpPass::InitVars(const std::vector<Scope *> &local_scopes,
const std::string &fused_var_name) const {
VLOG(10) << "Init FusedVars.";
// Alloc parameters and auxiliary vars in the respective scope.
size_t idx = local_scopes.size();
for (auto iter = local_scopes.rbegin(); iter != local_scopes.rend();
++iter, --idx) {
auto &scope = *iter;
VLOG(10) << "Init " << fused_var_name;
VLOG(6) << "Init: " << fused_var_name;
PADDLE_ENFORCE(scope->FindVar(fused_var_name) == nullptr,
"%s has exist in scope[%d]", fused_var_name, idx);
scope->Var(fused_var_name)->GetMutable<LoDTensor>();
......@@ -253,7 +251,7 @@ void FuseOptimizerOpPass::SortParametersAndAuxVars(
for (auto &var_name : aux_vars.second) {
out << var_name << " ";
}
VLOG(10) << aux_vars.first << ": " << out.str();
VLOG(6) << aux_vars.first << ": " << out.str();
}
std::vector<ir::Node *> sorted_ops;
......@@ -271,12 +269,14 @@ void FuseOptimizerOpPass::GetSpecifiedOpsAndVars(
const {
if (node->Op()->Type() != op_type) return;
std::stringstream out;
for (auto &var_n : aux_vars_name) {
auto arg_names = node->Op()->Input(var_n);
PADDLE_ENFORCE_EQ(arg_names.size(), static_cast<size_t>(1));
(*aux_args_name)[var_n].emplace_back(arg_names[0]);
VLOG(10) << var_n << ", " << arg_names[0];
out << var_n << ", " << arg_names[0] << "; ";
}
VLOG(7) << out.str();
ops->emplace_back(node);
}
......
......@@ -11,42 +11,43 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/fuse_sgd_op_pass.h"
#include <algorithm>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace framework {
namespace details {
const std::string FuseSgdOpPass::GetOpType() const { return "sgd"; }
class FuseSgdOpPass : public FuseOptimizerOpPass {
private:
virtual const std::string GetOpType() const { return "sgd"; }
const std::vector<std::string> FuseSgdOpPass::GetAuxiliaryVarNames() const {
virtual const std::vector<std::string> GetAuxiliaryVarNames() const {
return {};
}
}
void FuseSgdOpPass::FuseOptimizerOps(
const std::unordered_map<std::string, std::vector<std::string>>
&aux_var_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &sgd_ops, ir::Graph *graph) const {
FuseSgdOps(aux_var_set, fused_vars_name, sgd_ops, graph);
}
void FuseSgdOpPass::FuseSgdOps(
// Fuse Sgd Ops
virtual void FuseOptimizerOps(
const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &sgd_ops, ir::Graph *graph) const {
PADDLE_ENFORCE_GT(sgd_ops.size(), static_cast<size_t>(0));
// NOTE: fused_var is only exist in scope, so the graph doesn't have fused_var
// node.
// NOTE: fused_var is only exist in scope, so the graph doesn't have
// fused_var node.
int op_role = boost::get<int>(
sgd_ops[0]->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName()));
VLOG(10) << "Insert sgd to graph ";
VLOG(7) << "Insert sgd to graph ";
// Add fused scale
OpDesc Sgd_desc(sgd_ops[0]->Op()->Block());
Sgd_desc.SetType("sgd");
......@@ -54,7 +55,7 @@ void FuseSgdOpPass::FuseSgdOps(
Sgd_desc.SetInput(kGrad, {fused_vars_name.at(kGrad)});
Sgd_desc.SetOutput("ParamOut", {fused_vars_name.at(kParam)});
// TODO(zcd): The LearningRate, Beta1Pow, Beta2Pow should be equal.
// TODO(zcd): The LearningRate should be equal.
Sgd_desc.SetInput(kLearningRate, sgd_ops[0]->Op()->Input(kLearningRate));
// NOTE: multi_devices_pass requires that every op should have a role.
......@@ -63,8 +64,8 @@ void FuseSgdOpPass::FuseSgdOps(
auto sgd_node = graph->CreateOpNode(&Sgd_desc);
InserInputAndOutputForOptOps(sgd_ops, sgd_node);
}
}
};
} // namespace details
} // namespace framework
} // namespace paddle
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace details {
class FuseSgdOpPass : public FuseOptimizerOpPass {
private:
virtual const std::string GetOpType() const;
virtual const std::vector<std::string> GetAuxiliaryVarNames() const;
// Fuse Sgd Ops
virtual void FuseOptimizerOps(
const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &sgd_ops, ir::Graph *graph) const;
void FuseSgdOps(
const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &sgd_ops, ir::Graph *graph) const;
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -31,18 +31,17 @@ class TestFuseAdamOps(TestParallelExecutorBase):
if use_cuda and not core.is_compiled_with_cuda():
return
img, label = init_data()
feed_dict = {"image": img, "label": label}
not_fuse_op_first_loss, not_fuse_op_last_loss = self.check_network_convergence(
model,
feed_dict={"image": img,
"label": label},
feed_dict=feed_dict,
use_cuda=use_cuda,
fuse_all_optimizer_ops=False,
memory_opt=False, # avoid the gradient's name changed in Python side.
optimizer=optimizer)
fuse_op_first_loss, fuse_op_last_loss = self.check_network_convergence(
model,
feed_dict={"image": img,
"label": label},
feed_dict=feed_dict,
use_cuda=use_cuda,
fuse_all_optimizer_ops=True,
memory_opt=False, # avoid the gradient's name changed in Python side.
......@@ -63,7 +62,7 @@ class TestFuseAdamOps(TestParallelExecutorBase):
class TestFuseSGDOps(TestFuseAdamOps):
def sgd_optimizer(self, learning_rate=1e-4):
def sgd_optimizer(self, learning_rate=1e-3):
return fluid.optimizer.SGD(learning_rate=learning_rate)
def test_simple_fc_with_fuse_op(self):
......@@ -79,5 +78,23 @@ class TestFuseSGDOps(TestFuseAdamOps):
fc_with_batchnorm, False, optimizer=self.sgd_optimizer)
class TestFuseMomentumOps(TestFuseAdamOps):
def momentum_optimizer(self, learning_rate=1e-3):
return fluid.optimizer.Momentum(
learning_rate=learning_rate, momentum=0.1)
def test_simple_fc_with_fuse_op(self):
self._compare_fused_optimizer_ops(
simple_fc_net, True, optimizer=self.momentum_optimizer)
self._compare_fused_optimizer_ops(
simple_fc_net, False, optimizer=self.momentum_optimizer)
def test_batchnorm_fc_with_fuse_op(self):
self._compare_fused_optimizer_ops(
fc_with_batchnorm, True, optimizer=self.momentum_optimizer)
self._compare_fused_optimizer_ops(
fc_with_batchnorm, False, optimizer=self.momentum_optimizer)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册