提交 254154a9 编写于 作者: Y yi.wu

fix sparse paraexe dist train

上级 f0cf70ec
...@@ -470,7 +470,7 @@ void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op, ...@@ -470,7 +470,7 @@ void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
const OpDesc &op) const { const OpDesc &op) const {
int op_dev_id = -1; int op_dev_id = -1;
if (op.Type() == "split_byref") { if (op.Type() == "split_byref" || op.Type() == "split_selected_rows") {
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]); op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames()); op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册