提交 ec2ba242 编写于 作者: H hedaoyuan

Fix GemmConvFunction.

上级 53b0e427
...@@ -4,6 +4,8 @@ file(GLOB cpp_files . *Op.cpp) ...@@ -4,6 +4,8 @@ file(GLOB cpp_files . *Op.cpp)
list(APPEND h_files Function.h) list(APPEND h_files Function.h)
list(APPEND cpp_files Function.cpp) list(APPEND cpp_files Function.cpp)
list(APPEND cpp_files BufferArg.cpp) list(APPEND cpp_files BufferArg.cpp)
list(APPEND cpp_files GemmFunctor.cpp)
list(APPEND cpp_files EigenGemm.cpp)
if(WITH_GPU) if(WITH_GPU)
file(GLOB cu_files . *OpGpu.cu) file(GLOB cu_files . *OpGpu.cu)
......
...@@ -85,7 +85,6 @@ public: ...@@ -85,7 +85,6 @@ public:
} }
Im2ColFunctor<kCFO, Device, real> im2col; Im2ColFunctor<kCFO, Device, real> im2col;
GemmFunctor<Device, real> gemm;
size_t inputOffset = imShape.getElements(); size_t inputOffset = imShape.getElements();
size_t outputOffset = size_t outputOffset =
(outputChannels / groups_) * outputHeight * outputWidth; (outputChannels / groups_) * outputHeight * outputWidth;
...@@ -108,8 +107,8 @@ public: ...@@ -108,8 +107,8 @@ public:
int M = outputChannels / groups_; int M = outputChannels / groups_;
int N = outputHeight * outputWidth; int N = outputHeight * outputWidth;
int K = inputChannels / groups_ * filterHeight * filterWidth; int K = inputChannels / groups_ * filterHeight * filterWidth;
gemm(CblasNoTrans, BlasGemm<Device, real>::compute(false,
CblasNoTrans, false,
M, M,
N, N,
K, K,
...@@ -188,8 +187,6 @@ public: ...@@ -188,8 +187,6 @@ public:
} }
Col2ImFunctor<kCFO, Device, real> col2im; Col2ImFunctor<kCFO, Device, real> col2im;
GemmFunctor<Device, real> gemm;
size_t inputOffset = imShape.getElements(); size_t inputOffset = imShape.getElements();
size_t outputOffset = size_t outputOffset =
(outputChannels / groups_) * outputHeight * outputWidth; (outputChannels / groups_) * outputHeight * outputWidth;
...@@ -205,8 +202,8 @@ public: ...@@ -205,8 +202,8 @@ public:
colData = inputGrad + g * inputOffset; colData = inputGrad + g * inputOffset;
scale = 1.0f; scale = 1.0f;
} }
gemm(CblasTrans, BlasGemm<Device, real>::compute(true,
CblasNoTrans, false,
M, M,
N, N,
K, K,
...@@ -299,7 +296,6 @@ public: ...@@ -299,7 +296,6 @@ public:
} }
Im2ColFunctor<kCFO, Device, real> im2col; Im2ColFunctor<kCFO, Device, real> im2col;
GemmFunctor<Device, real> gemm;
size_t inputOffset = imShape.getElements(); size_t inputOffset = imShape.getElements();
size_t outputOffset = size_t outputOffset =
(outputChannels / groups_) * outputHeight * outputWidth; (outputChannels / groups_) * outputHeight * outputWidth;
...@@ -321,8 +317,8 @@ public: ...@@ -321,8 +317,8 @@ public:
int M = outputChannels / groups_; int M = outputChannels / groups_;
int K = outputHeight * outputWidth; int K = outputHeight * outputWidth;
int N = inputChannels / groups_ * filterHeight * filterWidth; int N = inputChannels / groups_ * filterHeight * filterWidth;
gemm(CblasNoTrans, BlasGemm<Device, real>::compute(false,
CblasTrans, true,
M, M,
N, N,
K, K,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册