From b430f6a3745bef9c89741733f46ceb4f99660939 Mon Sep 17 00:00:00 2001 From: Yulong Ao Date: Sun, 26 Sep 2021 13:57:42 +0800 Subject: [PATCH] Add a check for multiplex op (#34972) --- paddle/fluid/operators/multiplex_op.cu | 8 ++++++++ paddle/fluid/operators/multiplex_op.h | 8 ++++++++ 2 files changed, 16 insertions(+) diff --git a/paddle/fluid/operators/multiplex_op.cu b/paddle/fluid/operators/multiplex_op.cu index 64331b88bfc..505e322310c 100644 --- a/paddle/fluid/operators/multiplex_op.cu +++ b/paddle/fluid/operators/multiplex_op.cu @@ -29,6 +29,14 @@ class MultiplexGPUKernel : public framework::OpKernel { auto* out = ctx.Output("Out"); out->mutable_data(ctx.GetPlace()); + for (size_t i = 0; i < ins.size(); ++i) { + PADDLE_ENFORCE_GT( + ins[i]->numel(), 0, + platform::errors::OutOfRange( + "indexing will be out of bounds with size 0 for the %d-th input.", + i)); + } + auto rows = ins[0]->dims()[0]; auto cols = ins[0]->numel() / rows; // copy index to cpu diff --git a/paddle/fluid/operators/multiplex_op.h b/paddle/fluid/operators/multiplex_op.h index cb8d5eb2f76..c0f24a2034a 100644 --- a/paddle/fluid/operators/multiplex_op.h +++ b/paddle/fluid/operators/multiplex_op.h @@ -31,6 +31,14 @@ class MultiplexCPUKernel : public framework::OpKernel { out->mutable_data(ctx.GetPlace()); + for (size_t i = 0; i < ins.size(); ++i) { + PADDLE_ENFORCE_GT( + ins[i]->numel(), 0, + platform::errors::OutOfRange( + "indexing will be out of bounds with size 0 for the %d-th input.", + i)); + } + auto rows = ins[0]->dims()[0]; auto cols = ins[0]->numel() / rows; auto index = ids->data(); -- GitLab