From 9f09d68678c66e4759ce0bffc338cae87d5ec9d5 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 15 Aug 2018 21:22:14 +0800 Subject: [PATCH] add enforce --- paddle/fluid/operators/sampling_id_op.cc | 4 ++++ paddle/fluid/operators/sampling_id_op.cu | 7 ++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/sampling_id_op.cc b/paddle/fluid/operators/sampling_id_op.cc index e88310745f..ca7b246901 100644 --- a/paddle/fluid/operators/sampling_id_op.cc +++ b/paddle/fluid/operators/sampling_id_op.cc @@ -33,6 +33,10 @@ class SamplingIdKernel : public framework::OpKernel { const int batch_size = static_cast(input->dims()[0]); const int width = static_cast(input->dims()[1]); + PADDLE_ENFORCE_GE(batch_size, 0, + "batch_size(dims[0]) must be nonnegative."); + PADDLE_ENFORCE_GE(width, 0, "width(dims[1]) must be nonnegative."); + std::vector ins_vector; framework::TensorToVector(*input, context.device_context(), &ins_vector); diff --git a/paddle/fluid/operators/sampling_id_op.cu b/paddle/fluid/operators/sampling_id_op.cu index b104710374..114df044af 100644 --- a/paddle/fluid/operators/sampling_id_op.cu +++ b/paddle/fluid/operators/sampling_id_op.cu @@ -46,6 +46,10 @@ class SamplingIdGPUKernel : public framework::OpKernel { const int batch_size = static_cast(input->dims()[0]); const int width = static_cast(input->dims()[1]); + PADDLE_ENFORCE_GE(batch_size, 0, + "batch_size(dims[0]) must be nonnegative."); + PADDLE_ENFORCE_GE(width, 0, "width(dims[1]) must be nonnegative."); + std::vector ins_vector; framework::TensorToVector(*input, context.device_context(), &ins_vector); @@ -56,10 +60,11 @@ class SamplingIdGPUKernel : public framework::OpKernel { } T min = static_cast(context.Attr("min")); T max = static_cast(context.Attr("max")); + UniformGenerator gen = UniformGenerator(min, max, seed); std::vector ids(batch_size); for (size_t i = 0; i < batch_size; ++i) { - T r = UniformGenerator(min, max, seed); + T r = gen(0); int idx = width - 1; for (int j = 0; j < width; ++j) { if ((r -= ins_vector[i * width + j]) < 0) { -- GitLab