提交 784e2184 编写于 作者: H hedaoyuan

Fix the error of group convolution.

上级 7aac38c7
...@@ -46,8 +46,13 @@ namespace paddle { ...@@ -46,8 +46,13 @@ namespace paddle {
* are all NCHW format. Where N is batch size, C is the number of channels, * are all NCHW format. Where N is batch size, C is the number of channels,
* H and W is the height and width of image or image gradient. * H and W is the height and width of image or image gradient.
* *
* 2. The format of the filter data is MCHW, where M is the number of * 2. The format of the filter data is MCHW, where M is the number of output
* output image channels, C is the number of input image channels, * image channels, C is the number of input image channels,
* H and W is height and width of filter.
*
* If groups is greater than 1, the filter's data format should be GMCHW,
* where G is the groups, and G * M is the number of output image channels,
* G * C is the number of input image channels,
* H and W is height and width of filter. * H and W is height and width of filter.
*/ */
class ConvFunctionBase : public FunctionBase { class ConvFunctionBase : public FunctionBase {
...@@ -73,20 +78,47 @@ public: ...@@ -73,20 +78,47 @@ public:
const TensorShape& output) { const TensorShape& output) {
// inputs and outputs arguments should be 4-dimensional. // inputs and outputs arguments should be 4-dimensional.
CHECK_EQ(input.ndims(), (size_t)4); CHECK_EQ(input.ndims(), (size_t)4);
CHECK_EQ(filter.ndims(), (size_t)4);
CHECK_EQ(output.ndims(), (size_t)4); CHECK_EQ(output.ndims(), (size_t)4);
// The batchSize of the input needs to be equal to // The batchSize of the input needs to be equal to
// the batchSize of the output. // the batchSize of the output.
CHECK_EQ(input[0], output[0]); CHECK_EQ(input[0], output[0]);
// The input and output channel dimensions are the second and first if (filter.ndims() == (size_t)4) {
// dimensions of the filter shape. // If the filter's dimension is 4, groups convolution is not supported.
CHECK_EQ(input[1] / groups_, filter[1]); CHECK_EQ(groups_, (size_t)1);
CHECK_EQ(output[1], filter[0]); // The input and output channel dimensions are the second and first
// dimensions of the filter shape.
CHECK_EQ(input[1], filter[1]);
CHECK_EQ(output[1], filter[0]);
} else {
// filter argument should be 5-dimensional.
CHECK_EQ(filter.ndims(), (size_t)5);
// The first dimension of the filter is the size of the group
CHECK_EQ(filter[0], groups_);
// The input and output channel dimensions are the third and second
// dimensions of the filter shape.
CHECK_EQ(input[1], filter[2] * groups_);
CHECK_EQ(output[1], filter[1] * groups_);
}
} }
protected: protected:
size_t getFilterHeight(const TensorShape& filter) const {
if (filter.ndims() == 5) {
return filter[3];
} else {
return filter[2];
}
}
size_t getFilterWidth(const TensorShape& filter) const {
if (filter.ndims() == 5) {
return filter[4];
} else {
return filter[3];
}
}
std::vector<size_t> strides_; std::vector<size_t> strides_;
std::vector<size_t> paddings_; std::vector<size_t> paddings_;
......
...@@ -80,7 +80,7 @@ public: ...@@ -80,7 +80,7 @@ public:
} else if (type == BACKWARD_INPUT_TEST) { } else if (type == BACKWARD_INPUT_TEST) {
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output)); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter)); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, input)); test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, input), ADD_TO);
test.run(); test.run();
} else if (type == BACKWARD_FILTER_TEST) { } else if (type == BACKWARD_FILTER_TEST) {
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output)); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output));
......
...@@ -134,15 +134,15 @@ public: ...@@ -134,15 +134,15 @@ public:
beta = 0.0; beta = 0.0;
} }
size_t batchSize = inputs[0].shape()[0]; size_t batchSize = input[0];
size_t inputChannels = inputs[0].shape()[1]; size_t inputChannels = input[1];
size_t inputHeight = inputs[0].shape()[2]; size_t inputHeight = input[2];
size_t inputWidth = inputs[0].shape()[3]; size_t inputWidth = input[3];
size_t filterHeight = inputs[1].shape()[2]; size_t filterHeight = getFilterHeight(filter);
size_t filterWidth = inputs[1].shape()[3]; size_t filterWidth = getFilterWidth(filter);
size_t outputChannels = outputs[0].shape()[1]; size_t outputChannels = output[1];
size_t outputHeight = outputs[0].shape()[2]; size_t outputHeight = output[2];
size_t outputWidth = outputs[0].shape()[3]; size_t outputWidth = output[3];
real* inputData = inputs[0].data<real>(); real* inputData = inputs[0].data<real>();
real* filterData = inputs[1].data<real>(); real* filterData = inputs[1].data<real>();
...@@ -158,7 +158,8 @@ public: ...@@ -158,7 +158,8 @@ public:
size_t inputOffset = (inputChannels / groups_) * inputHeight * inputWidth; size_t inputOffset = (inputChannels / groups_) * inputHeight * inputWidth;
size_t outputOffset = size_t outputOffset =
(outputChannels / groups_) * outputHeight * outputWidth; (outputChannels / groups_) * outputHeight * outputWidth;
size_t filterOffset = inputs[1].shape().getElements() / groups_; size_t filterOffset = filter.getElements() / groups_;
for (size_t i = 0; i < batchSize; i++) { for (size_t i = 0; i < batchSize; i++) {
for (size_t g = 0; g < groups_; g++) { for (size_t g = 0; g < groups_; g++) {
im2col(inputData + g * inputOffset, im2col(inputData + g * inputOffset,
...@@ -211,7 +212,9 @@ public: ...@@ -211,7 +212,9 @@ public:
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(numInputs_, inputs.size()); CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size()); CHECK_EQ(numOutputs_, outputs.size());
// CHECK_EQ(outputs[0].getArgType(), ADD_TO); // Since the implementation of Col2ImFunctor is ADD_TO,
// this function only supports ADD_TO mode.
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
const TensorShape& output = inputs[0].shape(); const TensorShape& output = inputs[0].shape();
const TensorShape& filter = inputs[1].shape(); const TensorShape& filter = inputs[1].shape();
const TensorShape& input = outputs[0].shape(); const TensorShape& input = outputs[0].shape();
...@@ -221,8 +224,8 @@ public: ...@@ -221,8 +224,8 @@ public:
size_t inputChannels = input[1]; size_t inputChannels = input[1];
size_t inputHeight = input[2]; size_t inputHeight = input[2];
size_t inputWidth = input[3]; size_t inputWidth = input[3];
size_t filterHeight = filter[2]; size_t filterHeight = getFilterHeight(filter);
size_t filterWidth = filter[3]; size_t filterWidth = getFilterWidth(filter);
size_t outputChannels = output[1]; size_t outputChannels = output[1];
size_t outputHeight = output[2]; size_t outputHeight = output[2];
size_t outputWidth = output[3]; size_t outputWidth = output[3];
...@@ -311,8 +314,8 @@ public: ...@@ -311,8 +314,8 @@ public:
size_t inputChannels = input[1]; size_t inputChannels = input[1];
size_t inputHeight = input[2]; size_t inputHeight = input[2];
size_t inputWidth = input[3]; size_t inputWidth = input[3];
size_t filterHeight = filter[2]; size_t filterHeight = getFilterHeight(filter);
size_t filterWidth = filter[3]; size_t filterWidth = getFilterWidth(filter);
size_t outputChannels = output[1]; size_t outputChannels = output[1];
size_t outputHeight = output[2]; size_t outputHeight = output[2];
size_t outputWidth = output[3]; size_t outputWidth = output[3];
......
...@@ -80,8 +80,11 @@ void ExpandConvLayer::forward(PassType passType) { ...@@ -80,8 +80,11 @@ void ExpandConvLayer::forward(PassType passType) {
(size_t)imgSizeH_[i], (size_t)imgSizeH_[i],
(size_t)imgSizeW_[i]}); (size_t)imgSizeW_[i]});
filterShape_[i] = filterShape_[i] =
TensorShape({!isDeconv_ ? (size_t)numFilters_ : (size_t)channels_[i], TensorShape({(size_t)groups_[i],
!isDeconv_ ? (size_t)channels_[i] : (size_t)numFilters_, !isDeconv_ ? (size_t)numFilters_ / groups_[i]
: (size_t)channels_[i] / groups_[i],
!isDeconv_ ? (size_t)channels_[i] / groups_[i]
: (size_t)numFilters_ / groups_[i],
(size_t)filterSizeY_[i], (size_t)filterSizeY_[i],
(size_t)filterSize_[i]}); (size_t)filterSize_[i]});
outputShape_[i] = TensorShape({(size_t)batchSize, outputShape_[i] = TensorShape({(size_t)batchSize,
...@@ -96,8 +99,9 @@ void ExpandConvLayer::forward(PassType passType) { ...@@ -96,8 +99,9 @@ void ExpandConvLayer::forward(PassType passType) {
BufferArgs outputs; BufferArgs outputs;
inputs.addArg(*getInputValue(i), inputShape_[i]); inputs.addArg(*getInputValue(i), inputShape_[i]);
inputs.addArg(*weights_[i]->getW(), filterShape_[i]); inputs.addArg(*weights_[i]->getW(), filterShape_[i]);
outputs.addArg( outputs.addArg(*getOutputValue(),
*getOutputValue(), outputShape_[i], i == 0 ? ASSIGN_TO : ADD_TO); outputShape_[i],
!isDeconv_ && i == 0 ? ASSIGN_TO : ADD_TO);
forward_[i]->calc(inputs, outputs); forward_[i]->calc(inputs, outputs);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册