cinn_launch_context.cc 9.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/operators/cinn/cinn_launch_context.h"
16
#include <functional>
17
#include <vector>
18

19 20
namespace paddle {
namespace operators {
21 22
namespace details {

23 24 25 26
CinnLaunchContext::CinnLaunchContext(
    const std::unordered_map<std::string, std::string>& paddle2cinn_varmap,
    const std::shared_ptr<CinnScope>& cinn_scope)
    : paddle2cinn_varmap_(paddle2cinn_varmap), cinn_scope_(cinn_scope) {
27
  // generate all names of cinn used variables
28 29
  auto var_names = cinn_scope_->var_names();
  cinn_variable_names_.reserve(var_names.size());
30
  std::transform(
31 32 33
      var_names.begin(), var_names.end(),
      std::inserter(cinn_variable_names_, cinn_variable_names_.end()),
      [](const auto& name_view) { return std::string(name_view.data()); });
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
  // build the variable name map of cinn2paddle
  for (const auto& x : paddle2cinn_varmap_) {
    auto res = cinn2paddle_varmap_.emplace(x.second, x.first);
    PADDLE_ENFORCE_EQ(
        res.second, true,
        platform::errors::InvalidArgument(
            "Cinn variable(%s) maps to more than one paddle variable(%s,%s)",
            x.second, res.first->second, x.first));
  }
  // supplement the relations of the remain variables not appearing in above
  // map,
  // they are internal variables and here we use the name from cinn compiled.
  for (const auto& var_name : cinn_variable_names_) {
    if (!cinn2paddle_varmap_.count(var_name)) {
      cinn2paddle_varmap_.emplace(var_name, var_name);
      paddle2cinn_varmap_.emplace(var_name, var_name);
    }
  }
52 53
}

54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
void CinnLaunchContext::UpdateCapturedEnv(const framework::Scope& scope,
                                          const platform::Place& place) {
  if (std::addressof(scope) == cached_scope_ &&
      std::addressof(place) == cached_place_) {
    VLOG(4) << "Captured scope:" << cached_scope_ << ", place:" << cached_place_
            << " are not changed";
    return;
  }
  cached_scope_ = std::addressof(scope);
  cached_place_ = std::addressof(place);
  cached_temp_scope_ = scope.NewTmpScope();
  VLOG(4) << "Captured env is update, scope:" << cached_scope_ << "->"
          << std::addressof(scope) << ", place:" << cached_place_ << "->"
          << std::addressof(place);
}

bool CinnLaunchContext::IsArgumentsInitialized() const {
  if (hold_buffers_.empty() || name2argument_.empty()) {
    return false;
  }
  return true;
}

77 78 79 80 81
bool CinnLaunchContext::IsVariableUsed(
    const std::string& paddle_var_name) const {
  return paddle2cinn_varmap_.count(paddle_var_name) > 0 &&
         cinn_variable_names_.count(paddle2cinn_varmap_.at(paddle_var_name)) >
             0;
82 83 84 85 86 87 88 89 90
}

CinnTensor CinnLaunchContext::GetCinnTensor(const std::string& var_name) {
  PADDLE_ENFORCE_GT(cinn_variable_names_.count(var_name), 0,
                    platform::errors::NotFound(
                        "Variable(%s) not found in cinn scope.", var_name));
  return cinn_scope_->GetTensor(var_name);
}

91
std::unordered_set<std::string> CinnLaunchContext::GetInternalVariableNames() {
92 93 94 95 96
  std::unordered_set<std::string> all_parameters(cinn_variable_names_);
  std::for_each(name2argument_.begin(), name2argument_.end(),
                [&all_parameters](const auto& name2arg) {
                  all_parameters.erase(name2arg.first);
                });
97
  return all_parameters;
98 99
}

100 101 102
void CinnLaunchContext::CheckTensorEquivalent(
    const std::string& paddle_var_name, const LoDTensor& paddle_tensor,
    const CinnTensor& cinn_tensor) {
103 104
  // check dimension
  auto cinn_dims = framework::make_ddim(cinn_tensor->shape().data());
105 106 107 108
  PADDLE_ENFORCE_EQ(paddle_tensor.dims(), cinn_dims,
                    platform::errors::PreconditionNotMet(
                        "Tensors' shape in variable(%s) are not equivalent, "
                        "paddle's shape = [%s], but cinn's shape = [%s].",
109
                        paddle_var_name, paddle_tensor.dims(), cinn_dims));
110 111 112 113

  // TODO(CtfGo): check the underlying data type after CINN ready
}

114 115 116 117 118 119
void CinnLaunchContext::AssignExternalVariable(
    const std::string& paddle_var_name) {
  PADDLE_ENFORCE_EQ(
      IsVariableUsed(paddle_var_name), true,
      platform::errors::InvalidArgument("Paddle variable(%s) not used by cinn",
                                        paddle_var_name));
120

121
  const auto& cinn_var_name = paddle2cinn_varmap_.at(paddle_var_name);
122
  const auto& paddle_tensor =
123 124
      cached_scope_->GetVar(paddle_var_name)->Get<LoDTensor>();
  CinnTensor cinn_tensor = GetCinnTensor(cinn_var_name);
125
  if (paddle_tensor.IsInitialized()) {
126
    CheckTensorEquivalent(paddle_var_name, paddle_tensor, cinn_tensor);
127
  }
128 129 130 131 132 133

  auto cinn_buffer = std::make_unique<cinn_buffer_t>();
  // assign dimensions and alloc/free callback of cinn_buffer_t
  cinn_buffer->resize(cinn_tensor->shape().data().data(),
                      cinn_tensor->shape().data().size());
  cinn_buffer->external_malloc = new std::function<int(void*, cinn_buffer_t*)>(
134
      [this, paddle_var_name](void* ctx, cinn_buffer_t* buffer) {
135
        auto* tensor =
136
            cached_scope_->GetVar(paddle_var_name)->GetMutable<LoDTensor>();
137 138 139 140 141 142 143 144 145 146 147 148 149
        tensor->Resize(framework::DDim(buffer->dims, buffer->dimensions));
        buffer->memory = reinterpret_cast<uint8_t*>(
            tensor->mutable_data<float>(*cached_place_));
        return 0;
      });

  // external variables will be recycled by global gc, so do nothing here
  cinn_buffer->external_free = new std::function<int(void*, cinn_buffer_t*)>(
      [](void* ctx, cinn_buffer_t* buffer) {
        // Do nothing
        return 0;
      });

150
  return SetArgument(cinn_var_name, std::move(cinn_buffer));
151
}
152

153 154 155 156 157 158 159
void CinnLaunchContext::AssignInternalVariable(
    const std::string& cinn_var_name) {
  PADDLE_ENFORCE_GT(
      cinn_variable_names_.count(cinn_var_name), 0,
      platform::errors::InvalidArgument("Variable(%s) not found in cinn socpe.",
                                        cinn_var_name));
  CinnTensor cinn_tensor = GetCinnTensor(cinn_var_name);
160
  auto cinn_buffer = std::make_unique<cinn_buffer_t>();
161 162 163
  // assign dimensions and alloc/free callback of cinn_buffer_t
  cinn_buffer->resize(cinn_tensor->shape().data().data(),
                      cinn_tensor->shape().data().size());
164 165

  cinn_buffer->external_malloc = new std::function<int(void*, cinn_buffer_t*)>(
166
      [this, cinn_var_name](void* ctx, cinn_buffer_t* buffer) {
167
        auto* tensor =
168
            cached_temp_scope_->Var(cinn_var_name)->GetMutable<LoDTensor>();
169 170 171
        tensor->Resize(framework::DDim(buffer->dims, buffer->dimensions));
        buffer->memory = reinterpret_cast<uint8_t*>(
            tensor->mutable_data<float>(*cached_place_));
172 173 174
        return 0;
      });

175 176
  // internal variables should release its buffer immediately
  // if no instruction use it
177
  cinn_buffer->external_free = new std::function<int(void*, cinn_buffer_t*)>(
178
      [this, cinn_var_name](void* ctx, cinn_buffer_t* buffer) {
179
        auto* tensor =
180
            cached_temp_scope_->GetVar(cinn_var_name)->GetMutable<LoDTensor>();
181
        tensor->clear();
182 183
        return 0;
      });
184
  return SetArgument(cinn_var_name, std::move(cinn_buffer));
185 186
}

187
void CinnLaunchContext::SetArgument(const std::string& cinn_var_name,
188
                                    std::unique_ptr<cinn_buffer_t>&& buffer) {
189 190 191
  VLOG(4) << "SetArgument-" << name2argument_.size() << ": name("
          << cinn_var_name << "), dims("
          << framework::DDim(buffer->dims, buffer->dimensions) << ").";
192

193
  name2argument_.emplace(cinn_var_name, buffer.get());
194 195 196 197 198 199 200 201 202 203 204 205 206
  hold_buffers_.emplace_back(std::move(buffer));
}

const std::map<std::string, cinn_pod_value_t>&
CinnLaunchContext::FinalizeArguments() const {
  // Check all execution parameters are assigned valued.
  std::for_each(cinn_variable_names_.begin(), cinn_variable_names_.end(),
                [this](const auto& var_name) {
                  PADDLE_ENFORCE_GT(name2argument_.count(var_name), 0,
                                    platform::errors::InvalidArgument(
                                        "Variable(%s) is missed for launching "
                                        "compiled program execution",
                                        var_name));
207
                });
208
  return name2argument_;
209 210
}

211 212 213 214 215 216 217 218 219 220 221 222 223 224
cinn_buffer_t* CinnLaunchContext::GetCinnBufferOfVar(
    const std::string& paddle_var_name) {
  auto res = paddle2cinn_varmap_.find(paddle_var_name);
  PADDLE_ENFORCE_NE(
      res, paddle2cinn_varmap_.end(),
      platform::errors::InvalidArgument(
          "Variable(%s) not found in compilation result", paddle_var_name));
  auto it = name2argument_.find(res->second);
  PADDLE_ENFORCE_NE(it, name2argument_.end(),
                    platform::errors::InvalidArgument(
                        "Argument(%s) not be initialized", res->second));
  return static_cast<cinn_buffer_t*>(it->second);
}

225
}  // namespace details
226 227
}  // namespace operators
}  // namespace paddle