提交 b3e958d0 编写于 作者: M Megvii Engine Team

fix(src): fix the warnings and copy.bara.sky in custom op

GitOrigin-RevId: 4ade45589798cce48286ddb0f3a4298f7e03cd49
上级 cdb692d2
......@@ -613,17 +613,17 @@ void init_ops(py::module m) {
}
#define CUSTOM_CASE_TO_PARSE_NON_LIST(dyn_type, static_type) \
case mgb::custom::ParamDynType::dyn_type: { \
case custom::ParamDynType::dyn_type: { \
param_val = py::handle(kv.second).cast<static_type>(); \
break; \
}
#define CUSTOM_CASE_TO_PARSE_LIST(dyn_type, static_type) \
case mgb::custom::ParamDynType::dyn_type: { \
case custom::ParamDynType::dyn_type: { \
auto pyvals = py::handle(kv.second).cast<py::list>(); \
static_type vals; \
using basic_type = \
mgb::custom::get_vector_template_arg_type<static_type>::type; \
custom::get_vector_template_arg_type<static_type>::type; \
for (auto &pyval: pyvals) { \
vals.push_back(py::handle(pyval).cast<basic_type>()); \
} \
......@@ -631,7 +631,7 @@ void init_ops(py::module m) {
break; \
}
PyObject *make_custom_op(PyObject *self, PyObject **args, Py_ssize_t nargs, PyObject *kwnames) {
PyObject *make_custom_op(PyObject *self, PyObject **args, Py_ssize_t nargs) {
auto op_name = py::handle(args[0]).cast<std::string>();
auto kwargs = py::handle(args[1]).cast<py::dict>();
......@@ -680,7 +680,7 @@ PyObject *make_custom_op(PyObject *self, PyObject **args, Py_ssize_t nargs, PyOb
py::list install_custom(const std::string &name, const std::string &path) {
py::list ret;
const auto &ops_in_lib = mgb::custom::LibManager::inst()->install(name, path);
const auto &ops_in_lib = custom::LibManager::inst()->install(name, path);
for (const auto &op: ops_in_lib) {
ret.append(op);
}
......@@ -688,7 +688,7 @@ py::list install_custom(const std::string &name, const std::string &path) {
}
bool uninstall_custom(const std::string &name) {
return mgb::custom::LibManager::inst()->uninstall(name);
return custom::LibManager::inst()->uninstall(name);
}
py::list get_custom_op_list(void) {
......@@ -697,16 +697,28 @@ py::list get_custom_op_list(void) {
for (auto &op: all_ops) {
ret.append(op);
}
return std::move(ret);
return ret;
}
#ifndef METH_FASTCALL
PyObject* py35_make_custom_op(PyObject* self, PyObject* args) {
auto* arr = &PyTuple_GET_ITEM(args, 0);
auto size = PyTuple_GET_SIZE(args);
return make_custom_op(self, arr, size);
};
#endif
void init_custom(pybind11::module m) {
m.def("_install", &install_custom);
m.def("_uninstall", &uninstall_custom);
m.def("_get_custom_op_list", &get_custom_op_list);
static PyMethodDef method_def = {
#ifdef METH_FASTCALL
"_make_custom_op", (PyCFunction)make_custom_op, METH_FASTCALL, ""
#else
"_make_custom_op", (PyCFunction)py35_make_custom_op, METH_VARARGS, ""
#endif
};
auto* func = PyCFunction_NewEx(&method_def, nullptr, nullptr);
pybind11::setattr(m, method_def.ml_name, func);
......
......@@ -70,7 +70,7 @@ void CustomOpDef::compute(const SmallVector<DeviceTensorND> &inputs,
std::tuple<SmallVector<LogicalTensorDesc>, bool> CustomOpDef::infer_output_attrs(
const SmallVector<TensorPtr> &inputs) const {
SmallVector<LogicalTensorDesc> input_descs(inputs.size());
for (int i=0; i<inputs.size(); i++) {
for (size_t i=0; i<inputs.size(); i++) {
input_descs[i].comp_node = inputs[i]->comp_node();
input_descs[i].layout = inputs[i]->layout();
}
......@@ -84,7 +84,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> CustomOpDef::infer_output_attrs
SmallVector<megdnn::DType> i_dtypes(inputs.size());
SmallVector<TensorFormat> i_formats(inputs.size());
for (int i=0; i<inputs.size(); i++) {
for (size_t i=0; i<inputs.size(); i++) {
i_devices[i] = inputs[i].comp_node;
i_shapes[i] = inputs[i].layout; // TensorLayout is derived from TensorShape
i_dtypes[i] = inputs[i].layout.dtype;
......@@ -132,7 +132,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> CustomOpDef::infer_output_attrs
}
SmallVector<LogicalTensorDesc> outputs(this->output_num());
for (int i=0; i<this->output_num(); i++) {
for (size_t i=0; i<this->output_num(); i++) {
outputs[i].comp_node = std::move(o_devices[i]);
outputs[i].layout = std::move(
TensorLayout(o_shapes[i], o_dtypes[i], o_formats[i])
......
/**
* \file src/custom/impl/manager.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/custom/manager.h"
#include "megbrain/common.h"
#include <unordered_set>
#ifndef _WIN32
#include <dlfcn.h>
#endif
using namespace mgb;
namespace custom {
CustomOpManager *CustomOpManager::inst(void) {
static CustomOpManager op_manager;
return &op_manager;
}
CustomOpManager::~CustomOpManager() {
mgb_assert(m_name2op.size() == m_id2op.size(), "Custom Op maintenance error!");
LibManager::inst()->m_custom_libs.clear();
}
std::shared_ptr<CustomOp> CustomOpManager::insert(const std::string &name, uint32_t version) {
MGB_LOCK_GUARD(m_mtx);
auto iter = m_name2op.find(name);
if (iter != m_name2op.end()) {
mgb_log_warn("Register Custom Op Failed! Op %s has been registered", name.c_str());
return std::const_pointer_cast<CustomOp, const CustomOp>(iter->second);
}
std::shared_ptr<const CustomOp> op = std::make_shared<const CustomOp>(name, version);
m_name2op[op->op_type()] = op;
m_id2op[op->runtime_id()] = op;
return std::const_pointer_cast<CustomOp, const CustomOp>(op);
}
bool CustomOpManager::erase(const std::string &name) {
MGB_LOCK_GUARD(m_mtx);
auto iter = m_name2op.find(name);
if (iter == m_name2op.end()) {
mgb_log_warn("Erase Custom Op Failed! %s has not been registered", name.c_str());
return false;
}
std::shared_ptr<const CustomOp> op = iter->second;
m_id2op.erase(op->runtime_id());
m_name2op.erase(op->op_type());
return true;
}
bool CustomOpManager::erase(const RunTimeId &id) {
MGB_LOCK_GUARD(m_mtx);
auto iter = m_id2op.find(id);
if (iter == m_id2op.end()) {
mgb_log_warn("Erase Custom Op Failed! The Op has not been registered");
return false;
}
std::shared_ptr<const CustomOp> op = iter->second;
m_id2op.erase(op->runtime_id());
m_name2op.erase(op->op_type());
return true;
}
std::shared_ptr<CustomOp> CustomOpManager::find_or_reg(const std::string &name, uint32_t version) {
auto iter = m_name2op.find(name);
if (iter == m_name2op.end()) {
return insert(name, version);
}
return std::const_pointer_cast<CustomOp, const CustomOp>(iter->second);
}
RunTimeId CustomOpManager::to_id(const std::string &name) const {
std::shared_ptr<const CustomOp> op = find(name);
return op->runtime_id();
}
std::string CustomOpManager::to_name(const RunTimeId &id) const {
std::shared_ptr<const CustomOp> op = find(id);
return op->op_type();
}
std::shared_ptr<const CustomOp> CustomOpManager::find(const std::string &name) const {
auto ret = m_name2op.find(name);
mgb_assert(ret != m_name2op.end(),
"Find Custom Op Failed! Op %s has not been registered", name.c_str()
);
return ret->second;
}
std::shared_ptr<const CustomOp> CustomOpManager::find(const RunTimeId &id) const {
auto ret = m_id2op.find(id);
mgb_assert(ret != m_id2op.end(), "Find Custom Op Failed! Op has not been registered");
return ret->second;
}
std::vector<std::string> CustomOpManager::op_name_list(void) {
std::vector<std::string> ret;
for (auto kv: m_name2op) {
ret.emplace_back(kv.first);
}
return ret;
}
std::vector<RunTimeId> CustomOpManager::op_id_list(void) {
std::vector<RunTimeId> ret;
for (auto kv: m_id2op) {
ret.emplace_back(kv.first);
}
return ret;
}
#ifndef _WIN32
CustomLib::CustomLib(const std::string &path, int mode = RTLD_LAZY)
: m_handle(nullptr, [](void* handle) {dlclose(handle);}) {
auto op_list_before_load = CustomOpManager::inst()->op_name_list();
std::unordered_set<std::string> op_set_before_load(
op_list_before_load.begin(), op_list_before_load.end());
m_handle.reset(dlopen(path.c_str(), mode));
mgb_assert(m_handle != nullptr, "open custom op lib failed, error type: %s", dlerror());
auto op_list_after_load = CustomOpManager::inst()->op_name_list();
for (auto &op: op_list_after_load) {
if (op_set_before_load.find(op) == op_set_before_load.end()) {
m_ops.emplace_back(op);
}
}
}
#else
CustomLib::CustomLib(const std::string &path, int mode = 0)
: m_handle(nullptr, [](void* handle) {}) {
mgb_assert(false, "custom op is only supported on Linux now");
}
#endif
const std::vector<std::string> &CustomLib::ops_in_lib(void) const {
return m_ops;
}
CustomLib::~CustomLib() {
for (auto &op: m_ops) {
CustomOpManager::inst()->erase(op);
}
}
bool CustomLib::valid() const {
return m_handle != nullptr;
}
LibManager *LibManager::inst(void) {
static LibManager custom_libs;
return &custom_libs;
}
const std::vector<std::string> &LibManager::install(const std::string &name, const std::string &path) {
MGB_LOCK_GUARD(m_mtx);;
LibHandle handle = std::make_shared<CustomLib>(path);
m_custom_libs.insert({name, handle});
return m_custom_libs[name]->ops_in_lib();
}
bool LibManager::uninstall(const std::string &name) {
MGB_LOCK_GUARD(m_mtx);;
mgb_assert(m_custom_libs.erase(name) == 1, "uninstall error");
return true;
}
std::shared_ptr<CustomOp> op_insert(std::string opname, uint32_t version) {
return CustomOpManager::inst()->insert(opname, version);
}
}
/**
* \file src/custom/impl/op.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/common.h"
#include "megbrain/custom/op.h"
#include "megbrain/custom/utils.h"
#include <unordered_set>
#include <sstream>
using namespace mgb;
namespace custom {
class ArgInfoImpl {
std::string m_name;
std::string m_desc;
std::unordered_set<std::string> m_dtypes;
int m_ndim; // use int rather than size_t for representing m_dims = -1
std::string m_mem_stgy;
friend class ArgInfo;
};
CUSTOM_PIMPL_CLS_DEFINE(ArgInfo)
ArgInfo::ArgInfo(const std::string &name,
const std::string &desc,
const std::unordered_set<std::string> &dtypes,
const int &ndim,
const std::string &mem_stgy): m_impl(new ArgInfoImpl(), impl_deleter<ArgInfoImpl>) {
for (auto &&dtype: dtypes) {
mgb_assert(DType::is_legal(dtype), "unsupported tensor data type: %s", dtype.c_str());
}
mgb_assert(mem_stgy == "default", "only default mem strategy is supported now!");
TypedRef(ArgInfoImpl, m_impl.get()).m_name = name;
TypedRef(ArgInfoImpl, m_impl.get()).m_desc = desc;
TypedRef(ArgInfoImpl, m_impl.get()).m_dtypes = dtypes;
TypedRef(ArgInfoImpl, m_impl.get()).m_ndim = ndim;
TypedRef(ArgInfoImpl, m_impl.get()).m_mem_stgy = mem_stgy;
}
const std::string &ArgInfo::name(void) const {
return TypedRef(ArgInfoImpl, m_impl.get()).m_name;
}
const std::string &ArgInfo::desc(void) const {
return TypedRef(ArgInfoImpl, m_impl.get()).m_desc;
}
const std::unordered_set<std::string> &ArgInfo::dtypes(void) const {
return TypedRef(ArgInfoImpl, m_impl.get()).m_dtypes;
}
int ArgInfo::ndim(void) const {
return TypedRef(ArgInfoImpl, m_impl.get()).m_ndim;
}
const std::string &ArgInfo::mem_strategy(void) const {
return TypedRef(ArgInfoImpl, m_impl.get()).m_mem_stgy;
}
std::string ArgInfo::str() const {
std::stringstream ss;
ss << "name: " << TypedRef(ArgInfoImpl, m_impl.get()).m_name << "\n"
<< "desc: " << TypedRef(ArgInfoImpl, m_impl.get()).m_desc << "\nlegal_dtypes: {";
for (auto &val: TypedRef(ArgInfoImpl, m_impl.get()).m_dtypes) {
ss << val << ", ";
}
if (TypedRef(ArgInfoImpl, m_impl.get()).m_dtypes.size() != 0) {
ss.seekp(ss.tellp()-std::streampos(2));
}
ss << "}\ndims: " << TypedRef(ArgInfoImpl, m_impl.get()).m_ndim << "\n"
<< "memory_strategy: " << TypedRef(ArgInfoImpl, m_impl.get()).m_mem_stgy;
return ss.str();
}
#define assert_inputs_size_right(inputs_vec) \
mgb_assert( \
inputs_vec.size() == input_num(), \
"op %s need %lu inputs but given %lu", \
op_type().c_str(), static_cast<unsigned long>(input_num()), \
static_cast<unsigned long>(inputs_vec.size()) \
)
#define assert_outputs_size_right(outputs_vec) \
mgb_assert( \
outputs_vec.size() == output_num(), \
"op %s have %lu outputs but given %lu", \
op_type().c_str(), static_cast<unsigned long>(output_num()), \
static_cast<unsigned long>(outputs_vec.size()) \
)
#define assert_arg_shape_dim_right(real_shape, arg_info) \
mgb_assert( \
(arg_info).ndim() == -1 || static_cast<int>((real_shape).ndim()) == \
static_cast<int>((arg_info).ndim()), \
"%s's args: %s dim match error, need %d but given %d", op_type().c_str(), \
(arg_info).name().c_str(), static_cast<int>((arg_info).ndim()), \
static_cast<int>((real_shape).ndim()) \
)
template <typename T>
class Function;
template<typename RType, typename... Args>
class Function<RType(Args...)> {
public:
using Functor = RType (*)(Args...);
Function() = default;
Function(Functor f): m_f(f) {}
Function(const Function &rhs) {
m_f = rhs.m_f;
}
RType operator()(Args... args) {
custom_assert(m_f != nullptr, "invalid function ptr\n");
return m_f(std::forward<Args>(args)...);
}
void operator=(const Function &rhs) { // not allowed continuous assignment
m_f = rhs.m_f;
}
void operator=(const Functor f) {
m_f = f;
}
private:
Functor m_f = nullptr;
};
template <typename Functions>
class FuncWithSig: public Functions {
public:
using Functions::operator();
using Functions::operator=;
};
class CustomOpImpl {
static constexpr uint32_t CURRENT_VERSION = CUSTOM_OP_VERSION;
const uint32_t m_version;
const std::string m_op_type;
std::string m_op_desc;
std::vector<ArgInfo> m_input_infos;
std::vector<ArgInfo> m_output_infos;
ParamInfo m_param_infos;
using DeviceInfer = FuncWithSig<Function<void(const std::vector<Device>&, const Param&, std::vector<Device>&)>>;
using ShapeInfer = FuncWithSig<Function<void(const std::vector<Shape>&, const Param&, std::vector<Shape>&)>>;
using DTypeInfer = FuncWithSig<Function<void(const std::vector<DType>&, const Param&, std::vector<DType>&)>>;
using FormatInfer = FuncWithSig<Function<void(const std::vector<Format>&, const Param&, std::vector<Format>&)>>;
using Preprocess = FuncWithSig<Function<void(const std::vector<Tensor>&, const Param&, std::vector<Tensor>&)>>;
using Postprocess = FuncWithSig<Function<void(const std::vector<Tensor>&, const Param&, std::vector<Tensor>&)>>;
using Compute = FuncWithSig<Function<void(const std::vector<Tensor>&, const Param&, std::vector<Tensor>&)>>;
DeviceInfer infer_output_device_func;
ShapeInfer infer_output_shape_func;
DTypeInfer infer_output_dtype_func;
FormatInfer infer_output_format_func;
std::unordered_map<std::string, Compute> compute_funcs;
std::unordered_map<std::string, Preprocess> preprocess_funcs;
std::unordered_map<std::string, Postprocess> postprocess_funcs;
public:
CustomOpImpl(const std::string&, uint32_t version);
PREVENT_COPY_AND_ASSIGN(CustomOpImpl);
friend CustomOp;
};
CustomOpImpl::CustomOpImpl(const std::string &op_type, uint32_t version)
: m_version(version), m_op_type(op_type) {
if (m_version != CURRENT_VERSION) {
mgb_log_warn(
"the version of loaded custom op %s is %u, but custom op version "
"of the system is %u\n", op_type.c_str(), m_version, CURRENT_VERSION
);
}
infer_output_device_func = [](const std::vector<Device> &inputs,
const Param&,
std::vector<Device> &outputs) -> void {
static UnImpleWarnLog log_once("output_device_infer", "device", "x86");
for (size_t i=0; i<outputs.size(); ++i) {
outputs[i] = inputs.size() > 0 ? inputs[0] : Device("x86");
}
};
infer_output_shape_func = [](const std::vector<Shape> &inputs,
const Param&,
std::vector<Shape> &outputs) -> void {
static UnImpleWarnLog log_once("output_shape_infer", "shape", "{1}");
for (size_t i=0; i<outputs.size(); ++i) {
outputs[i] = inputs.size() > 0 ? inputs[0] : Shape({1});
}
};
infer_output_dtype_func = [](const std::vector<DType> &inputs,
const Param&,
std::vector<DType> &outputs) -> void {
static UnImpleWarnLog log_once("output_dtype_infer", "dtype", "float32");
for (size_t i=0; i<outputs.size(); ++i) {
outputs[i] = inputs.size() > 0 ? inputs[0] : DType("float32");
}
};
infer_output_format_func = [](const std::vector<Format> &inputs,
const Param&,
std::vector<Format> &outputs) -> void {
for (size_t i=0; i<outputs.size(); ++i) {
outputs[i] = inputs.size() > 0 ? inputs[0] : Format("default");
}
};
for (const auto &device: Device::legal_devices()) {
compute_funcs[device] = [](const std::vector<Tensor>&, const Param&, std::vector<Tensor> &outputs) -> void {
auto device = outputs[0].device();
mgb_assert(false, "There is no forward function for your op on device `%s`. "
"Please implement this function and register it.", device.str().c_str());
};
preprocess_funcs[device] = [](const std::vector<Tensor>&, const Param&, std::vector<Tensor>&) -> void {
return;
};
postprocess_funcs[device] = [](const std::vector<Tensor>&, const Param&, std::vector<Tensor>&) -> void {
return;
};
}
m_param_infos.set_tag(op_type);
}
CustomOp::CustomOp(const std::string &op_type, uint32_t version)
: m_impl(new CustomOpImpl(op_type, version), impl_deleter<CustomOpImpl>) {
}
#define OpImplRef(raw_ptr) reinterpret_cast<CustomOpImpl*>(raw_ptr)
CustomOp &CustomOp::set_device_infer(DeviceInferFuncPtr func) {
OpImplRef(m_impl.get())->infer_output_device_func = func;
return *this;
}
CustomOp &CustomOp::set_shape_infer(ShapeInferFuncPtr func) {
OpImplRef(m_impl.get())->infer_output_shape_func = func;
return *this;
}
CustomOp &CustomOp::set_dtype_infer(DTypeInferFuncPtr func) {
OpImplRef(m_impl.get())->infer_output_dtype_func = func;
return *this;
}
CustomOp &CustomOp::set_format_infer(FormatInferFuncPtr func) {
OpImplRef(m_impl.get())->infer_output_format_func = func;
return *this;
}
CustomOp &CustomOp::set_preprocess(PreprocessFuncPtr func) {
set_preprocess("x86", func);
return *this;
}
CustomOp &CustomOp::set_preprocess(const std::string &device, PreprocessFuncPtr func) {
OpImplRef(m_impl.get())->preprocess_funcs[device] = func;
return *this;
}
CustomOp &CustomOp::set_postprocess(PostprocessFuncPtr func) {
set_postprocess("x86", func);
return *this;
}
CustomOp &CustomOp::set_postprocess(const std::string &device, PostprocessFuncPtr func) {
OpImplRef(m_impl.get())->postprocess_funcs[device] = func;
return *this;
}
CustomOp &CustomOp::set_compute(ComputeFuncPtr func) {
set_compute("x86", func);
return *this;
}
CustomOp &CustomOp::set_compute(const std::string &device, ComputeFuncPtr func) {
OpImplRef(m_impl.get())->compute_funcs[device] = func;
return *this;
}
CustomOp &CustomOp::set_description(const std::string &op_desc) {
OpImplRef(m_impl.get())->m_op_desc = op_desc;
return *this;
}
CustomOp &CustomOp::add_input(const std::string &name, const std::string &desc, const std::initializer_list<std::string> &legal_dtypes, int dims, const std::string &mem_stgy) {
auto &ref = OpImplRef(m_impl.get())->m_input_infos;
for (const auto &input: ref) {
mgb_assert(input.name() != name, "input %s has been registered", name.c_str());
}
ref.emplace_back(name, desc, legal_dtypes, dims, mem_stgy);
return *this;
}
CustomOp &CustomOp::add_output(const std::string &name, const std::string &desc, const std::initializer_list<std::string> &legal_dtypes, int dims, const std::string &mem_stgy) {
auto &ref = OpImplRef(m_impl.get())->m_output_infos;
for (const auto &output: ref) {
mgb_assert(output.name() != name, "output %s has been registered", name.c_str());
}
ref.emplace_back(name, desc, legal_dtypes, dims, mem_stgy);
return *this;
}
CustomOp &CustomOp::add_input(const std::string &name, const std::initializer_list<std::string> &legal_dtypes, int dims, const std::string &mem_stgy) {
add_input(name, name, legal_dtypes, dims, mem_stgy);
return *this;
}
CustomOp &CustomOp::add_output(const std::string &name, const std::initializer_list<std::string> &legal_dtypes, int dims, const std::string &mem_stgy) {
add_output(name, name, legal_dtypes, dims, mem_stgy);
return *this;
}
CustomOp &CustomOp::add_inputs(const size_t &num) {
size_t cur_inp_num = input_num();
for (size_t i=cur_inp_num; i<cur_inp_num+num; i++) {
add_input(op_type() + "_Input_" + std::to_string(i));
}
return *this;
}
CustomOp &CustomOp::add_outputs(const size_t &num) {
size_t cur_oup_num = output_num();
for (size_t i=cur_oup_num; i<cur_oup_num+num; i++) {
add_output(op_type() + "_Output_" + std::to_string(i));
}
return *this;
}
CustomOp &CustomOp::add_param(const std::string &name, const ParamVal &default_val) {
add_param(name, name, default_val);
return *this;
}
CustomOp &CustomOp::add_param(const std::string &name, const std::string &desc, const ParamVal &default_val) {
auto &meta = OpImplRef(m_impl.get())->m_param_infos.meta();
for(const auto &schema: meta) {
mgb_assert(name != schema.name(), "param %s has been registered\n", name.c_str());
}
ParamSchema sch = ParamSchema(name, default_val, desc);
meta.emplace_back(sch);
return *this;
}
std::string CustomOp::op_type(void) const {
return OpImplRef(m_impl.get())->m_op_type;
}
std::string CustomOp::op_desc(void) const {
return OpImplRef(m_impl.get())->m_op_desc;
}
RunTimeId CustomOp::runtime_id(void) const {
return (RunTimeId)(this);
}
size_t CustomOp::input_num(void) const {
return OpImplRef(m_impl.get())->m_input_infos.size();
}
size_t CustomOp::output_num(void) const {
return OpImplRef(m_impl.get())->m_output_infos.size();
}
std::string CustomOp::str(void) const {
std::stringstream ss;
ss << "op name: " << op_type() << "\nop desc: " << op_desc() << "\n\ninputs:\n";
for (const auto &input: inputs_info()) {
ss << input.str();
ss << "\n--------------------\n";
}
ss << "\noutputs:\n";
for (const auto &output: outputs_info()) {
ss << output.str();
ss << "\n--------------------\n";
}
ss << "\nparams:\n";
for (const auto &param: param_info().meta()) {
ss << param.str();
ss << "\n--------------------\n";
}
return ss.str();
}
const ParamInfo &CustomOp::param_info(void) const {
return OpImplRef(m_impl.get())->m_param_infos;
}
ArgInfo CustomOp::input_info(size_t idx) const {
return OpImplRef(m_impl.get())->m_input_infos[idx];
}
ArgInfo CustomOp::output_info(size_t idx) const {
return OpImplRef(m_impl.get())->m_output_infos[idx];
}
const std::vector<ArgInfo> &CustomOp::inputs_info(void) const {
return OpImplRef(m_impl.get())->m_input_infos;
}
const std::vector<ArgInfo> &CustomOp::outputs_info(void) const {
return OpImplRef(m_impl.get())->m_output_infos;
}
std::vector<Device> CustomOp::infer_output_device(const std::vector<Device> &inputs, const Param &param) const {
assert_inputs_size_right(inputs);
std::vector<Device> outputs(output_num());
OpImplRef(m_impl.get())->infer_output_device_func(inputs, param, outputs);
assert_outputs_size_right(outputs);
return outputs;
}
std::vector<Shape> CustomOp::infer_output_shape(const std::vector<Shape> &inputs, const Param &param) const {
assert_inputs_size_right(inputs);
for (size_t i=0; i<inputs_info().size(); i++) {
assert_arg_shape_dim_right(inputs[i], input_info(i));
}
std::vector<Shape> outputs(output_num());
OpImplRef(m_impl.get())->infer_output_shape_func(inputs, param, outputs);
for (size_t i=0; i<outputs_info().size(); i++) {
assert_arg_shape_dim_right(outputs[i], output_info(i));
}
assert_outputs_size_right(outputs);
return outputs;
}
std::vector<DType> CustomOp::infer_output_dtype(const std::vector<DType> &inputs, const Param &param) const {
assert_inputs_size_right(inputs);
for (size_t i=0; i<inputs_info().size(); i++) {
std::unordered_set<std::string> legal_input_dtypes_i = input_info(i).dtypes();
mgb_assert(
legal_input_dtypes_i.find(inputs[i].str()) != legal_input_dtypes_i.end(),
"dtypes of input: %s(%s) is not allowed, the info of this input is:\n%s",
input_info(i).name().c_str(), inputs[i].str().c_str(),
input_info(i).str().c_str()
);
}
std::vector<DType> outputs(output_num());
OpImplRef(m_impl.get())->infer_output_dtype_func(inputs, param, outputs);
for (size_t i=0; i<outputs_info().size(); i++) {
std::unordered_set<std::string> legal_output_dtypes_i = output_info(i).dtypes();
mgb_assert(
legal_output_dtypes_i.find(outputs[i].str()) != legal_output_dtypes_i.end(),
"dtypes of output: %s is %s, the info of this output is:\n%s",
output_info(i).name().c_str(), outputs[i].str().c_str(),
output_info(i).str().c_str()
);
}
assert_outputs_size_right(outputs);
return outputs;
}
std::vector<Format> CustomOp::infer_output_format(const std::vector<Format> &inputs, const Param &param) const {
assert_inputs_size_right(inputs);
for (size_t i=0; i<inputs.size(); i++) {
mgb_assert(
inputs[i].is_default(),
"the tensor format of %s:%s is not default",
op_type().c_str(), input_info(i).name().c_str()
);
}
std::vector<Format> outputs(output_num());
OpImplRef(m_impl.get())->infer_output_format_func(inputs, param, outputs);
for (size_t i=0; i<outputs.size(); i++) {
mgb_assert(
outputs[i].is_default(),
"the tensor format of %s:%s is not default",
op_type().c_str(), output_info(i).name().c_str()
);
}
assert_outputs_size_right(outputs);
return outputs;
}
void CustomOp::compute(const std::vector<Tensor> &inputs, const Param &param, std::vector<Tensor> &outputs) const {
assert_inputs_size_right(inputs);
assert_outputs_size_right(outputs);
if (outputs.size() == 0) {
return;
}
std::string device = outputs[0].device().str();
for (size_t i=1; i<outputs.size(); ++i) {
mgb_assert(
outputs[i].device().str() == device,
"all output tensors should have the same device attribute"
);
}
// need to add other input/output check
mgb_assert(Device::is_legal(device), "unsupported device type: %s", device.c_str());
auto preprocess_func = OpImplRef(m_impl.get())->preprocess_funcs[device];
auto forward_func = OpImplRef(m_impl.get())->compute_funcs[device];
auto postprocess_func = OpImplRef(m_impl.get())->postprocess_funcs[device];
preprocess_func(inputs, param, outputs);
forward_func(inputs, param, outputs);
postprocess_func(outputs, param, outputs);
assert_outputs_size_right(outputs);
}
}
/**
* \file src/custom/impl/param.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/custom/param.h"
#include "megbrain/common.h"
#include "megbrain/utils/hash.h"
#include <limits>
#include <sstream>
#include <map>
using namespace mgb;
namespace custom {
class ParamSchemaImpl {
std::string m_name;
std::string m_desc;
ParamVal m_default;
friend ParamSchema;
};
class ParamInfoImpl {
std::vector<ParamSchema> m_meta;
uint32_t TAG;
friend ParamInfo;
};
class ParamImpl {
std::unordered_map<std::string, ParamVal> m_vals;
ParamImpl() = default;
ParamImpl(const ParamImpl &rhs) = default;
ParamImpl &operator=(const ParamImpl &rhs) {
mgb_assert(
m_vals.size() == rhs.m_vals.size(),
"params of different op, assignment failed!"
);
for (const auto &kv: rhs.m_vals) {
auto iter = m_vals.find(kv.first);
mgb_assert(iter != m_vals.end(), "params of different op, assignment failed!");
iter->second = kv.second;
}
return *this;
}
friend Param;
};
CUSTOM_PIMPL_CLS_DEFINE(ParamSchema)
ParamSchema::ParamSchema(const std::string &name, const ParamVal &value, const std::string &desc)
: m_impl(new ParamSchemaImpl(), impl_deleter<ParamSchemaImpl>) {
TypedRef(ParamSchemaImpl, m_impl.get()).m_name = name;
TypedRef(ParamSchemaImpl, m_impl.get()).m_default = value;
TypedRef(ParamSchemaImpl, m_impl.get()).m_desc = desc;
}
const std::string &ParamSchema::name(void) const {
return TypedRef(ParamSchemaImpl, m_impl.get()).m_name;
}
const std::string &ParamSchema::desc(void) const {
return TypedRef(ParamSchemaImpl, m_impl.get()).m_desc;
}
const ParamVal &ParamSchema::default_val(void) const {
return TypedRef(ParamSchemaImpl, m_impl.get()).m_default;
}
ParamDynType ParamSchema::type(void) const {
return TypedRef(ParamSchemaImpl, m_impl.get()).m_default.type();
}
std::string ParamSchema::str(void) const {
std::stringstream ss;
ss << "name: " << TypedRef(ParamSchemaImpl, m_impl.get()).m_name
<< "\ndesc: " << TypedRef(ParamSchemaImpl, m_impl.get()).m_desc
<< "\n" << TypedRef(ParamSchemaImpl, m_impl.get()).m_default.str();
return ss.str();
}
CUSTOM_PIMPL_CLS_DEFINE(ParamInfo)
void ParamInfo::set_tag(const std::string &hash_str) {
const char *ptr = hash_str.c_str();
TypedRef(ParamInfoImpl, m_impl.get()).TAG = 0;
for (size_t i=0; i<hash_str.size(); i++) {
TypedRef(ParamInfoImpl, m_impl.get()).TAG =
mgb::hash_pair_combine(TypedRef(ParamInfoImpl, m_impl.get()).TAG, mgb::hash(*(ptr++))) %
std::numeric_limits<uint32_t>::max();
}
}
void ParamInfo::set_meta(const std::vector<ParamSchema> &meta) {
TypedRef(ParamInfoImpl, m_impl.get()).m_meta = meta;
}
uint32_t ParamInfo::tag(void) const {
return TypedRef(ParamInfoImpl, m_impl.get()).TAG;
}
std::vector<ParamSchema> &ParamInfo::meta(void) {
return TypedRef(ParamInfoImpl, m_impl.get()).m_meta;
}
const std::vector<ParamSchema> &ParamInfo::meta(void) const {
return TypedRef(ParamInfoImpl, m_impl.get()).m_meta;
}
CUSTOM_PIMPL_CLS_DEFINE(Param)
Param::Param(const ParamInfo &info): m_impl(new ParamImpl(), impl_deleter<ParamImpl>) {
for (const auto &schema: info.meta()) {
TypedRef(ParamImpl, m_impl.get()).m_vals.emplace(schema.name(), schema.default_val());
}
}
ParamVal &Param::operator[](const std::string &name) {
return TypedRef(ParamImpl, m_impl.get()).m_vals.find(name)->second;
}
const ParamVal &Param::operator[](const std::string &name) const {
return TypedRef(ParamImpl, m_impl.get()).m_vals.find(name)->second;
}
const std::unordered_map<std::string, ParamVal> &Param::raw() const {
return TypedRef(ParamImpl, m_impl.get()).m_vals;
}
bool Param::exist(const std::string &name) const {
return TypedRef(ParamImpl, m_impl.get()).m_vals.find(name) !=
TypedRef(ParamImpl, m_impl.get()).m_vals.end();
}
std::string Param::to_bytes(void) const {
std::string res;
std::map<std::string, ParamVal> ordered_vals(
TypedRef(ParamImpl, m_impl.get()).m_vals.begin(),
TypedRef(ParamImpl, m_impl.get()).m_vals.end());
for (auto &&kv: ordered_vals) {
res += ParamVal::to_bytes(kv.second);
}
return res;
}
void Param::from_bytes(const std::string &bytes) {
std::map<std::string, ParamVal> ordered_vals(
TypedRef(ParamImpl, m_impl.get()).m_vals.begin(),
TypedRef(ParamImpl, m_impl.get()).m_vals.end());
size_t offset = 0;
for (auto &kv: ordered_vals) {
kv.second = ParamVal::from_bytes(bytes, offset);
}
TypedRef(ParamImpl, m_impl.get()).m_vals.clear();
TypedRef(ParamImpl, m_impl.get()).m_vals.insert(ordered_vals.begin(), ordered_vals.end());
mgb_assert(offset == bytes.size(), "wrong data loader");
}
bool operator==(const Param &lhs, const Param &rhs) {
if (lhs.raw().size() != rhs.raw().size())
return false;
for (const auto &kv: lhs.raw()) {
auto riter = rhs.raw().find(kv.first);
if (riter == rhs.raw().end() || !((kv.second) == riter->second)) {
return false;
}
}
return true;
}
}
/**
* \file src/custom/impl/param_val.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/custom/param_val.h"
#include "megbrain/common.h"
#pragma GCC diagnostic ignored "-Wsign-compare"
using namespace mgb;
namespace custom {
/**
* Macro Callback for Case
*/
#define CUSTOM_CASE_TO_ALLOC_ACCORD_TO_RHS(dyn_type, static_type) \
case (ParamDynType::dyn_type): { \
std::unique_ptr<void, void_deleter> new_ptr( \
new static_type(TypedRef(static_type, rhs.m_ptr.get())), \
impl_deleter<static_type> \
); \
m_ptr.swap(new_ptr); \
break; \
}
#define CUSTOM_CASE_TO_ASSIGN_ACCORD_TO_RHS(dyn_type, static_type) \
case (ParamDynType::dyn_type): { \
TypedRef(static_type, m_ptr.get()) = TypedRef(static_type, rhs.m_ptr.get());\
break; \
}
#define CUSTOM_ASSERT_OPERAND_VALID(operand, opr) \
mgb_assert( \
operand.m_ptr != nullptr && operand.m_type != ParamDynType::Invalid, \
"invalid %s of operator %s of ParamVal", #operand, #opr \
)
#define CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op) \
mgb_assert( \
lhs.m_type == rhs.m_type, "`%s` %s `%s` is not allowed", \
type2name[lhs.m_type].c_str(), #op, \
type2name[rhs.m_type].c_str() \
)
#define CUSTOM_CASE_TO_GET_BINARY_OP_RHS_AND_CAL(dyn_type, static_type, op) \
case (ParamDynType::dyn_type): { \
const auto &rval = TypedRef(static_type, rhs.m_ptr.get()); \
return lval op rval; \
}
#define CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_BASIC(dyn_type, static_type, op) \
case (ParamDynType::dyn_type): { \
const auto &lval = TypedRef(static_type, lhs.m_ptr.get()); \
switch (rhs.m_type) { \
CUSTOM_FOR_EACH_BASIC_PARAMTYPE_COPY( \
CUSTOM_CASE_TO_GET_BINARY_OP_RHS_AND_CAL, op) \
default: \
CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op); \
} \
break; \
}
#define CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_NONBASIC(dyn_type, static_type, op) \
case (ParamDynType::dyn_type): { \
CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op); \
const auto &lval = TypedRef(static_type, lhs.m_ptr.get()); \
const auto &rval = TypedRef(static_type, rhs.m_ptr.get()); \
return lval op rval; \
}
#define CUSTOM_DEFINE_BINARY_OP_FOR_BASIC(op, ret_type) \
ret_type operator op(const ParamVal &lhs, const ParamVal &rhs) { \
CUSTOM_ASSERT_OPERAND_VALID(lhs, op); \
CUSTOM_ASSERT_OPERAND_VALID(rhs, op); \
\
switch (lhs.m_type) { \
CUSTOM_FOR_EACH_BASIC_PARAMTYPE( \
CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_BASIC, op) \
default: \
CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op); \
} \
return {}; \
}
#define CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING(op, ret_type) \
ret_type operator op(const ParamVal &lhs, const ParamVal &rhs) { \
CUSTOM_ASSERT_OPERAND_VALID(lhs, op); \
CUSTOM_ASSERT_OPERAND_VALID(rhs, op); \
\
switch (lhs.m_type) { \
CUSTOM_FOR_EACH_BASIC_PARAMTYPE( \
CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_BASIC, op) \
CUSTOM_FOR_STRING_PARAMTYPE( \
CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_NONBASIC, op) \
default: \
CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op); \
} \
return {}; \
}
#define CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(op, ret_type) \
ret_type operator op(const ParamVal &lhs, const ParamVal &rhs) { \
CUSTOM_ASSERT_OPERAND_VALID(lhs, op); \
CUSTOM_ASSERT_OPERAND_VALID(rhs, op); \
\
switch (lhs.m_type) { \
CUSTOM_FOR_EACH_BASIC_PARAMTYPE( \
CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_BASIC, op) \
CUSTOM_FOR_STRING_PARAMTYPE( \
CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_NONBASIC, op) \
CUSTOM_FOR_EACH_LIST_PARAMTYPE( \
CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_NONBASIC, op) \
default: \
CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op); \
} \
return {}; \
}
#define CUSTOM_CASE_TO_PRINT_NONLIST(dyn_type, static_type) \
case (ParamDynType::dyn_type): { \
auto rval = TypedRef(static_type, m_ptr.get()); \
ss << rval; \
break; \
}
#define CUSTOM_CASE_TO_PRINT_LIST(dyn_type, static_type) \
case (ParamDynType::dyn_type): { \
auto rval = TypedRef(static_type, m_ptr.get()); \
ss << vec2str(rval); \
break; \
}
#define CUSTOM_CASE_TO_RET_SIZE(dyn_type, static_type) \
case (ParamDynType::dyn_type): { \
return TypedRef(static_type, m_ptr.get()).size(); \
break; \
}
#define CUSTOM_CASE_TO_DUMP_BASIC(dyn_type, static_type) \
case (ParamDynType::dyn_type): { \
res.resize(sizeof(ParamDynType) + sizeof(static_type)); \
memcpy(&res[0], &(value.m_type), sizeof(ParamDynType)); \
memcpy(&res[sizeof(ParamDynType)], value.m_ptr.get(), sizeof(static_type)); \
break; \
}
#define CUSTOM_CASE_TO_DUMP_LIST(dyn_type, static_type) \
case (ParamDynType::dyn_type): { \
auto &ref = TypedRef(static_type, value.m_ptr.get()); \
size_t len = ref.size(); \
size_t elem_size = len != 0 ? sizeof(ref[0]) : 0; \
res.resize(sizeof(ParamDynType) + sizeof(len) + len*elem_size); \
memcpy(&res[0], &(value.m_type), sizeof(ParamDynType)); \
memcpy(&res[sizeof(ParamDynType)], &len, sizeof(len)); \
memcpy(&res[sizeof(ParamDynType)+sizeof(len)], ref.data(), len*elem_size); \
break; \
}
#define CUSTOM_CASE_TO_LOAD_BASIC(dyn_type, static_type) \
case (ParamDynType::dyn_type): { \
static_type val; \
memcpy(&val, &bytes[offset], sizeof(val)); \
offset += sizeof(val); \
return val; \
break; \
}
#define CUSTOM_CASE_TO_LOAD_LIST(dyn_type, static_type) \
case (ParamDynType::dyn_type): { \
size_t len = 0; \
memcpy(&len, &bytes[offset], sizeof(len)); \
offset += sizeof(len); \
static_type vals; \
vals.resize(len); \
size_t elem_size = len != 0 ? sizeof(vals[0]) : 0; \
memcpy(&vals[0], &bytes[offset], len*elem_size); \
offset += len*elem_size; \
return vals; \
break; \
}
ParamVal::ParamVal(): m_ptr(nullptr, [](void*) -> void {}) {
m_type = ParamDynType::Invalid;
}
ParamVal::ParamVal(const char *str): ParamVal(std::string(str)) {
}
ParamVal::ParamVal(const std::initializer_list<const char*> &strs): ParamVal(std::vector<const char*>(strs)) {
}
ParamVal::ParamVal(const std::vector<const char*> &strs)
: m_ptr(new std::vector<std::string>(), impl_deleter<std::vector<std::string>>) {
m_type = ParamDynType::StringList;
for (const auto &str: strs) {
TypedRef(std::vector<std::string>, m_ptr.get()).emplace_back(str);
}
}
ParamVal::ParamVal(const ParamVal &rhs): m_ptr(nullptr, [](void*) -> void {}) {
mgb_assert(
rhs.m_type != ParamDynType::Invalid && rhs.m_ptr != nullptr,
"invalid rhs of copy constructor of ParamVal"
);
m_type = rhs.m_type;
switch(m_type) {
CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_CASE_TO_ALLOC_ACCORD_TO_RHS)
default: {
mgb_assert(false, "invalid rhs of copy constructor of ParamVal");
}
}
}
ParamVal &ParamVal::operator=(const char *str) {
this->operator=(std::string(str));
return *this;
}
ParamVal &ParamVal::operator=(const std::initializer_list<const char*> &strs) {
this->operator=(std::vector<const char*>(strs));
return *this;
}
ParamVal &ParamVal::operator=(const std::vector<const char*> &strs) {
std::vector<std::string> tmp_strs;
for (const auto &str: strs) {
tmp_strs.emplace_back(str);
}
this->operator=(tmp_strs);
return *this;
}
ParamVal &ParamVal::operator=(const ParamVal &rhs) {
if (&rhs == this)
return *this;
mgb_assert(
rhs.m_type != ParamDynType::Invalid && rhs.m_ptr != nullptr,
"invalid rhs of assignment operator of ParamVal"
);
if (rhs.m_type == m_type) {
switch(m_type) {
CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_CASE_TO_ASSIGN_ACCORD_TO_RHS);
default:
mgb_assert(false, "invalid rhs of assignment operator of ParamVal");
}
}
else {
m_type = rhs.m_type;
switch(m_type) {
CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_CASE_TO_ALLOC_ACCORD_TO_RHS);
default:
mgb_assert(false, "invalid rhs of assignment operator of ParamVal");
}
}
return *this;
}
const void *ParamVal::raw_ptr(void) const {
return m_ptr.get();
}
void *ParamVal::raw_ptr(void) {
return m_ptr.get();
}
ParamDynType ParamVal::type(void) const {
return m_type;
}
std::string ParamVal::str() const {
std::stringstream ss;
ss << "type: " << type2name[m_type] << "\n" << "value: ";
switch (m_type) {
CUSTOM_FOR_EACH_BASIC_PARAMTYPE(CUSTOM_CASE_TO_PRINT_NONLIST)
CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_PRINT_NONLIST)
CUSTOM_FOR_EACH_LIST_PARAMTYPE(CUSTOM_CASE_TO_PRINT_LIST)
default:
mgb_assert(false, "invalid data of assignment operator of ParamVal");
}
return ss.str();
}
size_t ParamVal::size(void) const {
switch (m_type) {
CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_RET_SIZE)
CUSTOM_FOR_EACH_LIST_PARAMTYPE(CUSTOM_CASE_TO_RET_SIZE)
default:
mgb_assert(false, "there is no size() for basic data types");
}
}
std::string ParamVal::to_bytes(const ParamVal &value) {
std::string res;
// because the specialization of std::vector<bool>
if (value.type() == ParamDynType::BoolList) {
std::vector<bool> &ref = TypedRef(std::vector<bool>, value.m_ptr.get());
size_t len = ref.size();
size_t elem_size = sizeof(bool);
res.resize(sizeof(ParamDynType) + sizeof(len) + len*elem_size);
memcpy(&res[0], &(value.m_type), sizeof(ParamDynType));
memcpy(&res[sizeof(ParamDynType)], &len, sizeof(len));
size_t startpos = sizeof(ParamDynType)+sizeof(len);
for (size_t idx=0; idx<len; idx++) {
bool b = ref[idx];
memcpy(&res[startpos+idx*sizeof(b)], &b, sizeof(b));
}
return res;
}
else if (value.type() == ParamDynType::StringList) {
std::vector<std::string> &ref = TypedRef(std::vector<std::string>, value.m_ptr.get());
size_t len = ref.size();
res.resize(sizeof(ParamDynType) + sizeof(len));
memcpy(&res[0], &(value.m_type), sizeof(ParamDynType));
memcpy(&res[sizeof(ParamDynType)], &len, sizeof(len));
for (size_t idx=0; idx<ref.size(); ++idx) {
size_t str_len = ref[idx].size();
std::string bytes(sizeof(str_len) + str_len, ' ');
memcpy(&bytes[0], &str_len, sizeof(str_len));
memcpy(&bytes[sizeof(str_len)], ref[idx].data(), str_len);
res += bytes;
}
return res;
}
switch(value.type()) {
CUSTOM_FOR_EACH_BASIC_PARAMTYPE(CUSTOM_CASE_TO_DUMP_BASIC)
CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_DUMP_LIST)
CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(CUSTOM_CASE_TO_DUMP_LIST)
default:
mgb_assert(false, "invalid param type");
}
return res;
}
ParamVal ParamVal::from_bytes(const std::string &bytes, size_t &offset) {
ParamDynType data_type = ParamDynType::Invalid;
memcpy(&data_type, &bytes[offset], sizeof(ParamDynType));
offset += sizeof(ParamDynType);
if (data_type == ParamDynType::BoolList) {
std::vector<bool> ret;
size_t len = 0;
memcpy(&len, &bytes[offset], sizeof(len));
offset += sizeof(len);
for (size_t idx =0; idx<len; ++idx) {
bool b = true;
memcpy(&b, &bytes[offset], sizeof(bool));
offset += sizeof(bool);
ret.push_back(b);
}
return ret;
}
else if (data_type == ParamDynType::StringList) {
std::vector<std::string> ret;
size_t len = 0;
memcpy(&len, &bytes[offset], sizeof(len));
offset += sizeof(len);
for (size_t idx =0; idx<len; ++idx) {
size_t str_len = 0;
memcpy(&str_len, &bytes[offset], sizeof(str_len));
offset += sizeof(str_len);
std::string str(str_len, ' ');
memcpy(&str[0], &bytes[offset], str_len);
offset += str_len;
ret.push_back(str);
}
return ret;
}
switch (data_type) {
CUSTOM_FOR_EACH_BASIC_PARAMTYPE(CUSTOM_CASE_TO_LOAD_BASIC)
CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_LOAD_LIST)
CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(CUSTOM_CASE_TO_LOAD_LIST);
default:
mgb_assert(false, "invalid param type");
}
return {};
}
CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING(+, ParamVal)
CUSTOM_DEFINE_BINARY_OP_FOR_BASIC(-, ParamVal)
CUSTOM_DEFINE_BINARY_OP_FOR_BASIC(*, ParamVal)
CUSTOM_DEFINE_BINARY_OP_FOR_BASIC(/, ParamVal)
CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(==, bool)
CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(!=, bool)
CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(>=, bool)
CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(<=, bool)
CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(>, bool)
CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(<, bool)
}
/**
* \file src/custom/impl/tensor.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/custom/tensor.h"
#include "megbrain/comp_node.h"
#include "megbrain/common.h"
#include "megbrain/tensor.h"
#include <cctype>
#include <algorithm>
using namespace mgb;
namespace custom {
template<typename T>
SmallVector<T> to_builtin_vector(const std::vector<T> &custom_data) {
SmallVector<T> builtin_data(custom_data.size());
memcpy(builtin_data.data(), custom_data.data(), sizeof(T)*custom_data.size());
return builtin_data;
}
using DeviceImpl = CompNode;
using ShapeImpl = megdnn::TensorShape;
using DTypeImpl = megdnn::DType;
using FormatImpl = megdnn::TensorLayout::Format;
using TensorImpl = DeviceTensorND;
#define DeviceImplRef(rawptr) (*reinterpret_cast<DeviceImpl*>(rawptr))
#define ShapeImplRef(rawptr) (*reinterpret_cast<ShapeImpl*>(rawptr))
#define DTypeImplRef(rawptr) (*reinterpret_cast<DTypeImpl*>(rawptr))
#define FormatImplRef(rawptr) (*reinterpret_cast<FormatImpl*>(rawptr))
#define TensorImplRef(rawptr) (*reinterpret_cast<TensorImpl*>(rawptr))
#define DeviceImplConstRef(rawptr) static_cast<const DeviceImpl&>(*reinterpret_cast<const DeviceImpl*>(rawptr))
#define ShapeImplConstRef(rawptr) static_cast<const ShapeImpl&>(*reinterpret_cast<const ShapeImpl*>(rawptr))
#define DTypeImplConstRef(rawptr) static_cast<const DTypeImpl&>(*reinterpret_cast<const DTypeImpl*>(rawptr))
#define FormatImplConstRef(rawptr) static_cast<const FormatImpl&>(*reinterpret_cast<const FormatImpl*>(rawptr))
#define TensorImplConstRef(rawptr) static_cast<const TensorImpl&>(*reinterpret_cast<const TensorImpl*>(rawptr))
static std::unordered_map<DeviceImpl::DeviceType, std::string,
EnumHash<DeviceImpl::DeviceType>,
EnumCmp<DeviceImpl::DeviceType>> dev_benum2cstr;
static std::unordered_map<DeviceImpl::DeviceType, DeviceEnum,
EnumHash<DeviceImpl::DeviceType>,
EnumCmp<DeviceImpl::DeviceType>> dev_benum2cenum;
static std::unordered_map<std::string, std::string> dev_cstr2bstr;
static std::unordered_map<DeviceEnum, std::string,
EnumHash<DeviceEnum>,
EnumCmp<DeviceEnum>> dev_cenum2bstr;
#define CUSTOM_BIND_DEVICE(custom_impl, builtin_device, builtin_str) \
auto be2cs##custom_impl = dev_benum2cstr.emplace( \
DeviceImpl::DeviceType::builtin_device, std::string(#custom_impl)); \
auto be2ce##custom_impl = dev_benum2cenum.emplace( \
DeviceImpl::DeviceType::builtin_device, DeviceEnum::custom_impl); \
auto cs2bs##custom_impl = dev_cstr2bstr.emplace( \
std::string(#custom_impl), std::string(builtin_str)); \
auto ce2bs##custom_impl = dev_cenum2bstr.emplace( \
DeviceEnum::custom_impl, std::string(builtin_str));
CUSTOM_FOR_EACH_DEVICE_TYPE(CUSTOM_BIND_DEVICE)
#undef CUSTOM_BIND_DEVICE
CUSTOM_PIMPL_CLS_DEFINE(Device)
const void *Device::impl() const {
return m_impl.get();
}
Device::Device(const void *impl): m_impl(nullptr, impl_deleter<DeviceImpl>) {
mgb_assert(impl != nullptr, "invalid ptr");
if (!DeviceImplConstRef(impl).valid()) {
m_impl.reset(new DeviceImpl());
return;
}
auto builtin_device_enum = DeviceImplConstRef(impl).device_type();
mgb_assert(
dev_benum2cenum.find(builtin_device_enum) != dev_benum2cenum.end(),
"unsupported compnode type: %s", DeviceImplConstRef(impl).to_string().c_str()
);
m_impl.reset(new DeviceImpl(DeviceImplConstRef(impl)));
}
Device::Device(const std::string &device): m_impl(nullptr, impl_deleter<DeviceImpl>) {
mgb_assert(is_legal(device), "invalid device type: %s", device.c_str());
std::string builtin_device = dev_cstr2bstr[device];
m_impl.reset(new DeviceImpl(DeviceImpl::load(builtin_device)));
}
// to avoid the ambiguous from Device(const void *impl)
Device::Device(const char *device): Device(std::string(device)) {
}
Device::Device(DeviceEnum device): m_impl(nullptr, impl_deleter<DeviceImpl>) {
mgb_assert(is_legal(device), "invalid device type");
std::string builtin_device = dev_cenum2bstr[device];
m_impl.reset(new DeviceImpl(DeviceImpl::load(builtin_device)));
}
std::string Device::str(void) const {
if (!DeviceImplRef(m_impl.get()).valid()) {
return "invalid";
}
auto builtin_device_type = DeviceImplRef(m_impl.get()).device_type();
auto iter = dev_benum2cstr.find(builtin_device_type);
mgb_assert(
iter != dev_benum2cstr.end(), "invalid device type %s\n",
DeviceImplRef(m_impl.get()).to_string().c_str()
);
return iter->second;
}
DeviceEnum Device::enumv(void) const {
mgb_assert(
DeviceImplRef(m_impl.get()).valid(),
"cannot get the enum value of invalid device"
);
auto builtin_device_type = DeviceImplRef(m_impl.get()).device_type();
auto iter = dev_benum2cenum.find(builtin_device_type);
mgb_assert(
iter != dev_benum2cenum.end(), "invalid device type %s\n",
DeviceImplRef(m_impl.get()).to_string().c_str()
);
return iter->second;
}
bool Device::is_legal(const std::string &device_type) {
return dev_cstr2bstr.find(device_type) != dev_cstr2bstr.end();
}
bool Device::is_legal(DeviceEnum device_type) {
return dev_cenum2bstr.find(device_type) != dev_cenum2bstr.end();
}
std::vector<std::string> Device::legal_devices(void) {
std::vector<std::string> ret;
for (const auto &kv: dev_cstr2bstr) {
ret.emplace_back(kv.first);
}
return ret;
}
bool operator==(const Device &lhs, const Device &rhs) {
return lhs.str() == rhs.str();
}
CUSTOM_PIMPL_CLS_DEFINE(Shape)
const void *Shape::impl() const {
return m_impl.get();
}
Shape::Shape(const void *impl): m_impl(nullptr, impl_deleter<ShapeImpl>) {
mgb_assert(impl != nullptr, "invalid ptr");
m_impl.reset(new ShapeImpl(ShapeImplConstRef(impl)));
}
Shape::Shape(const std::vector<size_t> &rhs): m_impl(nullptr, impl_deleter<ShapeImpl>) {
m_impl.reset(new ShapeImpl(to_builtin_vector<size_t>(rhs)));
}
Shape::Shape(const std::initializer_list<size_t> &rhs): m_impl(nullptr, impl_deleter<ShapeImpl>) {
m_impl.reset(new ShapeImpl(rhs));
}
size_t &Shape::operator[](size_t idx) {
mgb_assert(idx < ndim(), "wrong tensor dimension idx: %lu < %lu", static_cast<unsigned long>(idx), static_cast<unsigned long>(ndim()));
return ShapeImplRef(m_impl.get()).operator[](idx);
}
size_t Shape::operator[](size_t idx) const {
return const_cast<Shape*>(this)->operator[](idx);
}
void Shape::ndim(size_t dim) {
mgb_assert(dim < ShapeImpl::MAX_NDIM, "dimension must <= %lu", static_cast<unsigned long>(ShapeImpl::MAX_NDIM));
ShapeImplRef(m_impl.get()).ndim = dim;
}
size_t Shape::ndim(void) const {
return ShapeImplRef(m_impl.get()).ndim;
}
bool operator==(const Shape &lhs, const Shape &rhs) {
return ShapeImplRef(lhs.m_impl.get()).eq_shape(ShapeImplRef(rhs.m_impl.get()));
}
static std::unordered_map<std::string, megdnn::DTypeEnum> dtype_cstr2benum;
static std::unordered_map<DTypeEnum, megdnn::DTypeEnum,
EnumHash<DTypeEnum>,
EnumCmp<DTypeEnum>> dtype_cenum2benum;
static std::unordered_map<megdnn::DTypeEnum, std::string,
EnumHash<megdnn::DTypeEnum>,
EnumCmp<megdnn::DTypeEnum>> dtype_benum2cstr;
static std::unordered_map<megdnn::DTypeEnum, DTypeEnum,
EnumHash<megdnn::DTypeEnum>,
EnumCmp<megdnn::DTypeEnum>> dtype_benum2cenum;
static std::unordered_map<DTypeEnum, std::string,
EnumHash<DTypeEnum>,
EnumCmp<DTypeEnum>> dtype_cenum2cstr;
#define CUSTOM_BIND_DTYPE(custom_impl, builtin_dtype, ctype) \
auto cs2be##custom_impl = dtype_cstr2benum.emplace( \
std::string(#custom_impl), megdnn::DTypeEnum::builtin_dtype); \
auto ce2be##custom_impl = dtype_cenum2benum.emplace( \
DTypeEnum::custom_impl, megdnn::DTypeEnum::builtin_dtype); \
auto be2cs##custom_impl = dtype_benum2cstr.emplace( \
megdnn::DTypeEnum::builtin_dtype, std::string(#custom_impl)); \
auto be2ce##custom_impl = dtype_benum2cenum.emplace( \
megdnn::DTypeEnum::builtin_dtype, DTypeEnum::custom_impl); \
auto ce2cs##custom_impl = dtype_cenum2cstr.emplace( \
DTypeEnum::custom_impl, std::string(#custom_impl));
CUSTOM_FOR_EACH_TENSOR_DATA_TYPE(CUSTOM_BIND_DTYPE)
#undef CUSTOM_BIND_DTYPE
CUSTOM_PIMPL_CLS_DEFINE(DType)
const void *DType::impl() const {
return m_impl.get();
}
DType::DType(const void *impl): m_impl(nullptr, impl_deleter<DTypeImpl>) {
mgb_assert(impl != nullptr, "invalid ptr");
m_impl.reset(new DTypeImpl(DTypeImplConstRef(impl)));
}
DType::DType(const std::string &dtype): m_impl(nullptr, impl_deleter<DTypeImpl>) {
auto iter = dtype_cstr2benum.find(dtype);
mgb_assert(iter != dtype_cstr2benum.end(), "invalid dtype %s", dtype.c_str());
mgb_assert(
dtype[0] != 'q', "can not construct quantized dtype "
"%s without scale and zero_point", dtype.c_str()
);
m_impl.reset(new DTypeImpl(DTypeImpl::from_enum(iter->second)));
}
DType::DType(const char *dtype): DType(std::string(dtype)) {
}
DType::DType(const std::string &dtype, float scale, uint8_t zero_point)
: m_impl(nullptr, impl_deleter<DTypeImpl>) {
auto iter = dtype_cstr2benum.find(dtype);
mgb_assert(iter != dtype_cstr2benum.end(), "invalid dtype %s", dtype.c_str());
mgb_assert(
dtype[0] == 'q', "given scale/zero_point to construct "
"non-quantized dtype: %s is not allowed", dtype.c_str()
);
if (dtype == "quint8") {
m_impl.reset(new megdnn::ParameterizedDType<
megdnn::DTypeEnum::Quantized8Asymm>(scale, zero_point));
}
else {
mgb_assert(
zero_point == 0, "invalid zero point %d for dtype %s",
zero_point, dtype.c_str()
);
if (dtype == "qint8") {
m_impl.reset(new megdnn::ParameterizedDType<
megdnn::DTypeEnum::QuantizedS8>(scale));
}
else if (dtype == "qint16") {
m_impl.reset(new megdnn::ParameterizedDType<
megdnn::DTypeEnum::QuantizedS16>(scale));
}
else if (dtype == "qint32") {
m_impl.reset(new megdnn::ParameterizedDType<
megdnn::DTypeEnum::QuantizedS32>(scale));
}
else {
mgb_assert(false, "invalid dtype %s", dtype.c_str());
}
}
}
DType::DType(const char *dtype, float scale, uint8_t zero_point)
: DType(std::string(dtype), scale, zero_point) {
}
DType::DType(DTypeEnum dtype): m_impl(nullptr, impl_deleter<DTypeImpl>) {
auto iter = dtype_cenum2benum.find(dtype);
mgb_assert(iter != dtype_cenum2benum.end(), "invalid dtype");
mgb_assert(dtype < DTypeEnum::quint8,
"can not construct quantized dtype without scale and zero_point");
m_impl.reset(new DTypeImpl(DTypeImpl::from_enum(iter->second)));
}
DType::DType(DTypeEnum dtype, float scale, uint8_t zero_point)
: DType(dtype_cenum2cstr.find(dtype)->second, scale, zero_point) {
}
std::string DType::str(void) const {
if (!DTypeImplRef(m_impl.get()).valid())
return "invalid";
auto iter = dtype_benum2cstr.find(DTypeImplRef(m_impl.get()).enumv());
if (iter == dtype_benum2cstr.end())
return "invalid";
return iter->second;
}
DTypeEnum DType::enumv(void) const {
auto iter = dtype_benum2cenum.find(DTypeImplRef(m_impl.get()).enumv());
mgb_assert(iter != dtype_benum2cenum.end(), "invalid dtype");
return iter->second;
}
float DType::scale() const {
if (enumv() == DTypeEnum::qint8) {
return DTypeImplRef(m_impl.get()).param<dtype::QuantizedS8>().scale;
}
else if (enumv() == DTypeEnum::qint16) {
return DTypeImplRef(m_impl.get()).param<dtype::QuantizedS16>().scale;
}
else if (enumv() == DTypeEnum::qint32) {
return DTypeImplRef(m_impl.get()).param<dtype::QuantizedS32>().scale;
}
else if (enumv() == DTypeEnum::quint8) {
return DTypeImplRef(m_impl.get()).param<dtype::Quantized8Asymm>().scale;
}
else {
mgb_assert(false, "dtype %s has no scale", str().c_str());
return 0.f;
}
}
uint8_t DType::zero_point() const {
mgb_assert(enumv()==DTypeEnum::quint8, "dtype %s has no zero point", str().c_str());
return DTypeImplRef(m_impl.get()).param<dtype::Quantized8Asymm>().zero_point;
}
bool DType::is_legal(const std::string &dtype) {
return dtype_cstr2benum.find(dtype) != dtype_cstr2benum.end();
}
bool DType::is_legal(const DTypeEnum &dtype) {
return dtype_cenum2benum.find(dtype) != dtype_cenum2benum.end();
}
std::vector<std::string> DType::legal_dtypes(void) {
std::vector<std::string> ret;
for (const auto &kv: dtype_cstr2benum)
ret.emplace_back(kv.first);
return ret;
}
bool operator==(const DType &lhs, const DType &rhs) {
return DTypeImplRef(lhs.m_impl.get()) == DTypeImplRef(rhs.m_impl.get());
}
bool operator==(const DType &lhs, const std::string &rhs) {
return lhs.str() == rhs;
}
bool operator==(const DType &lhs, const char *rhs) {
return operator==(lhs, std::string(rhs));
}
bool operator==(const std::string &lhs, const DType &rhs) {
return operator==(rhs, lhs);
}
bool operator==(const char *lhs, const DType &rhs) {
return operator==(rhs, std::string(lhs));
}
CUSTOM_PIMPL_CLS_DEFINE(Format)
const void *Format::impl() const {
return m_impl.get();
}
Format::Format(const void *impl): m_impl(nullptr, impl_deleter<FormatImpl>) {
mgb_assert(impl != nullptr, "invalid ptr");
mgb_assert(FormatImplConstRef(impl).is_default(), "only default format is supported now");
m_impl.reset(new FormatImpl(FormatImplConstRef(impl)));
}
Format::Format(const std::string &format): m_impl(nullptr, impl_deleter<FormatImpl>) {
mgb_assert(format == "default", "only default format is supported now");
m_impl.reset(new FormatImpl());
}
Format::Format(const char *format): Format(std::string(format)) {
}
std::string Format::str(void) const {
return FormatImplRef(m_impl.get()).to_string();
}
bool Format::is_default(void) const {
return FormatImplRef(m_impl.get()).is_default();
}
const void *Tensor::impl(void) const {
return m_tensor;
}
Tensor::Tensor(const void *impl) {
mgb_assert(impl != nullptr, "invalid ptr");
m_tensor = const_cast<void*>(impl);
}
const size_t *Tensor::shapes_raw(void) const {
return TensorImplRef(m_tensor).shape().shape;
}
const ptrdiff_t *Tensor::strides_raw(void) const {
return TensorImplRef(m_tensor).layout().stride;
}
Tensor::Tensor(const Tensor &rhs) {
mgb_assert(rhs.m_tensor != nullptr, "invalid rhs for copy constructor\n");
m_tensor = rhs.m_tensor;
}
Tensor &Tensor::operator=(const Tensor &rhs) {
mgb_assert(rhs.m_tensor != nullptr, "invalid rhs for assignment operator");
if (&rhs == this || rhs.m_tensor == m_tensor)
return *this;
m_tensor = rhs.m_tensor;
return *this;
}
Shape Tensor::shape(void) const {
auto builtin = TensorImplRef(m_tensor).shape();
return Shape(&builtin);
}
DType Tensor::dtype(void) const {
auto builtin = TensorImplRef(m_tensor).dtype();
return DType(&builtin);
}
Format Tensor::format(void) const {
auto builtin = TensorImplRef(m_tensor).format();
return Format(&builtin);
}
Device Tensor::device(void) const {
auto builtin = TensorImplRef(m_tensor).comp_node();
return Device(&builtin);
}
size_t Tensor::size(void) const {
return TensorImplRef(m_tensor).shape().total_nr_elems();
}
std::vector<ptrdiff_t> Tensor::stride(void) const {
std::vector<ptrdiff_t> ret(TensorImplRef(m_tensor).shape().ndim);
for (size_t i=0; i<ret.size(); i++)
ret[i] = TensorImplRef(m_tensor).layout().stride[i];
return ret;
}
float Tensor::scale(void) const {
return dtype().scale();
}
uint8_t Tensor::zero_point(void) const {
return dtype().zero_point();
}
void *Tensor::data(void) {
return static_cast<void*>(TensorImplRef(m_tensor).raw_ptr());
}
const void *Tensor::data(void) const {
return static_cast<const void*>(TensorImplRef(m_tensor).raw_ptr());
}
} // namespace custom
/**
* \file src/custom/impl/utils.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/custom/utils.h"
#include "megbrain/common.h"
#include <sstream>
using namespace mgb;
namespace custom {
void assert_failed_log(const char *file, int line, const char *func, const char *expr, const char *msg_fmt, ...) {
std::string msg = ssprintf("`%s' is true at %s:%d: %s", expr, file, line, func);
if (msg_fmt) {
msg_fmt = convert_fmt_str(msg_fmt);
va_list ap;
va_start(ap, msg_fmt);
msg.append("\nextra message: ");
msg.append(svsprintf(msg_fmt, ap));
va_end(ap);
}
printf("%s\n", msg.c_str());
}
UnImpleWarnLog::UnImpleWarnLog(const std::string &func, const std::string &attr,
const std::string &val) {
mgb_log_warn("you are using the default custom %s function, the `%s` attribute "
"of all the outputs tensor will be the same with inputs tensor[0]. "
"If there is no input tensor, it will be `%s`",
func.c_str(), attr.c_str(), val.c_str());
}
}
/**
* \file src/custom/include/megbrain/custom/accessor.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <cstddef>
#include <cstdint>
namespace custom {
#ifdef __CUDACC__
#define CUSTOM_HOST __host__
#define CUSTOM_DEVICE __device__
#else
#define CUSTOM_HOST
#define CUSTOM_DEVICE
#endif
#define CUSTOM_HOST_DEVICE CUSTOM_HOST CUSTOM_DEVICE
template <typename T>
struct DefaultPtrTraits {
using PtrType = T*;
};
#ifdef __CUDACC__
template <typename T>
struct RestrictPtrTraits {
using PtrType = T* __restrict__;
};
#endif
template <typename T, size_t N,
template <typename U> class PtrTraits = DefaultPtrTraits,
typename index_t = int64_t>
class TensorAccessorProxyBase {
public:
using PtrType = typename PtrTraits<T>::PtrType;
protected:
PtrType m_data;
const index_t* m_sizes;
const index_t* m_strides;
public:
CUSTOM_HOST_DEVICE TensorAccessorProxyBase(PtrType data, const index_t *sizes, const index_t *strides) {
m_data = data;
m_sizes = sizes;
m_strides = strides;
}
CUSTOM_HOST_DEVICE index_t stride(index_t i) const {
return m_strides[i];
}
CUSTOM_HOST_DEVICE index_t size(index_t i) const {
return m_sizes[i];
}
CUSTOM_HOST_DEVICE PtrType data() const {
return m_data;
}
};
template<typename T, size_t N,
template <typename U> class PtrTraits = DefaultPtrTraits,
typename index_t = int64_t>
class TensorAccessorProxy: public TensorAccessorProxyBase<T, N, PtrTraits, index_t> {
public:
using PtrType = typename PtrTraits<T>::PtrType;
CUSTOM_HOST_DEVICE TensorAccessorProxy(PtrType data, const index_t *sizes, const index_t *strides)
: TensorAccessorProxyBase<T, N, PtrTraits, index_t>(data, sizes, strides) {
}
CUSTOM_HOST_DEVICE TensorAccessorProxy<T, N-1, PtrTraits, index_t> operator[](index_t i) {
return TensorAccessorProxy<T, N-1, PtrTraits, index_t>(
this->m_data + this->m_strides[0] * i,
this->m_sizes + 1,
this->m_strides + 1
);
}
CUSTOM_HOST_DEVICE const TensorAccessorProxy<T, N-1, PtrTraits, index_t> operator[](index_t i) const {
return TensorAccessorProxy<T, N-1, PtrTraits, index_t>(
this->m_data + this->m_strides[0] * i,
this->m_sizes + 1,
this->m_strides + 1
);
}
};
template<typename T, template <typename U> class PtrTraits, typename index_t>
class TensorAccessorProxy<T, 1, PtrTraits, index_t>
: public TensorAccessorProxyBase<T, 1, PtrTraits, index_t> {
public:
using PtrType = typename PtrTraits<T>::PtrType;
CUSTOM_HOST_DEVICE TensorAccessorProxy(PtrType data, const index_t *sizes, const index_t *strides)
: TensorAccessorProxyBase<T, 1, PtrTraits, index_t>(data, sizes, strides ) {
}
CUSTOM_HOST_DEVICE T &operator[](index_t i) {
return this->m_data[this->m_strides[0]*i];
}
CUSTOM_HOST_DEVICE const T &operator[](index_t i) const {
return this->m_data[this->m_strides[0]*i];
}
};
template<typename T, size_t N,
template <typename U> class PtrTraits = DefaultPtrTraits,
typename index_t = int64_t>
class TensorAccessorBase {
public:
using PtrType = typename PtrTraits<T>::PtrType;
protected:
PtrType m_data;
index_t m_sizes[N];
index_t m_strides[N];
public:
CUSTOM_HOST_DEVICE TensorAccessorBase(PtrType data, const size_t *sizes, const ptrdiff_t *strides) {
m_data = data;
for (size_t i=0; i<N; ++i) {
m_sizes[i] = sizes[i];
m_strides[i] = strides[i];
}
}
CUSTOM_HOST_DEVICE index_t stride(index_t i) const {
return m_strides[i];
}
CUSTOM_HOST_DEVICE index_t size(index_t i) const {
return m_sizes[i];
}
CUSTOM_HOST_DEVICE PtrType data() const {
return m_data;
}
};
template<typename T, size_t N,
template <typename U> class PtrTraits = DefaultPtrTraits,
typename index_t = int64_t>
class TensorAccessor: public TensorAccessorBase<T, N, PtrTraits, index_t> {
public:
using PtrType = typename PtrTraits<T>::PtrType;
CUSTOM_HOST_DEVICE TensorAccessor(PtrType data, const size_t *sizes, const ptrdiff_t *strides)
: TensorAccessorBase<T, N, PtrTraits, index_t>(data, sizes, strides) {
}
CUSTOM_HOST_DEVICE decltype(auto) operator[](index_t i) {
return TensorAccessorProxy<T, N, PtrTraits, index_t>(
this->m_data,
this->m_sizes,
this->m_strides
)[i];
}
CUSTOM_HOST_DEVICE decltype(auto) operator[](index_t i) const {
return TensorAccessorProxy<T, N, PtrTraits, index_t>(
this->m_data,
this->m_sizes,
this->m_strides
)[i];
}
};
}
/**
* \file src/custom/include/megbrain/custom/custom.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "op.h"
#include "tensor.h"
#include "param.h"
namespace custom {
std::shared_ptr<CustomOp> op_insert(std::string opname, uint32_t version);
}
#define CUSTOM_OP_REG(OpName) CustomOp &_##OpName = (*(op_insert(#OpName, CUSTOM_OP_VERSION)))
#define CUSTOM_OP_REG_BEGIN(OpName) \
namespace custom { \
namespace OpName {
#define CUSTOM_OP_REG_END(OpName) \
} \
}
#define CASE_TO_PERFORM_USING_HINT(name, case_type, real_type, hint, ...) \
case (case_type): { \
using hint = real_type; \
return __VA_ARGS__(); \
}
#define CASE_TO_PERFORM_ON_SCALAR(name, case_type, real_type, ...) \
CASE_TO_PERFORM_USING_HINT(name, case_type, real_type, scalar_t, __VA_ARGS__)
#define DISPATCH_FLOAT_TYPES(tensor_dtype, name, ...) \
[&]() { \
const auto &dtype = tensor_dtype; \
switch (dtype.enumv()) { \
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::float32, float, __VA_ARGS__) \
default: \
custom_assert(false, "no implemented %s kernel for dtype %s\n", \
name, dtype.str().c_str()); \
} \
}()
#define DISPATCH_INT_TYPES(tensor_dtype, name, ...) \
[&]() { \
const auto &dtype = tensor_dtype; \
switch (dtype.enumv()) { \
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int8, int8_t, __VA_ARGS__) \
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::uint8, uint8_t, __VA_ARGS__) \
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::uint16,uint16_t, __VA_ARGS__)\
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int16, int16_t, __VA_ARGS__) \
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int32, int32_t, __VA_ARGS__) \
default: \
custom_assert(false, "no implemented %s kernel for dtype %s\n", \
name, dtype.str().c_str()); \
} \
}()
#define DISPATCH_INT_AND_FLOAT_TYPES(tensor_dtype, name, ...) \
[&]() { \
const auto &dtype = tensor_dtype; \
switch (dtype.enumv()) { \
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int8, int8_t, __VA_ARGS__) \
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::uint8, uint8_t, __VA_ARGS__) \
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::uint16,uint16_t, __VA_ARGS__)\
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int16, int16_t, __VA_ARGS__) \
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int32, int32_t, __VA_ARGS__) \
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::float32, float, __VA_ARGS__) \
default: \
custom_assert(false, "no implemented %s kernel for dtype %s\n", \
name, dtype.str().c_str()); \
} \
}()
#define DISPATCH_SIGN_INT_TYPES(tensor_dtype, name, ...) \
[&]() { \
const auto &dtype = tensor_dtype; \
switch (dtype.enumv()) { \
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int8, int8_t, __VA_ARGS__) \
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int16, int16_t, __VA_ARGS__) \
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int32, int32_t, __VA_ARGS__) \
default: \
custom_assert(false, "no implemented %s kernel for dtype %s\n", \
name, dtype.str().c_str()); \
} \
}()
#define DISPATCH_SIGN_INT_AND_FLOAT_TYPES(tensor_dtype, name, ...) \
[&]() { \
const auto &dtype = tensor_dtype; \
switch (dtype.enumv()) { \
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::float32, float, __VA_ARGS__) \
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int8, int8_t, __VA_ARGS__) \
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int16, int16_t, __VA_ARGS__) \
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int32, int32_t, __VA_ARGS__) \
default: \
custom_assert(false, "no implemented %s kernel for dtype %s\n", \
name, dtype.str().c_str()); \
} \
}()
/**
* \file src/custom/include/megbrain/custom/data_adaptor.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/thin/small_vector.h"
namespace custom {
template <typename BuiltinT, typename CustomT>
BuiltinT to_builtin(const CustomT &custom) {
return *reinterpret_cast<const BuiltinT*>(custom.impl());
}
template <typename BuiltinT, typename CustomT>
CustomT to_custom(const BuiltinT &builtin) {
return std::move(CustomT(&builtin));
}
template <typename BuiltinT, typename CustomT>
megdnn::SmallVector<BuiltinT> to_builtin(const std::vector<CustomT> &customs) {
megdnn::SmallVector<BuiltinT> builtins;
for (size_t i=0; i<customs.size(); ++i) {
builtins.push_back(std::move(to_builtin<BuiltinT, CustomT>(customs[i])));
}
return std::move(builtins);
}
template <typename BuiltinT, typename CustomT>
std::vector<CustomT> to_custom(
const megdnn::SmallVector<BuiltinT> &builtins) {
std::vector<CustomT> customs;
for (size_t i=0; i<builtins.size(); ++i) {
customs.push_back(std::move(to_custom<BuiltinT, CustomT>(builtins[i])));
}
return std::move(customs);
}
}
#define to_custom_device(expr) custom::to_custom<CompNode, custom::Device>(expr)
#define to_builtin_device(expr) custom::to_builtin<CompNode, custom::Device>(expr)
#define to_custom_shape(expr) custom::to_custom<megdnn::TensorShape, custom::Shape>(expr)
#define to_builtin_shape(expr) custom::to_builtin<megdnn::TensorShape, custom::Shape>(expr)
#define to_custom_dtype(expr) custom::to_custom<megdnn::DType, custom::DType>(expr)
#define to_builtin_dtype(expr) custom::to_builtin<megdnn::DType, custom::DType>(expr)
#define to_custom_format(expr) custom::to_custom<megdnn::TensorLayout::Format, custom::Format>(expr)
#define to_builtin_format(expr) custom::to_builtin<megdnn::TensorLayout::Format, custom::Format>(expr)
#define to_custom_tensor(expr) custom::to_custom<DeviceTensorND, custom::Tensor>(expr)
#define to_builtin_tensor(expr) custom::to_builtin<DeviceTensorND, custom::Tensor>(expr)
/**
* \file src/custom/include/megbrain/custom/manager.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "custom.h"
#include "megbrain/common.h"
namespace custom {
class CustomOpManager {
std::unordered_map<std::string, std::shared_ptr<const CustomOp>> m_name2op;
std::unordered_map<RunTimeId, std::shared_ptr<const CustomOp>> m_id2op;
MGB_MUTEX m_mtx;
CustomOpManager() = default;
public:
PREVENT_COPY_AND_ASSIGN(CustomOpManager);
static CustomOpManager *inst(void);
~CustomOpManager();
std::shared_ptr<CustomOp> insert(const std::string &name, uint32_t version);
bool erase(const std::string &name);
bool erase(const RunTimeId &id);
std::shared_ptr<CustomOp> find_or_reg(const std::string &name, uint32_t version);
RunTimeId to_id(const std::string &name) const;
std::string to_name(const RunTimeId &id) const;
std::shared_ptr<const CustomOp> find(const std::string &name) const;
std::shared_ptr<const CustomOp> find(const RunTimeId &id) const;
std::vector<std::string> op_name_list(void);
std::vector<RunTimeId> op_id_list(void);
};
class CustomLib {
std::unique_ptr<void, void_deleter> m_handle;
std::vector<std::string> m_ops;
public:
PREVENT_COPY_AND_ASSIGN(CustomLib);
CustomLib(const std::string &path, int mode);
const std::vector<std::string> &ops_in_lib(void) const;
~CustomLib();
bool valid(void) const;
};
using LibHandle = std::shared_ptr<CustomLib>;
class LibManager {
std::unordered_map<std::string, LibHandle> m_custom_libs;
MGB_MUTEX m_mtx;
LibManager() = default;
public:
PREVENT_COPY_AND_ASSIGN(LibManager);
static LibManager *inst(void);
const std::vector<std::string> &install(const std::string &name, const std::string &path);
bool uninstall(const std::string &name);
friend class CustomOpManager;
};
}
/**
* \file src/custom/include/megbrain/custom/op.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "tensor.h"
#include "param.h"
#include <unordered_set>
#define PREVENT_COPY_AND_ASSIGN(Cls) \
Cls(const Cls&) = delete; \
Cls(const Cls&&) = delete; \
Cls &operator=(const Cls&) = delete; \
Cls &operator=(const Cls&&) = delete
#define CUSTOM_OP_MAJOR 0
#define CUSTOM_OP_MINOR 1
#define CUSTOM_OP_PATCH 0
#define CUSTOM_OP_VERSION CUSTOM_OP_MAJOR*10000 + CUSTOM_OP_MINOR*100 + CUSTOM_OP_PATCH
namespace custom {
using RunTimeId = uint64_t;
class ArgInfo {
CUSTOM_PIMPL_CLS_DECL(ArgInfo);
ArgInfo(const std::string &name,
const std::string &desc,
const std::unordered_set<std::string> &dtypes,
const int &ndim,
const std::string &mem_stgy);
const std::string &name(void) const;
const std::string &desc(void) const;
const std::unordered_set<std::string> &dtypes(void) const;
int ndim(void) const;
const std::string &mem_strategy(void) const;
std::string str() const;
};
class CustomOp {
std::unique_ptr<void, void_deleter> m_impl;
public:
CustomOp(const std::string &op_type, uint32_t version);
PREVENT_COPY_AND_ASSIGN(CustomOp);
using DeviceInferFuncPtr = void(*)(const std::vector<Device>&, const Param&, std::vector<Device>&);
using ShapeInferFuncPtr = void(*)(const std::vector<Shape>&, const Param&, std::vector<Shape>&);
using DTypeInferFuncPtr = void(*)(const std::vector<DType>&, const Param&, std::vector<DType>&);
using FormatInferFuncPtr = void(*)(const std::vector<Format>&, const Param&, std::vector<Format>&);
using PreprocessFuncPtr = void(*)(const std::vector<Tensor>&, const Param&, std::vector<Tensor>&);
using PostprocessFuncPtr = void(*)(const std::vector<Tensor>&, const Param&, std::vector<Tensor>&);
using ComputeFuncPtr = void(*)(const std::vector<Tensor>&, const Param&, std::vector<Tensor>&);
// write for forward
CustomOp &set_device_infer(DeviceInferFuncPtr func);
CustomOp &set_shape_infer(ShapeInferFuncPtr func);
CustomOp &set_dtype_infer(DTypeInferFuncPtr func);
CustomOp &set_format_infer(FormatInferFuncPtr func);
CustomOp &set_preprocess(PreprocessFuncPtr func);
CustomOp &set_preprocess(const std::string &device, PreprocessFuncPtr func);
CustomOp &set_postprocess(PostprocessFuncPtr func);
CustomOp &set_postprocess(const std::string &device, PostprocessFuncPtr func);
CustomOp &set_compute(ComputeFuncPtr func);
CustomOp &set_compute(const std::string &device, ComputeFuncPtr func);
CustomOp &set_description(const std::string &op_desc);
CustomOp &add_input(const std::string &name, const std::string &desc, const std::initializer_list<std::string> &legal_dtypes={"float32"}, int dims=-1, const std::string &mem_stgy="default");
CustomOp &add_output(const std::string &name, const std::string &desc, const std::initializer_list<std::string> &legal_dtypes={"float32"}, int dims=-1, const std::string &mem_stgy="default");
CustomOp &add_input(const std::string &name, const std::initializer_list<std::string> &legal_dtypes={"float32"}, int dims=-1, const std::string &mem_stgy="default");
CustomOp &add_output(const std::string &name, const std::initializer_list<std::string> &legal_dtypes={"float32"}, int dims=-1, const std::string &mem_stgy="default");
CustomOp &add_inputs(const size_t &input_num);
CustomOp &add_outputs(const size_t &output_num);
CustomOp &add_param(const std::string &name, const ParamVal &default_val);
CustomOp &add_param(const std::string &name, const std::string &desc, const ParamVal &default_val);
// read
std::string op_type(void) const;
std::string op_desc(void) const;
RunTimeId runtime_id(void) const;
size_t input_num(void) const;
size_t output_num(void) const;
std::string str(void) const;
const ParamInfo &param_info(void) const;
ArgInfo input_info(size_t idx) const;
ArgInfo output_info(size_t idx) const;
const std::vector<ArgInfo> &inputs_info(void) const;
const std::vector<ArgInfo> &outputs_info(void) const;
// use
std::vector<Device> infer_output_device(const std::vector<Device>&, const Param&) const;
std::vector<Shape> infer_output_shape (const std::vector<Shape>&, const Param&) const;
std::vector<DType> infer_output_dtype (const std::vector<DType>&, const Param&) const;
std::vector<Format> infer_output_format(const std::vector<Format>&, const Param&) const;
void compute(const std::vector<Tensor>&, const Param&, std::vector<Tensor>&) const;
};
}
/**
* \file src/custom/include/megbrain/custom/param.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <vector>
#include <string>
#include <unordered_map>
#include "param_val.h"
namespace custom {
class ParamSchemaImpl;
class ParamInfoImpl;
class ParamImpl;
// Schema of a param element
class ParamSchema {
CUSTOM_PIMPL_CLS_DECL(ParamSchema);
ParamSchema(const std::string &name, const ParamVal &value, const std::string &desc="");
const std::string &name(void) const;
const std::string &desc(void) const;
const ParamVal &default_val(void) const;
ParamDynType type(void) const;
std::string str(void) const;
};
class ParamInfo {
CUSTOM_PIMPL_CLS_DECL(ParamInfo);
void set_tag(const std::string&);
void set_meta(const std::vector<ParamSchema> &meta);
uint32_t tag(void) const;
std::vector<ParamSchema> &meta(void);
const std::vector<ParamSchema> &meta(void) const;
};
class Param {
CUSTOM_PIMPL_CLS_DECL(Param);
Param(const ParamInfo&);
ParamVal &operator[](const std::string&);
const ParamVal &operator[](const std::string&) const;
const std::unordered_map<std::string, ParamVal> &raw() const;
bool exist(const std::string &name) const;
std::string to_bytes(void) const;
void from_bytes(const std::string&);
};
bool operator==(const Param&, const Param&);
} // custom
/**
* \file src/custom/include/megbrain/custom/param_val.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <string>
#include <vector>
#include <cassert>
#include <sstream>
#include <memory>
#include <unordered_map>
#include "utils.h"
namespace custom {
/**
* we can add a new basic data type here, basic means we can perform binary
* op such as: +, -, *, /, ==, != between any two of them
*/
#define CUSTOM_FOR_EACH_BASIC_PARAMTYPE(cb, ...) \
cb(Int32, int32_t, ##__VA_ARGS__) \
cb(Int64, int64_t, ##__VA_ARGS__) \
cb(Uint32, uint32_t, ##__VA_ARGS__) \
cb(Uint64, uint64_t, ##__VA_ARGS__) \
cb(Float32, float, ##__VA_ARGS__) \
cb(Float64, double, ##__VA_ARGS__) \
cb(Bool, bool, ##__VA_ARGS__)
#define CUSTOM_FOR_STRING_PARAMTYPE(cb, ...) \
cb(String, std::string, ##__VA_ARGS__)
#define CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(cb, ...) \
cb(Int32List, std::vector<int32_t>, ##__VA_ARGS__) \
cb(Int64List, std::vector<int64_t>, ##__VA_ARGS__) \
cb(Uint32List, std::vector<uint32_t>, ##__VA_ARGS__) \
cb(Uint64List, std::vector<uint64_t>, ##__VA_ARGS__) \
cb(Float32List, std::vector<float>, ##__VA_ARGS__) \
cb(Float64List, std::vector<double>, ##__VA_ARGS__)
#define CUSTOM_FOR_BOOL_LIST_PARAMTYPE(cb, ...) \
cb(BoolList, std::vector<bool>, ##__VA_ARGS__)
#define CUSTOM_FOR_STRING_LIST_PARAMTYPE(cb, ...) \
cb(StringList, std::vector<std::string>, ##__VA_ARGS__)
/**
* to avoid the recursive of MACRO
*/
#define CUSTOM_FOR_EACH_BASIC_PARAMTYPE_COPY(cb, ...) \
cb(Int32, int32_t, ##__VA_ARGS__) \
cb(Int64, int64_t, ##__VA_ARGS__) \
cb(Uint32, uint32_t, ##__VA_ARGS__) \
cb(Uint64, uint64_t, ##__VA_ARGS__) \
cb(Float32, float, ##__VA_ARGS__) \
cb(Float64, double, ##__VA_ARGS__) \
cb(Bool, bool, ##__VA_ARGS__)
#define CUSTOM_FOR_EACH_VALID_PARAMTYPE(cb, ...) \
CUSTOM_FOR_EACH_BASIC_PARAMTYPE(cb, ##__VA_ARGS__) \
CUSTOM_FOR_STRING_PARAMTYPE(cb, ##__VA_ARGS__) \
CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(cb, ##__VA_ARGS__) \
CUSTOM_FOR_BOOL_LIST_PARAMTYPE(cb, ##__VA_ARGS__) \
CUSTOM_FOR_STRING_LIST_PARAMTYPE(cb, ##__VA_ARGS__)
#define CUSTOM_FOR_EACH_LIST_PARAMTYPE(cb, ...) \
CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(cb, ##__VA_ARGS__) \
CUSTOM_FOR_BOOL_LIST_PARAMTYPE(cb, ##__VA_ARGS__) \
CUSTOM_FOR_STRING_LIST_PARAMTYPE(cb, ##__VA_ARGS__)
/**
* Macro Callback for Register
*/
#define CUSTOM_REG_DYN_PARAMTYPE(dyn_type, static_type) dyn_type,
#define CUSTOM_REG_DYN_PARAMTYPE_NAME(dyn_type, static_type) {ParamDynType::dyn_type, #dyn_type},
#define CUSTOM_REG_DYN_PARAMTYPE_GETTER(dyn_type, static_type) \
template <> \
struct get_dyn_type<static_type> { \
static constexpr ParamDynType type = ParamDynType::dyn_type;\
};
#define CUSTOM_REG_STATIC_PARAMTYPE_GETTER(dyn_type, static_type) \
template <> \
struct get_static_type<ParamDynType::dyn_type> { \
using type = static_type; \
};
enum class ParamDynType: uint32_t {
CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_REG_DYN_PARAMTYPE)
Invalid=255
};
static std::unordered_map<ParamDynType, std::string, EnumHash<ParamDynType>, EnumCmp<ParamDynType>> type2name = {
CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_REG_DYN_PARAMTYPE_NAME)
{ParamDynType::Invalid, "Invalid"}
};
/**
* get the dynamic data type according to the builtin static data type
* we can use it like:
* ParamDynType dyn_type = get_dyn_type<int32_t>::type;
* assert(dyn_type == ParamDynType::Int32)
*/
template <typename T>
struct get_dyn_type {
static constexpr ParamDynType type = ParamDynType::Invalid;
};
/**
* get the static data type according to the dynamic data type
* we can use it like:
* get_static_type<ParamDynType::Int32>::type int_32_value;
* assert(std::is_same<decltype(int_32_value), int>::value)
*/
template <ParamDynType>
struct get_static_type;
CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_REG_DYN_PARAMTYPE_GETTER)
CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_REG_STATIC_PARAMTYPE_GETTER)
#undef CUSTOM_REG_DYN_PARAMTYPE
#undef CUSTOM_REG_DYN_PARAMTYPE_NAME
#undef CUSTOM_REG_DYN_PARAMTYPE_GETTER
#undef CUSTOM_REG_STATIC_PARAMTYPE_GETTER
template <typename T>
struct get_vector_template_arg_type;
template <typename T>
struct get_vector_template_arg_type<std::vector<T>> {
using type = std::decay_t<T>;
};
template <typename T>
struct is_vector {
static constexpr bool value = false;
};
template <typename T>
struct is_vector <std::vector<T>> {
static constexpr bool value = true;
};
template <typename T>
std::string vec2str(const std::vector<T> &vec) {
std::stringstream ss;
ss << "{";
for (const auto &val: vec) {
ss << val << ", ";
}
if (vec.size() != 0) {
ss.seekp(ss.tellp()-std::streampos(2));
}
ss << "}";
return ss.str();
}
/**
* we use void* rather than template to help us realise a complete dynamic type
* if we use template such as:
* template <typename T>
* class ParamVal {
* T m_data;
* }
* Con1: user need to set the type explicitly when class template instantiation
* Con2: ParamVal<int> can not be assigned to ParamVal<double>
*/
class ParamVal {
std::unique_ptr<void, void_deleter> m_ptr;
ParamDynType m_type;
public:
template <typename T>
ParamVal(const T &val);
template <typename T>
ParamVal(const std::initializer_list<T> &val);
ParamVal();
ParamVal(const char *str);
ParamVal(const std::initializer_list<const char*> &strs);
ParamVal(const std::vector<const char*> &strs);
ParamVal(const ParamVal &rhs);
template <typename T>
ParamVal &operator=(const T &rhs);
template <typename T>
ParamVal &operator=(const std::initializer_list<T> &val);
ParamVal &operator=(const char *str);
ParamVal &operator=(const std::initializer_list<const char*> &strs);
ParamVal &operator=(const std::vector<const char*> &strs);
ParamVal &operator=(const ParamVal &rhs);
template <typename T>
const T &as(void) const;
template <typename T>
T &as(void);
const void *raw_ptr(void) const;
void *raw_ptr(void);
ParamDynType type(void) const;
std::string str(void) const;
size_t size(void) const;
static std::string to_bytes(const ParamVal &value);
static ParamVal from_bytes(const std::string &bytes, size_t &offset);
friend ParamVal operator+(const ParamVal &lhs, const ParamVal &rhs);
friend ParamVal operator-(const ParamVal &lhs, const ParamVal &rhs);
friend ParamVal operator*(const ParamVal &lhs, const ParamVal &rhs);
friend ParamVal operator/(const ParamVal &lhs, const ParamVal &rhs);
friend bool operator==(const ParamVal &lhs, const ParamVal &rhs);
friend bool operator!=(const ParamVal &lhs, const ParamVal &rhs);
friend bool operator> (const ParamVal &lhs, const ParamVal &rhs);
friend bool operator< (const ParamVal &lhs, const ParamVal &rhs);
friend bool operator>=(const ParamVal &lhs, const ParamVal &rhs);
friend bool operator<=(const ParamVal &lhs, const ParamVal &rhs);
};
ParamVal operator+(const ParamVal &lhs, const ParamVal &rhs);
ParamVal operator-(const ParamVal &lhs, const ParamVal &rhs);
ParamVal operator*(const ParamVal &lhs, const ParamVal &rhs);
ParamVal operator/(const ParamVal &lhs, const ParamVal &rhs);
bool operator==(const ParamVal &lhs, const ParamVal &rhs);
bool operator!=(const ParamVal &lhs, const ParamVal &rhs);
bool operator> (const ParamVal &lhs, const ParamVal &rhs);
bool operator< (const ParamVal &lhs, const ParamVal &rhs);
bool operator>=(const ParamVal &lhs, const ParamVal &rhs);
bool operator<=(const ParamVal &lhs, const ParamVal &rhs);
template <typename T>
ParamVal::ParamVal(const T &val): m_ptr(nullptr, impl_deleter<std::decay_t<T>>) {
using DecayType = std::decay_t<T>;
m_type = get_dyn_type<DecayType>::type;
custom_assert(m_type != ParamDynType::Invalid, "param construct error! unsupported builtin type");
m_ptr.reset(new DecayType(val));
}
template <typename T>
ParamVal::ParamVal(const std::initializer_list<T> &val): ParamVal(std::vector<std::decay_t<T>>(val)) {
}
template <typename T>
ParamVal &ParamVal::operator=(const T &rhs) {
using DecayType = std::decay_t<T>;
ParamDynType rhs_dyn_type = get_dyn_type<DecayType>::type;
custom_assert(rhs_dyn_type != ParamDynType::Invalid, "unsupported builtin dtype");
if (rhs_dyn_type == m_type) {
TypedRef(DecayType, m_ptr.get()) = rhs;
}
else {
m_type = rhs_dyn_type;
std::unique_ptr<void, void_deleter> new_ptr(new DecayType(rhs), impl_deleter<DecayType>);
m_ptr.swap(new_ptr);
}
return *this;
}
template <typename T>
ParamVal &ParamVal::operator=(const std::initializer_list<T> &val) {
return this->operator=(std::vector<std::decay_t<T>>(val));
}
template <typename T>
const T &ParamVal::as(void) const {
return const_cast<ParamVal*>(this)->as<T>();
}
template <typename T>
T &ParamVal::as(void) {
using DecayType = std::decay_t<T>;
ParamDynType t_dyn_type = get_dyn_type<DecayType>::type;
custom_assert(
t_dyn_type == m_type, "type mismatch, type %s cannot be cast to type %s\n",
type2name[m_type].c_str(), type2name[t_dyn_type].c_str()
);
return TypedRef(T, m_ptr.get());
}
}
/**
* \file src/custom/include/megbrain/custom/tensor.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include <vector>
#include <string>
#include "utils.h"
#include "accessor.h"
namespace custom {
#define CUSTOM_DATA_ADAPTOR_FRIEND_DECL \
template <typename BuiltinT, typename CustomT> \
friend BuiltinT to_builtin(const CustomT &custom); \
template <typename BuiltinT, typename CustomT> \
friend CustomT to_custom(const BuiltinT &builtin)
#define CUSTOM_FOR_EACH_DEVICE_TYPE(cb) \
cb(x86, CPU, "cpux") \
cb(cuda, CUDA, "gpux")
#define CUSTOM_DEVICE_TYPE_ENUM_DECL(custom_type, builtin_type, builtin_str) custom_type,
class Device {
const void *impl() const;
Device(const void *impl);
CUSTOM_PIMPL_CLS_DECL(Device);
public:
enum class DeviceEnum: uint32_t {
CUSTOM_FOR_EACH_DEVICE_TYPE(CUSTOM_DEVICE_TYPE_ENUM_DECL)
};
Device(const std::string &device);
Device(const char *device);
Device(DeviceEnum device);
std::string str(void) const;
DeviceEnum enumv(void) const;
static bool is_legal(const std::string &device);
static bool is_legal(DeviceEnum device);
static std::vector<std::string> legal_devices(void);
friend class Tensor;
friend bool operator==(const Device &lhs, const Device &rhs);
CUSTOM_DATA_ADAPTOR_FRIEND_DECL;
};
using DeviceEnum = Device::DeviceEnum;
bool operator==(const Device &lhs, const Device &rhs);
class Shape {
const void *impl() const;
Shape(const void *impl);
CUSTOM_PIMPL_CLS_DECL(Shape);
public:
Shape(const std::vector<size_t> &rhs);
Shape(const std::initializer_list<size_t> &rhs);
size_t &operator[](size_t idx);
size_t operator[](size_t idx) const;
void ndim(size_t dim);
size_t ndim(void) const;
friend class Tensor;
friend bool operator==(const Shape &lhs, const Shape &rhs);
CUSTOM_DATA_ADAPTOR_FRIEND_DECL;
};
bool operator==(const Shape &lhs, const Shape &rhs);
using float16_t = uint16_t;
using bfloat16_t = uint16_t;
#if MEGDNN_DISABLE_FLOAT16
#define fp16_wrap(cb, custom_dtype, dnn_dtype, c_dtype)
#else
#define fp16_wrap(cb, custom_dtype, dnn_dtype, c_dtype) cb(custom_dtype, dnn_dtype, c_dtype)
#endif
#define CUSTOM_FOR_EACH_TENSOR_DATA_TYPE(cb) \
cb(float32, Float32, float) \
cb(uint8, Uint8, uint8_t) \
cb(int8, Int8, int8_t) \
cb(int16, Int16, int16_t) \
cb(int32, Int32, int32_t) \
fp16_wrap(cb, float16, Float16, float16_t) \
fp16_wrap(cb, bfloat16, BFloat16, bfloat16_t) \
cb(uint16, Uint16, uint16_t) \
cb(quint8, Quantized8Asymm, uint8_t) \
cb(qint32, QuantizedS32, int32_t) \
cb(qint8, QuantizedS8, int8_t) \
cb(qint16, QuantizedS16, int16_t)
#define CUSTOM_DTYPE_ENUM_DECL(custom_type, builtin_type, ctype) custom_type,
class DType {
const void *impl() const;
DType(const void *impl);
CUSTOM_PIMPL_CLS_DECL(DType);
public:
enum class DTypeEnum: uint32_t {
CUSTOM_FOR_EACH_TENSOR_DATA_TYPE(CUSTOM_DTYPE_ENUM_DECL)
};
DType(const std::string &dtype);
DType(const char *dtype);
DType(const std::string &dtype, float scale, uint8_t zero_point = 0);
DType(const char *dtype, float scale, uint8_t zero_point = 0);
DType(DTypeEnum dtype);
DType(DTypeEnum dtype, float scale, uint8_t zero_point = 0);
std::string str(void) const;
DTypeEnum enumv() const;
float scale(void) const;
uint8_t zero_point(void) const;
template<typename T>
bool is_compatible(void) const;
static bool is_legal(const std::string &dtype);
static bool is_legal(const DTypeEnum &dtype);
static std::vector<std::string> legal_dtypes(void);
friend class Tensor;
friend bool operator==(const DType &lhs, const DType &rhs);
CUSTOM_DATA_ADAPTOR_FRIEND_DECL;
};
using DTypeEnum = DType::DTypeEnum;
template <DTypeEnum>
struct DTypeTrait;
#define CUSTOM_DEFINE_DTYPE_TRAIT(custom_type, builtin_type, ctype) \
template <> \
struct DTypeTrait<DTypeEnum::custom_type> { \
using type = ctype; \
};
#define CUSTOM_CASE_TO_COMPARE_DTYPE(custom_type, builtin_type, ctype) \
case (DTypeEnum::custom_type): { \
return std::is_same<DecayT, ctype>::value; \
}
CUSTOM_FOR_EACH_TENSOR_DATA_TYPE(CUSTOM_DEFINE_DTYPE_TRAIT)
template<typename T>
bool DType::is_compatible(void) const {
using DecayT = typename std::decay<T>::type;
auto dtype_enum = enumv();
#if !MEGDNN_DISABLE_FLOAT16
if (dtype_enum == DTypeEnum::float16) {
return sizeof(DecayT) == sizeof(DTypeTrait<DTypeEnum::float16>::type);
}
else if (dtype_enum == DTypeEnum::bfloat16) {
return sizeof(DecayT) == sizeof(DTypeTrait<DTypeEnum::bfloat16>::type);
}
#endif
switch (dtype_enum) {
CUSTOM_FOR_EACH_TENSOR_DATA_TYPE(CUSTOM_CASE_TO_COMPARE_DTYPE)
default:
return false;
}
}
bool operator==(const DType &lhs, const DType &rhs);
bool operator==(const DType &lhs, const std::string &rhs);
bool operator==(const DType &lhs, const char *rhs);
bool operator==(const std::string &lhs, const DType &rhs);
bool operator==(const char *lhs, const DType &rhs);
class Format {
const void *impl() const;
Format(const void *impl);
CUSTOM_PIMPL_CLS_DECL(Format);
public:
Format(const std::string &format);
Format(const char *format);
std::string str(void) const;
bool is_default(void) const;
friend class Tensor;
CUSTOM_DATA_ADAPTOR_FRIEND_DECL;
};
class Tensor {
void *m_tensor;
const void *impl(void) const;
Tensor(const void *impl);
const size_t *shapes_raw(void) const;
const ptrdiff_t *strides_raw(void) const;
public:
Tensor() = delete;
Tensor(const Tensor &rhs);
Tensor &operator=(const Tensor &rhs);
Shape shape(void) const;
DType dtype(void) const;
Format format(void) const;
Device device(void) const;
size_t size(void) const;
std::vector<ptrdiff_t> stride(void) const;
float scale(void) const;
uint8_t zero_point(void) const;
void *data(void);
const void *data(void) const;
template <typename T>
T *data(void);
template <typename T>
const T *data(void) const;
template <typename T, size_t N,
template <typename U> class PtrTraits = DefaultPtrTraits,
typename index_t = int64_t>
const TensorAccessor<T, N, PtrTraits, index_t> accessor() const;
template <typename T, size_t N,
template <typename U> class PtrTraits = DefaultPtrTraits,
typename index_t = int64_t>
TensorAccessor<T, N, PtrTraits, index_t> accessor();
CUSTOM_DATA_ADAPTOR_FRIEND_DECL;
};
template <typename T>
T *Tensor::data(void) {
custom_assert(dtype().is_compatible<T>(),
"invalid convert, tensor data type is %s", dtype().str().c_str());
return reinterpret_cast<T*>(data());
}
template <typename T>
const T *Tensor::data(void) const {
return const_cast<Tensor*>(this)->data<T>();
}
template <typename T, size_t N, template <typename U> class PtrTraits, typename index_t>
const TensorAccessor<T, N, PtrTraits, index_t> Tensor::accessor() const {
return const_cast<Tensor*>(this)->accessor<T, N, PtrTraits, index_t>();
}
template <typename T, size_t N, template <typename U> class PtrTraits, typename index_t>
TensorAccessor<T, N, PtrTraits, index_t> Tensor::accessor() {
custom_assert(N == shape().ndim(),
"cannot get a %lu-d accessor for a tensor with dim %lu", static_cast<unsigned long>(N), static_cast<unsigned long>(shape().ndim()));
custom_assert(N > 0, "cannot get 0-d accessor");
T *ptr = data<T>();
return TensorAccessor<T, N, PtrTraits, index_t>(ptr, shapes_raw(), strides_raw());
}
#undef CUSTOM_DATA_ADAPTOR_FRIEND_DECL
#undef CUSTOM_DEVICE_TYPE_ENUM_DECL
#undef CUSTOM_DTYPE_ENUM_DECL
#undef CUSTOM_DEFINE_DTYPE_TRAIT
#undef CUSTOM_CASE_TO_COMPARE_DTYPE
} // custom
/**
* \file src/custom/include/megbrain/custom/utils.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include <vector>
#include <string>
#include <memory>
#include <cassert>
namespace custom {
void assert_failed_log(const char *file, int line, const char *func, const char *expr, const char *msg_fmt, ...);
#define custom_expect(expr, msg...) \
if (!(expr)) { \
assert_failed_log( \
__FILE__, __LINE__, __PRETTY_FUNCTION__, #expr, ##msg \
); \
}
#define custom_assert(expr, msg...) \
if (!(expr)) { \
assert_failed_log( \
__FILE__, __LINE__, __PRETTY_FUNCTION__, #expr, ##msg \
); \
} \
assert((expr))
class UnImpleWarnLog {
public:
UnImpleWarnLog(const std::string &func, const std::string &attr,
const std::string &val);
};
using void_deleter = void(*)(void*);
template<typename Impl>
void impl_deleter(void *ptr) {
delete reinterpret_cast<Impl*>(ptr);
}
#define TypedPtr(type, raw_ptr) reinterpret_cast<type*>(raw_ptr)
#define TypedRef(type, raw_ptr) (*reinterpret_cast<type*>(raw_ptr))
#define CUSTOM_PIMPL_CLS_DECL(Cls) \
std::unique_ptr<void, void_deleter> m_impl; \
public: \
Cls(); \
Cls(const Cls &rhs); \
Cls &operator=(const Cls &rhs)
#define CUSTOM_PIMPL_CLS_DEFINE(Cls) \
Cls::Cls(): m_impl(new Cls##Impl(), impl_deleter<Cls##Impl>) {} \
\
Cls::Cls(const Cls &rhs): m_impl(nullptr, impl_deleter<Cls##Impl>) { \
custom_assert( \
rhs.m_impl != nullptr, \
"invalid rhs for the copy constructor of %s", #Cls \
); \
m_impl.reset(new Cls##Impl(TypedRef(Cls##Impl, rhs.m_impl.get()))); \
} \
\
Cls &Cls::operator=(const Cls &rhs) { \
custom_assert( \
m_impl != nullptr && rhs.m_impl != nullptr, \
"invalid assignment of %s, lhs or rhs is invalid", #Cls \
); \
if (&rhs == this) \
return *this; \
\
TypedRef(Cls##Impl, m_impl.get()) = TypedRef(Cls##Impl, rhs.m_impl.get()); \
return *this; \
}
/**
* we define this two function explicitly used for std::unordered_map
* to improve the compatibility with different compiler versions
*/
template <typename T>
struct EnumHash {
size_t operator()(const T &rhs) const {
return static_cast<size_t>(rhs);
}
};
template <typename T>
struct EnumCmp {
bool operator()(const T &lhs, const T &rhs) const {
return static_cast<size_t>(lhs) == static_cast<size_t>(rhs);
}
};
} // custom
/**
* \file src/custom/test/manager.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/custom/manager.h"
#include "megbrain/custom/custom.h"
#include "gtest/gtest.h"
#define MANAGER_TEST_LOG 0
namespace custom {
TEST(TestOpManager, TestOpManager) {
CustomOpManager *com = CustomOpManager::inst();
com->insert("Op1", CUSTOM_OP_VERSION);
com->insert("Op2", CUSTOM_OP_VERSION);
std::shared_ptr<CustomOp> ptr = com->find_or_reg("Op3", CUSTOM_OP_VERSION);
ASSERT_TRUE(ptr != nullptr);
std::vector<std::string> op_names = com->op_name_list();
std::vector<RunTimeId> op_ids = com->op_id_list();
ASSERT_TRUE(op_names.size() == 3);
ASSERT_TRUE(op_ids.size() == 3);
#if MANAGER_TEST_LOG
for (std::string &name: op_names) {
std::cout << name << std::endl;
}
#endif
for (std::string &name: op_names) {
std::shared_ptr<const CustomOp> op = com->find(name);
ASSERT_TRUE(op != nullptr);
ASSERT_TRUE(op->op_type() == name);
RunTimeId id = com->to_id(name);
ASSERT_TRUE(com->find(id) == op);
}
for (RunTimeId &id: op_ids) {
std::shared_ptr<const CustomOp> op = com->find(id);
ASSERT_TRUE(op != nullptr);
ASSERT_TRUE(op->runtime_id() == id);
std::string name = com->to_name(id);
ASSERT_TRUE(com->find(name) == op);
}
ASSERT_FALSE(com->erase("Op0"));
#if MANAGER_TEST_LOG
for (auto &name: com->op_name_list()) {
std::cout << name << std::endl;
}
#endif
ASSERT_TRUE(com->erase("Op1"));
ASSERT_TRUE(com->erase(com->to_id("Op2")));
ASSERT_TRUE(com->op_id_list().size() == 1);
ASSERT_TRUE(com->op_name_list().size() == 1);
ASSERT_TRUE(com->op_name_list()[0] == "Op3");
ptr.reset();
ASSERT_TRUE(com->erase("Op3"));
}
TEST(TestOpManager, TestOpReg) {
CUSTOM_OP_REG(Op1)
.add_inputs(2)
.add_outputs(3)
.add_input("lhs")
.add_param("param1", 1)
.add_param("param2", 3.45);
CUSTOM_OP_REG(Op2)
.add_input("lhs")
.add_input("rhs")
.add_output("out")
.add_param("param1", "test")
.add_param("param2", true)
.add_param("", "no name");
(void)_Op1;
(void)_Op2;
#if MANAGER_TEST_LOG
for (const auto &name: CustomOpManager::inst()->op_name_list()) {
std::cout << CustomOpManager::inst()->find(name)->str() << std::endl;
}
#endif
}
}
/**
* \file src/custom/test/op.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/custom/op.h"
#include "megbrain/comp_node.h"
#include "megbrain/tensor.h"
#include "megbrain/custom/data_adaptor.h"
#include "gtest/gtest.h"
#include "megbrain_build_config.h"
#define OP_TEST_LOG 0
using namespace mgb;
namespace custom {
TEST(TestCustomOp, TestCustomOpInfoSetter) {
CustomOp test("TestOp", CUSTOM_OP_VERSION);
test.set_description("Test Op")
.add_input("lhs", "lhs of test op", {"float32", "int32"}, 2)
.add_inputs(2)
.add_input("rhs", "rhs of test op", {"float32", "int32"}, 2)
.add_outputs(1)
.add_output("out", "out of test op", {"float32", "int32"}, 2)
.add_outputs(3);
ASSERT_TRUE(test.op_type() == "TestOp");
ASSERT_TRUE(test.op_desc() == "Test Op");
ASSERT_TRUE(test.input_num() == 4);
ASSERT_TRUE(test.output_num() == 5);
#if OP_TEST_LOG
for (auto input: test.inputs_info()) {
std::cout << input.str() << std::endl;
}
for (auto output: test.outputs_info()) {
std::cout << output.str() << std::endl;
}
#endif
test.add_param("param1", "param1 - float", 1.23f)
.add_param("param2", "param2 - float list", {2.34f, 3.45f})
.add_param("param3", "param3 - string", "test-string")
.add_param("param4", {"test", "string", "list"})
.add_param("param5", 1);
#if OP_TEST_LOG
ParamInfo pinfo = test.param_info();
for (auto kv: pinfo.meta()) {
std::cout << kv.str() << std::endl;
}
#endif
}
void device_infer(const std::vector<Device> &inputs, const Param &params,
std::vector<Device> &outputs) {
(void)inputs;
(void)params;
(void)outputs;
outputs[0] = inputs[1];
outputs[1] = inputs[0];
}
void shape_infer(const std::vector<Shape> &inputs, const Param &params,
std::vector<Shape> &outputs) {
(void)inputs;
(void)params;
(void)outputs;
outputs[0] = inputs[1];
outputs[1] = inputs[0];
}
void dtype_infer(const std::vector<DType> &inputs, const Param &params,
std::vector<DType> &outputs) {
(void)inputs;
(void)params;
(void)outputs;
outputs[0] = inputs[1];
outputs[1] = inputs[0];
}
void format_infer(const std::vector<Format> &inputs, const Param &params,
std::vector<Format> &outputs) {
(void)inputs;
(void)params;
(void)outputs;
outputs[0] = inputs[1];
outputs[1] = inputs[0];
}
void cpu_kernel(const std::vector<Tensor> &inputs, const Param &params,
std::vector<Tensor> &outputs) {
(void)inputs;
(void)params;
(void)outputs;
#if OP_TEST_LOG
std::cout << "Checking CPU Forward - " << params["device"].as<std::string>() << std::endl;
#endif
ASSERT_TRUE(params["device"] == "x86");
}
void gpu_kernel(const std::vector<Tensor> &inputs, const Param &params,
std::vector<Tensor> &outputs) {
(void)inputs;
(void)params;
(void)outputs;
#if OP_TEST_LOG
std::cout << "Checking GPU Forward - " << params["device"].as<std::string>() << std::endl;
#endif
ASSERT_TRUE(params["device"] == "cuda");
}
TEST(TestCustomOp, TestCustomOpFuncSetter) {
#if MGB_CUDA
CustomOp test("TestOp", CUSTOM_OP_VERSION);
test.set_description("Test Op Forward Backward Union")
.add_input("lhs", "lhs of Test op", {"float32", "int32"}, 2)
.add_input("rhs", "rhs of Test op", {"float32", "int32"}, 2)
.add_output("outl", "outl of Test op", {"float32", "int32"}, 2)
.add_output("outr", "outr of Test op", {"float32", "int32"}, 2)
.add_param("smooth", "smooth", 0.f)
.add_param("device", "using for judge device", "x86");
std::vector<Device> idevices = {"x86", "cuda"};
std::vector<Shape> ishapes = {{2, 3}, {3, 4}};
std::vector<DType> idtypes = {"int32", "float32"};
std::vector<Format> iformats = {"default", "default"};
Param param(test.param_info());
std::vector<Device> odevices = test.infer_output_device(idevices, param);
std::vector<Shape> oshapes = test.infer_output_shape (ishapes, param);
std::vector<DType> odtypes = test.infer_output_dtype (idtypes, param);
std::vector<Format> oformats = test.infer_output_format(iformats, param);
ASSERT_TRUE(odevices.size() == 2);
ASSERT_TRUE(oshapes.size() == 2);
ASSERT_TRUE(odtypes.size() == 2);
ASSERT_TRUE(oformats.size() == 2);
ASSERT_TRUE(odevices[0] == "x86");
ASSERT_TRUE(odevices[1] == "x86");
ASSERT_TRUE(oshapes[0] == Shape({2,3}));
ASSERT_TRUE(oshapes[1] == Shape({2,3}));
ASSERT_TRUE(odtypes[0] == "int32");
ASSERT_TRUE(odtypes[1] == "int32");
ASSERT_TRUE(iformats[0].is_default());
ASSERT_TRUE(iformats[1].is_default());
test.set_device_infer(device_infer)
.set_shape_infer(shape_infer)
.set_dtype_infer(dtype_infer)
.set_format_infer(format_infer);
odevices = test.infer_output_device(idevices, param);
oshapes = test.infer_output_shape (ishapes, param);
odtypes = test.infer_output_dtype (idtypes, param);
oformats = test.infer_output_format(iformats, param);
ASSERT_TRUE(odevices.size() == 2);
ASSERT_TRUE(oshapes.size() == 2);
ASSERT_TRUE(odtypes.size() == 2);
ASSERT_TRUE(oformats.size() == 2);
ASSERT_TRUE(odevices[0] == "cuda");
ASSERT_TRUE(odevices[1] == "x86");
ASSERT_TRUE(oshapes[0] == Shape({3,4}));
ASSERT_TRUE(oshapes[1] == Shape({2,3}));
ASSERT_TRUE(odtypes[0] == "float32");
ASSERT_TRUE(odtypes[1] == "int32");
ASSERT_TRUE(iformats[0].is_default());
ASSERT_TRUE(iformats[1].is_default());
test.set_compute(cpu_kernel);
DeviceTensorND cdev_itensor0(CompNode::load("cpux"), {3, 2}, dtype::Int32{});
DeviceTensorND cdev_itensor1(CompNode::load("cpux"), {3, 2}, dtype::Float32{});
DeviceTensorND cdev_otensor0(CompNode::load("cpux"), {3, 2}, dtype::Float32{});
DeviceTensorND cdev_otensor1(CompNode::load("cpux"), {3, 2}, dtype::Int32{});
std::vector<Tensor> cinputs = {to_custom_tensor(cdev_itensor0), to_custom_tensor(cdev_itensor1)};
std::vector<Tensor> coutputs ={to_custom_tensor(cdev_otensor0), to_custom_tensor(cdev_otensor1)};
param["device"] = "x86";
test.compute(cinputs, param, coutputs);
test.set_compute("cuda", gpu_kernel);
DeviceTensorND gdev_itensor0(CompNode::load("gpux"), {3, 2}, dtype::Int32{});
DeviceTensorND gdev_itensor1(CompNode::load("gpux"), {3, 2}, dtype::Float32{});
DeviceTensorND gdev_otensor0(CompNode::load("gpux"), {3, 2}, dtype::Float32{});
DeviceTensorND gdev_otensor1(CompNode::load("gpux"), {3, 2}, dtype::Int32{});
std::vector<Tensor> ginputs = {to_custom_tensor(gdev_itensor0), to_custom_tensor(gdev_itensor1)};
std::vector<Tensor> goutputs ={to_custom_tensor(gdev_otensor0), to_custom_tensor(gdev_otensor1)};
param["device"] = "cuda";
test.compute(ginputs, param, goutputs);
#endif
}
}
/**
* \file src/custom/test/param.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/custom/param.h"
#include "gtest/gtest.h"
#include <iostream>
#define PARAM_TEST_LOG 0
namespace custom {
#define SchemaDef \
ParamSchema schema_bool("param_bool", true, "bool"); \
ParamSchema schema_flt("param_flt", 2.3f, "float"); \
ParamSchema schema_int("param_int", 4, "int"); \
ParamSchema schema_str("param_str", "test", "string"); \
ParamSchema schema_bool_list("param_bl", {true, false, true}, "bool list"); \
ParamSchema schema_flt_list("param_fl", {1.1f, 2.2f, 3.3f}, "float list"); \
ParamSchema schema_int_list("param_il", {1, 2, 3}, "int list"); \
ParamSchema schema_str_list("param_sl", {"test1", "test2", "test3"}, "string list")
#define InfoDef \
info.meta().emplace_back(schema_bool); \
info.meta().emplace_back(schema_flt); \
info.meta().emplace_back(schema_int); \
info.meta().emplace_back(schema_str); \
info.meta().emplace_back(schema_bool_list); \
info.meta().emplace_back(schema_flt_list); \
info.meta().emplace_back(schema_int_list); \
info.meta().emplace_back(schema_str_list)
TEST(TestParam, TestParamScheme) {
#if PARAM_TEST_LOG
SchemaDef;
ParamSchema new_schema = schema_int;
std::cout << schema_bool.str() << std::endl;
std::cout << schema_flt.str() << std::endl;
std::cout << schema_int.str() << std::endl;
std::cout << schema_str.str() << std::endl;
std::cout << schema_bool_list.str() << "len: "<< schema_bool_list.default_val().size() << std::endl;
std::cout << schema_flt_list.str() << "len: "<< schema_flt_list.default_val().size() << std::endl;
std::cout << schema_int_list.str() << "len: "<< schema_int_list.default_val().size() << std::endl;
std::cout << schema_str_list.str() << "len: "<< schema_str_list.default_val().size() << std::endl;
std::cout << new_schema.str() << std::endl;
#endif
}
TEST(TestParam, TestParamVal) {
ParamVal pv1 = 1.2f, pv2 = true, pv3 = "test", pv4 = {0, 1, 2},
pv5 = {true, false, true};
#if PARAM_TEST_LOG
ParamVal pv6 = {"test1", "test2", "test3"};
std::cout << pv1.str() << std::endl;
std::cout << pv2.str() << std::endl;
std::cout << pv3.str() << std::endl;
std::cout << pv4.str() << std::endl;
std::cout << pv5.str() << std::endl;
std::cout << pv6.str() << std::endl;
#endif
ParamVal pv_manip = pv1;
ASSERT_TRUE(pv_manip.type() == pv1.type());
ASSERT_TRUE(pv_manip == pv1);
pv_manip = 1.3;
ASSERT_TRUE(pv_manip.type() != pv1.type());
ASSERT_TRUE(pv_manip != pv1);
ASSERT_TRUE(pv_manip > pv1);
pv_manip = pv_manip + pv1;
ASSERT_TRUE(pv_manip.type() == ParamDynType::Float64);
ASSERT_TRUE(pv_manip == 1.3 + 1.2f);
pv_manip = 1.3f + 1.2f;
ASSERT_TRUE(pv_manip.type() == pv1.type());
pv_manip = false;
ASSERT_TRUE(pv_manip.type() == pv2.type());
ASSERT_TRUE(pv_manip.type() == ParamDynType::Bool);
ASSERT_TRUE(pv_manip != pv2);
pv_manip = "test";
ASSERT_TRUE(pv_manip.type() == pv3.type());
ASSERT_TRUE(pv_manip.type() == ParamDynType::String);
ASSERT_TRUE(pv_manip == pv3);
pv_manip = "test1";
ASSERT_TRUE(pv_manip > pv3);
pv_manip = pv_manip + pv3;
ASSERT_TRUE(pv_manip == "test1test");
pv_manip = {0, 1, 2};
ASSERT_TRUE(pv_manip.type() == pv4.type());
ASSERT_TRUE(pv_manip.type() == ParamDynType::Int32List);
ASSERT_TRUE(pv_manip == pv4);
pv_manip = {3, 2, 1};
ASSERT_TRUE(pv_manip != pv4);
ASSERT_TRUE(pv_manip > pv4);
pv_manip = {true, false, true};
ASSERT_TRUE(pv_manip.type() == pv5.type());
ASSERT_TRUE(pv_manip.type() == ParamDynType::BoolList);
ASSERT_TRUE(pv_manip == pv5);
pv_manip = {false, true, false};
ASSERT_TRUE(pv_manip != pv5);
}
TEST(TestParam, TestParamInfo) {
ParamInfo info;
info.set_tag("Test");
#if PARAM_TEST_LOG
uint32_t tag = info.tag();
std::cout << tag << std::endl;
#endif
SchemaDef;
InfoDef;
ParamInfo new_info1, new_info2;
new_info1.set_meta(info.meta());
new_info2.meta() = info.meta();
#if PARAM_TEST_LOG
for (auto ele: new_info1.meta()) {
std::cout << ele.str() << std::endl;
}
for (auto ele: new_info2.meta()) {
std::cout << ele.str() << std::endl;
}
#endif
}
TEST(TestParam, TestParam) {
ParamInfo info;
SchemaDef;
InfoDef;
Param param(info);
#if PARAM_TEST_LOG
std::vector<std::string> names = {"param_bool", "param_flt", "param_int", "param_str", "param_bl", "param_fl", "param_il", "param_sl"};
for (auto &name: names) {
std::cout << param[name].str() << std::endl;;
}
#endif
ASSERT_TRUE(param["param_bool"] == true);
ASSERT_TRUE(param["param_flt"] == 2.3f);
ASSERT_TRUE(param["param_int"] == 4);
ASSERT_TRUE(param["param_str"] == "test");
ASSERT_TRUE(param["param_bl"] == ParamVal({true, false, true}));
ASSERT_TRUE(param["param_fl"] == ParamVal({1.1f, 2.2f, 3.3f}));
ASSERT_TRUE(param["param_il"] == ParamVal({1, 2, 3}));
ASSERT_TRUE(param["param_sl"] == ParamVal({"test1", "test2", "test3"}));
param["param_bool"] = false;
param["param_flt"] = 3.4f;
param["param_int"] = 5;
param["param_str"] = "tset";
param["param_bl"] = {false, true, false, true};
param["param_fl"] = {7.6f, 6.5f};
param["param_il"] = {5, 4, 3, 2, 1};
param["param_sl"] = {"1tset", "2tset", "3tset", "4tset", "5tset"};
ASSERT_TRUE(param["param_bool"] != true);
ASSERT_TRUE(param["param_flt"] != 2.3f);
ASSERT_TRUE(param["param_int"] != 4);
ASSERT_TRUE(param["param_str"] != "test");
ASSERT_TRUE(param["param_bl"] != ParamVal({true, false, true}));
ASSERT_TRUE(param["param_fl"] != ParamVal({1.1f, 2.2f, 3.3f}));
ASSERT_TRUE(param["param_il"] != ParamVal({1, 2, 3}));
ASSERT_TRUE(param["param_sl"] != ParamVal({"test1", "test2", "test3"}));
ASSERT_TRUE(param["param_bool"] == false);
ASSERT_TRUE(param["param_flt"] == 3.4f);
ASSERT_TRUE(param["param_int"] == 5);
ASSERT_TRUE(param["param_str"] == "tset");
ASSERT_TRUE(param["param_bl"] == ParamVal({false, true, false, true}));
ASSERT_TRUE(param["param_fl"] == ParamVal({7.6f, 6.5f}));
ASSERT_TRUE(param["param_il"] == ParamVal({5, 4, 3, 2, 1}));
ASSERT_TRUE(param["param_sl"] == ParamVal({"1tset", "2tset", "3tset", "4tset", "5tset"}));
#if PARAM_TEST_LOG
Param copy_param = param;
for (auto &name: names) {
std::cout << copy_param[name].str() << std::endl;
}
#endif
Param loaded_param(info);
std::string bytes = param.to_bytes();
loaded_param.from_bytes(bytes);
#if PARAM_TEST_LOG
for (auto &kv: loaded_param.raw()) {
std::cout << kv.first << ":\n" << kv.second.str() << std::endl;
}
#endif
}
}
/**
* \file src/custom/test/tensor.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/custom/tensor.h"
#include "megbrain/custom/data_adaptor.h"
#include "megbrain/comp_node.h"
#include "megbrain/tensor.h"
#include "gtest/gtest.h"
#include "megbrain_build_config.h"
#define TENSOR_TEST_LOG 0
using namespace mgb;
namespace custom {
TEST(TestDevice, TestDevice) {
#if MGB_CUDA
ASSERT_TRUE(Device::is_legal("x86"));
ASSERT_TRUE(Device::is_legal(DeviceEnum::cuda));
ASSERT_FALSE(Device::is_legal("cpu"));
Device dev1;
ASSERT_TRUE(dev1.str() == "invalid");
dev1 = "x86";
ASSERT_TRUE("x86" == dev1);
Device dev2 = "cuda";
ASSERT_TRUE(dev2 == "cuda");
ASSERT_FALSE(dev2 == dev1);
Device dev3 = dev2;
ASSERT_TRUE(dev3 == dev2);
ASSERT_FALSE(dev3 == dev1);
Device dev4 = DeviceEnum::cuda;
ASSERT_TRUE(dev4.enumv() == DeviceEnum::cuda);
#if TENSOR_TEST_LOG
std::cout << dev1.str() << "\n" << dev2.str() << "\n"
<< dev3.str() << "\n" << dev4.str() << std::endl;
#endif
CompNode compnode = to_builtin<CompNode, Device>(dev3);
ASSERT_TRUE(compnode.to_string_logical() == "gpux:0");
compnode = CompNode::load("cpu0:0");
Device dev5 = to_custom<CompNode, Device>(compnode);
ASSERT_TRUE(dev5.str() == "x86");
std::vector<Device> devs1 = {"x86", "cuda", "x86"};
megdnn::SmallVector<CompNode> compnodes = to_builtin<CompNode, Device>(devs1);
ASSERT_TRUE(compnodes[0].to_string_logical() == "cpux:0");
ASSERT_TRUE(compnodes[1].to_string_logical() == "gpux:0");
ASSERT_TRUE(compnodes[2].to_string_logical() == "cpux:0");
std::vector<Device> devs2 = to_custom<CompNode, Device>(compnodes);
ASSERT_TRUE(devs2[0] == "x86");
ASSERT_TRUE(devs2[1].str() == "cuda");
ASSERT_TRUE(devs2[2] == "x86");
#endif
}
TEST(TestShape, TestShape) {
Shape shape1, shape2;
ASSERT_TRUE(shape1.ndim() == 0);
shape1 = {16, 32, 8, 8};
shape2 = shape1;
ASSERT_TRUE(shape2.ndim() == 4);
ASSERT_TRUE(shape2[0] == 16);
ASSERT_TRUE(shape2[1] == 32);
ASSERT_TRUE(shape2[2] == 8);
ASSERT_TRUE(shape2[3] == 8);
Shape shape3 = {16, 32, 8, 8};
const Shape shape4 = shape1;
ASSERT_TRUE(shape3 == shape4);
shape3[0] = 32;
ASSERT_FALSE(shape3 == shape4);
ASSERT_TRUE(shape3[0] == 32);
ASSERT_TRUE(shape4[0] == 16);
Shape shape5 = {2, 3, 4};
TensorShape bshape1 = to_builtin<TensorShape, Shape>(shape5);
ASSERT_TRUE(bshape1.ndim == 3);
ASSERT_TRUE(bshape1[0] == 2);
ASSERT_TRUE(bshape1[1] == 3);
ASSERT_TRUE(bshape1[2] == 4);
bshape1 = {4, 2, 3};
Shape shape6 = to_custom<TensorShape, Shape>(bshape1);
ASSERT_TRUE(shape6.ndim() == 3);
ASSERT_TRUE(shape6[0] == 4);
ASSERT_TRUE(shape6[1] == 2);
ASSERT_TRUE(shape6[2] == 3);
Shape shape7;
shape7.ndim(3);
shape7[1] = 4;
ASSERT_TRUE(shape7 == Shape({0, 4, 0}));
std::vector<Shape> shapes1 = {{2, 3, 4}, {6}, {5, 7}};
megdnn::SmallVector<TensorShape> bshapes = to_builtin<TensorShape, Shape>(shapes1);
ASSERT_TRUE(bshapes[0].total_nr_elems() == 2*3*4);
ASSERT_TRUE(bshapes[1].total_nr_elems() == 6);
ASSERT_TRUE(bshapes[2].total_nr_elems() == 35);
std::vector<Shape> shapes2 = to_custom<TensorShape, Shape>(bshapes);
ASSERT_TRUE(shapes2[0] == Shape({2, 3, 4}));
ASSERT_TRUE(shapes2[1] == Shape({6}));
ASSERT_TRUE(shapes2[2] == Shape({5, 7}));
}
TEST(TestDType, TestDType) {
#if !MEGDNN_DISABLE_FLOAT16
ASSERT_TRUE(DType::is_legal("uint8"));
ASSERT_TRUE(DType::is_legal(DTypeEnum::bfloat16));
DType dtype1, dtype2;
ASSERT_TRUE(dtype1.str() == "invalid");
dtype1 = "float32";
ASSERT_TRUE(dtype1.str() == "float32");
dtype2 = dtype1;
DType dtype3 = dtype2;
ASSERT_TRUE(dtype3 == dtype1);
ASSERT_TRUE(dtype3 == "float32");
dtype3 = "int8";
ASSERT_FALSE("float32" == dtype3.str());
ASSERT_FALSE(dtype3 == dtype2);
DType dtype4 = DTypeEnum::int8, dtype5 = dtype3;
ASSERT_TRUE(dtype4 == dtype5);
ASSERT_TRUE(dtype4.is_compatible<int8_t>());
ASSERT_FALSE(dtype4.is_compatible<uint8_t>());
DType dtype6 = "int32";
megdnn::DType bdtype1 = to_builtin<megdnn::DType, DType>(dtype6);
ASSERT_TRUE(bdtype1.name() == std::string("Int32"));
bdtype1 = megdnn::DType::from_enum(megdnn::DTypeEnum::BFloat16);
DType dtype7 = to_custom<megdnn::DType, DType>(bdtype1);
ASSERT_TRUE(dtype7.enumv() == DTypeEnum::bfloat16);
std::vector<DType> dtypes1 = {"int8", "uint8", "float16"};
megdnn::SmallVector<megdnn::DType> bdtypes
= to_builtin<megdnn::DType, DType>(dtypes1);
ASSERT_TRUE(bdtypes[0].name() == std::string("Int8"));
ASSERT_TRUE(bdtypes[1].name() == std::string("Uint8"));
ASSERT_TRUE(bdtypes[2].name() == std::string("Float16"));
std::vector<DType> dtypes2 = to_custom<megdnn::DType, DType>(bdtypes);
ASSERT_TRUE(dtypes2[0] == "int8");
ASSERT_TRUE(dtypes2[1] == "uint8");
ASSERT_TRUE(dtypes2[2] == "float16");
#endif
}
TEST(TestDType, TestDTypeQuantized) {
DType quint8_1("quint8", 3.2, 15);
DType quint8_2("quint8", 3.2, 15);
DType quint8_3("quint8", 3.2, 16);
DType quint8_4("quint8", 3.1, 15);
ASSERT_TRUE(quint8_1 == quint8_2);
ASSERT_FALSE(quint8_1 == quint8_3);
ASSERT_FALSE(quint8_1 == quint8_4);
ASSERT_TRUE(quint8_1.scale() == 3.2f);
ASSERT_TRUE(quint8_1.zero_point() == 15);
DType qint8("qint8", 3.3f);
DType qint16("qint16", 3.4f);
DType qint32("qint32", 3.5f);
ASSERT_TRUE(qint8.scale() == 3.3f);
ASSERT_TRUE(qint16.scale() == 3.4f);
ASSERT_TRUE(qint32.scale() == 3.5f);
ASSERT_TRUE(qint8.enumv() == DTypeEnum::qint8);
ASSERT_TRUE(qint8.str() == "qint8");
}
TEST(TestFormat, TestFormat) {
Format format1, format2("default");
ASSERT_TRUE(format1.is_default());
ASSERT_TRUE(format2.is_default());
Format format3 = format1;
ASSERT_TRUE(format3.is_default());
}
TEST(TestTensor, TestTensor) {
CompNode builtin_device = CompNode::load("cpux:0");
TensorShape builtin_shape = {3, 2, 4};
megdnn::DType builtin_dtype = dtype::Int32{};
DeviceTensorND dev_tensor(builtin_device, builtin_shape, builtin_dtype);
Tensor tensor1 = to_custom<DeviceTensorND, Tensor>(dev_tensor);
Tensor tensor2 = to_custom<DeviceTensorND, Tensor>(dev_tensor);
Device device = tensor1.device();
Shape shape = tensor1.shape();
DType dtype = tensor1.dtype();
ASSERT_TRUE(device == "x86");
ASSERT_TRUE(shape.ndim() == 3);
ASSERT_TRUE(shape[0] == 3);
ASSERT_TRUE(shape[1] == 2);
ASSERT_TRUE(shape[2] == 4);
ASSERT_TRUE(shape == std::vector<size_t>({3, 2, 4}));
ASSERT_TRUE(dtype == "int32");
int *raw_ptr1 = tensor1.data<int>();
for (size_t i=0; i<tensor1.size(); i++)
raw_ptr1[i] = i;
int *raw_ptr2 = tensor2.data<int>();
for (size_t i=0; i<tensor2.size(); i++)
ASSERT_TRUE(raw_ptr2[i] == static_cast<int>(i));
Tensor tensor3 = tensor2;
int *raw_ptr3 = tensor3.data<int>();
for (size_t i=0; i<tensor3.size(); i++)
ASSERT_TRUE(raw_ptr3[i] == static_cast<int>(i));
ASSERT_TRUE(raw_ptr1 == raw_ptr2);
ASSERT_TRUE(raw_ptr1 == raw_ptr3);
for (size_t i=0; i<tensor3.size(); i++) {
raw_ptr3[i] = -static_cast<int>(i);
}
for (size_t i=0; i<tensor1.size(); i++) {
ASSERT_TRUE(raw_ptr1[i] == -static_cast<int>(i));
}
DeviceTensorND new_dev_tensor = to_builtin<DeviceTensorND, Tensor>(tensor3);
int *builtin_ptr = new_dev_tensor.ptr<int>();
for (size_t i=0; i<new_dev_tensor.shape().total_nr_elems(); i++) {
ASSERT_TRUE(builtin_ptr[i] == -static_cast<int>(i));
}
}
TEST(TestTensor, TestTensorQuantized) {
#if MGB_CUDA
CompNode builtin_device = CompNode::load("gpux:0");
TensorShape builtin_shape = {3, 2, 4};
megdnn::DType builtin_dtype = dtype::Quantized8Asymm{3.2f, uint8_t(15)};
DeviceTensorND dev_tensor(builtin_device, builtin_shape, builtin_dtype);
Tensor tensor1 = to_custom<DeviceTensorND, Tensor>(dev_tensor);
Tensor tensor2 = to_custom<DeviceTensorND, Tensor>(dev_tensor);
Device device1 = tensor1.device(), device2 = tensor2.device();
Shape shape1 = tensor1.shape(), shape2 = tensor2.shape();
DType dtype1 = tensor1.dtype(), dtype2 = tensor2.dtype();
ASSERT_TRUE(device1 == "cuda");
ASSERT_TRUE(shape1.ndim() == 3);
ASSERT_TRUE(shape1[0] == 3);
ASSERT_TRUE(shape1[1] == 2);
ASSERT_TRUE(shape1[2] == 4);
ASSERT_TRUE(shape1 == std::vector<size_t>({3, 2, 4}));
ASSERT_TRUE(dtype1 == "quint8");
ASSERT_TRUE(dtype1.scale() == 3.2f);
ASSERT_TRUE(dtype1.zero_point() == 15);
ASSERT_TRUE(device1 == device2);
ASSERT_TRUE(shape1 == shape2);
ASSERT_TRUE(dtype1 == dtype2);
#endif
}
TEST(TestTensor, TestTensorAccessorND) {
size_t N = 2, C = 4, H = 6, W = 8;
CompNode builtin_device = CompNode::load("cpux");
TensorShape builtin_shape = {N, C, H, W};
megdnn::DType builtin_dtype = dtype::Int32{};
DeviceTensorND dev_tensor(builtin_device, builtin_shape, builtin_dtype);
int *builtin_ptr = dev_tensor.ptr<int>();
for (size_t i=0; i<dev_tensor.shape().total_nr_elems(); i++) {
builtin_ptr[i] = i;
}
Tensor tensor = to_custom_tensor(dev_tensor);
auto accessor = tensor.accessor<int32_t, 4>();
for (size_t n=0; n<N; ++n) {
for (size_t c=0; c<C; ++c) {
for (size_t h=0; h<H; ++h) {
for (size_t w=0; w<W; ++w) {
int32_t idx = n*C*H*W + c*H*W + h*W + w;
ASSERT_TRUE(accessor[n][c][h][w] == idx);
}
}
}
}
}
TEST(TestTensor, TestTensorAccessor1D) {
CompNode builtin_device = CompNode::load("cpux");
TensorShape builtin_shape = {32};
megdnn::DType builtin_dtype = dtype::Float32{};
DeviceTensorND dev_tensor(builtin_device, builtin_shape, builtin_dtype);
float *builtin_ptr = dev_tensor.ptr<float>();
for (size_t i=0; i<dev_tensor.shape().total_nr_elems(); i++) {
builtin_ptr[i] = i;
}
Tensor tensor = to_custom_tensor(dev_tensor);
auto accessor = tensor.accessor<float, 1>();
for (size_t n=0; n<32; ++n) {
ASSERT_TRUE(accessor[n] == n);
}
}
}
......@@ -18,7 +18,7 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(CustomOpNode);
void CustomOpNode::infer_output_comp_node(void) {
SmallVector<CompNode> input_comp_nodes(input_num());
for (int i=0; i<input_num(); ++i) {
for (size_t i=0; i<input_num(); ++i) {
input_comp_nodes[i] = input(i)->comp_node();
}
......@@ -28,7 +28,7 @@ void CustomOpNode::infer_output_comp_node(void) {
)
);
for (int i=0; i<output_num(); ++i) {
for (size_t i=0; i<output_num(); ++i) {
mgb_assert(output_comp_nodes[i] == output_comp_nodes[0],
"only single comp node operator is supported");
output(i)->comp_node(output_comp_nodes[i]);
......@@ -39,7 +39,7 @@ void CustomOpNode::infer_output_comp_node(void) {
void CustomOpNode::infer_output_dtype(void) {
SmallVector<DType> input_dtypes(input_num());
for (int i=0; i<input_num(); ++i) {
for (size_t i=0; i<input_num(); ++i) {
input_dtypes[i] = input(i)->dtype();
}
......@@ -49,14 +49,14 @@ void CustomOpNode::infer_output_dtype(void) {
)
);
for (int i=0; i<output_num(); ++i) {
for (size_t i=0; i<output_num(); ++i) {
output(i)->dtype(output_dtypes[i]);
}
}
void CustomOpNode::infer_output_format(void) {
SmallVector<TensorFormat> input_formats(input_num());
for (int i=0; i<input_num(); ++i) {
for (size_t i=0; i<input_num(); ++i) {
input_formats[i] = input(i)->format();
}
......@@ -66,14 +66,14 @@ void CustomOpNode::infer_output_format(void) {
)
);
for (int i=0; i<output_num(); ++i) {
for (size_t i=0; i<output_num(); ++i) {
output(i)->format(output_formats[i]);
}
}
void CustomOpNode::infer_output_shape(void) {
SmallVector<TensorShape> input_shapes(input_num());
for (int i=0; i<input_num(); ++i) {
for (size_t i=0; i<input_num(); ++i) {
input_shapes[i] = input(i)->shape();
}
......@@ -83,7 +83,7 @@ void CustomOpNode::infer_output_shape(void) {
)
);
for (int i=0; i<output_num(); ++i) {
for (size_t i=0; i<output_num(); ++i) {
output(i)->shape(output_shapes[i]);
}
}
......@@ -235,10 +235,10 @@ CustomOpNode::CustomOpNode(const std::shared_ptr<const custom::CustomOp> &op,
const OperatorNodeConfig &config):
OperatorNodeBase(inputs[0]->owner_graph(), config, op->op_type(), inputs), m_op(op), m_param(param) {
mgb_assert(input_num() == inputs.size(), "wrong input tensors list length");
for (int i=0; i < input_num(); ++i)
for (size_t i=0; i < input_num(); ++i)
add_input({inputs[i]});
for (int i=0; i<output_num(); ++i)
for (size_t i=0; i<output_num(); ++i)
add_output(output_info(i).name());
if (!std::is_empty<custom::Param>::value) {
......@@ -306,11 +306,11 @@ std::string CustomOpNode::op_desc(void) const {
return m_op->op_desc();
}
int CustomOpNode::input_num(void) const {
size_t CustomOpNode::input_num(void) const {
return m_op->input_num();
}
int CustomOpNode::output_num(void) const {
size_t CustomOpNode::output_num(void) const {
return m_op->output_num();
}
......
......@@ -93,8 +93,8 @@ public:
custom::Param param(void) const;
std::string op_type(void) const;
std::string op_desc(void) const;
int input_num(void) const;
int output_num(void) const;
size_t input_num(void) const;
size_t output_num(void) const;
custom::ArgInfo input_info(size_t idx) const;
custom::ArgInfo output_info(size_t idx) const;
};
......
include_directories("./src/include")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter")
file(GLOB_RECURSE SOURCES ./*.cpp ../src/core/test/*.cpp ../src/gopt/test/*.cpp ../src/opr/test/*.cpp ../src/plugin/test/*.cpp ../src/serialization/test/*.cpp)
file(GLOB_RECURSE SOURCES ./*.cpp ../src/core/test/*.cpp ../src/gopt/test/*.cpp ../src/opr/test/*.cpp ../src/plugin/test/*.cpp ../src/serialization/test/*.cpp ../src/custom/test/*.cpp)
if(MGE_WITH_JIT)
file(GLOB_RECURSE SOURCES_ ../src/jit/test/*.cpp)
list(APPEND SOURCES ${SOURCES_})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册