提交 601ba23b 编写于 作者: Y Yuan Shuai 提交者: GitHub

[LITE][OPENCL] add 3 data layouts for opencl image2d (#2561)

* add 3 layout for opencl image. test=develop
上级 e6582cb6
...@@ -179,11 +179,13 @@ if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM) ...@@ -179,11 +179,13 @@ if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM)
add_dependencies(tiny_publish_cxx_lib paddle_light_api_shared) add_dependencies(tiny_publish_cxx_lib paddle_light_api_shared)
add_dependencies(tiny_publish_cxx_lib bundle_light_api) add_dependencies(tiny_publish_cxx_lib bundle_light_api)
add_dependencies(publish_inference tiny_publish_cxx_lib) add_dependencies(publish_inference tiny_publish_cxx_lib)
if(NOT "${CMAKE_BUILD_TYPE}" STREQUAL "Debug")
add_custom_command(TARGET tiny_publish_cxx_lib POST_BUILD add_custom_command(TARGET tiny_publish_cxx_lib POST_BUILD
COMMAND ${CMAKE_STRIP} "-s" ${INFER_LITE_PUBLISH_ROOT}/cxx/lib/libpaddle_light_api_shared.so) COMMAND ${CMAKE_STRIP} "-s" ${INFER_LITE_PUBLISH_ROOT}/cxx/lib/libpaddle_light_api_shared.so)
endif() endif()
endif() endif()
endif() endif()
endif()
if (LITE_WITH_JAVA) if (LITE_WITH_JAVA)
......
...@@ -77,7 +77,8 @@ const std::string& PrecisionToStr(PrecisionType precision) { ...@@ -77,7 +77,8 @@ const std::string& PrecisionToStr(PrecisionType precision) {
} }
const std::string& DataLayoutToStr(DataLayoutType layout) { const std::string& DataLayoutToStr(DataLayoutType layout) {
static const std::string datalayout2string[] = {"unk", "NCHW", "any", "NHWC"}; static const std::string datalayout2string[] = {
"unk", "NCHW", "any", "NHWC", "ImageDefault", "ImageFolder", "ImageNW"};
auto x = static_cast<int>(layout); auto x = static_cast<int>(layout);
CHECK_LT(x, static_cast<int>(DATALAYOUT(NUM))); CHECK_LT(x, static_cast<int>(DATALAYOUT(NUM)));
return datalayout2string[x]; return datalayout2string[x];
...@@ -115,8 +116,13 @@ const std::string& PrecisionRepr(PrecisionType precision) { ...@@ -115,8 +116,13 @@ const std::string& PrecisionRepr(PrecisionType precision) {
} }
const std::string& DataLayoutRepr(DataLayoutType layout) { const std::string& DataLayoutRepr(DataLayoutType layout) {
static const std::string datalayout2string[] = { static const std::string datalayout2string[] = {"kUnk",
"kUnk", "kNCHW", "kAny", "kNHWC"}; "kNCHW",
"kAny",
"kNHWC",
"kImageDefault",
"kImageFolder",
"kImageNW"};
auto x = static_cast<int>(layout); auto x = static_cast<int>(layout);
CHECK_LT(x, static_cast<int>(DATALAYOUT(NUM))); CHECK_LT(x, static_cast<int>(DATALAYOUT(NUM)));
return datalayout2string[x]; return datalayout2string[x];
...@@ -146,8 +152,12 @@ std::set<PrecisionType> ExpandValidPrecisions(PrecisionType precision) { ...@@ -146,8 +152,12 @@ std::set<PrecisionType> ExpandValidPrecisions(PrecisionType precision) {
} }
std::set<DataLayoutType> ExpandValidLayouts(DataLayoutType layout) { std::set<DataLayoutType> ExpandValidLayouts(DataLayoutType layout) {
static const std::set<DataLayoutType> valid_set( static const std::set<DataLayoutType> valid_set({DATALAYOUT(kNCHW),
{DATALAYOUT(kNCHW), DATALAYOUT(kAny), DATALAYOUT(kNHWC)}); DATALAYOUT(kAny),
DATALAYOUT(kNHWC),
DATALAYOUT(kImageDefault),
DATALAYOUT(kImageFolder),
DATALAYOUT(kImageNW)});
if (layout == DATALAYOUT(kAny)) { if (layout == DATALAYOUT(kAny)) {
return valid_set; return valid_set;
} }
......
...@@ -71,8 +71,11 @@ enum class DataLayoutType : int { ...@@ -71,8 +71,11 @@ enum class DataLayoutType : int {
kUnk = 0, kUnk = 0,
kNCHW = 1, kNCHW = 1,
kNHWC = 3, kNHWC = 3,
kImageDefault = 4, // for opencl image2d
kImageFolder = 5, // for opencl image2d
kImageNW = 6, // for opencl image2d
kAny = 2, // any data layout kAny = 2, // any data layout
NUM = 4, // number of fields. NUM = 7, // number of fields.
}; };
typedef enum { typedef enum {
......
...@@ -165,6 +165,9 @@ void BindLitePlace(py::module *m) { ...@@ -165,6 +165,9 @@ void BindLitePlace(py::module *m) {
py::enum_<DataLayoutType>(*m, "DataLayoutType") py::enum_<DataLayoutType>(*m, "DataLayoutType")
.value("NCHW", DataLayoutType::kNCHW) .value("NCHW", DataLayoutType::kNCHW)
.value("NHWC", DataLayoutType::kNHWC) .value("NHWC", DataLayoutType::kNHWC)
.value("ImageDefault", DataLayoutType::kImageDefault)
.value("ImageFolder", DataLayoutType::kImageFolder)
.value("ImageNW", DataLayoutType::kImageNW)
.value("Any", DataLayoutType::kAny); .value("Any", DataLayoutType::kAny);
// Place // Place
......
...@@ -40,6 +40,18 @@ std::list<std::unique_ptr<KernelBase>> KernelRegistry::Create( ...@@ -40,6 +40,18 @@ std::list<std::unique_ptr<KernelBase>> KernelRegistry::Create(
return Create<TARGET(target__), \ return Create<TARGET(target__), \
PRECISION(precision__), \ PRECISION(precision__), \
DATALAYOUT(kNHWC)>(op_type); \ DATALAYOUT(kNHWC)>(op_type); \
case DATALAYOUT(kImageDefault): \
return Create<TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(kImageDefault)>(op_type); \
case DATALAYOUT(kImageFolder): \
return Create<TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(kImageFolder)>(op_type); \
case DATALAYOUT(kImageNW): \
return Create<TARGET(target__), \
PRECISION(precision__), \
DATALAYOUT(kImageNW)>(op_type); \
default: \ default: \
LOG(FATAL) << "unsupported kernel layout " << DataLayoutToStr(layout); \ LOG(FATAL) << "unsupported kernel layout " << DataLayoutToStr(layout); \
} }
...@@ -147,6 +159,17 @@ KernelRegistry::KernelRegistry() ...@@ -147,6 +159,17 @@ KernelRegistry::KernelRegistry()
INIT_FOR(kOpenCL, kFloat, kAny); INIT_FOR(kOpenCL, kFloat, kAny);
INIT_FOR(kOpenCL, kInt8, kNCHW); INIT_FOR(kOpenCL, kInt8, kNCHW);
INIT_FOR(kOpenCL, kAny, kAny); INIT_FOR(kOpenCL, kAny, kAny);
INIT_FOR(kOpenCL, kFP16, kNCHW);
INIT_FOR(kOpenCL, kFP16, kNHWC);
INIT_FOR(kOpenCL, kFP16, kImageDefault);
INIT_FOR(kOpenCL, kFP16, kImageFolder);
INIT_FOR(kOpenCL, kFP16, kImageNW);
INIT_FOR(kOpenCL, kFloat, kImageDefault);
INIT_FOR(kOpenCL, kFloat, kImageFolder);
INIT_FOR(kOpenCL, kFloat, kImageNW);
INIT_FOR(kOpenCL, kAny, kImageDefault);
INIT_FOR(kOpenCL, kAny, kImageFolder);
INIT_FOR(kOpenCL, kAny, kImageNW);
INIT_FOR(kNPU, kFloat, kNCHW); INIT_FOR(kNPU, kFloat, kNCHW);
INIT_FOR(kNPU, kInt8, kNCHW); INIT_FOR(kNPU, kInt8, kNCHW);
......
...@@ -176,6 +176,39 @@ class KernelRegistry final { ...@@ -176,6 +176,39 @@ class KernelRegistry final {
KernelRegistryForTarget<TARGET(kOpenCL), KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kAny), PRECISION(kAny),
DATALAYOUT(kAny)> *, // DATALAYOUT(kAny)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kNHWC)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageFolder)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageNW)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageDefault)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageFolder)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kFloat),
DATALAYOUT(kImageNW)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kAny),
DATALAYOUT(kImageDefault)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kAny),
DATALAYOUT(kImageFolder)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kAny),
DATALAYOUT(kImageNW)> *, //
KernelRegistryForTarget<TARGET(kNPU), KernelRegistryForTarget<TARGET(kNPU),
PRECISION(kAny), PRECISION(kAny),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册