提交 3ce974b9 编写于 作者: H hedaoyuan

Add group argument in ConvFunctionBase

上级 048b14a9
......@@ -38,6 +38,7 @@ public:
// function arguments
strides_ = config.get<std::vector<size_t>>("strides");
paddings_ = config.get<std::vector<size_t>>("paddings");
groups_ = config.get<size_t>("groups");
// number of inputs and outputs
numInputs_ = 2;
......@@ -62,6 +63,11 @@ public:
protected:
std::vector<size_t> strides_;
std::vector<size_t> paddings_;
/// Group size, refer to grouped convolution in
/// Alex Krizhevsky's paper: when group=2, the first half of the
/// filters are only connected to the first half of the input channels,
/// and the second half only connected to the second half.
size_t groups_;
inline int strideH() const { return strides_[0]; }
inline int strideW() const { return strides_[1]; }
......
......@@ -55,6 +55,7 @@ public:
FuncConfig()
.set("paddings", paddings)
.set("strides", strides)
.set("groups", (size_t)1)
.set("algo", algo));
TensorShape shape0{
......
......@@ -101,49 +101,57 @@ public:
size_t outputHeight = outputs[0].shape()[2];
size_t outputWidth = outputs[0].shape()[3];
CHECK_EQ(inputChannels / groups_, inputs[1].shape()[1]);
real* inputData = inputs[0].data<real>();
real* filterData = inputs[1].data<real>();
real* outputData = outputs[0].data<real>();
size_t size =
inputChannels * filterHeight * filterWidth * outputHeight * outputWidth;
size_t size = inputChannels / groups_ * filterHeight * filterWidth *
outputHeight * outputWidth;
resizeBuffer(size);
real* colData = reinterpret_cast<real*>(memory_->getBuf());
Im2ColFunctor<real> im2col;
size_t inputOffset = (inputChannels / groups_) * inputHeight * inputWidth;
size_t outputOffset =
(outputChannels / groups_) * outputHeight * outputWidth;
size_t filterOffset = inputs[1].shape().getElements() / groups_;
for (size_t i = 0; i < batchSize; i++) {
im2col(inputData,
inputChannels,
inputHeight,
inputWidth,
filterHeight,
filterWidth,
strideH(),
strideW(),
paddingH(),
paddingW(),
outputHeight,
outputWidth,
colData);
int M = outputChannels;
int N = outputHeight * outputWidth;
int K = inputChannels * filterHeight * filterWidth;
gemm<real>(CblasNoTrans,
CblasNoTrans,
M,
N,
K,
1.0f,
filterData,
K,
colData,
N,
0.0f,
outputData,
N);
inputData += inputChannels * inputHeight * inputWidth;
outputData += outputChannels * outputHeight * outputWidth;
for (int g = 0; g < groups_; g++) {
im2col(inputData + g * inputOffset,
inputChannels / groups_,
inputHeight,
inputWidth,
filterHeight,
filterWidth,
strideH(),
strideW(),
paddingH(),
paddingW(),
outputHeight,
outputWidth,
colData);
int M = outputChannels;
int N = outputHeight * outputWidth;
int K = inputChannels * filterHeight * filterWidth;
gemm<real>(CblasNoTrans,
CblasNoTrans,
M,
N,
K,
1.0f,
filterData + g * filterOffset,
K,
colData,
N,
0.0f,
outputData + g * outputOffset,
N);
inputData += inputChannels * inputHeight * inputWidth;
outputData += outputChannels * outputHeight * outputWidth;
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册