cinn_launch_op.h 7.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

S
sneaxiy 已提交
17
#include <chrono>
18 19 20
#include <memory>
#include <string>
#include <unordered_map>
21
#include <unordered_set>
22

23
#include "cinn/common/target.h"
24
#include "paddle/fluid/framework/data_type.h"
25 26 27
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
28
#include "paddle/fluid/operators/cinn/cinn_launch_context.h"
29
#include "paddle/fluid/operators/cinn/cinn_op_helper.h"
30
#include "paddle/fluid/platform/profiler.h"
31
#include "paddle/phi/core/flags.h"
32

33 34
PHI_DECLARE_bool(enable_pe_launch_cinn);
PHI_DECLARE_bool(enable_interpretercore_launch_cinn);
35 36 37
namespace paddle {
namespace operators {

38 39 40 41 42 43 44 45 46 47 48
using CinnCompiler = framework::paddle2cinn::CinnCompiler;
using CinnCompiledObject = framework::paddle2cinn::CinnCompiledObject;

namespace details {

// Tranform Paddle place to CINN target
const ::cinn::common::Target& PlaceToCinnTarget(const platform::Place& place);

// Print detailed compilation result of graph for debug
void DebugCinnCompiledResult(const CinnCompiledObject& result);

49 50
// Launch cinn to execute compiled executable program and wait done
void LaunchCinnExecution(const CinnCompiledObject& compiled_obj,
51 52
                         const CinnLaunchContext& context,
                         void* stream);
53 54 55

// Set cinn FLAGS (such as FLAGS_cinn_cudnn_deterministic) with paddle's FLAGS.
void SetCinnRuntimeFlags();
56

57 58 59 60
// set CINN global random seed
template <typename DeviceContext>
void SetCinnRandomSeed();

61 62 63
// set CINN compile target
void SetCinnTarget(const ::cinn::common::Target& target);

64
}  // namespace details
65

66
template <typename T, typename DeviceContext>
67 68 69
class CinnLaunchOpKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
70 71
    const auto& scope = ctx.scope();
    const auto& place = ctx.GetPlace();
72
    void* stream = details::GetStream<DeviceContext>(ctx);
73 74
    platform::RecordEvent record_event_1(
        "Step 1. Find graph object and prepare input");
75
    // Step 1. Find graph object and prepare input
76 77
    PADDLE_ENFORCE_EQ(ctx.HasAttr(kCompilationKey),
                      true,
78 79 80
                      platform::errors::NotFound(
                          "No Attribute(%s) found for CinnLaunchOp operator.",
                          kCompilationKey));
81
    const auto& compilation_key = ctx.template Attr<int64_t>(kCompilationKey);
82
    VLOG(4) << "CinnLaunchOp attribute(" << kCompilationKey << ") "
83 84
            << "value:\n"
            << CinnCompiler::GetInstance()->ReadableKey(compilation_key);
85

86
    std::map<std::string, const phi::DenseTensor*> inputs_name2tensor;
87 88
    std::vector<std::string> input_x_variable_names;
    std::vector<std::string> input_no_need_buffer_variable_names;
89
    auto add_name2tensor_fn =
90 91 92
        [&inputs_name2tensor](
            const std::vector<std::string>& variable_names,
            const std::vector<const phi::DenseTensor*>& tensors) {
93
          std::transform(
94 95 96
              variable_names.begin(),
              variable_names.end(),
              tensors.begin(),
97
              std::inserter(inputs_name2tensor, inputs_name2tensor.end()),
98
              [](const std::string& name, const phi::DenseTensor* tensor) {
99 100 101
                return std::make_pair(name, tensor);
              });
        };
102

103
    auto input_x_tensors = ctx.MultiInput<phi::DenseTensor>(kX);
104 105 106 107 108
    if (!input_x_tensors.empty()) {
      input_x_variable_names = std::move(ctx.InputNames(kX));
      add_name2tensor_fn(input_x_variable_names, input_x_tensors);
    }
    auto input_no_need_buffer_tensors =
109
        ctx.MultiInput<phi::DenseTensor>(kNoNeedBufferX);
110 111 112 113 114 115
    if (!input_no_need_buffer_tensors.empty()) {
      input_no_need_buffer_variable_names =
          std::move(ctx.InputNames(kNoNeedBufferX));
      add_name2tensor_fn(input_no_need_buffer_variable_names,
                         input_no_need_buffer_tensors);
    }
116

117 118
    platform::RecordEvent record_event_2(
        "Step 2. Get compilation result of the graph");
119
    // Step 2. Get compilation result of the graph
120
    auto target = details::PlaceToCinnTarget(place);
121
    details::SetCinnTarget(target);
S
sneaxiy 已提交
122 123 124 125 126 127
    using ClockType = std::chrono::steady_clock;
    std::chrono::time_point<ClockType> start_t, end_t;
    if (VLOG_IS_ON(1)) {
      VLOG(1) << "Starts to compile at thread " << std::this_thread::get_id();
      start_t = ClockType::now();
    }
128
    const auto& cinn_compiled_object = CinnCompiler::GetInstance()->Compile(
129
        compilation_key, inputs_name2tensor, target, stream);
S
sneaxiy 已提交
130 131 132 133 134 135 136
    if (VLOG_IS_ON(1)) {
      end_t = ClockType::now();
      auto time_sec = std::chrono::duration_cast<std::chrono::milliseconds>(
          end_t - start_t);
      VLOG(1) << "Ends to compile at thread " << std::this_thread::get_id()
              << " , time cost : " << time_sec.count() << " ms";
    }
137
    details::DebugCinnCompiledResult(cinn_compiled_object);
138
    auto* launch_context = cinn_compiled_object.launch_context.get();
139

140
    platform::RecordEvent record_event_3("Step 3. Set CINN runtime FLAGS.");
141
    // Step 3. Set CINN runtime FLAGS, such as FLAGS_cinn_cudnn_deterministic.
142 143
    details::SetCinnRuntimeFlags();

144 145 146
    // set CINN global random seed
    details::SetCinnRandomSeed<DeviceContext>();

147 148 149
    // Step 4. Execute the compiled CINN instructions by a PE or
    //         by the CINN compiled program in sequential order
    if (FLAGS_enable_pe_launch_cinn) {
150 151 152 153 154 155
      if (FLAGS_enable_interpretercore_launch_cinn) {
        platform::RecordEvent record_event_4(
            "Step 4. Execute the runtime program by InterpreterCore.");
        VLOG(4) << "Execute the runtime program by InterpreterCore";
        auto* interpreter_core = launch_context->InitializeInterpreterCore(
            place, const_cast<framework::Scope*>(&scope));
156
        interpreter_core->Run({}, false);
157 158 159 160 161 162 163 164
      } else {
        platform::RecordEvent record_event_4(
            "Step 4. Execute the runtime graph by PE.");
        VLOG(4) << "Execute the runtime graph by PE";
        framework::Scope& exec_scope = scope.NewScope();
        auto* pe = launch_context->InitializePE(place, &exec_scope);
        pe->RunWithoutFetch(launch_context->GetSkipEagerVars());
      }
165
    } else {
166 167
      platform::RecordEvent record_event_4(
          "Step 4. Execute the compiled executable program.");
168 169 170 171
      VLOG(4) << "Execute the compiled executable program";
      launch_context->UpdateCapturedEnv(scope, place);
      LaunchCinnExecution(cinn_compiled_object, *launch_context, stream);
    }
172
    VLOG(4) << "CinnLaunchOp launch execution done.";
173 174 175 176 177
  }
};

}  // namespace operators
}  // namespace paddle