未验证 提交 6d2d8e50 编写于 作者: H Haohongxiang 提交者: GitHub

fix bugs of paddle.multiplex API (#49368) (#49642)

上级 1d25c663
...@@ -2190,6 +2190,14 @@ void MultiplexInferMeta(const std::vector<const MetaTensor*>& ins, ...@@ -2190,6 +2190,14 @@ void MultiplexInferMeta(const std::vector<const MetaTensor*>& ins,
phi::errors::PreconditionNotMet( phi::errors::PreconditionNotMet(
"All the candidate tensors must have the same size.")); "All the candidate tensors must have the same size."));
} }
PADDLE_ENFORCE_GE(
in_dim[0],
ids_dim[0],
phi::errors::InvalidArgument("The 2nd-dim of input cannot be smaller "
"than batchSize of the index tensor."));
in_dim[0] = ids_dim[0];
out->set_dims(in_dim); out->set_dims(in_dim);
out->set_dtype(ins[0]->dtype()); out->set_dtype(ins[0]->dtype());
} }
......
...@@ -37,7 +37,7 @@ void MultiplexKernel(const Context& ctx, ...@@ -37,7 +37,7 @@ void MultiplexKernel(const Context& ctx,
auto rows = ins[0]->dims()[0]; auto rows = ins[0]->dims()[0];
auto cols = ins[0]->numel() / rows; auto cols = ins[0]->numel() / rows;
auto index = ids.data<int32_t>(); auto index = ids.data<int32_t>();
for (auto i = 0; i < rows; i++) { for (auto i = 0; i < ids.dims()[0]; i++) {
int32_t k = index[i]; int32_t k = index[i];
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
k, 0, errors::PreconditionNotMet("index must be nonnegative.")); k, 0, errors::PreconditionNotMet("index must be nonnegative."));
......
...@@ -41,7 +41,7 @@ void MultiplexKernel(const Context& ctx, ...@@ -41,7 +41,7 @@ void MultiplexKernel(const Context& ctx,
paddle::framework::TensorCopySync(ids, phi::CPUPlace(), &index_t_cpu); paddle::framework::TensorCopySync(ids, phi::CPUPlace(), &index_t_cpu);
auto* index = index_t_cpu.data<int32_t>(); auto* index = index_t_cpu.data<int32_t>();
auto stream = ctx.stream(); auto stream = ctx.stream();
for (auto i = 0; i < rows; i++) { for (auto i = 0; i < ids.dims()[0]; i++) {
int32_t k = index[i]; int32_t k = index[i];
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
k, 0, errors::PreconditionNotMet("index must be nonnegative.")); k, 0, errors::PreconditionNotMet("index must be nonnegative."));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册