提交 729242f9 编写于 作者: M Megvii Engine Team

refactor(imperative): move typecvt code of sereval ops to c++

GitOrigin-RevId: 4ffaa376c1f8edaaff7a9c9fb14cbe3ddd186515
上级 3c3fc6f3
......@@ -320,12 +320,6 @@ def conv3d(
stride = _triple_nonzero(stride)
dilate = _triple_nonzero(dilation)
dtype = dtype_promotion(inp, weight)
if inp.dtype != dtype:
inp = inp.astype(dtype)
if weight.dtype != dtype:
weight = weight.astype(dtype)
sparse_type = "dense" if groups == 1 else "group"
op = builtin.Convolution3D(
pad_d=pad[D],
......@@ -389,15 +383,6 @@ def conv_transpose2d(
conv_mode.lower() == "cross_correlation"
or conv_mode.name == "CROSS_CORRELATION"
)
if amp._enabled:
compute_mode = "float32"
inp, weight, bias = cast_tensors(inp, weight, bias)
else:
dtype = dtype_promotion(inp, weight)
if inp.dtype != dtype:
inp = inp.astype(dtype)
if weight.dtype != dtype:
weight = weight.astype(dtype)
stride_h, stride_w = expand_hw(stride)
pad_h, pad_w = expand_hw(padding)
......@@ -418,6 +403,8 @@ def conv_transpose2d(
)
(output,) = apply(op, weight, inp)
if bias is not None:
if amp._enabled:
bias = cast_tensors(bias)
output += bias
return output
......@@ -591,12 +578,6 @@ def conv_transpose3d(
stride = _triple_nonzero(stride)
dilate = _triple_nonzero(dilation)
dtype = dtype_promotion(inp, weight)
if inp.dtype != dtype:
inp = inp.astype(dtype)
if weight.dtype != dtype:
weight = weight.astype(dtype)
sparse_type = "dense" if groups == 1 else "group"
op = builtin.Convolution3DBackwardData(
pad_d=pad[D],
......
......@@ -183,6 +183,38 @@ ValueRefList convolution_rule(const OpDef& op, Span<ValueRef> inputs) {
return imperative::apply(op, converted);
}
// differ from Convolution, ConvolutionBackwardData is used in both
// functional.conv_transpose2d and quantize.conv_transpose2d
ValueRefList convolution_backward_rule(const OpDef& op, Span<ValueRef> inputs) {
auto&& conv_op = const_cast<ConvolutionBackwardData&>(
op.cast_final_safe<ConvolutionBackwardData>());
SmallVector<DType> dtypes = get_value_dtypes(inputs);
if (is_quantized_dtype(dtypes[0]) && is_quantized_dtype(dtypes[1])) {
return imperative::apply(op, inputs);
}
mgb::DType target_dtype;
if (DTypePromoteCfg::amp_dtype_autocast_enabled) {
conv_op.compute_mode = ConvolutionBackwardData::ComputeMode::FLOAT32;
target_dtype = DTypePromoteCfg::amp_low_prec_dtype;
} else {
target_dtype = get_promoted_dtype(dtypes);
}
ValueRefList converted(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
if (dtypes[i] != target_dtype) {
converted[i] = imperative::apply(
ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0];
} else {
converted[i] = inputs[i];
}
}
return imperative::apply(op, converted);
}
ValueRefList batch_norm_rule(const OpDef& op, Span<ValueRef> inputs) {
if (DTypePromoteCfg::amp_dtype_autocast_enabled) {
mgb_assert(inputs.size() > 0);
......@@ -208,12 +240,32 @@ ValueRefList batch_norm_rule(const OpDef& op, Span<ValueRef> inputs) {
return imperative::apply(op, inputs);
}
ValueRefList convolution3d_rule(const OpDef& op, Span<ValueRef> inputs) {
SmallVector<DType> dtypes = get_value_dtypes(inputs);
mgb::DType target_dtype = get_promoted_dtype(dtypes);
ValueRefList converted(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
if (dtypes[i] != target_dtype) {
converted[i] = imperative::apply(
ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0];
} else {
converted[i] = inputs[i];
}
}
return imperative::apply(op, converted);
}
struct DTypePromoteRuleRegistry {
DTypePromoteRuleRegistry() {
register_dtype_promote_rule<Elemwise>(elemwise_rule);
register_dtype_promote_rule<Reduce>(reduce_rule);
register_dtype_promote_rule<Convolution>(convolution_rule);
register_dtype_promote_rule<ConvolutionBackwardData>(convolution_backward_rule);
register_dtype_promote_rule<BatchNorm>(batch_norm_rule);
register_dtype_promote_rule<Convolution3D>(convolution3d_rule);
register_dtype_promote_rule<Convolution3DBackwardData>(convolution3d_rule);
}
} register_helper;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册