提交 3ce974b9 编写于 作者: H hedaoyuan

Add group argument in ConvFunctionBase

上级 048b14a9
...@@ -38,6 +38,7 @@ public: ...@@ -38,6 +38,7 @@ public:
// function arguments // function arguments
strides_ = config.get<std::vector<size_t>>("strides"); strides_ = config.get<std::vector<size_t>>("strides");
paddings_ = config.get<std::vector<size_t>>("paddings"); paddings_ = config.get<std::vector<size_t>>("paddings");
groups_ = config.get<size_t>("groups");
// number of inputs and outputs // number of inputs and outputs
numInputs_ = 2; numInputs_ = 2;
...@@ -62,6 +63,11 @@ public: ...@@ -62,6 +63,11 @@ public:
protected: protected:
std::vector<size_t> strides_; std::vector<size_t> strides_;
std::vector<size_t> paddings_; std::vector<size_t> paddings_;
/// Group size, refer to grouped convolution in
/// Alex Krizhevsky's paper: when group=2, the first half of the
/// filters are only connected to the first half of the input channels,
/// and the second half only connected to the second half.
size_t groups_;
inline int strideH() const { return strides_[0]; } inline int strideH() const { return strides_[0]; }
inline int strideW() const { return strides_[1]; } inline int strideW() const { return strides_[1]; }
......
...@@ -55,6 +55,7 @@ public: ...@@ -55,6 +55,7 @@ public:
FuncConfig() FuncConfig()
.set("paddings", paddings) .set("paddings", paddings)
.set("strides", strides) .set("strides", strides)
.set("groups", (size_t)1)
.set("algo", algo)); .set("algo", algo));
TensorShape shape0{ TensorShape shape0{
......
...@@ -101,19 +101,26 @@ public: ...@@ -101,19 +101,26 @@ public:
size_t outputHeight = outputs[0].shape()[2]; size_t outputHeight = outputs[0].shape()[2];
size_t outputWidth = outputs[0].shape()[3]; size_t outputWidth = outputs[0].shape()[3];
CHECK_EQ(inputChannels / groups_, inputs[1].shape()[1]);
real* inputData = inputs[0].data<real>(); real* inputData = inputs[0].data<real>();
real* filterData = inputs[1].data<real>(); real* filterData = inputs[1].data<real>();
real* outputData = outputs[0].data<real>(); real* outputData = outputs[0].data<real>();
size_t size = size_t size = inputChannels / groups_ * filterHeight * filterWidth *
inputChannels * filterHeight * filterWidth * outputHeight * outputWidth; outputHeight * outputWidth;
resizeBuffer(size); resizeBuffer(size);
real* colData = reinterpret_cast<real*>(memory_->getBuf()); real* colData = reinterpret_cast<real*>(memory_->getBuf());
Im2ColFunctor<real> im2col; Im2ColFunctor<real> im2col;
size_t inputOffset = (inputChannels / groups_) * inputHeight * inputWidth;
size_t outputOffset =
(outputChannels / groups_) * outputHeight * outputWidth;
size_t filterOffset = inputs[1].shape().getElements() / groups_;
for (size_t i = 0; i < batchSize; i++) { for (size_t i = 0; i < batchSize; i++) {
im2col(inputData, for (int g = 0; g < groups_; g++) {
inputChannels, im2col(inputData + g * inputOffset,
inputChannels / groups_,
inputHeight, inputHeight,
inputWidth, inputWidth,
filterHeight, filterHeight,
...@@ -135,17 +142,18 @@ public: ...@@ -135,17 +142,18 @@ public:
N, N,
K, K,
1.0f, 1.0f,
filterData, filterData + g * filterOffset,
K, K,
colData, colData,
N, N,
0.0f, 0.0f,
outputData, outputData + g * outputOffset,
N); N);
inputData += inputChannels * inputHeight * inputWidth; inputData += inputChannels * inputHeight * inputWidth;
outputData += outputChannels * outputHeight * outputWidth; outputData += outputChannels * outputHeight * outputWidth;
} }
} }
}
void resizeBuffer(size_t newSize) { void resizeBuffer(size_t newSize) {
if (!memory_ || newSize * sizeof(real) > memory_->getAllocSize()) { if (!memory_ || newSize * sizeof(real) > memory_->getAllocSize()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册