diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index ebd1d644bcce0554904ee0931827ef43a6d52b11..4f3889a71e865ce6f9799d976c6e6ec5c365429b 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -680,7 +680,8 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, } if (node->Op()->Type() == "split_byref" || - node->Op()->Type() == "split_selected_rows") { + node->Op()->Type() == "split_selected_rows" || + node->Op()->Type() == "split_ids") { // TODO(paddle-dev): getting the first var is not safe. op_dev_id = GetVarDeviceID(*result, input_var_names[0]); if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {