cinn_launch_context.cc 7.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/operators/cinn/cinn_launch_context.h"
16
#include <functional>
17
#include <vector>
18

19 20
namespace paddle {
namespace operators {
21 22
namespace details {

23 24 25 26
CinnLaunchContext::CinnLaunchContext(
    const std::unordered_map<std::string, std::string>& paddle2cinn_varmap,
    const std::shared_ptr<CinnScope>& cinn_scope)
    : paddle2cinn_varmap_(paddle2cinn_varmap), cinn_scope_(cinn_scope) {
27 28
  auto var_names = cinn_scope_->var_names();
  cinn_variable_names_.reserve(var_names.size());
29
  std::transform(
30 31 32
      var_names.begin(), var_names.end(),
      std::inserter(cinn_variable_names_, cinn_variable_names_.end()),
      [](const auto& name_view) { return std::string(name_view.data()); });
33 34
}

35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
void CinnLaunchContext::UpdateCapturedEnv(const framework::Scope& scope,
                                          const platform::Place& place) {
  if (std::addressof(scope) == cached_scope_ &&
      std::addressof(place) == cached_place_) {
    VLOG(4) << "Captured scope:" << cached_scope_ << ", place:" << cached_place_
            << " are not changed";
    return;
  }
  cached_scope_ = std::addressof(scope);
  cached_place_ = std::addressof(place);
  cached_temp_scope_ = scope.NewTmpScope();
  VLOG(4) << "Captured env is update, scope:" << cached_scope_ << "->"
          << std::addressof(scope) << ", place:" << cached_place_ << "->"
          << std::addressof(place);
}

bool CinnLaunchContext::IsArgumentsInitialized() const {
  if (hold_buffers_.empty() || name2argument_.empty()) {
    return false;
  }
  return true;
}

bool CinnLaunchContext::IsVariableUsed(const std::string& paddle_name) const {
59 60 61 62 63 64 65 66 67 68 69
  return paddle2cinn_varmap_.count(paddle_name) > 0 &&
         cinn_variable_names_.count(paddle2cinn_varmap_.at(paddle_name)) > 0;
}

CinnTensor CinnLaunchContext::GetCinnTensor(const std::string& var_name) {
  PADDLE_ENFORCE_GT(cinn_variable_names_.count(var_name), 0,
                    platform::errors::NotFound(
                        "Variable(%s) not found in cinn scope.", var_name));
  return cinn_scope_->GetTensor(var_name);
}

70
std::unordered_set<std::string> CinnLaunchContext::GetInternalVariableNames() {
71 72 73 74 75
  std::unordered_set<std::string> all_parameters(cinn_variable_names_);
  std::for_each(name2argument_.begin(), name2argument_.end(),
                [&all_parameters](const auto& name2arg) {
                  all_parameters.erase(name2arg.first);
                });
76
  return all_parameters;
77 78 79 80 81
}

void CinnLaunchContext::CheckTensorEquivalent(const std::string& paddle_name,
                                              const LoDTensor& paddle_tensor,
                                              const CinnTensor& cinn_tensor) {
82 83
  // check dimension
  auto cinn_dims = framework::make_ddim(cinn_tensor->shape().data());
84 85 86 87 88
  PADDLE_ENFORCE_EQ(paddle_tensor.dims(), cinn_dims,
                    platform::errors::PreconditionNotMet(
                        "Tensors' shape in variable(%s) are not equivalent, "
                        "paddle's shape = [%s], but cinn's shape = [%s].",
                        paddle_name, paddle_tensor.dims(), cinn_dims));
89 90 91 92

  // TODO(CtfGo): check the underlying data type after CINN ready
}

93
void CinnLaunchContext::AssignExternalVariable(const std::string& paddle_name) {
94 95 96
  PADDLE_ENFORCE_EQ(IsVariableUsed(paddle_name), true,
                    platform::errors::InvalidArgument(
                        "Paddle variable(%s) not used by cinn", paddle_name));
97

98
  const auto& cinn_name = paddle2cinn_varmap_.at(paddle_name);
99 100
  const auto& paddle_tensor =
      cached_scope_->GetVar(paddle_name)->Get<LoDTensor>();
101
  CinnTensor cinn_tensor = GetCinnTensor(cinn_name);
102 103
  if (paddle_tensor.IsInitialized()) {
    CheckTensorEquivalent(paddle_name, paddle_tensor, cinn_tensor);
104
  }
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127

  auto cinn_buffer = std::make_unique<cinn_buffer_t>();
  // assign dimensions and alloc/free callback of cinn_buffer_t
  cinn_buffer->resize(cinn_tensor->shape().data().data(),
                      cinn_tensor->shape().data().size());
  cinn_buffer->external_malloc = new std::function<int(void*, cinn_buffer_t*)>(
      [this, paddle_name](void* ctx, cinn_buffer_t* buffer) {
        auto* tensor =
            cached_scope_->GetVar(paddle_name)->GetMutable<LoDTensor>();
        tensor->Resize(framework::DDim(buffer->dims, buffer->dimensions));
        buffer->memory = reinterpret_cast<uint8_t*>(
            tensor->mutable_data<float>(*cached_place_));
        return 0;
      });

  // external variables will be recycled by global gc, so do nothing here
  cinn_buffer->external_free = new std::function<int(void*, cinn_buffer_t*)>(
      [](void* ctx, cinn_buffer_t* buffer) {
        // Do nothing
        return 0;
      });

  return SetArgument(cinn_name, std::move(cinn_buffer));
128
}
129

130
void CinnLaunchContext::AssignInternalVariable(const std::string& cinn_name) {
131 132 133
  PADDLE_ENFORCE_GT(cinn_variable_names_.count(cinn_name), 0,
                    platform::errors::InvalidArgument(
                        "Variable(%s) not found in cinn socpe.", cinn_name));
134
  CinnTensor cinn_tensor = GetCinnTensor(cinn_name);
135
  auto cinn_buffer = std::make_unique<cinn_buffer_t>();
136 137 138
  // assign dimensions and alloc/free callback of cinn_buffer_t
  cinn_buffer->resize(cinn_tensor->shape().data().data(),
                      cinn_tensor->shape().data().size());
139 140

  cinn_buffer->external_malloc = new std::function<int(void*, cinn_buffer_t*)>(
141 142 143 144 145 146
      [this, cinn_name](void* ctx, cinn_buffer_t* buffer) {
        auto* tensor =
            cached_temp_scope_->Var(cinn_name)->GetMutable<LoDTensor>();
        tensor->Resize(framework::DDim(buffer->dims, buffer->dimensions));
        buffer->memory = reinterpret_cast<uint8_t*>(
            tensor->mutable_data<float>(*cached_place_));
147 148 149
        return 0;
      });

150 151
  // internal variables should release its buffer immediately
  // if no instruction use it
152
  cinn_buffer->external_free = new std::function<int(void*, cinn_buffer_t*)>(
153 154 155 156
      [this, cinn_name](void* ctx, cinn_buffer_t* buffer) {
        auto* tensor =
            cached_temp_scope_->GetVar(cinn_name)->GetMutable<LoDTensor>();
        tensor->clear();
157 158
        return 0;
      });
159
  return SetArgument(cinn_name, std::move(cinn_buffer));
160 161
}

162
void CinnLaunchContext::SetArgument(const std::string& cinn_name,
163 164 165 166 167
                                    std::unique_ptr<cinn_buffer_t>&& buffer) {
  VLOG(4) << "SetArgument-" << name2argument_.size() << ": name(" << cinn_name
          << "), dims(" << framework::DDim(buffer->dims, buffer->dimensions)
          << ").";

168 169 170 171 172 173 174 175 176 177 178 179 180 181
  name2argument_.emplace(cinn_name, buffer.get());
  hold_buffers_.emplace_back(std::move(buffer));
}

const std::map<std::string, cinn_pod_value_t>&
CinnLaunchContext::FinalizeArguments() const {
  // Check all execution parameters are assigned valued.
  std::for_each(cinn_variable_names_.begin(), cinn_variable_names_.end(),
                [this](const auto& var_name) {
                  PADDLE_ENFORCE_GT(name2argument_.count(var_name), 0,
                                    platform::errors::InvalidArgument(
                                        "Variable(%s) is missed for launching "
                                        "compiled program execution",
                                        var_name));
182
                });
183
  return name2argument_;
184 185 186
}

}  // namespace details
187 188
}  // namespace operators
}  // namespace paddle