cinn_launch_context.cc 10.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/operators/cinn/cinn_launch_context.h"
16
#include <algorithm>
17
#include <functional>
18
#include <utility>
19
#include <vector>
20 21 22 23
#include "cinn/hlir/framework/scope.h"
#include "cinn/hlir/framework/tensor.h"
#include "cinn/runtime/cinn_runtime.h"
#include "paddle/fluid/string/printf.h"
24
#include "paddle/pten/core/ddim.h"
25

26
namespace paddle {
27 28 29
namespace operators::details {

using LoDTensor = framework::LoDTensor;
30

31 32 33
CinnLaunchContext::CinnLaunchContext(
    const std::unordered_map<std::string, std::string>& paddle2cinn_varmap,
    const std::shared_ptr<CinnScope>& cinn_scope)
34 35
    : cinn_scope_(cinn_scope) {
  // generate all names of the cinn execution arguments
36
  auto var_names = cinn_scope_->var_names();
37
  cinn_argument_names_.reserve(var_names.size());
38
  std::transform(
39
      var_names.begin(), var_names.end(),
40
      std::inserter(cinn_argument_names_, cinn_argument_names_.end()),
41
      [](const auto& name_view) { return std::string(name_view.data()); });
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
  // build name map between the original variables and compiled ones
  BuildVarNameMap(paddle2cinn_varmap, cinn_argument_names_);
}

void CinnLaunchContext::BuildVarNameMap(
    const std::unordered_map<std::string, std::string>& compiled_varmap,
    const std::unordered_set<std::string>& argument_names) {
  for (const auto& x : compiled_varmap) {
    if (!argument_names.count(x.second)) {
      // exclude variables not used
      continue;
    }
    // copy to local paddle2cinn map
    paddle2cinn_varmap_.emplace(x.first, x.second);
    // add an entry to local cinn2paddle map reversely
57 58 59 60 61 62 63
    auto res = cinn2paddle_varmap_.emplace(x.second, x.first);
    PADDLE_ENFORCE_EQ(
        res.second, true,
        platform::errors::InvalidArgument(
            "Cinn variable(%s) maps to more than one paddle variable(%s,%s)",
            x.second, res.first->second, x.first));
  }
64 65 66 67
  // supplement the relations of the remain variables
  // not appearing in above map, which are internal variables
  // and here we use the names from cinn compiled.
  for (const auto& var_name : argument_names) {
68 69 70 71 72
    if (!cinn2paddle_varmap_.count(var_name)) {
      cinn2paddle_varmap_.emplace(var_name, var_name);
      paddle2cinn_varmap_.emplace(var_name, var_name);
    }
  }
73 74 75 76 77 78

  PADDLE_ENFORCE_EQ(
      paddle2cinn_varmap_.size(), cinn2paddle_varmap_.size(),
      platform::errors::PreconditionNotMet(
          "Size of variables is not euqal, paddle[%ld] vs cinn[%ld]",
          paddle2cinn_varmap_.size(), cinn2paddle_varmap_.size()));
79 80
}

81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
void CinnLaunchContext::UpdateCapturedEnv(const framework::Scope& scope,
                                          const platform::Place& place) {
  if (std::addressof(scope) == cached_scope_ &&
      std::addressof(place) == cached_place_) {
    VLOG(4) << "Captured scope:" << cached_scope_ << ", place:" << cached_place_
            << " are not changed";
    return;
  }
  cached_scope_ = std::addressof(scope);
  cached_place_ = std::addressof(place);
  cached_temp_scope_ = scope.NewTmpScope();
  VLOG(4) << "Captured env is update, scope:" << cached_scope_ << "->"
          << std::addressof(scope) << ", place:" << cached_place_ << "->"
          << std::addressof(place);
}

bool CinnLaunchContext::IsArgumentsInitialized() const {
  if (hold_buffers_.empty() || name2argument_.empty()) {
    return false;
  }
  return true;
}

104 105
bool CinnLaunchContext::IsVariableUsed(const std::string& var_name) const {
  return paddle2cinn_varmap_.count(var_name) > 0;
106 107
}

108 109 110 111 112
CinnTensor CinnLaunchContext::GetCinnTensor(const std::string& arg_name) {
  PADDLE_ENFORCE_GT(cinn_argument_names_.count(arg_name), 0,
                    platform::errors::InvalidArgument(
                        "Variable(%s) not found in cinn scope.", arg_name));
  return cinn_scope_->GetTensor(arg_name);
113 114
}

115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
std::unordered_set<std::string> CinnLaunchContext::ExtractInternalVarNames(
    const std::vector<std::string>& input_var_names,
    const std::vector<std::string>& output_var_names) {
  std::unordered_set<std::string> remain_var_names;
  remain_var_names.reserve(paddle2cinn_varmap_.size());
  std::transform(paddle2cinn_varmap_.begin(), paddle2cinn_varmap_.end(),
                 std::inserter(remain_var_names, remain_var_names.end()),
                 [](const auto& name_pair) { return name_pair.first; });

  // exclude the input variables and output variables
  auto exclude_names_fn = [&remain_var_names](const std::string& var_name) {
    remain_var_names.erase(var_name);
  };
  std::for_each(input_var_names.begin(), input_var_names.end(),
                exclude_names_fn);
  std::for_each(output_var_names.begin(), output_var_names.end(),
                exclude_names_fn);
  return remain_var_names;
133 134
}

135 136 137
void CinnLaunchContext::CheckTensorEquivalent(const std::string& var_name,
                                              const LoDTensor& paddle_tensor,
                                              const CinnTensor& cinn_tensor) {
138
  // check dimension
139
  auto cinn_dims = pten::make_ddim(cinn_tensor->shape().data());
140 141 142 143
  PADDLE_ENFORCE_EQ(paddle_tensor.dims(), cinn_dims,
                    platform::errors::PreconditionNotMet(
                        "Tensors' shape in variable(%s) are not equivalent, "
                        "paddle's shape = [%s], but cinn's shape = [%s].",
144
                        var_name, paddle_tensor.dims(), cinn_dims));
145 146 147 148

  // TODO(CtfGo): check the underlying data type after CINN ready
}

149 150 151 152 153 154 155 156
void CinnLaunchContext::AssignExternalVariable(const std::string& var_name) {
  PADDLE_ENFORCE_EQ(IsVariableUsed(var_name), true,
                    platform::errors::InvalidArgument(
                        "Variable(%s) not applied in cinn", var_name));
  const auto& cinn_arg_name = paddle2cinn_varmap_.at(var_name);

  const auto& paddle_tensor = cached_scope_->GetVar(var_name)->Get<LoDTensor>();
  CinnTensor cinn_tensor = GetCinnTensor(cinn_arg_name);
157
  if (paddle_tensor.IsInitialized()) {
158
    CheckTensorEquivalent(var_name, paddle_tensor, cinn_tensor);
159
  }
160 161 162 163 164 165

  auto cinn_buffer = std::make_unique<cinn_buffer_t>();
  // assign dimensions and alloc/free callback of cinn_buffer_t
  cinn_buffer->resize(cinn_tensor->shape().data().data(),
                      cinn_tensor->shape().data().size());
  cinn_buffer->external_malloc = new std::function<int(void*, cinn_buffer_t*)>(
166 167
      [this, var_name](void* ctx, cinn_buffer_t* buffer) {
        auto* tensor = cached_scope_->GetVar(var_name)->GetMutable<LoDTensor>();
168 169 170 171 172 173 174 175 176 177 178 179 180
        tensor->Resize(framework::DDim(buffer->dims, buffer->dimensions));
        buffer->memory = reinterpret_cast<uint8_t*>(
            tensor->mutable_data<float>(*cached_place_));
        return 0;
      });

  // external variables will be recycled by global gc, so do nothing here
  cinn_buffer->external_free = new std::function<int(void*, cinn_buffer_t*)>(
      [](void* ctx, cinn_buffer_t* buffer) {
        // Do nothing
        return 0;
      });

181
  return AppendArgument(cinn_arg_name, std::move(cinn_buffer));
182
}
183

184 185 186 187 188 189 190
void CinnLaunchContext::AssignInternalVariable(const std::string& var_name) {
  PADDLE_ENFORCE_EQ(IsVariableUsed(var_name), true,
                    platform::errors::InvalidArgument(
                        "Variable(%s) not applied in cinn", var_name));
  const auto& cinn_arg_name = paddle2cinn_varmap_.at(var_name);

  CinnTensor cinn_tensor = GetCinnTensor(cinn_arg_name);
191
  auto cinn_buffer = std::make_unique<cinn_buffer_t>();
192 193 194
  // assign dimensions and alloc/free callback of cinn_buffer_t
  cinn_buffer->resize(cinn_tensor->shape().data().data(),
                      cinn_tensor->shape().data().size());
195 196

  cinn_buffer->external_malloc = new std::function<int(void*, cinn_buffer_t*)>(
197
      [this, var_name](void* ctx, cinn_buffer_t* buffer) {
198
        auto* tensor =
199
            cached_temp_scope_->Var(var_name)->GetMutable<LoDTensor>();
200 201 202
        tensor->Resize(framework::DDim(buffer->dims, buffer->dimensions));
        buffer->memory = reinterpret_cast<uint8_t*>(
            tensor->mutable_data<float>(*cached_place_));
203 204 205
        return 0;
      });

206 207
  // internal variables should release its buffer immediately
  // if no instruction use it
208
  cinn_buffer->external_free = new std::function<int(void*, cinn_buffer_t*)>(
209
      [this, var_name](void* ctx, cinn_buffer_t* buffer) {
210
        auto* tensor =
211
            cached_temp_scope_->GetVar(var_name)->GetMutable<LoDTensor>();
212
        tensor->clear();
213 214
        return 0;
      });
215
  return AppendArgument(cinn_arg_name, std::move(cinn_buffer));
216 217
}

218 219 220
void CinnLaunchContext::AppendArgument(
    const std::string& arg_name, std::unique_ptr<cinn_buffer_t>&& buffer) {
  name2argument_.emplace(arg_name, buffer.get());
221
  hold_buffers_.emplace_back(std::move(buffer));
222 223 224 225
  VLOG(4) << string::Sprintf(
      "Append an argument:name(%s),dims(%s),argument size:(%lu)", arg_name,
      framework::DDim(buffer->dims, buffer->dimensions).to_str(),
      name2argument_.size());
226 227 228 229 230
}

const std::map<std::string, cinn_pod_value_t>&
CinnLaunchContext::FinalizeArguments() const {
  // Check all execution parameters are assigned valued.
231 232 233 234 235 236
  std::for_each(cinn_argument_names_.begin(), cinn_argument_names_.end(),
                [this](const auto& arg_name) {
                  PADDLE_ENFORCE_GT(
                      name2argument_.count(arg_name), 0,
                      platform::errors::NotFound(
                          "Argument(%s) is missed for execution", arg_name));
237
                });
238
  return name2argument_;
239 240
}

241
cinn_buffer_t* CinnLaunchContext::GetCinnBufferOfVar(
242 243
    const std::string& var_name) {
  auto it = paddle2cinn_varmap_.find(var_name);
244
  PADDLE_ENFORCE_NE(
245
      it, paddle2cinn_varmap_.end(),
246
      platform::errors::InvalidArgument(
247 248 249 250 251 252
          "Variable(%s) not found in compilation result", var_name));
  auto res = name2argument_.find(it->second);
  PADDLE_ENFORCE_NE(res, name2argument_.end(),
                    platform::errors::NotFound(
                        "Argument(%s) not be initialized", it->second));
  return static_cast<cinn_buffer_t*>(res->second);
253 254
}

255
}  // namespace operators::details
256
}  // namespace paddle