未验证 提交 056fdedd 编写于 作者: C chengduo 提交者: GitHub

Open fuse all reduce option (#19765)

* Open fuse all reduce op
test=develop

* Add Fuse optimization op log

* Add log in fuse_optimizer op pass and fuse all_reduce op pass

* replace with boost::optional<bool>
test=develop

* Polish code
test=develop

* fix code coverage
test=develop
上级 8c7e4119
...@@ -43,6 +43,12 @@ static inline bool SeqOnlyAllReduceOps(const BuildStrategy &strategy) { ...@@ -43,6 +43,12 @@ static inline bool SeqOnlyAllReduceOps(const BuildStrategy &strategy) {
!strategy.enable_parallel_graph_; !strategy.enable_parallel_graph_;
} }
static inline void ConvertDefaultValue(boost::optional<bool> *default_value) {
if (*default_value == boost::none) {
*default_value = true;
}
}
class ParallelExecutorPassBuilder : public ir::PassBuilder { class ParallelExecutorPassBuilder : public ir::PassBuilder {
public: public:
explicit ParallelExecutorPassBuilder(const BuildStrategy &strategy) explicit ParallelExecutorPassBuilder(const BuildStrategy &strategy)
...@@ -79,39 +85,55 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -79,39 +85,55 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
void ResolveOptionConfliction() { void ResolveOptionConfliction() {
// Specifies the restrictions between different pass. // Specifies the restrictions between different pass.
if (strategy_.enable_parallel_graph_) { if (strategy_.enable_parallel_graph_) {
VLOG_IF(3, strategy_.fuse_all_optimizer_ops_) LOG_IF(WARNING, strategy_.fuse_all_optimizer_ops_ == true)
<< "Currently, fuse_all_optimizer_ops doesn't work under " << "Currently, fuse_all_optimizer_ops doesn't work under "
"parallel_graph."; "parallel_graph.";
strategy_.fuse_all_optimizer_ops_ = false; strategy_.fuse_all_optimizer_ops_ = false;
VLOG_IF(3, strategy_.fuse_all_reduce_ops_) LOG_IF(WARNING, strategy_.fuse_all_reduce_ops_ == true)
<< "fuse_all_reduce_ops doesn't work under " << "fuse_all_reduce_ops doesn't work under "
"parallel_graph."; "parallel_graph.";
strategy_.fuse_all_reduce_ops_ = false; strategy_.fuse_all_reduce_ops_ = false;
} }
if (strategy_.is_distribution_) { if (strategy_.is_distribution_) {
VLOG_IF(3, strategy_.fuse_all_optimizer_ops_) LOG_IF(WARNING, strategy_.fuse_all_optimizer_ops_ == true)
<< "Currently, fuse_all_optimizer_ops only works under " << "Currently, fuse_all_optimizer_ops only works under "
"Non-distributed mode."; "Non-distributed mode.";
strategy_.fuse_all_optimizer_ops_ = false; strategy_.fuse_all_optimizer_ops_ = false;
VLOG_IF(3, strategy_.fuse_all_reduce_ops_) LOG_IF(WARNING, strategy_.fuse_all_reduce_ops_ == true)
<< "Currently, fuse_all_reduce_ops_ only works under " << "Currently, fuse_all_reduce_ops_ only works under "
"Non-distributed mode."; "Non-distributed mode.";
strategy_.fuse_all_reduce_ops_ = false; strategy_.fuse_all_reduce_ops_ = false;
} }
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) { if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
VLOG_IF(3, strategy_.fuse_all_optimizer_ops_) LOG_IF(WARNING, strategy_.fuse_all_optimizer_ops_ == true)
<< "Currently, fuse_all_optimizer_ops only works under AllReduce " << "Currently, fuse_all_optimizer_ops only works under AllReduce "
"mode."; "mode.";
strategy_.fuse_all_optimizer_ops_ = false; strategy_.fuse_all_optimizer_ops_ = false;
VLOG_IF(3, strategy_.fuse_all_reduce_ops_) LOG_IF(WARNING, strategy_.fuse_all_reduce_ops_ == true)
<< "fuse_all_optimizer_ops only work in Reducer mode."; << "fuse_all_optimizer_ops only work in Reducer mode.";
strategy_.fuse_all_reduce_ops_ = false; strategy_.fuse_all_reduce_ops_ = false;
} }
if (strategy_.async_mode_) { if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
VLOG_IF(3, strategy_.fuse_all_optimizer_ops_) LOG_IF(WARNING, strategy_.fuse_broadcast_ops_ == true)
<< "Currently, fuse_broadcast_ops only works under Reduce "
"mode.";
strategy_.fuse_broadcast_ops_ = false;
}
ConvertDefaultValue(&strategy_.fuse_all_optimizer_ops_);
ConvertDefaultValue(&strategy_.fuse_all_reduce_ops_);
ConvertDefaultValue(&strategy_.fuse_broadcast_ops_);
if (strategy_.fuse_all_optimizer_ops_ == true) {
LOG_IF(WARNING, strategy_.async_mode_)
<< "Currently, fuse_all_optimizer_ops doesn't work under " << "Currently, fuse_all_optimizer_ops doesn't work under "
"async mode."; "async mode.";
strategy_.fuse_all_optimizer_ops_ = false; strategy_.fuse_all_optimizer_ops_ = !strategy_.async_mode_;
}
if (strategy_.fuse_all_reduce_ops_ == true) {
LOG_IF(WARNING, strategy_.async_mode_)
<< "fuse_all_optimizer_ops only work in Reducer mode.";
strategy_.fuse_all_reduce_ops_ = !strategy_.async_mode_;
} }
} }
...@@ -151,7 +173,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -151,7 +173,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// NOTE: fuse_all_xx_ops will count the number of xx operator first, // NOTE: fuse_all_xx_ops will count the number of xx operator first,
// if the number is zero, fuse_all_reduce_ops will do nothing. // if the number is zero, fuse_all_reduce_ops will do nothing.
// Currently, only one type of optimization algorithm can be fused. // Currently, only one type of optimization algorithm can be fused.
if (strategy_.fuse_all_optimizer_ops_) { if (strategy_.fuse_all_optimizer_ops_ == true) {
AppendPass("fuse_adam_op_pass"); AppendPass("fuse_adam_op_pass");
AppendPass("fuse_sgd_op_pass"); AppendPass("fuse_sgd_op_pass");
AppendPass("fuse_momentum_op_pass"); AppendPass("fuse_momentum_op_pass");
...@@ -207,6 +229,11 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -207,6 +229,11 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
} }
} }
void AppendPassWithCheck(const boost::optional<bool> &append_pass,
const std::string &pass_name) {
AppendPassWithCheck(append_pass == true, pass_name);
}
void AppendPassWithCheck(bool append_pass, const std::string &pass_name) { void AppendPassWithCheck(bool append_pass, const std::string &pass_name) {
if (append_pass) { if (append_pass) {
AppendPass(pass_name); AppendPass(pass_name);
......
...@@ -89,8 +89,8 @@ struct BuildStrategy { ...@@ -89,8 +89,8 @@ struct BuildStrategy {
bool fuse_elewise_add_act_ops_{false}; bool fuse_elewise_add_act_ops_{false};
// Fuse_all_optimizer_ops and fuse_all_reduce_ops require that gradients // Fuse_all_optimizer_ops and fuse_all_reduce_ops require that gradients
// should not be sparse types // should not be sparse types
bool fuse_all_optimizer_ops_{true}; boost::optional<bool> fuse_all_optimizer_ops_{boost::none};
bool fuse_all_reduce_ops_{false}; boost::optional<bool> fuse_all_reduce_ops_{boost::none};
// fuse_relu_depthwise_conv can fuse the `relu -> // fuse_relu_depthwise_conv can fuse the `relu ->
// depthwise_conv` // depthwise_conv`
bool fuse_relu_depthwise_conv_{false}; bool fuse_relu_depthwise_conv_{false};
...@@ -98,7 +98,7 @@ struct BuildStrategy { ...@@ -98,7 +98,7 @@ struct BuildStrategy {
// faster. Because fusing broadcast OP equals delaying the execution of all // faster. Because fusing broadcast OP equals delaying the execution of all
// broadcast Ops, in this case, all nccl streams are used only for reduce // broadcast Ops, in this case, all nccl streams are used only for reduce
// operations for a period of time. // operations for a period of time.
bool fuse_broadcast_ops_{true}; boost::optional<bool> fuse_broadcast_ops_{boost::none};
// replace batch_norm with sync_batch_norm. // replace batch_norm with sync_batch_norm.
bool sync_batch_norm_{false}; bool sync_batch_norm_{false};
......
...@@ -124,7 +124,7 @@ class FuseAdamOpPass : public FuseOptimizerOpPass { ...@@ -124,7 +124,7 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
// NOTE: fused_var is only exist in scope, so the graph doesn't have // NOTE: fused_var is only exist in scope, so the graph doesn't have
// fused_var node. // fused_var node.
VLOG(7) << "Insert adam to graph "; VLOG(6) << "Insert adam to graph ";
OpDesc adam_desc(adam_ops[0]->Op()->Block()); OpDesc adam_desc(adam_ops[0]->Op()->Block());
adam_desc.SetType("adam"); adam_desc.SetType("adam");
adam_desc.SetInput(kParam, {fused_vars_name.at(kParam)}); adam_desc.SetInput(kParam, {fused_vars_name.at(kParam)});
...@@ -180,7 +180,7 @@ class FuseAdamOpPass : public FuseOptimizerOpPass { ...@@ -180,7 +180,7 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
scale_ops.emplace_back(*scale_op_iter); scale_ops.emplace_back(*scale_op_iter);
} }
PADDLE_ENFORCE_EQ(scale_ops.size(), beta_name.size()); PADDLE_ENFORCE_EQ(scale_ops.size(), beta_name.size());
VLOG(7) << "The number of scale op is " << scale_ops.size() << "."; VLOG(6) << "The number of scale op is " << scale_ops.size() << ".";
// Check attributions // Check attributions
// NOTE: If new attribution is added, the following code maybe need change. // NOTE: If new attribution is added, the following code maybe need change.
int op_role = boost::get<int>( int op_role = boost::get<int>(
...@@ -205,7 +205,7 @@ class FuseAdamOpPass : public FuseOptimizerOpPass { ...@@ -205,7 +205,7 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
// NOTE: fused_var is only exist in scope, so the graph doesn't have // NOTE: fused_var is only exist in scope, so the graph doesn't have
// fused_var node. // fused_var node.
VLOG(7) << "Insert fused scale to graph."; VLOG(6) << "Insert fused scale to graph.";
OpDesc scale_desc(scale_ops[0]->Op()->Block()); OpDesc scale_desc(scale_ops[0]->Op()->Block());
scale_desc.SetType("scale"); scale_desc.SetType("scale");
scale_desc.SetInput("X", {fused_var_name}); scale_desc.SetInput("X", {fused_var_name});
......
...@@ -61,7 +61,7 @@ class FuseMomentumOpPass : public FuseOptimizerOpPass { ...@@ -61,7 +61,7 @@ class FuseMomentumOpPass : public FuseOptimizerOpPass {
// NOTE: fused_var is only exist in scope, so the graph doesn't have // NOTE: fused_var is only exist in scope, so the graph doesn't have
// fused_var node. // fused_var node.
VLOG(7) << "Insert momentum to graph "; VLOG(6) << "Insert momentum to graph ";
OpDesc momentum_desc(momentum_ops[0]->Op()->Block()); OpDesc momentum_desc(momentum_ops[0]->Op()->Block());
momentum_desc.SetType("momentum"); momentum_desc.SetType("momentum");
momentum_desc.SetInput(kParam, {fused_vars_name.at(kParam)}); momentum_desc.SetInput(kParam, {fused_vars_name.at(kParam)});
......
...@@ -49,7 +49,7 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const { ...@@ -49,7 +49,7 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
} }
VLOG(6) << "Find " << fuse_op_type << " operators : " << opt_ops_num VLOG(6) << "Find " << fuse_op_type << " operators : " << opt_ops_num
<< ", and " << opt_nodes.size() << " for dense gradients "; << ", and " << opt_nodes.size() << " for dense gradients.";
if (opt_nodes.size() == 0 || result.Has(details::kFusedOptType)) { if (opt_nodes.size() == 0 || result.Has(details::kFusedOptType)) {
if (result.Has(details::kFusedOptType)) { if (result.Has(details::kFusedOptType)) {
auto &opt_type = auto &opt_type =
...@@ -69,6 +69,11 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const { ...@@ -69,6 +69,11 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
return; return;
} }
LOG(WARNING) << "Find " << fuse_op_type << " operators : " << opt_ops_num
<< ", and " << opt_nodes.size() << " for dense gradients. "
<< "To make the speed faster, those optimization are fused "
"during training.";
result.Set(details::kFusedOptType, new details::FusedOptType); result.Set(details::kFusedOptType, new details::FusedOptType);
result.Get<details::FusedOptType>(details::kFusedOptType) = fuse_op_type; result.Get<details::FusedOptType>(details::kFusedOptType) = fuse_op_type;
if (!result.Has(details::kProgramDescs)) { if (!result.Has(details::kProgramDescs)) {
...@@ -149,7 +154,7 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const { ...@@ -149,7 +154,7 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
&opt_nodes); &opt_nodes);
grad_fused = true; grad_fused = true;
} else { } else {
VLOG(10) << "The number of new gradients is " << new_grad_idx.size(); VLOG(6) << "The number of new gradients is " << new_grad_idx.size();
if (new_grad_idx.size() == 1) return; if (new_grad_idx.size() == 1) return;
// NOTE(zcd): If the gradients of backward stage and optimization stage // NOTE(zcd): If the gradients of backward stage and optimization stage
// have diff, Only take care of the the gradient of optimization stage. // have diff, Only take care of the the gradient of optimization stage.
......
...@@ -42,7 +42,7 @@ class FuseSgdOpPass : public FuseOptimizerOpPass { ...@@ -42,7 +42,7 @@ class FuseSgdOpPass : public FuseOptimizerOpPass {
int op_role = boost::get<int>( int op_role = boost::get<int>(
sgd_ops[0]->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())); sgd_ops[0]->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName()));
VLOG(7) << "Insert sgd to graph "; VLOG(6) << "Insert sgd to graph.";
// Add fused scale // Add fused scale
OpDesc Sgd_desc(sgd_ops[0]->Op()->Block()); OpDesc Sgd_desc(sgd_ops[0]->Op()->Block());
Sgd_desc.SetType("sgd"); Sgd_desc.SetType("sgd");
......
...@@ -56,7 +56,7 @@ class FuseAllReduceOpPass : public ir::Pass { ...@@ -56,7 +56,7 @@ class FuseAllReduceOpPass : public ir::Pass {
std::unordered_map<std::string, Node *> all_reduce_ops = std::unordered_map<std::string, Node *> all_reduce_ops =
GetAllReduceOps(result, places, grads); GetAllReduceOps(result, places, grads);
VLOG(10) << "Find all_reduce_ops: " << all_reduce_ops.size(); VLOG(6) << "Find all_reduce_ops: " << all_reduce_ops.size();
if (all_reduce_ops.size() == 0) { if (all_reduce_ops.size() == 0) {
return; return;
} }
...@@ -65,11 +65,16 @@ class FuseAllReduceOpPass : public ir::Pass { ...@@ -65,11 +65,16 @@ class FuseAllReduceOpPass : public ir::Pass {
"The number of all_reduce OpHandle is not equal to the " "The number of all_reduce OpHandle is not equal to the "
"number of grads. Maybe some gradients are sparse type, " "number of grads. Maybe some gradients are sparse type, "
"it is not supported currently."); "it is not supported currently.");
VLOG(10) << "Insert fused_all_reduce";
auto &group_params_grads = graph->Get<details::GroupParamsAndGrads>( auto &group_params_grads = graph->Get<details::GroupParamsAndGrads>(
details::kGroupParamsAndDenseGrads); details::kGroupParamsAndDenseGrads);
LOG(WARNING) << string::Sprintf(
"Find all_reduce operators: %d. To make the speed faster, some "
"all_reduce ops are fused during training, after fusion, "
"the number of all_reduce ops is %d.",
all_reduce_ops.size(), group_params_grads.size());
for (auto &group_p_g : group_params_grads) { for (auto &group_p_g : group_params_grads) {
size_t group_size = group_p_g.size(); size_t group_size = group_p_g.size();
PADDLE_ENFORCE_GT(group_size, static_cast<size_t>(0)); PADDLE_ENFORCE_GT(group_size, static_cast<size_t>(0));
......
...@@ -699,7 +699,7 @@ bool ReduceSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result, ...@@ -699,7 +699,7 @@ bool ReduceSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
void ReduceSSAGraphBuilder::InsertPostprocessOps(ir::Graph *result) const { void ReduceSSAGraphBuilder::InsertPostprocessOps(ir::Graph *result) const {
if (UseGPU()) { if (UseGPU()) {
if (strategy_.fuse_broadcast_ops_) { if (strategy_.fuse_broadcast_ops_ == true) {
CreateFusedBroadcastOp(result, bcast_var_name_set_); CreateFusedBroadcastOp(result, bcast_var_name_set_);
} else { } else {
for (size_t dev_id = 0; dev_id < bcast_var_name_set_.size(); ++dev_id) { for (size_t dev_id = 0; dev_id < bcast_var_name_set_.size(); ++dev_id) {
...@@ -1068,7 +1068,7 @@ void DistSSAGraphBuilder::InsertPostprocessOps(ir::Graph *result) const { ...@@ -1068,7 +1068,7 @@ void DistSSAGraphBuilder::InsertPostprocessOps(ir::Graph *result) const {
strategy_.reduce_ == details::BuildStrategy::ReduceStrategy::kReduce) { strategy_.reduce_ == details::BuildStrategy::ReduceStrategy::kReduce) {
return; return;
} }
if (strategy_.fuse_broadcast_ops_) { if (strategy_.fuse_broadcast_ops_ == true) {
CreateFusedBroadcastOp(result, bcast_var_name_set_); CreateFusedBroadcastOp(result, bcast_var_name_set_);
} else { } else {
for (size_t dev_id = 0; dev_id < bcast_var_name_set_.size(); ++dev_id) { for (size_t dev_id = 0; dev_id < bcast_var_name_set_.size(); ++dev_id) {
......
...@@ -123,7 +123,7 @@ class AsyncSSAGraphBuilder : public MultiDevSSAGraphBuilderBase { ...@@ -123,7 +123,7 @@ class AsyncSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
const std::string &g_name) const override {} const std::string &g_name) const override {}
bool NeedCollectiveForGrad(const std::string &grad_name, bool NeedCollectiveForGrad(const std::string &grad_name,
std::vector<ir::Node *> ops) const { std::vector<ir::Node *> ops) const override {
return false; return false;
} }
......
...@@ -338,8 +338,8 @@ PYBIND11_MODULE(core_noavx, m) { ...@@ -338,8 +338,8 @@ PYBIND11_MODULE(core_noavx, m) {
recursive_sequence_lengths.end(), recursive_sequence_lengths.end(),
std::back_inserter(new_lod)); std::back_inserter(new_lod));
LoD new_offset_lod = ConvertToOffsetBasedLoD(new_lod); LoD new_offset_lod = ConvertToOffsetBasedLoD(new_lod);
PADDLE_ENFORCE( PADDLE_ENFORCE_EQ(
CheckLoD(new_offset_lod, -1), CheckLoD(new_offset_lod, -1), true,
"the provided recursive_sequence_lengths info is invalid"); "the provided recursive_sequence_lengths info is invalid");
new (&instance) LoDTensor(new_offset_lod); new (&instance) LoDTensor(new_offset_lod);
}) })
...@@ -355,7 +355,8 @@ PYBIND11_MODULE(core_noavx, m) { ...@@ -355,7 +355,8 @@ PYBIND11_MODULE(core_noavx, m) {
LoD new_lod; LoD new_lod;
new_lod.reserve(lod.size()); new_lod.reserve(lod.size());
std::copy(lod.begin(), lod.end(), std::back_inserter(new_lod)); std::copy(lod.begin(), lod.end(), std::back_inserter(new_lod));
PADDLE_ENFORCE(CheckLoD(new_lod, vectorize(self.dims()).front()), PADDLE_ENFORCE_EQ(
CheckLoD(new_lod, vectorize(self.dims()).front()), true,
"the provided lod info is invalid"); "the provided lod info is invalid");
self.set_lod(new_lod); self.set_lod(new_lod);
}, },
...@@ -386,8 +387,8 @@ PYBIND11_MODULE(core_noavx, m) { ...@@ -386,8 +387,8 @@ PYBIND11_MODULE(core_noavx, m) {
recursive_sequence_lengths.end(), recursive_sequence_lengths.end(),
std::back_inserter(new_lod)); std::back_inserter(new_lod));
LoD new_offset_lod = ConvertToOffsetBasedLoD(new_lod); LoD new_offset_lod = ConvertToOffsetBasedLoD(new_lod);
PADDLE_ENFORCE( PADDLE_ENFORCE_EQ(
CheckLoD(new_offset_lod, vectorize(self.dims()).front()), CheckLoD(new_offset_lod, vectorize(self.dims()).front()), true,
"the provided recursive_sequence_lengths info is invalid"); "the provided recursive_sequence_lengths info is invalid");
self.set_lod(new_offset_lod); self.set_lod(new_offset_lod);
}, },
...@@ -588,7 +589,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -588,7 +589,7 @@ All parameter, weight, gradient are variables in Paddle.
#endif #endif
.def("get_reader", .def("get_reader",
[](Variable &self) -> framework::ReaderHolder * { [](Variable &self) -> framework::ReaderHolder * {
PADDLE_ENFORCE(self.IsType<framework::ReaderHolder>()); PADDLE_ENFORCE_EQ(self.IsType<framework::ReaderHolder>(), true);
return self.GetMutable<framework::ReaderHolder>(); return self.GetMutable<framework::ReaderHolder>();
}, },
py::return_value_policy::reference); py::return_value_policy::reference);
...@@ -713,8 +714,8 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -713,8 +714,8 @@ All parameter, weight, gradient are variables in Paddle.
auto &info = iter.second; auto &info = iter.second;
if (info.HasOpProtoAndChecker()) { if (info.HasOpProtoAndChecker()) {
std::string str; std::string str;
PADDLE_ENFORCE( PADDLE_ENFORCE_EQ(
info.Proto().SerializeToString(&str), info.Proto().SerializeToString(&str), true,
"Serialize OpProto Error. This could be a bug of Paddle."); "Serialize OpProto Error. This could be a bug of Paddle.");
ret_values.emplace_back(str); ret_values.emplace_back(str);
} }
...@@ -942,12 +943,13 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -942,12 +943,13 @@ All parameter, weight, gradient are variables in Paddle.
}); });
py::class_<OperatorBase>(m, "Operator") py::class_<OperatorBase>(m, "Operator")
.def_static("create", .def_static(
"create",
[](py::bytes protobin) { [](py::bytes protobin) {
proto::OpDesc desc; proto::OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), PADDLE_ENFORCE_EQ(desc.ParsePartialFromString(protobin), true,
"Cannot parse user input to OpDesc"); "Cannot parse user input to OpDesc");
PADDLE_ENFORCE(desc.IsInitialized(), PADDLE_ENFORCE_EQ(desc.IsInitialized(), true,
"User OpDesc is not initialized, reason %s", "User OpDesc is not initialized, reason %s",
desc.InitializationErrorString()); desc.InitializationErrorString());
return OpRegistry::CreateOp(desc); return OpRegistry::CreateOp(desc);
...@@ -1323,7 +1325,8 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1323,7 +1325,8 @@ All parameter, weight, gradient are variables in Paddle.
"reduce_strategy", "reduce_strategy",
[](const BuildStrategy &self) { return self.reduce_; }, [](const BuildStrategy &self) { return self.reduce_; },
[](BuildStrategy &self, BuildStrategy::ReduceStrategy strategy) { [](BuildStrategy &self, BuildStrategy::ReduceStrategy strategy) {
PADDLE_ENFORCE(!self.IsFinalized(), "BuildStrategy is finlaized."); PADDLE_ENFORCE_EQ(!self.IsFinalized(), true,
"BuildStrategy is finlaized.");
self.reduce_ = strategy; self.reduce_ = strategy;
}, },
R"DOC(The type is fluid.BuildStrategy.ReduceStrategy, there are two reduce R"DOC(The type is fluid.BuildStrategy.ReduceStrategy, there are two reduce
...@@ -1346,7 +1349,8 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1346,7 +1349,8 @@ All parameter, weight, gradient are variables in Paddle.
[](const BuildStrategy &self) { return self.gradient_scale_; }, [](const BuildStrategy &self) { return self.gradient_scale_; },
[](BuildStrategy &self, [](BuildStrategy &self,
BuildStrategy::GradientScaleStrategy strategy) { BuildStrategy::GradientScaleStrategy strategy) {
PADDLE_ENFORCE(!self.IsFinalized(), "BuildStrategy is finalized."); PADDLE_ENFORCE_EQ(!self.IsFinalized(), true,
"BuildStrategy is finalized.");
self.gradient_scale_ = strategy; self.gradient_scale_ = strategy;
}, },
R"DOC(The type is fluid.BuildStrategy.GradientScaleStrategy, there are three R"DOC(The type is fluid.BuildStrategy.GradientScaleStrategy, there are three
...@@ -1407,7 +1411,8 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1407,7 +1411,8 @@ All parameter, weight, gradient are variables in Paddle.
"debug_graphviz_path", "debug_graphviz_path",
[](const BuildStrategy &self) { return self.debug_graphviz_path_; }, [](const BuildStrategy &self) { return self.debug_graphviz_path_; },
[](BuildStrategy &self, const std::string &path) { [](BuildStrategy &self, const std::string &path) {
PADDLE_ENFORCE(!self.IsFinalized(), "BuildStrategy is finlaized."); PADDLE_ENFORCE_EQ(!self.IsFinalized(), true,
"BuildStrategy is finlaized.");
self.debug_graphviz_path_ = path; self.debug_graphviz_path_ = path;
}, },
R"DOC(The type is STR, debug_graphviz_path indicates the path that R"DOC(The type is STR, debug_graphviz_path indicates the path that
...@@ -1428,7 +1433,8 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1428,7 +1433,8 @@ All parameter, weight, gradient are variables in Paddle.
return self.enable_sequential_execution_; return self.enable_sequential_execution_;
}, },
[](BuildStrategy &self, bool b) { [](BuildStrategy &self, bool b) {
PADDLE_ENFORCE(!self.IsFinalized(), "BuildStrategy is finlaized."); PADDLE_ENFORCE_EQ(!self.IsFinalized(), true,
"BuildStrategy is finlaized.");
self.enable_sequential_execution_ = b; self.enable_sequential_execution_ = b;
}, },
R"DOC(The type is BOOL. If set True, the execution order of ops would R"DOC(The type is BOOL. If set True, the execution order of ops would
...@@ -1447,7 +1453,8 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1447,7 +1453,8 @@ All parameter, weight, gradient are variables in Paddle.
return self.remove_unnecessary_lock_; return self.remove_unnecessary_lock_;
}, },
[](BuildStrategy &self, bool b) { [](BuildStrategy &self, bool b) {
PADDLE_ENFORCE(!self.IsFinalized(), "BuildStrategy is finlaized."); PADDLE_ENFORCE_EQ(!self.IsFinalized(), true,
"BuildStrategy is finlaized.");
self.remove_unnecessary_lock_ = b; self.remove_unnecessary_lock_ = b;
}, },
R"DOC(The type is BOOL. If set True, some locks in GPU ops would be R"DOC(The type is BOOL. If set True, some locks in GPU ops would be
...@@ -1508,7 +1515,8 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1508,7 +1515,8 @@ All parameter, weight, gradient are variables in Paddle.
return self.fuse_elewise_add_act_ops_; return self.fuse_elewise_add_act_ops_;
}, },
[](BuildStrategy &self, bool b) { [](BuildStrategy &self, bool b) {
PADDLE_ENFORCE(!self.IsFinalized(), "BuildStrategy is finlaized."); PADDLE_ENFORCE_EQ(!self.IsFinalized(), true,
"BuildStrategy is finlaized.");
self.fuse_elewise_add_act_ops_ = b; self.fuse_elewise_add_act_ops_ = b;
}, },
R"DOC(The type is BOOL, fuse_elewise_add_act_ops indicate whether R"DOC(The type is BOOL, fuse_elewise_add_act_ops indicate whether
...@@ -1528,7 +1536,8 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1528,7 +1536,8 @@ All parameter, weight, gradient are variables in Paddle.
return self.fuse_relu_depthwise_conv_; return self.fuse_relu_depthwise_conv_;
}, },
[](BuildStrategy &self, bool b) { [](BuildStrategy &self, bool b) {
PADDLE_ENFORCE(!self.IsFinalized(), "BuildStrategy is finlaized."); PADDLE_ENFORCE_EQ(!self.IsFinalized(), true,
"BuildStrategy is finlaized.");
self.fuse_relu_depthwise_conv_ = b; self.fuse_relu_depthwise_conv_ = b;
}, },
R"DOC(The type is BOOL, fuse_relu_depthwise_conv indicate whether R"DOC(The type is BOOL, fuse_relu_depthwise_conv indicate whether
...@@ -1544,11 +1553,14 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1544,11 +1553,14 @@ All parameter, weight, gradient are variables in Paddle.
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.fuse_relu_depthwise_conv = True build_strategy.fuse_relu_depthwise_conv = True
)DOC") )DOC")
.def_property( .def_property("fuse_broadcast_ops",
"fuse_broadcast_ops", [](const BuildStrategy &self) {
[](const BuildStrategy &self) { return self.fuse_broadcast_ops_; }, return self.fuse_broadcast_ops_ == true ||
self.fuse_broadcast_ops_ == boost::none;
},
[](BuildStrategy &self, bool b) { [](BuildStrategy &self, bool b) {
PADDLE_ENFORCE(!self.IsFinalized(), "BuildStrategy is finlaized."); PADDLE_ENFORCE_EQ(!self.IsFinalized(), true,
"BuildStrategy is finlaized.");
self.fuse_broadcast_ops_ = b; self.fuse_broadcast_ops_ = b;
}, },
R"DOC(The type is BOOL, fuse_broadcast_op indicates whether R"DOC(The type is BOOL, fuse_broadcast_op indicates whether
...@@ -1559,10 +1571,11 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1559,10 +1571,11 @@ All parameter, weight, gradient are variables in Paddle.
for NCCLReduce operations for a period of time. Default False.)DOC") for NCCLReduce operations for a period of time. Default False.)DOC")
.def_property("fuse_all_optimizer_ops", .def_property("fuse_all_optimizer_ops",
[](const BuildStrategy &self) { [](const BuildStrategy &self) {
return self.fuse_all_optimizer_ops_; return self.fuse_all_optimizer_ops_ == true ||
self.fuse_all_optimizer_ops_ == boost::none;
}, },
[](BuildStrategy &self, bool b) { [](BuildStrategy &self, bool b) {
PADDLE_ENFORCE(!self.IsFinalized(), PADDLE_ENFORCE_EQ(!self.IsFinalized(), true,
"BuildStrategy is finlaized."); "BuildStrategy is finlaized.");
self.fuse_all_optimizer_ops_ = b; self.fuse_all_optimizer_ops_ = b;
}) })
...@@ -1570,7 +1583,8 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1570,7 +1583,8 @@ All parameter, weight, gradient are variables in Paddle.
"sync_batch_norm", "sync_batch_norm",
[](const BuildStrategy &self) { return self.sync_batch_norm_; }, [](const BuildStrategy &self) { return self.sync_batch_norm_; },
[](BuildStrategy &self, bool b) { [](BuildStrategy &self, bool b) {
PADDLE_ENFORCE(!self.IsFinalized(), "BuildStrategy is finlaized."); PADDLE_ENFORCE_EQ(!self.IsFinalized(), true,
"BuildStrategy is finlaized.");
self.sync_batch_norm_ = b; self.sync_batch_norm_ = b;
}, },
R"DOC(The type is BOOL, sync_batch_norm indicates whether to use R"DOC(The type is BOOL, sync_batch_norm indicates whether to use
...@@ -1637,7 +1651,10 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1637,7 +1651,10 @@ All parameter, weight, gradient are variables in Paddle.
[](BuildStrategy &self, bool b) { self.enable_inplace_ = b; }) [](BuildStrategy &self, bool b) { self.enable_inplace_ = b; })
.def_property( .def_property(
"fuse_all_reduce_ops", "fuse_all_reduce_ops",
[](const BuildStrategy &self) { return self.fuse_all_reduce_ops_; }, [](const BuildStrategy &self) {
return self.fuse_all_reduce_ops_ == true ||
self.fuse_all_reduce_ops_ == boost::none;
},
[](BuildStrategy &self, bool b) { self.fuse_all_reduce_ops_ = b; }) [](BuildStrategy &self, bool b) { self.fuse_all_reduce_ops_ = b; })
.def_property("enable_backward_optimizer_op_deps", .def_property("enable_backward_optimizer_op_deps",
[](const BuildStrategy &self) { [](const BuildStrategy &self) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册