diff --git a/CMakeLists.txt b/CMakeLists.txt index 97af6192cb899590927b80de9bda217c5ac56221..7a7b5860a122a853fc9ce1da6494fc039b38bc10 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -309,5 +309,5 @@ if (ON_INFER) add_definitions(-DPADDLE_ON_INFERENCE) else() #TODO(luotao), combine this warning with `make inference_lib_dist` command. - message(WARNING "On inference mode, will take place some specific optimization. Only used in make inference_lib_dist.") + message(WARNING "On inference mode, will take place some specific optimization. Turn on the ON_INFER flag when building inference_lib only.") endif() diff --git a/benchmark/fluid/args.py b/benchmark/fluid/args.py index 9540900b112f54594bbfdbc8d7cd3b6e1f5269dd..ff616ddbb2cb1cb7f348d6d164815823b08b7629 100644 --- a/benchmark/fluid/args.py +++ b/benchmark/fluid/args.py @@ -142,5 +142,10 @@ def parse_args(): choices=['reduce', 'all_reduce'], default='all_reduce', help='Specify the reduce strategy, can be reduce, all_reduce') + parser.add_argument( + '--fuse_broadcast_op', + action='store_true', + help='If set, would fuse multiple broadcast operators into one fused_broadcast operator.' + ) args = parser.parse_args() return args diff --git a/benchmark/fluid/fluid_benchmark.py b/benchmark/fluid/fluid_benchmark.py index ddd9fe809853a830ca676cc98f1819f683866def..5f3ce300acc44ad8d2898c27296b866c403f3cc8 100644 --- a/benchmark/fluid/fluid_benchmark.py +++ b/benchmark/fluid/fluid_benchmark.py @@ -177,6 +177,7 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog, else: build_strategy.reduce_strategy = fluid.BuildStrategy( ).ReduceStrategy.AllReduce + build_strategy.fuse_broadcast_op = args.fuse_broadcast_op avg_loss = train_args[0] @@ -240,7 +241,6 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog, if args.use_fake_data or args.use_reader_op: try: - fetch_ret = exe.run(fetch_list) except fluid.core.EOFException as eof: break diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 0d90bf3cc12e285f5cafd80180c12ddeb4ad8b51..2b8b82e74fc49d454b5331460acbffd0e9404fb5 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -355,6 +355,8 @@ paddle.fluid.optimizer.ModelAverage.__init__ ArgSpec(args=['self', 'average_wind paddle.fluid.optimizer.ModelAverage.apply ArgSpec(args=[], varargs='args', keywords='kwds', defaults=None) paddle.fluid.optimizer.ModelAverage.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.optimizer.ModelAverage.restore ArgSpec(args=['self', 'executor'], varargs=None, keywords=None, defaults=None) +paddle.fluid.optimizer.LarsMomentumOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'momentum', 'lars_coeff', 'lars_weight_decay', 'regularization', 'name'], varargs=None, keywords=None, defaults=(0.001, 0.0005, None, None)) +paddle.fluid.optimizer.LarsMomentumOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.backward.append_backward ArgSpec(args=['loss', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.regularizer.L1DecayRegularizer.__init__ ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,)) paddle.fluid.regularizer.L2DecayRegularizer.__init__ ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,)) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index e0a3ef5a9c6c53c42ebea1a41cac0d18a77781b2..17188ac5f301102ae79c6ace676b84ee66e28801 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -16,12 +16,14 @@ if(WITH_GPU) dynload_cuda variable_visitor) nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda) nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda) + nv_library(fused_broadcast_op_handle SRCS fused_broadcast_op_handle.cc DEPS broadcast_op_handle) else() cc_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory variable_visitor) cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim) cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) + cc_library(fused_broadcast_op_handle SRCS fused_broadcast_op_handle.cc DEPS broadcast_op_handle) endif() cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_base scope lod_tensor) @@ -34,7 +36,7 @@ if(WITH_GPU) endif() cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle - scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle) + scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle) if(WITH_GPU) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass) @@ -58,4 +60,4 @@ cc_library(fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executo cc_library(build_strategy SRCS build_strategy.cc DEPS graph_viz_pass multi_devices_graph_pass multi_devices_graph_print_pass multi_devices_graph_check_pass - fuse_elewise_add_act_pass) + fuse_elewise_add_act_pass multi_batch_merge_pass) diff --git a/paddle/fluid/framework/details/broadcast_op_handle.cc b/paddle/fluid/framework/details/broadcast_op_handle.cc index 4fdab5cd94358d08eac7f8b041bf16d09042f0bd..5b5a10e22776bee5c61a55c163c1732692551e36 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.cc +++ b/paddle/fluid/framework/details/broadcast_op_handle.cc @@ -48,16 +48,23 @@ void BroadcastOpHandle::RunImpl() { var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get()); } + BroadcastOneVar(*in_var_handle, out_var_handles, var_scopes); +} + +void BroadcastOpHandle::BroadcastOneVar( + const VarHandle &in_var_handle, + const std::vector &out_var_handles, + const std::vector &var_scopes) { auto *in_var = - var_scopes.at(in_var_handle->scope_idx_)->FindVar(in_var_handle->name_); + var_scopes.at(in_var_handle.scope_idx_)->FindVar(in_var_handle.name_); PADDLE_ENFORCE_NOT_NULL(in_var); Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var); - InitOutputValue(*in_var_handle, out_var_handles); + InitOutputValue(in_var_handle, out_var_handles); if (platform::is_cpu_place(in_tensor.place())) { for (auto *out_var_handle : out_var_handles) { - if (out_var_handle->IsTheSameVar(*in_var_handle)) { + if (out_var_handle->IsTheSameVar(in_var_handle)) { continue; } auto &out_p = out_var_handle->place_; @@ -114,12 +121,12 @@ void BroadcastOpHandle::RunImpl() { } } - if (!out_handle->IsTheSameVar(*in_var_handle)) { - auto out_var = var_scopes.at(in_var_handle->scope_idx_) + if (!out_handle->IsTheSameVar(in_var_handle)) { + auto out_var = var_scopes.at(in_var_handle.scope_idx_) ->FindVar(out_var_handles[0]->name_); paddle::framework::TensorCopy( - in_tensor, in_var_handle->place_, - *(dev_ctxes_.at(in_var_handle->place_)), + in_tensor, in_var_handle.place_, + *(dev_ctxes_.at(in_var_handle.place_)), &VariableVisitor::GetMutableTensor(out_var)); } }); diff --git a/paddle/fluid/framework/details/broadcast_op_handle.h b/paddle/fluid/framework/details/broadcast_op_handle.h index fe4e733e43417977df324fde808f52b228a27d19..020d351e891c7afab37c59c0ff8d8e5e7ba184f2 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.h +++ b/paddle/fluid/framework/details/broadcast_op_handle.h @@ -61,7 +61,10 @@ struct BroadcastOpHandle : public OpHandleBase { protected: void RunImpl() override; - private: + void BroadcastOneVar(const VarHandle &in_var_handle, + const std::vector &out_var_handles, + const std::vector &var_scopes); + std::vector local_scopes_; std::vector places_; #ifdef PADDLE_WITH_CUDA diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 6a6b497fa897e3882995688bf36704b1d77ea962..fefd27fc86fb8dce3311fa580d90f518906dd862 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -121,6 +121,7 @@ std::unique_ptr BuildStrategy::Apply( USE_PASS(fuse_elewise_add_act_pass); USE_PASS(graph_viz_pass); +USE_PASS(multi_batch_merge_pass); USE_PASS(multi_devices_pass); USE_PASS(multi_devices_check_pass); USE_PASS(multi_devices_print_pass); diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index 02c4bea16916d58a6d0fce8918f8fceb9ff9356e..f3ffaf6ecd7c4dd99c40fe58ba88c0cbdc14bde7 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -69,6 +69,8 @@ struct BuildStrategy { bool enable_data_balance_{false}; + bool fuse_broadcast_op_{false}; + // User normally doesn't need to call this API. // The PassBuilder allows for more customized insert, remove of passes // from python side. diff --git a/paddle/fluid/framework/details/fused_broadcast_op_handle.cc b/paddle/fluid/framework/details/fused_broadcast_op_handle.cc new file mode 100644 index 0000000000000000000000000000000000000000..51dfa2d0711f49aaefab0af3549283dbf77eee4a --- /dev/null +++ b/paddle/fluid/framework/details/fused_broadcast_op_handle.cc @@ -0,0 +1,55 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/fused_broadcast_op_handle.h" +#include "paddle/fluid/framework/details/container_cast.h" +#include "paddle/fluid/framework/details/variable_visitor.h" +#include "paddle/fluid/platform/profiler.h" + +namespace paddle { +namespace framework { +namespace details { + +void FusedBroadcastOpHandle::RunImpl() { + platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second); + + if (places_.size() == 1UL) return; + + auto in_var_handles = DynamicCast(inputs_); + auto out_var_handles = DynamicCast(outputs_); + + WaitInputVarGenerated(); + + std::vector var_scopes; + for (auto *s : local_scopes_) { + var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get()); + } + + size_t place_num = places_.size(); + PADDLE_ENFORCE_EQ(in_var_handles.size() * place_num, out_var_handles.size()); + + for (size_t i = 0; i < in_var_handles.size(); ++i) { + BroadcastOneVar( + *in_var_handles[i], + std::vector(out_var_handles.begin() + i * place_num, + out_var_handles.begin() + (i + 1) * place_num), + var_scopes); + } +} + +std::string FusedBroadcastOpHandle::Name() const { return "fused_broadcast"; } + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/fused_broadcast_op_handle.h b/paddle/fluid/framework/details/fused_broadcast_op_handle.h new file mode 100644 index 0000000000000000000000000000000000000000..e37259526a5f6f57d51a0ca8bca96a18211a4790 --- /dev/null +++ b/paddle/fluid/framework/details/fused_broadcast_op_handle.h @@ -0,0 +1,57 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/details/broadcast_op_handle.h" +#include "paddle/fluid/framework/details/multi_devices_helper.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/platform/device_context.h" + +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/nccl_helper.h" +#endif + +namespace paddle { +namespace framework { +namespace details { + +struct FusedBroadcastOpHandle : public BroadcastOpHandle { + public: +#ifdef PADDLE_WITH_CUDA + FusedBroadcastOpHandle(ir::Node *node, + const std::vector local_scopes, + const std::vector &places, + const platform::NCCLContextMap *nccl_ctx) + : BroadcastOpHandle(node, local_scopes, places, nccl_ctx) {} +#else + FusedBroadcastOpHandle(ir::Node* node, const std::vector local_scopes, + const std::vector& places) + : BroadcastOpHandle(node, local_scopes, places) {} +#endif + std::string Name() const override; + + protected: + void RunImpl() override; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index ebd1d644bcce0554904ee0931827ef43a6d52b11..f2d5b182e577714d6138e99932af637a711cc9f2 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -21,6 +21,7 @@ #include "paddle/fluid/framework/details/broadcast_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/data_balance_op_handle.h" +#include "paddle/fluid/framework/details/fused_broadcast_op_handle.h" #include "paddle/fluid/framework/details/multi_devices_graph_pass.h" #include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/rpc_op_handle.h" @@ -347,7 +348,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( BuildStrategy::GradientScaleStrategy::kCustomized) { // TODO(paddle-dev): Why is there no input for this op_handle? auto loss_grad_name = node->Op()->OutputArgumentNames()[0]; - CreateScaleLossGradOp(&result, loss_grad_name); + CreateScaleLossGradOp(&result, loss_grad_name, node->outputs[0]); } // This assumes the backward generating code will ensure IsScaleLossOp // is true only for the op that scale the final scalar loss. @@ -436,10 +437,14 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( if ((use_gpu && strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) || is_dist_train) { - for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) { - auto &to_bcast_set = bcast_var_name_set[dev_id]; - for (auto &bcast_name : to_bcast_set) { - CreateBroadcastOp(&result, bcast_name, dev_id); + if (strategy_.fuse_broadcast_op_) { + CreateFusedBroadcastOp(&result, bcast_var_name_set); + } else { + for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) { + auto &to_bcast_set = bcast_var_name_set[dev_id]; + for (auto &bcast_name : to_bcast_set) { + CreateBroadcastOp(&result, bcast_name, dev_id); + } } } } @@ -508,6 +513,44 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result, } } +void MultiDevSSAGraphBuilder::CreateFusedBroadcastOp( + ir::Graph *result, + const std::vector> &bcast_varnames) const { +#ifdef PADDLE_WITH_CUDA + auto *op_handle = new FusedBroadcastOpHandle( + result->CreateEmptyNode("fused_broadcast", ir::Node::Type::kOperation), + local_scopes_, places_, nccl_ctxs_); +#else + auto *op_handle = new FusedBroadcastOpHandle( + result->CreateEmptyNode("fused_broadcast", ir::Node::Type::kOperation), + local_scopes_, places_); +#endif + result->Get(kGraphOps).emplace_back(op_handle); + + for (size_t i = 0; i < places_.size(); ++i) { + auto &p = places_[i]; + SetCommunicationContext(op_handle, p); + } + + for (size_t dev_id = 0; dev_id < bcast_varnames.size(); ++dev_id) { + for (auto &p_name : bcast_varnames[dev_id]) { + auto *in = + result->Get(kGraphVars).at(dev_id).at(p_name).back().get(); + op_handle->AddInput(in); + for (size_t out_dev_id = 0; out_dev_id < places_.size(); ++out_dev_id) { + auto &p = places_[out_dev_id]; + auto &vars = + result->Get(kGraphVars).at(out_dev_id).at(p_name); + auto *out_var = new VarHandle( + result->CreateEmptyNode(p_name, ir::Node::Type::kVariable), + vars.size(), out_dev_id, p_name, p); + vars.emplace_back(out_var); + op_handle->AddOutput(out_var); + } + } + } +} + void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result, ir::Node *node, int dev_id) const { @@ -602,7 +645,8 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph, } void MultiDevSSAGraphBuilder::CreateScaleLossGradOp( - ir::Graph *result, const std::string &loss_grad_name) const { + ir::Graph *result, const std::string &loss_grad_name, + ir::Node *out_var_node) const { for (size_t i = 0; i < places_.size(); ++i) { // Insert ScaleCost OpHandle auto *dev_ctx = platform::DeviceContextPool::Instance().Get(places_[i]); @@ -617,10 +661,8 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp( // loss->pending_ops_.emplace_back(op_handle); // op_handle->inputs_.emplace_back(loss); - CreateOpOutput( - result, op_handle, - result->CreateEmptyNode(loss_grad_name, ir::Node::Type::kVariable), - places_[i], i); + CreateOpOutput(result, op_handle, + result->CreateVarNode(out_var_node->Var()), places_[i], i); } } diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.h b/paddle/fluid/framework/details/multi_devices_graph_pass.h index cdf9f13cde608b546d17a1e53e0f6acea9e12566..03b2de2f04da4bac8d342a76c80fd12beaeba4b7 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.h +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.h @@ -61,7 +61,8 @@ class MultiDevSSAGraphBuilder : public ir::Pass { size_t num_places) const; void CreateScaleLossGradOp(ir::Graph *result, - const std::string &loss_grad_name) const; + const std::string &loss_grad_name, + ir::Node *out_var_node) const; VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og, int dst_dev_id) const; @@ -78,6 +79,10 @@ class MultiDevSSAGraphBuilder : public ir::Pass { void CreateBroadcastOp(ir::Graph *result, const std::string &p_name, size_t src_dev_id) const; + void CreateFusedBroadcastOp( + ir::Graph *result, + const std::vector> &bcast_varnames) const; + bool IsSparseGradient(const std::string &og) const; size_t GetAppropriateDeviceID( diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index a145b2fafe64f8c80ac7808583d6670ca0218c06..ce006b7a3fbc16f3c9149933390969b14a46b484 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -36,6 +36,7 @@ pass_library(fc_lstm_fuse_pass inference) pass_library(embedding_fc_lstm_fuse_pass inference) pass_library(fc_gru_fuse_pass inference) pass_library(seq_concat_fc_fuse_pass inference) +pass_library(multi_batch_merge_pass base) pass_library(conv_bn_fuse_pass inference) pass_library(seqconv_eltadd_relu_fuse_pass inference) if(WITH_MKLDNN) diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index 398f7095968e62f92d610f560d7574b27706d13e..265a128e95e6205159de67d87d5cb8ca1ad7d475 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -24,79 +24,23 @@ namespace paddle { namespace framework { namespace ir { -std::vector FindDistTrainSendVars( - const std::vector &nodes) { - std::vector send_vars; - // since parameters are all in block 0, - // it's enough to only scan send ops in block 0 - for (auto &node : nodes) { - auto op_vars = node->Op()->InputArgumentNames(); - send_vars.reserve(send_vars.size() + - std::distance(op_vars.begin(), op_vars.end())); - send_vars.insert(send_vars.end(), op_vars.begin(), op_vars.end()); - } - return send_vars; -} - -std::vector FindDistTrainRecvVars( - const std::vector &nodes) { - std::vector recv_vars; - for (auto &node : nodes) { - auto op_vars = node->Op()->OutputArgumentNames(); - recv_vars.reserve(recv_vars.size() + - std::distance(op_vars.begin(), op_vars.end())); - recv_vars.insert(recv_vars.end(), op_vars.begin(), op_vars.end()); - } - return recv_vars; -} - -bool IsDistTrainOp(ir::Node *node, const std::vector &send_vars, - const std::vector &recv_vars) { - if (send_vars.size() == 0 || recv_vars.size() == 0) { - return false; - } - - /** - * Check any of opvars contains `.block` and in sendvars - */ - auto checker = [](const std::vector &opvars, - const std::vector &rpc_vars) -> bool { - for (auto &var : opvars) { - // a variable name with the suffix `.block` means it's a splited - // variable by (DistributeTranspiler) - // [python/paddle/fluid/transpiler/distribute_transpiler.py] - if (var.find(".block") != std::string::npos && - std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) { - return true; - } - } - return false; - }; - - std::vector input_var_names; - std::vector output_var_names; - for (ir::Node *input : node->inputs) { - input_var_names.push_back(input->Name()); - } - for (ir::Node *output : node->outputs) { - output_var_names.push_back(output->Name()); - } - - return checker(output_var_names, send_vars) || - checker(input_var_names, recv_vars); -} - Graph::Graph(const ProgramDesc &program) : program_(program) { // Make the nodes id start from 0. Node::ResetId(); + auto var_nodes = InitFromProgram(program_); + ResolveHazard(var_nodes); +} +std::map> Graph::InitFromProgram( + const ProgramDesc &program) { VLOG(3) << "block in program:" << program_.Size(); std::unordered_map all_vars; + // var nodes for each var name, will have multiple versions in SSA + std::map> var_nodes; for (auto *var : program.Block(0).AllVars()) { all_vars.emplace(var->Name(), var); } - std::map> var_nodes; for (auto *op : program.Block(0).AllOps()) { ir::Node *node = CreateOpNode(op); // For input args, reuse the same var name if it was created before. @@ -134,7 +78,11 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { var->inputs.push_back(node); } } + return std::move(var_nodes); +} +void Graph::ResolveHazard( + const std::map> &var_nodes) { /** * We should handle write after read(WAR) and write after write(WAW) here. * Because some of the operators of the program can be executed parallelly. @@ -153,6 +101,7 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { auto it_old = versions.rbegin(); ++it_old; for (; it_old != versions.rend(); it_new = it_old, ++it_old) { + VLOG(3) << "deal with var: " << (*it_new)->Name(); ir::Node *write_op = (*it_new)->inputs.empty() ? nullptr : (*it_new)->inputs[0]; const auto &read_ops = (*it_old)->outputs; diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index ab687e760a761d4e445726bd5149966adc2403d0..9d7aa5d32deb274fbf29481b0d4754c05d1e21b5 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -160,6 +160,12 @@ class Graph { return nullptr; } + std::map> InitFromProgram( + const ProgramDesc &program); + + void ResolveHazard( + const std::map> &var_nodes); + private: // This method takes ownership of `node`. ir::Node *AddNode(ir::Node *node) { diff --git a/paddle/fluid/framework/ir/multi_batch_merge_pass.cc b/paddle/fluid/framework/ir/multi_batch_merge_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..bd5b76426eb55cebdabfccd700439a4c418a10f0 --- /dev/null +++ b/paddle/fluid/framework/ir/multi_batch_merge_pass.cc @@ -0,0 +1,315 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/multi_batch_merge_pass.h" + +#include +#include +#include + +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/op_proto_maker.h" + +namespace paddle { +namespace framework { +namespace ir { + +static const char kNumRepeats[] = "num_repeats"; +typedef std::unordered_map> SSAVarList; + +ir::Node* SameNameVar(std::unordered_set all, ir::Node* target) { + for (auto n : all) { + if (target->IsVar() && target->Name() == n->Name()) { + return n; + } + } + return nullptr; +} + +VarDesc CopyVarDesc(VarDesc* var_desc) { + VarDesc repeated_var(var_desc->Name()); + // copy other variable attributes + if (var_desc->GetType() != proto::VarType::READER) { + repeated_var.SetType(var_desc->GetType()); + repeated_var.SetShape(var_desc->GetShape()); + repeated_var.SetDataType(var_desc->GetDataType()); + repeated_var.SetLoDLevel(var_desc->GetLoDLevel()); + repeated_var.SetPersistable(var_desc->Persistable()); + } else { + // TODO(typhoonzero): copy reader var + } + return repeated_var; +} + +VarDesc UpdateGradVarDesc( + VarDesc* var_desc, int repeat, + const std::unordered_set& grad_names, + const std::unordered_set& bn_vars_need_rename) { + if (grad_names.find(var_desc->Name()) != grad_names.end() || + bn_vars_need_rename.find(var_desc->Name()) != bn_vars_need_rename.end()) { + std::string new_gname = + string::Sprintf("%s.repeat.%d", var_desc->Name(), repeat); + VarDesc repeated_var = CopyVarDesc(var_desc); + repeated_var.SetName(new_gname); + VLOG(3) << "update " << var_desc->Name() << " to repeat " << repeat; + return repeated_var; + } + return *var_desc; +} + +std::unique_ptr BatchMergePass::ApplyImpl( + std::unique_ptr graph) const { + int num_repeats = Get(kNumRepeats); + std::vector forward_backward_ops; + std::vector optimize_ops; + std::vector lr_ops; // ops other than forward/backward/optimize + std::unordered_set grad_names; + + std::vector nodes = TopologySortOperations(*graph); + auto origin_nodes = graph->ReleaseNodes(); + VLOG(3) << "origin nodes count: " << origin_nodes.size(); + ir::Graph& result = *graph; + + // 1. record op nodes of different roles + for (auto node : nodes) { + if (node->IsVar()) continue; + int op_role = boost::get(node->Op()->GetAttr( + framework::OpProtoAndCheckerMaker::OpRoleAttrName())); + if ((op_role == static_cast(framework::OpRole::kForward)) || + (op_role & static_cast(framework::OpRole::kBackward)) || + (op_role & static_cast(framework::OpRole::kLoss))) { + forward_backward_ops.push_back(node); + } else if ((op_role & static_cast(framework::OpRole::kOptimize)) || + (op_role & static_cast(framework::OpRole::kDist)) || + (op_role & static_cast(framework::OpRole::kRPC))) { + optimize_ops.push_back(node); + auto op_role_var = node->Op()->GetNullableAttr( + OpProtoAndCheckerMaker::OpRoleVarAttrName()); + auto op_role_vars = boost::get>(op_role_var); + for (size_t i = 0; i < op_role_vars.size(); i += 2) { + grad_names.insert(op_role_vars[i + 1]); + } + } else if (op_role & static_cast(framework::OpRole::kLRSched)) { + lr_ops.push_back(node); + } else { // NOLINT + PADDLE_THROW("Invalid op_role: %d", static_cast(op_role)); + } + } + + // 2. copy forward backward + ir::Node* prev_repeat_last_op_node = nullptr; + // record origin_grad -> repeated grad list map. + std::map> grad_repeated_map; + std::map> created; + std::unordered_set bn_vars_need_rename; + for (int i = 0; i < num_repeats; ++i) { + std::unordered_set copied; + for (size_t node_idx = 0; node_idx < forward_backward_ops.size(); + ++node_idx) { + auto node = forward_backward_ops[node_idx]; + OpDesc repeated_op(*(node->Op()), node->Op()->Block()); + // 3. rename grad outputs to current repeat. + for (auto outname : repeated_op.OutputArgumentNames()) { + if (grad_names.find(outname) != grad_names.end()) { + std::string new_gname = string::Sprintf("%s.repeat.%d", outname, i); + repeated_op.RenameOutput(outname, new_gname); + } + } + // 3.5 let batch_norm ops use independent vars, note batch_norm_grad do + // not need this update + if (node->Name() == "batch_norm") { + // NOTE: assume bn op created by layers use save var as output mean and + // variance + std::string new_mean_name = + string::Sprintf("%s.repeat.%d", repeated_op.Input("Mean")[0], i); + std::string new_var_name = string::Sprintf( + "%s.repeat.%d", repeated_op.Input("Variance")[0], i); + bn_vars_need_rename.insert(repeated_op.Input("Mean")[0]); + bn_vars_need_rename.insert(repeated_op.Input("Variance")[0]); + VLOG(3) << "renaming " << repeated_op.Input("Mean")[0] << " to " + << new_mean_name; + repeated_op.RenameInput(repeated_op.Input("Mean")[0], new_mean_name); + repeated_op.RenameInput(repeated_op.Input("Variance")[0], new_var_name); + repeated_op.RenameOutput(repeated_op.Output("MeanOut")[0], + new_mean_name); + repeated_op.RenameOutput(repeated_op.Output("VarianceOut")[0], + new_var_name); + } + + // 3.9 do copy + auto repeated_node = result.CreateOpNode(&repeated_op); + copied.insert(node); + + // 4. add deps between repeats + if (node_idx == forward_backward_ops.size() - 1) { + prev_repeat_last_op_node = repeated_node; + } + if (node_idx == 0 && prev_repeat_last_op_node) { + auto* depvar = result.CreateControlDepVar(); + prev_repeat_last_op_node->outputs.push_back(depvar); + depvar->inputs.push_back(prev_repeat_last_op_node); + repeated_node->inputs.push_back(depvar); + depvar->outputs.push_back(repeated_node); + } + + for (auto in_node : node->inputs) { + if (in_node->IsCtrlVar()) { + continue; + } + ir::Node* var = nullptr; + auto updated_var = UpdateGradVarDesc(in_node->Var(), i, grad_names, + bn_vars_need_rename); + // should be initialized by startup, how to initilize tensor in the + // scope? + if (node->Name() == "batch_norm" && + bn_vars_need_rename.find(in_node->Name()) != + bn_vars_need_rename.end()) { + // Create bn mean/variance for each repeat + var = result.CreateVarNode(&updated_var); + created[updated_var.Name()].push_back(var); + copied.insert(in_node); + repeated_node->inputs.push_back(var); + var->outputs.push_back(repeated_node); + continue; + } + + // for other ops + if (in_node->inputs.empty() && i > 0) { + // do not copy head vars (inputs, params) in repeats > 0 + var = created.at(in_node->Name()).back(); + } else { + if (copied.find(in_node) == copied.end()) { + var = result.CreateVarNode(&updated_var); + if (grad_names.find(in_node->Var()->Name()) != grad_names.end()) { + grad_repeated_map[in_node].push_back(var); + } + copied.insert(in_node); + created[updated_var.Name()].push_back(var); + } else { + var = created.at(updated_var.Name()).back(); + } + } + repeated_node->inputs.push_back(var); + var->outputs.push_back(repeated_node); + } + for (auto out_node : node->outputs) { + if (out_node->IsCtrlVar()) { + continue; + } + ir::Node* var = nullptr; + auto updated_var = UpdateGradVarDesc(out_node->Var(), i, grad_names, + bn_vars_need_rename); + if (copied.find(out_node) == copied.end()) { + var = result.CreateVarNode(&updated_var); + if (grad_names.find(out_node->Var()->Name()) != grad_names.end()) { + grad_repeated_map[out_node].push_back(var); + } + copied.insert(out_node); + created[updated_var.Name()].push_back(var); + } else { + var = created.at(updated_var.Name()).back(); + } + repeated_node->outputs.push_back(var); + var->inputs.push_back(repeated_node); + } + } + } + + // 5. create GRAD merge op node + for (auto kv : grad_repeated_map) { + OpDesc sum_op; + sum_op.SetType("sum"); + std::vector repeated_grad_names; + for (auto r : kv.second) { + repeated_grad_names.push_back(r->Var()->Name()); + } + sum_op.SetInput("X", repeated_grad_names); + sum_op.SetOutput("Out", {kv.first->Var()->Name()}); + sum_op.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), + static_cast(OpRole::kBackward)); + auto sum_op_node = result.CreateOpNode(&sum_op); + for (auto r : kv.second) { + sum_op_node->inputs.push_back(r); + r->outputs.push_back(sum_op_node); + } + auto sum_out_var_node = result.CreateVarNode(kv.first->Var()); + sum_op_node->outputs.push_back(sum_out_var_node); + sum_out_var_node->inputs.push_back(sum_op_node); + created[sum_out_var_node->Name()].push_back(sum_out_var_node); + + OpDesc scale_op; + scale_op.SetType("scale"); + scale_op.SetInput("X", {sum_out_var_node->Var()->Name()}); + // NOTE: inplace scale. + scale_op.SetOutput("Out", {sum_out_var_node->Var()->Name()}); + scale_op.SetAttr("scale", static_cast(1.0f / num_repeats)); + scale_op.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), + static_cast(OpRole::kBackward)); + auto scale_op_node = result.CreateOpNode(&scale_op); + scale_op_node->inputs.push_back(sum_out_var_node); + sum_out_var_node->outputs.push_back(scale_op_node); + auto scale_out_var_node = result.CreateVarNode(sum_out_var_node->Var()); + scale_op_node->outputs.push_back(scale_out_var_node); + scale_out_var_node->inputs.push_back(scale_op_node); + created[scale_out_var_node->Name()].push_back(scale_out_var_node); + } + // 6. add optimize ops + { + auto copy_node = [&result, &created](ir::Node* node) { + auto op_node = result.CreateOpNode(node->Op()); + // copy op ins/outs + // NOTE: for send/recv ops, the OpDesc uses ctrldepvar to describe + // dependencies, so create those depvars if OpDesc have in/outs. + for (auto in_node : node->inputs) { + if (in_node->IsCtrlVar() && !in_node->Var()) { + continue; + } + ir::Node* var = nullptr; + if (created.find(in_node->Name()) == created.end()) { + var = result.CreateVarNode(in_node->Var()); + created[in_node->Name()].push_back(var); + } else { + var = created.at(in_node->Name()).back(); + } + op_node->inputs.push_back(var); + var->outputs.push_back(op_node); + } + for (auto out_node : node->outputs) { + if (out_node->IsCtrlVar() && !out_node->Var()) { + continue; + } + auto var = result.CreateVarNode(out_node->Var()); + created[out_node->Name()].push_back(var); + op_node->outputs.push_back(var); + var->inputs.push_back(op_node); + } + }; + for (auto node : lr_ops) { + copy_node(node); + } + for (auto node : optimize_ops) { + copy_node(node); + } + } + + result.ResolveHazard(created); + return graph; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(multi_batch_merge_pass, paddle::framework::ir::BatchMergePass) + .RequirePassAttr(paddle::framework::ir::kNumRepeats); diff --git a/paddle/fluid/framework/ir/multi_batch_merge_pass.h b/paddle/fluid/framework/ir/multi_batch_merge_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..c1e5aef20dbc60c18ed03038818bfd8ab217bf28 --- /dev/null +++ b/paddle/fluid/framework/ir/multi_batch_merge_pass.h @@ -0,0 +1,44 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +// BatchMergePass is used to copy forward and backward ops for several +// times to run several batches to simulate large batch size training +// as if we have more than 1 GPUs. +// User can define how many batches to run, gradients will be merged +// through those repeats, and then do optimization using merged gradients. +// This pass is extremely useful when doing large batch-size distributed +// sync training, we can simulate even large batch size as if we have more +// GPUs. + +class BatchMergePass : public Pass { + public: + virtual ~BatchMergePass() {} + + protected: + std::unique_ptr ApplyImpl(std::unique_ptr graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index 5d6da9f1d76a3c0fc64b7ff35264e385cf19a14b..d6d42f5e92080aa57445e2d6ce59aa3faa89d22d 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -44,6 +44,7 @@ class Node { return op_desc_.get(); } + // Please don't use this API! int id() const { return id_; } bool IsOp() const { return type_ == Type::kOperation; } @@ -92,6 +93,7 @@ class Node { Node() = delete; static int count_; + // Please don't use this API or make this public. static void ResetId() { count_ = 0; } DISABLE_COPY_AND_ASSIGN(Node); }; diff --git a/paddle/fluid/framework/op_desc.h b/paddle/fluid/framework/op_desc.h index 440e0509be727ec2b84abc76fca44edda11f8a0a..30c8a26c3d2f0068674aa70b4ff875a2f73c1dca 100644 --- a/paddle/fluid/framework/op_desc.h +++ b/paddle/fluid/framework/op_desc.h @@ -121,10 +121,6 @@ class OpDesc { BlockDesc *Block() { return this->block_; } - const BlockDesc &BlockRef() const { return *this->block_; } - - void SetBlock(BlockDesc *block) { this->block_ = block; } - private: template static std::vector MapKeys(const MapType &map) { diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 3368ae2ee4cf65b85abf1fcd89dee14f43522e1f..cffb96bedf7638ee52856f21662437085035eca6 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -109,18 +109,9 @@ ParallelExecutor::ParallelExecutor( if (member_->local_scopes_.size() != 1 && local_scopes.empty()) { BCastParamsToDevices(bcast_vars); } - // Startup Program has been run. All local scopes has correct parameters. +// Startup Program has been run. All local scopes has correct parameters. - // Step 2. Create vars in each scope; - std::vector var_infos; - for (auto *var : main_program.Block(0).AllVars()) { - var_infos.emplace_back(); - var_infos.back().name_ = var->Name(); - var_infos.back().type_ = var->GetType(); - var_infos.back().persistable_ = var->Persistable(); - } - -// Step 3. Convert main_program to SSA form and dependency graph. Also, insert +// Step 2. Convert main_program to SSA form and dependency graph. Also, insert // ncclOp #ifdef PADDLE_WITH_CUDA std::unique_ptr graph = build_strategy.Apply( @@ -156,6 +147,17 @@ ParallelExecutor::ParallelExecutor( params, member_->local_scopes_, member_->use_cuda_); #endif + // Step 3. Create vars in each scope. Passes may also create new vars. + // skip control vars and empty vars + std::vector var_infos; + for (auto &node : graph->Nodes()) { + if (node->IsVar() && !node->IsCtrlVar() && node->Var()) { + var_infos.emplace_back(); + var_infos.back().name_ = node->Var()->Name(); + var_infos.back().type_ = node->Var()->GetType(); + var_infos.back().persistable_ = node->Var()->Persistable(); + } + } // If the loss_var_name is given, the number of graph should be only one. if (loss_var_name.size()) { PADDLE_ENFORCE_EQ(ir::GraphNum(*graph), 1, diff --git a/paddle/fluid/inference/api/helper.h b/paddle/fluid/inference/api/helper.h index 24f59cf43a9700ff1732e1ef6ad82e1a6294eede..e46dc1326951f68fd030f2208b9bea1647d0026d 100644 --- a/paddle/fluid/inference/api/helper.h +++ b/paddle/fluid/inference/api/helper.h @@ -160,7 +160,8 @@ static void PrintTime(int batch_size, int repeat, int num_threads, int tid, double latency, int epoch = 1) { LOG(INFO) << "====== batch_size: " << batch_size << ", repeat: " << repeat << ", threads: " << num_threads << ", thread id: " << tid - << ", latency: " << latency << "ms ======"; + << ", latency: " << latency << "ms, fps: " << 1 / (latency / 1000.f) + << " ======"; if (epoch > 1) { int samples = batch_size * epoch; LOG(INFO) << "====== sample number: " << samples diff --git a/paddle/fluid/inference/tests/api/tester_helper.h b/paddle/fluid/inference/tests/api/tester_helper.h index 5589b58b0618523b22f169f9ba9930c3ff3e3c48..19c3f532d5dcb7588793fa21fa179f6b48649103 100644 --- a/paddle/fluid/inference/tests/api/tester_helper.h +++ b/paddle/fluid/inference/tests/api/tester_helper.h @@ -139,6 +139,9 @@ void TestMultiThreadPrediction( } for (int tid = 0; tid < num_threads; ++tid) { threads.emplace_back([&, tid]() { +#ifdef PADDLE_WITH_MKLDNN + platform::set_cur_thread_id(static_cast(tid) + 1); +#endif // Each thread should have local inputs and outputs. // The inputs of each thread are all the same. std::vector> inputs_tid = inputs; diff --git a/paddle/fluid/operators/lars_momentum_op.cc b/paddle/fluid/operators/lars_momentum_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..a8dda93902448fa1bd21b719ffd9c9b500caf755 --- /dev/null +++ b/paddle/fluid/operators/lars_momentum_op.cc @@ -0,0 +1,86 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/lars_momentum_op.h" +#include "paddle/fluid/operators/momentum_op.h" + +namespace paddle { +namespace operators { + +class LarsMomentumOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Param", + "(LoDTensor, default LoDTensor) " + "Input parameter that has to be updated"); + AddInput("Grad", + "(LoDTensor, default LoDTensor) " + "Input gradient of the parameter"); + AddInput("Velocity", + "(LoDTensor, default LoDTensor) " + "Input velocity (corresponding to the parameter) " + "that has to be updated"); + AddInput("LearningRate", + "(LoDTensor, default LoDTensor) " + "Input learning rate"); + + AddOutput("ParamOut", + "(LoDTensor) This output is updated parameter. " + "It shared memory with Input(Param)."); + AddOutput("VelocityOut", + "(LoDTensor) This output is updated velocity. " + "It shared memory with Input(Velocity)."); + + AddAttr("mu", "(float) Momentum coefficient"); + AddAttr("lars_coeff", "(float, default 0.001) LARS coefficient.") + .SetDefault(0.001); + AddAttr("lars_weight_decay", + "(float, default 0.0005) LARS weight decay") + .SetDefault(0.0005); + + AddComment(R"DOC( +Lars Momentum Optimizer. + +This optimizer use LARS (https://arxiv.org/abs/1708.03888) to optimize each +weight using a local learning rate: + +$$ +local\_lr = \eta * + \frac{\left \| param \right \|}{\left \| grad \right \| + \beta *\left \| param \right \|} \\ +velocity = mu * velocity + + local\_lr * (grad + \beta * param) \\ +param = param - velocity. \\ +$$ + +Note that we use lars_weight_decay here to decay weights, you may need not to +use L2 regularizers in case of using LARS. + +)DOC"); + } +}; + +class LarsMomentumOpVarTypeInference : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override {} +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(lars_momentum, ops::MomentumOp, ops::LarsMomentumOpMaker, + paddle::framework::EmptyGradOpMaker, + ops::LarsMomentumOpVarTypeInference); +REGISTER_OP_CPU_KERNEL(lars_momentum, ops::LarsMomentumOpKernel, + ops::LarsMomentumOpKernel); diff --git a/paddle/fluid/operators/lars_momentum_op.cu b/paddle/fluid/operators/lars_momentum_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..eb346851a2f690fa05422c84ddcb08307539048f --- /dev/null +++ b/paddle/fluid/operators/lars_momentum_op.cu @@ -0,0 +1,94 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/lars_momentum_op.h" + +namespace paddle { +namespace operators { + +template +__global__ void MomentumLarsKernel(const T* p, const T* g, const T* v, + const T* learning_rate, const T mu, + const int64_t num, const T lars_coeff, + const T lars_weight_decay, const T* p_norm, + const T* g_norm, T* p_out, T* v_out) { + T lr = learning_rate[0]; + T local_lr = learning_rate[0]; + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; + i += blockDim.x * gridDim.x) { + if (p_norm[0] > 0 && g_norm[0] > 0) { + local_lr = lr * lars_coeff * p_norm[0] / + (g_norm[0] + lars_weight_decay * p_norm[0]); + } + T v_new = v[i] * mu + local_lr * (g[i] + lars_weight_decay * p[i]); + v_out[i] = v_new; + p_out[i] = p[i] - v_new; + } +} + +template +class LarsMomentumOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto param_out = ctx.Output("ParamOut"); + auto velocity_out = ctx.Output("VelocityOut"); + auto param = ctx.Input("Param"); + auto velocity = ctx.Input("Velocity"); + auto grad = ctx.Input("Grad"); + auto learning_rate = ctx.Input("LearningRate"); + + T* p_out = param_out->mutable_data(ctx.GetPlace()); + T* v_out = velocity_out->mutable_data(ctx.GetPlace()); + + T mu = static_cast(ctx.Attr("mu")); + T lars_coeff = ctx.Attr("lars_coeff"); + T lars_weight_decay = ctx.Attr("lars_weight_decay"); + + auto* p = param->data(); + auto* v = velocity->data(); + auto* g = grad->data(); + auto* lr = learning_rate->data(); + + int block = 512; + int grid = (param->numel() + block - 1) / block; + + auto eigen_p = framework::EigenVector::Flatten(*param); + auto eigen_g = framework::EigenVector::Flatten(*grad); + // calculate norms using eigein and launch the kernel. + framework::Tensor p_norm_t, g_norm_t; + p_norm_t.Resize({1}); + g_norm_t.Resize({1}); + auto* p_norm_data = p_norm_t.mutable_data(ctx.GetPlace()); + auto* g_norm_data = g_norm_t.mutable_data(ctx.GetPlace()); + auto ep_norm = framework::EigenScalar::From(p_norm_t); + auto eg_norm = framework::EigenScalar::From(g_norm_t); + + auto* place = ctx.template device_context().eigen_device(); + ep_norm.device(*place) = eigen_p.square().sum().sqrt(); + eg_norm.device(*place) = eigen_g.square().sum().sqrt(); + MomentumLarsKernel<<>>( + p, g, v, lr, mu, param->numel(), lars_coeff, lars_weight_decay, + p_norm_data, g_norm_data, p_out, v_out); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + lars_momentum, + ops::LarsMomentumOpCUDAKernel, + ops::LarsMomentumOpCUDAKernel); diff --git a/paddle/fluid/operators/lars_momentum_op.h b/paddle/fluid/operators/lars_momentum_op.h new file mode 100644 index 0000000000000000000000000000000000000000..e85be99fc42522e461a7915847d82144d8195a96 --- /dev/null +++ b/paddle/fluid/operators/lars_momentum_op.h @@ -0,0 +1,72 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class LarsMomentumOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto param_out = ctx.Output("ParamOut"); + auto velocity_out = ctx.Output("VelocityOut"); + auto param = ctx.Input("Param"); + auto velocity = ctx.Input("Velocity"); + auto learning_rate = ctx.Input("LearningRate"); + auto* grad_var = ctx.InputVar("Grad"); + // only support dense for now. + PADDLE_ENFORCE(grad_var->IsType()); + auto grad = ctx.Input("Grad"); + + param_out->mutable_data(ctx.GetPlace()); + velocity_out->mutable_data(ctx.GetPlace()); + + T mu = static_cast(ctx.Attr("mu")); + T lars_coeff = ctx.Attr("lars_coeff"); + T lars_weight_decay = ctx.Attr("lars_weight_decay"); + + auto p_out = framework::EigenVector::Flatten(*param_out); + auto v_out = framework::EigenVector::Flatten(*velocity_out); + + auto p = framework::EigenVector::Flatten(*param); + auto v = framework::EigenVector::Flatten(*velocity); + auto g = framework::EigenVector::Flatten(*grad); + auto* lr = learning_rate->data(); + + framework::Tensor p_norm_t, g_norm_t; + p_norm_t.Resize({1}); + g_norm_t.Resize({1}); + p_norm_t.mutable_data(ctx.GetPlace()); + g_norm_t.mutable_data(ctx.GetPlace()); + auto ep_norm = framework::EigenScalar::From(p_norm_t); + auto eg_norm = framework::EigenScalar::From(g_norm_t); + + ep_norm = p.square().sum().sqrt(); + eg_norm = g.square().sum().sqrt(); + T local_lr = lr[0]; + if (ep_norm(0) > 0 && eg_norm(0) > 0) { + local_lr = lr[0] * lars_coeff * ep_norm(0) / + (eg_norm(0) + lars_weight_decay * ep_norm(0)); + } + v_out = v * mu + local_lr * (g + lars_weight_decay * p); + p_out = p - v_out; + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/momentum_op.cc b/paddle/fluid/operators/momentum_op.cc index 12b916fcebd425bd4a03d920f947829098a924a1..7f0b51580aa2591ac7338ad7c29ee4756d909925 100644 --- a/paddle/fluid/operators/momentum_op.cc +++ b/paddle/fluid/operators/momentum_op.cc @@ -19,54 +19,6 @@ namespace operators { using Tensor = framework::Tensor; -class MomentumOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Param"), - "Input(param) of Momentum should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Grad"), - "Input(grad) of Momentum should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Velocity"), - "Input(velocity) of Momentum should not be null."); - PADDLE_ENFORCE(ctx->HasInput("LearningRate"), - "Input(LearningRate) of Momentum should not be null."); - PADDLE_ENFORCE( - ctx->GetInputsVarType("Param").front() == - framework::proto::VarType::LOD_TENSOR, - "The input var's type should be LoDTensor, but the received is %s", - ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front()); - - PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), - "Output(ParamOut) of Momentum should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("VelocityOut"), - "Output(VelocityOut) of Momentum should not be null."); - - auto param_dim = ctx->GetInputDim("Param"); - if (ctx->GetInputsVarType("Grad")[0] == - framework::proto::VarType::LOD_TENSOR) { - PADDLE_ENFORCE_EQ( - param_dim, ctx->GetInputDim("Grad"), - "Param and Grad input of MomentumOp should have the same dimension."); - PADDLE_ENFORCE_EQ( - param_dim, ctx->GetInputDim("Velocity"), - "Param and Velocity of MomentumOp should have the same dimension."); - } - PADDLE_ENFORCE_EQ(framework::product(ctx->GetInputDim("LearningRate")), 1, - "Learning_rate should be a scalar"); - - ctx->SetOutputDim("ParamOut", param_dim); - ctx->SetOutputDim("VelocityOut", param_dim); - } - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - auto input_data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param")); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); - } -}; - class MomentumOpInferVarType : public framework::VarTypeInference { public: void operator()(const framework::OpDesc& op_desc, diff --git a/paddle/fluid/operators/momentum_op.h b/paddle/fluid/operators/momentum_op.h index 6b4d00f56ca06c402c07ecf770a390e88ae3edf1..71f079e4d97f5259359ee6572f584894551452ca 100644 --- a/paddle/fluid/operators/momentum_op.h +++ b/paddle/fluid/operators/momentum_op.h @@ -28,6 +28,54 @@ using framework::SelectedRows; struct NoNesterov; struct UseNesterov; +class MomentumOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Param"), + "Input(param) of Momentum should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Grad"), + "Input(grad) of Momentum should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Velocity"), + "Input(velocity) of Momentum should not be null."); + PADDLE_ENFORCE(ctx->HasInput("LearningRate"), + "Input(LearningRate) of Momentum should not be null."); + PADDLE_ENFORCE( + ctx->GetInputsVarType("Param").front() == + framework::proto::VarType::LOD_TENSOR, + "The input var's type should be LoDTensor, but the received is %s", + ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front()); + + PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), + "Output(ParamOut) of Momentum should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("VelocityOut"), + "Output(VelocityOut) of Momentum should not be null."); + + auto param_dim = ctx->GetInputDim("Param"); + if (ctx->GetInputsVarType("Grad")[0] == + framework::proto::VarType::LOD_TENSOR) { + PADDLE_ENFORCE_EQ( + param_dim, ctx->GetInputDim("Grad"), + "Param and Grad input of MomentumOp should have the same dimension."); + PADDLE_ENFORCE_EQ( + param_dim, ctx->GetInputDim("Velocity"), + "Param and Velocity of MomentumOp should have the same dimension."); + } + PADDLE_ENFORCE_EQ(framework::product(ctx->GetInputDim("LearningRate")), 1, + "Learning_rate should be a scalar"); + + ctx->SetOutputDim("ParamOut", param_dim); + ctx->SetOutputDim("VelocityOut", param_dim); + } + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto input_data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param")); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } +}; + template class CPUDenseMomentumFunctor { private: diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 7d1cf57253819b34fedfb292ad1635650f53f20f..b0de636de46451c8b05546fdbff142f984c2bb43 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -296,38 +296,73 @@ Place CUDAPinnedDeviceContext::GetPlace() const { return place_; } #ifdef PADDLE_WITH_MKLDNN MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place) - : CPUDeviceContext(place), engine_(mkldnn::engine::cpu, 0), p_blobs_() { - p_blobs_.reset(new std::unordered_map>()); + : CPUDeviceContext(place), engine_(mkldnn::engine::cpu, 0), p_blobmap_() { + p_blobmap_.reset(new BlobMap()); + p_mutex_.reset(new std::mutex()); } +namespace { +// Current thread's id. +thread_local int cur_thread_id = 0; +} + +void set_cur_thread_id(int tid) { cur_thread_id = tid; } +int get_cur_thread_id(void) { return cur_thread_id; } + void MKLDNNDeviceContext::SetBlob(const std::string& name, std::shared_ptr data) const { - std::unordered_map>* p; - p = p_blobs_.get(); + BlobMap* pMap = p_blobmap_.get(); + std::shared_ptr pBlob = nullptr; + + int tid = platform::get_cur_thread_id(); - auto it = p->find(name); + std::lock_guard lock(*p_mutex_.get()); - if (it == p->end()) { - (*p)[name] = data; // create new blob + // Find KeyBlob for current thread + auto map_it = pMap->find(tid); + + if (map_it == pMap->end()) { + // 1st time to set blob in current thread + pBlob = std::shared_ptr(new KeyBlob()); + (*pMap)[tid] = pBlob; } else { - it->second = data; // set data to existing blob + pBlob = map_it->second; } + // Find Key in found (or newly created) KeyBlob + auto key_it = pBlob->find(name); + + if (key_it == pBlob->end()) { + (*pBlob)[name] = data; // create new blob + } else { + key_it->second = data; // set data to existing blob + } + + // lock will be automatically released when out of scope return; } std::shared_ptr MKLDNNDeviceContext::GetBlob( const std::string& name) const { - std::unordered_map>* p; - p = p_blobs_.get(); + BlobMap* pMap = p_blobmap_.get(); + std::shared_ptr pBlob = nullptr; - auto it = p->find(name); + int tid = platform::get_cur_thread_id(); - if (it != p->end()) { - return it->second; - } + std::lock_guard lock(*p_mutex_.get()); + + // Find KeyBlob for current thread firstly + auto map_it = pMap->find(tid); + if (map_it == pMap->end()) return nullptr; + pBlob = map_it->second; + + // Find Blob via name + auto key_it = pBlob->find(name); + + if (key_it == pBlob->end()) return nullptr; - return nullptr; + // lock will be automatically released when out of scope + return key_it->second; } #endif diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 999bbe00f1659881050cb0dc89570b74b201aca7..942e13a724339dc85ed1fc72c11e208ddce36dbb 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -176,6 +176,12 @@ struct DefaultDeviceContextType { #endif #ifdef PADDLE_WITH_MKLDNN +using KeyBlob = std::unordered_map>; +using BlobMap = std::unordered_map>; + +void set_cur_thread_id(int); +int get_cur_thread_id(void); + class MKLDNNDeviceContext : public CPUDeviceContext { public: explicit MKLDNNDeviceContext(CPUPlace place); @@ -191,8 +197,8 @@ class MKLDNNDeviceContext : public CPUDeviceContext { private: mkldnn::engine engine_; - std::shared_ptr>> - p_blobs_; + std::shared_ptr p_blobmap_; + std::shared_ptr p_mutex_; }; #endif diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 339a7c98c6a2bba2cd46790cecc169ef447c63ce..5f15a29f4c3e9b1412912fe4723642d1ede60346 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -645,9 +645,13 @@ All parameter, weight, gradient are variables in Paddle. py::class_> pass(m, "Pass"); pass.def(py::init()) - .def("set_str", [](ir::Pass &self, const std::string &name, - const std::string &attr) { - self.Set(name, new std::string(attr)); + .def( + "set_str", + [](ir::Pass &self, const std::string &name, const std::string &attr) { + self.Set(name, new std::string(attr)); + }) + .def("set_int", [](ir::Pass &self, const std::string &name, int val) { + self.Set(name, new int(val)); }); py::class_> pb( diff --git a/python/paddle/fluid/layers/learning_rate_scheduler.py b/python/paddle/fluid/layers/learning_rate_scheduler.py index dfd801a098d6451dbdb20d9ba44187d1e3f8a91a..149224bb68ac869dec14ac9f953f0072bd24c7e2 100644 --- a/python/paddle/fluid/layers/learning_rate_scheduler.py +++ b/python/paddle/fluid/layers/learning_rate_scheduler.py @@ -27,7 +27,7 @@ from . import nn from . import ops from . import tensor from ..initializer import init_on_cpu -from ..framework import default_main_program, Parameter, unique_name +from ..framework import default_main_program, Parameter, unique_name, name_scope __all__ = [ 'exponential_decay', 'natural_exp_decay', 'inverse_time_decay', @@ -332,14 +332,16 @@ def append_LARS(params_grads, learning_rate, weight_decay): return grad_norm + weight_decay * param_norm for param, grad in params_grads: - param_lr = param.optimize_attr['learning_rate'] - param_norm = ops.sqrt(nn.reduce_sum(input=ops.square(param))) - grad_norm = ops.sqrt(nn.reduce_sum(input=ops.square(grad))) - if type(param_lr) == float and param_lr == 1.0: - decayed_lr = learning_rate * param_norm \ - / _balanced_weight(param_norm, grad_norm) - else: - decayed_lr = learning_rate * param_lr * param_norm \ - / _balanced_weight(param_norm, grad_norm) - # set back param local learning rate - param.optimize_attr['learning_rate'] = decayed_lr + with param.block.program.optimized_guard( + [param, grad]), name_scope("optimizer"): + param_lr = param.optimize_attr['learning_rate'] + param_norm = ops.sqrt(nn.reduce_sum(input=ops.square(param))) + grad_norm = ops.sqrt(nn.reduce_sum(input=ops.square(grad))) + if type(param_lr) == float and param_lr == 1.0: + decayed_lr = learning_rate * param_norm \ + / _balanced_weight(param_norm, grad_norm) + else: + decayed_lr = learning_rate * param_lr * param_norm \ + / _balanced_weight(param_norm, grad_norm) + # set back param local learning rate + param.optimize_attr['learning_rate'] = decayed_lr diff --git a/python/paddle/fluid/metrics.py b/python/paddle/fluid/metrics.py index 25d43be3b7520b53cb2ff2af8583cea2330acfbf..a4503e75671d7d12ff84bb538776f8e6c832b9d1 100644 --- a/python/paddle/fluid/metrics.py +++ b/python/paddle/fluid/metrics.py @@ -478,7 +478,7 @@ class EditDistance(MetricBase): "There is no data in EditDistance Metric. Please check layers.edit_distance output has been added to EditDistance." ) avg_distance = self.total_distance / self.seq_num - avg_instance_error = self.instance_error / self.seq_num + avg_instance_error = self.instance_error / float(self.seq_num) return avg_distance, avg_instance_error diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 6ea280c733996436a6218cc6f759f2cf7c652ac9..7e2364a5a872cdd8cf590438cc081ab070db767d 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -14,6 +14,7 @@ from __future__ import print_function import re +import sys from collections import defaultdict from paddle.fluid.framework import Program, Variable, name_scope, default_main_program from . import framework @@ -32,7 +33,8 @@ __all__ = [ 'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', 'Ftrl', 'SGDOptimizer', 'MomentumOptimizer', 'AdagradOptimizer', 'AdamOptimizer', 'AdamaxOptimizer', 'DecayedAdagradOptimizer', 'RMSPropOptimizer', - 'FtrlOptimizer', 'Adadelta', 'ModelAverage', 'RMSPropOptimizer' + 'FtrlOptimizer', 'Adadelta', 'ModelAverage', 'LarsMomentum', + 'LarsMomentumOptimizer' ] @@ -105,7 +107,6 @@ class Optimizer(object): param = param_and_grad[0] param_lr = param.optimize_attr['learning_rate'] if type(param_lr) == Variable: - print("returns updated param lr ", param_lr) return param_lr else: if param_lr == 1.0: @@ -400,6 +401,91 @@ class MomentumOptimizer(Optimizer): return momentum_op +class LarsMomentumOptimizer(Optimizer): + """ + Momentum optimizer with LARS support + + The update equations are as follows: + + .. math:: + + & local\_learning\_rate = learning\_rate * lars\_coeff * \\ + \\frac{||param||}{||gradient|| + lars\_weight\_decay * ||param||} + + & velocity = mu * velocity + local\_learning\_rate * (gradient + lars\_weight\_decay * param) + + & param = param - velocity + + Args: + learning_rate (float|Variable): the learning rate used to update parameters. \ + Can be a float value or a Variable with one float value as data element. + momentum (float): momentum factor + lars_coeff (float): defines how much we trust the layer to change its weights. + lars_weight_decay (float): weight decay coefficient for decaying using LARS. + regularization: A Regularizer, such as + fluid.regularizer.L2DecayRegularizer. + name: A optional name prefix. + + + Examples: + .. code-block:: python + + optimizer = fluid.optimizer.LarsMomentum(learning_rate=0.2, momentum=0.1, lars_weight_decay=0.001) + optimizer.minimize(cost) + """ + _velocity_acc_str = "velocity" + + def __init__(self, + learning_rate, + momentum, + lars_coeff=0.001, + lars_weight_decay=0.0005, + regularization=None, + name=None): + assert learning_rate is not None + assert momentum is not None + super(LarsMomentumOptimizer, self).__init__( + learning_rate=learning_rate, + regularization=regularization, + name=name) + self.type = "lars_momentum" + self._momentum = momentum + self._lars_coeff = float(lars_coeff) + self._lars_weight_decay = float(lars_weight_decay) + + def _create_accumulators(self, block, parameters): + assert isinstance(block, framework.Block) + + for p in parameters: + self._add_accumulator(self._velocity_acc_str, p) + + def _append_optimize_op(self, block, param_and_grad): + assert isinstance(block, framework.Block) + + velocity_acc = self._get_accumulator(self._velocity_acc_str, + param_and_grad[0]) + # create the momentum optimize op + momentum_op = block.append_op( + type=self.type, + inputs={ + "Param": param_and_grad[0], + "Grad": param_and_grad[1], + "Velocity": velocity_acc, + "LearningRate": self._create_param_lr(param_and_grad) + }, + outputs={ + "ParamOut": param_and_grad[0], + "VelocityOut": velocity_acc + }, + attrs={ + "mu": self._momentum, + "lars_coeff": self._lars_coeff, + "lars_weight_decay": self._lars_weight_decay + }) + + return momentum_op + + class AdagradOptimizer(Optimizer): """ **Adaptive Gradient Algorithm (Adagrad)** @@ -1221,6 +1307,7 @@ DecayedAdagrad = DecayedAdagradOptimizer Adadelta = AdadeltaOptimizer RMSProp = RMSPropOptimizer Ftrl = FtrlOptimizer +LarsMomentum = LarsMomentumOptimizer class ModelAverage(Optimizer): diff --git a/python/paddle/fluid/tests/unittests/dist_mnist.py b/python/paddle/fluid/tests/unittests/dist_mnist.py index 877d21ae882ab4efb49beb6a846ab71a22c2aab7..01e9795d8b1beb67270f45fe7ba2819bf8c3be3e 100644 --- a/python/paddle/fluid/tests/unittests/dist_mnist.py +++ b/python/paddle/fluid/tests/unittests/dist_mnist.py @@ -95,7 +95,7 @@ class TestDistMnist2x2(TestDistRunnerBase): # Reader train_reader = paddle.batch( - paddle.dataset.mnist.train(), batch_size=batch_size) + paddle.dataset.mnist.test(), batch_size=batch_size) test_reader = paddle.batch( paddle.dataset.mnist.test(), batch_size=batch_size) opt.minimize(avg_cost) diff --git a/python/paddle/fluid/tests/unittests/dist_mnist_batch_merge.py b/python/paddle/fluid/tests/unittests/dist_mnist_batch_merge.py new file mode 100644 index 0000000000000000000000000000000000000000..d386e75fd887a898f5a13e48e378e08ff6c99ea0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dist_mnist_batch_merge.py @@ -0,0 +1,80 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import argparse +import time +import math + +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler +from paddle.fluid import core +import unittest +from multiprocessing import Process +import os +import signal +from functools import reduce +from test_dist_base import TestDistRunnerBase, runtime_main +from dist_mnist import cnn_model + +DTYPE = "float32" + + +def test_merge_reader(repeat_batch_size=8): + orig_reader = paddle.dataset.mnist.test() + record_batch = [] + b = 0 + for d in orig_reader(): + if b >= repeat_batch_size: + break + record_batch.append(d) + b += 1 + while True: + for d in record_batch: + yield d + + +class TestDistMnist2x2(TestDistRunnerBase): + def get_model(self, batch_size=2): + # Input data + images = fluid.layers.data(name='pixel', shape=[1, 28, 28], dtype=DTYPE) + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + + # Train program + predict = cnn_model(images) + cost = fluid.layers.cross_entropy(input=predict, label=label) + avg_cost = fluid.layers.mean(x=cost) + + # Evaluator + batch_size_tensor = fluid.layers.create_tensor(dtype='int64') + batch_acc = fluid.layers.accuracy( + input=predict, label=label, total=batch_size_tensor) + + inference_program = fluid.default_main_program().clone() + # Optimization + opt = fluid.optimizer.Momentum(learning_rate=0.001, momentum=0.9) + + # Reader + train_reader = paddle.batch(test_merge_reader, batch_size=batch_size) + test_reader = paddle.batch( + paddle.dataset.mnist.test(), batch_size=batch_size) + opt.minimize(avg_cost) + return inference_program, avg_cost, train_reader, test_reader, batch_acc, predict + + +if __name__ == "__main__": + runtime_main(TestDistMnist2x2) diff --git a/python/paddle/fluid/tests/unittests/dist_mnist_lars.py b/python/paddle/fluid/tests/unittests/dist_mnist_lars.py new file mode 100644 index 0000000000000000000000000000000000000000..977e17c37f7676ae81d9ab29b6b36089ccbeeacf --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dist_mnist_lars.py @@ -0,0 +1,73 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import argparse +import time +import math + +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler +from paddle.fluid import core +import unittest +from multiprocessing import Process +import os +import signal +from functools import reduce +from test_dist_base import TestDistRunnerBase, runtime_main +from dist_mnist import cnn_model + +DTYPE = "float32" +paddle.dataset.mnist.fetch() + +# Fix seed for test +fluid.default_startup_program().random_seed = 1 +fluid.default_main_program().random_seed = 1 + + +class TestDistMnist2x2(TestDistRunnerBase): + def get_model(self, batch_size=2): + # Input data + images = fluid.layers.data(name='pixel', shape=[1, 28, 28], dtype=DTYPE) + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + + # Train program + predict = cnn_model(images) + cost = fluid.layers.cross_entropy(input=predict, label=label) + avg_cost = fluid.layers.mean(x=cost) + + # Evaluator + batch_size_tensor = fluid.layers.create_tensor(dtype='int64') + batch_acc = fluid.layers.accuracy( + input=predict, label=label, total=batch_size_tensor) + + inference_program = fluid.default_main_program().clone() + # Optimization + opt = fluid.optimizer.LarsMomentumOptimizer( + learning_rate=0.001, momentum=0.9) + + # Reader + train_reader = paddle.batch( + paddle.dataset.mnist.test(), batch_size=batch_size) + test_reader = paddle.batch( + paddle.dataset.mnist.test(), batch_size=batch_size) + opt.minimize(avg_cost) + return inference_program, avg_cost, train_reader, test_reader, batch_acc, predict + + +if __name__ == "__main__": + runtime_main(TestDistMnist2x2) diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index 04924bec057e301bfb342a62bb4c1e0b3c3aff4c..87fd03ca61d33a53b9323edb2ec7e1c71655816b 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -26,10 +26,11 @@ import argparse import paddle.fluid as fluid RUN_STEP = 10 +DEFAULT_BATCH_SIZE = 2 class TestDistRunnerBase(object): - def get_model(self, batch_size=2): + def get_model(self, batch_size=DEFAULT_BATCH_SIZE): raise NotImplementedError( "get_model should be implemented by child classes.") @@ -48,8 +49,7 @@ class TestDistRunnerBase(object): return t def run_pserver(self, args): - - self.get_model(batch_size=2) + self.get_model(batch_size=args.batch_size) # NOTE: pserver should not call memory optimize t = self.get_transpiler(args.trainer_id, fluid.default_main_program(), args.endpoints, @@ -65,7 +65,7 @@ class TestDistRunnerBase(object): def run_trainer(self, args): test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \ - self.get_model(batch_size=2) + self.get_model(batch_size=args.batch_size) if args.mem_opt: fluid.memory_optimize(fluid.default_main_program(), skip_grads=True) @@ -92,6 +92,11 @@ class TestDistRunnerBase(object): strategy.allow_op_delay = False build_stra = fluid.BuildStrategy() + if args.batch_merge_repeat > 1: + pass_builder = build_stra._create_passes_from_strategy() + mypass = pass_builder.insert_pass( + len(pass_builder.all_passes()) - 2, "multi_batch_merge_pass") + mypass.set_int("num_repeats", args.batch_merge_repeat) if args.use_reduce: build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce @@ -145,6 +150,9 @@ def runtime_main(test_class): parser.add_argument('--use_reduce', action='store_true') parser.add_argument( '--use_reader_alloc', action='store_true', required=False, default=True) + parser.add_argument('--batch_size', required=False, type=int, default=2) + parser.add_argument( + '--batch_merge_repeat', required=False, type=int, default=1) args = parser.parse_args() @@ -244,9 +252,18 @@ class TestDistBase(unittest.TestCase): (e, retry_times)) retry_times -= 1 - def _run_local(self, model, envs, check_error_log): + def _run_local(self, + model, + envs, + check_error_log=False, + batch_size=DEFAULT_BATCH_SIZE, + batch_merge_repeat=1): cmd = "%s %s --role trainer" % (self._python_interp, model) + if batch_size != DEFAULT_BATCH_SIZE: + cmd += " --batch_size %d" % batch_size + if batch_merge_repeat > 1: + cmd += " --batch_merge_repeat %d" % batch_merge_repeat if self.__use_cuda: cmd += " --use_cuda" diff --git a/python/paddle/fluid/tests/unittests/test_dist_mnist.py b/python/paddle/fluid/tests/unittests/test_dist_mnist.py index f65dd7e2a28c4ace3988c0cc1267ebe981fbd9cb..922dd838f8996adfc15afffcd44c1acca2bc14a9 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_mnist.py +++ b/python/paddle/fluid/tests/unittests/test_dist_mnist.py @@ -26,6 +26,15 @@ class TestDistMnist2x2(TestDistBase): self.check_with_place("dist_mnist.py", delta=1e-5) +class TestDistMnist2x2Lars(TestDistBase): + def _setup_config(self): + self._sync_mode = True + self._use_reduce = False + + def test_se_resnext(self): + self.check_with_place("dist_mnist_lars.py", delta=1e-5) + + class TestDistMnist2x2WithMemopt(TestDistBase): def _setup_config(self): self._sync_mode = True diff --git a/python/paddle/fluid/tests/unittests/test_dist_mnist_batch_merge.py b/python/paddle/fluid/tests/unittests/test_dist_mnist_batch_merge.py new file mode 100644 index 0000000000000000000000000000000000000000..22d4b7929033529c5cea60064e6d9de57eddeb8e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dist_mnist_batch_merge.py @@ -0,0 +1,67 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +from test_dist_base import TestDistBase +import os + + +class TestDistMnist2x2(TestDistBase): + def _setup_config(self): + self._sync_mode = True + self._use_reduce = False + + def test_dist_train(self): + self.check_with_place("dist_mnist_batch_merge.py", delta=1e-5) + + def check_with_place(self, + model_file, + delta=1e-3, + check_error_log=False, + need_envs={}): + # TODO(typhoonzero): should auto adapt GPU count on the machine. + required_envs = { + "PATH": os.getenv("PATH", ""), + "PYTHONPATH": os.getenv("PYTHONPATH", ""), + "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""), + "FLAGS_fraction_of_gpu_memory_to_use": "0.15", + "FLAGS_cudnn_deterministic": "1", + } + + required_envs.update(need_envs) + + if check_error_log: + required_envs["GLOG_v"] = "7" + required_envs["GLOG_logtostderr"] = "1" + + no_merge_losses = self._run_local( + model_file, + required_envs, + check_error_log=check_error_log, + batch_size=4) + + batch_merge_losses = self._run_local( + model_file, + required_envs, + check_error_log=check_error_log, + batch_size=2, + batch_merge_repeat=2) + # Ensure both result have values. + self.assertGreater(len(no_merge_losses), 1) + self.assertEqual(len(no_merge_losses), len(batch_merge_losses)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_momentum_op.py b/python/paddle/fluid/tests/unittests/test_momentum_op.py index a3d89610b40ff9bd5002e843f8667ada87e67981..cf4346cf2e7a099334ec273546901a91d0ad925d 100644 --- a/python/paddle/fluid/tests/unittests/test_momentum_op.py +++ b/python/paddle/fluid/tests/unittests/test_momentum_op.py @@ -90,6 +90,45 @@ class TestMomentumOp2(OpTest): self.check_output() +class TestLarsMomentumOp(OpTest): + def setUp(self): + self.op_type = "lars_momentum" + + param = np.random.random((123, 321)).astype("float32") + grad = np.random.random((123, 321)).astype("float32") + velocity = np.zeros((123, 321)).astype("float32") + learning_rate = np.array([0.001]).astype("float32") + mu = 0.0001 + lars_coeff = 0.001 + lars_weight_decay = 0.0005 + + self.inputs = { + 'Param': param, + 'Grad': grad, + 'Velocity': velocity, + 'LearningRate': learning_rate + } + + self.attrs = { + 'mu': mu, + 'lars_coeff': lars_coeff, + 'lars_weight_decay': lars_weight_decay + } + + pnorm = np.sqrt(np.square(param).sum()) + gnorm = np.sqrt(np.square(grad).sum()) + local_lr = learning_rate * lars_coeff * pnorm / ( + gnorm + lars_weight_decay * param) + velocity_out = mu * velocity + local_lr * (grad + lars_weight_decay * + param) + param_out = param - velocity_out + + self.outputs = {'ParamOut': param_out, 'VelocityOut': velocity_out} + + def test_check_output(self): + self.check_output() + + class TestSparseMomentumOp(unittest.TestCase): def setUp(self): self.use_nesterov = False diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 28d7df8e45d4f3209fb6205b76f8d1be8e73392b..28ad8443673492b31a6228bc85939549749541e9 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -1431,7 +1431,7 @@ to transpile() call.") elif op_type == "adamax": if varkey in ["Moment", "InfNorm"]: return param_shape - elif op_type == "momentum": + elif op_type in ["momentum", "lars_momentum"]: if varkey == "Velocity": return param_shape elif op_type == "rmsprop": @@ -1442,6 +1442,10 @@ to transpile() call.") return param_shape elif op_type == "sgd": pass + else: + raise ValueError( + "Not supported optimizer for distributed training: %s" % + op_type) return orig_shape def _get_varname_parts(self, varname):