提交 9f09d686 编写于 作者: T tangwei12

add enforce

上级 baa6273c
......@@ -33,6 +33,10 @@ class SamplingIdKernel : public framework::OpKernel<T> {
const int batch_size = static_cast<int>(input->dims()[0]);
const int width = static_cast<int>(input->dims()[1]);
PADDLE_ENFORCE_GE(batch_size, 0,
"batch_size(dims[0]) must be nonnegative.");
PADDLE_ENFORCE_GE(width, 0, "width(dims[1]) must be nonnegative.");
std::vector<T> ins_vector;
framework::TensorToVector(*input, context.device_context(), &ins_vector);
......
......@@ -46,6 +46,10 @@ class SamplingIdGPUKernel : public framework::OpKernel<T> {
const int batch_size = static_cast<int>(input->dims()[0]);
const int width = static_cast<int>(input->dims()[1]);
PADDLE_ENFORCE_GE(batch_size, 0,
"batch_size(dims[0]) must be nonnegative.");
PADDLE_ENFORCE_GE(width, 0, "width(dims[1]) must be nonnegative.");
std::vector<T> ins_vector;
framework::TensorToVector(*input, context.device_context(), &ins_vector);
......@@ -56,10 +60,11 @@ class SamplingIdGPUKernel : public framework::OpKernel<T> {
}
T min = static_cast<T>(context.Attr<float>("min"));
T max = static_cast<T>(context.Attr<float>("max"));
UniformGenerator<T> gen = UniformGenerator<T>(min, max, seed);
std::vector<T> ids(batch_size);
for (size_t i = 0; i < batch_size; ++i) {
T r = UniformGenerator<T>(min, max, seed);
T r = gen(0);
int idx = width - 1;
for (int j = 0; j < width; ++j) {
if ((r -= ins_vector[i * width + j]) < 0) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册