提交 afbe556e 编写于 作者: H hedaoyuan

Modify the arguments description of ConvFunctionBase. And add the definition...

Modify the arguments description of ConvFunctionBase. And add the definition of backward input and backward filter function.
上级 3408b4b2
...@@ -19,20 +19,36 @@ limitations under the License. */ ...@@ -19,20 +19,36 @@ limitations under the License. */
namespace paddle { namespace paddle {
/* /*
* Function Arguments: * \brief Based on the ConvFunctionBase class, the forward calculation,
* backward input calculation and backward filter calculation
* of convolution operations can be implemented.
* *
* \param inputs[0] Input image data, is NCHW format, where N is batch size, * Arguments of forward and backward calculation:
* C is the number of channels, H and W is the height and * 1. Forward calculation of convolution.
* width of input image. * inputs = {INPUT, FILTER}, outputs = {OUTPUT}
* \param inputs[1] Filter data, is MCHW, where M is the number of output * The first and second input arguments are input image and filter data.
* channels, C is the number of input channels, H and W * The output argument is output image.
* is height and width of filter.
* \param outputs[0] Output image data, is NCHW format, where N is batch size,
* C is the number of channels, H and W is the height and
* width of output image.
* *
* \note Implemented based on the ConvFunctionBase class only supports * 2. Backward input calculation of convolution.
* input data in the NCHW format. * inputs = {OUTPUT_GRAD, FILTER}, outputs = {INPUT_GRAD}
* The first and second input arguments are output grad image
* and filter data.
* The output argument is input grad image.
*
* 3. Backward filter calculation of convolution.
* inputs = {OUTPUT_GRAD, INPUT}, outputs = {FILTER_GRAD}
* The first and second input arguments are output grad image
* and input image.
* The output argument is filter grad.
*
* Arguments format of input, filter and output:
* 1. Input image, output image, input image gradient, output image gradient
* are all NCHW format. Where N is batch size, C is the number of channels,
* H and W is the height and width of image or image gradient.
*
* 2. The format of the filter data is MCHW, where M is the number of
* output image channels, C is the number of input image channels,
* H and W is height and width of filter.
*/ */
class ConvFunctionBase : public FunctionBase { class ConvFunctionBase : public FunctionBase {
public: public:
...@@ -49,17 +65,25 @@ public: ...@@ -49,17 +65,25 @@ public:
virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {} virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {}
void check(const BufferArgs& inputs, const BufferArgs& outputs) override { // input can be INPUT and INPUT_GRAD
CHECK_EQ(numInputs_, inputs.size()); // filter can be FILTER and FILTER_GRAD
CHECK_EQ(numOutputs_, outputs.size()); // output can be OUTPUT and OUTPUT_GRAD
void check(const TensorShape& input,
CHECK_EQ(inputs[0].shape().ndims(), (size_t)4); const TensorShape& filter,
CHECK_EQ(inputs[1].shape().ndims(), (size_t)4); const TensorShape& output) {
CHECK_EQ(outputs[0].shape().ndims(), (size_t)4); // inputs and outputs arguments should be 4-dimensional.
CHECK_EQ(input.ndims(), (size_t)4);
CHECK(inputs[0].shape()[0] == outputs[0].shape()[0]); CHECK_EQ(filter.ndims(), (size_t)4);
CHECK(inputs[0].shape()[1] / groups_ == inputs[1].shape()[1]); CHECK_EQ(output.ndims(), (size_t)4);
CHECK(outputs[0].shape()[1] == inputs[1].shape()[0]);
// The batchSize of the input needs to be equal to
// the batchSize of the output.
CHECK_EQ(input[0], output[0]);
// The input and output channel dimensions are the second and first
// dimensions of the filter shape.
CHECK_EQ(input[1] / groups_, filter[1]);
CHECK_EQ(output[1], filter[0]);
} }
protected: protected:
......
...@@ -68,17 +68,7 @@ public: ...@@ -68,17 +68,7 @@ public:
}; };
/* /*
* Function Arguments: * \brief Forward calculation of convolution.
*
* \param inputs[0] Input image data, is NCHW format, where N is batch size,
* C is the number of channels, H and W is the height and
* width of input image.
* \param inputs[1] Filter data, is MCHW, where M is the number of output
* channels, C is the number of input channels, H and W
* is height and width of filter.
* \param outputs[0] Output image data, is NCHW format, where N is batch size,
* C is the number of channels, H and W is the height and
* width of output image.
*/ */
template <DeviceType Device> template <DeviceType Device>
class GemmConvFunction : public ConvFunctionBase { class GemmConvFunction : public ConvFunctionBase {
...@@ -88,8 +78,21 @@ public: ...@@ -88,8 +78,21 @@ public:
} }
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
check(inputs, outputs); CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); CHECK_EQ(numOutputs_, outputs.size());
// TODO(hedaoyuan): Need to define some index macros,
// to avoid useing 0 and 1.
const TensorShape& input = inputs[0].shape();
const TensorShape& filter = inputs[1].shape();
const TensorShape& output = outputs[0].shape();
check(input, filter, output);
real beta;
if (outputs[0].getArgType() == ADD_TO) {
beta = 1.0;
} else {
beta = 0.0;
}
size_t batchSize = inputs[0].shape()[0]; size_t batchSize = inputs[0].shape()[0];
size_t inputChannels = inputs[0].shape()[1]; size_t inputChannels = inputs[0].shape()[1];
...@@ -143,7 +146,7 @@ public: ...@@ -143,7 +146,7 @@ public:
K, K,
colData, colData,
N, N,
0.0f, beta,
outputData + g * outputOffset, outputData + g * outputOffset,
N); N);
} }
...@@ -166,9 +169,53 @@ private: ...@@ -166,9 +169,53 @@ private:
MemoryHandlePtr memory_; MemoryHandlePtr memory_;
}; };
/*
* \brief Backward input calculation of convolution.
*/
template <DeviceType Device>
class GemmConvGradInputFunction : public ConvFunctionBase {
public:
void init(const FuncConfig& config) override {
ConvFunctionBase::init(config);
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size());
const TensorShape& outputGrad = inputs[0].shape();
const TensorShape& filter = inputs[1].shape();
const TensorShape& inputGrad = outputs[0].shape();
check(inputGrad, filter, outputGrad);
}
};
/*
* \brief Backward filter calculation of convolution.
*/
template <DeviceType Device>
class GemmConvGradFilterFunction : public ConvFunctionBase {
public:
void init(const FuncConfig& config) override {
ConvFunctionBase::init(config);
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size());
const TensorShape& outputGrad = inputs[0].shape();
const TensorShape& input = inputs[1].shape();
const TensorShape& filterGrad = outputs[0].shape();
check(input, filterGrad, outputGrad);
}
};
REGISTER_TYPED_FUNC(GemmConv, CPU, GemmConvFunction); REGISTER_TYPED_FUNC(GemmConv, CPU, GemmConvFunction);
REGISTER_TYPED_FUNC(GemmConvGradInput, CPU, GemmConvGradInputFunction);
REGISTER_TYPED_FUNC(GemmConvGradFilter, CPU, GemmConvGradFilterFunction);
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
REGISTER_TYPED_FUNC(GemmConv, GPU, GemmConvFunction); REGISTER_TYPED_FUNC(GemmConv, GPU, GemmConvFunction);
REGISTER_TYPED_FUNC(GemmConvGradInput, GPU, GemmConvGradInputFunction);
REGISTER_TYPED_FUNC(GemmConvGradFilter, GPU, GemmConvGradFilterFunction);
#endif #endif
} // namespace paddle } // namespace paddle
...@@ -91,7 +91,12 @@ public: ...@@ -91,7 +91,12 @@ public:
} }
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
check(inputs, outputs); CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size());
const TensorShape& input = inputs[0].shape();
const TensorShape& filter = inputs[1].shape();
const TensorShape& output = outputs[0].shape();
check(input, filter, output);
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
size_t batchSize = inputs[0].shape()[0]; size_t batchSize = inputs[0].shape()[0];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册