From 616352b0093f9ed0a00384610d8bf388aa383c73 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 25 Mar 2022 13:25:29 +0800 Subject: [PATCH] fix(imperative): add dtype promote support for concat GitOrigin-RevId: e743a6c99585b5072be68d33e937dbb80e57a27b --- imperative/src/impl/transformations/dtype_promote.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/imperative/src/impl/transformations/dtype_promote.cpp b/imperative/src/impl/transformations/dtype_promote.cpp index 4cdcb0b15..d39b6565a 100644 --- a/imperative/src/impl/transformations/dtype_promote.cpp +++ b/imperative/src/impl/transformations/dtype_promote.cpp @@ -238,7 +238,7 @@ ValueRefList batch_norm_rule(const OpDef& op, Span inputs) { return imperative::apply(op, inputs); } -ValueRefList convolution3d_rule(const OpDef& op, Span inputs) { +ValueRefList naive_promote_rule(const OpDef& op, Span inputs) { SmallVector dtypes = get_value_dtypes(inputs); mgb::DType target_dtype = get_promoted_dtype(dtypes); @@ -258,12 +258,13 @@ ValueRefList convolution3d_rule(const OpDef& op, Span inputs) { struct DTypePromoteRuleRegistry { DTypePromoteRuleRegistry() { register_dtype_promote_rule(elemwise_rule); + register_dtype_promote_rule(naive_promote_rule); register_dtype_promote_rule(reduce_rule); register_dtype_promote_rule(convolution_rule); register_dtype_promote_rule(convolution_backward_rule); register_dtype_promote_rule(batch_norm_rule); - register_dtype_promote_rule(convolution3d_rule); - register_dtype_promote_rule(convolution3d_rule); + register_dtype_promote_rule(naive_promote_rule); + register_dtype_promote_rule(naive_promote_rule); } } register_helper; -- GitLab