提交 e330cd03 编写于 作者: C chengduoZH

balance parameter update

上级 d9061062
......@@ -11,11 +11,15 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include <algorithm>
#include <fstream>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
......@@ -26,9 +30,6 @@
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
#endif
#include <string>
#include <vector>
DEFINE_string(ssa_graph_path, "/tmp/ssa_graph.dot",
"the ssa graph path only print with GLOG_v=10,"
"default /tmp/graph.dot");
......@@ -148,9 +149,9 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
const ProgramDesc &program) const {
std::unordered_map<std::string, proto::VarType::Type> var_types;
std::unordered_map<std::string, VarDesc *> all_vars;
for (auto *var : program.Block(0).AllVars()) {
var_types[var->Name()] = var->GetType();
all_vars[var->Name()] = var;
}
auto graph = new SSAGraph();
......@@ -167,12 +168,28 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
auto send_vars = FindDistTrainSendVars(program);
auto recv_vars = FindDistTrainRecvVars(program);
size_t cur_device_id = 0;
std::vector<std::unordered_set<std::string>> var_name_on_devices;
std::vector<std::unordered_set<std::string>> bcast_var_name_set;
var_name_on_devices.resize(places_.size());
bcast_var_name_set.resize(places_.size());
size_t cur_device_id = 0;
std::vector<int64_t> balance_grads(places_.size(), 0);
auto get_appropriate_dev = [&](std::string &g_name) -> size_t {
auto var_desc = all_vars.at(g_name);
PADDLE_ENFORCE_NOT_NULL(var_desc);
auto dim = framework::make_ddim(var_desc->GetShape());
int64_t numel = framework::product(dim);
PADDLE_ENFORCE_GE(numel, 0);
auto smallest =
std::min_element(std::begin(balance_grads), std::end(balance_grads));
size_t dev_id =
static_cast<size_t>(std::distance(std::begin(balance_grads), smallest));
balance_grads[dev_id] += numel;
return dev_id;
};
bool is_forwarding = true;
for (auto *op : program.Block(0).AllOps()) {
if (boost::get<int>(
......@@ -220,13 +237,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
switch (strategy_.reduce_) {
case BuildStrategy::ReduceStrategy::kReduce:
cur_device_id = get_appropriate_dev(g_name);
CreateReduceOp(&result, g_name, cur_device_id);
var_name_on_devices[cur_device_id].emplace(g_name);
bcast_var_name_set[cur_device_id].emplace(p_name);
cur_device_id = (cur_device_id + 1) % places_.size();
break;
case BuildStrategy::ReduceStrategy::kAllReduce:
if (IsSparseGradient(var_types, g_name)) {
if (IsSparseGradient(all_vars, g_name)) {
CreateReduceOp(&result, g_name, 0);
CreateBroadcastOp(&result, g_name, 0);
} else {
......@@ -269,10 +286,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
}
bool MultiDevSSAGraphBuilder::IsSparseGradient(
const std::unordered_map<std::string, proto::VarType::Type> &var_types,
const std::unordered_map<std::string, VarDesc *> &all_vars,
const std::string &og) const {
PADDLE_ENFORCE(var_types.count(og) != 0);
if (var_types.at(og) == proto::VarType::SELECTED_ROWS) {
PADDLE_ENFORCE(all_vars.count(og) != 0);
if (all_vars.at(og)->GetType() == proto::VarType::SELECTED_ROWS) {
return true;
}
return false;
......
......@@ -106,7 +106,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
size_t src_dev_id) const;
bool IsSparseGradient(
const std::unordered_map<std::string, proto::VarType::Type> &var_types,
const std::unordered_map<std::string, VarDesc *> &all_vars,
const std::string &og) const;
private:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册