提交 7d3df995 编写于 作者: M Megvii Engine Team 提交者: huangxinda

feat(gopt/inference): allow Float32 output dtype in EnableNCHW4Pass

GitOrigin-RevId: 81100dbaf7d9e8f5aade6c630ead79ea5ba45e52
上级 633016a9
......@@ -31,6 +31,15 @@ bool ConvBiasForwardImpl::AlgoCUDNNConv::is_available(
}
}
// FIXME: cudnn cannot handle the case when the initial value of dst tensor
// contains nan and beta is zero, because the result of 0.f * nan is still
// nan
if (args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS8 &&
args.dst_layout->dtype.enumv() == DTypeEnum::Float32 &&
args.opr->param().format == param::ConvBias::Format::NCHW) {
return false;
}
auto dst_layout = *args.dst_layout;
if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) {
dst_layout.dtype = DType();
......
......@@ -57,6 +57,15 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available(
}
#endif
// FIXME: cudnn cannot handle the case when the initial value of dst tensor
// contains nan and beta is zero, because the result of 0.f * nan is still
// nan
if (args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS8 &&
args.dst_layout->dtype.enumv() == DTypeEnum::Float32 &&
param.format == param::ConvBias::Format::NCHW) {
return false;
}
//! FIXME: conv kernel of cudnn for NCHW4_NCHW tensor format causes illegal
//! memory access errors, so we have to disable this kernel here.
if (param.format == param::ConvBias::Format::NCHW4_NCHW ||
......
......@@ -1619,6 +1619,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
megdnn::param::Convolution::Format::NCHW4;
megdnn::param::ConvBias::Format conv_bias_format =
megdnn::param::ConvBias::Format::NCHW4;
megdnn::param::ConvBias::Format conv_bias_format_nchw4_nchw =
megdnn::param::ConvBias::Format::NCHW4_NCHW;
megdnn::param::BatchConvBias::Format batch_conv_bias_format =
megdnn::param::BatchConvBias::Format::NCHW4;
RelayoutMode src_to_nchw4_mode = RelayoutMode::NCHW_TO_NCHW4;
......@@ -1821,6 +1823,7 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
return new_opr;
};
auto replace_conv_bias_opr = [trans_nchw4, conv_bias_format,
conv_bias_format_nchw4_nchw,
src_to_nchw4_mode](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
......@@ -1851,19 +1854,27 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
conv_bias_filter = new_filter.node();
// format: NCHW --> NCHW4
auto new_param = conv_bias_opr.param();
new_param.format = conv_bias_format;
if (conv_bias_opr.output().size() > 0 &&
conv_bias_opr.output(0)->dtype().enumv() == DTypeEnum::Float32) {
new_param.format = conv_bias_format_nchw4_nchw;
} else {
new_param.format = conv_bias_format;
}
if (new_inp.size() == 2) {
auto new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_filter, new_param,
conv_bias_opr.execution_policy(), conv_bias_opr.config());
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr();
mgb_assert(new_conv_bias_opr.shape().ndim == 5,
"The conv_bias dst dim is not trans to nchw4");
mgb_assert(
new_conv_bias_opr.node()->dtype().enumv() == DTypeEnum::Float32 ||
new_conv_bias_opr.shape().ndim == 5,
"The conv_bias dst dim is not trans to nchw4");
return new_opr;
}
// bias: NCHW --> NCHW4
// bias: NCHW --> NCHW4 when bias_dtype is not Float32
VarNode* conv_bias_bias = new_inp[2];
if (new_inp[2]->shape().ndim == 4) {
if (new_inp[2]->dtype().enumv() != DTypeEnum::Float32 &&
new_inp[2]->shape().ndim == 4) {
auto new_bias =
RelayoutPlaceholder::make(new_inp[2], src_to_nchw4_mode);
conv_bias_bias = new_bias.node();
......@@ -1873,13 +1884,16 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
conv_bias_src, conv_bias_filter, conv_bias_bias, new_param,
conv_bias_opr.execution_policy(), conv_bias_opr.config());
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr();
mgb_assert(new_conv_bias_opr.shape().ndim == 5,
"The conv_bias dst dim is not trans to nchw4");
mgb_assert(
new_conv_bias_opr.node()->dtype().enumv() == DTypeEnum::Float32 ||
new_conv_bias_opr.shape().ndim == 5,
"The conv_bias dst dim is not trans to nchw4");
return new_opr;
}
// z_inp: NCHW --> NCHW4
// z_inp: NCHW --> NCHW4 when bias_dtype is not Float32
VarNode* z_inp = new_inp[3];
if (new_inp[3]->shape().ndim == 4) {
if (new_inp[3]->dtype().enumv() != DTypeEnum::Float32 &&
new_inp[3]->shape().ndim == 4) {
auto new_z =
RelayoutPlaceholder::make(new_inp[3], src_to_nchw4_mode);
z_inp = new_z.node();
......@@ -1889,8 +1903,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
new_param, conv_bias_opr.execution_policy(),
conv_bias_opr.config());
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr();
mgb_assert(new_conv_bias_opr.shape().ndim == 5,
"The conv_bias dst dim is not trans to nchw4");
mgb_assert(
new_conv_bias_opr.node()->dtype().enumv() == DTypeEnum::Float32 ||
new_conv_bias_opr.shape().ndim == 5,
"The conv_bias dst dim is not trans to nchw4");
return new_opr;
};
auto replace_elemwise_opr = [=](OperatorNodeBase* opr,
......
......@@ -3088,6 +3088,88 @@ TEST(TestGoptInference, ConvertFormatNCHW4GPU) {
MGB_ASSERT_TENSOR_EQ(host_y, host_y_opt);
}
TEST(TestGoptInference, ConvertFormatNCHW4FloatGPU) {
REQUIRE_GPU(1);
auto cn = CompNode::load("gpu0");
cn.activate();
REQUIRE_CUDA_COMPUTE_CAPABILITY_EQ(6, 1);
HostTensorGenerator<> gen;
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkvar = [&](const char* name, const TensorShape& shp,
const DType& dtype) {
return opr::TypeCvt::make(
opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name),
dtype);
};
auto mkcvar = [&](const char* name, const TensorShape& shp,
const DType& dtype) {
return opr::TypeCvt::make(
opr::SharedDeviceTensor::make(*graph, *gen(shp, cn))
.rename(name),
dtype);
};
auto x = mkvar("x", {2, 4, 16, 16}, dtype::QuantizedS8(1.2f));
opr::ConvBias::Param param_conv_bias;
param_conv_bias.pad_h = param_conv_bias.pad_w = 1;
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE;
// conv1, with bias
auto w1 = mkcvar("w1", {8, 4, 3, 3}, dtype::QuantizedS8(1.3f)),
b1 = mkcvar("b1", {1, 8, 1, 1}, dtype::Float32());
auto conv1 = opr::ConvBias::make(x, w1, b1, param_conv_bias, {},
OperatorNodeConfig{dtype::Float32()});
// conv2, with bias and z
auto w2 = mkcvar("w2", {8, 4, 3, 3}, dtype::QuantizedS8(1.3f)),
b2 = mkcvar("b2", {1, 8, 1, 1}, dtype::Float32()),
z2 = mkcvar("z2", {2, 8, 16, 16}, dtype::Float32());
auto conv2 = opr::ConvBias::make(x, w2, b2, z2, param_conv_bias, {},
OperatorNodeConfig{dtype::Float32()});
// conv3, relu
param_conv_bias.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU;
auto w3 = mkcvar("w3", {8, 4, 3, 3}, dtype::QuantizedS8(1.3f)),
b3 = mkcvar("b3", {1, 8, 1, 1}, dtype::Float32()),
z3 = mkcvar("z3", {2, 8, 16, 16}, dtype::Float32());
auto conv3 = opr::ConvBias::make(x, w3, b3, z3, param_conv_bias, {},
OperatorNodeConfig{dtype::Float32()});
auto y = conv1 + conv2 + conv3;
SymbolVar y_opt;
{
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nchw4();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
}
bool succ = true;
auto cb = [&succ](cg::OperatorNodeBase* opr) {
if (opr->same_type<opr::ConvBias>()) {
auto& conv_bias = opr->cast_final_safe<opr::ConvBias>();
if (conv_bias.param().format !=
opr::ConvBias::Param::Format::NCHW4_NCHW) {
succ = false;
}
}
};
cg::DepOprIter{cb}.add(y_opt);
ASSERT_TRUE(succ);
HostTensorND host_y, host_y_opt;
auto func = graph->compile({make_callback_copy(y, host_y),
make_callback_copy(y_opt, host_y_opt)});
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-5);
}
#endif
TEST(TestGoptInference, ConvertFormatNCHW4NonConvOpr) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册