提交 74b8af4d 编写于 作者: M Megvii Engine Team

feat(dtype): support complex dtype

GitOrigin-RevId: 8a8715b322b40e805dfb9f3da08d6fc31c1675ea
上级 bc9f9cd4
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <stdint.h> #include <stdint.h>
#include <cfloat> #include <cfloat>
#include <complex>
#include <cstddef> #include <cstddef>
#include <limits> #include <limits>
...@@ -31,7 +32,7 @@ namespace megdnn { ...@@ -31,7 +32,7 @@ namespace megdnn {
#define MEGDNN_FOREACH_DTYPE_NAME(cb) \ #define MEGDNN_FOREACH_DTYPE_NAME(cb) \
cb(Float32) cb(Uint8) cb(Int8) cb(Int16) cb(Int32) cb(IntB1) cb(IntB2) cb(IntB4) \ cb(Float32) cb(Uint8) cb(Int8) cb(Int16) cb(Int32) cb(IntB1) cb(IntB2) cb(IntB4) \
cb(Byte) DNN_INC_FLOAT16(cb(Float16)) DNN_INC_FLOAT16(cb(BFloat16)) \ cb(Byte) DNN_INC_FLOAT16(cb(Float16)) DNN_INC_FLOAT16(cb(BFloat16)) \
cb(UintB4) cb(Bool) cb(Uint16) cb(UintB4) cb(Bool) cb(Uint16) cb(Complex64)
/*! /*!
* \brief iterate through each full byte dtype * \brief iterate through each full byte dtype
...@@ -39,7 +40,7 @@ namespace megdnn { ...@@ -39,7 +40,7 @@ namespace megdnn {
#define MEGDNN_FOREACH_FULL_BYTE_DTYPE(cb) \ #define MEGDNN_FOREACH_FULL_BYTE_DTYPE(cb) \
cb(Float32) cb(Uint8) cb(Int8) cb(Int16) cb(Int32) cb(Byte) \ cb(Float32) cb(Uint8) cb(Int8) cb(Int16) cb(Int32) cb(Byte) \
DNN_INC_FLOAT16(cb(Float16)) DNN_INC_FLOAT16(cb(BFloat16)) cb(Bool) \ DNN_INC_FLOAT16(cb(Float16)) DNN_INC_FLOAT16(cb(BFloat16)) cb(Bool) \
cb(Uint16) cb(Uint16) cb(Complex64)
/*! /*!
* \brief iterate through each fractional byte dtype * \brief iterate through each fractional byte dtype
...@@ -314,6 +315,7 @@ typedef bool dt_bool; ...@@ -314,6 +315,7 @@ typedef bool dt_bool;
typedef uint16_t dt_uint16; typedef uint16_t dt_uint16;
DNN_INC_FLOAT16(typedef half_float::half dt_float16;) DNN_INC_FLOAT16(typedef half_float::half dt_float16;)
DNN_INC_FLOAT16(typedef half_bfloat16::bfloat16 dt_bfloat16;) DNN_INC_FLOAT16(typedef half_bfloat16::bfloat16 dt_bfloat16;)
typedef std::complex<float> dt_complex64;
#define MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE 100000 #define MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE 100000
#if MEGDNN_CC_HOST #if MEGDNN_CC_HOST
...@@ -341,6 +343,7 @@ struct DTypeEnum { ...@@ -341,6 +343,7 @@ struct DTypeEnum {
#endif #endif
Bool = 12, Bool = 12,
Uint16 = 13, Uint16 = 13,
Complex64 = 14,
#define FST(_name) _name = MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE, #define FST(_name) _name = MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE,
#define D(_name) _name, #define D(_name) _name,
MEGDNN_FOREACH_PARAMETERIZED_DTYPE_2(FST, D) MEGDNN_FOREACH_PARAMETERIZED_DTYPE_2(FST, D)
...@@ -356,12 +359,28 @@ DTypeEnum(uint32_t e) : ev(e) {} ...@@ -356,12 +359,28 @@ DTypeEnum(uint32_t e) : ev(e) {}
#if MEGDNN_CC_HOST #if MEGDNN_CC_HOST
//! dtype numeric category fo //! dtype numeric category fo
enum class DTypeCategory : int { OTHER, FLOAT, INT, LOWBIT, QUANTIZED, BOOL }; enum class DTypeCategory : int {
OTHER,
FLOAT,
INT,
LOWBIT,
QUANTIZED,
BOOL,
COMPLEX,
};
//! dtype signedness //! dtype signedness
enum class DTypeSignedness : int { OTHER, UNSIGNED, SIGNED }; enum class DTypeSignedness : int { OTHER, UNSIGNED, SIGNED };
#else #else
struct DTypeCategory { struct DTypeCategory {
enum Ev { OTHER, FLOAT, INT, LOWBIT, QUANTIZED, BOOL }; enum Ev {
OTHER,
FLOAT,
INT,
LOWBIT,
QUANTIZED,
BOOL,
COMPLEX,
};
int ev; int ev;
}; };
struct DTypeSignedness { struct DTypeSignedness {
...@@ -447,6 +466,15 @@ public: ...@@ -447,6 +466,15 @@ public:
bool is_low_bit() const { return low_bit() != 0; } bool is_low_bit() const { return low_bit() != 0; }
bool is_complex() const {
return
#if MEGDNN_CC_HOST
m_trait->category == DTypeCategory::COMPLEX;
#else
m_trait->category.ev == DTypeCategory::Ev::COMPLEX;
#endif
}
bool is_quantized_lowbit() const { bool is_quantized_lowbit() const {
return low_bit() != 0 && return low_bit() != 0 &&
#if MEGDNN_CC_HOST #if MEGDNN_CC_HOST
...@@ -665,6 +693,11 @@ struct DTypeTrait<dtype::Byte> { ...@@ -665,6 +693,11 @@ struct DTypeTrait<dtype::Byte> {
MEGDNN_DEF_DT_BASIC_FIELDS(Byte, dt_byte, OTHER, OTHER, 0, false); MEGDNN_DEF_DT_BASIC_FIELDS(Byte, dt_byte, OTHER, OTHER, 0, false);
}; };
template <>
struct DTypeTrait<dtype::Complex64> {
MEGDNN_DEF_DT_BASIC_FIELDS(Complex64, dt_complex64, COMPLEX, SIGNED, 0, false);
};
#define MEGDNN_DEF_FRACTION_DT(_name, b) \ #define MEGDNN_DEF_FRACTION_DT(_name, b) \
template <> \ template <> \
struct DTypeTrait<dtype::_name##b> { \ struct DTypeTrait<dtype::_name##b> { \
......
...@@ -9,8 +9,11 @@ from ..core._imperative_rt.core2 import ( ...@@ -9,8 +9,11 @@ from ..core._imperative_rt.core2 import (
Const, Const,
apply, apply,
broadcast_cpp, broadcast_cpp,
create_complex,
dtype_promotion, dtype_promotion,
expand_dims_cpp, expand_dims_cpp,
get_imag,
get_real,
split_cpp, split_cpp,
squeeze_cpp, squeeze_cpp,
) )
...@@ -20,13 +23,14 @@ from ..core.ops.builtin import Copy, Identity ...@@ -20,13 +23,14 @@ from ..core.ops.builtin import Copy, Identity
from ..core.tensor.utils import astensor1d, convert_inputs, get_device, subgraph_fn from ..core.tensor.utils import astensor1d, convert_inputs, get_device, subgraph_fn
from ..device import get_default_device from ..device import get_default_device
from ..tensor import Tensor from ..tensor import Tensor
from .elemwise import ceil from .elemwise import ceil, cos, sin
__all__ = [ __all__ = [
"arange", "arange",
"broadcast_to", "broadcast_to",
"concat", "concat",
"cond_take", "cond_take",
"copy",
"cumsum", "cumsum",
"diag", "diag",
"expand_dims", "expand_dims",
...@@ -35,21 +39,24 @@ __all__ = [ ...@@ -35,21 +39,24 @@ __all__ = [
"full", "full",
"full_like", "full_like",
"gather", "gather",
"imag",
"linspace", "linspace",
"meshgrid", "meshgrid",
"ones", "ones",
"ones_like", "ones_like",
"polar",
"repeat", "repeat",
"reshape", "reshape",
"roll", "roll",
"scatter",
"split", "split",
"squeeze", "squeeze",
"stack", "stack",
"scatter", "swapaxes",
"tile", "tile",
"copy",
"transpose", "transpose",
"swapaxes", "complex",
"real",
"where", "where",
"zeros", "zeros",
"zeros_like", "zeros_like",
...@@ -417,6 +424,26 @@ def ones_like(inp: Tensor) -> Tensor: ...@@ -417,6 +424,26 @@ def ones_like(inp: Tensor) -> Tensor:
return full_like(inp, 1.0) return full_like(inp, 1.0)
def polar(abs: Tensor, angle: Tensor) -> Tensor:
return create_complex(abs * cos(angle), abs * sin(angle))
def complex(real: Tensor, imag: Tensor) -> Tensor:
if not isinstance(real, Tensor):
real = Tensor(real)
if not isinstance(imag, Tensor):
imag = Tensor(imag)
return create_complex(real, imag)
def real(complex: Tensor) -> Tensor:
return get_real(complex)
def imag(complex: Tensor) -> Tensor:
return get_imag(complex)
def full_like(inp: Tensor, value: Union[int, float]) -> Tensor: def full_like(inp: Tensor, value: Union[int, float]) -> Tensor:
r"""Returns a tensor filled with given value with the same shape as input tensor. r"""Returns a tensor filled with given value with the same shape as input tensor.
......
...@@ -160,7 +160,8 @@ int to_mgb_supported_dtype_raw(int dtype) { ...@@ -160,7 +160,8 @@ int to_mgb_supported_dtype_raw(int dtype) {
#define FOREACH_NPY_DTYPE_PAIR(cb) \ #define FOREACH_NPY_DTYPE_PAIR(cb) \
cb(Uint8, NPY_UINT8) cb(Int8, NPY_INT8) cb(Uint16, NPY_UINT16) \ cb(Uint8, NPY_UINT8) cb(Int8, NPY_INT8) cb(Uint16, NPY_UINT16) \
cb(Int16, NPY_INT16) cb(Int32, NPY_INT32) cb(Float16, NPY_FLOAT16) \ cb(Int16, NPY_INT16) cb(Int32, NPY_INT32) cb(Float16, NPY_FLOAT16) \
cb(Float32, NPY_FLOAT32) cb(Bool, NPY_BOOL) cb(Float32, NPY_FLOAT32) cb(Bool, NPY_BOOL) \
cb(Complex64, NPY_COMPLEX64)
#define FOREACH_NPY_MGB_DTYPE_PAIR(cb) \ #define FOREACH_NPY_MGB_DTYPE_PAIR(cb) \
FOREACH_NPY_DTYPE_PAIR(cb) \ FOREACH_NPY_DTYPE_PAIR(cb) \
......
...@@ -2,11 +2,13 @@ ...@@ -2,11 +2,13 @@
#include "megbrain/dtype.h" #include "megbrain/dtype.h"
#include "megbrain/imperative/backtrace.h" #include "megbrain/imperative/backtrace.h"
#include "megbrain/imperative/cpp_cupti.h" #include "megbrain/imperative/cpp_cupti.h"
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/backward_graph.h" #include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/utility.h" #include "megbrain/imperative/ops/utility.h"
#include "megbrain/imperative/profiler.h" #include "megbrain/imperative/profiler.h"
#include "megbrain/imperative/transformation.h" #include "megbrain/imperative/transformation.h"
#include "megbrain/imperative/transformations/complex.h"
#include "megbrain/imperative/transformations/dim_expansion.h" #include "megbrain/imperative/transformations/dim_expansion.h"
#include "megbrain/imperative/transformations/dtype_promote.h" #include "megbrain/imperative/transformations/dtype_promote.h"
#include "megbrain/imperative/transformations/eval.h" #include "megbrain/imperative/transformations/eval.h"
...@@ -826,6 +828,10 @@ void init_tensor(py::module m) { ...@@ -826,6 +828,10 @@ void init_tensor(py::module m) {
.register_at<Segment::DimExpansion>( .register_at<Segment::DimExpansion>(
std::make_shared<DimExpansionTransformation>()) std::make_shared<DimExpansionTransformation>())
.release()); .release());
MGB_MARK_USED_VAR(transformations
.register_at<Segment::Complex>(
std::make_shared<ComplexTransformation>())
.release());
auto format_trans = std::make_shared<FormatTransformation>(); auto format_trans = std::make_shared<FormatTransformation>();
MGB_MARK_USED_VAR( MGB_MARK_USED_VAR(
transformations.register_at<Segment::Format>(format_trans).release()); transformations.register_at<Segment::Format>(format_trans).release());
...@@ -1460,6 +1466,31 @@ void init_tensor(py::module m) { ...@@ -1460,6 +1466,31 @@ void init_tensor(py::module m) {
[format_trans]() { return format_trans->get_auto_convert(); }); [format_trans]() { return format_trans->get_auto_convert(); });
py::register_exception<TraceError>(m, "TraceError"); py::register_exception<TraceError>(m, "TraceError");
m.def("create_complex", [](py::object real, py::object imag) {
return TensorWrapper::make(
py_tensor_type,
imperative::apply(
CreateComplex(),
TensorWrapper::try_cast(real.ptr())->m_tensor->data(),
TensorWrapper::try_cast(imag.ptr())->m_tensor->data())[0]);
});
m.def("get_real", [](py::object complex) {
return TensorWrapper::make(
py_tensor_type,
imperative::apply(
GetReal(),
TensorWrapper::try_cast(complex.ptr())->m_tensor->data())[0]);
});
m.def("get_imag", [](py::object complex) {
return TensorWrapper::make(
py_tensor_type,
imperative::apply(
GetImag(),
TensorWrapper::try_cast(complex.ptr())->m_tensor->data())[0]);
});
} }
#undef MGE_PY_INTERFACE #undef MGE_PY_INTERFACE
......
...@@ -19,15 +19,17 @@ public: ...@@ -19,15 +19,17 @@ public:
GroupComm, GroupComm,
DTypePromote, DTypePromote,
DimExpansion, DimExpansion,
Complex,
Format, Format,
Grad, Grad,
Scalar, Scalar,
Symbol, Symbol,
Trace, Trace,
Eval, Eval,
SEGMENT_COUNT,
}; };
std::array<std::vector<std::shared_ptr<Transformation>>, 10> segments; std::array<std::vector<std::shared_ptr<Transformation>>, SEGMENT_COUNT> segments;
private: private:
template <Segment segment> template <Segment segment>
......
#pragma once
#include <cstddef>
#include "megbrain/common.h"
#include "megbrain/exception.h"
#include "megbrain/imperative/basic_operators.h"
#include "megbrain/imperative/basic_values.h"
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/operator.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/transformation.h"
#include "megbrain/imperative/utils/helper.h"
#include "megbrain/imperative/utils/span.h"
#include "megbrain/imperative/value.h"
#include "megdnn/thin/small_vector.h"
namespace mgb {
namespace imperative {
class ComplexTensor final : public ObjectValue<ComplexTensor> {
private:
ValueRef m_real;
ValueRef m_imag;
public:
ComplexTensor(ValueRef real, ValueRef imag) : m_real(real), m_imag(imag) {}
std::string to_string() const override {
return ssprintf(
"ComplexTensor{m_real=%s, m_imag=%s}", m_real.to_string().c_str(),
m_imag.to_string().c_str());
}
DTypeValue::ref_t dtype() const {
auto dtype = m_real.dtype();
mgb_assert(dtype == m_imag.dtype());
return dtype;
}
const ValueRef& real() const { return m_real; }
const ValueRef imag() const { return m_imag; }
/**
* \brief clear all states of this value
*
*/
void clear() override {
m_real = {};
m_imag = {};
}
};
class CreateComplex final : public OperatorImpl<CreateComplex> {
public:
std::string to_string() const override { return "CreateComplex"; }
std::string raw_type() const override { return "CreateComplex"; }
};
class GetReal final : public OperatorImpl<GetReal> {
public:
std::string to_string() const override { return "GetReal"; }
std::string raw_type() const override { return "GetReal"; }
};
class GetImag final : public OperatorImpl<GetImag> {
public:
std::string to_string() const override { return "GetImag"; }
std::string raw_type() const override { return "GetImag"; }
};
class ComplexTransformation final : public Transformation {
private:
ObjectType<ComplexTensor> m_complex_type{"Complex"};
public:
std::string name() const override { return "ComplexTransformation"; }
HostTensorND make_complex_tensor(HostTensorND real, HostTensorND imag) {
mgb_assert(real.shape().eq_shape(imag.shape()));
mgb_assert(
real.dtype() == dtype::Float32() && imag.dtype() == dtype::Float32());
mgb_assert(real.comp_node() == imag.comp_node());
HostTensorND complex{real.comp_node(), real.shape(), dtype::Complex64()};
TensorShape f32_shape = complex.shape();
f32_shape[f32_shape.ndim++] = 2;
TensorLayout f32_layout = {f32_shape, dtype::Float32()};
f32_layout.init_contiguous_stride();
HostTensorND f32{complex.comp_node(), f32_layout};
f32.storage(complex.storage());
TensorLayout real_layout = f32_layout;
real_layout.ndim--;
TensorLayout imag_layout = real_layout;
// mgb_assert(!real_layout.is_contiguous());
// mgb_assert(!imag_layout.is_contiguous());
f32.sub(SubTensorSpec::make_from_layout(real_layout)).copy_from_fixlayout(real);
f32.sub(SubTensorSpec::make_from_offset_elem(imag_layout, 1))
.copy_from_fixlayout(imag);
return complex;
}
ValueRefList apply_complex_mask(
const ApplyOp& apply_op, Span<ValueRef> inputs, Span<bool> mask) {
ValueRefList real_list(inputs.size());
ValueRefList imag_list(inputs.size());
bool any_complex = false;
bool all_complex = true;
for (size_t i = 0; i < inputs.size(); ++i) {
if (auto* complex = inputs[i].as(m_complex_type)) {
mgb_assert(mask[i], "unexpected complex");
any_complex = true;
real_list[i] = complex->real();
imag_list[i] = complex->imag();
} else {
real_list[i] = inputs[i];
if (mask[i]) {
all_complex = false;
} else {
imag_list[i] = inputs[i];
}
}
}
if (!any_complex) {
// no complex
return imperative::apply(apply_op, real_list);
} else {
// all complex
mgb_assert(all_complex, "only serval inputs are complex");
auto reals = imperative::apply(apply_op, real_list);
auto imags = imperative::apply(apply_op, imag_list);
mgb_assert(reals.size() == imags.size());
ValueRefList results(reals.size());
for (size_t i = 0; i < results.size(); ++i) {
results[i] = m_complex_type.make(reals[i], imags[i]);
}
return results;
}
}
ValueRefList apply_complex_real(const ApplyOp& apply_op, Span<ValueRef> inputs) {
ValueRefList real_list(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
if (auto* complex = inputs[i].as(m_complex_type)) {
real_list[i] = complex->real();
} else {
real_list[i] = inputs[i];
}
}
return imperative::apply(apply_op, real_list);
}
ValueRefList apply_transformation(
const Operator& op, Span<ValueRef> inputs) override {
if (auto* create_complex = op.as<CreateComplex>()) {
auto [real, imag] = inputs.as_array<2>();
auto dtype_real = real.dtype();
auto dtype_imag = imag.dtype();
mgb_assert(
*dtype_real == *dtype_imag, "dtype mismatch: %s vs %s",
dtype_real->name(), dtype_imag->name());
return {m_complex_type.make(real, imag)};
} else if (auto* create_tensor = op.as<CreateTensor>()) {
if (create_tensor->dtype().is_complex()) {
auto args = create_tensor->parse(inputs);
mgb_assert(!args.device);
auto& host = *args.host;
// reinterpret_cast to f32
mgb_assert(host.layout().is_physical_contiguous());
mgb_assert(host.dtype() == dtype::Complex64());
TensorShape f32_shape = host.shape();
f32_shape[f32_shape.ndim++] = 2;
TensorLayout f32_layout = {f32_shape, dtype::Float32()};
HostTensorND f32_host = {host.comp_node(), f32_layout};
f32_host.storage(host.storage());
// take real slice and imag slice
auto real_layout = f32_layout;
real_layout[real_layout.ndim - 1] = 1;
auto imag_layout = real_layout;
auto real_host =
f32_host.sub(SubTensorSpec::make_from_layout(real_layout));
auto imag_host = f32_host.sub(
SubTensorSpec::make_from_offset_elem(imag_layout, 1));
// create real and imag
auto real = imperative::apply(
CreateTensor(
create_tensor->kind(), create_tensor->device(),
real_layout),
HostStorage::make(real_host.storage()))[0];
auto imag = imperative::apply(
CreateTensor(
create_tensor->kind(), create_tensor->device(),
imag_layout),
HostStorage::make(imag_host.storage()))[0];
return {m_complex_type.make(real, imag)};
} else {
return imperative::apply(op, inputs);
}
}
bool any_complex = false;
for (auto&& input : inputs) {
if (input.is(m_complex_type)) {
any_complex = true;
break;
}
}
if (!any_complex) {
return imperative::apply(op, inputs);
}
if (auto* apply_op = op.as<ApplyOp>()) {
// TODO: handle apply op
// see https://zhuanlan.zhihu.com/p/627536105
if (auto* elemwise = apply_op->op().try_cast_final<Elemwise>()) {
switch (elemwise->mode) {
case Elemwise::Mode::MUL: {
auto* complex_a = inputs[0].as(m_complex_type);
auto* complex_b = inputs[1].as(m_complex_type);
auto& mul = *apply_op;
if (complex_a && complex_b) {
auto add = Elemwise::make(Elemwise::Mode::ADD);
auto sub = Elemwise::make(Elemwise::Mode::SUB);
auto real = imperative::apply(
*sub,
imperative::apply(
mul, complex_a->real(),
complex_b->real())[0],
imperative::apply(
mul, complex_a->imag(),
complex_b->imag())[0])[0];
auto imag = imperative::apply(
*add,
imperative::apply(
mul, complex_a->real(),
complex_b->imag())[0],
imperative::apply(
mul, complex_a->imag(),
complex_b->real())[0])[0];
return {m_complex_type.make(real, imag)};
} else if (complex_a) {
auto real = imperative::apply(
mul, complex_a->real(), inputs[1])[0];
auto imag = imperative::apply(
mul, complex_a->imag(), inputs[1])[0];
return {m_complex_type.make(real, imag)};
} else if (complex_b) {
auto real = imperative::apply(
mul, complex_b->real(), inputs[0])[0];
auto imag = imperative::apply(
mul, complex_b->imag(), inputs[0])[0];
return {m_complex_type.make(real, imag)};
} else {
mgb_assert(0);
}
}
case Elemwise::Mode::ADD:
case Elemwise::Mode::SUB: {
bool mask[2] = {true, true};
return apply_complex_mask(*apply_op, inputs, {mask, 2});
}
case Elemwise::Mode::NEGATE: {
bool mask[1] = {true};
return apply_complex_mask(*apply_op, inputs, {mask, 1});
}
default: {
mgb_assert(0, "unsupported elemwise mode");
}
}
} else if (auto* reshape = apply_op->op().try_cast_final<Reshape>()) {
SmallVector<bool> mask(inputs.size(), false);
mask[0] = true;
return apply_complex_mask(*apply_op, inputs, mask);
} else if (auto* subtensor = apply_op->op().try_cast_final<Subtensor>()) {
SmallVector<bool> mask(inputs.size(), false);
mask[0] = true;
return apply_complex_mask(*apply_op, inputs, mask);
} else if (auto* get_shape = apply_op->op().try_cast_final<GetVarShape>()) {
return apply_complex_real(*apply_op, inputs);
} else {
mgb_assert(0, "unsupported operator");
}
} else if (auto* get_attr = op.as<GetAttr>()) {
// TODO: handle get attr
auto&& input = inputs[0].as_ref(m_complex_type);
switch (get_attr->attr()) {
case GetAttr::DType:
switch (input->dtype()->enumv()) {
case DTypeEnum::Float32: {
return {DTypeValue::make(dtype::Complex64())};
}
default:
mgb_assert(
0, "unsupported dtype %s", input->dtype()->name());
}
case GetAttr::Device:
case GetAttr::Shape:
return imperative::apply(op, input->real());
case GetAttr::Value: {
auto complex = make_complex_tensor(
input->real().numpy()->as_nd(),
input->imag().numpy()->as_nd());
return {HostValue::make(complex)};
}
default:
mgb_throw(
MegBrainError, "unsupported %s for complex",
get_attr->to_string().c_str());
}
} else if (auto* as_real = op.as<GetReal>()) {
auto&& input = inputs[0].as_ref(m_complex_type);
return {input->real()};
} else if (auto* as_real = op.as<GetImag>()) {
auto&& input = inputs[0].as_ref(m_complex_type);
return {input->imag()};
}
mgb_throw(
MegBrainError, "unsupported op for complex: %s",
op.to_string().c_str());
}
ValueRef unwrap(ValueRef value) override {
mgb_assert(!value.is(m_complex_type), "cannot unwrap complex value");
return value;
}
};
} // namespace imperative
} // namespace mgb
...@@ -400,6 +400,8 @@ cg::OperatorNodeBase::NodeProp* Loop::do_make_node_prop() const { ...@@ -400,6 +400,8 @@ cg::OperatorNodeBase::NodeProp* Loop::do_make_node_prop() const {
break; break;
case DTypeEnum::Bool: case DTypeEnum::Bool:
break; break;
case DTypeEnum::Complex64:
break;
#define cb(x) \ #define cb(x) \
case DTypeEnum::x: \ case DTypeEnum::x: \
......
...@@ -235,6 +235,8 @@ public: ...@@ -235,6 +235,8 @@ public:
break; break;
case DTypeEnum::Uint16: case DTypeEnum::Uint16:
break; break;
case DTypeEnum::Complex64:
break;
#define cb(_dt) \ #define cb(_dt) \
case DTypeEnum::_dt: \ case DTypeEnum::_dt: \
break; break;
......
...@@ -24,6 +24,7 @@ enum DTypeEnum : byte { ...@@ -24,6 +24,7 @@ enum DTypeEnum : byte {
Bool, Bool,
Uint16, Uint16,
QuantizedS1, QuantizedS1,
Complex64,
} }
table LinearQuantizationParam { table LinearQuantizationParam {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册