提交 67cfce9f 编写于 作者: M Megvii Engine Team

fix(imperative/amp): add is_scalar check in elemwise and concat

GitOrigin-RevId: 61a612e92a716030d5d7ad6f6ee3258f03e35069
上级 d313f926
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from copy import deepcopy
from .. import functional as F
......
......@@ -592,7 +592,6 @@ def matmul(
transpose_a=False,
transpose_b=False,
compute_mode="default",
format="default",
) -> Tensor:
r"""Performs a matrix multiplication of the matrices ``inp1`` and ``inp2``.
......@@ -625,7 +624,7 @@ def matmul(
array([[10., 13.],
[28., 40.]], dtype=float32)
"""
return _matmul(inp1, inp2, transpose_a, transpose_b, compute_mode, format)
return _matmul(inp1, inp2, transpose_a, transpose_b, compute_mode)
def dot(inp1: Tensor, inp2: Tensor) -> Tensor:
......
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
import pytest
......
......@@ -23,24 +23,42 @@ TypedValueRef<FormattedTensorValue> FormatTransformation::to(
if (format == target)
return as(tensor, target);
auto&& shape = tensor.value().shape().cast<ShapeValue>();
if (format == FT::NHWC && (target == FT::NCHW || target == FT::DEFAULT)) {
// FIXME(czh): temporary fast path for group conv 5D weight.
if (tensor.value().shape().cast<ShapeValue>().ndim == 5) {
if (shape.ndim == 5) {
pattern = {0, 1, 4, 2, 3};
} else {
} else if (shape.ndim == 4) {
pattern = {0, 3, 1, 2};
} else {
mgb_throw(
MegBrainError,
"Unsupport format conversion for tensor %s(shape=%s) from %s to %s",
tensor.to_string().c_str(), shape.to_string().c_str(),
format.to_string().c_str(), Format(target).to_string().c_str());
}
} else if ((format == FT::NCHW || format == FT::DEFAULT) && target == FT::NHWC) {
if (tensor.value().shape().cast<ShapeValue>().ndim == 5) {
if (shape.ndim == 5) {
pattern = {0, 1, 3, 4, 2};
} else {
} else if (shape.ndim == 4) {
pattern = {0, 2, 3, 1};
} else {
mgb_throw(
MegBrainError,
"Unsupport format conversion for tensor %s(shape=%s) from %s to %s",
tensor.to_string().c_str(), shape.to_string().c_str(),
format.to_string().c_str(), Format(target).to_string().c_str());
}
} else {
mgb_throw(
MegBrainError, "Unsupport format conversion from %s to %s",
MegBrainError,
"Unsupport format conversion for tensor %s(shape=%s) from %s to %s",
tensor.to_string().c_str(), shape.to_string().c_str(),
format.to_string().c_str(), Format(target).to_string().c_str());
}
mgb_log_debug(
"Change tensor %s from %s to %s", tensor.to_string().c_str(),
format.to_string().c_str(), Format(target).to_string().c_str());
auto output =
imperative::apply(*Dimshuffle::make(pattern, scope), {tensor.value()})[0];
return m_value_type.make(output, target);
......@@ -380,9 +398,7 @@ inline ValueRefList unify_inputs_format(
ValueRefList unified_inputs(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
auto&& inp = inputs[i].cast(t.value_type());
if (inp.format() != dst_fmt &&
(inp.value().shape().cast<ShapeValue>().ndim == 4 ||
inp.value().shape().cast<ShapeValue>().ndim == 5)) {
if (inp.format() != dst_fmt) {
unified_inputs[i] = t.to(inp, dst_fmt, scope);
} else {
unified_inputs[i] = inputs[i];
......@@ -396,7 +412,16 @@ ValueRefList elemwise_rule(
const FormatTransformation& t) {
FT format = get_inputs_format(inputs, t);
if (format == FT::NHWC && auto_convert) {
auto unified_inputs = unify_inputs_format(inputs, FT::NHWC, op.scope(), t);
ValueRefList unified_inputs(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
auto&& inp = inputs[i].cast(t.value_type());
if (inp.format() != FT::NHWC && inp.value().is_scalar()) {
unified_inputs[i] = t.value_type().make(inp.value(), FT::NHWC);
} else {
unified_inputs[i] = inputs[i];
}
}
unified_inputs = unify_inputs_format(unified_inputs, FT::NHWC, op.scope(), t);
return t.wrap_outputs(
imperative::apply(op, t.unwrap_inputs(unified_inputs)), format);
}
......@@ -410,7 +435,16 @@ ValueRefList concat_rule(
if (!(format == FT::NHWC && auto_convert)) {
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), format);
}
auto unified_inputs = unify_inputs_format(inputs, FT::NHWC, op.scope(), t);
ValueRefList unified_inputs(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
auto&& inp = inputs[i].cast(t.value_type());
if (inp.format() != FT::NHWC && inp.value().is_scalar()) {
unified_inputs[i] = t.value_type().make(inp.value(), FT::NHWC);
} else {
unified_inputs[i] = inputs[i];
}
}
unified_inputs = unify_inputs_format(unified_inputs, FT::NHWC, op.scope(), t);
// TODO: handle 5D NHWC Tensor from group conv
auto axis = op.axis;
if (axis == 2 || axis == 3) {
......@@ -441,7 +475,7 @@ ValueRefList batchnorm_rule(
const FormatTransformation& t) {
auto&& inp_format = inputs[0].cast(t.value_type()).format();
if (inp_format == FT::NHWC) {
auto&& new_param = op.param();
auto new_param = op.param();
new_param.param_dim = BatchNorm::ParamDim::DIM_111C;
auto new_op = BatchNorm::make(new_param);
return identity_rule_helper(*new_op, inputs, t);
......@@ -454,7 +488,7 @@ ValueRefList adaptive_pooling_rule(
const FormatTransformation& t) {
auto&& inp_format = inputs[0].cast(t.value_type()).format();
if (inp_format == FT::NHWC) {
auto&& new_param = op.param();
auto new_param = op.param();
new_param.format = AdaptivePooling::Format::NHWC;
auto new_op = AdaptivePooling::make(new_param, op.shape);
return identity_rule_helper(*new_op, inputs, t);
......@@ -518,7 +552,7 @@ FOREACH_IDENTITY_OP(CREATE_IDENTITY_OP_RULE)
const FormatTransformation& t) { \
auto&& inp_format = inputs[0].cast(t.value_type()).format(); \
if (inp_format == FT::NHWC) { \
auto&& new_param = _op.param(); \
auto new_param = _op.param(); \
new_param.format = Op::Format::NHWC; \
auto new_op = Op::make(new_param); \
return identity_rule_helper(*new_op, inputs, t); \
......@@ -535,7 +569,7 @@ FOREACH_FORMAT_OP(CREATE_FORMAT_OP_RULE)
const FormatTransformation& t) { \
auto&& inp_format = inputs[0].cast(t.value_type()).format(); \
if (inp_format == FT::NHWC) { \
auto&& new_param = _op.param(); \
auto new_param = _op.param(); \
new_param.format = Op::Format::NHWC; \
auto new_op = Op::make(new_param, _op.policy()); \
return identity_rule_helper(*new_op, inputs, t); \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册