提交 0ad5eeae 编写于 作者: M Megvii Engine Team

feat(mgb/gopt): global layout transform support opencl

GitOrigin-RevId: 132605c7d946d403dc2164a71cd3769b29ccfb31
上级 26146e5a
......@@ -246,6 +246,8 @@ NamedTensorShape NamedTensorShape::make_named_tensor_shape(Format format) {
return {{"N//8"}, {"C//8"}, {"H"}, {"W"}, {"C%8"}, {"N%8"}};
case Format::NCHW44_DOT:
return {{"N//4"}, {"C//4"}, {"H"}, {"W"}, {"N%4"}, {"C%4"}};
case Format::NHWCD4:
return {{"N"}, {"H"}, {"C//4"}, {"W"}, {"C%4"}};
default:
megdnn_throw(ssprintf("Format unimplement(%d)", static_cast<int>(format))
.c_str());
......
......@@ -229,6 +229,30 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW64> {
}
};
template <>
struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NHWCD4> {
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) {
OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NHWCD4;
config.config_id = OprFormatConfigID::NHWCD4;
bool available =
opr->input(0)->dtype().enumv() == DTypeEnum::Float32 ||
DNN_FLOAT16_SELECT(
(opr->input(0)->dtype().enumv() == DTypeEnum::Float16), true) ||
opr->input(0)->dtype().enumv() == DTypeEnum::Int8 ||
opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8;
config.input_dtypes = {opr->input(0)->dtype().enumv()};
config.input_tensor_types = {TensorType::FEATURE};
config.output_dtypes = {opr->output(0)->dtype().enumv()};
config.input_tensor_formats = {TensorFormats::NHCWc4};
config.output_tensor_formats = {TensorFormats::NHCWc4};
if (available)
return config;
return None;
}
};
template <typename Opr, OprFormatConfigID config_id>
struct ConvTensorFormatsDispatcherImpl;
......@@ -814,6 +838,55 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW44_DOT_HYBRID
}
};
template <typename Opr>
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NHWCD4> {
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) {
const auto& conv = opr->cast_final_safe<Opr>();
OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NHWCD4;
config.config_id = OprFormatConfigID::NHWCD4;
for (size_t i = 0; i < opr->input().size(); ++i) {
config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv());
TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE;
config.input_tensor_types.emplace_back(tensor_type);
}
config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv());
if (conv.param().sparse == Opr::Param::Sparse::DENSE) {
if (opr->input(1)->dtype().enumv() == DTypeEnum::QuantizedS8 ||
opr->input(1)->dtype().enumv() == DTypeEnum::Quantized8Asymm) {
config.input_tensor_formats = {
TensorFormats::NHCWc4, TensorFormats::KRSCk4c4,
TensorFormats::NHCWc4, TensorFormats::NHCWc4};
} else {
config.input_tensor_formats = {
TensorFormats::NHCWc4, TensorFormats::KRSCk4,
TensorFormats::NHCWc4, TensorFormats::NHCWc4};
}
} else {
mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP);
if (is_channel_wise_conv<Opr>(opr)) {
config.input_tensor_formats = {
TensorFormats::NHCWc4, TensorFormats::C1RSc4,
TensorFormats::NHCWc4, TensorFormats::NHCWc4};
} else {
if (opr->input(1)->dtype().enumv() == DTypeEnum::QuantizedS8 ||
opr->input(1)->dtype().enumv() == DTypeEnum::Quantized8Asymm) {
config.input_tensor_formats = {
TensorFormats::NHCWc4, TensorFormats::GKRSCk4c4,
TensorFormats::NHCWc4, TensorFormats::NHCWc4};
} else {
config.input_tensor_formats = {
TensorFormats::NHCWc4, TensorFormats::GKRSCk4,
TensorFormats::NHCWc4, TensorFormats::NHCWc4};
}
}
}
config.output_tensor_formats = {TensorFormats::NHCWc4};
return config;
}
};
template <>
struct ConvTensorFormatsDispatcherImpl<
opr::ConvolutionBackwardData, OprFormatConfigID::NCHW> {
......@@ -919,6 +992,57 @@ struct ConvTensorFormatsDispatcherImpl<
}
};
template <>
struct ConvTensorFormatsDispatcherImpl<
opr::ConvolutionBackwardData, OprFormatConfigID::NHWCD4> {
using Opr = opr::ConvolutionBackwardData;
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) {
const auto& conv = opr->cast_final_safe<Opr>();
OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NHWCD4;
config.config_id = OprFormatConfigID::NHWCD4;
for (size_t i = 0; i < opr->input().size(); ++i) {
config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv());
TensorType tensor_type = i == 0 ? TensorType::WEIGHT : TensorType::FEATURE;
config.input_tensor_types.emplace_back(tensor_type);
}
config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv());
if (conv.param().sparse == Opr::Param::Sparse::DENSE) {
if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8 ||
opr->input(0)->dtype().enumv() == DTypeEnum::Quantized8Asymm) {
config.input_tensor_formats = {
TensorFormats::KRSCk4c4, TensorFormats::NHCWc4,
TensorFormats::NHCWc4, TensorFormats::NHCWc4};
} else {
config.input_tensor_formats = {
TensorFormats::KRSCk4, TensorFormats::NHCWc4,
TensorFormats::NHCWc4, TensorFormats::NHCWc4};
}
} else {
mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP);
if (is_channel_wise_conv<Opr>(opr)) {
config.input_tensor_formats = {
TensorFormats::C1RSc4, TensorFormats::NHCWc4,
TensorFormats::NHCWc4, TensorFormats::NHCWc4};
} else {
if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8 ||
opr->input(0)->dtype().enumv() == DTypeEnum::Quantized8Asymm) {
config.input_tensor_formats = {
TensorFormats::GKRSCk4c4, TensorFormats::NHCWc4,
TensorFormats::NHCWc4, TensorFormats::NHCWc4};
} else {
config.input_tensor_formats = {
TensorFormats::GKRSCk4, TensorFormats::NHCWc4,
TensorFormats::NHCWc4, TensorFormats::NHCWc4};
}
}
}
config.output_tensor_formats = {TensorFormats::NHCWc4};
return config;
}
};
struct StaticData {
struct KeyHash {
size_t operator()(const std::pair<Typeinfo*, OprFormatConfigID>& val) const {
......@@ -969,6 +1093,7 @@ StaticData::StaticData() {
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_DOT);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_HYBRID);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_DOT_HYBRID);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NHWCD4);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NHWC);
......@@ -979,15 +1104,18 @@ StaticData::StaticData() {
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_DOT);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_HYBRID);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_DOT_HYBRID);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NHWCD4);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NCHW);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NHWC);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NCHW4);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NHWCD4);
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(WarpPerspectiveForward, NCHW);
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(WarpPerspectiveForward, NHWC);
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(WarpPerspectiveForward, NCHW4);
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(WarpPerspectiveForward, NCHW64);
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(WarpPerspectiveForward, NHWCD4);
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW);
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NHWC);
......@@ -997,10 +1125,12 @@ StaticData::StaticData() {
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW64);
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW44);
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW88);
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NHWCD4);
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW);
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW44);
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW88);
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NHWCD4);
#undef OPR_TENSOR_FORMATS_CONFIG_REG
#undef OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG
......
......@@ -22,6 +22,7 @@
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/plugin/base.h"
#include "megbrain/serialization/sereg.h"
#include "megdnn/tensor_format.h"
using namespace mgb;
using namespace cg;
......@@ -281,9 +282,6 @@ float ProfilerImpl::profile_operator(
std::min(config.input_tensor_formats.size(), opr->input().size());
for (; i < nr_input_tensor; ++i) {
auto&& var = opr->input(i);
auto&& cn = var->comp_node();
auto&& dtype = var->dtype();
auto dval = std::make_shared<DeviceTensorND>(cn, dtype);
TensorShape aligned_shape;
if (config.input_tensor_types[i] == TensorType::WEIGHT) {
mgb_assert(base_config.input_tensor_types[i] == TensorType::WEIGHT);
......@@ -299,9 +297,12 @@ float ProfilerImpl::profile_operator(
var, base_config.input_tensor_formats[i],
config.input_tensor_formats[i], extra_attribute);
}
dval->resize(aligned_shape);
std::shared_ptr<DeviceTensorND> dval = create_device_tensor_helper(
config, i, var, aligned_shape, extra_attribute);
if (config.input_tensor_types[i] == TensorType::WEIGHT) {
new_inps[i] = opr::SharedDeviceTensor::make_const(*graph, dval).node();
new_inps[i] =
opr::SharedDeviceTensorWithFormat::make_const(*graph, dval).node();
} else {
new_inps[i] = opr::VolatileSharedDeviceTensor::make(*graph, dval).node();
}
......@@ -368,10 +369,27 @@ float ProfilerImpl::profile_var_node(
const VarNode* var, TensorFormats base_format, const ReformatKey& key) const {
auto&& cn = var->comp_node();
auto&& dtype = var->dtype();
auto dval = std::make_shared<DeviceTensorND>(cn, dtype);
auto aligned_tensor_shape = ReformatManager::make_aligned_tensor_shape(
var, base_format, key.input_format, key.attribute);
dval->resize(aligned_tensor_shape);
std::shared_ptr<DeviceTensorND> dval;
if (key.input_format == TensorFormats::NHCWc4 &&
key.attribute & ReformatAttribute::IMAGE2D) {
size_t align_axis = 2;
auto named_tensor = tensor_formats_to_named_tensor_shape(key.input_format);
for (size_t n = 0; n < named_tensor.ndim; n++) {
if (named_tensor[n].name() == megdnn::Dimension::Name::C) {
align_axis = n;
break;
}
}
dval = std::make_shared<DeviceTensorND>(
cn, aligned_tensor_shape, dtype,
megdnn::Image2DPack4TensorFormat::make(
align_axis, opr::intl::get_megdnn_handle(cn)));
} else
dval = std::make_shared<DeviceTensorND>(cn, aligned_tensor_shape, dtype);
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
graph->options().var_sanity_check_first_run = false;
......@@ -516,6 +534,8 @@ ProfilerImpl::OprFormatConfigID ProfilerImpl::tensor_formats_to_config_id(
return OprFormatConfigID::NHWC;
case TensorFormats::CHWNc4:
return OprFormatConfigID::CHWN4;
case TensorFormats::NHCWc4:
return OprFormatConfigID::NHWCD4;
default:
mgb_throw(
MegBrainError, "tensor format(%u) is not supported",
......@@ -523,6 +543,39 @@ ProfilerImpl::OprFormatConfigID ProfilerImpl::tensor_formats_to_config_id(
}
}
std::shared_ptr<DeviceTensorND> ProfilerImpl::create_device_tensor_helper(
const OprTensorFormatsConfiguration& config, const size_t inp_idx,
const VarNode* var, const TensorShape aligned_shape,
ReformatAttribute extra_attribute) const {
auto&& cn = var->comp_node();
auto&& dtype = var->dtype();
std::shared_ptr<DeviceTensorND> dval;
if (config.config_id == OprFormatConfigID::NHWCD4 &&
extra_attribute & ReformatAttribute::IMAGE2D) {
size_t align_axis = 2;
auto named_tensor = tensor_formats_to_named_tensor_shape(
config.input_tensor_formats[inp_idx]);
for (size_t n = 0; n < named_tensor.ndim; n++) {
if (named_tensor[n].name() == megdnn::Dimension::Name::C) {
align_axis = n;
break;
}
}
// channel wise weight
bool is_channel_wise =
config.input_tensor_formats[inp_idx] == TensorFormats::C1RSc4;
if (is_channel_wise)
align_axis = 1;
dval = std::make_shared<DeviceTensorND>(
cn, aligned_shape, dtype,
megdnn::Image2DPack4TensorFormat::make(
align_axis, opr::intl::get_megdnn_handle(cn)));
} else {
dval = std::make_shared<DeviceTensorND>(cn, aligned_shape, dtype);
}
return dval;
}
/* ================== ProfilerBase =================*/
std::string ProfilerBase::OperatorNodeRecord::to_string() const {
auto str = ssprintf(
......
......@@ -249,7 +249,7 @@ ReformatManager::ReformatManager() {
m_cache.emplace(ReformatKey{i, o, Attribute::IMAGE2D}, impl);
}
{
auto i = TensorFormats::KCRS, o = TensorFormats::GKRSCk4;
auto i = TensorFormats::GKCRS, o = TensorFormats::GKRSCk4;
auto&& impl = [](const VarNodeArray& vars) {
return opr::RelayoutFormat::make(
vars[0],
......@@ -259,7 +259,7 @@ ReformatManager::ReformatManager() {
m_cache.emplace(ReformatKey{i, o, Attribute::IMAGE2D}, impl);
}
{
auto i = TensorFormats::KCRS, o = TensorFormats::C1RSc4;
auto i = TensorFormats::C11RS, o = TensorFormats::C1RSc4;
auto&& impl = [](const VarNodeArray& vars) {
return opr::RelayoutFormat::make(
vars[0],
......@@ -268,6 +268,21 @@ ReformatManager::ReformatManager() {
};
m_cache.emplace(ReformatKey{i, o, Attribute::IMAGE2D}, impl);
}
{
auto i = TensorFormats::NCHW, o = TensorFormats::NHCWc4;
auto&& impl1 = [](const VarNodeArray& vars) {
return opr::RelayoutFormat::make(
vars[0], megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4)
.node();
};
m_cache.emplace(ReformatKey{i, o}, impl1);
auto&& impl2 = [](const VarNodeArray& vars) {
return opr::RelayoutFormat::make(
vars[0], megdnn::param::RelayoutFormat::Mode::NHWCD4_NCHW)
.node();
};
m_cache.emplace(ReformatKey{o, i}, impl2);
}
{
auto i = TensorFormats::NCHW, o = TensorFormats::NHCWc4;
auto&& impl = [](const VarNodeArray& vars) {
......@@ -281,7 +296,7 @@ ReformatManager::ReformatManager() {
auto i = TensorFormats::NHCWc4, o = TensorFormats::NCHW;
auto&& impl = [](const VarNodeArray& vars) {
return opr::RelayoutFormat::make(
vars[0], megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I)
vars[0], megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW)
.node();
};
m_cache.emplace(ReformatKey{i, o, Attribute::IMAGE2D}, impl);
......@@ -346,6 +361,15 @@ ReformatManager::ReformatImpl ReformatManager::get(const ReformatKey& key) const
return rst;
}
}
if (key.attribute == Attribute::IMAGE2D) {
auto key_ = key;
key_.input_dtype = DTypeEnum::Float32;
key_.output_dtype = DTypeEnum::Float32;
auto find = m_cache.find(key_);
if (find != m_cache.end()) {
return find->second;
}
}
mgb_assert(
!(key.attribute & Attribute::IMAGE2D) &&
!(key.attribute & Attribute::IC_SMALL));
......@@ -682,7 +706,8 @@ TensorShape ReformatManager::make_aligned_weight_shape(
auto target_shape = tensor_formats_to_named_tensor_shape(target_formats);
for (size_t i = 0; i < target_shape.ndim; ++i) {
auto name = target_shape[i].name();
if ((name == Dimension::Name::K || name == Dimension::Name::N) &&
if ((name == Dimension::Name::K || name == Dimension::Name::N ||
(extra_formats == TensorFormats::NHCWc4 && name == Dimension::Name::C)) &&
target_shape[i].extent() == UNDETERMINED_EXTENT) {
size_t out_channels = tshp[i] * target_shape[i].stride();
tshp[i] = divup(out_channels, out_channel_alignment) *
......
......@@ -32,6 +32,7 @@ static inline const char* opr_format_to_string(
cb(NCHW44);
cb(NCHW88);
cb(NCHW44_DOT);
cb(NHWCD4);
default:
mgb_assert(
false, "Invalid opr format(got:%u)",
......@@ -63,6 +64,7 @@ static inline const char* config_id_to_string(
cb(NCHW88_HYBRID);
cb(NCHW44_DOT);
cb(NCHW44_DOT_HYBRID);
cb(NHWCD4);
default:
mgb_assert(
false, "Invalid config id(got:%u)",
......@@ -95,6 +97,8 @@ static inline TensorFormats opr_format_to_tensor_formats(
return TensorFormats::NCHWc8;
case OprFormat::NCHW44_DOT:
return TensorFormats::NCHWc4;
case OprFormat::NHWCD4:
return TensorFormats::NHCWc4;
default:
mgb_throw(
AssertionError, "format(%s) is not supported",
......
......@@ -202,6 +202,11 @@ protected:
const ReformatKey& key) const;
OprFormatConfigID tensor_formats_to_config_id(TensorFormats tensor_format) const;
std::shared_ptr<DeviceTensorND> create_device_tensor_helper(
const OprTensorFormatsConfiguration& config, const size_t inp_idx,
const VarNode* var, const TensorShape aligned_shape,
ReformatAttribute extra_attribute) const;
OprFootprint m_opr_footprint;
float m_opr_threshold; /// a threshold, when the computation of the newly
/// created operator that is built in some opr
......
此差异由.gitattributes 抑制。
......@@ -336,6 +336,10 @@ cg::OperatorNodeBase::NodeProp* VolatileSharedDeviceTensor::do_make_node_prop()
return ret;
}
void VolatileSharedDeviceTensor::init_output_format() {
output(0)->format(get_dev_tensor().format());
}
SymbolVar VolatileSharedDeviceTensor::make(
ComputingGraph& graph, const std::shared_ptr<DeviceTensorND>& dev_data,
const OperatorNodeConfig& config) {
......
......@@ -337,6 +337,8 @@ MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
public:
using Super::Super;
void init_output_format() override;
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
ComputingGraph& graph, const std::shared_ptr<DeviceTensorND>& dev_data,
const OperatorNodeConfig& config = {});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册