From 9e6ed83cc4295414436ab784db10bf715637cddf Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Wed, 21 Jun 2017 11:26:40 +0800 Subject: [PATCH] Fix ImageExpandFunction. --- paddle/function/ImageExpandOp.cpp | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/paddle/function/ImageExpandOp.cpp b/paddle/function/ImageExpandOp.cpp index 625bf5b6e..ca1d117db 100644 --- a/paddle/function/ImageExpandOp.cpp +++ b/paddle/function/ImageExpandOp.cpp @@ -45,9 +45,7 @@ public: numOutputs_ = 1; } - virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {} - - void check(const TensorShape& image, const TensorShape& sequence) const { + void checkShape(const TensorShape& image, const TensorShape& sequence) const { // image shape should be 4-dimensional. CHECK_EQ(image.ndims(), (size_t)4); // sequence shape should be 3-dimensional. @@ -108,12 +106,18 @@ public: ImageExpandFunction::init(config); } + void check(const BufferArgs& inputs, const BufferArgs& outputs) override { + const TensorShape& image = inputs[0].shape(); + const TensorShape& sequence = outputs[0].shape(); + checkShape(image, sequence); + } + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { CHECK_EQ(numInputs_, inputs.size()); CHECK_EQ(numOutputs_, outputs.size()); + check(inputs, outputs); const TensorShape& image = inputs[0].shape(); const TensorShape& sequence = outputs[0].shape(); - check(image, sequence); TensorShape imShape = TensorShape({image[1], image[2], image[3]}); TensorShape colShape = getColShape(image, sequence); @@ -149,15 +153,21 @@ public: ImageExpandFunction::init(config); } + void check(const BufferArgs& inputs, const BufferArgs& outputs) override { + const TensorShape& image = outputs[0].shape(); + const TensorShape& sequence = inputs[0].shape(); + checkShape(image, sequence); + } + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { CHECK_EQ(numInputs_, inputs.size()); CHECK_EQ(numOutputs_, outputs.size()); + check(inputs, outputs); // Since the implementation of Col2ImFunctor is ADD_TO, // this function only supports ADD_TO mode. CHECK_EQ(outputs[0].getArgType(), ADD_TO); const TensorShape& image = outputs[0].shape(); const TensorShape& sequence = inputs[0].shape(); - check(image, sequence); TensorShape imShape = TensorShape({image[1], image[2], image[3]}); TensorShape colShape = getColShape(image, sequence); -- GitLab