未验证 提交 d68c38ef 编写于 作者: S seemingwang 提交者: GitHub

add embedding range check (#46991)

* add embedding range check

* change head file

* change head file

* fix
上级 e5e3d5cf
...@@ -13,13 +13,12 @@ ...@@ -13,13 +13,12 @@
// limitations under the License. // limitations under the License.
#include "paddle/phi/kernels/embedding_kernel.h" #include "paddle/phi/kernels/embedding_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/embedding_util.h" #include "paddle/phi/kernels/funcs/embedding_util.h"
namespace phi { namespace phi {
template <typename T, typename IdT, bool PaddingFlag> template <typename T, typename IdT, bool PaddingFlag>
...@@ -35,6 +34,16 @@ __global__ void EmbeddingFW(T *output, ...@@ -35,6 +34,16 @@ __global__ void EmbeddingFW(T *output,
while (idy < K) { while (idy < K) {
auto id = static_cast<int64_t>(ids[idy]); auto id = static_cast<int64_t>(ids[idy]);
if (PaddingFlag == false || id != padding_idx) {
PADDLE_ENFORCE(id >= 0,
"Id should no less than 0 but received an id value: %lld.",
id);
PADDLE_ENFORCE(
id < N,
"Id should smaller than %lld but received an id value: %lld.",
N,
id);
}
T *out = output + idy * D; T *out = output + idy * D;
const T *tab = table + id * D; const T *tab = table + id * D;
for (int i = idx; i < D; i += blockDim.x) { for (int i = idx; i < D; i += blockDim.x) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册