Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
74b8af4d
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
74b8af4d
编写于
5月 11, 2023
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dtype): support complex dtype
GitOrigin-RevId: 8a8715b322b40e805dfb9f3da08d6fc31c1675ea
上级
bc9f9cd4
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
439 addition
and
10 deletion
+439
-10
dnn/include/megdnn/dtype.h
dnn/include/megdnn/dtype.h
+37
-4
imperative/python/megengine/functional/tensor.py
imperative/python/megengine/functional/tensor.py
+31
-4
imperative/python/src/helper.cpp
imperative/python/src/helper.cpp
+2
-1
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+31
-0
imperative/python/src/transformation.h
imperative/python/src/transformation.h
+3
-1
imperative/src/include/megbrain/imperative/transformations/complex.h
...src/include/megbrain/imperative/transformations/complex.h
+330
-0
src/opr/impl/loop/forward.cpp
src/opr/impl/loop/forward.cpp
+2
-0
src/opr/impl/loop/impl.cpp
src/opr/impl/loop/impl.cpp
+2
-0
src/serialization/impl/dtype.fbs
src/serialization/impl/dtype.fbs
+1
-0
未找到文件。
dnn/include/megdnn/dtype.h
浏览文件 @
74b8af4d
...
...
@@ -3,6 +3,7 @@
#include <stdint.h>
#include <cfloat>
#include <complex>
#include <cstddef>
#include <limits>
...
...
@@ -31,7 +32,7 @@ namespace megdnn {
#define MEGDNN_FOREACH_DTYPE_NAME(cb) \
cb(Float32) cb(Uint8) cb(Int8) cb(Int16) cb(Int32) cb(IntB1) cb(IntB2) cb(IntB4) \
cb(Byte) DNN_INC_FLOAT16(cb(Float16)) DNN_INC_FLOAT16(cb(BFloat16)) \
cb(UintB4) cb(Bool) cb(Uint16)
cb(UintB4) cb(Bool) cb(Uint16)
cb(Complex64)
/*!
* \brief iterate through each full byte dtype
...
...
@@ -39,7 +40,7 @@ namespace megdnn {
#define MEGDNN_FOREACH_FULL_BYTE_DTYPE(cb) \
cb(Float32) cb(Uint8) cb(Int8) cb(Int16) cb(Int32) cb(Byte) \
DNN_INC_FLOAT16(cb(Float16)) DNN_INC_FLOAT16(cb(BFloat16)) cb(Bool) \
cb(Uint16)
cb(Uint16)
cb(Complex64)
/*!
* \brief iterate through each fractional byte dtype
...
...
@@ -314,6 +315,7 @@ typedef bool dt_bool;
typedef
uint16_t
dt_uint16
;
DNN_INC_FLOAT16
(
typedef
half_float
::
half
dt_float16
;)
DNN_INC_FLOAT16
(
typedef
half_bfloat16
::
bfloat16
dt_bfloat16
;)
typedef
std
::
complex
<
float
>
dt_complex64
;
#define MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE 100000
#if MEGDNN_CC_HOST
...
...
@@ -341,6 +343,7 @@ struct DTypeEnum {
#endif
Bool
=
12
,
Uint16
=
13
,
Complex64
=
14
,
#define FST(_name) _name = MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE,
#define D(_name) _name,
MEGDNN_FOREACH_PARAMETERIZED_DTYPE_2
(
FST
,
D
)
...
...
@@ -356,12 +359,28 @@ DTypeEnum(uint32_t e) : ev(e) {}
#if MEGDNN_CC_HOST
//! dtype numeric category fo
enum
class
DTypeCategory
:
int
{
OTHER
,
FLOAT
,
INT
,
LOWBIT
,
QUANTIZED
,
BOOL
};
enum
class
DTypeCategory
:
int
{
OTHER
,
FLOAT
,
INT
,
LOWBIT
,
QUANTIZED
,
BOOL
,
COMPLEX
,
};
//! dtype signedness
enum
class
DTypeSignedness
:
int
{
OTHER
,
UNSIGNED
,
SIGNED
};
#else
struct
DTypeCategory
{
enum
Ev
{
OTHER
,
FLOAT
,
INT
,
LOWBIT
,
QUANTIZED
,
BOOL
};
enum
Ev
{
OTHER
,
FLOAT
,
INT
,
LOWBIT
,
QUANTIZED
,
BOOL
,
COMPLEX
,
};
int
ev
;
};
struct
DTypeSignedness
{
...
...
@@ -447,6 +466,15 @@ public:
bool
is_low_bit
()
const
{
return
low_bit
()
!=
0
;
}
bool
is_complex
()
const
{
return
#if MEGDNN_CC_HOST
m_trait
->
category
==
DTypeCategory
::
COMPLEX
;
#else
m_trait
->
category
.
ev
==
DTypeCategory
::
Ev
::
COMPLEX
;
#endif
}
bool
is_quantized_lowbit
()
const
{
return
low_bit
()
!=
0
&&
#if MEGDNN_CC_HOST
...
...
@@ -665,6 +693,11 @@ struct DTypeTrait<dtype::Byte> {
MEGDNN_DEF_DT_BASIC_FIELDS
(
Byte
,
dt_byte
,
OTHER
,
OTHER
,
0
,
false
);
};
template
<
>
struct
DTypeTrait
<
dtype
::
Complex64
>
{
MEGDNN_DEF_DT_BASIC_FIELDS
(
Complex64
,
dt_complex64
,
COMPLEX
,
SIGNED
,
0
,
false
);
};
#define MEGDNN_DEF_FRACTION_DT(_name, b) \
template <> \
struct DTypeTrait<dtype::_name##b> { \
...
...
imperative/python/megengine/functional/tensor.py
浏览文件 @
74b8af4d
...
...
@@ -9,8 +9,11 @@ from ..core._imperative_rt.core2 import (
Const
,
apply
,
broadcast_cpp
,
create_complex
,
dtype_promotion
,
expand_dims_cpp
,
get_imag
,
get_real
,
split_cpp
,
squeeze_cpp
,
)
...
...
@@ -20,13 +23,14 @@ from ..core.ops.builtin import Copy, Identity
from
..core.tensor.utils
import
astensor1d
,
convert_inputs
,
get_device
,
subgraph_fn
from
..device
import
get_default_device
from
..tensor
import
Tensor
from
.elemwise
import
ceil
from
.elemwise
import
ceil
,
cos
,
sin
__all__
=
[
"arange"
,
"broadcast_to"
,
"concat"
,
"cond_take"
,
"copy"
,
"cumsum"
,
"diag"
,
"expand_dims"
,
...
...
@@ -35,21 +39,24 @@ __all__ = [
"full"
,
"full_like"
,
"gather"
,
"imag"
,
"linspace"
,
"meshgrid"
,
"ones"
,
"ones_like"
,
"polar"
,
"repeat"
,
"reshape"
,
"roll"
,
"scatter"
,
"split"
,
"squeeze"
,
"stack"
,
"s
catter
"
,
"s
wapaxes
"
,
"tile"
,
"copy"
,
"transpose"
,
"swapaxes"
,
"complex"
,
"real"
,
"where"
,
"zeros"
,
"zeros_like"
,
...
...
@@ -417,6 +424,26 @@ def ones_like(inp: Tensor) -> Tensor:
return
full_like
(
inp
,
1.0
)
def
polar
(
abs
:
Tensor
,
angle
:
Tensor
)
->
Tensor
:
return
create_complex
(
abs
*
cos
(
angle
),
abs
*
sin
(
angle
))
def
complex
(
real
:
Tensor
,
imag
:
Tensor
)
->
Tensor
:
if
not
isinstance
(
real
,
Tensor
):
real
=
Tensor
(
real
)
if
not
isinstance
(
imag
,
Tensor
):
imag
=
Tensor
(
imag
)
return
create_complex
(
real
,
imag
)
def
real
(
complex
:
Tensor
)
->
Tensor
:
return
get_real
(
complex
)
def
imag
(
complex
:
Tensor
)
->
Tensor
:
return
get_imag
(
complex
)
def
full_like
(
inp
:
Tensor
,
value
:
Union
[
int
,
float
])
->
Tensor
:
r
"""Returns a tensor filled with given value with the same shape as input tensor.
...
...
imperative/python/src/helper.cpp
浏览文件 @
74b8af4d
...
...
@@ -160,7 +160,8 @@ int to_mgb_supported_dtype_raw(int dtype) {
#define FOREACH_NPY_DTYPE_PAIR(cb) \
cb(Uint8, NPY_UINT8) cb(Int8, NPY_INT8) cb(Uint16, NPY_UINT16) \
cb(Int16, NPY_INT16) cb(Int32, NPY_INT32) cb(Float16, NPY_FLOAT16) \
cb(Float32, NPY_FLOAT32) cb(Bool, NPY_BOOL)
cb(Float32, NPY_FLOAT32) cb(Bool, NPY_BOOL) \
cb(Complex64, NPY_COMPLEX64)
#define FOREACH_NPY_MGB_DTYPE_PAIR(cb) \
FOREACH_NPY_DTYPE_PAIR(cb) \
...
...
imperative/python/src/tensor.cpp
浏览文件 @
74b8af4d
...
...
@@ -2,11 +2,13 @@
#include "megbrain/dtype.h"
#include "megbrain/imperative/backtrace.h"
#include "megbrain/imperative/cpp_cupti.h"
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/utility.h"
#include "megbrain/imperative/profiler.h"
#include "megbrain/imperative/transformation.h"
#include "megbrain/imperative/transformations/complex.h"
#include "megbrain/imperative/transformations/dim_expansion.h"
#include "megbrain/imperative/transformations/dtype_promote.h"
#include "megbrain/imperative/transformations/eval.h"
...
...
@@ -826,6 +828,10 @@ void init_tensor(py::module m) {
.
register_at
<
Segment
::
DimExpansion
>
(
std
::
make_shared
<
DimExpansionTransformation
>
())
.
release
());
MGB_MARK_USED_VAR
(
transformations
.
register_at
<
Segment
::
Complex
>
(
std
::
make_shared
<
ComplexTransformation
>
())
.
release
());
auto
format_trans
=
std
::
make_shared
<
FormatTransformation
>
();
MGB_MARK_USED_VAR
(
transformations
.
register_at
<
Segment
::
Format
>
(
format_trans
).
release
());
...
...
@@ -1460,6 +1466,31 @@ void init_tensor(py::module m) {
[
format_trans
]()
{
return
format_trans
->
get_auto_convert
();
});
py
::
register_exception
<
TraceError
>
(
m
,
"TraceError"
);
m
.
def
(
"create_complex"
,
[](
py
::
object
real
,
py
::
object
imag
)
{
return
TensorWrapper
::
make
(
py_tensor_type
,
imperative
::
apply
(
CreateComplex
(),
TensorWrapper
::
try_cast
(
real
.
ptr
())
->
m_tensor
->
data
(),
TensorWrapper
::
try_cast
(
imag
.
ptr
())
->
m_tensor
->
data
())[
0
]);
});
m
.
def
(
"get_real"
,
[](
py
::
object
complex
)
{
return
TensorWrapper
::
make
(
py_tensor_type
,
imperative
::
apply
(
GetReal
(),
TensorWrapper
::
try_cast
(
complex
.
ptr
())
->
m_tensor
->
data
())[
0
]);
});
m
.
def
(
"get_imag"
,
[](
py
::
object
complex
)
{
return
TensorWrapper
::
make
(
py_tensor_type
,
imperative
::
apply
(
GetImag
(),
TensorWrapper
::
try_cast
(
complex
.
ptr
())
->
m_tensor
->
data
())[
0
]);
});
}
#undef MGE_PY_INTERFACE
...
...
imperative/python/src/transformation.h
浏览文件 @
74b8af4d
...
...
@@ -19,15 +19,17 @@ public:
GroupComm
,
DTypePromote
,
DimExpansion
,
Complex
,
Format
,
Grad
,
Scalar
,
Symbol
,
Trace
,
Eval
,
SEGMENT_COUNT
,
};
std
::
array
<
std
::
vector
<
std
::
shared_ptr
<
Transformation
>>
,
10
>
segments
;
std
::
array
<
std
::
vector
<
std
::
shared_ptr
<
Transformation
>>
,
SEGMENT_COUNT
>
segments
;
private:
template
<
Segment
segment
>
...
...
imperative/src/include/megbrain/imperative/transformations/complex.h
0 → 100644
浏览文件 @
74b8af4d
#pragma once
#include <cstddef>
#include "megbrain/common.h"
#include "megbrain/exception.h"
#include "megbrain/imperative/basic_operators.h"
#include "megbrain/imperative/basic_values.h"
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/operator.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/transformation.h"
#include "megbrain/imperative/utils/helper.h"
#include "megbrain/imperative/utils/span.h"
#include "megbrain/imperative/value.h"
#include "megdnn/thin/small_vector.h"
namespace
mgb
{
namespace
imperative
{
class
ComplexTensor
final
:
public
ObjectValue
<
ComplexTensor
>
{
private:
ValueRef
m_real
;
ValueRef
m_imag
;
public:
ComplexTensor
(
ValueRef
real
,
ValueRef
imag
)
:
m_real
(
real
),
m_imag
(
imag
)
{}
std
::
string
to_string
()
const
override
{
return
ssprintf
(
"ComplexTensor{m_real=%s, m_imag=%s}"
,
m_real
.
to_string
().
c_str
(),
m_imag
.
to_string
().
c_str
());
}
DTypeValue
::
ref_t
dtype
()
const
{
auto
dtype
=
m_real
.
dtype
();
mgb_assert
(
dtype
==
m_imag
.
dtype
());
return
dtype
;
}
const
ValueRef
&
real
()
const
{
return
m_real
;
}
const
ValueRef
imag
()
const
{
return
m_imag
;
}
/**
* \brief clear all states of this value
*
*/
void
clear
()
override
{
m_real
=
{};
m_imag
=
{};
}
};
class
CreateComplex
final
:
public
OperatorImpl
<
CreateComplex
>
{
public:
std
::
string
to_string
()
const
override
{
return
"CreateComplex"
;
}
std
::
string
raw_type
()
const
override
{
return
"CreateComplex"
;
}
};
class
GetReal
final
:
public
OperatorImpl
<
GetReal
>
{
public:
std
::
string
to_string
()
const
override
{
return
"GetReal"
;
}
std
::
string
raw_type
()
const
override
{
return
"GetReal"
;
}
};
class
GetImag
final
:
public
OperatorImpl
<
GetImag
>
{
public:
std
::
string
to_string
()
const
override
{
return
"GetImag"
;
}
std
::
string
raw_type
()
const
override
{
return
"GetImag"
;
}
};
class
ComplexTransformation
final
:
public
Transformation
{
private:
ObjectType
<
ComplexTensor
>
m_complex_type
{
"Complex"
};
public:
std
::
string
name
()
const
override
{
return
"ComplexTransformation"
;
}
HostTensorND
make_complex_tensor
(
HostTensorND
real
,
HostTensorND
imag
)
{
mgb_assert
(
real
.
shape
().
eq_shape
(
imag
.
shape
()));
mgb_assert
(
real
.
dtype
()
==
dtype
::
Float32
()
&&
imag
.
dtype
()
==
dtype
::
Float32
());
mgb_assert
(
real
.
comp_node
()
==
imag
.
comp_node
());
HostTensorND
complex
{
real
.
comp_node
(),
real
.
shape
(),
dtype
::
Complex64
()};
TensorShape
f32_shape
=
complex
.
shape
();
f32_shape
[
f32_shape
.
ndim
++
]
=
2
;
TensorLayout
f32_layout
=
{
f32_shape
,
dtype
::
Float32
()};
f32_layout
.
init_contiguous_stride
();
HostTensorND
f32
{
complex
.
comp_node
(),
f32_layout
};
f32
.
storage
(
complex
.
storage
());
TensorLayout
real_layout
=
f32_layout
;
real_layout
.
ndim
--
;
TensorLayout
imag_layout
=
real_layout
;
// mgb_assert(!real_layout.is_contiguous());
// mgb_assert(!imag_layout.is_contiguous());
f32
.
sub
(
SubTensorSpec
::
make_from_layout
(
real_layout
)).
copy_from_fixlayout
(
real
);
f32
.
sub
(
SubTensorSpec
::
make_from_offset_elem
(
imag_layout
,
1
))
.
copy_from_fixlayout
(
imag
);
return
complex
;
}
ValueRefList
apply_complex_mask
(
const
ApplyOp
&
apply_op
,
Span
<
ValueRef
>
inputs
,
Span
<
bool
>
mask
)
{
ValueRefList
real_list
(
inputs
.
size
());
ValueRefList
imag_list
(
inputs
.
size
());
bool
any_complex
=
false
;
bool
all_complex
=
true
;
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
if
(
auto
*
complex
=
inputs
[
i
].
as
(
m_complex_type
))
{
mgb_assert
(
mask
[
i
],
"unexpected complex"
);
any_complex
=
true
;
real_list
[
i
]
=
complex
->
real
();
imag_list
[
i
]
=
complex
->
imag
();
}
else
{
real_list
[
i
]
=
inputs
[
i
];
if
(
mask
[
i
])
{
all_complex
=
false
;
}
else
{
imag_list
[
i
]
=
inputs
[
i
];
}
}
}
if
(
!
any_complex
)
{
// no complex
return
imperative
::
apply
(
apply_op
,
real_list
);
}
else
{
// all complex
mgb_assert
(
all_complex
,
"only serval inputs are complex"
);
auto
reals
=
imperative
::
apply
(
apply_op
,
real_list
);
auto
imags
=
imperative
::
apply
(
apply_op
,
imag_list
);
mgb_assert
(
reals
.
size
()
==
imags
.
size
());
ValueRefList
results
(
reals
.
size
());
for
(
size_t
i
=
0
;
i
<
results
.
size
();
++
i
)
{
results
[
i
]
=
m_complex_type
.
make
(
reals
[
i
],
imags
[
i
]);
}
return
results
;
}
}
ValueRefList
apply_complex_real
(
const
ApplyOp
&
apply_op
,
Span
<
ValueRef
>
inputs
)
{
ValueRefList
real_list
(
inputs
.
size
());
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
if
(
auto
*
complex
=
inputs
[
i
].
as
(
m_complex_type
))
{
real_list
[
i
]
=
complex
->
real
();
}
else
{
real_list
[
i
]
=
inputs
[
i
];
}
}
return
imperative
::
apply
(
apply_op
,
real_list
);
}
ValueRefList
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
override
{
if
(
auto
*
create_complex
=
op
.
as
<
CreateComplex
>
())
{
auto
[
real
,
imag
]
=
inputs
.
as_array
<
2
>
();
auto
dtype_real
=
real
.
dtype
();
auto
dtype_imag
=
imag
.
dtype
();
mgb_assert
(
*
dtype_real
==
*
dtype_imag
,
"dtype mismatch: %s vs %s"
,
dtype_real
->
name
(),
dtype_imag
->
name
());
return
{
m_complex_type
.
make
(
real
,
imag
)};
}
else
if
(
auto
*
create_tensor
=
op
.
as
<
CreateTensor
>
())
{
if
(
create_tensor
->
dtype
().
is_complex
())
{
auto
args
=
create_tensor
->
parse
(
inputs
);
mgb_assert
(
!
args
.
device
);
auto
&
host
=
*
args
.
host
;
// reinterpret_cast to f32
mgb_assert
(
host
.
layout
().
is_physical_contiguous
());
mgb_assert
(
host
.
dtype
()
==
dtype
::
Complex64
());
TensorShape
f32_shape
=
host
.
shape
();
f32_shape
[
f32_shape
.
ndim
++
]
=
2
;
TensorLayout
f32_layout
=
{
f32_shape
,
dtype
::
Float32
()};
HostTensorND
f32_host
=
{
host
.
comp_node
(),
f32_layout
};
f32_host
.
storage
(
host
.
storage
());
// take real slice and imag slice
auto
real_layout
=
f32_layout
;
real_layout
[
real_layout
.
ndim
-
1
]
=
1
;
auto
imag_layout
=
real_layout
;
auto
real_host
=
f32_host
.
sub
(
SubTensorSpec
::
make_from_layout
(
real_layout
));
auto
imag_host
=
f32_host
.
sub
(
SubTensorSpec
::
make_from_offset_elem
(
imag_layout
,
1
));
// create real and imag
auto
real
=
imperative
::
apply
(
CreateTensor
(
create_tensor
->
kind
(),
create_tensor
->
device
(),
real_layout
),
HostStorage
::
make
(
real_host
.
storage
()))[
0
];
auto
imag
=
imperative
::
apply
(
CreateTensor
(
create_tensor
->
kind
(),
create_tensor
->
device
(),
imag_layout
),
HostStorage
::
make
(
imag_host
.
storage
()))[
0
];
return
{
m_complex_type
.
make
(
real
,
imag
)};
}
else
{
return
imperative
::
apply
(
op
,
inputs
);
}
}
bool
any_complex
=
false
;
for
(
auto
&&
input
:
inputs
)
{
if
(
input
.
is
(
m_complex_type
))
{
any_complex
=
true
;
break
;
}
}
if
(
!
any_complex
)
{
return
imperative
::
apply
(
op
,
inputs
);
}
if
(
auto
*
apply_op
=
op
.
as
<
ApplyOp
>
())
{
// TODO: handle apply op
// see https://zhuanlan.zhihu.com/p/627536105
if
(
auto
*
elemwise
=
apply_op
->
op
().
try_cast_final
<
Elemwise
>
())
{
switch
(
elemwise
->
mode
)
{
case
Elemwise
::
Mode
::
MUL
:
{
auto
*
complex_a
=
inputs
[
0
].
as
(
m_complex_type
);
auto
*
complex_b
=
inputs
[
1
].
as
(
m_complex_type
);
auto
&
mul
=
*
apply_op
;
if
(
complex_a
&&
complex_b
)
{
auto
add
=
Elemwise
::
make
(
Elemwise
::
Mode
::
ADD
);
auto
sub
=
Elemwise
::
make
(
Elemwise
::
Mode
::
SUB
);
auto
real
=
imperative
::
apply
(
*
sub
,
imperative
::
apply
(
mul
,
complex_a
->
real
(),
complex_b
->
real
())[
0
],
imperative
::
apply
(
mul
,
complex_a
->
imag
(),
complex_b
->
imag
())[
0
])[
0
];
auto
imag
=
imperative
::
apply
(
*
add
,
imperative
::
apply
(
mul
,
complex_a
->
real
(),
complex_b
->
imag
())[
0
],
imperative
::
apply
(
mul
,
complex_a
->
imag
(),
complex_b
->
real
())[
0
])[
0
];
return
{
m_complex_type
.
make
(
real
,
imag
)};
}
else
if
(
complex_a
)
{
auto
real
=
imperative
::
apply
(
mul
,
complex_a
->
real
(),
inputs
[
1
])[
0
];
auto
imag
=
imperative
::
apply
(
mul
,
complex_a
->
imag
(),
inputs
[
1
])[
0
];
return
{
m_complex_type
.
make
(
real
,
imag
)};
}
else
if
(
complex_b
)
{
auto
real
=
imperative
::
apply
(
mul
,
complex_b
->
real
(),
inputs
[
0
])[
0
];
auto
imag
=
imperative
::
apply
(
mul
,
complex_b
->
imag
(),
inputs
[
0
])[
0
];
return
{
m_complex_type
.
make
(
real
,
imag
)};
}
else
{
mgb_assert
(
0
);
}
}
case
Elemwise
::
Mode
::
ADD
:
case
Elemwise
::
Mode
::
SUB
:
{
bool
mask
[
2
]
=
{
true
,
true
};
return
apply_complex_mask
(
*
apply_op
,
inputs
,
{
mask
,
2
});
}
case
Elemwise
::
Mode
::
NEGATE
:
{
bool
mask
[
1
]
=
{
true
};
return
apply_complex_mask
(
*
apply_op
,
inputs
,
{
mask
,
1
});
}
default:
{
mgb_assert
(
0
,
"unsupported elemwise mode"
);
}
}
}
else
if
(
auto
*
reshape
=
apply_op
->
op
().
try_cast_final
<
Reshape
>
())
{
SmallVector
<
bool
>
mask
(
inputs
.
size
(),
false
);
mask
[
0
]
=
true
;
return
apply_complex_mask
(
*
apply_op
,
inputs
,
mask
);
}
else
if
(
auto
*
subtensor
=
apply_op
->
op
().
try_cast_final
<
Subtensor
>
())
{
SmallVector
<
bool
>
mask
(
inputs
.
size
(),
false
);
mask
[
0
]
=
true
;
return
apply_complex_mask
(
*
apply_op
,
inputs
,
mask
);
}
else
if
(
auto
*
get_shape
=
apply_op
->
op
().
try_cast_final
<
GetVarShape
>
())
{
return
apply_complex_real
(
*
apply_op
,
inputs
);
}
else
{
mgb_assert
(
0
,
"unsupported operator"
);
}
}
else
if
(
auto
*
get_attr
=
op
.
as
<
GetAttr
>
())
{
// TODO: handle get attr
auto
&&
input
=
inputs
[
0
].
as_ref
(
m_complex_type
);
switch
(
get_attr
->
attr
())
{
case
GetAttr
::
DType
:
switch
(
input
->
dtype
()
->
enumv
())
{
case
DTypeEnum
::
Float32
:
{
return
{
DTypeValue
::
make
(
dtype
::
Complex64
())};
}
default:
mgb_assert
(
0
,
"unsupported dtype %s"
,
input
->
dtype
()
->
name
());
}
case
GetAttr
::
Device
:
case
GetAttr
::
Shape
:
return
imperative
::
apply
(
op
,
input
->
real
());
case
GetAttr
::
Value
:
{
auto
complex
=
make_complex_tensor
(
input
->
real
().
numpy
()
->
as_nd
(),
input
->
imag
().
numpy
()
->
as_nd
());
return
{
HostValue
::
make
(
complex
)};
}
default:
mgb_throw
(
MegBrainError
,
"unsupported %s for complex"
,
get_attr
->
to_string
().
c_str
());
}
}
else
if
(
auto
*
as_real
=
op
.
as
<
GetReal
>
())
{
auto
&&
input
=
inputs
[
0
].
as_ref
(
m_complex_type
);
return
{
input
->
real
()};
}
else
if
(
auto
*
as_real
=
op
.
as
<
GetImag
>
())
{
auto
&&
input
=
inputs
[
0
].
as_ref
(
m_complex_type
);
return
{
input
->
imag
()};
}
mgb_throw
(
MegBrainError
,
"unsupported op for complex: %s"
,
op
.
to_string
().
c_str
());
}
ValueRef
unwrap
(
ValueRef
value
)
override
{
mgb_assert
(
!
value
.
is
(
m_complex_type
),
"cannot unwrap complex value"
);
return
value
;
}
};
}
// namespace imperative
}
// namespace mgb
src/opr/impl/loop/forward.cpp
浏览文件 @
74b8af4d
...
...
@@ -400,6 +400,8 @@ cg::OperatorNodeBase::NodeProp* Loop::do_make_node_prop() const {
break
;
case
DTypeEnum
::
Bool
:
break
;
case
DTypeEnum
::
Complex64
:
break
;
#define cb(x) \
case DTypeEnum::x: \
...
...
src/opr/impl/loop/impl.cpp
浏览文件 @
74b8af4d
...
...
@@ -235,6 +235,8 @@ public:
break
;
case
DTypeEnum
::
Uint16
:
break
;
case
DTypeEnum
::
Complex64
:
break
;
#define cb(_dt) \
case DTypeEnum::_dt: \
break;
...
...
src/serialization/impl/dtype.fbs
浏览文件 @
74b8af4d
...
...
@@ -24,6 +24,7 @@ enum DTypeEnum : byte {
Bool,
Uint16,
QuantizedS1,
Complex64,
}
table LinearQuantizationParam {
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录