From 6d2d8e50a5f2059fea5bcff972bab9de3018b877 Mon Sep 17 00:00:00 2001 From: Haohongxiang <86215757+haohongxiang@users.noreply.github.com> Date: Mon, 9 Jan 2023 17:13:57 +0800 Subject: [PATCH] fix bugs of paddle.multiplex API (#49368) (#49642) --- paddle/phi/infermeta/multiary.cc | 8 ++++++++ paddle/phi/kernels/cpu/multiplex_kernel.cc | 2 +- paddle/phi/kernels/gpu/multiplex_kernel.cu | 2 +- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 1ab67ede69..375b88493a 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -2190,6 +2190,14 @@ void MultiplexInferMeta(const std::vector& ins, phi::errors::PreconditionNotMet( "All the candidate tensors must have the same size.")); } + + PADDLE_ENFORCE_GE( + in_dim[0], + ids_dim[0], + phi::errors::InvalidArgument("The 2nd-dim of input cannot be smaller " + "than batchSize of the index tensor.")); + + in_dim[0] = ids_dim[0]; out->set_dims(in_dim); out->set_dtype(ins[0]->dtype()); } diff --git a/paddle/phi/kernels/cpu/multiplex_kernel.cc b/paddle/phi/kernels/cpu/multiplex_kernel.cc index 2d9f4c51a9..4e60448c6c 100644 --- a/paddle/phi/kernels/cpu/multiplex_kernel.cc +++ b/paddle/phi/kernels/cpu/multiplex_kernel.cc @@ -37,7 +37,7 @@ void MultiplexKernel(const Context& ctx, auto rows = ins[0]->dims()[0]; auto cols = ins[0]->numel() / rows; auto index = ids.data(); - for (auto i = 0; i < rows; i++) { + for (auto i = 0; i < ids.dims()[0]; i++) { int32_t k = index[i]; PADDLE_ENFORCE_GE( k, 0, errors::PreconditionNotMet("index must be nonnegative.")); diff --git a/paddle/phi/kernels/gpu/multiplex_kernel.cu b/paddle/phi/kernels/gpu/multiplex_kernel.cu index 743448a468..2a86827bcf 100644 --- a/paddle/phi/kernels/gpu/multiplex_kernel.cu +++ b/paddle/phi/kernels/gpu/multiplex_kernel.cu @@ -41,7 +41,7 @@ void MultiplexKernel(const Context& ctx, paddle::framework::TensorCopySync(ids, phi::CPUPlace(), &index_t_cpu); auto* index = index_t_cpu.data(); auto stream = ctx.stream(); - for (auto i = 0; i < rows; i++) { + for (auto i = 0; i < ids.dims()[0]; i++) { int32_t k = index[i]; PADDLE_ENFORCE_GE( k, 0, errors::PreconditionNotMet("index must be nonnegative.")); -- GitLab