未验证 提交 a2be4b4d 编写于 作者: C chengduo 提交者: GitHub

Add fuse momenutum ops (#16745)

* Add fuse momenutum ops
上级 03d469ad
...@@ -14,6 +14,7 @@ cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc ...@@ -14,6 +14,7 @@ cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc
cc_library(alloc_continuous_space_for_grad_pass SRCS alloc_continuous_space_for_grad_pass.cc DEPS graph graph_helper) cc_library(alloc_continuous_space_for_grad_pass SRCS alloc_continuous_space_for_grad_pass.cc DEPS graph graph_helper)
cc_library(fuse_adam_op_pass SRCS fuse_adam_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper) cc_library(fuse_adam_op_pass SRCS fuse_adam_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper)
cc_library(fuse_sgd_op_pass SRCS fuse_sgd_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper) cc_library(fuse_sgd_op_pass SRCS fuse_sgd_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper)
cc_library(fuse_momentum_op_pass SRCS fuse_momentum_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper)
cc_library(record_skip_memory_opt_vars_pass SRCS record_skip_memory_opt_vars_pass.cc DEPS graph graph_helper) cc_library(record_skip_memory_opt_vars_pass SRCS record_skip_memory_opt_vars_pass.cc DEPS graph graph_helper)
...@@ -126,4 +127,5 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS ...@@ -126,4 +127,5 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS
fuse_relu_depthwise_conv_pass fuse_relu_depthwise_conv_pass
memory_optimize_pass lock_free_optimize_pass memory_optimize_pass lock_free_optimize_pass
alloc_continuous_space_for_grad_pass fuse_all_reduce_op_pass alloc_continuous_space_for_grad_pass fuse_all_reduce_op_pass
fuse_adam_op_pass fuse_sgd_op_pass record_skip_memory_opt_vars_pass) fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass
record_skip_memory_opt_vars_pass)
...@@ -57,7 +57,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -57,7 +57,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendPass("record_skip_memory_opt_vars_pass"); AppendPass("record_skip_memory_opt_vars_pass");
if (strategy_.enable_sequential_execution_) { if (strategy_.enable_sequential_execution_) {
VLOG(10) << "Add sequential_execution_pass"; VLOG(5) << "Add sequential_execution_pass";
AppendPass("sequential_execution_pass"); AppendPass("sequential_execution_pass");
} }
...@@ -68,7 +68,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -68,7 +68,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Add op fusion. // Add op fusion.
if (strategy.fuse_relu_depthwise_conv_) { if (strategy.fuse_relu_depthwise_conv_) {
VLOG(10) << "Add fuse_relu_depthwise_conv_pass"; VLOG(5) << "Add fuse_relu_depthwise_conv_pass";
AppendPass("fuse_relu_depthwise_conv_pass"); AppendPass("fuse_relu_depthwise_conv_pass");
} }
...@@ -80,19 +80,19 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -80,19 +80,19 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Add automatically inplace. // Add automatically inplace.
if (strategy_.enable_inplace_) { if (strategy_.enable_inplace_) {
VLOG(10) << "Add inplace_pass"; VLOG(5) << "Add inplace_pass";
AppendPass("inplace_pass"); AppendPass("inplace_pass");
} }
if (strategy_.fuse_elewise_add_act_ops_) { if (strategy_.fuse_elewise_add_act_ops_) {
VLOG(10) << "Add fuse_elewise_add_act_pass"; VLOG(5) << "Add fuse_elewise_add_act_pass";
AppendPass("fuse_elewise_add_act_pass"); AppendPass("fuse_elewise_add_act_pass");
} }
// for single card training, fuse_all_reduce_ops is unnecessary. // for single card training, fuse_all_reduce_ops is unnecessary.
// alloc_continuous_space_for_grad_pass should be before of MultiDevPass. // alloc_continuous_space_for_grad_pass should be before of MultiDevPass.
if (strategy_.fuse_all_reduce_ops_) { if (strategy_.fuse_all_reduce_ops_) {
VLOG(10) << "Add alloc_continuous_space_for_grad_pass"; VLOG(5) << "Add alloc_continuous_space_for_grad_pass";
AppendPass("alloc_continuous_space_for_grad_pass"); AppendPass("alloc_continuous_space_for_grad_pass");
} }
...@@ -107,10 +107,12 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -107,10 +107,12 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// NOTE: fuse_all_xx_ops will count the number of xx operator first, // NOTE: fuse_all_xx_ops will count the number of xx operator first,
// if the number is zero, fuse_all_reduce_ops will do nothing. // if the number is zero, fuse_all_reduce_ops will do nothing.
// Currently, only one type of optimization algorithm can be fused. // Currently, only one type of optimization algorithm can be fused.
VLOG(10) << "Add fuse_adam_op_pass"; VLOG(5) << "Add fuse_adam_op_pass";
AppendPass("fuse_adam_op_pass"); AppendPass("fuse_adam_op_pass");
VLOG(10) << "Add fuse_sgd_op_pass"; VLOG(5) << "Add fuse_sgd_op_pass";
AppendPass("fuse_sgd_op_pass"); AppendPass("fuse_sgd_op_pass");
VLOG(5) << "Add fuse_momentum_op_pass";
AppendPass("fuse_momentum_op_pass");
} }
} }
...@@ -139,7 +141,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -139,7 +141,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// A side-effect of that, memory optimize cannot forsee the fetched vars // A side-effect of that, memory optimize cannot forsee the fetched vars
// , so fetchlist should be set persistable before call the Run interface. // , so fetchlist should be set persistable before call the Run interface.
if (strategy_.memory_optimize_) { if (strategy_.memory_optimize_) {
VLOG(10) << "Add memory_optimize_pass"; VLOG(5) << "Add memory_optimize_pass";
AppendPass("memory_optimize_pass"); AppendPass("memory_optimize_pass");
} }
...@@ -147,7 +149,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -147,7 +149,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// all original and fused operators. But no operators can be enabled this // all original and fused operators. But no operators can be enabled this
// attr if putting it after MultiDevPass. // attr if putting it after MultiDevPass.
if (strategy_.cache_runtime_context_) { if (strategy_.cache_runtime_context_) {
VLOG(10) << "Add runtime_context_cache_pass"; VLOG(5) << "Add runtime_context_cache_pass";
AppendPass("runtime_context_cache_pass"); AppendPass("runtime_context_cache_pass");
} }
...@@ -161,7 +163,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -161,7 +163,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
if (strategy_.fuse_all_reduce_ops_) { if (strategy_.fuse_all_reduce_ops_) {
// NOTE: fuse_all_reduce_ops will count the number of all_reduce operator // NOTE: fuse_all_reduce_ops will count the number of all_reduce operator
// first, if the number is zero, fuse_all_reduce_ops will do nothing. // first, if the number is zero, fuse_all_reduce_ops will do nothing.
VLOG(10) << "Add fuse_all_reduce_op_pass"; VLOG(5) << "Add fuse_all_reduce_op_pass";
AppendPass("fuse_all_reduce_op_pass"); AppendPass("fuse_all_reduce_op_pass");
} }
...@@ -182,12 +184,12 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -182,12 +184,12 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
if (!strategy_.enable_parallel_graph_ && if (!strategy_.enable_parallel_graph_ &&
(SeqOnlyAllReduceOps(strategy_) || (SeqOnlyAllReduceOps(strategy_) ||
strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce)) { strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce)) {
VLOG(10) << "Add all_reduce_deps_pass"; VLOG(5) << "Add all_reduce_deps_pass";
AppendPass("all_reduce_deps_pass"); AppendPass("all_reduce_deps_pass");
} }
if (strategy_.remove_unnecessary_lock_) { if (strategy_.remove_unnecessary_lock_) {
VLOG(10) << "Add modify_op_lock_and_record_event_pass"; VLOG(5) << "Add modify_op_lock_and_record_event_pass";
AppendPass("modify_op_lock_and_record_event_pass"); AppendPass("modify_op_lock_and_record_event_pass");
} }
...@@ -202,16 +204,16 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -202,16 +204,16 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
if (strategy_.async_mode_) { if (strategy_.async_mode_) {
multi_devices_pass = AppendPass("async_multi_devices_pass").get(); multi_devices_pass = AppendPass("async_multi_devices_pass").get();
} else if (strategy_.is_distribution_) { } else if (strategy_.is_distribution_) {
VLOG(10) VLOG(5)
<< "Add dist_multi_devices_pass, multi device parameter server mode"; << "Add dist_multi_devices_pass, multi device parameter server mode";
multi_devices_pass = AppendPass("dist_multi_devices_pass").get(); multi_devices_pass = AppendPass("dist_multi_devices_pass").get();
} else { } else {
if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
VLOG(10) << "Add all_reduce_mode_multi_devices_pass"; VLOG(5) << "Add all_reduce_mode_multi_devices_pass";
multi_devices_pass = multi_devices_pass =
AppendPass("all_reduce_mode_multi_devices_pass").get(); AppendPass("all_reduce_mode_multi_devices_pass").get();
} else if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kReduce) { } else if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
VLOG(10) << "Add reduce_mode_multi_devices_pass"; VLOG(5) << "Add reduce_mode_multi_devices_pass";
multi_devices_pass = AppendPass("reduce_mode_multi_devices_pass").get(); multi_devices_pass = AppendPass("reduce_mode_multi_devices_pass").get();
} else { } else {
PADDLE_THROW("Unknown reduce strategy."); PADDLE_THROW("Unknown reduce strategy.");
...@@ -277,6 +279,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, ...@@ -277,6 +279,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
} else if (pass->Type() == "alloc_continuous_space_for_grad_pass" || } else if (pass->Type() == "alloc_continuous_space_for_grad_pass" ||
pass->Type() == "fuse_adam_op_pass" || pass->Type() == "fuse_adam_op_pass" ||
pass->Type() == "fuse_sgd_op_pass" || pass->Type() == "fuse_sgd_op_pass" ||
pass->Type() == "fuse_momentum_op_pass" ||
pass->Type() == "fuse_all_reduce_op_pass") { pass->Type() == "fuse_all_reduce_op_pass") {
pass->Erase(kPlaces); pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places); pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
...@@ -341,6 +344,7 @@ USE_PASS(alloc_continuous_space_for_grad_pass); ...@@ -341,6 +344,7 @@ USE_PASS(alloc_continuous_space_for_grad_pass);
USE_PASS(graph_to_program_pass); USE_PASS(graph_to_program_pass);
USE_PASS(fuse_adam_op_pass); USE_PASS(fuse_adam_op_pass);
USE_PASS(fuse_sgd_op_pass); USE_PASS(fuse_sgd_op_pass);
USE_PASS(fuse_momentum_op_pass);
USE_PASS(fuse_all_reduce_op_pass); USE_PASS(fuse_all_reduce_op_pass);
USE_PASS(runtime_context_cache_pass); USE_PASS(runtime_context_cache_pass);
USE_PASS(expected_kernel_cache_pass); USE_PASS(expected_kernel_cache_pass);
......
...@@ -11,9 +11,15 @@ ...@@ -11,9 +11,15 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/fuse_adam_op_pass.h"
#include <algorithm> #include <algorithm>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -21,13 +27,15 @@ namespace paddle { ...@@ -21,13 +27,15 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
const std::string FuseAdamOpPass::GetOpType() const { return "adam"; } class FuseAdamOpPass : public FuseOptimizerOpPass {
private:
const std::string GetOpType() const { return "adam"; }
const std::vector<std::string> FuseAdamOpPass::GetAuxiliaryVarNames() const { const std::vector<std::string> GetAuxiliaryVarNames() const {
return {"Moment1", "Moment2", "Beta1Pow", "Beta2Pow"}; return {"Moment1", "Moment2", "Beta1Pow", "Beta2Pow"};
} }
void FuseAdamOpPass::FuseOptimizerOps( void FuseOptimizerOps(
const std::unordered_map<std::string, std::vector<std::string>> const std::unordered_map<std::string, std::vector<std::string>>
&aux_var_set, &aux_var_set,
const std::unordered_map<std::string, std::string> &fused_vars_name, const std::unordered_map<std::string, std::string> &fused_vars_name,
...@@ -37,9 +45,9 @@ void FuseAdamOpPass::FuseOptimizerOps( ...@@ -37,9 +45,9 @@ void FuseAdamOpPass::FuseOptimizerOps(
adam_ops, graph); adam_ops, graph);
FuseScaleOps(aux_var_set.at("Beta2Pow"), fused_vars_name.at("Beta2Pow"), FuseScaleOps(aux_var_set.at("Beta2Pow"), fused_vars_name.at("Beta2Pow"),
adam_ops, graph); adam_ops, graph);
} }
void FuseAdamOpPass::FuseAdamOps( void FuseAdamOps(
const std::unordered_map<std::string, std::vector<std::string>> &vars_set, const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name, const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &adam_ops, ir::Graph *graph) const { const std::vector<ir::Node *> &adam_ops, ir::Graph *graph) const {
...@@ -67,14 +75,15 @@ void FuseAdamOpPass::FuseAdamOps( ...@@ -67,14 +75,15 @@ void FuseAdamOpPass::FuseAdamOps(
PADDLE_ENFORCE_EQ(min_row_size_to_use_multithread, PADDLE_ENFORCE_EQ(min_row_size_to_use_multithread,
boost::get<int64_t>(adam_op->Op()->GetAttr( boost::get<int64_t>(adam_op->Op()->GetAttr(
"min_row_size_to_use_multithread"))); "min_row_size_to_use_multithread")));
PADDLE_ENFORCE_EQ(op_role, boost::get<int>(adam_op->Op()->GetAttr( PADDLE_ENFORCE_EQ(op_role,
boost::get<int>(adam_op->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName()))); OpProtoAndCheckerMaker::OpRoleAttrName())));
} }
// NOTE: fused_var is only exist in scope, so the graph doesn't have fused_var // NOTE: fused_var is only exist in scope, so the graph doesn't have
// node. // fused_var node.
VLOG(10) << "Insert adam to graph "; VLOG(7) << "Insert adam to graph ";
OpDesc adam_desc(adam_ops[0]->Op()->Block()); OpDesc adam_desc(adam_ops[0]->Op()->Block());
adam_desc.SetType("adam"); adam_desc.SetType("adam");
adam_desc.SetInput(kParam, {fused_vars_name.at(kParam)}); adam_desc.SetInput(kParam, {fused_vars_name.at(kParam)});
...@@ -100,9 +109,9 @@ void FuseAdamOpPass::FuseAdamOps( ...@@ -100,9 +109,9 @@ void FuseAdamOpPass::FuseAdamOps(
auto adam_node = graph->CreateOpNode(&adam_desc); auto adam_node = graph->CreateOpNode(&adam_desc);
InserInputAndOutputForOptOps(adam_ops, adam_node); InserInputAndOutputForOptOps(adam_ops, adam_node);
} }
void FuseAdamOpPass::FuseScaleOps(const std::vector<std::string> &beta_name, void FuseScaleOps(const std::vector<std::string> &beta_name,
const std::string &fused_var_name, const std::string &fused_var_name,
const std::vector<ir::Node *> &adam_ops, const std::vector<ir::Node *> &adam_ops,
ir::Graph *graph) const { ir::Graph *graph) const {
...@@ -117,7 +126,8 @@ void FuseAdamOpPass::FuseScaleOps(const std::vector<std::string> &beta_name, ...@@ -117,7 +126,8 @@ void FuseAdamOpPass::FuseScaleOps(const std::vector<std::string> &beta_name,
auto beta_pow_iter = std::find_if( auto beta_pow_iter = std::find_if(
adam_ops[i]->inputs.begin(), adam_ops[i]->inputs.end(), adam_ops[i]->inputs.begin(), adam_ops[i]->inputs.end(),
[&beta_name, &beta_1_pow_name](ir::Node *var_node) -> bool { [&beta_name, &beta_1_pow_name](ir::Node *var_node) -> bool {
return var_node->Var() && var_node->Var()->Name() == beta_1_pow_name; return var_node->Var() &&
var_node->Var()->Name() == beta_1_pow_name;
}); });
PADDLE_ENFORCE(beta_pow_iter != adam_ops[i]->inputs.end()); PADDLE_ENFORCE(beta_pow_iter != adam_ops[i]->inputs.end());
...@@ -144,18 +154,20 @@ void FuseAdamOpPass::FuseScaleOps(const std::vector<std::string> &beta_name, ...@@ -144,18 +154,20 @@ void FuseAdamOpPass::FuseScaleOps(const std::vector<std::string> &beta_name,
for (auto &scale_op : scale_ops) { for (auto &scale_op : scale_ops) {
PADDLE_ENFORCE_EQ(scale, PADDLE_ENFORCE_EQ(scale,
boost::get<float>(scale_op->Op()->GetAttr("scale"))); boost::get<float>(scale_op->Op()->GetAttr("scale")));
PADDLE_ENFORCE_EQ(bias, boost::get<float>(scale_op->Op()->GetAttr("bias"))); PADDLE_ENFORCE_EQ(bias,
boost::get<float>(scale_op->Op()->GetAttr("bias")));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
bias_after_scale, bias_after_scale,
boost::get<bool>(scale_op->Op()->GetAttr("bias_after_scale"))); boost::get<bool>(scale_op->Op()->GetAttr("bias_after_scale")));
PADDLE_ENFORCE_EQ(op_role, boost::get<int>(scale_op->Op()->GetAttr( PADDLE_ENFORCE_EQ(op_role,
boost::get<int>(scale_op->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName()))); OpProtoAndCheckerMaker::OpRoleAttrName())));
} }
// NOTE: fused_var is only exist in scope, so the graph doesn't have fused_var // NOTE: fused_var is only exist in scope, so the graph doesn't have
// node. // fused_var node.
VLOG(10) << "Insert fused scale to graph."; VLOG(7) << "Insert fused scale to graph.";
OpDesc scale_desc(scale_ops[0]->Op()->Block()); OpDesc scale_desc(scale_ops[0]->Op()->Block());
scale_desc.SetType("scale"); scale_desc.SetType("scale");
scale_desc.SetInput("X", {fused_var_name}); scale_desc.SetInput("X", {fused_var_name});
...@@ -169,7 +181,8 @@ void FuseAdamOpPass::FuseScaleOps(const std::vector<std::string> &beta_name, ...@@ -169,7 +181,8 @@ void FuseAdamOpPass::FuseScaleOps(const std::vector<std::string> &beta_name,
for (auto scale_op : scale_ops) { for (auto scale_op : scale_ops) {
// set inputs // set inputs
scale_node->inputs.insert(scale_node->inputs.begin(), scale_node->inputs.insert(scale_node->inputs.begin(),
scale_op->inputs.begin(), scale_op->inputs.end()); scale_op->inputs.begin(),
scale_op->inputs.end());
for (auto &input : scale_op->inputs) { for (auto &input : scale_op->inputs) {
std::replace(input->outputs.begin(), input->outputs.end(), scale_op, std::replace(input->outputs.begin(), input->outputs.end(), scale_op,
scale_node); scale_node);
...@@ -188,8 +201,8 @@ void FuseAdamOpPass::FuseScaleOps(const std::vector<std::string> &beta_name, ...@@ -188,8 +201,8 @@ void FuseAdamOpPass::FuseScaleOps(const std::vector<std::string> &beta_name,
for (auto &scale_op : scale_ops) { for (auto &scale_op : scale_ops) {
graph->RemoveNode(scale_op); graph->RemoveNode(scale_op);
} }
} }
};
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
...@@ -12,44 +12,83 @@ ...@@ -12,44 +12,83 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#pragma once #include <algorithm>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/build_strategy.h" #include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h" #include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
class FuseAdamOpPass : public FuseOptimizerOpPass { class FuseMomentumOpPass : public FuseOptimizerOpPass {
private: private:
virtual const std::string GetOpType() const; virtual const std::string GetOpType() const { return "momentum"; }
virtual const std::vector<std::string> GetAuxiliaryVarNames() const; virtual const std::vector<std::string> GetAuxiliaryVarNames() const {
return {"Velocity"};
}
// Fuse Adam Ops and Scale Ops which are used to update "Beta1Pow", "Beta2Pow" // Fuse Momentum Ops
virtual void FuseOptimizerOps( virtual void FuseOptimizerOps(
const std::unordered_map<std::string, std::vector<std::string>> &vars_set, const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name, const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &adam_ops, ir::Graph *graph) const; const std::vector<ir::Node *> &momentum_ops, ir::Graph *graph) const {
PADDLE_ENFORCE_GT(momentum_ops.size(), static_cast<size_t>(0));
void FuseAdamOps( // Check attributions
const std::unordered_map<std::string, std::vector<std::string>> &vars_set, // NOTE: If new attribution is added, the following code maybe need change.
const std::unordered_map<std::string, std::string> &fused_vars_name, int op_role = boost::get<int>(momentum_ops[0]->Op()->GetAttr(
const std::vector<ir::Node *> &adam_ops, ir::Graph *graph) const; OpProtoAndCheckerMaker::OpRoleAttrName()));
float mu = boost::get<float>(momentum_ops[0]->Op()->GetAttr("mu"));
bool use_nesterov =
boost::get<bool>(momentum_ops[0]->Op()->GetAttr("use_nesterov"));
for (auto &momentum_op : momentum_ops) {
PADDLE_ENFORCE_EQ(mu,
boost::get<float>(momentum_op->Op()->GetAttr("mu")));
PADDLE_ENFORCE_EQ(
use_nesterov,
boost::get<bool>(momentum_op->Op()->GetAttr("use_nesterov")));
PADDLE_ENFORCE_EQ(op_role,
boost::get<int>(momentum_op->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())));
}
// NOTE: fused_var is only exist in scope, so the graph doesn't have
// fused_var node.
void FuseScaleOps(const std::vector<std::string> &aux_var_set, VLOG(7) << "Insert momentum to graph ";
const std::string &fused_var_name, OpDesc momentum_desc(momentum_ops[0]->Op()->Block());
const std::vector<ir::Node *> &adam_ops, momentum_desc.SetType("momentum");
ir::Graph *graph) const; momentum_desc.SetInput(kParam, {fused_vars_name.at(kParam)});
momentum_desc.SetInput(kGrad, {fused_vars_name.at(kGrad)});
momentum_desc.SetInput("Velocity", {fused_vars_name.at("Velocity")});
// TODO(zcd): The LearningRate should be equal.
momentum_desc.SetInput(kLearningRate,
momentum_ops[0]->Op()->Input(kLearningRate));
momentum_desc.SetOutput("ParamOut", {fused_vars_name.at(kParam)});
momentum_desc.SetOutput("VelocityOut", {fused_vars_name.at("Velocity")});
momentum_desc.SetAttr("mu", mu);
momentum_desc.SetAttr("use_nesterov", use_nesterov);
momentum_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), op_role);
auto momentum_node = graph->CreateOpNode(&momentum_desc);
InserInputAndOutputForOptOps(momentum_ops, momentum_node);
}
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(fuse_momentum_op_pass,
paddle::framework::details::FuseMomentumOpPass)
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes);
...@@ -42,14 +42,13 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const { ...@@ -42,14 +42,13 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
&aux_var_set); &aux_var_set);
} }
VLOG(10) << "Find " << fuse_op_type << " operators: " << opt_ops.size(); VLOG(6) << "Find " << fuse_op_type << " operators: " << opt_ops.size();
if (opt_ops.size() == 0) { if (opt_ops.size() == 0) {
return; return;
} }
if (result.Has(kFusedOptType)) { if (result.Has(kFusedOptType)) {
VLOG(10) VLOG(6) << "Currently only support fusing one type optimizer op. Has fused "
<< "Currently only support fusing one type optimizer op. Has fused "
<< result.Get<FusedOptType>(kFusedOptType); << result.Get<FusedOptType>(kFusedOptType);
return; return;
} else { } else {
...@@ -70,7 +69,7 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const { ...@@ -70,7 +69,7 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
for (auto &var_name : aux_var_names) { for (auto &var_name : aux_var_names) {
auto fused_var_name = prefix + "_" + fuse_op_type + "_" + var_name + "_" + auto fused_var_name = prefix + "_" + fuse_op_type + "_" + var_name + "_" +
aux_var_set[var_name][0]; aux_var_set[var_name][0];
VLOG(10) << fused_var_name; VLOG(6) << var_name << ": " << fused_var_name;
fused_vars_name.emplace(var_name, fused_var_name); fused_vars_name.emplace(var_name, fused_var_name);
PADDLE_ENFORCE_EQ(fused_var_set.count(fused_var_name), 0); PADDLE_ENFORCE_EQ(fused_var_set.count(fused_var_name), 0);
fused_var_set.insert(fused_var_name); fused_var_set.insert(fused_var_name);
...@@ -151,7 +150,7 @@ void FuseOptimizerOpPass::InitFusedGradsAndAllocSpaceForGrads( ...@@ -151,7 +150,7 @@ void FuseOptimizerOpPass::InitFusedGradsAndAllocSpaceForGrads(
// Init Grads // Init Grads
for (auto it = local_scopes.rbegin(); it != local_scopes.rend(); ++it) { for (auto it = local_scopes.rbegin(); it != local_scopes.rend(); ++it) {
auto &scope = *it; auto &scope = *it;
VLOG(10) << "Init " << fused_grad_name; VLOG(6) << "Init: " << fused_grad_name;
PADDLE_ENFORCE(scope->FindVar(fused_grad_name) == nullptr, PADDLE_ENFORCE(scope->FindVar(fused_grad_name) == nullptr,
"%s has existed in scope.", fused_grad_name); "%s has existed in scope.", fused_grad_name);
scope->Var(fused_grad_name)->GetMutable<LoDTensor>(); scope->Var(fused_grad_name)->GetMutable<LoDTensor>();
...@@ -211,13 +210,12 @@ void FuseOptimizerOpPass::RunInitOps(const std::vector<platform::Place> &places, ...@@ -211,13 +210,12 @@ void FuseOptimizerOpPass::RunInitOps(const std::vector<platform::Place> &places,
void FuseOptimizerOpPass::InitVars(const std::vector<Scope *> &local_scopes, void FuseOptimizerOpPass::InitVars(const std::vector<Scope *> &local_scopes,
const std::string &fused_var_name) const { const std::string &fused_var_name) const {
VLOG(10) << "Init FusedVars.";
// Alloc parameters and auxiliary vars in the respective scope. // Alloc parameters and auxiliary vars in the respective scope.
size_t idx = local_scopes.size(); size_t idx = local_scopes.size();
for (auto iter = local_scopes.rbegin(); iter != local_scopes.rend(); for (auto iter = local_scopes.rbegin(); iter != local_scopes.rend();
++iter, --idx) { ++iter, --idx) {
auto &scope = *iter; auto &scope = *iter;
VLOG(10) << "Init " << fused_var_name; VLOG(6) << "Init: " << fused_var_name;
PADDLE_ENFORCE(scope->FindVar(fused_var_name) == nullptr, PADDLE_ENFORCE(scope->FindVar(fused_var_name) == nullptr,
"%s has exist in scope[%d]", fused_var_name, idx); "%s has exist in scope[%d]", fused_var_name, idx);
scope->Var(fused_var_name)->GetMutable<LoDTensor>(); scope->Var(fused_var_name)->GetMutable<LoDTensor>();
...@@ -253,7 +251,7 @@ void FuseOptimizerOpPass::SortParametersAndAuxVars( ...@@ -253,7 +251,7 @@ void FuseOptimizerOpPass::SortParametersAndAuxVars(
for (auto &var_name : aux_vars.second) { for (auto &var_name : aux_vars.second) {
out << var_name << " "; out << var_name << " ";
} }
VLOG(10) << aux_vars.first << ": " << out.str(); VLOG(6) << aux_vars.first << ": " << out.str();
} }
std::vector<ir::Node *> sorted_ops; std::vector<ir::Node *> sorted_ops;
...@@ -271,12 +269,14 @@ void FuseOptimizerOpPass::GetSpecifiedOpsAndVars( ...@@ -271,12 +269,14 @@ void FuseOptimizerOpPass::GetSpecifiedOpsAndVars(
const { const {
if (node->Op()->Type() != op_type) return; if (node->Op()->Type() != op_type) return;
std::stringstream out;
for (auto &var_n : aux_vars_name) { for (auto &var_n : aux_vars_name) {
auto arg_names = node->Op()->Input(var_n); auto arg_names = node->Op()->Input(var_n);
PADDLE_ENFORCE_EQ(arg_names.size(), static_cast<size_t>(1)); PADDLE_ENFORCE_EQ(arg_names.size(), static_cast<size_t>(1));
(*aux_args_name)[var_n].emplace_back(arg_names[0]); (*aux_args_name)[var_n].emplace_back(arg_names[0]);
VLOG(10) << var_n << ", " << arg_names[0]; out << var_n << ", " << arg_names[0] << "; ";
} }
VLOG(7) << out.str();
ops->emplace_back(node); ops->emplace_back(node);
} }
......
...@@ -11,42 +11,43 @@ ...@@ -11,42 +11,43 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/fuse_sgd_op_pass.h"
#include <algorithm> #include <algorithm>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
const std::string FuseSgdOpPass::GetOpType() const { return "sgd"; } class FuseSgdOpPass : public FuseOptimizerOpPass {
private:
virtual const std::string GetOpType() const { return "sgd"; }
const std::vector<std::string> FuseSgdOpPass::GetAuxiliaryVarNames() const { virtual const std::vector<std::string> GetAuxiliaryVarNames() const {
return {}; return {};
} }
void FuseSgdOpPass::FuseOptimizerOps( // Fuse Sgd Ops
const std::unordered_map<std::string, std::vector<std::string>> virtual void FuseOptimizerOps(
&aux_var_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &sgd_ops, ir::Graph *graph) const {
FuseSgdOps(aux_var_set, fused_vars_name, sgd_ops, graph);
}
void FuseSgdOpPass::FuseSgdOps(
const std::unordered_map<std::string, std::vector<std::string>> &vars_set, const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name, const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &sgd_ops, ir::Graph *graph) const { const std::vector<ir::Node *> &sgd_ops, ir::Graph *graph) const {
PADDLE_ENFORCE_GT(sgd_ops.size(), static_cast<size_t>(0)); PADDLE_ENFORCE_GT(sgd_ops.size(), static_cast<size_t>(0));
// NOTE: fused_var is only exist in scope, so the graph doesn't have fused_var // NOTE: fused_var is only exist in scope, so the graph doesn't have
// node. // fused_var node.
int op_role = boost::get<int>( int op_role = boost::get<int>(
sgd_ops[0]->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())); sgd_ops[0]->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName()));
VLOG(10) << "Insert sgd to graph "; VLOG(7) << "Insert sgd to graph ";
// Add fused scale // Add fused scale
OpDesc Sgd_desc(sgd_ops[0]->Op()->Block()); OpDesc Sgd_desc(sgd_ops[0]->Op()->Block());
Sgd_desc.SetType("sgd"); Sgd_desc.SetType("sgd");
...@@ -54,7 +55,7 @@ void FuseSgdOpPass::FuseSgdOps( ...@@ -54,7 +55,7 @@ void FuseSgdOpPass::FuseSgdOps(
Sgd_desc.SetInput(kGrad, {fused_vars_name.at(kGrad)}); Sgd_desc.SetInput(kGrad, {fused_vars_name.at(kGrad)});
Sgd_desc.SetOutput("ParamOut", {fused_vars_name.at(kParam)}); Sgd_desc.SetOutput("ParamOut", {fused_vars_name.at(kParam)});
// TODO(zcd): The LearningRate, Beta1Pow, Beta2Pow should be equal. // TODO(zcd): The LearningRate should be equal.
Sgd_desc.SetInput(kLearningRate, sgd_ops[0]->Op()->Input(kLearningRate)); Sgd_desc.SetInput(kLearningRate, sgd_ops[0]->Op()->Input(kLearningRate));
// NOTE: multi_devices_pass requires that every op should have a role. // NOTE: multi_devices_pass requires that every op should have a role.
...@@ -63,8 +64,8 @@ void FuseSgdOpPass::FuseSgdOps( ...@@ -63,8 +64,8 @@ void FuseSgdOpPass::FuseSgdOps(
auto sgd_node = graph->CreateOpNode(&Sgd_desc); auto sgd_node = graph->CreateOpNode(&Sgd_desc);
InserInputAndOutputForOptOps(sgd_ops, sgd_node); InserInputAndOutputForOptOps(sgd_ops, sgd_node);
} }
};
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace details {
class FuseSgdOpPass : public FuseOptimizerOpPass {
private:
virtual const std::string GetOpType() const;
virtual const std::vector<std::string> GetAuxiliaryVarNames() const;
// Fuse Sgd Ops
virtual void FuseOptimizerOps(
const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &sgd_ops, ir::Graph *graph) const;
void FuseSgdOps(
const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &sgd_ops, ir::Graph *graph) const;
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -31,18 +31,17 @@ class TestFuseAdamOps(TestParallelExecutorBase): ...@@ -31,18 +31,17 @@ class TestFuseAdamOps(TestParallelExecutorBase):
if use_cuda and not core.is_compiled_with_cuda(): if use_cuda and not core.is_compiled_with_cuda():
return return
img, label = init_data() img, label = init_data()
feed_dict = {"image": img, "label": label}
not_fuse_op_first_loss, not_fuse_op_last_loss = self.check_network_convergence( not_fuse_op_first_loss, not_fuse_op_last_loss = self.check_network_convergence(
model, model,
feed_dict={"image": img, feed_dict=feed_dict,
"label": label},
use_cuda=use_cuda, use_cuda=use_cuda,
fuse_all_optimizer_ops=False, fuse_all_optimizer_ops=False,
memory_opt=False, # avoid the gradient's name changed in Python side. memory_opt=False, # avoid the gradient's name changed in Python side.
optimizer=optimizer) optimizer=optimizer)
fuse_op_first_loss, fuse_op_last_loss = self.check_network_convergence( fuse_op_first_loss, fuse_op_last_loss = self.check_network_convergence(
model, model,
feed_dict={"image": img, feed_dict=feed_dict,
"label": label},
use_cuda=use_cuda, use_cuda=use_cuda,
fuse_all_optimizer_ops=True, fuse_all_optimizer_ops=True,
memory_opt=False, # avoid the gradient's name changed in Python side. memory_opt=False, # avoid the gradient's name changed in Python side.
...@@ -63,7 +62,7 @@ class TestFuseAdamOps(TestParallelExecutorBase): ...@@ -63,7 +62,7 @@ class TestFuseAdamOps(TestParallelExecutorBase):
class TestFuseSGDOps(TestFuseAdamOps): class TestFuseSGDOps(TestFuseAdamOps):
def sgd_optimizer(self, learning_rate=1e-4): def sgd_optimizer(self, learning_rate=1e-3):
return fluid.optimizer.SGD(learning_rate=learning_rate) return fluid.optimizer.SGD(learning_rate=learning_rate)
def test_simple_fc_with_fuse_op(self): def test_simple_fc_with_fuse_op(self):
...@@ -79,5 +78,23 @@ class TestFuseSGDOps(TestFuseAdamOps): ...@@ -79,5 +78,23 @@ class TestFuseSGDOps(TestFuseAdamOps):
fc_with_batchnorm, False, optimizer=self.sgd_optimizer) fc_with_batchnorm, False, optimizer=self.sgd_optimizer)
class TestFuseMomentumOps(TestFuseAdamOps):
def momentum_optimizer(self, learning_rate=1e-3):
return fluid.optimizer.Momentum(
learning_rate=learning_rate, momentum=0.1)
def test_simple_fc_with_fuse_op(self):
self._compare_fused_optimizer_ops(
simple_fc_net, True, optimizer=self.momentum_optimizer)
self._compare_fused_optimizer_ops(
simple_fc_net, False, optimizer=self.momentum_optimizer)
def test_batchnorm_fc_with_fuse_op(self):
self._compare_fused_optimizer_ops(
fc_with_batchnorm, True, optimizer=self.momentum_optimizer)
self._compare_fused_optimizer_ops(
fc_with_batchnorm, False, optimizer=self.momentum_optimizer)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册