提交 e330cd03 编写于 作者: C chengduoZH

balance parameter update

上级 d9061062
...@@ -11,11 +11,15 @@ ...@@ -11,11 +11,15 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" #include <algorithm>
#include <fstream> #include <fstream>
#include <string>
#include <utility> #include <utility>
#include <vector>
#include "paddle/fluid/framework/details/broadcast_op_handle.h" #include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h" #include "paddle/fluid/framework/details/rpc_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
...@@ -26,9 +30,6 @@ ...@@ -26,9 +30,6 @@
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h" #include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
#endif #endif
#include <string>
#include <vector>
DEFINE_string(ssa_graph_path, "/tmp/ssa_graph.dot", DEFINE_string(ssa_graph_path, "/tmp/ssa_graph.dot",
"the ssa graph path only print with GLOG_v=10," "the ssa graph path only print with GLOG_v=10,"
"default /tmp/graph.dot"); "default /tmp/graph.dot");
...@@ -148,9 +149,9 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp( ...@@ -148,9 +149,9 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
const ProgramDesc &program) const { const ProgramDesc &program) const {
std::unordered_map<std::string, proto::VarType::Type> var_types; std::unordered_map<std::string, VarDesc *> all_vars;
for (auto *var : program.Block(0).AllVars()) { for (auto *var : program.Block(0).AllVars()) {
var_types[var->Name()] = var->GetType(); all_vars[var->Name()] = var;
} }
auto graph = new SSAGraph(); auto graph = new SSAGraph();
...@@ -167,12 +168,28 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -167,12 +168,28 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
auto send_vars = FindDistTrainSendVars(program); auto send_vars = FindDistTrainSendVars(program);
auto recv_vars = FindDistTrainRecvVars(program); auto recv_vars = FindDistTrainRecvVars(program);
size_t cur_device_id = 0;
std::vector<std::unordered_set<std::string>> var_name_on_devices; std::vector<std::unordered_set<std::string>> var_name_on_devices;
std::vector<std::unordered_set<std::string>> bcast_var_name_set; std::vector<std::unordered_set<std::string>> bcast_var_name_set;
var_name_on_devices.resize(places_.size()); var_name_on_devices.resize(places_.size());
bcast_var_name_set.resize(places_.size()); bcast_var_name_set.resize(places_.size());
size_t cur_device_id = 0;
std::vector<int64_t> balance_grads(places_.size(), 0);
auto get_appropriate_dev = [&](std::string &g_name) -> size_t {
auto var_desc = all_vars.at(g_name);
PADDLE_ENFORCE_NOT_NULL(var_desc);
auto dim = framework::make_ddim(var_desc->GetShape());
int64_t numel = framework::product(dim);
PADDLE_ENFORCE_GE(numel, 0);
auto smallest =
std::min_element(std::begin(balance_grads), std::end(balance_grads));
size_t dev_id =
static_cast<size_t>(std::distance(std::begin(balance_grads), smallest));
balance_grads[dev_id] += numel;
return dev_id;
};
bool is_forwarding = true; bool is_forwarding = true;
for (auto *op : program.Block(0).AllOps()) { for (auto *op : program.Block(0).AllOps()) {
if (boost::get<int>( if (boost::get<int>(
...@@ -220,13 +237,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -220,13 +237,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
switch (strategy_.reduce_) { switch (strategy_.reduce_) {
case BuildStrategy::ReduceStrategy::kReduce: case BuildStrategy::ReduceStrategy::kReduce:
cur_device_id = get_appropriate_dev(g_name);
CreateReduceOp(&result, g_name, cur_device_id); CreateReduceOp(&result, g_name, cur_device_id);
var_name_on_devices[cur_device_id].emplace(g_name); var_name_on_devices[cur_device_id].emplace(g_name);
bcast_var_name_set[cur_device_id].emplace(p_name); bcast_var_name_set[cur_device_id].emplace(p_name);
cur_device_id = (cur_device_id + 1) % places_.size();
break; break;
case BuildStrategy::ReduceStrategy::kAllReduce: case BuildStrategy::ReduceStrategy::kAllReduce:
if (IsSparseGradient(var_types, g_name)) { if (IsSparseGradient(all_vars, g_name)) {
CreateReduceOp(&result, g_name, 0); CreateReduceOp(&result, g_name, 0);
CreateBroadcastOp(&result, g_name, 0); CreateBroadcastOp(&result, g_name, 0);
} else { } else {
...@@ -269,10 +286,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -269,10 +286,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
} }
bool MultiDevSSAGraphBuilder::IsSparseGradient( bool MultiDevSSAGraphBuilder::IsSparseGradient(
const std::unordered_map<std::string, proto::VarType::Type> &var_types, const std::unordered_map<std::string, VarDesc *> &all_vars,
const std::string &og) const { const std::string &og) const {
PADDLE_ENFORCE(var_types.count(og) != 0); PADDLE_ENFORCE(all_vars.count(og) != 0);
if (var_types.at(og) == proto::VarType::SELECTED_ROWS) { if (all_vars.at(og)->GetType() == proto::VarType::SELECTED_ROWS) {
return true; return true;
} }
return false; return false;
......
...@@ -106,7 +106,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -106,7 +106,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
size_t src_dev_id) const; size_t src_dev_id) const;
bool IsSparseGradient( bool IsSparseGradient(
const std::unordered_map<std::string, proto::VarType::Type> &var_types, const std::unordered_map<std::string, VarDesc *> &all_vars,
const std::string &og) const; const std::string &og) const;
private: private:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册