提交 74b8af4d 编写于 作者: M Megvii Engine Team

feat(dtype): support complex dtype

GitOrigin-RevId: 8a8715b322b40e805dfb9f3da08d6fc31c1675ea
上级 bc9f9cd4
......@@ -3,6 +3,7 @@
#include <stdint.h>
#include <cfloat>
#include <complex>
#include <cstddef>
#include <limits>
......@@ -31,7 +32,7 @@ namespace megdnn {
#define MEGDNN_FOREACH_DTYPE_NAME(cb) \
cb(Float32) cb(Uint8) cb(Int8) cb(Int16) cb(Int32) cb(IntB1) cb(IntB2) cb(IntB4) \
cb(Byte) DNN_INC_FLOAT16(cb(Float16)) DNN_INC_FLOAT16(cb(BFloat16)) \
cb(UintB4) cb(Bool) cb(Uint16)
cb(UintB4) cb(Bool) cb(Uint16) cb(Complex64)
/*!
* \brief iterate through each full byte dtype
......@@ -39,7 +40,7 @@ namespace megdnn {
#define MEGDNN_FOREACH_FULL_BYTE_DTYPE(cb) \
cb(Float32) cb(Uint8) cb(Int8) cb(Int16) cb(Int32) cb(Byte) \
DNN_INC_FLOAT16(cb(Float16)) DNN_INC_FLOAT16(cb(BFloat16)) cb(Bool) \
cb(Uint16)
cb(Uint16) cb(Complex64)
/*!
* \brief iterate through each fractional byte dtype
......@@ -314,6 +315,7 @@ typedef bool dt_bool;
typedef uint16_t dt_uint16;
DNN_INC_FLOAT16(typedef half_float::half dt_float16;)
DNN_INC_FLOAT16(typedef half_bfloat16::bfloat16 dt_bfloat16;)
typedef std::complex<float> dt_complex64;
#define MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE 100000
#if MEGDNN_CC_HOST
......@@ -341,6 +343,7 @@ struct DTypeEnum {
#endif
Bool = 12,
Uint16 = 13,
Complex64 = 14,
#define FST(_name) _name = MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE,
#define D(_name) _name,
MEGDNN_FOREACH_PARAMETERIZED_DTYPE_2(FST, D)
......@@ -356,12 +359,28 @@ DTypeEnum(uint32_t e) : ev(e) {}
#if MEGDNN_CC_HOST
//! dtype numeric category fo
enum class DTypeCategory : int { OTHER, FLOAT, INT, LOWBIT, QUANTIZED, BOOL };
enum class DTypeCategory : int {
OTHER,
FLOAT,
INT,
LOWBIT,
QUANTIZED,
BOOL,
COMPLEX,
};
//! dtype signedness
enum class DTypeSignedness : int { OTHER, UNSIGNED, SIGNED };
#else
struct DTypeCategory {
enum Ev { OTHER, FLOAT, INT, LOWBIT, QUANTIZED, BOOL };
enum Ev {
OTHER,
FLOAT,
INT,
LOWBIT,
QUANTIZED,
BOOL,
COMPLEX,
};
int ev;
};
struct DTypeSignedness {
......@@ -447,6 +466,15 @@ public:
bool is_low_bit() const { return low_bit() != 0; }
bool is_complex() const {
return
#if MEGDNN_CC_HOST
m_trait->category == DTypeCategory::COMPLEX;
#else
m_trait->category.ev == DTypeCategory::Ev::COMPLEX;
#endif
}
bool is_quantized_lowbit() const {
return low_bit() != 0 &&
#if MEGDNN_CC_HOST
......@@ -665,6 +693,11 @@ struct DTypeTrait<dtype::Byte> {
MEGDNN_DEF_DT_BASIC_FIELDS(Byte, dt_byte, OTHER, OTHER, 0, false);
};
template <>
struct DTypeTrait<dtype::Complex64> {
MEGDNN_DEF_DT_BASIC_FIELDS(Complex64, dt_complex64, COMPLEX, SIGNED, 0, false);
};
#define MEGDNN_DEF_FRACTION_DT(_name, b) \
template <> \
struct DTypeTrait<dtype::_name##b> { \
......
......@@ -9,8 +9,11 @@ from ..core._imperative_rt.core2 import (
Const,
apply,
broadcast_cpp,
create_complex,
dtype_promotion,
expand_dims_cpp,
get_imag,
get_real,
split_cpp,
squeeze_cpp,
)
......@@ -20,13 +23,14 @@ from ..core.ops.builtin import Copy, Identity
from ..core.tensor.utils import astensor1d, convert_inputs, get_device, subgraph_fn
from ..device import get_default_device
from ..tensor import Tensor
from .elemwise import ceil
from .elemwise import ceil, cos, sin
__all__ = [
"arange",
"broadcast_to",
"concat",
"cond_take",
"copy",
"cumsum",
"diag",
"expand_dims",
......@@ -35,21 +39,24 @@ __all__ = [
"full",
"full_like",
"gather",
"imag",
"linspace",
"meshgrid",
"ones",
"ones_like",
"polar",
"repeat",
"reshape",
"roll",
"scatter",
"split",
"squeeze",
"stack",
"scatter",
"swapaxes",
"tile",
"copy",
"transpose",
"swapaxes",
"complex",
"real",
"where",
"zeros",
"zeros_like",
......@@ -417,6 +424,26 @@ def ones_like(inp: Tensor) -> Tensor:
return full_like(inp, 1.0)
def polar(abs: Tensor, angle: Tensor) -> Tensor:
return create_complex(abs * cos(angle), abs * sin(angle))
def complex(real: Tensor, imag: Tensor) -> Tensor:
if not isinstance(real, Tensor):
real = Tensor(real)
if not isinstance(imag, Tensor):
imag = Tensor(imag)
return create_complex(real, imag)
def real(complex: Tensor) -> Tensor:
return get_real(complex)
def imag(complex: Tensor) -> Tensor:
return get_imag(complex)
def full_like(inp: Tensor, value: Union[int, float]) -> Tensor:
r"""Returns a tensor filled with given value with the same shape as input tensor.
......
......@@ -160,7 +160,8 @@ int to_mgb_supported_dtype_raw(int dtype) {
#define FOREACH_NPY_DTYPE_PAIR(cb) \
cb(Uint8, NPY_UINT8) cb(Int8, NPY_INT8) cb(Uint16, NPY_UINT16) \
cb(Int16, NPY_INT16) cb(Int32, NPY_INT32) cb(Float16, NPY_FLOAT16) \
cb(Float32, NPY_FLOAT32) cb(Bool, NPY_BOOL)
cb(Float32, NPY_FLOAT32) cb(Bool, NPY_BOOL) \
cb(Complex64, NPY_COMPLEX64)
#define FOREACH_NPY_MGB_DTYPE_PAIR(cb) \
FOREACH_NPY_DTYPE_PAIR(cb) \
......
......@@ -2,11 +2,13 @@
#include "megbrain/dtype.h"
#include "megbrain/imperative/backtrace.h"
#include "megbrain/imperative/cpp_cupti.h"
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/utility.h"
#include "megbrain/imperative/profiler.h"
#include "megbrain/imperative/transformation.h"
#include "megbrain/imperative/transformations/complex.h"
#include "megbrain/imperative/transformations/dim_expansion.h"
#include "megbrain/imperative/transformations/dtype_promote.h"
#include "megbrain/imperative/transformations/eval.h"
......@@ -826,6 +828,10 @@ void init_tensor(py::module m) {
.register_at<Segment::DimExpansion>(
std::make_shared<DimExpansionTransformation>())
.release());
MGB_MARK_USED_VAR(transformations
.register_at<Segment::Complex>(
std::make_shared<ComplexTransformation>())
.release());
auto format_trans = std::make_shared<FormatTransformation>();
MGB_MARK_USED_VAR(
transformations.register_at<Segment::Format>(format_trans).release());
......@@ -1460,6 +1466,31 @@ void init_tensor(py::module m) {
[format_trans]() { return format_trans->get_auto_convert(); });
py::register_exception<TraceError>(m, "TraceError");
m.def("create_complex", [](py::object real, py::object imag) {
return TensorWrapper::make(
py_tensor_type,
imperative::apply(
CreateComplex(),
TensorWrapper::try_cast(real.ptr())->m_tensor->data(),
TensorWrapper::try_cast(imag.ptr())->m_tensor->data())[0]);
});
m.def("get_real", [](py::object complex) {
return TensorWrapper::make(
py_tensor_type,
imperative::apply(
GetReal(),
TensorWrapper::try_cast(complex.ptr())->m_tensor->data())[0]);
});
m.def("get_imag", [](py::object complex) {
return TensorWrapper::make(
py_tensor_type,
imperative::apply(
GetImag(),
TensorWrapper::try_cast(complex.ptr())->m_tensor->data())[0]);
});
}
#undef MGE_PY_INTERFACE
......
......@@ -19,15 +19,17 @@ public:
GroupComm,
DTypePromote,
DimExpansion,
Complex,
Format,
Grad,
Scalar,
Symbol,
Trace,
Eval,
SEGMENT_COUNT,
};
std::array<std::vector<std::shared_ptr<Transformation>>, 10> segments;
std::array<std::vector<std::shared_ptr<Transformation>>, SEGMENT_COUNT> segments;
private:
template <Segment segment>
......
#pragma once
#include <cstddef>
#include "megbrain/common.h"
#include "megbrain/exception.h"
#include "megbrain/imperative/basic_operators.h"
#include "megbrain/imperative/basic_values.h"
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/operator.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/transformation.h"
#include "megbrain/imperative/utils/helper.h"
#include "megbrain/imperative/utils/span.h"
#include "megbrain/imperative/value.h"
#include "megdnn/thin/small_vector.h"
namespace mgb {
namespace imperative {
class ComplexTensor final : public ObjectValue<ComplexTensor> {
private:
ValueRef m_real;
ValueRef m_imag;
public:
ComplexTensor(ValueRef real, ValueRef imag) : m_real(real), m_imag(imag) {}
std::string to_string() const override {
return ssprintf(
"ComplexTensor{m_real=%s, m_imag=%s}", m_real.to_string().c_str(),
m_imag.to_string().c_str());
}
DTypeValue::ref_t dtype() const {
auto dtype = m_real.dtype();
mgb_assert(dtype == m_imag.dtype());
return dtype;
}
const ValueRef& real() const { return m_real; }
const ValueRef imag() const { return m_imag; }
/**
* \brief clear all states of this value
*
*/
void clear() override {
m_real = {};
m_imag = {};
}
};
class CreateComplex final : public OperatorImpl<CreateComplex> {
public:
std::string to_string() const override { return "CreateComplex"; }
std::string raw_type() const override { return "CreateComplex"; }
};
class GetReal final : public OperatorImpl<GetReal> {
public:
std::string to_string() const override { return "GetReal"; }
std::string raw_type() const override { return "GetReal"; }
};
class GetImag final : public OperatorImpl<GetImag> {
public:
std::string to_string() const override { return "GetImag"; }
std::string raw_type() const override { return "GetImag"; }
};
class ComplexTransformation final : public Transformation {
private:
ObjectType<ComplexTensor> m_complex_type{"Complex"};
public:
std::string name() const override { return "ComplexTransformation"; }
HostTensorND make_complex_tensor(HostTensorND real, HostTensorND imag) {
mgb_assert(real.shape().eq_shape(imag.shape()));
mgb_assert(
real.dtype() == dtype::Float32() && imag.dtype() == dtype::Float32());
mgb_assert(real.comp_node() == imag.comp_node());
HostTensorND complex{real.comp_node(), real.shape(), dtype::Complex64()};
TensorShape f32_shape = complex.shape();
f32_shape[f32_shape.ndim++] = 2;
TensorLayout f32_layout = {f32_shape, dtype::Float32()};
f32_layout.init_contiguous_stride();
HostTensorND f32{complex.comp_node(), f32_layout};
f32.storage(complex.storage());
TensorLayout real_layout = f32_layout;
real_layout.ndim--;
TensorLayout imag_layout = real_layout;
// mgb_assert(!real_layout.is_contiguous());
// mgb_assert(!imag_layout.is_contiguous());
f32.sub(SubTensorSpec::make_from_layout(real_layout)).copy_from_fixlayout(real);
f32.sub(SubTensorSpec::make_from_offset_elem(imag_layout, 1))
.copy_from_fixlayout(imag);
return complex;
}
ValueRefList apply_complex_mask(
const ApplyOp& apply_op, Span<ValueRef> inputs, Span<bool> mask) {
ValueRefList real_list(inputs.size());
ValueRefList imag_list(inputs.size());
bool any_complex = false;
bool all_complex = true;
for (size_t i = 0; i < inputs.size(); ++i) {
if (auto* complex = inputs[i].as(m_complex_type)) {
mgb_assert(mask[i], "unexpected complex");
any_complex = true;
real_list[i] = complex->real();
imag_list[i] = complex->imag();
} else {
real_list[i] = inputs[i];
if (mask[i]) {
all_complex = false;
} else {
imag_list[i] = inputs[i];
}
}
}
if (!any_complex) {
// no complex
return imperative::apply(apply_op, real_list);
} else {
// all complex
mgb_assert(all_complex, "only serval inputs are complex");
auto reals = imperative::apply(apply_op, real_list);
auto imags = imperative::apply(apply_op, imag_list);
mgb_assert(reals.size() == imags.size());
ValueRefList results(reals.size());
for (size_t i = 0; i < results.size(); ++i) {
results[i] = m_complex_type.make(reals[i], imags[i]);
}
return results;
}
}
ValueRefList apply_complex_real(const ApplyOp& apply_op, Span<ValueRef> inputs) {
ValueRefList real_list(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
if (auto* complex = inputs[i].as(m_complex_type)) {
real_list[i] = complex->real();
} else {
real_list[i] = inputs[i];
}
}
return imperative::apply(apply_op, real_list);
}
ValueRefList apply_transformation(
const Operator& op, Span<ValueRef> inputs) override {
if (auto* create_complex = op.as<CreateComplex>()) {
auto [real, imag] = inputs.as_array<2>();
auto dtype_real = real.dtype();
auto dtype_imag = imag.dtype();
mgb_assert(
*dtype_real == *dtype_imag, "dtype mismatch: %s vs %s",
dtype_real->name(), dtype_imag->name());
return {m_complex_type.make(real, imag)};
} else if (auto* create_tensor = op.as<CreateTensor>()) {
if (create_tensor->dtype().is_complex()) {
auto args = create_tensor->parse(inputs);
mgb_assert(!args.device);
auto& host = *args.host;
// reinterpret_cast to f32
mgb_assert(host.layout().is_physical_contiguous());
mgb_assert(host.dtype() == dtype::Complex64());
TensorShape f32_shape = host.shape();
f32_shape[f32_shape.ndim++] = 2;
TensorLayout f32_layout = {f32_shape, dtype::Float32()};
HostTensorND f32_host = {host.comp_node(), f32_layout};
f32_host.storage(host.storage());
// take real slice and imag slice
auto real_layout = f32_layout;
real_layout[real_layout.ndim - 1] = 1;
auto imag_layout = real_layout;
auto real_host =
f32_host.sub(SubTensorSpec::make_from_layout(real_layout));
auto imag_host = f32_host.sub(
SubTensorSpec::make_from_offset_elem(imag_layout, 1));
// create real and imag
auto real = imperative::apply(
CreateTensor(
create_tensor->kind(), create_tensor->device(),
real_layout),
HostStorage::make(real_host.storage()))[0];
auto imag = imperative::apply(
CreateTensor(
create_tensor->kind(), create_tensor->device(),
imag_layout),
HostStorage::make(imag_host.storage()))[0];
return {m_complex_type.make(real, imag)};
} else {
return imperative::apply(op, inputs);
}
}
bool any_complex = false;
for (auto&& input : inputs) {
if (input.is(m_complex_type)) {
any_complex = true;
break;
}
}
if (!any_complex) {
return imperative::apply(op, inputs);
}
if (auto* apply_op = op.as<ApplyOp>()) {
// TODO: handle apply op
// see https://zhuanlan.zhihu.com/p/627536105
if (auto* elemwise = apply_op->op().try_cast_final<Elemwise>()) {
switch (elemwise->mode) {
case Elemwise::Mode::MUL: {
auto* complex_a = inputs[0].as(m_complex_type);
auto* complex_b = inputs[1].as(m_complex_type);
auto& mul = *apply_op;
if (complex_a && complex_b) {
auto add = Elemwise::make(Elemwise::Mode::ADD);
auto sub = Elemwise::make(Elemwise::Mode::SUB);
auto real = imperative::apply(
*sub,
imperative::apply(
mul, complex_a->real(),
complex_b->real())[0],
imperative::apply(
mul, complex_a->imag(),
complex_b->imag())[0])[0];
auto imag = imperative::apply(
*add,
imperative::apply(
mul, complex_a->real(),
complex_b->imag())[0],
imperative::apply(
mul, complex_a->imag(),
complex_b->real())[0])[0];
return {m_complex_type.make(real, imag)};
} else if (complex_a) {
auto real = imperative::apply(
mul, complex_a->real(), inputs[1])[0];
auto imag = imperative::apply(
mul, complex_a->imag(), inputs[1])[0];
return {m_complex_type.make(real, imag)};
} else if (complex_b) {
auto real = imperative::apply(
mul, complex_b->real(), inputs[0])[0];
auto imag = imperative::apply(
mul, complex_b->imag(), inputs[0])[0];
return {m_complex_type.make(real, imag)};
} else {
mgb_assert(0);
}
}
case Elemwise::Mode::ADD:
case Elemwise::Mode::SUB: {
bool mask[2] = {true, true};
return apply_complex_mask(*apply_op, inputs, {mask, 2});
}
case Elemwise::Mode::NEGATE: {
bool mask[1] = {true};
return apply_complex_mask(*apply_op, inputs, {mask, 1});
}
default: {
mgb_assert(0, "unsupported elemwise mode");
}
}
} else if (auto* reshape = apply_op->op().try_cast_final<Reshape>()) {
SmallVector<bool> mask(inputs.size(), false);
mask[0] = true;
return apply_complex_mask(*apply_op, inputs, mask);
} else if (auto* subtensor = apply_op->op().try_cast_final<Subtensor>()) {
SmallVector<bool> mask(inputs.size(), false);
mask[0] = true;
return apply_complex_mask(*apply_op, inputs, mask);
} else if (auto* get_shape = apply_op->op().try_cast_final<GetVarShape>()) {
return apply_complex_real(*apply_op, inputs);
} else {
mgb_assert(0, "unsupported operator");
}
} else if (auto* get_attr = op.as<GetAttr>()) {
// TODO: handle get attr
auto&& input = inputs[0].as_ref(m_complex_type);
switch (get_attr->attr()) {
case GetAttr::DType:
switch (input->dtype()->enumv()) {
case DTypeEnum::Float32: {
return {DTypeValue::make(dtype::Complex64())};
}
default:
mgb_assert(
0, "unsupported dtype %s", input->dtype()->name());
}
case GetAttr::Device:
case GetAttr::Shape:
return imperative::apply(op, input->real());
case GetAttr::Value: {
auto complex = make_complex_tensor(
input->real().numpy()->as_nd(),
input->imag().numpy()->as_nd());
return {HostValue::make(complex)};
}
default:
mgb_throw(
MegBrainError, "unsupported %s for complex",
get_attr->to_string().c_str());
}
} else if (auto* as_real = op.as<GetReal>()) {
auto&& input = inputs[0].as_ref(m_complex_type);
return {input->real()};
} else if (auto* as_real = op.as<GetImag>()) {
auto&& input = inputs[0].as_ref(m_complex_type);
return {input->imag()};
}
mgb_throw(
MegBrainError, "unsupported op for complex: %s",
op.to_string().c_str());
}
ValueRef unwrap(ValueRef value) override {
mgb_assert(!value.is(m_complex_type), "cannot unwrap complex value");
return value;
}
};
} // namespace imperative
} // namespace mgb
......@@ -400,6 +400,8 @@ cg::OperatorNodeBase::NodeProp* Loop::do_make_node_prop() const {
break;
case DTypeEnum::Bool:
break;
case DTypeEnum::Complex64:
break;
#define cb(x) \
case DTypeEnum::x: \
......
......@@ -235,6 +235,8 @@ public:
break;
case DTypeEnum::Uint16:
break;
case DTypeEnum::Complex64:
break;
#define cb(_dt) \
case DTypeEnum::_dt: \
break;
......
......@@ -24,6 +24,7 @@ enum DTypeEnum : byte {
Bool,
Uint16,
QuantizedS1,
Complex64,
}
table LinearQuantizationParam {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册