提交 b59002da 编写于 作者: D dongzhihong

"fix math gemm lda order error"

上级 514398c0
......@@ -25,8 +25,8 @@ void gemm<platform::CPUPlace, float>(const CBLAS_TRANSPOSE transA,
const float alpha, const float* A,
const float* B, const float beta, float* C,
platform::DeviceContext* context) {
int lda = K;
int ldb = N;
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
cblas_sgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
......@@ -40,8 +40,8 @@ void gemm<platform::CPUPlace, double>(const CBLAS_TRANSPOSE transA,
const double* B, const double beta,
double* C,
platform::DeviceContext* context) {
int lda = K;
int ldb = N;
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
cblas_dgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
......
......@@ -23,7 +23,9 @@ class MulGradOpTest(GradientChecker):
'X': np.random.random((32, 84)).astype("float32"),
'Y': np.random.random((84, 100)).astype("float32")
}
self.check_grad(op, inputs, set(["X", "Y"]), "Out")
# mul op will enlarge the relative error
self.check_grad(
op, inputs, set(["X", "Y"]), "Out", max_relative_error=0.5)
# TODO(dzh,qijun) : mulgrad test case need transpose feature of blas library
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册