提交 6a93f0f3 编写于 作者: H hedaoyuan

Add the calculation implementation of GemmConvGradFilterFunction

上级 afbe556e
......@@ -89,11 +89,13 @@ public:
protected:
std::vector<size_t> strides_;
std::vector<size_t> paddings_;
/// Group size, refer to grouped convolution in
/// Alex Krizhevsky's paper: when group=2, the first half of the
/// filters are only connected to the first half of the input channels,
/// and the second half only connected to the second half.
size_t groups_;
inline int strideH() const { return strides_[0]; }
inline int strideW() const { return strides_[1]; }
......@@ -101,6 +103,20 @@ protected:
inline int paddingH() const { return paddings_[0]; }
inline int paddingW() const { return paddings_[1]; }
// A temporary memory in convolution calculation.
MemoryHandlePtr memory_;
template <DeviceType Device>
void resizeBuffer(size_t newSize) {
if (!memory_ || newSize * sizeof(real) > memory_->getAllocSize()) {
if (Device == DEVICE_TYPE_CPU) {
memory_ = std::make_shared<CpuMemoryHandle>(newSize * sizeof(real));
} else {
memory_ = std::make_shared<GpuMemoryHandle>(newSize * sizeof(real));
}
}
}
};
} // namespace paddle
......@@ -110,7 +110,7 @@ public:
size_t size = inputChannels / groups_ * filterHeight * filterWidth *
outputHeight * outputWidth;
resizeBuffer(size);
resizeBuffer<Device>(size);
real* colData = reinterpret_cast<real*>(memory_->getBuf());
Im2ColFunctor<Device, real> im2col;
......@@ -120,7 +120,7 @@ public:
(outputChannels / groups_) * outputHeight * outputWidth;
size_t filterOffset = inputs[1].shape().getElements() / groups_;
for (size_t i = 0; i < batchSize; i++) {
for (int g = 0; g < groups_; g++) {
for (size_t g = 0; g < groups_; g++) {
im2col(inputData + g * inputOffset,
inputChannels / groups_,
inputHeight,
......@@ -138,7 +138,9 @@ public:
int M = outputChannels / groups_;
int N = outputHeight * outputWidth;
int K = inputChannels / groups_ * filterHeight * filterWidth;
gemm(M,
gemm(CblasNoTrans,
CblasNoTrans,
M,
N,
K,
1.0f,
......@@ -154,19 +156,6 @@ public:
outputData += outputChannels * outputHeight * outputWidth;
}
}
void resizeBuffer(size_t newSize) {
if (!memory_ || newSize * sizeof(real) > memory_->getAllocSize()) {
if (Device == DEVICE_TYPE_CPU) {
memory_ = std::make_shared<CpuMemoryHandle>(newSize * sizeof(real));
} else {
memory_ = std::make_shared<GpuMemoryHandle>(newSize * sizeof(real));
}
}
}
private:
MemoryHandlePtr memory_;
};
/*
......@@ -202,10 +191,73 @@ public:
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size());
const TensorShape& outputGrad = inputs[0].shape();
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
const TensorShape& output = inputs[0].shape();
const TensorShape& input = inputs[1].shape();
const TensorShape& filterGrad = outputs[0].shape();
check(input, filterGrad, outputGrad);
const TensorShape& filter = outputs[0].shape();
check(input, filter, output);
size_t batchSize = input[0];
size_t inputChannels = input[1];
size_t inputHeight = input[2];
size_t inputWidth = input[3];
size_t filterHeight = filter[2];
size_t filterWidth = filter[3];
size_t outputChannels = output[1];
size_t outputHeight = output[2];
size_t outputWidth = output[3];
real* outputGrad = inputs[0].data<real>();
real* inputData = inputs[1].data<real>();
real* filterGrad = outputs[0].data<real>();
size_t size = inputChannels / groups_ * filterHeight * filterWidth *
outputHeight * outputWidth;
resizeBuffer<Device>(size);
real* colData = reinterpret_cast<real*>(memory_->getBuf());
Im2ColFunctor<Device, real> im2col;
GemmFunctor<Device, real> gemm;
size_t inputOffset = (inputChannels / groups_) * inputHeight * inputWidth;
size_t outputOffset =
(outputChannels / groups_) * outputHeight * outputWidth;
size_t filterOffset = filter.getElements() / groups_;
for (size_t i = 0; i < batchSize; i++) {
for (size_t g = 0; g < groups_; g++) {
im2col(inputData + g * inputOffset,
inputChannels / groups_,
inputHeight,
inputWidth,
filterHeight,
filterWidth,
strideH(),
strideW(),
paddingH(),
paddingW(),
outputHeight,
outputWidth,
colData);
int M = outputChannels / groups_;
int K = outputHeight * outputWidth;
int N = inputChannels / groups_ * filterHeight * filterWidth;
gemm(CblasNoTrans,
CblasTrans,
M,
N,
K,
1.0f,
outputGrad + g * outputOffset,
K,
colData,
K,
1.0f,
filterGrad + g * filterOffset,
N);
}
}
inputData += inputChannels * inputHeight * inputWidth;
outputGrad += outputChannels * outputHeight * outputWidth;
}
};
......
......@@ -26,7 +26,9 @@ namespace paddle {
template <DeviceType Device, class T>
class GemmFunctor {
public:
void operator()(const int M,
void operator()(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE TransB,
const int M,
const int N,
const int K,
const T alpha,
......@@ -42,7 +44,9 @@ public:
template <class T>
class GemmFunctor<DEVICE_TYPE_CPU, T> {
public:
void operator()(const int M,
void operator()(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE TransB,
const int M,
const int N,
const int K,
const T alpha,
......@@ -53,26 +57,16 @@ public:
const T beta,
T* C,
const int ldc) {
gemm<T>(CblasNoTrans,
CblasNoTrans,
M,
N,
K,
alpha,
A,
lda,
B,
ldb,
beta,
C,
ldc);
gemm<T>(transA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
}
};
template <class T>
class GemmFunctor<DEVICE_TYPE_GPU, T> {
public:
void operator()(const int M,
void operator()(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE TransB,
const int M,
const int N,
const int K,
const T alpha,
......@@ -84,9 +78,9 @@ public:
T* C,
const int ldc) {
hl_matrix_mul((T*)A,
HPPL_OP_N,
transA == CblasNoTrans ? HPPL_OP_N : HPPL_OP_T,
(T*)B,
HPPL_OP_N,
TransB == CblasNoTrans ? HPPL_OP_N : HPPL_OP_T,
C,
M,
N,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册