提交 15c6da62 编写于 作者: M Megvii Engine Team

feat(imperative/amp): add nhwc support for adaptive pooling

GitOrigin-RevId: 7c5755308e4355f38fffd5634f0733836699a8c1
上级 c28a875f
......@@ -37,12 +37,10 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
return {{{TensorLayout(src.layout.dtype), src.comp_node}}, false};
}
const dt_int32* oshp2d = nullptr;
dst_layout.ndim = 4u;
if (nr_inp == 1) {
dst_layout[0] = src.layout[0];
dst_layout[1] = src.layout[1];
dst_layout[2] = pool.shape[0];
dst_layout[3] = pool.shape[1];
oshp2d = pool.shape.data();
} else {
auto&& tshp = inputs[1];
if (tshp.value.empty()) {
......@@ -52,11 +50,21 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
tshp.layout.ndim == 1,
"target shape of AdaptivePooling expects ndim=1; got ndim=%lu actually",
tshp.layout.ndim);
oshp2d = tshp.value.ptr<dt_int32>();
}
auto param_format = pool.param().format;
if (param_format == opr::AdaptivePooling::Param::Format::NCHW) {
dst_layout[0] = src.layout[0];
dst_layout[1] = src.layout[1];
auto* ptr = tshp.value.ptr<dt_int32>();
dst_layout[2] = ptr[0];
dst_layout[3] = ptr[1];
dst_layout[2] = oshp2d[0];
dst_layout[3] = oshp2d[1];
} else if (param_format == opr::AdaptivePooling::Param::Format::NHWC) {
dst_layout[0] = src.layout[0];
dst_layout[1] = oshp2d[0];
dst_layout[2] = oshp2d[1];
dst_layout[3] = src.layout[3];
} else {
mgb_throw(MegBrainError, "AdaptivePooling only support NCHW or NHWC format");
}
dst_layout.init_contiguous_stride();
return {{{dst_layout, src.comp_node}}, true};
......@@ -71,26 +79,47 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
using TensorND = megdnn::TensorND;
auto&& src_layout = inputs[0]->layout();
TensorLayout dst_layout = output_descs[0].layout;
auto param_format = pool.format;
if (!validated) {
TensorShape tshp;
dst_layout.ndim = src_layout.ndim;
dst_layout[0] = src_layout[0];
dst_layout[1] = src_layout[1];
const dt_int32* oshp2d = nullptr;
if (inputs.size() == 2) {
auto&& tshp_nd = inputs[1];
cg::copy_tensor_value_to_shape(
tshp, tshp_nd->get_value().proxy_to_default_cpu());
dst_layout[2] = tshp[0];
dst_layout[3] = tshp[1];
oshp2d = tshp_nd->get_value().proxy_to_default_cpu().ptr<dt_int32>();
} else {
dst_layout[2] = pool.shape[0];
dst_layout[3] = pool.shape[1];
oshp2d = pool.shape.data();
}
if (param_format == opr::AdaptivePooling::Param::Format::NCHW) {
dst_layout[0] = src_layout[0];
dst_layout[1] = src_layout[1];
dst_layout[2] = oshp2d[0];
dst_layout[3] = oshp2d[1];
} else if (param_format == opr::AdaptivePooling::Param::Format::NHWC) {
dst_layout[0] = src_layout[0];
dst_layout[1] = oshp2d[0];
dst_layout[2] = oshp2d[1];
dst_layout[3] = src_layout[3];
} else {
mgb_throw(
MegBrainError, "AdaptivePooling only support NCHW or NHWC format");
}
dst_layout.init_contiguous_stride();
}
size_t IH = src_layout[2], IW = src_layout[3], OH = dst_layout[2],
size_t IH, IW, OH, OW;
if (param_format == param::AdaptivePooling::Format::NCHW) {
IH = src_layout[2];
IW = src_layout[3];
OH = dst_layout[2];
OW = dst_layout[3];
} else if (param_format == param::AdaptivePooling::Format::NHWC) {
IH = src_layout[1];
IW = src_layout[2];
OH = dst_layout[1];
OW = dst_layout[2];
} else {
mgb_throw(MegBrainError, "AdaptivePooling only support NCHW or NHWC format");
}
DnnOprCaller<megdnn::Pooling> dnn_opr(cn);
auto&& param = dnn_opr.op->param();
param.mode = pool.mode;
......
......@@ -105,7 +105,7 @@ std::vector<int32_t> convert_nchw2nhwc_vector(const std::vector<int32_t>& shape)
} else {
mgb_throw(
MegBrainError,
"Unsupported shape ndim %u in convert NCHW shape to NHWC.",
"Unsupported shape ndim %lu in convert NCHW shape to NHWC.",
shape.size());
}
}
......@@ -184,7 +184,8 @@ ValueRefList reshape_rule(
// output is still NHWC format
auto nhwc_shape = convert_nchw2nhwc_vector(op.shape);
auto outputs = imperative::apply(
*Reshape::make(op.axis, nhwc_shape), {t.unwrap_input(inputs[0])});
*Reshape::make(op.axis, nhwc_shape),
{t.unwrap_input(inputs[0])});
return t.wrap_outputs(outputs, FT::NHWC);
} else {
// will not maintain src's format
......@@ -395,12 +396,17 @@ ValueRefList batchnorm_rule(
return identity_rule_helper(op, inputs, t);
}
ValueRefList checknonfinite_rule(
const CheckNonFinite& op, Span<ValueRef>& inputs, const bool& auto_convert,
ValueRefList adaptive_pooling_rule(
const AdaptivePooling& op, Span<ValueRef>& inputs, const bool& auto_convert,
const FormatTransformation& t) {
auto&& inputs_ = t.unwrap_inputs(inputs);
auto&& outputs_ = imperative::apply(op, inputs_);
return t.wrap_outputs(outputs_);
auto&& inp_format = inputs[0].cast(t.value_type()).format();
if (inp_format == FT::NHWC) {
auto&& new_param = op.param();
new_param.format = AdaptivePooling::Format::NHWC;
auto new_op = AdaptivePooling::make(new_param, op.shape);
return identity_rule_helper(*new_op, inputs, t);
}
return identity_rule_helper(op, inputs, t);
}
// clang-format off
......@@ -417,7 +423,6 @@ ValueRefList checknonfinite_rule(
cb(Identity)
#define FOREACH_FORMAT_OP(cb) \
cb(AdaptivePooling) \
cb(WarpAffine) \
cb(Resize)
......@@ -494,7 +499,7 @@ struct FormatRuleRegistry {
register_format_rule(setsubtensor_rule<IndexingSetMultiAxisVec>);
register_format_rule(concat_rule);
register_format_rule(batchnorm_rule);
register_format_rule(checknonfinite_rule);
register_format_rule(adaptive_pooling_rule);
FOREACH_MULTI_INPS_NO_PARAM_OP(REGISTER_OP_RULE)
FOREACH_IDENTITY_OP(REGISTER_OP_RULE)
FOREACH_FORMAT_OP(REGISTER_OP_RULE)
......@@ -506,7 +511,7 @@ struct FormatRuleRegistry {
ValueRefList FormatTransformation::apply_transformation(
const Operator& op, Span<ValueRef> inputs) {
//mgb_log_warn("Format::apply_transformation %s", op.to_string().c_str());
// mgb_log_warn("Format::apply_transformation %s", op.to_string().c_str());
if (auto* apply_op = op.as<ApplyOp>()) {
// all inputs should be FormattedTensorValue
auto iter = format_rules.find(apply_op->op().dyn_typeinfo());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册