未验证 提交 b430f6a3 编写于 作者: Y Yulong Ao 提交者: GitHub

Add a check for multiplex op (#34972)

上级 628ff34b
......@@ -29,6 +29,14 @@ class MultiplexGPUKernel : public framework::OpKernel<T> {
auto* out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
for (size_t i = 0; i < ins.size(); ++i) {
PADDLE_ENFORCE_GT(
ins[i]->numel(), 0,
platform::errors::OutOfRange(
"indexing will be out of bounds with size 0 for the %d-th input.",
i));
}
auto rows = ins[0]->dims()[0];
auto cols = ins[0]->numel() / rows;
// copy index to cpu
......
......@@ -31,6 +31,14 @@ class MultiplexCPUKernel : public framework::OpKernel<T> {
out->mutable_data<T>(ctx.GetPlace());
for (size_t i = 0; i < ins.size(); ++i) {
PADDLE_ENFORCE_GT(
ins[i]->numel(), 0,
platform::errors::OutOfRange(
"indexing will be out of bounds with size 0 for the %d-th input.",
i));
}
auto rows = ins[0]->dims()[0];
auto cols = ins[0]->numel() / rows;
auto index = ids->data<int32_t>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册