未验证 提交 9ffedcfd 编写于 作者: z8hanghuan's avatar z8hanghuan 提交者: GitHub

support multi_dims for tril_triu, *test=kunlun (#40712)

* support multi_dims for tril_triu, *test=kunlun

* support multi_dims for tril_triu, *test=kunlun

* support multi_dims for tril_triu, *test=kunlun

* update xpu.cmake date, support multi_dims for tril_triu, *test=kunlun
上级 608a5f55
......@@ -36,7 +36,7 @@ ENDIF()
if(NOT DEFINED XPU_BASE_URL)
SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220307")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220324")
else()
SET(XPU_BASE_URL "${XPU_BASE_URL}")
endif()
......
......@@ -150,7 +150,7 @@ void *Alloc<platform::XPUPlace>(const platform::XPUPlace &place, size_t size) {
platform::XPUDeviceGuard gurad(place.device);
int ret = xpu_malloc(reinterpret_cast<void **>(&p), size);
if (ret != XPU_SUCCESS) {
std::cout << "xpu memory malloc(" << size << ") failed, try again\n";
VLOG(10) << "xpu memory malloc(" << size << ") failed, try again";
xpu_wait();
ret = xpu_malloc(reinterpret_cast<void **>(&p), size);
}
......
......@@ -94,17 +94,27 @@ class XPUTestTrilTriuOp(XPUOpTestWrapper):
class TestTrilTriuOp3(TestTrilTriuOp):
def initTestCase(self):
self.diagonal = 10
self.Xshape = (25, 25)
self.Xshape = (2, 25, 25)
class TestTrilTriuOp4(TestTrilTriuOp):
def initTestCase(self):
self.diagonal = -10
self.Xshape = (33, 11)
self.Xshape = (1, 2, 33, 11)
class TestTrilTriuOp5(TestTrilTriuOp):
def initTestCase(self):
self.diagonal = 11
self.Xshape = (1, 99)
self.Xshape = (1, 1, 99)
class TestTrilTriuOp6(TestTrilTriuOp):
def initTestCase(self):
self.diagonal = 5
self.Xshape = (1, 2, 3, 5, 99)
class TestTrilTriuOp7(TestTrilTriuOp):
def initTestCase(self):
self.diagonal = -100
self.Xshape = (2, 2, 3, 4, 5)
class TestTrilTriuOpError(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册