From f6f0c562c4212b253c111625475bc85daf338e40 Mon Sep 17 00:00:00 2001 From: Haohongxiang <86215757+haohongxiang@users.noreply.github.com> Date: Wed, 28 Dec 2022 11:14:27 +0800 Subject: [PATCH] fix bugs of paddle.multiplex API (#49368) --- paddle/phi/infermeta/multiary.cc | 8 ++++++++ paddle/phi/kernels/cpu/multiplex_kernel.cc | 2 +- paddle/phi/kernels/gpu/multiplex_kernel.cu | 2 +- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 319e173adb3..cc25fb156a8 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -2206,6 +2206,14 @@ void MultiplexInferMeta(const std::vector& ins, phi::errors::PreconditionNotMet( "All the candidate tensors must have the same size.")); } + + PADDLE_ENFORCE_GE( + in_dim[0], + ids_dim[0], + phi::errors::InvalidArgument("The 2nd-dim of input cannot be smaller " + "than batchSize of the index tensor.")); + + in_dim[0] = ids_dim[0]; out->set_dims(in_dim); out->set_dtype(ins[0]->dtype()); } diff --git a/paddle/phi/kernels/cpu/multiplex_kernel.cc b/paddle/phi/kernels/cpu/multiplex_kernel.cc index 2d9f4c51a98..4e60448c6c5 100644 --- a/paddle/phi/kernels/cpu/multiplex_kernel.cc +++ b/paddle/phi/kernels/cpu/multiplex_kernel.cc @@ -37,7 +37,7 @@ void MultiplexKernel(const Context& ctx, auto rows = ins[0]->dims()[0]; auto cols = ins[0]->numel() / rows; auto index = ids.data(); - for (auto i = 0; i < rows; i++) { + for (auto i = 0; i < ids.dims()[0]; i++) { int32_t k = index[i]; PADDLE_ENFORCE_GE( k, 0, errors::PreconditionNotMet("index must be nonnegative.")); diff --git a/paddle/phi/kernels/gpu/multiplex_kernel.cu b/paddle/phi/kernels/gpu/multiplex_kernel.cu index 743448a4686..2a86827bcf4 100644 --- a/paddle/phi/kernels/gpu/multiplex_kernel.cu +++ b/paddle/phi/kernels/gpu/multiplex_kernel.cu @@ -41,7 +41,7 @@ void MultiplexKernel(const Context& ctx, paddle::framework::TensorCopySync(ids, phi::CPUPlace(), &index_t_cpu); auto* index = index_t_cpu.data(); auto stream = ctx.stream(); - for (auto i = 0; i < rows; i++) { + for (auto i = 0; i < ids.dims()[0]; i++) { int32_t k = index[i]; PADDLE_ENFORCE_GE( k, 0, errors::PreconditionNotMet("index must be nonnegative.")); -- GitLab