未验证 提交 b8da70c3 编写于 作者: W Wu Yi 提交者: GitHub

Resovle multi gpu async deps (#12828)

* dist transpiler add control dependency var between send and recv

* fix async deps

* follow comments and refine

* fix deps connect for rpc ops
上级 4a4567fc
...@@ -763,6 +763,8 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, ...@@ -763,6 +763,8 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
// Create RPC related op handles that connects its in ops and out ops. // Create RPC related op handles that connects its in ops and out ops.
void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
ir::Node *node) const { ir::Node *node) const {
// FIXME(typhoonzero): Cleanup this deps for both sync mode and async mode
// put them into transpiler.
int op_dev_id = -1; int op_dev_id = -1;
if (node->Op()->Type() == "send") { if (node->Op()->Type() == "send") {
// TODO(paddle-dev): getting the first var is not safe. // TODO(paddle-dev): getting the first var is not safe.
...@@ -771,26 +773,42 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ...@@ -771,26 +773,42 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
"This hack no longer holds, please fix."); "This hack no longer holds, please fix.");
// the variable name which contains .block means it was splited by // the variable name which contains .block means it was splited by
// split_byref op // split_byref op
// so that we can balance the variable blocks to all the pserver
// instances.
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce && if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce &&
node->inputs[0]->Name().find(".block") == std::string::npos) { node->inputs[0]->Name().find(".block") == std::string::npos) {
std::vector<std::string> input_var_names; std::vector<std::string> input_var_names;
for (ir::Node *n : node->inputs) { for (ir::Node *n : node->inputs) {
input_var_names.push_back(n->Name()); input_var_names.push_back(n->Name());
} }
op_dev_id = GetAppropriateDeviceID(input_var_names); auto send_param_grad = boost::get<std::vector<std::string>>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(send_param_grad.size(), 2U);
op_dev_id = GetAppropriateDeviceID({send_param_grad[1]});
VLOG(10) << "send grad " << input_var_names[0] << " origin "
<< send_param_grad[1] << " place: " << op_dev_id;
for (auto &varname : input_var_names) { for (auto &varname : input_var_names) {
result->Get<ShardedVarDevice>(kShardedVarDevice) result->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(varname, op_dev_id); .emplace(varname, op_dev_id);
} }
result->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(send_param_grad[1], op_dev_id);
} }
} else if (node->Op()->Type() == "recv") { } else if (node->Op()->Type() == "recv") {
std::vector<std::string> output_var_names; std::vector<std::string> output_var_names;
for (ir::Node *n : node->outputs) { for (ir::Node *n : node->outputs) {
output_var_names.push_back(n->Name()); output_var_names.push_back(n->Name());
} }
op_dev_id = GetAppropriateDeviceID(output_var_names); auto recv_param_grad = boost::get<std::vector<std::string>>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
// FIXME(typhoonzero): assume each recv op output one param
// Use the same place as send.
if (recv_param_grad.size() == 2U) {
op_dev_id = GetVarDeviceID(*result, recv_param_grad[1]);
VLOG(10) << "recv param " << recv_param_grad[0]
<< " get grad place: " << recv_param_grad[1]
<< " place: " << op_dev_id;
} else {
op_dev_id = GetAppropriateDeviceID(output_var_names);
}
for (auto &varname : output_var_names) { for (auto &varname : output_var_names) {
result->Get<ShardedVarDevice>(kShardedVarDevice) result->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(varname, op_dev_id); .emplace(varname, op_dev_id);
......
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
const char Node::kControlDepVarName[] = "__control_var"; constexpr char Node::kControlDepVarName[];
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -27,7 +27,7 @@ namespace ir { ...@@ -27,7 +27,7 @@ namespace ir {
class Node { class Node {
public: public:
enum class Type { kOperation, kVariable }; enum class Type { kOperation, kVariable };
static const char kControlDepVarName[]; static constexpr char kControlDepVarName[] = "__control_var";
explicit Node(const std::string& name, Type type) explicit Node(const std::string& name, Type type)
: name_(name), var_desc_(nullptr), op_desc_(nullptr), type_(type) {} : name_(name), var_desc_(nullptr), op_desc_(nullptr), type_(type) {}
......
...@@ -13,7 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/pybind/const_value.h" #include "paddle/fluid/pybind/const_value.h"
#include <paddle/fluid/framework/op_proto_maker.h> #include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
namespace paddle { namespace paddle {
...@@ -24,6 +25,8 @@ void BindConstValue(pybind11::module* m) { ...@@ -24,6 +25,8 @@ void BindConstValue(pybind11::module* m) {
m->def("kTempVarName", [] { return framework::kTempVarName; }); m->def("kTempVarName", [] { return framework::kTempVarName; });
m->def("kGradVarSuffix", [] { return framework::kGradVarSuffix; }); m->def("kGradVarSuffix", [] { return framework::kGradVarSuffix; });
m->def("kZeroVarSuffix", [] { return framework::kZeroVarSuffix; }); m->def("kZeroVarSuffix", [] { return framework::kZeroVarSuffix; });
m->def("kControlDepVarName",
[] { return framework::ir::Node::kControlDepVarName; });
auto op_proto_and_checker_maker = auto op_proto_and_checker_maker =
m->def_submodule("op_proto_and_checker_maker"); m->def_submodule("op_proto_and_checker_maker");
......
...@@ -50,6 +50,12 @@ EMPTY_VAR_NAME = core.kEmptyVarName() ...@@ -50,6 +50,12 @@ EMPTY_VAR_NAME = core.kEmptyVarName()
TEMP_VAR_NAME = core.kTempVarName() TEMP_VAR_NAME = core.kTempVarName()
GRAD_VAR_SUFFIX = core.kGradVarSuffix() GRAD_VAR_SUFFIX = core.kGradVarSuffix()
ZERO_VAR_SUFFIX = core.kZeroVarSuffix() ZERO_VAR_SUFFIX = core.kZeroVarSuffix()
CONTROL_DEP_VAR_PREFIX = core.kControlDepVarName()
def generate_control_dev_var_name():
import random
return CONTROL_DEP_VAR_PREFIX + "@" + str(random.random())
def grad_var_name(var_name): def grad_var_name(var_name):
......
...@@ -212,8 +212,10 @@ class DistributeTranspiler(object): ...@@ -212,8 +212,10 @@ class DistributeTranspiler(object):
ps_dispatcher = self.config.split_method(self.pserver_endpoints) ps_dispatcher = self.config.split_method(self.pserver_endpoints)
self.has_distributed_lookup_table = self._has_distributed_lookup_table() self.has_distributed_lookup_table = self._has_distributed_lookup_table()
self.param_name_to_grad_name = dict() self.param_name_to_grad_name = dict()
self.grad_name_to_param_name = dict()
for param_var, grad_var in self.params_grads: for param_var, grad_var in self.params_grads:
self.param_name_to_grad_name[param_var.name] = grad_var.name self.param_name_to_grad_name[param_var.name] = grad_var.name
self.grad_name_to_param_name[grad_var.name] = param_var.name
# add distributed attrs to program # add distributed attrs to program
self.origin_program._is_distributed = True self.origin_program._is_distributed = True
...@@ -262,8 +264,10 @@ class DistributeTranspiler(object): ...@@ -262,8 +264,10 @@ class DistributeTranspiler(object):
AssertionError("Can not insert the send op by original " AssertionError("Can not insert the send op by original "
"variable name :", splited_grad_varname) "variable name :", splited_grad_varname)
dummy_output = program.global_block().create_var() dummy_output = program.global_block().create_var(
name=framework.generate_control_dev_var_name())
grad_name_to_send_dummy_out[grad_varname] = dummy_output grad_name_to_send_dummy_out[grad_varname] = dummy_output
program.global_block()._insert_op( program.global_block()._insert_op(
index=index + 1, index=index + 1,
type="send", type="send",
...@@ -272,6 +276,8 @@ class DistributeTranspiler(object): ...@@ -272,6 +276,8 @@ class DistributeTranspiler(object):
attrs={ attrs={
"epmap": eplist, "epmap": eplist,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
OP_ROLE_VAR_ATTR_NAME:
[self.grad_name_to_param_name[grad_varname], grad_varname],
"sync_mode": not self.sync_mode, "sync_mode": not self.sync_mode,
}) })
for _, var in enumerate(splited_vars): for _, var in enumerate(splited_vars):
...@@ -313,6 +319,10 @@ class DistributeTranspiler(object): ...@@ -313,6 +319,10 @@ class DistributeTranspiler(object):
attrs={ attrs={
"epmap": eps, "epmap": eps,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
OP_ROLE_VAR_ATTR_NAME: [
param_varname,
self.param_name_to_grad_name[param_varname]
],
"sync_mode": not self.sync_mode "sync_mode": not self.sync_mode
}) })
...@@ -971,7 +981,11 @@ class DistributeTranspiler(object): ...@@ -971,7 +981,11 @@ class DistributeTranspiler(object):
attrs={ attrs={
"sync_mode": True, "sync_mode": True,
"epmap": pserver_endpoints, "epmap": pserver_endpoints,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
OP_ROLE_VAR_ATTR_NAME: [
self.grad_name_to_param_name[table_grad_name],
table_grad_name
]
}) })
break break
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册