From 729242f9f8e1669ac31a7742cc2682e3627b415e Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 7 Mar 2022 15:47:59 +0800 Subject: [PATCH] refactor(imperative): move typecvt code of sereval ops to c++ GitOrigin-RevId: 4ffaa376c1f8edaaff7a9c9fb14cbe3ddd186515 --- imperative/python/megengine/functional/nn.py | 23 +------- .../impl/transformations/dtype_promote.cpp | 52 +++++++++++++++++++ 2 files changed, 54 insertions(+), 21 deletions(-) diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 7ca4395dc..f0d151054 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -320,12 +320,6 @@ def conv3d( stride = _triple_nonzero(stride) dilate = _triple_nonzero(dilation) - dtype = dtype_promotion(inp, weight) - if inp.dtype != dtype: - inp = inp.astype(dtype) - if weight.dtype != dtype: - weight = weight.astype(dtype) - sparse_type = "dense" if groups == 1 else "group" op = builtin.Convolution3D( pad_d=pad[D], @@ -389,15 +383,6 @@ def conv_transpose2d( conv_mode.lower() == "cross_correlation" or conv_mode.name == "CROSS_CORRELATION" ) - if amp._enabled: - compute_mode = "float32" - inp, weight, bias = cast_tensors(inp, weight, bias) - else: - dtype = dtype_promotion(inp, weight) - if inp.dtype != dtype: - inp = inp.astype(dtype) - if weight.dtype != dtype: - weight = weight.astype(dtype) stride_h, stride_w = expand_hw(stride) pad_h, pad_w = expand_hw(padding) @@ -418,6 +403,8 @@ def conv_transpose2d( ) (output,) = apply(op, weight, inp) if bias is not None: + if amp._enabled: + bias = cast_tensors(bias) output += bias return output @@ -591,12 +578,6 @@ def conv_transpose3d( stride = _triple_nonzero(stride) dilate = _triple_nonzero(dilation) - dtype = dtype_promotion(inp, weight) - if inp.dtype != dtype: - inp = inp.astype(dtype) - if weight.dtype != dtype: - weight = weight.astype(dtype) - sparse_type = "dense" if groups == 1 else "group" op = builtin.Convolution3DBackwardData( pad_d=pad[D], diff --git a/imperative/src/impl/transformations/dtype_promote.cpp b/imperative/src/impl/transformations/dtype_promote.cpp index bc4d9f2dd..5036c7276 100644 --- a/imperative/src/impl/transformations/dtype_promote.cpp +++ b/imperative/src/impl/transformations/dtype_promote.cpp @@ -183,6 +183,38 @@ ValueRefList convolution_rule(const OpDef& op, Span inputs) { return imperative::apply(op, converted); } +// differ from Convolution, ConvolutionBackwardData is used in both +// functional.conv_transpose2d and quantize.conv_transpose2d +ValueRefList convolution_backward_rule(const OpDef& op, Span inputs) { + auto&& conv_op = const_cast( + op.cast_final_safe()); + SmallVector dtypes = get_value_dtypes(inputs); + + if (is_quantized_dtype(dtypes[0]) && is_quantized_dtype(dtypes[1])) { + return imperative::apply(op, inputs); + } + + mgb::DType target_dtype; + if (DTypePromoteCfg::amp_dtype_autocast_enabled) { + conv_op.compute_mode = ConvolutionBackwardData::ComputeMode::FLOAT32; + target_dtype = DTypePromoteCfg::amp_low_prec_dtype; + } else { + target_dtype = get_promoted_dtype(dtypes); + } + + ValueRefList converted(inputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { + if (dtypes[i] != target_dtype) { + converted[i] = imperative::apply( + ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0]; + } else { + converted[i] = inputs[i]; + } + } + + return imperative::apply(op, converted); +} + ValueRefList batch_norm_rule(const OpDef& op, Span inputs) { if (DTypePromoteCfg::amp_dtype_autocast_enabled) { mgb_assert(inputs.size() > 0); @@ -208,12 +240,32 @@ ValueRefList batch_norm_rule(const OpDef& op, Span inputs) { return imperative::apply(op, inputs); } +ValueRefList convolution3d_rule(const OpDef& op, Span inputs) { + SmallVector dtypes = get_value_dtypes(inputs); + mgb::DType target_dtype = get_promoted_dtype(dtypes); + + ValueRefList converted(inputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { + if (dtypes[i] != target_dtype) { + converted[i] = imperative::apply( + ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0]; + } else { + converted[i] = inputs[i]; + } + } + + return imperative::apply(op, converted); +} + struct DTypePromoteRuleRegistry { DTypePromoteRuleRegistry() { register_dtype_promote_rule(elemwise_rule); register_dtype_promote_rule(reduce_rule); register_dtype_promote_rule(convolution_rule); + register_dtype_promote_rule(convolution_backward_rule); register_dtype_promote_rule(batch_norm_rule); + register_dtype_promote_rule(convolution3d_rule); + register_dtype_promote_rule(convolution3d_rule); } } register_helper; -- GitLab