提交 1192a9a6 编写于 作者: M Megvii Engine Team

fix(imperative): fix adaptive pool2d

GitOrigin-RevId: 624b47a51a9905d1744cd372354b3cd8dae48f6e
上级 7b4b94fd
...@@ -1066,7 +1066,31 @@ py::object _adaptive_pool2d_cpp( ...@@ -1066,7 +1066,31 @@ py::object _adaptive_pool2d_cpp(
py::handle inp_hdl, py::handle shape_val_hdl, py::handle pool_mode_hdl) { py::handle inp_hdl, py::handle shape_val_hdl, py::handle pool_mode_hdl) {
py::object shape_hdl = py::reinterpret_borrow<py::object>(shape_val_hdl); py::object shape_hdl = py::reinterpret_borrow<py::object>(shape_val_hdl);
py::list shps(0); py::list shps(0);
if (!PyTuple_Check(shape_val_hdl.ptr())) { auto mode_string = pool_mode_hdl.cast<std::string>();
::megdnn::param::AdaptivePooling::Mode pool_mode =
::megdnn::param::AdaptivePooling::Mode::MAX;
if (mode_string.compare(std::string("AVERAGE")) == 0) {
pool_mode = ::megdnn::param::AdaptivePooling::Mode::AVERAGE;
}
std::shared_ptr<OpDef> op;
std::vector<PyObject*> p;
auto pool_format = ::megdnn::param::AdaptivePooling::Format::NCHW;
auto inp_format = getattr(inp_hdl, "format").cast<std::string>();
if (inp_format == "nhwc") {
pool_format = ::megdnn::param::AdaptivePooling::Format::NHWC;
}
if (TensorWrapper::try_cast(shape_val_hdl.ptr())) {
std::vector<int32_t> shp;
op = AdaptivePooling::make(pool_mode, pool_format, shp);
py::object Op = py::cast(op);
p.resize(3);
p[0] = Op.ptr();
p[1] = inp_hdl.ptr();
p[2] = shape_val_hdl.ptr();
py::tuple ret =
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
return ret[0];
} else if (!PyTuple_Check(shape_val_hdl.ptr())) {
shps.append(PyLong_AsLong(shape_val_hdl.ptr())); shps.append(PyLong_AsLong(shape_val_hdl.ptr()));
shps.append(PyLong_AsLong(shape_val_hdl.ptr())); shps.append(PyLong_AsLong(shape_val_hdl.ptr()));
...@@ -1078,19 +1102,11 @@ py::object _adaptive_pool2d_cpp( ...@@ -1078,19 +1102,11 @@ py::object _adaptive_pool2d_cpp(
} catch (py::error_already_set& err) { } catch (py::error_already_set& err) {
shape_tuple = py::reinterpret_borrow<py::object>(shape_hdl); shape_tuple = py::reinterpret_borrow<py::object>(shape_hdl);
} }
auto mode_string = pool_mode_hdl.cast<std::string>();
::megdnn::param::AdaptivePooling::Mode pool_mode =
::megdnn::param::AdaptivePooling::Mode::MAX;
if (mode_string.compare(std::string("AVERAGE")) == 0) {
pool_mode = ::megdnn::param::AdaptivePooling::Mode::AVERAGE;
}
auto [shape, fastpath] = tuple2vector(shape_tuple); auto [shape, fastpath] = tuple2vector(shape_tuple);
fastpath &= enable_fastpath(inp_hdl); fastpath &= enable_fastpath(inp_hdl);
std::shared_ptr<OpDef> op;
std::vector<PyObject*> p;
py::object shape_tensor; py::object shape_tensor;
op = AdaptivePooling::make( op = AdaptivePooling::make(pool_mode, pool_format, shape);
pool_mode, ::megdnn::param::AdaptivePooling::Format::NCHW, shape);
if (fastpath) { if (fastpath) {
p.resize(2); p.resize(2);
} else { } else {
......
...@@ -39,6 +39,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( ...@@ -39,6 +39,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const dt_int32* oshp2d = nullptr; const dt_int32* oshp2d = nullptr;
dst_layout.ndim = 4u; dst_layout.ndim = 4u;
bool tshp1n = false;
if (nr_inp == 1) { if (nr_inp == 1) {
oshp2d = pool.shape.data(); oshp2d = pool.shape.data();
} else { } else {
...@@ -51,17 +52,18 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( ...@@ -51,17 +52,18 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
"target shape of AdaptivePooling expects ndim=1; got ndim=%lu actually", "target shape of AdaptivePooling expects ndim=1; got ndim=%lu actually",
tshp.layout.ndim); tshp.layout.ndim);
oshp2d = tshp.value.ptr<dt_int32>(); oshp2d = tshp.value.ptr<dt_int32>();
tshp1n = tshp.layout.total_nr_elems() == 1;
} }
auto param_format = pool.param().format; auto param_format = pool.param().format;
if (param_format == opr::AdaptivePooling::Param::Format::NCHW) { if (param_format == opr::AdaptivePooling::Param::Format::NCHW) {
dst_layout[0] = src.layout[0]; dst_layout[0] = src.layout[0];
dst_layout[1] = src.layout[1]; dst_layout[1] = src.layout[1];
dst_layout[2] = oshp2d[0]; dst_layout[2] = oshp2d[0];
dst_layout[3] = oshp2d[1]; dst_layout[3] = tshp1n ? oshp2d[0] : oshp2d[1];
} else if (param_format == opr::AdaptivePooling::Param::Format::NHWC) { } else if (param_format == opr::AdaptivePooling::Param::Format::NHWC) {
dst_layout[0] = src.layout[0]; dst_layout[0] = src.layout[0];
dst_layout[1] = oshp2d[0]; dst_layout[1] = oshp2d[0];
dst_layout[2] = oshp2d[1]; dst_layout[2] = tshp1n ? oshp2d[0] : oshp2d[1];
dst_layout[3] = src.layout[3]; dst_layout[3] = src.layout[3];
} else { } else {
mgb_throw(MegBrainError, "AdaptivePooling only support NCHW or NHWC format"); mgb_throw(MegBrainError, "AdaptivePooling only support NCHW or NHWC format");
...@@ -83,8 +85,10 @@ SmallVector<TensorPtr> apply_on_physical_tensor( ...@@ -83,8 +85,10 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
if (!validated) { if (!validated) {
dst_layout.ndim = src_layout.ndim; dst_layout.ndim = src_layout.ndim;
const dt_int32* oshp2d = nullptr; const dt_int32* oshp2d = nullptr;
bool tshp1n = false;
if (inputs.size() == 2) { if (inputs.size() == 2) {
auto&& tshp_nd = inputs[1]; auto&& tshp_nd = inputs[1];
tshp1n = inputs[1]->layout().total_nr_elems() == 1;
oshp2d = tshp_nd->get_value().proxy_to_default_cpu().ptr<dt_int32>(); oshp2d = tshp_nd->get_value().proxy_to_default_cpu().ptr<dt_int32>();
} else { } else {
oshp2d = pool.shape.data(); oshp2d = pool.shape.data();
...@@ -93,11 +97,11 @@ SmallVector<TensorPtr> apply_on_physical_tensor( ...@@ -93,11 +97,11 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
dst_layout[0] = src_layout[0]; dst_layout[0] = src_layout[0];
dst_layout[1] = src_layout[1]; dst_layout[1] = src_layout[1];
dst_layout[2] = oshp2d[0]; dst_layout[2] = oshp2d[0];
dst_layout[3] = oshp2d[1]; dst_layout[3] = tshp1n ? oshp2d[0] : oshp2d[1];
} else if (param_format == opr::AdaptivePooling::Param::Format::NHWC) { } else if (param_format == opr::AdaptivePooling::Param::Format::NHWC) {
dst_layout[0] = src_layout[0]; dst_layout[0] = src_layout[0];
dst_layout[1] = oshp2d[0]; dst_layout[1] = oshp2d[0];
dst_layout[2] = oshp2d[1]; dst_layout[2] = tshp1n ? oshp2d[0] : oshp2d[1];
dst_layout[3] = src_layout[3]; dst_layout[3] = src_layout[3];
} else { } else {
mgb_throw( mgb_throw(
......
...@@ -39,22 +39,23 @@ void AdaptivePoolingForward::outshape_by_symvar_do_get_output_shape( ...@@ -39,22 +39,23 @@ void AdaptivePoolingForward::outshape_by_symvar_do_get_output_shape(
cg::copy_tensor_value_to_shape(oshp2d, *shpinfo.shpval_inp_val.at(0)); cg::copy_tensor_value_to_shape(oshp2d, *shpinfo.shpval_inp_val.at(0));
auto src = shpinfo.shape_inp_shp.at(0); auto src = shpinfo.shape_inp_shp.at(0);
mgb_assert( mgb_assert(
src.ndim == 4 && oshp2d.ndim == 2, src.ndim == 4 && (oshp2d.ndim == 2 || oshp2d.ndim == 1),
"shape mismatch for AdaptivePooling: src=%s, out2d=%s", "shape mismatch for AdaptivePooling: src=%s, out2d=%s",
src.to_string().c_str(), oshp2d.to_string().c_str()); src.to_string().c_str(), oshp2d.to_string().c_str());
auto param_format = param().format; auto param_format = param().format;
bool tshp1n = oshp2d.ndim == 1;
if (param_format == Param::Format::NCHW) { if (param_format == Param::Format::NCHW) {
dest.ndim = 4; dest.ndim = 4;
dest.shape[0] = src.shape[0]; dest.shape[0] = src.shape[0];
dest.shape[1] = src.shape[1]; dest.shape[1] = src.shape[1];
dest.shape[2] = oshp2d.shape[0]; dest.shape[2] = oshp2d.shape[0];
dest.shape[3] = oshp2d.shape[1]; dest.shape[3] = (tshp1n) ? oshp2d.shape[0] : oshp2d.shape[1];
} else if (param_format == Param::Format::NHWC) { } else if (param_format == Param::Format::NHWC) {
dest.ndim = 4; dest.ndim = 4;
dest.shape[0] = src.shape[0]; dest.shape[0] = src.shape[0];
dest.shape[1] = oshp2d.shape[0]; dest.shape[1] = oshp2d.shape[0];
dest.shape[2] = oshp2d.shape[1]; dest.shape[2] = (tshp1n) ? oshp2d.shape[0] : oshp2d.shape[1];
dest.shape[3] = src.shape[3]; dest.shape[3] = src.shape[3];
} else { } else {
mgb_throw(MegBrainError, "AdaptivePooling only support NCHW or NHWC format"); mgb_throw(MegBrainError, "AdaptivePooling only support NCHW or NHWC format");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册