target_wrapper.h 3.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17 18
#include <memory>                                 // std::unique_ptr
#include "lite/backends/xpu/xpu_header_sitter.h"  // xpu_free
19 20 21 22 23 24 25 26
#include "lite/core/target_wrapper.h"             // TargetWrapper
#include "lite/utils/cp_logging.h"                // CHECK_EQ

#define XPU_CALL(func)                                        \
  {                                                           \
    auto e = (func);                                          \
    CHECK_EQ(e, 0) << "XPU: (" << #func << ") returns " << e; \
  }
27 28 29 30

namespace paddle {
namespace lite {

31 32 33 34 35
// MAX(lod.size()) = 64
const int XPU_MAX_LOD_SIZE = 64;
// MAX(lod[i + 1] - lod[i]) = 512
const int XPU_MAX_LOD_SEQ_LEN = 512;

36 37
using TargetWrapperXPU = TargetWrapper<TARGET(kXPU)>;

38 39 40 41 42 43 44 45 46 47
struct XPUScratchPad {
  XPUScratchPad(void* addr, bool is_l3) : addr_(addr), is_l3_(is_l3) {}

  void* addr_{nullptr};
  bool is_l3_{false};
};

struct XPUScratchPadDeleter {
  void operator()(XPUScratchPad* sp) const {
    if (!sp->is_l3_) {
48
      XPU_CALL(xpu_free(sp->addr_));
49 50 51 52 53 54 55
    }
    delete sp;
  }
};

using XPUScratchPadGuard = std::unique_ptr<XPUScratchPad, XPUScratchPadDeleter>;

56 57 58 59 60 61 62 63 64 65 66 67 68
template <>
class TargetWrapper<TARGET(kXPU)> {
 public:
  static size_t num_devices() { return 1; }
  static size_t maximum_stream() { return 0; }

  static void* Malloc(size_t size);
  static void Free(void* ptr);

  static void MemcpySync(void* dst,
                         const void* src,
                         size_t size,
                         IoDirection dir);
69

70
  static XPUScratchPadGuard MallocScratchPad(size_t size, bool use_l3 = false);
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91

  static xdnn::Context* GetRawContext() {
    if (tls_raw_ctx_ == nullptr) {
      tls_raw_ctx_ = xdnn::create_context();
      CHECK(tls_raw_ctx_);
      int r = xdnn::set_workspace_l3_size(tls_raw_ctx_,
                                          workspace_l3_size_per_thread);
      if (r != 0) {
        LOG(WARNING) << "xdnn::set_workspace_l3_size() failed, r = " << r
                     << ", workspace_l3_size_per_thread = "
                     << workspace_l3_size_per_thread;
      }
    }
    return tls_raw_ctx_;
  }

  // **DEPRECATED**, use xpu_set_device() at the very beginning of each worker
  // thread
  static void SetDev(int dev_no = 0) {
    const char* dev_env = getenv("LITE_XPU_DEV");
    if (dev_env) {
92
      dev_no = atoi(dev_env);
93 94
    }

95
    XPU_CALL(xpu_set_device(dev_no));
96 97 98 99 100 101 102
  }

  static std::string multi_encoder_precision;  // NOLINT
  static int workspace_l3_size_per_thread;

 private:
  static thread_local xdnn::Context* tls_raw_ctx_;
103 104 105 106
};

}  // namespace lite
}  // namespace paddle