stream_analyzer.cc 8.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/new_executor/stream_analyzer.h"
16

17
#include <future>
18 19
#include <unordered_set>

20 21
#include "paddle/fluid/platform/device_context.h"

22 23
namespace paddle {
namespace framework {
24 25 26 27 28 29 30
namespace {
std::map<Place, std::shared_future<std::unique_ptr<platform::DeviceContext>>>*
    d2h_ctxs = nullptr;
std::map<Place, std::shared_future<std::unique_ptr<platform::DeviceContext>>>*
    h2d_ctxs = nullptr;
std::mutex ctx_mtx;
}  // namespace
31

32
StreamAnalyzer::StreamAnalyzer(const platform::Place& place) : place_(place) {
33 34
  if (platform::is_gpu_place(place) || platform::is_npu_place(place) ||
      platform::is_custom_place(place)) {
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
    std::lock_guard<std::mutex> lk(ctx_mtx);
    if (d2h_ctxs == nullptr) {
      d2h_ctxs = new std::map<
          Place,
          std::shared_future<std::unique_ptr<platform::DeviceContext>>>();
      h2d_ctxs = new std::map<
          Place,
          std::shared_future<std::unique_ptr<platform::DeviceContext>>>();
    }
    if (d2h_ctxs->find(place) == d2h_ctxs->end()) {
      platform::EmplaceDeviceContexts(
          d2h_ctxs,
          {place},
          /*disable_setting_default_stream_for_allocator=*/true);
      platform::EmplaceDeviceContexts(
          h2d_ctxs,
          {place},
          /*disable_setting_default_stream_for_allocator=*/true);
    }
    d2h_ctx_ = (*d2h_ctxs)[place];
    h2d_ctx_ = (*h2d_ctxs)[place];
56 57 58
  }
}

59 60 61 62
/*
 * Parse the var_ids that need to be associated with an event.
 * The caller should guarantee front_op and back_op satisfy the
 * following conditions:
63
 *   1. kQueueSync -> kQueueAsync
64 65 66 67
 *   2. kQueueAsync -> kQueueSync
 *
 * For example: matmul(gpu) -> out_var -> memcpy_d2h
 * out_var should be associated with an event.
L
Leo Chen 已提交
68 69 70 71
 *
 * NOTE(zhiqiu): There are two special case that no event is needed:
 *  1. the variable is marked as NoDataTransformVar
 *  2. the variable is marked as NoNeedDataBuffer
72
 */
L
Leo Chen 已提交
73
std::vector<size_t> StreamAnalyzer::GetNeedEventVarIds(
74 75
    const Instruction& cur_instr, const Instruction& next_instr) {
  std::unordered_set<size_t> unique_var_ids;
76
  for (auto& item : cur_instr.Outputs()) {
77 78 79
    unique_var_ids.insert(item.second.begin(), item.second.end());
  }

L
Leo Chen 已提交
80 81 82 83 84 85 86 87 88 89 90 91
  auto is_no_need_buffer = [&next_instr](std::string name) {
    auto* op = next_instr.OpBase();
    auto& inferer = op->Info().NoNeedBufferVarsInferer();
    if (inferer) {
      auto no_need_buffer_ins =
          inferer(op->Inputs(), op->Outputs(), op->Attrs());
      return no_need_buffer_ins.count(name) != 0;
    }
    return false;
  };

  std::vector<size_t> need_event_var_ids;
92
  for (auto& item : next_instr.Inputs()) {
93
    for (auto var_id : item.second) {
L
Leo Chen 已提交
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
      if (unique_var_ids.count(var_id) > 0) {
        if (next_instr.NoDataTransformVars().count(var_id)) {
          VLOG(4) << "Skip inserting event at variable " << item.first
                  << " of operator " << next_instr.OpBase()->Type()
                  << " since it is NoDataTransform";
          continue;
        }
        if (is_no_need_buffer(item.first)) {
          VLOG(4) << "Skip inserting event at variable " << item.first
                  << " of operator " << next_instr.OpBase()->Type()
                  << " since it is NoNeedBufferVar";
          continue;
        }

        need_event_var_ids.push_back(var_id);
109 110 111
      }
    }
  }
L
Leo Chen 已提交
112
  return need_event_var_ids;
113 114
}

L
Leo Chen 已提交
115
void StreamAnalyzer::ConstructEventForVar(
116 117 118 119
    const std::vector<size_t>& new_event_var_id,
    Instruction* next_instr,
    platform::DeviceType waiter_type,
    const platform::Place& place) {
120 121 122
  for (auto var_id : new_event_var_id) {
    if (var_id2event_.count(var_id) == 0) {
      auto device_event = std::make_shared<platform::DeviceEvent>(
L
Leo Chen 已提交
123
          place, platform::GenerateDeviceEventFlag());
124 125 126
      var_id2event_.emplace(var_id, std::move(device_event));
    }
    // Add events for next_instr.inputs
127
    next_instr->AddInputEvent(var_id, var_id2event_.at(var_id), waiter_type);
128 129 130
  }
}

131 132 133
void StreamAnalyzer::Schedule(const std::vector<size_t>& downstream_ops,
                              std::vector<Instruction>* instructions,
                              size_t op_index) {
134
  auto& cur_instr = instructions->at(op_index);
135
  auto& next_instruction = cur_instr.NextInstructions();
136 137 138 139
  std::vector<size_t> event_var_ids;
  for (auto next_op_id : downstream_ops) {
    auto& next_instr = instructions->at(next_op_id);
    if (IsDirectRun(cur_instr, next_instr)) {
L
Leo Chen 已提交
140 141
      VLOG(4) << "DirectRun: " << cur_instr.OpBase()->Type() << "->"
              << next_instr.OpBase()->Type();
142
      next_instruction.AddDirectRun(next_op_id);
143
    } else {
144
      // Always insert events between different stream
L
Leo Chen 已提交
145
      auto need_event_var_ids = GetNeedEventVarIds(cur_instr, next_instr);
146 147
      event_var_ids.insert(event_var_ids.end(),
                           need_event_var_ids.begin(),
L
Leo Chen 已提交
148
                           need_event_var_ids.end());
149

150
      auto waiter_type = GetWaiterType(next_instr);
151 152 153
      ConstructEventForVar(need_event_var_ids,
                           &next_instr,
                           waiter_type,
L
Leo Chen 已提交
154
                           cur_instr.DeviceContext().GetPlace());
155

156
      if (waiter_type == platform::kCPU) {  // GPU -> CPU
L
Leo Chen 已提交
157 158
        VLOG(4) << "SyncRun: " << cur_instr.OpBase()->Type() << "->"
                << next_instr.OpBase()->Type();
159
        next_instruction.AddSyncRun(next_op_id);
160
      } else {  // GPU -> GPU(different stream)
L
Leo Chen 已提交
161 162
        VLOG(4) << "EventRun: " << cur_instr.OpBase()->Type() << "->"
                << next_instr.OpBase()->Type();
163
        next_instruction.ADDEventRun(next_op_id);
164 165
      }
    }
166 167
  }
  // Create events for these cross-stream vars
168
  VLOG(3) << cur_instr.OpBase()->Type()
169 170
          << " event_var_ids.size: " << event_var_ids.size();
  for (auto var_id : event_var_ids) {
171 172
    cur_instr.AddOutputEvent(
        var_id, var_id2event_.at(var_id), platform::kCUDA /*not used*/);
173 174 175 176
  }
}

platform::DeviceContext* StreamAnalyzer::ParseDeviceContext(
177 178
    const OpFuncNode& op_func_node) {
  auto& op_type = op_func_node.operator_base_->Type();
179
  auto* dev_ctx = op_func_node.dev_ctx_;
180
  // only gpu/npu need update. xpu not need, because xpu memcpy op kernel is
181
  // synchronous.
182 183
  if (platform::is_gpu_place(place_) || platform::is_npu_place(place_) ||
      platform::is_custom_place(place_)) {
184 185
    if (op_type == interpreter::kMemcpyD2H) {
      VLOG(3) << "Get dev_ctx from d2h_context_pool_";
186
      dev_ctx = d2h_ctx_.get().get();
187 188
    } else if (op_type == interpreter::kMemcpyH2D) {
      VLOG(3) << "Get dev_ctx from h2d_context_pool_";
189
      dev_ctx = h2d_ctx_.get().get();
190
    }
191 192 193 194
  }
  return dev_ctx;
}

195 196 197
/*
 * NOTE(dev): The following cases are considered as directly run:
 *
198
 *  0. in XPU place. because xpu memcpy op kernel is synchronous.
199
 *  1. with same dev_ctx_, such as: CPU -> CPU, GPU -> GPU
L
Leo Chen 已提交
200 201 202 203
 *  2. CPU -> any (it is possible: CPU op->VAR->GPU op, when var is no need
 * buffer or no need data transform)
 *  3. D2H -> CPU
 *  4. CPU -> H2D
204 205 206
 */
bool StreamAnalyzer::IsDirectRun(Instruction& cur_instr,
                                 const Instruction& next_instr) {
207 208
  if (&cur_instr.DeviceContext() == &next_instr.DeviceContext()) return true;

209 210 211
  // xpu&ipu memcpy kerenl is synchronous.
  if (platform::is_ipu_place(place_) || platform::is_xpu_place(place_))
    return true;
212 213

  // npu d2h kernel is asynchronous.
214
  if (platform::is_npu_place(place_) || platform::is_custom_place(place_)) {
215 216 217 218 219 220 221
    return interpreter::IsCpuOp(cur_instr) ||
           interpreter::IsMemcpyH2D(next_instr);
  }
  // gpu or cpu
  return interpreter::IsCpuOp(cur_instr) ||
         interpreter::IsMemcpyD2H(cur_instr) ||
         interpreter::IsMemcpyH2D(next_instr);
222 223 224
}

platform::DeviceType StreamAnalyzer::GetWaiterType(const Instruction& instr) {
225
  if (instr.KernelType() == OpFuncType::kQueueSync) {
226 227
    return platform::kCPU;
  } else {
228 229
    if (platform::is_xpu_place(place_)) {
      return platform::kXPU;
230 231
    } else if (platform::is_npu_place(place_)) {
      return platform::kNPU;
232 233
    } else if (platform::is_custom_place(place_)) {
      return platform::kCUSTOM_DEVICE;
234
    }
235 236 237 238
    return platform::kCUDA;
  }
}

239 240
}  // namespace framework
}  // namespace paddle