未验证 提交 0e28c8bb 编写于 作者: Z zhulei 提交者: GitHub

Fix safety-bug of functional.linear (#34696)

* Fix safety-bug of functional.linear

* Fix safety-bug of functional.linear

* Fix safety-bug of functional.linear

* Fix safety-bug of functional.linear
上级 589d13c5
...@@ -1041,6 +1041,12 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMM( ...@@ -1041,6 +1041,12 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
T alpha, const T *A, const T *B, T beta, T *C, int batchCount, T alpha, const T *A, const T *B, T beta, T *C, int batchCount,
int64_t strideA, int64_t strideB) const { int64_t strideA, int64_t strideB) const {
PADDLE_ENFORCE_NOT_NULL(
A, platform::errors::InvalidArgument("Pointer A should not be null."));
PADDLE_ENFORCE_NOT_NULL(
B, platform::errors::InvalidArgument("Pointer B should not be null."));
PADDLE_ENFORCE_NOT_NULL(
C, platform::errors::InvalidArgument("Pointer C should not be null."));
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
int lda = (transA == CblasNoTrans) ? K : M; int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K; int ldb = (transB == CblasNoTrans) ? N : K;
......
...@@ -73,6 +73,15 @@ class LinearTestCase(unittest.TestCase): ...@@ -73,6 +73,15 @@ class LinearTestCase(unittest.TestCase):
np.testing.assert_array_almost_equal(res_f, res_nn) np.testing.assert_array_almost_equal(res_f, res_nn)
np.testing.assert_array_almost_equal(res_nn, res_np) np.testing.assert_array_almost_equal(res_nn, res_np)
def test_error_dummy_input(self, place=paddle.CPUPlace()):
with self.assertRaises(ValueError):
x_arr = np.array([], dtype=np.float32)
x = paddle.to_tensor(
np.reshape(x_arr, (0, 4, 4, 4)), dtype='float32')
weight = paddle.zeros([4, 4, 4], dtype='float32')
bias = paddle.to_tensor([], dtype='float32')
paddle.nn.functional.linear(x, weight, bias=bias)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册