From 7872f37650c1524d6f57d975731a7557352b7b9f Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Tue, 20 Jun 2017 11:16:19 +0800 Subject: [PATCH] Fix some compile error. --- paddle/function/ConvOp.h | 8 +++----- paddle/function/GemmConvOp.cpp | 30 +++++++++++++++++++++++++++--- paddle/function/NaiveConvOp.cpp | 13 +++++++++---- 3 files changed, 39 insertions(+), 12 deletions(-) diff --git a/paddle/function/ConvOp.h b/paddle/function/ConvOp.h index 65b9d1d53..bb4f48364 100644 --- a/paddle/function/ConvOp.h +++ b/paddle/function/ConvOp.h @@ -68,14 +68,12 @@ public: numOutputs_ = 1; } - virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {} - // input can be INPUT and INPUT_GRAD // filter can be FILTER and FILTER_GRAD // output can be OUTPUT and OUTPUT_GRAD - void check(const TensorShape& input, - const TensorShape& filter, - const TensorShape& output) { + void checkShape(const TensorShape& input, + const TensorShape& filter, + const TensorShape& output) { // inputs and outputs arguments should be 4-dimensional. CHECK_EQ(input.ndims(), (size_t)4); CHECK_EQ(output.ndims(), (size_t)4); diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp index c7a57801e..a40e5d9d2 100644 --- a/paddle/function/GemmConvOp.cpp +++ b/paddle/function/GemmConvOp.cpp @@ -117,15 +117,23 @@ public: ConvFunctionBase::init(config); } + virtual void check(const BufferArgs& inputs, + const BufferArgs& outputs) override { + const TensorShape& input = inputs[0].shape(); + const TensorShape& filter = inputs[1].shape(); + const TensorShape& output = outputs[0].shape(); + checkShape(input, filter, output); + } + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { CHECK_EQ(numInputs_, inputs.size()); CHECK_EQ(numOutputs_, outputs.size()); + check(inputs, outputs); // TODO(hedaoyuan): Need to define some index macros, // to avoid useing 0 and 1. const TensorShape& input = inputs[0].shape(); const TensorShape& filter = inputs[1].shape(); const TensorShape& output = outputs[0].shape(); - check(input, filter, output); real beta; if (outputs[0].getArgType() == ADD_TO) { @@ -209,16 +217,24 @@ public: ConvFunctionBase::init(config); } + virtual void check(const BufferArgs& inputs, + const BufferArgs& outputs) override { + const TensorShape& output = inputs[0].shape(); + const TensorShape& filter = inputs[1].shape(); + const TensorShape& input = outputs[0].shape(); + checkShape(input, filter, output); + } + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { CHECK_EQ(numInputs_, inputs.size()); CHECK_EQ(numOutputs_, outputs.size()); + check(inputs, outputs); // Since the implementation of Col2ImFunctor is ADD_TO, // this function only supports ADD_TO mode. CHECK_EQ(outputs[0].getArgType(), ADD_TO); const TensorShape& output = inputs[0].shape(); const TensorShape& filter = inputs[1].shape(); const TensorShape& input = outputs[0].shape(); - check(input, filter, output); size_t batchSize = input[0]; size_t inputChannels = input[1]; @@ -295,13 +311,21 @@ public: ConvFunctionBase::init(config); } + virtual void check(const BufferArgs& inputs, + const BufferArgs& outputs) override { + const TensorShape& output = inputs[0].shape(); + const TensorShape& input = inputs[1].shape(); + const TensorShape& filter = outputs[0].shape(); + checkShape(input, filter, output); + } + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { CHECK_EQ(numInputs_, inputs.size()); CHECK_EQ(numOutputs_, outputs.size()); + check(inputs, outputs); const TensorShape& output = inputs[0].shape(); const TensorShape& input = inputs[1].shape(); const TensorShape& filter = outputs[0].shape(); - check(input, filter, output); real beta; if (outputs[0].getArgType() == ADD_TO) { diff --git a/paddle/function/NaiveConvOp.cpp b/paddle/function/NaiveConvOp.cpp index 1d204f99e..70bd196a6 100644 --- a/paddle/function/NaiveConvOp.cpp +++ b/paddle/function/NaiveConvOp.cpp @@ -90,14 +90,19 @@ public: ConvFunctionBase::init(config); } - void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { - CHECK_EQ(numInputs_, inputs.size()); - CHECK_EQ(numOutputs_, outputs.size()); + virtual void check(const BufferArgs& inputs, + const BufferArgs& outputs) override { const TensorShape& input = inputs[0].shape(); const TensorShape& filter = inputs[1].shape(); const TensorShape& output = outputs[0].shape(); - check(input, filter, output); + checkShape(input, filter, output); + } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(numInputs_, inputs.size()); + CHECK_EQ(numOutputs_, outputs.size()); CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); + check(inputs, outputs); size_t batchSize = inputs[0].shape()[0]; size_t inputChannels = inputs[0].shape()[1]; -- GitLab