diff --git a/paddle/fluid/operators/sampling_id_op.cc b/paddle/fluid/operators/sampling_id_op.cc index e88310745fb492b34252b5195722e8eae6bdb31a..ca7b2469010d57fc6bcb8b6cea8149fdbb091e58 100644 --- a/paddle/fluid/operators/sampling_id_op.cc +++ b/paddle/fluid/operators/sampling_id_op.cc @@ -33,6 +33,10 @@ class SamplingIdKernel : public framework::OpKernel { const int batch_size = static_cast(input->dims()[0]); const int width = static_cast(input->dims()[1]); + PADDLE_ENFORCE_GE(batch_size, 0, + "batch_size(dims[0]) must be nonnegative."); + PADDLE_ENFORCE_GE(width, 0, "width(dims[1]) must be nonnegative."); + std::vector ins_vector; framework::TensorToVector(*input, context.device_context(), &ins_vector); diff --git a/paddle/fluid/operators/sampling_id_op.cu b/paddle/fluid/operators/sampling_id_op.cu index b1047103746acfa5bcbe126295ea2f29ef337f10..114df044afcf2bef971ddd294aee7b2f4779aec4 100644 --- a/paddle/fluid/operators/sampling_id_op.cu +++ b/paddle/fluid/operators/sampling_id_op.cu @@ -46,6 +46,10 @@ class SamplingIdGPUKernel : public framework::OpKernel { const int batch_size = static_cast(input->dims()[0]); const int width = static_cast(input->dims()[1]); + PADDLE_ENFORCE_GE(batch_size, 0, + "batch_size(dims[0]) must be nonnegative."); + PADDLE_ENFORCE_GE(width, 0, "width(dims[1]) must be nonnegative."); + std::vector ins_vector; framework::TensorToVector(*input, context.device_context(), &ins_vector); @@ -56,10 +60,11 @@ class SamplingIdGPUKernel : public framework::OpKernel { } T min = static_cast(context.Attr("min")); T max = static_cast(context.Attr("max")); + UniformGenerator gen = UniformGenerator(min, max, seed); std::vector ids(batch_size); for (size_t i = 0; i < batch_size; ++i) { - T r = UniformGenerator(min, max, seed); + T r = gen(0); int idx = width - 1; for (int j = 0; j < width; ++j) { if ((r -= ins_vector[i * width + j]) < 0) {