提交 15c6da62 编写于 作者: M Megvii Engine Team

feat(imperative/amp): add nhwc support for adaptive pooling

GitOrigin-RevId: 7c5755308e4355f38fffd5634f0733836699a8c1
上级 c28a875f
...@@ -37,12 +37,10 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( ...@@ -37,12 +37,10 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
return {{{TensorLayout(src.layout.dtype), src.comp_node}}, false}; return {{{TensorLayout(src.layout.dtype), src.comp_node}}, false};
} }
const dt_int32* oshp2d = nullptr;
dst_layout.ndim = 4u; dst_layout.ndim = 4u;
if (nr_inp == 1) { if (nr_inp == 1) {
dst_layout[0] = src.layout[0]; oshp2d = pool.shape.data();
dst_layout[1] = src.layout[1];
dst_layout[2] = pool.shape[0];
dst_layout[3] = pool.shape[1];
} else { } else {
auto&& tshp = inputs[1]; auto&& tshp = inputs[1];
if (tshp.value.empty()) { if (tshp.value.empty()) {
...@@ -52,11 +50,21 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( ...@@ -52,11 +50,21 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
tshp.layout.ndim == 1, tshp.layout.ndim == 1,
"target shape of AdaptivePooling expects ndim=1; got ndim=%lu actually", "target shape of AdaptivePooling expects ndim=1; got ndim=%lu actually",
tshp.layout.ndim); tshp.layout.ndim);
oshp2d = tshp.value.ptr<dt_int32>();
}
auto param_format = pool.param().format;
if (param_format == opr::AdaptivePooling::Param::Format::NCHW) {
dst_layout[0] = src.layout[0]; dst_layout[0] = src.layout[0];
dst_layout[1] = src.layout[1]; dst_layout[1] = src.layout[1];
auto* ptr = tshp.value.ptr<dt_int32>(); dst_layout[2] = oshp2d[0];
dst_layout[2] = ptr[0]; dst_layout[3] = oshp2d[1];
dst_layout[3] = ptr[1]; } else if (param_format == opr::AdaptivePooling::Param::Format::NHWC) {
dst_layout[0] = src.layout[0];
dst_layout[1] = oshp2d[0];
dst_layout[2] = oshp2d[1];
dst_layout[3] = src.layout[3];
} else {
mgb_throw(MegBrainError, "AdaptivePooling only support NCHW or NHWC format");
} }
dst_layout.init_contiguous_stride(); dst_layout.init_contiguous_stride();
return {{{dst_layout, src.comp_node}}, true}; return {{{dst_layout, src.comp_node}}, true};
...@@ -71,26 +79,47 @@ SmallVector<TensorPtr> apply_on_physical_tensor( ...@@ -71,26 +79,47 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
using TensorND = megdnn::TensorND; using TensorND = megdnn::TensorND;
auto&& src_layout = inputs[0]->layout(); auto&& src_layout = inputs[0]->layout();
TensorLayout dst_layout = output_descs[0].layout; TensorLayout dst_layout = output_descs[0].layout;
auto param_format = pool.format;
if (!validated) { if (!validated) {
TensorShape tshp;
dst_layout.ndim = src_layout.ndim; dst_layout.ndim = src_layout.ndim;
dst_layout[0] = src_layout[0]; const dt_int32* oshp2d = nullptr;
dst_layout[1] = src_layout[1];
if (inputs.size() == 2) { if (inputs.size() == 2) {
auto&& tshp_nd = inputs[1]; auto&& tshp_nd = inputs[1];
cg::copy_tensor_value_to_shape( oshp2d = tshp_nd->get_value().proxy_to_default_cpu().ptr<dt_int32>();
tshp, tshp_nd->get_value().proxy_to_default_cpu());
dst_layout[2] = tshp[0];
dst_layout[3] = tshp[1];
} else { } else {
dst_layout[2] = pool.shape[0]; oshp2d = pool.shape.data();
dst_layout[3] = pool.shape[1]; }
if (param_format == opr::AdaptivePooling::Param::Format::NCHW) {
dst_layout[0] = src_layout[0];
dst_layout[1] = src_layout[1];
dst_layout[2] = oshp2d[0];
dst_layout[3] = oshp2d[1];
} else if (param_format == opr::AdaptivePooling::Param::Format::NHWC) {
dst_layout[0] = src_layout[0];
dst_layout[1] = oshp2d[0];
dst_layout[2] = oshp2d[1];
dst_layout[3] = src_layout[3];
} else {
mgb_throw(
MegBrainError, "AdaptivePooling only support NCHW or NHWC format");
} }
dst_layout.init_contiguous_stride(); dst_layout.init_contiguous_stride();
} }
size_t IH = src_layout[2], IW = src_layout[3], OH = dst_layout[2], size_t IH, IW, OH, OW;
OW = dst_layout[3]; if (param_format == param::AdaptivePooling::Format::NCHW) {
IH = src_layout[2];
IW = src_layout[3];
OH = dst_layout[2];
OW = dst_layout[3];
} else if (param_format == param::AdaptivePooling::Format::NHWC) {
IH = src_layout[1];
IW = src_layout[2];
OH = dst_layout[1];
OW = dst_layout[2];
} else {
mgb_throw(MegBrainError, "AdaptivePooling only support NCHW or NHWC format");
}
DnnOprCaller<megdnn::Pooling> dnn_opr(cn); DnnOprCaller<megdnn::Pooling> dnn_opr(cn);
auto&& param = dnn_opr.op->param(); auto&& param = dnn_opr.op->param();
param.mode = pool.mode; param.mode = pool.mode;
......
...@@ -105,7 +105,7 @@ std::vector<int32_t> convert_nchw2nhwc_vector(const std::vector<int32_t>& shape) ...@@ -105,7 +105,7 @@ std::vector<int32_t> convert_nchw2nhwc_vector(const std::vector<int32_t>& shape)
} else { } else {
mgb_throw( mgb_throw(
MegBrainError, MegBrainError,
"Unsupported shape ndim %u in convert NCHW shape to NHWC.", "Unsupported shape ndim %lu in convert NCHW shape to NHWC.",
shape.size()); shape.size());
} }
} }
...@@ -184,7 +184,8 @@ ValueRefList reshape_rule( ...@@ -184,7 +184,8 @@ ValueRefList reshape_rule(
// output is still NHWC format // output is still NHWC format
auto nhwc_shape = convert_nchw2nhwc_vector(op.shape); auto nhwc_shape = convert_nchw2nhwc_vector(op.shape);
auto outputs = imperative::apply( auto outputs = imperative::apply(
*Reshape::make(op.axis, nhwc_shape), {t.unwrap_input(inputs[0])}); *Reshape::make(op.axis, nhwc_shape),
{t.unwrap_input(inputs[0])});
return t.wrap_outputs(outputs, FT::NHWC); return t.wrap_outputs(outputs, FT::NHWC);
} else { } else {
// will not maintain src's format // will not maintain src's format
...@@ -395,12 +396,17 @@ ValueRefList batchnorm_rule( ...@@ -395,12 +396,17 @@ ValueRefList batchnorm_rule(
return identity_rule_helper(op, inputs, t); return identity_rule_helper(op, inputs, t);
} }
ValueRefList checknonfinite_rule( ValueRefList adaptive_pooling_rule(
const CheckNonFinite& op, Span<ValueRef>& inputs, const bool& auto_convert, const AdaptivePooling& op, Span<ValueRef>& inputs, const bool& auto_convert,
const FormatTransformation& t) { const FormatTransformation& t) {
auto&& inputs_ = t.unwrap_inputs(inputs); auto&& inp_format = inputs[0].cast(t.value_type()).format();
auto&& outputs_ = imperative::apply(op, inputs_); if (inp_format == FT::NHWC) {
return t.wrap_outputs(outputs_); auto&& new_param = op.param();
new_param.format = AdaptivePooling::Format::NHWC;
auto new_op = AdaptivePooling::make(new_param, op.shape);
return identity_rule_helper(*new_op, inputs, t);
}
return identity_rule_helper(op, inputs, t);
} }
// clang-format off // clang-format off
...@@ -417,7 +423,6 @@ ValueRefList checknonfinite_rule( ...@@ -417,7 +423,6 @@ ValueRefList checknonfinite_rule(
cb(Identity) cb(Identity)
#define FOREACH_FORMAT_OP(cb) \ #define FOREACH_FORMAT_OP(cb) \
cb(AdaptivePooling) \
cb(WarpAffine) \ cb(WarpAffine) \
cb(Resize) cb(Resize)
...@@ -494,7 +499,7 @@ struct FormatRuleRegistry { ...@@ -494,7 +499,7 @@ struct FormatRuleRegistry {
register_format_rule(setsubtensor_rule<IndexingSetMultiAxisVec>); register_format_rule(setsubtensor_rule<IndexingSetMultiAxisVec>);
register_format_rule(concat_rule); register_format_rule(concat_rule);
register_format_rule(batchnorm_rule); register_format_rule(batchnorm_rule);
register_format_rule(checknonfinite_rule); register_format_rule(adaptive_pooling_rule);
FOREACH_MULTI_INPS_NO_PARAM_OP(REGISTER_OP_RULE) FOREACH_MULTI_INPS_NO_PARAM_OP(REGISTER_OP_RULE)
FOREACH_IDENTITY_OP(REGISTER_OP_RULE) FOREACH_IDENTITY_OP(REGISTER_OP_RULE)
FOREACH_FORMAT_OP(REGISTER_OP_RULE) FOREACH_FORMAT_OP(REGISTER_OP_RULE)
...@@ -506,7 +511,7 @@ struct FormatRuleRegistry { ...@@ -506,7 +511,7 @@ struct FormatRuleRegistry {
ValueRefList FormatTransformation::apply_transformation( ValueRefList FormatTransformation::apply_transformation(
const Operator& op, Span<ValueRef> inputs) { const Operator& op, Span<ValueRef> inputs) {
//mgb_log_warn("Format::apply_transformation %s", op.to_string().c_str()); // mgb_log_warn("Format::apply_transformation %s", op.to_string().c_str());
if (auto* apply_op = op.as<ApplyOp>()) { if (auto* apply_op = op.as<ApplyOp>()) {
// all inputs should be FormattedTensorValue // all inputs should be FormattedTensorValue
auto iter = format_rules.find(apply_op->op().dyn_typeinfo()); auto iter = format_rules.find(apply_op->op().dyn_typeinfo());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册