diff --git a/paddle/function/ImageExpandOp.cpp b/paddle/function/ImageExpandOp.cpp index 625bf5b6edf44148b85bbb09da43da6f210e34b7..ca1d117db8845c3dca814dcee575cd73e6cf5a5a 100644 --- a/paddle/function/ImageExpandOp.cpp +++ b/paddle/function/ImageExpandOp.cpp @@ -45,9 +45,7 @@ public: numOutputs_ = 1; } - virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {} - - void check(const TensorShape& image, const TensorShape& sequence) const { + void checkShape(const TensorShape& image, const TensorShape& sequence) const { // image shape should be 4-dimensional. CHECK_EQ(image.ndims(), (size_t)4); // sequence shape should be 3-dimensional. @@ -108,12 +106,18 @@ public: ImageExpandFunction::init(config); } + void check(const BufferArgs& inputs, const BufferArgs& outputs) override { + const TensorShape& image = inputs[0].shape(); + const TensorShape& sequence = outputs[0].shape(); + checkShape(image, sequence); + } + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { CHECK_EQ(numInputs_, inputs.size()); CHECK_EQ(numOutputs_, outputs.size()); + check(inputs, outputs); const TensorShape& image = inputs[0].shape(); const TensorShape& sequence = outputs[0].shape(); - check(image, sequence); TensorShape imShape = TensorShape({image[1], image[2], image[3]}); TensorShape colShape = getColShape(image, sequence); @@ -149,15 +153,21 @@ public: ImageExpandFunction::init(config); } + void check(const BufferArgs& inputs, const BufferArgs& outputs) override { + const TensorShape& image = outputs[0].shape(); + const TensorShape& sequence = inputs[0].shape(); + checkShape(image, sequence); + } + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { CHECK_EQ(numInputs_, inputs.size()); CHECK_EQ(numOutputs_, outputs.size()); + check(inputs, outputs); // Since the implementation of Col2ImFunctor is ADD_TO, // this function only supports ADD_TO mode. CHECK_EQ(outputs[0].getArgType(), ADD_TO); const TensorShape& image = outputs[0].shape(); const TensorShape& sequence = inputs[0].shape(); - check(image, sequence); TensorShape imShape = TensorShape({image[1], image[2], image[3]}); TensorShape colShape = getColShape(image, sequence);