提交 9e6ed83c 编写于 作者: H hedaoyuan

Fix ImageExpandFunction.

上级 07cde439
...@@ -45,9 +45,7 @@ public: ...@@ -45,9 +45,7 @@ public:
numOutputs_ = 1; numOutputs_ = 1;
} }
virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {} void checkShape(const TensorShape& image, const TensorShape& sequence) const {
void check(const TensorShape& image, const TensorShape& sequence) const {
// image shape should be 4-dimensional. // image shape should be 4-dimensional.
CHECK_EQ(image.ndims(), (size_t)4); CHECK_EQ(image.ndims(), (size_t)4);
// sequence shape should be 3-dimensional. // sequence shape should be 3-dimensional.
...@@ -108,12 +106,18 @@ public: ...@@ -108,12 +106,18 @@ public:
ImageExpandFunction::init(config); ImageExpandFunction::init(config);
} }
void check(const BufferArgs& inputs, const BufferArgs& outputs) override {
const TensorShape& image = inputs[0].shape();
const TensorShape& sequence = outputs[0].shape();
checkShape(image, sequence);
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(numInputs_, inputs.size()); CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size()); CHECK_EQ(numOutputs_, outputs.size());
check(inputs, outputs);
const TensorShape& image = inputs[0].shape(); const TensorShape& image = inputs[0].shape();
const TensorShape& sequence = outputs[0].shape(); const TensorShape& sequence = outputs[0].shape();
check(image, sequence);
TensorShape imShape = TensorShape({image[1], image[2], image[3]}); TensorShape imShape = TensorShape({image[1], image[2], image[3]});
TensorShape colShape = getColShape(image, sequence); TensorShape colShape = getColShape(image, sequence);
...@@ -149,15 +153,21 @@ public: ...@@ -149,15 +153,21 @@ public:
ImageExpandFunction::init(config); ImageExpandFunction::init(config);
} }
void check(const BufferArgs& inputs, const BufferArgs& outputs) override {
const TensorShape& image = outputs[0].shape();
const TensorShape& sequence = inputs[0].shape();
checkShape(image, sequence);
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(numInputs_, inputs.size()); CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size()); CHECK_EQ(numOutputs_, outputs.size());
check(inputs, outputs);
// Since the implementation of Col2ImFunctor is ADD_TO, // Since the implementation of Col2ImFunctor is ADD_TO,
// this function only supports ADD_TO mode. // this function only supports ADD_TO mode.
CHECK_EQ(outputs[0].getArgType(), ADD_TO); CHECK_EQ(outputs[0].getArgType(), ADD_TO);
const TensorShape& image = outputs[0].shape(); const TensorShape& image = outputs[0].shape();
const TensorShape& sequence = inputs[0].shape(); const TensorShape& sequence = inputs[0].shape();
check(image, sequence);
TensorShape imShape = TensorShape({image[1], image[2], image[3]}); TensorShape imShape = TensorShape({image[1], image[2], image[3]});
TensorShape colShape = getColShape(image, sequence); TensorShape colShape = getColShape(image, sequence);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册