提交 6a93f0f3 编写于 作者: H hedaoyuan

Add the calculation implementation of GemmConvGradFilterFunction

上级 afbe556e
...@@ -89,11 +89,13 @@ public: ...@@ -89,11 +89,13 @@ public:
protected: protected:
std::vector<size_t> strides_; std::vector<size_t> strides_;
std::vector<size_t> paddings_; std::vector<size_t> paddings_;
/// Group size, refer to grouped convolution in /// Group size, refer to grouped convolution in
/// Alex Krizhevsky's paper: when group=2, the first half of the /// Alex Krizhevsky's paper: when group=2, the first half of the
/// filters are only connected to the first half of the input channels, /// filters are only connected to the first half of the input channels,
/// and the second half only connected to the second half. /// and the second half only connected to the second half.
size_t groups_; size_t groups_;
inline int strideH() const { return strides_[0]; } inline int strideH() const { return strides_[0]; }
inline int strideW() const { return strides_[1]; } inline int strideW() const { return strides_[1]; }
...@@ -101,6 +103,20 @@ protected: ...@@ -101,6 +103,20 @@ protected:
inline int paddingH() const { return paddings_[0]; } inline int paddingH() const { return paddings_[0]; }
inline int paddingW() const { return paddings_[1]; } inline int paddingW() const { return paddings_[1]; }
// A temporary memory in convolution calculation.
MemoryHandlePtr memory_;
template <DeviceType Device>
void resizeBuffer(size_t newSize) {
if (!memory_ || newSize * sizeof(real) > memory_->getAllocSize()) {
if (Device == DEVICE_TYPE_CPU) {
memory_ = std::make_shared<CpuMemoryHandle>(newSize * sizeof(real));
} else {
memory_ = std::make_shared<GpuMemoryHandle>(newSize * sizeof(real));
}
}
}
}; };
} // namespace paddle } // namespace paddle
...@@ -110,7 +110,7 @@ public: ...@@ -110,7 +110,7 @@ public:
size_t size = inputChannels / groups_ * filterHeight * filterWidth * size_t size = inputChannels / groups_ * filterHeight * filterWidth *
outputHeight * outputWidth; outputHeight * outputWidth;
resizeBuffer(size); resizeBuffer<Device>(size);
real* colData = reinterpret_cast<real*>(memory_->getBuf()); real* colData = reinterpret_cast<real*>(memory_->getBuf());
Im2ColFunctor<Device, real> im2col; Im2ColFunctor<Device, real> im2col;
...@@ -120,7 +120,7 @@ public: ...@@ -120,7 +120,7 @@ public:
(outputChannels / groups_) * outputHeight * outputWidth; (outputChannels / groups_) * outputHeight * outputWidth;
size_t filterOffset = inputs[1].shape().getElements() / groups_; size_t filterOffset = inputs[1].shape().getElements() / groups_;
for (size_t i = 0; i < batchSize; i++) { for (size_t i = 0; i < batchSize; i++) {
for (int g = 0; g < groups_; g++) { for (size_t g = 0; g < groups_; g++) {
im2col(inputData + g * inputOffset, im2col(inputData + g * inputOffset,
inputChannels / groups_, inputChannels / groups_,
inputHeight, inputHeight,
...@@ -138,7 +138,9 @@ public: ...@@ -138,7 +138,9 @@ public:
int M = outputChannels / groups_; int M = outputChannels / groups_;
int N = outputHeight * outputWidth; int N = outputHeight * outputWidth;
int K = inputChannels / groups_ * filterHeight * filterWidth; int K = inputChannels / groups_ * filterHeight * filterWidth;
gemm(M, gemm(CblasNoTrans,
CblasNoTrans,
M,
N, N,
K, K,
1.0f, 1.0f,
...@@ -154,19 +156,6 @@ public: ...@@ -154,19 +156,6 @@ public:
outputData += outputChannels * outputHeight * outputWidth; outputData += outputChannels * outputHeight * outputWidth;
} }
} }
void resizeBuffer(size_t newSize) {
if (!memory_ || newSize * sizeof(real) > memory_->getAllocSize()) {
if (Device == DEVICE_TYPE_CPU) {
memory_ = std::make_shared<CpuMemoryHandle>(newSize * sizeof(real));
} else {
memory_ = std::make_shared<GpuMemoryHandle>(newSize * sizeof(real));
}
}
}
private:
MemoryHandlePtr memory_;
}; };
/* /*
...@@ -202,10 +191,73 @@ public: ...@@ -202,10 +191,73 @@ public:
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(numInputs_, inputs.size()); CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size()); CHECK_EQ(numOutputs_, outputs.size());
const TensorShape& outputGrad = inputs[0].shape(); CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
const TensorShape& output = inputs[0].shape();
const TensorShape& input = inputs[1].shape(); const TensorShape& input = inputs[1].shape();
const TensorShape& filterGrad = outputs[0].shape(); const TensorShape& filter = outputs[0].shape();
check(input, filterGrad, outputGrad); check(input, filter, output);
size_t batchSize = input[0];
size_t inputChannels = input[1];
size_t inputHeight = input[2];
size_t inputWidth = input[3];
size_t filterHeight = filter[2];
size_t filterWidth = filter[3];
size_t outputChannels = output[1];
size_t outputHeight = output[2];
size_t outputWidth = output[3];
real* outputGrad = inputs[0].data<real>();
real* inputData = inputs[1].data<real>();
real* filterGrad = outputs[0].data<real>();
size_t size = inputChannels / groups_ * filterHeight * filterWidth *
outputHeight * outputWidth;
resizeBuffer<Device>(size);
real* colData = reinterpret_cast<real*>(memory_->getBuf());
Im2ColFunctor<Device, real> im2col;
GemmFunctor<Device, real> gemm;
size_t inputOffset = (inputChannels / groups_) * inputHeight * inputWidth;
size_t outputOffset =
(outputChannels / groups_) * outputHeight * outputWidth;
size_t filterOffset = filter.getElements() / groups_;
for (size_t i = 0; i < batchSize; i++) {
for (size_t g = 0; g < groups_; g++) {
im2col(inputData + g * inputOffset,
inputChannels / groups_,
inputHeight,
inputWidth,
filterHeight,
filterWidth,
strideH(),
strideW(),
paddingH(),
paddingW(),
outputHeight,
outputWidth,
colData);
int M = outputChannels / groups_;
int K = outputHeight * outputWidth;
int N = inputChannels / groups_ * filterHeight * filterWidth;
gemm(CblasNoTrans,
CblasTrans,
M,
N,
K,
1.0f,
outputGrad + g * outputOffset,
K,
colData,
K,
1.0f,
filterGrad + g * filterOffset,
N);
}
}
inputData += inputChannels * inputHeight * inputWidth;
outputGrad += outputChannels * outputHeight * outputWidth;
} }
}; };
......
...@@ -26,7 +26,9 @@ namespace paddle { ...@@ -26,7 +26,9 @@ namespace paddle {
template <DeviceType Device, class T> template <DeviceType Device, class T>
class GemmFunctor { class GemmFunctor {
public: public:
void operator()(const int M, void operator()(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE TransB,
const int M,
const int N, const int N,
const int K, const int K,
const T alpha, const T alpha,
...@@ -42,7 +44,9 @@ public: ...@@ -42,7 +44,9 @@ public:
template <class T> template <class T>
class GemmFunctor<DEVICE_TYPE_CPU, T> { class GemmFunctor<DEVICE_TYPE_CPU, T> {
public: public:
void operator()(const int M, void operator()(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE TransB,
const int M,
const int N, const int N,
const int K, const int K,
const T alpha, const T alpha,
...@@ -53,26 +57,16 @@ public: ...@@ -53,26 +57,16 @@ public:
const T beta, const T beta,
T* C, T* C,
const int ldc) { const int ldc) {
gemm<T>(CblasNoTrans, gemm<T>(transA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
CblasNoTrans,
M,
N,
K,
alpha,
A,
lda,
B,
ldb,
beta,
C,
ldc);
} }
}; };
template <class T> template <class T>
class GemmFunctor<DEVICE_TYPE_GPU, T> { class GemmFunctor<DEVICE_TYPE_GPU, T> {
public: public:
void operator()(const int M, void operator()(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE TransB,
const int M,
const int N, const int N,
const int K, const int K,
const T alpha, const T alpha,
...@@ -84,9 +78,9 @@ public: ...@@ -84,9 +78,9 @@ public:
T* C, T* C,
const int ldc) { const int ldc) {
hl_matrix_mul((T*)A, hl_matrix_mul((T*)A,
HPPL_OP_N, transA == CblasNoTrans ? HPPL_OP_N : HPPL_OP_T,
(T*)B, (T*)B,
HPPL_OP_N, TransB == CblasNoTrans ? HPPL_OP_N : HPPL_OP_T,
C, C,
M, M,
N, N,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册