cinn_launch_op.h 8.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>
#include <string>
#include <unordered_map>
20
#include <unordered_set>
21 22 23
#include "cinn/hlir/framework/graph_compiler.h"
#include "cinn/hlir/framework/scope.h"
#include "cinn/runtime/cinn_runtime.h"
24
#include "cinn/runtime/flags.h"
25
#include "paddle/fluid/framework/data_type.h"
26 27 28
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
29
#include "paddle/fluid/operators/cinn/cinn_launch_context.h"
30 31 32 33

namespace paddle {
namespace operators {

34
constexpr char kX[] = "X";
35
constexpr char kNoNeedBufferX[] = "NoNeedBufferX";
36 37
constexpr char kOutputs[] = "Out";
constexpr char kCompilationKey[] = "compilation_key";
38 39

using LoDTensor = framework::LoDTensor;
40 41 42 43 44 45 46 47 48 49 50 51 52
using CinnTensor = ::cinn::hlir::framework::Tensor;
using CinnScope = ::cinn::hlir::framework::Scope;
using CinnCompiler = framework::paddle2cinn::CinnCompiler;
using CinnCompiledObject = framework::paddle2cinn::CinnCompiledObject;

namespace details {

// Tranform Paddle place to CINN target
const ::cinn::common::Target& PlaceToCinnTarget(const platform::Place& place);

// Print detailed compilation result of graph for debug
void DebugCinnCompiledResult(const CinnCompiledObject& result);

53 54
// Launch cinn to execute compiled executable program and wait done
void LaunchCinnExecution(const CinnCompiledObject& compiled_obj,
55
                         const CinnLaunchContext& context, void* stream);
56 57 58

// Set cinn FLAGS (such as FLAGS_cinn_cudnn_deterministic) with paddle's FLAGS.
void SetCinnRuntimeFlags();
59 60 61 62 63 64 65 66 67 68 69 70

template <typename DeviceContext>
void* GetStream(const framework::ExecutionContext& ctx) {
  return nullptr;
}

#ifdef PADDLE_WITH_CUDA
template <>
void* GetStream<platform::CUDADeviceContext>(
    const framework::ExecutionContext& ctx);
#endif

71
}  // namespace details
72 73 74 75 76

template <typename DeviceContext, typename T>
class CinnLaunchOpKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
77 78
    const auto& scope = ctx.scope();
    const auto& place = ctx.GetPlace();
79
    void* stream = details::GetStream<DeviceContext>(ctx);
80 81 82 83 84 85 86
    // Step 1. Find graph object and prepare input
    PADDLE_ENFORCE_EQ(ctx.HasAttr(kCompilationKey), true,
                      platform::errors::NotFound(
                          "No Attribute(%s) found for CinnLaunchOp operator.",
                          kCompilationKey));
    const auto& compilation_key =
        ctx.template Attr<std::string>(kCompilationKey);
87
    VLOG(4) << "CinnLaunchOp attribute(" << kCompilationKey << ") "
88 89
            << "value:\n"
            << CinnCompiler::GetInstance()->ReadableKey(compilation_key);
90

91
    std::map<std::string, const LoDTensor*> inputs_name2tensor;
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
    std::vector<std::string> input_x_variable_names;
    std::vector<std::string> input_no_need_buffer_variable_names;
    auto add_name2tensor_fn = [&inputs_name2tensor](
        const std::vector<std::string>& variable_names,
        const std::vector<const LoDTensor*>& tensors) {
      std::transform(
          variable_names.begin(), variable_names.end(), tensors.begin(),
          std::inserter(inputs_name2tensor, inputs_name2tensor.end()),
          [](const std::string& name, const LoDTensor* tensor) {
            return std::make_pair(name, tensor);
          });
    };

    auto input_x_tensors = ctx.MultiInput<LoDTensor>(kX);
    if (!input_x_tensors.empty()) {
      input_x_variable_names = std::move(ctx.InputNames(kX));
      add_name2tensor_fn(input_x_variable_names, input_x_tensors);
    }
    auto input_no_need_buffer_tensors =
        ctx.MultiInput<LoDTensor>(kNoNeedBufferX);
    if (!input_no_need_buffer_tensors.empty()) {
      input_no_need_buffer_variable_names =
          std::move(ctx.InputNames(kNoNeedBufferX));
      add_name2tensor_fn(input_no_need_buffer_variable_names,
                         input_no_need_buffer_tensors);
    }
118 119

    // Step 2. Get compilation result of the graph
120
    auto target = details::PlaceToCinnTarget(place);
121
    const auto& cinn_compiled_object = CinnCompiler::GetInstance()->Compile(
122
        compilation_key, inputs_name2tensor, target, stream);
123
    details::DebugCinnCompiledResult(cinn_compiled_object);
124

125
    auto* launch_context = cinn_compiled_object.launch_context.get();
126
    // Step 3. Prepare arguments needed for the compiled executable program.
127 128 129 130 131 132 133
    launch_context->UpdateCapturedEnv(scope, place);
    if (!launch_context->IsArgumentsInitialized()) {
      VLOG(4) << "CinnLaunchOp prepare arguments";

      // 3.1 Prepare input variables: tensors of input variables have
      //     been initialized before graph compiled, just check the
      //     equiality between tensors of paddle and cinn.
134 135 136 137 138 139 140 141 142 143 144 145 146
      for (const auto& var_name : input_no_need_buffer_variable_names) {
        // the input variable declared as 'no need buffer' can not be used
        PADDLE_ENFORCE_EQ(
            launch_context->IsVariableUsed(var_name), false,
            platform::errors::InvalidArgument(
                "Input variable(%s) should not be used by cinn in execution",
                var_name));
      }

      for (const auto& var_name : input_x_variable_names) {
        // some input variables don't need for cinn because they are
        // eliminated by optimized passes or some cinn operators use
        // less variables
147
        if (!launch_context->IsVariableUsed(var_name)) {
148
          VLOG(4) << "Input variable" << var_name << " not used by cinn";
149 150 151 152
          continue;
        }

        launch_context->AssignExternalVariable(var_name);
153 154
      }

155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
      // 3.2 Prepare output variables: all output variables should
      //     be initialized and allocated buffer before
      //     the runtime program start execution, the compilation result
      //     includes details of their buffer assginment and we use that to
      //     allocate space in Paddle. For those variables allocated yet,
      //     like persistable parameters, just check the equiality between
      //     Paddle allocation and CINN buffer assginment.
      auto output_variable_names = ctx.OutputNames(kOutputs);
      for (const auto var_name : output_variable_names) {
        PADDLE_ENFORCE_EQ(
            launch_context->IsVariableUsed(var_name), true,
            platform::errors::InvalidArgument(
                "Output variable(%s) not used by cinn", var_name));

        launch_context->AssignExternalVariable(var_name);
      }
171

172 173 174 175 176 177 178 179 180 181
      // 3.3 Prepare internal or temporary variables: Create a temporary
      //     scope to keep internal variables within graph or temporary
      //     variables needed by the compiled runtime program in addition.
      //     Here we directly use the names from CinnScope as Paddle variable
      //     names, because they will not be used outside the graph
      //     and should be destructed after computation finished.
      auto internal_variable_names = launch_context->GetInternalVariableNames();
      for (const auto& var_name : internal_variable_names) {
        launch_context->AssignInternalVariable(var_name);
      }
182
    }
183

184 185 186 187
    // Step 4. Set CINN runtime FLAGS, such as FLAGS_cinn_cudnn_deterministic.
    details::SetCinnRuntimeFlags();

    // Step 5. Launch CINN to execute the compiled executable program
188 189
    VLOG(4) << "Run Cinn compiled executable program with stream: " << stream;
    details::LaunchCinnExecution(cinn_compiled_object, *launch_context, stream);
190
    VLOG(4) << "CinnLaunchOp launch execution done.";
191 192 193 194 195
  }
};

}  // namespace operators
}  // namespace paddle