diff --git a/paddle/fluid/operators/multiplex_op.cu b/paddle/fluid/operators/multiplex_op.cu index 64331b88bfc04a1761ebbc381f55228d3d183759..505e322310caf5f6c146b123bd9ff8b3474d806a 100644 --- a/paddle/fluid/operators/multiplex_op.cu +++ b/paddle/fluid/operators/multiplex_op.cu @@ -29,6 +29,14 @@ class MultiplexGPUKernel : public framework::OpKernel { auto* out = ctx.Output("Out"); out->mutable_data(ctx.GetPlace()); + for (size_t i = 0; i < ins.size(); ++i) { + PADDLE_ENFORCE_GT( + ins[i]->numel(), 0, + platform::errors::OutOfRange( + "indexing will be out of bounds with size 0 for the %d-th input.", + i)); + } + auto rows = ins[0]->dims()[0]; auto cols = ins[0]->numel() / rows; // copy index to cpu diff --git a/paddle/fluid/operators/multiplex_op.h b/paddle/fluid/operators/multiplex_op.h index cb8d5eb2f761da512dcc27ce7f832306eaafa244..c0f24a2034a150970361fcbdf1d0de0892dc5754 100644 --- a/paddle/fluid/operators/multiplex_op.h +++ b/paddle/fluid/operators/multiplex_op.h @@ -31,6 +31,14 @@ class MultiplexCPUKernel : public framework::OpKernel { out->mutable_data(ctx.GetPlace()); + for (size_t i = 0; i < ins.size(); ++i) { + PADDLE_ENFORCE_GT( + ins[i]->numel(), 0, + platform::errors::OutOfRange( + "indexing will be out of bounds with size 0 for the %d-th input.", + i)); + } + auto rows = ins[0]->dims()[0]; auto cols = ins[0]->numel() / rows; auto index = ids->data();