提交 086ee045 编写于 作者: M Megvii Engine Team

fix(mge): fix error when transform model to nchw4/32/64 tensor format

GitOrigin-RevId: 34be9c7844499b306616398403a7b25d4142fb90
上级 841bfe92
......@@ -100,6 +100,10 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const {
auto conv_bias = try_cast_as_op<opr::ConvBias>(shuffle->input(0)->owner_opr());
if (conv_bias == nullptr)
return false;
bool is_group =
conv_bias->param().sparse == megdnn::param::ConvBias::Sparse::GROUP;
if (is_group)
return false;
inp_dtype = conv_bias->input(0)->dtype();
bool is_s8nchw4 =
inp_dtype.enumv() == DTypeEnum::QuantizedS8 &&
......@@ -180,6 +184,10 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const {
auto conv_bias = try_cast_as_op<opr::ConvBias>(reshape2->input(0)->owner_opr());
if (conv_bias == nullptr)
return false;
bool is_group =
conv_bias->param().sparse == megdnn::param::ConvBias::Sparse::GROUP;
if (is_group)
return false;
auto inp_dtype = conv_bias->input(0)->dtype();
bool is_s8nchw4 =
inp_dtype.enumv() == DTypeEnum::QuantizedS8 &&
......@@ -267,6 +275,10 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const {
auto conv_bias = try_cast_as_op<opr::ConvBias>(shuffle->input(0)->owner_opr());
if (conv_bias == nullptr)
return false;
bool is_group =
conv_bias->param().sparse == megdnn::param::ConvBias::Sparse::GROUP;
if (is_group)
return false;
auto inp_dtype = conv_bias->input(0)->dtype();
bool is_s8nchw4 =
inp_dtype.enumv() == DTypeEnum::QuantizedS8 &&
......@@ -345,6 +357,10 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const {
auto conv_bias = try_cast_as_op<opr::ConvBias>(reshape2->input(0)->owner_opr());
if (conv_bias == nullptr)
return false;
bool is_group =
conv_bias->param().sparse == megdnn::param::ConvBias::Sparse::GROUP;
if (is_group)
return false;
auto inp_dtype = conv_bias->input(0)->dtype();
bool is_s8nchw32 =
inp_dtype.enumv() == DTypeEnum::QuantizedS8 &&
......
此差异已折叠。
......@@ -235,6 +235,9 @@ public:
class EnableTensorCorePass final : public TensorReformatPass {
VarNode* on_graph_endpoint_var(VarNode* new_var, VarNode* orig_var) const override;
static VarNode* trans_to_nchw32(VarNode* new_inp);
static VarNode* trans_from_nchw32(VarNode* new_inp, VarNode* orig_inp);
public:
const char* name() const override { return mgb_cstr_log("enable_tensorcore"); }
//! make enable tensorcore opt pass
......
......@@ -2356,8 +2356,18 @@ TEST(TestEnableTensorCore, SmallInputShape) {
dtype);
};
auto x = mkvar("x", {32, 16, 4, 8, 4}, dtype::QuantizedS8(2.5f)),
w = mkcvar("w1", {64, 16, 3, 3, 4}, dtype::QuantizedS8(2.5f)),
auto x0 = mkvar("x0", {32, 16, 4, 8, 4}, dtype::QuantizedS8(2.5f)),
w0 = mkcvar("w0", {2, 32, 8, 3, 3, 4}, dtype::QuantizedS8(2.5f)),
b0 = mkcvar("b0", {1, 16, 1, 1, 4}, dtype::QuantizedS32(6.25f)),
z0 = mkcvar("z0", {32, 16, 4, 8, 4}, dtype::QuantizedS8(2.5f));
opr::ConvBias::Param param0;
param0.format = opr::ConvBias::Param::Format::NCHW4;
param0.sparse = opr::ConvBias::Param::Sparse::GROUP;
param0.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU;
param0.stride_h = param0.stride_w = 1;
param0.pad_h = param0.pad_w = 1;
auto w = mkcvar("w1", {64, 16, 3, 3, 4}, dtype::QuantizedS8(2.5f)),
b = mkcvar("b", {1, 16, 1, 1, 4}, dtype::QuantizedS32(6.25f)),
z = mkcvar("b1", {32, 16, 2, 4, 4}, dtype::QuantizedS8(2.5f));
opr::ConvBias::Param param;
......@@ -2367,7 +2377,9 @@ TEST(TestEnableTensorCore, SmallInputShape) {
param.pad_h = param.pad_w = 1;
auto y = opr::ConvBias::make(
x, w, b, z, param, {}, OperatorNodeConfig{dtype::QuantizedS8(2.5f)});
x0, w0, b0, z0, param0, {}, OperatorNodeConfig{dtype::QuantizedS8(2.5f)});
y = opr::ConvBias::make(
y, w, b, z, param, {}, OperatorNodeConfig{dtype::QuantizedS8(2.5f)});
y = opr::ConvBias::make(
y, w, b, param, {}, OperatorNodeConfig{dtype::QuantizedS8(2.5f)});
y = opr::TypeCvt::make(y, dtype::Float32());
......@@ -2431,10 +2443,47 @@ TEST(TestEnableTensorCore, Nchw4Nchw) {
}
};
auto mk_flt_shape = [](opr::ConvBias::Param::Format format, size_t OC, size_t IC,
size_t FH, size_t FW, size_t g = 1) -> TensorShape {
mgb_assert(OC % (g * 4) == 0 && IC % (g * 4) == 0);
if (g == 1) {
if (format == opr::ConvBias::Param::Format::NCHW4) {
return {OC, IC / 4, FH, FW, 4};
} else {
mgb_assert(format == opr::ConvBias::Param::Format::NCHW);
return {OC, IC, FH, FW};
}
} else {
if (format == opr::ConvBias::Param::Format::NCHW4) {
return {g, OC / g, IC / 4 / g, FH, FW, 4};
} else {
mgb_assert(format == opr::ConvBias::Param::Format::NCHW);
return {g, OC / g, IC / g, FH, FW};
}
}
};
for (auto format :
{opr::ConvBias::Param::Format::NCHW, opr::ConvBias::Param::Format::NCHW4}) {
auto x = mkvar("x", mkshape(format, 32, 64, 16, 16), dtype::QuantizedS8(2.5f)),
w = mkcvar("w1", mkshape(format, 64, 64, 3, 3), dtype::QuantizedS8(2.5f)),
auto x0 = mkvar(
"x0", mkshape(format, 32, 64, 16, 16), dtype::QuantizedS8(2.5f)),
w0 =
mkcvar("w0", mk_flt_shape(format, 64, 64, 3, 3, 2),
dtype::QuantizedS8(2.5f)),
b0 = mkcvar(
"b0", mkshape(format, 1, 64, 1, 1), dtype::QuantizedS32(6.25f)),
z0 = mkcvar(
"z0", mkshape(format, 32, 64, 16, 16), dtype::QuantizedS8(2.5f));
opr::ConvBias::Param param0;
param0.format = format;
param0.sparse = opr::ConvBias::Param::Sparse::GROUP;
param0.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU;
param0.stride_h = param0.stride_w = 1;
param0.pad_h = param0.pad_w = 1;
auto w =
mkcvar("w1", mk_flt_shape(format, 64, 64, 3, 3),
dtype::QuantizedS8(2.5f)),
b = mkcvar("b", mkshape(format, 1, 64, 1, 1), dtype::QuantizedS32(6.25f)),
z = mkcvar("b1", mkshape(format, 32, 64, 8, 8), dtype::QuantizedS8(2.5f));
opr::ConvBias::Param param;
......@@ -2444,7 +2493,10 @@ TEST(TestEnableTensorCore, Nchw4Nchw) {
param.pad_h = param.pad_w = 1;
auto y = opr::ConvBias::make(
x, w, b, z, param, {}, OperatorNodeConfig{dtype::QuantizedS8(2.5f)});
x0, w0, b0, z0, param0, {},
OperatorNodeConfig{dtype::QuantizedS8(2.5f)});
y = opr::ConvBias::make(
y, w, b, z, param, {}, OperatorNodeConfig{dtype::QuantizedS8(2.5f)});
y = opr::ConvBias::make(
y, w, b, param, {}, OperatorNodeConfig{dtype::QuantizedS8(2.5f)});
y = opr::TypeCvt::make(y, dtype::Float32());
......@@ -2470,7 +2522,7 @@ TEST(TestEnableTensorCore, Nchw4Nchw) {
ASSERT_EQ(2u, nr_dimshuffle);
#endif
} else {
ASSERT_EQ(2u, nr_dimshuffle);
ASSERT_EQ(3u, nr_dimshuffle);
}
std::string json_name;
if (format == opr::ConvBias::Param::Format::NCHW4) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册