提交 dc3d2dc8 编写于 作者: Q qiaolongfei

rename grad_map to grad_to_id

上级 260bf5ac
...@@ -89,7 +89,7 @@ void AsyncListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -89,7 +89,7 @@ void AsyncListenAndServOp::RunImpl(const framework::Scope &scope,
std::unordered_map<std::string, int32_t> grad_to_id; std::unordered_map<std::string, int32_t> grad_to_id;
std::unordered_map<int32_t, std::string> id_to_grad; std::unordered_map<int32_t, std::string> id_to_grad;
auto grad_map_str = Attr<std::vector<std::string>>("grad_map"); auto grad_map_str = Attr<std::vector<std::string>>("grad_to_id");
for (auto &grad_and_id : grad_map_str) { for (auto &grad_and_id : grad_map_str) {
std::vector<std::string> pieces; std::vector<std::string> pieces;
split(grad_and_id, ' ', &pieces); split(grad_and_id, ' ', &pieces);
...@@ -193,7 +193,7 @@ from send_op and send back variables to recv_op. ...@@ -193,7 +193,7 @@ from send_op and send back variables to recv_op.
.SetDefault("127.0.0.1:6164") .SetDefault("127.0.0.1:6164")
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); }); .AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
AddAttr<std::vector<std::string>>( AddAttr<std::vector<std::string>>(
"grad_map(['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'])", "grad_to_id(['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'])",
"a map from grad name to it's optimize block id") "a map from grad name to it's optimize block id")
.SetDefault({}); .SetDefault({});
AddAttr<framework::BlockDesc *>(kOptimizeBlock, AddAttr<framework::BlockDesc *>(kOptimizeBlock,
......
...@@ -207,8 +207,8 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, ...@@ -207,8 +207,8 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
std::unordered_map<std::string, int32_t> grad_to_id; std::unordered_map<std::string, int32_t> grad_to_id;
std::unordered_map<int32_t, std::string> id_to_grad; std::unordered_map<int32_t, std::string> id_to_grad;
auto grad_map_str = Attr<std::vector<std::string>>("grad_map"); auto grad_to_id_str = Attr<std::vector<std::string>>("grad_to_id");
for (auto &grad_and_id : grad_map_str) { for (auto &grad_and_id : grad_to_id_str) {
std::vector<std::string> pieces; std::vector<std::string> pieces;
split(grad_and_id, ' ', &pieces); split(grad_and_id, ' ', &pieces);
PADDLE_ENFORCE_EQ(pieces.size(), 2); PADDLE_ENFORCE_EQ(pieces.size(), 2);
...@@ -227,7 +227,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, ...@@ -227,7 +227,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
block_list.push_back(blkid); block_list.push_back(blkid);
} }
} }
PADDLE_ENFORCE_EQ(grad_map_str.size(), block_list.size(), PADDLE_ENFORCE_EQ(grad_to_id_str.size(), block_list.size(),
"grad num should be equal to optimize block num"); "grad num should be equal to optimize block num");
auto optimize_prepared = executor->Prepare(*program, block_list); auto optimize_prepared = executor->Prepare(*program, block_list);
...@@ -328,7 +328,7 @@ from send_op and send back variables to recv_op. ...@@ -328,7 +328,7 @@ from send_op and send back variables to recv_op.
.SetDefault("127.0.0.1:6164") .SetDefault("127.0.0.1:6164")
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); }); .AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
AddAttr<std::vector<std::string>>( AddAttr<std::vector<std::string>>(
"grad_map(['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'])", "grad_to_id(['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'])",
"a map from grad name to it's optimize block id") "a map from grad name to it's optimize block id")
.SetDefault({}); .SetDefault({});
AddAttr<bool>("sync_mode", "if works at sync_mode or not") AddAttr<bool>("sync_mode", "if works at sync_mode or not")
......
...@@ -137,7 +137,7 @@ void StartServerNet(bool is_sparse) { ...@@ -137,7 +137,7 @@ void StartServerNet(bool is_sparse) {
attrs.insert({"GradList", std::vector<std::string>({"x1"})}); attrs.insert({"GradList", std::vector<std::string>({"x1"})});
attrs.insert({"OptimizeBlock", optimize_block}); attrs.insert({"OptimizeBlock", optimize_block});
attrs.insert({"PrefetchBlock", prefetch_block}); attrs.insert({"PrefetchBlock", prefetch_block});
attrs.insert({"grad_map", {}}); attrs.insert({"grad_to_id", {}});
attrs.insert({"sync_mode", true}); attrs.insert({"sync_mode", true});
listen_and_serv_op = listen_and_serv_op =
f::OpRegistry::CreateOp("listen_and_serv", {{"X", {"x1"}}}, {}, attrs); f::OpRegistry::CreateOp("listen_and_serv", {{"X", {"x1"}}}, {}, attrs);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册