cl_scope.h 3.7 KB
Newer Older
L
liuruilong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include <map>
L
liuruilong 已提交
18 19 20
#include <memory>
#include <string>
#include <unordered_map>
21
#include <vector>
L
liuruilong 已提交
22

L
liuruilong 已提交
23
#include "CL/cl.h"
24 25 26
#include "framework/cl/cl_deleter.h"
#include "framework/cl/cl_engine.h"
#include "framework/cl/cl_tool.h"
L
liuruilong 已提交
27 28

namespace paddle_mobile {
29 30 31 32

extern const std::map<std::string, std::vector<unsigned char>> opencl_kernels;
extern const std::vector<std::string> need_conv_header_kernels;

L
liuruilong 已提交
33 34 35 36 37
namespace framework {

class CLScope {
 public:
  CLScope() {
xiebaiyuan's avatar
xiebaiyuan 已提交
38 39 40
    CLEngine *engine = CLEngine::Instance();
    context_ = engine->getContext();
    command_queue_ = engine->getClCommandQueue();
L
liuruilong 已提交
41 42
  }

xiebaiyuan's avatar
xiebaiyuan 已提交
43
  cl_command_queue CommandQueue() { return command_queue_; }
L
liuruilong 已提交
44

45
  std::unique_ptr<_cl_kernel, CLKernelDeleter> GetKernel(
46 47
      const std::string &kernel_name, const std::string &file_name,
      const std::string &options) {
L
liuruilong 已提交
48
    DLOG << " to get program " << file_name;
49
    auto program = Program(file_name, options);
L
liuruilong 已提交
50 51
    DLOG << " end get program ~ ";
    DLOG << " to create kernel: " << kernel_name;
52
    std::unique_ptr<_cl_kernel, CLKernelDeleter> kernel(
L
liuruilong 已提交
53 54
        clCreateKernel(program, kernel_name.c_str(), &status_));
    CL_CHECK_ERRORS(status_);
L
liuruilong 已提交
55
    DLOG << " end create kernel ~ ";
L
liuruilong 已提交
56 57 58
    return std::move(kernel);
  }

xiebaiyuan's avatar
xiebaiyuan 已提交
59
  cl_context Context() { return context_; }
L
liuruilong 已提交
60

61 62 63 64 65 66
  cl_program Program(const std::string &file_name, const std::string &options) {
    std::string program_key = file_name;
    if (!options.empty()) {
      program_key += options;
    }
    auto it = programs_.find(program_key);
L
liuruilong 已提交
67 68 69 70
    if (it != programs_.end()) {
      return it->second.get();
    }

71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
    if (opencl_kernels.find(file_name) != opencl_kernels.end()) {
      auto it = opencl_kernels.find(file_name);
      std::string source(it->second.begin(), it->second.end());
      if (std::find(need_conv_header_kernels.begin(),
                    need_conv_header_kernels.end(),
                    file_name) != need_conv_header_kernels.end()) {
        auto it = opencl_kernels.find("conv_kernel.inc.cl");
        std::string header(it->second.begin(), it->second.end());
        source = header + source;
      }
      auto program = CLEngine::Instance()->CreateProgramWithSource(
          context_, source.c_str());

      DLOG << " --- begin build program -> " << program_key << " --- ";
      CLEngine::Instance()->BuildProgram(program.get(), options);
      DLOG << " --- end build program -> " << program_key << " --- ";

      programs_[program_key] = std::move(program);
    } else {
      auto program = CLEngine::Instance()->CreateProgramWith(
          context_,
          CLEngine::Instance()->GetCLPath() + "/cl_kernel/" + file_name);

      DLOG << " --- begin build program -> " << program_key << " --- ";
      CLEngine::Instance()->BuildProgram(program.get(), options);
      DLOG << " --- end build program -> " << program_key << " --- ";

      programs_[program_key] = std::move(program);
    }
L
liuruilong 已提交
100

101
    return programs_[program_key].get();
L
liuruilong 已提交
102 103 104
  }

 private:
105
  cl_int status_;
xiebaiyuan's avatar
xiebaiyuan 已提交
106 107
  cl_context context_;
  cl_command_queue command_queue_;
108 109 110
  std::unordered_map<std::string,
                     std::unique_ptr<_cl_program, CLProgramDeleter>>
      programs_;
L
liuruilong 已提交
111 112
};

113 114
}  // namespace framework
}  // namespace paddle_mobile