diff --git a/paddle/function/ConvOp.h b/paddle/function/ConvOp.h index 65b9d1d53f9210b08cdc8bbd9d93b03305e582e4..bb4f48364b9b454af7d37fe4d3c340666e53285c 100644 --- a/paddle/function/ConvOp.h +++ b/paddle/function/ConvOp.h @@ -68,14 +68,12 @@ public: numOutputs_ = 1; } - virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {} - // input can be INPUT and INPUT_GRAD // filter can be FILTER and FILTER_GRAD // output can be OUTPUT and OUTPUT_GRAD - void check(const TensorShape& input, - const TensorShape& filter, - const TensorShape& output) { + void checkShape(const TensorShape& input, + const TensorShape& filter, + const TensorShape& output) { // inputs and outputs arguments should be 4-dimensional. CHECK_EQ(input.ndims(), (size_t)4); CHECK_EQ(output.ndims(), (size_t)4); diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp index c7a57801ed6098260af5ba22be82ac4ea7c2e601..a40e5d9d2e76605525f0956445fc43c693933cf8 100644 --- a/paddle/function/GemmConvOp.cpp +++ b/paddle/function/GemmConvOp.cpp @@ -117,15 +117,23 @@ public: ConvFunctionBase::init(config); } + virtual void check(const BufferArgs& inputs, + const BufferArgs& outputs) override { + const TensorShape& input = inputs[0].shape(); + const TensorShape& filter = inputs[1].shape(); + const TensorShape& output = outputs[0].shape(); + checkShape(input, filter, output); + } + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { CHECK_EQ(numInputs_, inputs.size()); CHECK_EQ(numOutputs_, outputs.size()); + check(inputs, outputs); // TODO(hedaoyuan): Need to define some index macros, // to avoid useing 0 and 1. const TensorShape& input = inputs[0].shape(); const TensorShape& filter = inputs[1].shape(); const TensorShape& output = outputs[0].shape(); - check(input, filter, output); real beta; if (outputs[0].getArgType() == ADD_TO) { @@ -209,16 +217,24 @@ public: ConvFunctionBase::init(config); } + virtual void check(const BufferArgs& inputs, + const BufferArgs& outputs) override { + const TensorShape& output = inputs[0].shape(); + const TensorShape& filter = inputs[1].shape(); + const TensorShape& input = outputs[0].shape(); + checkShape(input, filter, output); + } + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { CHECK_EQ(numInputs_, inputs.size()); CHECK_EQ(numOutputs_, outputs.size()); + check(inputs, outputs); // Since the implementation of Col2ImFunctor is ADD_TO, // this function only supports ADD_TO mode. CHECK_EQ(outputs[0].getArgType(), ADD_TO); const TensorShape& output = inputs[0].shape(); const TensorShape& filter = inputs[1].shape(); const TensorShape& input = outputs[0].shape(); - check(input, filter, output); size_t batchSize = input[0]; size_t inputChannels = input[1]; @@ -295,13 +311,21 @@ public: ConvFunctionBase::init(config); } + virtual void check(const BufferArgs& inputs, + const BufferArgs& outputs) override { + const TensorShape& output = inputs[0].shape(); + const TensorShape& input = inputs[1].shape(); + const TensorShape& filter = outputs[0].shape(); + checkShape(input, filter, output); + } + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { CHECK_EQ(numInputs_, inputs.size()); CHECK_EQ(numOutputs_, outputs.size()); + check(inputs, outputs); const TensorShape& output = inputs[0].shape(); const TensorShape& input = inputs[1].shape(); const TensorShape& filter = outputs[0].shape(); - check(input, filter, output); real beta; if (outputs[0].getArgType() == ADD_TO) { diff --git a/paddle/function/NaiveConvOp.cpp b/paddle/function/NaiveConvOp.cpp index 1d204f99e0e127688eeda28b46715a37c1100c4e..70bd196a674993a00788ad7a6662b1c4f8a00e07 100644 --- a/paddle/function/NaiveConvOp.cpp +++ b/paddle/function/NaiveConvOp.cpp @@ -90,14 +90,19 @@ public: ConvFunctionBase::init(config); } - void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { - CHECK_EQ(numInputs_, inputs.size()); - CHECK_EQ(numOutputs_, outputs.size()); + virtual void check(const BufferArgs& inputs, + const BufferArgs& outputs) override { const TensorShape& input = inputs[0].shape(); const TensorShape& filter = inputs[1].shape(); const TensorShape& output = outputs[0].shape(); - check(input, filter, output); + checkShape(input, filter, output); + } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(numInputs_, inputs.size()); + CHECK_EQ(numOutputs_, outputs.size()); CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); + check(inputs, outputs); size_t batchSize = inputs[0].shape()[0]; size_t inputChannels = inputs[0].shape()[1];