diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 319e173adb3b55983f59773e2e5f598ef63eb6f2..cc25fb156a8e3ace643bf50c9c0d093718079f97 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -2206,6 +2206,14 @@ void MultiplexInferMeta(const std::vector& ins, phi::errors::PreconditionNotMet( "All the candidate tensors must have the same size.")); } + + PADDLE_ENFORCE_GE( + in_dim[0], + ids_dim[0], + phi::errors::InvalidArgument("The 2nd-dim of input cannot be smaller " + "than batchSize of the index tensor.")); + + in_dim[0] = ids_dim[0]; out->set_dims(in_dim); out->set_dtype(ins[0]->dtype()); } diff --git a/paddle/phi/kernels/cpu/multiplex_kernel.cc b/paddle/phi/kernels/cpu/multiplex_kernel.cc index 2d9f4c51a981ed8701afe0aa4e7d6a8955f4348c..4e60448c6c5369d7e93268d54524e6bda5369b30 100644 --- a/paddle/phi/kernels/cpu/multiplex_kernel.cc +++ b/paddle/phi/kernels/cpu/multiplex_kernel.cc @@ -37,7 +37,7 @@ void MultiplexKernel(const Context& ctx, auto rows = ins[0]->dims()[0]; auto cols = ins[0]->numel() / rows; auto index = ids.data(); - for (auto i = 0; i < rows; i++) { + for (auto i = 0; i < ids.dims()[0]; i++) { int32_t k = index[i]; PADDLE_ENFORCE_GE( k, 0, errors::PreconditionNotMet("index must be nonnegative.")); diff --git a/paddle/phi/kernels/gpu/multiplex_kernel.cu b/paddle/phi/kernels/gpu/multiplex_kernel.cu index 743448a46866687cf2ac68be522a306281289252..2a86827bcf4752b6c391bac29315be770913ee5c 100644 --- a/paddle/phi/kernels/gpu/multiplex_kernel.cu +++ b/paddle/phi/kernels/gpu/multiplex_kernel.cu @@ -41,7 +41,7 @@ void MultiplexKernel(const Context& ctx, paddle::framework::TensorCopySync(ids, phi::CPUPlace(), &index_t_cpu); auto* index = index_t_cpu.data(); auto stream = ctx.stream(); - for (auto i = 0; i < rows; i++) { + for (auto i = 0; i < ids.dims()[0]; i++) { int32_t k = index[i]; PADDLE_ENFORCE_GE( k, 0, errors::PreconditionNotMet("index must be nonnegative."));