add_layout.md 8.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
# 如何增加Layout

Paddle-Lite中Place包含了Target、Layout、Precision信息,用来注册和选择模型中的具体Kernel。下面以增加Place中的layout:`ImageDefault``ImageFolder``ImageNW`为例,讲解如何增加新Layout。

根据在`lite/core/``lite/api`目录下以`NHWC`为关键词检索代码,发现需要分别在以下的文件中加入Layout内容:

1. lite/api/paddle_place.h
2. lite/api/paddle_place.cc
3. lite/api/python/pybind/pybind.cc
4. lite/core/op_registry.h
5. lite/core/op_registry.cc

## 1. lite/api/paddle_place.h

`enum class DataLayoutType`中加入对应的Layout,注意已有的Layout不能改变值,增加新Layout递增即可:

```cpp
enum class DataLayoutType : int {
  kUnk = 0,
  kNCHW = 1,
  kNHWC = 3,
  kImageDefault = 4,  // for opencl image2d
  kImageFolder = 5,   // for opencl image2d
  kImageNW = 6,       // for opencl image2d
  kAny = 2,           // any data layout
  NUM = 7,            // number of fields.
};
```

## 2. lite/api/paddle_place.cc

本文件有3处修改,注意在` DataLayoutToStr`函数中加入对应Layout的字符串名,顺序为`lite/api/paddle_place.h`中枚举值的顺序:

```cpp
// 该文件第1处
const std::string& DataLayoutToStr(DataLayoutType layout) {
  static const std::string datalayout2string[] = {
      "unk", "NCHW", "any", "NHWC", "ImageDefault", "ImageFolder", "ImageNW"};
  auto x = static_cast<int>(layout);
  CHECK_LT(x, static_cast<int>(DATALAYOUT(NUM)));
  return datalayout2string[x];
}

// 该文件第2处
const std::string& DataLayoutRepr(DataLayoutType layout) {
  static const std::string datalayout2string[] = {"kUnk",
                                                  "kNCHW",
                                                  "kAny",
                                                  "kNHWC",
                                                  "kImageDefault",
                                                  "kImageFolder",
                                                  "kImageNW"};
  auto x = static_cast<int>(layout);
  CHECK_LT(x, static_cast<int>(DATALAYOUT(NUM)));
  return datalayout2string[x];
}

// 该文件第3处
std::set<DataLayoutType> ExpandValidLayouts(DataLayoutType layout) {
  static const std::set<DataLayoutType> valid_set({DATALAYOUT(kNCHW),
                                                   DATALAYOUT(kAny),
                                                   DATALAYOUT(kNHWC),
                                                   DATALAYOUT(kImageDefault),
                                                   DATALAYOUT(kImageFolder),
                                                   DATALAYOUT(kImageNW)});
  if (layout == DATALAYOUT(kAny)) {
    return valid_set;
  }
  return std::set<DataLayoutType>({layout});
}
```

## 3. lite/api/python/pybind/pybind.cc

```cpp
  // DataLayoutType
  py::enum_<DataLayoutType>(*m, "DataLayoutType")
      .value("NCHW", DataLayoutType::kNCHW)
      .value("NHWC", DataLayoutType::kNHWC)
      .value("ImageDefault", DataLayoutType::kImageDefault)
      .value("ImageFolder", DataLayoutType::kImageFolder)
      .value("ImageNW", DataLayoutType::kImageNW)
      .value("Any", DataLayoutType::kAny);
```

## 4. lite/core/op_registry.h

找到KernelRegister final中的`using any_kernel_registor_t =`,加入下面修改信息:

```cpp
// 找到KernelRegister final中的`using any_kernel_registor_t =`
// 加入如下内容:
              KernelRegistryForTarget<TARGET(kOpenCL),
                                      PRECISION(kFP16),
                                      DATALAYOUT(kNCHW)> *,  //
              KernelRegistryForTarget<TARGET(kOpenCL),
                                      PRECISION(kFP16),
                                      DATALAYOUT(kNHWC)> *,  //
              KernelRegistryForTarget<TARGET(kOpenCL),
                                      PRECISION(kFP16),
                                      DATALAYOUT(kImageDefault)> *,  //
              KernelRegistryForTarget<TARGET(kOpenCL),
                                      PRECISION(kFP16),
                                      DATALAYOUT(kImageFolder)> *,  //
              KernelRegistryForTarget<TARGET(kOpenCL),
                                      PRECISION(kFP16),
                                      DATALAYOUT(kImageNW)> *,  //
              KernelRegistryForTarget<TARGET(kOpenCL),
                                      PRECISION(kFloat),
                                      DATALAYOUT(kImageDefault)> *,  //
              KernelRegistryForTarget<TARGET(kOpenCL),
                                      PRECISION(kFloat),
                                      DATALAYOUT(kImageFolder)> *,  //
              KernelRegistryForTarget<TARGET(kOpenCL),
                                      PRECISION(kFloat),
                                      DATALAYOUT(kImageNW)> *,  //
              KernelRegistryForTarget<TARGET(kOpenCL),
                                      PRECISION(kAny),
                                      DATALAYOUT(kImageDefault)> *,  //
              KernelRegistryForTarget<TARGET(kOpenCL),
                                      PRECISION(kAny),
                                      DATALAYOUT(kImageFolder)> *,  //
              KernelRegistryForTarget<TARGET(kOpenCL),
                                      PRECISION(kAny),
                                      DATALAYOUT(kImageNW)> *,  //
```


## 5. lite/core/op_registry.cc

该文件有2处修改:

```cpp
// 该文件第1处
#define CREATE_KERNEL1(target__, precision__)                                \
  switch (layout) {                                                          \
    case DATALAYOUT(kNCHW):                                                  \
      return Create<TARGET(target__),                                        \
                    PRECISION(precision__),                                  \
                    DATALAYOUT(kNCHW)>(op_type);                             \
    case DATALAYOUT(kAny):                                                   \
      return Create<TARGET(target__),                                        \
                    PRECISION(precision__),                                  \
                    DATALAYOUT(kAny)>(op_type);                              \
    case DATALAYOUT(kNHWC):                                                  \
      return Create<TARGET(target__),                                        \
                    PRECISION(precision__),                                  \
                    DATALAYOUT(kNHWC)>(op_type);                             \
    case DATALAYOUT(kImageDefault):                                          \
      return Create<TARGET(target__),                                        \
                    PRECISION(precision__),                                  \
                    DATALAYOUT(kImageDefault)>(op_type);                     \
    case DATALAYOUT(kImageFolder):                                           \
      return Create<TARGET(target__),                                        \
                    PRECISION(precision__),                                  \
                    DATALAYOUT(kImageFolder)>(op_type);                      \
    case DATALAYOUT(kImageNW):                                               \
      return Create<TARGET(target__),                                        \
                    PRECISION(precision__),                                  \
                    DATALAYOUT(kImageNW)>(op_type);                          \
    default:                                                                 \
      LOG(FATAL) << "unsupported kernel layout " << DataLayoutToStr(layout); \
  }

// 该文件第2处
// 找到文件中的下面的函数
KernelRegistry::KernelRegistry()
    : registries_(static_cast<int>(TARGET(NUM)) *
                  static_cast<int>(PRECISION(NUM)) *
                  static_cast<int>(DATALAYOUT(NUM)))

// 在该函数中加入新增Layout的下面内容
  INIT_FOR(kOpenCL, kFP16, kNCHW);
  INIT_FOR(kOpenCL, kFP16, kNHWC);
  INIT_FOR(kOpenCL, kFP16, kImageDefault);
  INIT_FOR(kOpenCL, kFP16, kImageFolder);
  INIT_FOR(kOpenCL, kFP16, kImageNW);
  INIT_FOR(kOpenCL, kFloat, kImageDefault);
  INIT_FOR(kOpenCL, kFloat, kImageFolder);
  INIT_FOR(kOpenCL, kFloat, kImageNW);
  INIT_FOR(kOpenCL, kAny, kImageDefault);
  INIT_FOR(kOpenCL, kAny, kImageFolder);
  INIT_FOR(kOpenCL, kAny, kImageNW);
```