未验证 提交 32baca93 编写于 作者: D Difer 提交者: GitHub

Case7:paddle.distribution.Beta:fix beta(true stack) (#51847)

上级 65c6d2ef
......@@ -25,6 +25,15 @@ void StackKernel(const Context& dev_ctx,
int axis,
DenseTensor* out) {
if (axis < 0) axis += (x[0]->dims().size() + 1);
auto x_dims = x[0]->dims();
for (int i = 0; i < x_dims.size(); i++) {
PADDLE_ENFORCE_GT(x_dims[i],
0,
phi::errors::InvalidArgument(
"The dims of Input(X) should be greater than 0"));
}
int n = static_cast<int>(x.size());
T* y_data = dev_ctx.template Alloc<T>(out);
std::vector<const T*> x_datas(n);
......
......@@ -77,11 +77,12 @@ void StackRawKernel(const Context& ctx,
// Split x dim from axis to matrix of shape [x_row, x_col], and the output
// tensor's shape is [x_row, out_col].
int64_t x_row = 1;
int64_t x_row = 1, x_row_bak = 1;
for (int i = 0; i < axis; ++i) {
x_row *= x[0]->dims()[i];
}
int64_t x_col = x[0]->numel() / x_row;
x_row_bak = x_row == 0 ? 1 : x_row;
int64_t x_col = x[0]->numel() / x_row_bak;
int64_t out_col = x_col * num;
if (out->numel() < std::numeric_limits<int32_t>::max()) {
......
......@@ -113,6 +113,12 @@ class TestBeta(unittest.TestCase):
== case.get('expect')
)
def test_errors(self):
with self.assertRaises(ValueError):
array = np.array([], dtype=np.float32)
x = paddle.to_tensor(np.reshape(array, [0]), dtype='int32')
paddle.distribution.Beta(alpha=x, beta=x)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册