未验证 提交 31c33122 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

fix multinomial paddle_enforce bug (#42302)

上级 e5a0365b
...@@ -133,11 +133,10 @@ void MultinomialKernel(const Context& dev_ctx, ...@@ -133,11 +133,10 @@ void MultinomialKernel(const Context& dev_ctx,
DenseTensor* out) { DenseTensor* out) {
auto* in_data = x.data<T>(); auto* in_data = x.data<T>();
int64_t* out_data = dev_ctx.template Alloc<int64_t>(out); int64_t* out_data = dev_ctx.template Alloc<int64_t>(out);
auto in_dims = x.dims(); auto in_dims = x.dims();
int64_t in_rank = in_dims.size(); int64_t dim_size = in_dims.size();
const int64_t num_categories = in_dims[in_rank - 1]; const int64_t num_categories = in_dims[dim_size - 1];
const int64_t num_distributions = in_rank > 1 ? in_dims[in_rank - 2] : 1; const int64_t num_distributions = dim_size > 1 ? in_dims[dim_size - 2] : 1;
// If replacement is False, it's not a replaceable sample. Every category // If replacement is False, it's not a replaceable sample. Every category
// can be used only once. // can be used only once.
...@@ -145,8 +144,8 @@ void MultinomialKernel(const Context& dev_ctx, ...@@ -145,8 +144,8 @@ void MultinomialKernel(const Context& dev_ctx,
int64_t in_data_numel = x.numel(); int64_t in_data_numel = x.numel();
int64_t out_data_numel = out->numel(); int64_t out_data_numel = out->numel();
// Just use to PADDLE_ENFORCE error message
T* cpu_in_data = new T[in_data_numel]; T* cpu_in_data = new T[in_data_numel];
int64_t* cpu_out_data = new int64_t[out_data_numel];
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
hipMemcpy( hipMemcpy(
...@@ -160,7 +159,7 @@ void MultinomialKernel(const Context& dev_ctx, ...@@ -160,7 +159,7 @@ void MultinomialKernel(const Context& dev_ctx,
for (size_t i = 0; i < num_distributions; ++i) { for (size_t i = 0; i < num_distributions; ++i) {
int zero_num = 0; int zero_num = 0;
for (size_t j = 0; j < num_categories; ++j) { for (size_t j = 0; j < num_categories; ++j) {
T weight = cpu_in_data[i * num_distributions + j]; T weight = cpu_in_data[i * num_categories + j];
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
weight, weight,
0, 0,
......
...@@ -216,6 +216,14 @@ class TestMultinomialError(unittest.TestCase): ...@@ -216,6 +216,14 @@ class TestMultinomialError(unittest.TestCase):
self.assertRaises(ValueError, test_dim_less_than_1) self.assertRaises(ValueError, test_dim_less_than_1)
with self.assertRaises(ValueError):
y = paddle.multinomial(paddle.to_tensor([1., 2., -3.]))
with self.assertRaises(ValueError):
prob = paddle.rand([20, 1000])
prob[1:0] = 0
y = paddle.multinomial(prob)
class TestRandomValue(unittest.TestCase): class TestRandomValue(unittest.TestCase):
def test_fixed_random_number(self): def test_fixed_random_number(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册