提交 616352b0 编写于 作者: M Megvii Engine Team

fix(imperative): add dtype promote support for concat

GitOrigin-RevId: e743a6c99585b5072be68d33e937dbb80e57a27b
上级 95a30eb6
......@@ -238,7 +238,7 @@ ValueRefList batch_norm_rule(const OpDef& op, Span<ValueRef> inputs) {
return imperative::apply(op, inputs);
}
ValueRefList convolution3d_rule(const OpDef& op, Span<ValueRef> inputs) {
ValueRefList naive_promote_rule(const OpDef& op, Span<ValueRef> inputs) {
SmallVector<DType> dtypes = get_value_dtypes(inputs);
mgb::DType target_dtype = get_promoted_dtype(dtypes);
......@@ -258,12 +258,13 @@ ValueRefList convolution3d_rule(const OpDef& op, Span<ValueRef> inputs) {
struct DTypePromoteRuleRegistry {
DTypePromoteRuleRegistry() {
register_dtype_promote_rule<Elemwise>(elemwise_rule);
register_dtype_promote_rule<Concat>(naive_promote_rule);
register_dtype_promote_rule<Reduce>(reduce_rule);
register_dtype_promote_rule<Convolution>(convolution_rule);
register_dtype_promote_rule<ConvolutionBackwardData>(convolution_backward_rule);
register_dtype_promote_rule<BatchNorm>(batch_norm_rule);
register_dtype_promote_rule<Convolution3D>(convolution3d_rule);
register_dtype_promote_rule<Convolution3DBackwardData>(convolution3d_rule);
register_dtype_promote_rule<Convolution3D>(naive_promote_rule);
register_dtype_promote_rule<Convolution3DBackwardData>(naive_promote_rule);
}
} register_helper;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册