cinn_launch_context.cc 7.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/operators/cinn/cinn_launch_context.h"
16
#include <functional>
17
#include <vector>
18

19 20
namespace paddle {
namespace operators {
21 22
namespace details {

23 24 25 26
CinnLaunchContext::CinnLaunchContext(
    const std::unordered_map<std::string, std::string>& paddle2cinn_varmap,
    const std::shared_ptr<CinnScope>& cinn_scope)
    : paddle2cinn_varmap_(paddle2cinn_varmap), cinn_scope_(cinn_scope) {
27 28
  auto var_names = cinn_scope_->var_names();
  cinn_variable_names_.reserve(var_names.size());
29
  std::transform(
30 31 32
      var_names.begin(), var_names.end(),
      std::inserter(cinn_variable_names_, cinn_variable_names_.end()),
      [](const auto& name_view) { return std::string(name_view.data()); });
33 34
}

35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
void CinnLaunchContext::UpdateCapturedEnv(const framework::Scope& scope,
                                          const platform::Place& place) {
  if (std::addressof(scope) == cached_scope_ &&
      std::addressof(place) == cached_place_) {
    VLOG(4) << "Captured scope:" << cached_scope_ << ", place:" << cached_place_
            << " are not changed";
    return;
  }
  cached_scope_ = std::addressof(scope);
  cached_place_ = std::addressof(place);
  cached_temp_scope_ = scope.NewTmpScope();
  VLOG(4) << "Captured env is update, scope:" << cached_scope_ << "->"
          << std::addressof(scope) << ", place:" << cached_place_ << "->"
          << std::addressof(place);
}

bool CinnLaunchContext::IsArgumentsInitialized() const {
  if (hold_buffers_.empty() || name2argument_.empty()) {
    return false;
  }
  return true;
}

58 59 60 61 62
bool CinnLaunchContext::IsVariableUsed(
    const std::string& paddle_var_name) const {
  return paddle2cinn_varmap_.count(paddle_var_name) > 0 &&
         cinn_variable_names_.count(paddle2cinn_varmap_.at(paddle_var_name)) >
             0;
63 64 65 66 67 68 69 70 71
}

CinnTensor CinnLaunchContext::GetCinnTensor(const std::string& var_name) {
  PADDLE_ENFORCE_GT(cinn_variable_names_.count(var_name), 0,
                    platform::errors::NotFound(
                        "Variable(%s) not found in cinn scope.", var_name));
  return cinn_scope_->GetTensor(var_name);
}

72
std::unordered_set<std::string> CinnLaunchContext::GetInternalVariableNames() {
73 74 75 76 77
  std::unordered_set<std::string> all_parameters(cinn_variable_names_);
  std::for_each(name2argument_.begin(), name2argument_.end(),
                [&all_parameters](const auto& name2arg) {
                  all_parameters.erase(name2arg.first);
                });
78
  return all_parameters;
79 80
}

81 82 83
void CinnLaunchContext::CheckTensorEquivalent(
    const std::string& paddle_var_name, const LoDTensor& paddle_tensor,
    const CinnTensor& cinn_tensor) {
84 85
  // check dimension
  auto cinn_dims = framework::make_ddim(cinn_tensor->shape().data());
86 87 88 89
  PADDLE_ENFORCE_EQ(paddle_tensor.dims(), cinn_dims,
                    platform::errors::PreconditionNotMet(
                        "Tensors' shape in variable(%s) are not equivalent, "
                        "paddle's shape = [%s], but cinn's shape = [%s].",
90
                        paddle_var_name, paddle_tensor.dims(), cinn_dims));
91 92 93 94

  // TODO(CtfGo): check the underlying data type after CINN ready
}

95 96 97 98 99 100
void CinnLaunchContext::AssignExternalVariable(
    const std::string& paddle_var_name) {
  PADDLE_ENFORCE_EQ(
      IsVariableUsed(paddle_var_name), true,
      platform::errors::InvalidArgument("Paddle variable(%s) not used by cinn",
                                        paddle_var_name));
101

102
  const auto& cinn_var_name = paddle2cinn_varmap_.at(paddle_var_name);
103
  const auto& paddle_tensor =
104 105
      cached_scope_->GetVar(paddle_var_name)->Get<LoDTensor>();
  CinnTensor cinn_tensor = GetCinnTensor(cinn_var_name);
106
  if (paddle_tensor.IsInitialized()) {
107
    CheckTensorEquivalent(paddle_var_name, paddle_tensor, cinn_tensor);
108
  }
109 110 111 112 113 114

  auto cinn_buffer = std::make_unique<cinn_buffer_t>();
  // assign dimensions and alloc/free callback of cinn_buffer_t
  cinn_buffer->resize(cinn_tensor->shape().data().data(),
                      cinn_tensor->shape().data().size());
  cinn_buffer->external_malloc = new std::function<int(void*, cinn_buffer_t*)>(
115
      [this, paddle_var_name](void* ctx, cinn_buffer_t* buffer) {
116
        auto* tensor =
117
            cached_scope_->GetVar(paddle_var_name)->GetMutable<LoDTensor>();
118 119 120 121 122 123 124 125 126 127 128 129 130
        tensor->Resize(framework::DDim(buffer->dims, buffer->dimensions));
        buffer->memory = reinterpret_cast<uint8_t*>(
            tensor->mutable_data<float>(*cached_place_));
        return 0;
      });

  // external variables will be recycled by global gc, so do nothing here
  cinn_buffer->external_free = new std::function<int(void*, cinn_buffer_t*)>(
      [](void* ctx, cinn_buffer_t* buffer) {
        // Do nothing
        return 0;
      });

131
  return SetArgument(cinn_var_name, std::move(cinn_buffer));
132
}
133

134 135 136 137 138 139 140
void CinnLaunchContext::AssignInternalVariable(
    const std::string& cinn_var_name) {
  PADDLE_ENFORCE_GT(
      cinn_variable_names_.count(cinn_var_name), 0,
      platform::errors::InvalidArgument("Variable(%s) not found in cinn socpe.",
                                        cinn_var_name));
  CinnTensor cinn_tensor = GetCinnTensor(cinn_var_name);
141
  auto cinn_buffer = std::make_unique<cinn_buffer_t>();
142 143 144
  // assign dimensions and alloc/free callback of cinn_buffer_t
  cinn_buffer->resize(cinn_tensor->shape().data().data(),
                      cinn_tensor->shape().data().size());
145 146

  cinn_buffer->external_malloc = new std::function<int(void*, cinn_buffer_t*)>(
147
      [this, cinn_var_name](void* ctx, cinn_buffer_t* buffer) {
148
        auto* tensor =
149
            cached_temp_scope_->Var(cinn_var_name)->GetMutable<LoDTensor>();
150 151 152
        tensor->Resize(framework::DDim(buffer->dims, buffer->dimensions));
        buffer->memory = reinterpret_cast<uint8_t*>(
            tensor->mutable_data<float>(*cached_place_));
153 154 155
        return 0;
      });

156 157
  // internal variables should release its buffer immediately
  // if no instruction use it
158
  cinn_buffer->external_free = new std::function<int(void*, cinn_buffer_t*)>(
159
      [this, cinn_var_name](void* ctx, cinn_buffer_t* buffer) {
160
        auto* tensor =
161
            cached_temp_scope_->GetVar(cinn_var_name)->GetMutable<LoDTensor>();
162
        tensor->clear();
163 164
        return 0;
      });
165
  return SetArgument(cinn_var_name, std::move(cinn_buffer));
166 167
}

168
void CinnLaunchContext::SetArgument(const std::string& cinn_var_name,
169
                                    std::unique_ptr<cinn_buffer_t>&& buffer) {
170 171 172
  VLOG(4) << "SetArgument-" << name2argument_.size() << ": name("
          << cinn_var_name << "), dims("
          << framework::DDim(buffer->dims, buffer->dimensions) << ").";
173

174
  name2argument_.emplace(cinn_var_name, buffer.get());
175 176 177 178 179 180 181 182 183 184 185 186 187
  hold_buffers_.emplace_back(std::move(buffer));
}

const std::map<std::string, cinn_pod_value_t>&
CinnLaunchContext::FinalizeArguments() const {
  // Check all execution parameters are assigned valued.
  std::for_each(cinn_variable_names_.begin(), cinn_variable_names_.end(),
                [this](const auto& var_name) {
                  PADDLE_ENFORCE_GT(name2argument_.count(var_name), 0,
                                    platform::errors::InvalidArgument(
                                        "Variable(%s) is missed for launching "
                                        "compiled program execution",
                                        var_name));
188
                });
189
  return name2argument_;
190 191 192
}

}  // namespace details
193 194
}  // namespace operators
}  // namespace paddle