diff --git a/paddle/fluid/operators/split_ids_op.h b/paddle/fluid/operators/split_ids_op.h index 3e750ed2d171876ce2d3c232f5d34234217b3c3e..d36ed398ebce661a62ca92696b0089b5289d5b1c 100644 --- a/paddle/fluid/operators/split_ids_op.h +++ b/paddle/fluid/operators/split_ids_op.h @@ -30,19 +30,16 @@ class SplitIdsOpKernel : public framework::OpKernel { PADDLE_THROW("SplitIds do not support GPU kernel"); } - const auto* ids_t = ctx.Input("Ids"); - auto& ids_dims = ids_t->dims(); + auto& ids_dims = ctx.Input("Ids")->dims(); + const T* ids = ctx.Input("Ids")->data(); auto outs = ctx.MultiOutput("Out"); - - const T* ids = ids_t->data(); - const size_t shard_num = outs.size(); std::vector> out_ids; out_ids.resize(outs.size()); // split id by their shard_num. - for (size_t i = 0; i < ids_dims[0]; ++i) { + for (int i = 0; i < ids_dims[0]; ++i) { T id = ids[i]; size_t shard_id = static_cast(id) % shard_num; out_ids[shard_id].push_back(id);