提交 784e2184 编写于 作者: H hedaoyuan

Fix the error of group convolution.

上级 7aac38c7
......@@ -46,8 +46,13 @@ namespace paddle {
* are all NCHW format. Where N is batch size, C is the number of channels,
* H and W is the height and width of image or image gradient.
*
* 2. The format of the filter data is MCHW, where M is the number of
* output image channels, C is the number of input image channels,
* 2. The format of the filter data is MCHW, where M is the number of output
* image channels, C is the number of input image channels,
* H and W is height and width of filter.
*
* If groups is greater than 1, the filter's data format should be GMCHW,
* where G is the groups, and G * M is the number of output image channels,
* G * C is the number of input image channels,
* H and W is height and width of filter.
*/
class ConvFunctionBase : public FunctionBase {
......@@ -73,20 +78,47 @@ public:
const TensorShape& output) {
// inputs and outputs arguments should be 4-dimensional.
CHECK_EQ(input.ndims(), (size_t)4);
CHECK_EQ(filter.ndims(), (size_t)4);
CHECK_EQ(output.ndims(), (size_t)4);
// The batchSize of the input needs to be equal to
// the batchSize of the output.
CHECK_EQ(input[0], output[0]);
// The input and output channel dimensions are the second and first
// dimensions of the filter shape.
CHECK_EQ(input[1] / groups_, filter[1]);
CHECK_EQ(output[1], filter[0]);
if (filter.ndims() == (size_t)4) {
// If the filter's dimension is 4, groups convolution is not supported.
CHECK_EQ(groups_, (size_t)1);
// The input and output channel dimensions are the second and first
// dimensions of the filter shape.
CHECK_EQ(input[1], filter[1]);
CHECK_EQ(output[1], filter[0]);
} else {
// filter argument should be 5-dimensional.
CHECK_EQ(filter.ndims(), (size_t)5);
// The first dimension of the filter is the size of the group
CHECK_EQ(filter[0], groups_);
// The input and output channel dimensions are the third and second
// dimensions of the filter shape.
CHECK_EQ(input[1], filter[2] * groups_);
CHECK_EQ(output[1], filter[1] * groups_);
}
}
protected:
size_t getFilterHeight(const TensorShape& filter) const {
if (filter.ndims() == 5) {
return filter[3];
} else {
return filter[2];
}
}
size_t getFilterWidth(const TensorShape& filter) const {
if (filter.ndims() == 5) {
return filter[4];
} else {
return filter[3];
}
}
std::vector<size_t> strides_;
std::vector<size_t> paddings_;
......
......@@ -80,7 +80,7 @@ public:
} else if (type == BACKWARD_INPUT_TEST) {
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, input));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, input), ADD_TO);
test.run();
} else if (type == BACKWARD_FILTER_TEST) {
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output));
......
......@@ -134,15 +134,15 @@ public:
beta = 0.0;
}
size_t batchSize = inputs[0].shape()[0];
size_t inputChannels = inputs[0].shape()[1];
size_t inputHeight = inputs[0].shape()[2];
size_t inputWidth = inputs[0].shape()[3];
size_t filterHeight = inputs[1].shape()[2];
size_t filterWidth = inputs[1].shape()[3];
size_t outputChannels = outputs[0].shape()[1];
size_t outputHeight = outputs[0].shape()[2];
size_t outputWidth = outputs[0].shape()[3];
size_t batchSize = input[0];
size_t inputChannels = input[1];
size_t inputHeight = input[2];
size_t inputWidth = input[3];
size_t filterHeight = getFilterHeight(filter);
size_t filterWidth = getFilterWidth(filter);
size_t outputChannels = output[1];
size_t outputHeight = output[2];
size_t outputWidth = output[3];
real* inputData = inputs[0].data<real>();
real* filterData = inputs[1].data<real>();
......@@ -158,7 +158,8 @@ public:
size_t inputOffset = (inputChannels / groups_) * inputHeight * inputWidth;
size_t outputOffset =
(outputChannels / groups_) * outputHeight * outputWidth;
size_t filterOffset = inputs[1].shape().getElements() / groups_;
size_t filterOffset = filter.getElements() / groups_;
for (size_t i = 0; i < batchSize; i++) {
for (size_t g = 0; g < groups_; g++) {
im2col(inputData + g * inputOffset,
......@@ -211,7 +212,9 @@ public:
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size());
// CHECK_EQ(outputs[0].getArgType(), ADD_TO);
// Since the implementation of Col2ImFunctor is ADD_TO,
// this function only supports ADD_TO mode.
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
const TensorShape& output = inputs[0].shape();
const TensorShape& filter = inputs[1].shape();
const TensorShape& input = outputs[0].shape();
......@@ -221,8 +224,8 @@ public:
size_t inputChannels = input[1];
size_t inputHeight = input[2];
size_t inputWidth = input[3];
size_t filterHeight = filter[2];
size_t filterWidth = filter[3];
size_t filterHeight = getFilterHeight(filter);
size_t filterWidth = getFilterWidth(filter);
size_t outputChannels = output[1];
size_t outputHeight = output[2];
size_t outputWidth = output[3];
......@@ -311,8 +314,8 @@ public:
size_t inputChannels = input[1];
size_t inputHeight = input[2];
size_t inputWidth = input[3];
size_t filterHeight = filter[2];
size_t filterWidth = filter[3];
size_t filterHeight = getFilterHeight(filter);
size_t filterWidth = getFilterWidth(filter);
size_t outputChannels = output[1];
size_t outputHeight = output[2];
size_t outputWidth = output[3];
......
......@@ -80,8 +80,11 @@ void ExpandConvLayer::forward(PassType passType) {
(size_t)imgSizeH_[i],
(size_t)imgSizeW_[i]});
filterShape_[i] =
TensorShape({!isDeconv_ ? (size_t)numFilters_ : (size_t)channels_[i],
!isDeconv_ ? (size_t)channels_[i] : (size_t)numFilters_,
TensorShape({(size_t)groups_[i],
!isDeconv_ ? (size_t)numFilters_ / groups_[i]
: (size_t)channels_[i] / groups_[i],
!isDeconv_ ? (size_t)channels_[i] / groups_[i]
: (size_t)numFilters_ / groups_[i],
(size_t)filterSizeY_[i],
(size_t)filterSize_[i]});
outputShape_[i] = TensorShape({(size_t)batchSize,
......@@ -96,8 +99,9 @@ void ExpandConvLayer::forward(PassType passType) {
BufferArgs outputs;
inputs.addArg(*getInputValue(i), inputShape_[i]);
inputs.addArg(*weights_[i]->getW(), filterShape_[i]);
outputs.addArg(
*getOutputValue(), outputShape_[i], i == 0 ? ASSIGN_TO : ADD_TO);
outputs.addArg(*getOutputValue(),
outputShape_[i],
!isDeconv_ && i == 0 ? ASSIGN_TO : ADD_TO);
forward_[i]->calc(inputs, outputs);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册