提交 3ce974b9 编写于 作者: H hedaoyuan

Add group argument in ConvFunctionBase

上级 048b14a9
...@@ -38,6 +38,7 @@ public: ...@@ -38,6 +38,7 @@ public:
// function arguments // function arguments
strides_ = config.get<std::vector<size_t>>("strides"); strides_ = config.get<std::vector<size_t>>("strides");
paddings_ = config.get<std::vector<size_t>>("paddings"); paddings_ = config.get<std::vector<size_t>>("paddings");
groups_ = config.get<size_t>("groups");
// number of inputs and outputs // number of inputs and outputs
numInputs_ = 2; numInputs_ = 2;
...@@ -62,6 +63,11 @@ public: ...@@ -62,6 +63,11 @@ public:
protected: protected:
std::vector<size_t> strides_; std::vector<size_t> strides_;
std::vector<size_t> paddings_; std::vector<size_t> paddings_;
/// Group size, refer to grouped convolution in
/// Alex Krizhevsky's paper: when group=2, the first half of the
/// filters are only connected to the first half of the input channels,
/// and the second half only connected to the second half.
size_t groups_;
inline int strideH() const { return strides_[0]; } inline int strideH() const { return strides_[0]; }
inline int strideW() const { return strides_[1]; } inline int strideW() const { return strides_[1]; }
......
...@@ -55,6 +55,7 @@ public: ...@@ -55,6 +55,7 @@ public:
FuncConfig() FuncConfig()
.set("paddings", paddings) .set("paddings", paddings)
.set("strides", strides) .set("strides", strides)
.set("groups", (size_t)1)
.set("algo", algo)); .set("algo", algo));
TensorShape shape0{ TensorShape shape0{
......
...@@ -101,49 +101,57 @@ public: ...@@ -101,49 +101,57 @@ public:
size_t outputHeight = outputs[0].shape()[2]; size_t outputHeight = outputs[0].shape()[2];
size_t outputWidth = outputs[0].shape()[3]; size_t outputWidth = outputs[0].shape()[3];
CHECK_EQ(inputChannels / groups_, inputs[1].shape()[1]);
real* inputData = inputs[0].data<real>(); real* inputData = inputs[0].data<real>();
real* filterData = inputs[1].data<real>(); real* filterData = inputs[1].data<real>();
real* outputData = outputs[0].data<real>(); real* outputData = outputs[0].data<real>();
size_t size = size_t size = inputChannels / groups_ * filterHeight * filterWidth *
inputChannels * filterHeight * filterWidth * outputHeight * outputWidth; outputHeight * outputWidth;
resizeBuffer(size); resizeBuffer(size);
real* colData = reinterpret_cast<real*>(memory_->getBuf()); real* colData = reinterpret_cast<real*>(memory_->getBuf());
Im2ColFunctor<real> im2col; Im2ColFunctor<real> im2col;
size_t inputOffset = (inputChannels / groups_) * inputHeight * inputWidth;
size_t outputOffset =
(outputChannels / groups_) * outputHeight * outputWidth;
size_t filterOffset = inputs[1].shape().getElements() / groups_;
for (size_t i = 0; i < batchSize; i++) { for (size_t i = 0; i < batchSize; i++) {
im2col(inputData, for (int g = 0; g < groups_; g++) {
inputChannels, im2col(inputData + g * inputOffset,
inputHeight, inputChannels / groups_,
inputWidth, inputHeight,
filterHeight, inputWidth,
filterWidth, filterHeight,
strideH(), filterWidth,
strideW(), strideH(),
paddingH(), strideW(),
paddingW(), paddingH(),
outputHeight, paddingW(),
outputWidth, outputHeight,
colData); outputWidth,
colData);
int M = outputChannels;
int N = outputHeight * outputWidth; int M = outputChannels;
int K = inputChannels * filterHeight * filterWidth; int N = outputHeight * outputWidth;
gemm<real>(CblasNoTrans, int K = inputChannels * filterHeight * filterWidth;
CblasNoTrans, gemm<real>(CblasNoTrans,
M, CblasNoTrans,
N, M,
K, N,
1.0f, K,
filterData, 1.0f,
K, filterData + g * filterOffset,
colData, K,
N, colData,
0.0f, N,
outputData, 0.0f,
N); outputData + g * outputOffset,
inputData += inputChannels * inputHeight * inputWidth; N);
outputData += outputChannels * outputHeight * outputWidth; inputData += inputChannels * inputHeight * inputWidth;
outputData += outputChannels * outputHeight * outputWidth;
}
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册