提交 601ba23b 编写于 作者: Y Yuan Shuai 提交者: GitHub

[LITE][OPENCL] add 3 data layouts for opencl image2d (#2561)

* add 3 layout for opencl image. test=develop
上级 e6582cb6
......@@ -179,8 +179,10 @@ if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM)
add_dependencies(tiny_publish_cxx_lib paddle_light_api_shared)
add_dependencies(tiny_publish_cxx_lib bundle_light_api)
add_dependencies(publish_inference tiny_publish_cxx_lib)
add_custom_command(TARGET tiny_publish_cxx_lib POST_BUILD
COMMAND ${CMAKE_STRIP} "-s" ${INFER_LITE_PUBLISH_ROOT}/cxx/lib/libpaddle_light_api_shared.so)
if(NOT "${CMAKE_BUILD_TYPE}" STREQUAL "Debug")
add_custom_command(TARGET tiny_publish_cxx_lib POST_BUILD
COMMAND ${CMAKE_STRIP} "-s" ${INFER_LITE_PUBLISH_ROOT}/cxx/lib/libpaddle_light_api_shared.so)
endif()
endif()
endif()
endif()
......
......@@ -77,7 +77,8 @@ const std::string& PrecisionToStr(PrecisionType precision) {
}
const std::string& DataLayoutToStr(DataLayoutType layout) {
static const std::string datalayout2string[] = {"unk", "NCHW", "any", "NHWC"};
static const std::string datalayout2string[] = {
"unk", "NCHW", "any", "NHWC", "ImageDefault", "ImageFolder", "ImageNW"};
auto x = static_cast<int>(layout);
CHECK_LT(x, static_cast<int>(DATALAYOUT(NUM)));
return datalayout2string[x];
......@@ -115,8 +116,13 @@ const std::string& PrecisionRepr(PrecisionType precision) {
}
const std::string& DataLayoutRepr(DataLayoutType layout) {
static const std::string datalayout2string[] = {
"kUnk", "kNCHW", "kAny", "kNHWC"};
static const std::string datalayout2string[] = {"kUnk",
"kNCHW",
"kAny",
"kNHWC",
"kImageDefault",
"kImageFolder",
"kImageNW"};
auto x = static_cast<int>(layout);
CHECK_LT(x, static_cast<int>(DATALAYOUT(NUM)));
return datalayout2string[x];
......@@ -146,8 +152,12 @@ std::set<PrecisionType> ExpandValidPrecisions(PrecisionType precision) {
}
std::set<DataLayoutType> ExpandValidLayouts(DataLayoutType layout) {
static const std::set<DataLayoutType> valid_set(
{DATALAYOUT(kNCHW), DATALAYOUT(kAny), DATALAYOUT(kNHWC)});
static const std::set<DataLayoutType> valid_set({DATALAYOUT(kNCHW),
DATALAYOUT(kAny),
DATALAYOUT(kNHWC),
DATALAYOUT(kImageDefault),
DATALAYOUT(kImageFolder),
DATALAYOUT(kImageNW)});
if (layout == DATALAYOUT(kAny)) {
return valid_set;
}
......
......@@ -71,8 +71,11 @@ enum class DataLayoutType : int {
kUnk = 0,
kNCHW = 1,
kNHWC = 3,
kAny = 2, // any data layout
NUM = 4, // number of fields.
kImageDefault = 4, // for opencl image2d
kImageFolder = 5, // for opencl image2d
kImageNW = 6, // for opencl image2d
kAny = 2, // any data layout
NUM = 7, // number of fields.
};
typedef enum {
......
......@@ -165,6 +165,9 @@ void BindLitePlace(py::module *m) {
py::enum_<DataLayoutType>(*m, "DataLayoutType")
.value("NCHW", DataLayoutType::kNCHW)
.value("NHWC", DataLayoutType::kNHWC)
.value("ImageDefault", DataLayoutType::kImageDefault)
.value("ImageFolder", DataLayoutType::kImageFolder)
.value("ImageNW", DataLayoutType::kImageNW)
.value("Any", DataLayoutType::kAny);
// Place
......
......@@ -40,6 +40,18 @@ std::list<std::unique_ptr<KernelBase>> KernelRegistry::Create(
return Create<TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(kNHWC)>(op_type); \
case DATALAYOUT(kImageDefault): \
return Create<TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(kImageDefault)>(op_type); \
case DATALAYOUT(kImageFolder): \
return Create<TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(kImageFolder)>(op_type); \
case DATALAYOUT(kImageNW): \
return Create<TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(kImageNW)>(op_type); \
default: \
LOG(FATAL) << "unsupported kernel layout " << DataLayoutToStr(layout); \
}
......@@ -147,6 +159,17 @@ KernelRegistry::KernelRegistry()
INIT_FOR(kOpenCL, kFloat, kAny);
INIT_FOR(kOpenCL, kInt8, kNCHW);
INIT_FOR(kOpenCL, kAny, kAny);
INIT_FOR(kOpenCL, kFP16, kNCHW);
INIT_FOR(kOpenCL, kFP16, kNHWC);
INIT_FOR(kOpenCL, kFP16, kImageDefault);
INIT_FOR(kOpenCL, kFP16, kImageFolder);
INIT_FOR(kOpenCL, kFP16, kImageNW);
INIT_FOR(kOpenCL, kFloat, kImageDefault);
INIT_FOR(kOpenCL, kFloat, kImageFolder);
INIT_FOR(kOpenCL, kFloat, kImageNW);
INIT_FOR(kOpenCL, kAny, kImageDefault);
INIT_FOR(kOpenCL, kAny, kImageFolder);
INIT_FOR(kOpenCL, kAny, kImageNW);
INIT_FOR(kNPU, kFloat, kNCHW);
INIT_FOR(kNPU, kInt8, kNCHW);
......
......@@ -176,6 +176,39 @@ class KernelRegistry final {
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kAny),
DATALAYOUT(kAny)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kNHWC)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageFolder)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageNW)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageDefault)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageFolder)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageNW)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kAny),
DATALAYOUT(kImageDefault)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kAny),
DATALAYOUT(kImageFolder)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kAny),
DATALAYOUT(kImageNW)> *, //
KernelRegistryForTarget<TARGET(kNPU),
PRECISION(kAny),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册