未验证 提交 ad92fa61 编写于 作者: Y Yulong Ao 提交者: GitHub

[Cherry-pick] Add a check for multiplex op (#38757)

* [Cherry-pick] Add the forward QR operator

* Add a check for multiplex op

* Improve multiplex based on reviews
上级 0d081cbc
...@@ -29,6 +29,14 @@ class MultiplexGPUKernel : public framework::OpKernel<T> { ...@@ -29,6 +29,14 @@ class MultiplexGPUKernel : public framework::OpKernel<T> {
auto* out = ctx.Output<Tensor>("Out"); auto* out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
for (size_t i = 0; i < ins.size(); ++i) {
PADDLE_ENFORCE_GT(
ins[i]->numel(), 0,
platform::errors::OutOfRange(
"indexing will be out of bounds with size 0 for the %d-th input.",
i));
}
auto rows = ins[0]->dims()[0]; auto rows = ins[0]->dims()[0];
auto cols = ins[0]->numel() / rows; auto cols = ins[0]->numel() / rows;
// copy index to cpu // copy index to cpu
......
...@@ -31,6 +31,14 @@ class MultiplexCPUKernel : public framework::OpKernel<T> { ...@@ -31,6 +31,14 @@ class MultiplexCPUKernel : public framework::OpKernel<T> {
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
for (size_t i = 0; i < ins.size(); ++i) {
PADDLE_ENFORCE_GT(
ins[i]->numel(), 0,
platform::errors::OutOfRange(
"indexing will be out of bounds with size 0 for the %d-th input.",
i));
}
auto rows = ins[0]->dims()[0]; auto rows = ins[0]->dims()[0];
auto cols = ins[0]->numel() / rows; auto cols = ins[0]->numel() / rows;
auto index = ids->data<int32_t>(); auto index = ids->data<int32_t>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册