diff --git a/paddle/function/ConvOp.h b/paddle/function/ConvOp.h index 14c20b74f2007314534b20d36cf588cef84d34cd..8f2c0c4cb8ab4a5e380055b831852325bbdbe3b5 100644 --- a/paddle/function/ConvOp.h +++ b/paddle/function/ConvOp.h @@ -19,20 +19,36 @@ limitations under the License. */ namespace paddle { /* - * Function Arguments: + * \brief Based on the ConvFunctionBase class, the forward calculation, + * backward input calculation and backward filter calculation + * of convolution operations can be implemented. * - * \param inputs[0] Input image data, is NCHW format, where N is batch size, - * C is the number of channels, H and W is the height and - * width of input image. - * \param inputs[1] Filter data, is MCHW, where M is the number of output - * channels, C is the number of input channels, H and W - * is height and width of filter. - * \param outputs[0] Output image data, is NCHW format, where N is batch size, - * C is the number of channels, H and W is the height and - * width of output image. + * Arguments of forward and backward calculation: + * 1. Forward calculation of convolution. + * inputs = {INPUT, FILTER}, outputs = {OUTPUT} + * The first and second input arguments are input image and filter data. + * The output argument is output image. * - * \note Implemented based on the ConvFunctionBase class only supports - * input data in the NCHW format. + * 2. Backward input calculation of convolution. + * inputs = {OUTPUT_GRAD, FILTER}, outputs = {INPUT_GRAD} + * The first and second input arguments are output grad image + * and filter data. + * The output argument is input grad image. + * + * 3. Backward filter calculation of convolution. + * inputs = {OUTPUT_GRAD, INPUT}, outputs = {FILTER_GRAD} + * The first and second input arguments are output grad image + * and input image. + * The output argument is filter grad. + * + * Arguments format of input, filter and output: + * 1. Input image, output image, input image gradient, output image gradient + * are all NCHW format. Where N is batch size, C is the number of channels, + * H and W is the height and width of image or image gradient. + * + * 2. The format of the filter data is MCHW, where M is the number of + * output image channels, C is the number of input image channels, + * H and W is height and width of filter. */ class ConvFunctionBase : public FunctionBase { public: @@ -49,17 +65,25 @@ public: virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {} - void check(const BufferArgs& inputs, const BufferArgs& outputs) override { - CHECK_EQ(numInputs_, inputs.size()); - CHECK_EQ(numOutputs_, outputs.size()); - - CHECK_EQ(inputs[0].shape().ndims(), (size_t)4); - CHECK_EQ(inputs[1].shape().ndims(), (size_t)4); - CHECK_EQ(outputs[0].shape().ndims(), (size_t)4); - - CHECK(inputs[0].shape()[0] == outputs[0].shape()[0]); - CHECK(inputs[0].shape()[1] / groups_ == inputs[1].shape()[1]); - CHECK(outputs[0].shape()[1] == inputs[1].shape()[0]); + // input can be INPUT and INPUT_GRAD + // filter can be FILTER and FILTER_GRAD + // output can be OUTPUT and OUTPUT_GRAD + void check(const TensorShape& input, + const TensorShape& filter, + const TensorShape& output) { + // inputs and outputs arguments should be 4-dimensional. + CHECK_EQ(input.ndims(), (size_t)4); + CHECK_EQ(filter.ndims(), (size_t)4); + CHECK_EQ(output.ndims(), (size_t)4); + + // The batchSize of the input needs to be equal to + // the batchSize of the output. + CHECK_EQ(input[0], output[0]); + + // The input and output channel dimensions are the second and first + // dimensions of the filter shape. + CHECK_EQ(input[1] / groups_, filter[1]); + CHECK_EQ(output[1], filter[0]); } protected: diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp index 78aa8f14f34457e131dd4126fc6b8b2b6b07485c..109ed20ab0666815f922fda1433c68af5540e5f0 100644 --- a/paddle/function/GemmConvOp.cpp +++ b/paddle/function/GemmConvOp.cpp @@ -68,17 +68,7 @@ public: }; /* - * Function Arguments: - * - * \param inputs[0] Input image data, is NCHW format, where N is batch size, - * C is the number of channels, H and W is the height and - * width of input image. - * \param inputs[1] Filter data, is MCHW, where M is the number of output - * channels, C is the number of input channels, H and W - * is height and width of filter. - * \param outputs[0] Output image data, is NCHW format, where N is batch size, - * C is the number of channels, H and W is the height and - * width of output image. + * \brief Forward calculation of convolution. */ template class GemmConvFunction : public ConvFunctionBase { @@ -88,8 +78,21 @@ public: } void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { - check(inputs, outputs); - CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); + CHECK_EQ(numInputs_, inputs.size()); + CHECK_EQ(numOutputs_, outputs.size()); + // TODO(hedaoyuan): Need to define some index macros, + // to avoid useing 0 and 1. + const TensorShape& input = inputs[0].shape(); + const TensorShape& filter = inputs[1].shape(); + const TensorShape& output = outputs[0].shape(); + check(input, filter, output); + + real beta; + if (outputs[0].getArgType() == ADD_TO) { + beta = 1.0; + } else { + beta = 0.0; + } size_t batchSize = inputs[0].shape()[0]; size_t inputChannels = inputs[0].shape()[1]; @@ -143,7 +146,7 @@ public: K, colData, N, - 0.0f, + beta, outputData + g * outputOffset, N); } @@ -166,9 +169,53 @@ private: MemoryHandlePtr memory_; }; +/* + * \brief Backward input calculation of convolution. + */ +template +class GemmConvGradInputFunction : public ConvFunctionBase { +public: + void init(const FuncConfig& config) override { + ConvFunctionBase::init(config); + } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(numInputs_, inputs.size()); + CHECK_EQ(numOutputs_, outputs.size()); + const TensorShape& outputGrad = inputs[0].shape(); + const TensorShape& filter = inputs[1].shape(); + const TensorShape& inputGrad = outputs[0].shape(); + check(inputGrad, filter, outputGrad); + } +}; + +/* + * \brief Backward filter calculation of convolution. + */ +template +class GemmConvGradFilterFunction : public ConvFunctionBase { +public: + void init(const FuncConfig& config) override { + ConvFunctionBase::init(config); + } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(numInputs_, inputs.size()); + CHECK_EQ(numOutputs_, outputs.size()); + const TensorShape& outputGrad = inputs[0].shape(); + const TensorShape& input = inputs[1].shape(); + const TensorShape& filterGrad = outputs[0].shape(); + check(input, filterGrad, outputGrad); + } +}; + REGISTER_TYPED_FUNC(GemmConv, CPU, GemmConvFunction); +REGISTER_TYPED_FUNC(GemmConvGradInput, CPU, GemmConvGradInputFunction); +REGISTER_TYPED_FUNC(GemmConvGradFilter, CPU, GemmConvGradFilterFunction); #ifndef PADDLE_ONLY_CPU REGISTER_TYPED_FUNC(GemmConv, GPU, GemmConvFunction); +REGISTER_TYPED_FUNC(GemmConvGradInput, GPU, GemmConvGradInputFunction); +REGISTER_TYPED_FUNC(GemmConvGradFilter, GPU, GemmConvGradFilterFunction); #endif } // namespace paddle diff --git a/paddle/function/NaiveConvOp.cpp b/paddle/function/NaiveConvOp.cpp index f5d2aa16ab9b8fdedf6320df52bdeae24ca73eea..94aba253e3e56d8a26429d64d539f29ed97ff30a 100644 --- a/paddle/function/NaiveConvOp.cpp +++ b/paddle/function/NaiveConvOp.cpp @@ -91,7 +91,12 @@ public: } void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { - check(inputs, outputs); + CHECK_EQ(numInputs_, inputs.size()); + CHECK_EQ(numOutputs_, outputs.size()); + const TensorShape& input = inputs[0].shape(); + const TensorShape& filter = inputs[1].shape(); + const TensorShape& output = outputs[0].shape(); + check(input, filter, output); CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); size_t batchSize = inputs[0].shape()[0];