提交 dfb2b2ce 编写于 作者: M Megvii Engine Team

fix(dnn): change pooling window size smaller than padding constraint to log_error

GitOrigin-RevId: c3cda68f6d5a5b4f7a8083d0df7423e83c2b7aa6
上级 6919127f
......@@ -92,10 +92,12 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src,
size_t sw = this->param().stride_w;
size_t ph = this->param().pad_h;
size_t pw = this->param().pad_w;
megdnn_assert(ph < fh && pw < fw,
"pooling padding size (%zu %zu) should not be bigger than "
"window size (%zu %zu)",
pw, ph, fw, fh);
if (ph < fh && pw < fw) {
megdnn_log_error(
"pooling padding size (%zu %zu) should not be bigger than "
"window size (%zu %zu), it only can be used in CaffePooling",
pw, ph, fw, fh);
}
infer_conv_shape2d(ih, iw, fh, fw, sh, sw, ph, pw, oh, ow);
if (param().format == Param::Format::NCHW) {
dst = TensorLayout(TensorShape({n, c, oh, ow}), src.dtype);
......
......@@ -104,17 +104,6 @@ TEST(TestOprDNN, PoolingBackward)
}
}
TEST(TestOprDNN, PoolingForwardPadding) {
auto graph = ComputingGraph::make();
Param param(Param::Mode::MAX, 2, 2, 2, 2, 2, 2);
SymbolVarArray symbol_inputs;
HostTensorGenerator<> gen;
auto host_tensor = gen({2, 3, 23, 24});
symbol_inputs.push_back(
mgb::opr::Host2DeviceCopy::make(*graph, host_tensor));
ASSERT_THROW(opr::Pooling::make(symbol_inputs[0], param), MegDNNError);
}
} // anonymous namespace
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册