compute_interceptor.cc 8.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/distributed/fleet_executor/compute_interceptor.h"

17
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
18
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
19
#include "paddle/fluid/framework/executor_gc_helper.h"
20
#include "paddle/fluid/framework/operator.h"
21 22 23 24 25 26 27 28 29 30 31

namespace paddle {
namespace distributed {

ComputeInterceptor::ComputeInterceptor(int64_t interceptor_id, TaskNode* node)
    : Interceptor(interceptor_id, node) {
  PrepareDeps();
  RegisterMsgHandle([this](const InterceptorMessage& msg) { Compute(msg); });
}

void ComputeInterceptor::PrepareDeps() {
32 33
  auto& upstream = node_->upstream();
  auto& downstream = node_->downstream();
34

35 36 37
  for (auto up : upstream) {
    in_readys_.emplace(up.first, std::make_pair(up.second, 0));
    in_stops_.emplace(up.first, false);
38
  }
39 40
  for (auto down : downstream) {
    out_buffs_.emplace(down.first, std::make_pair(down.second, 0));
41
  }
42 43 44 45

  // source compute node, should we add a new SourceInterceptor?
  if (upstream.empty()) {
    is_source_ = true;
46 47
    PADDLE_ENFORCE_GT(node_->max_run_times(),
                      0,
48 49 50 51
                      platform::errors::InvalidArgument(
                          "Source ComputeInterceptor must run at least one "
                          "times, but now max_run_times=%ld",
                          node_->max_run_times()));
Y
Yuang Liu 已提交
52 53
    in_readys_.emplace(-1,
                       std::make_pair(std::numeric_limits<int64_t>::max(), 0));
54
  }
55 56 57 58 59

  // If there is no downstream or every downstream is in different rank,
  // then this interceptor is the last one for current rank.
  // This can be get during init, can be cached for later use.
  is_last_ = downstream.empty();
60 61 62 63
}

void ComputeInterceptor::IncreaseReady(int64_t up_id) {
  auto it = in_readys_.find(up_id);
64 65
  PADDLE_ENFORCE_NE(it,
                    in_readys_.end(),
66 67 68
                    platform::errors::NotFound(
                        "Cannot find upstream=%lld in in_readys.", up_id));

Y
Yuang Liu 已提交
69 70
  // source node has no upstream, data_is_ready is send by carrier or others
  if (is_source_ && up_id == -1) {
71
    it->second.second += GetTaskNode()->max_run_times();
Y
Yuang Liu 已提交
72 73 74
    return;
  }

75 76 77
  auto max_ready_size = it->second.first;
  auto ready_size = it->second.second;
  ready_size += 1;
78 79
  PADDLE_ENFORCE_LE(ready_size,
                    max_ready_size,
80 81 82
                    platform::errors::OutOfRange(
                        "upstream=%lld ready_size must <= max_ready_size, but "
                        "now ready_size=%lld, max_ready_size=%lld",
83 84 85
                        up_id,
                        ready_size,
                        max_ready_size));
86 87 88 89 90
  it->second.second = ready_size;
}

void ComputeInterceptor::DecreaseBuff(int64_t down_id) {
  auto it = out_buffs_.find(down_id);
91 92
  PADDLE_ENFORCE_NE(it,
                    out_buffs_.end(),
93 94 95 96 97
                    platform::errors::NotFound(
                        "Cannot find downstream=%lld in out_buffs.", down_id));
  auto used_size = it->second.second;
  used_size -= 1;
  PADDLE_ENFORCE_GE(
98 99
      used_size,
      0,
100 101
      platform::errors::OutOfRange(
          "downstream=%lld used buff size must >= 0, but now equal %lld",
102 103
          down_id,
          used_size));
104 105 106 107 108 109 110
  it->second.second = used_size;
}

bool ComputeInterceptor::IsInputReady() {
  for (auto& ins : in_readys_) {
    auto ready_size = ins.second.second;
    // not ready, return false
Y
Yuang Liu 已提交
111 112 113 114 115
    if (ready_size == 0) {
      VLOG(3) << "Interceptor " << GetInterceptorId()
              << "'s upstreams aren't all ready.";
      return false;
    }
116 117 118 119 120 121 122 123 124
  }
  return true;
}

bool ComputeInterceptor::CanWriteOutput() {
  for (auto& outs : out_buffs_) {
    auto max_buffer_size = outs.second.first;
    auto used_size = outs.second.second;
    // full, return false
Y
Yuang Liu 已提交
125 126 127 128 129
    if (used_size == max_buffer_size) {
      VLOG(3) << "Interceptor " << GetInterceptorId()
              << "'s out buffer is full.";
      return false;
    }
130 131
  }
  return true;
132 133 134
}

void ComputeInterceptor::SendDataReadyToDownStream() {
135 136 137 138 139 140
  for (auto& outs : out_buffs_) {
    auto down_id = outs.first;
    auto max_buff_size = outs.second.first;
    auto used_size = outs.second.second;
    used_size += 1;
    PADDLE_ENFORCE_LE(
141 142
        used_size,
        max_buff_size,
143 144 145
        platform::errors::OutOfRange("downstream=%lld used buff size must <= "
                                     "max_buff_size, but now used_size=%lld, "
                                     "max_buff_size=%lld",
146 147 148
                                     down_id,
                                     used_size,
                                     max_buff_size));
149 150 151 152
    outs.second.second = used_size;

    InterceptorMessage ready_msg;
    ready_msg.set_message_type(DATA_IS_READY);
153
    VLOG(3) << "ComputeInterceptor " << interceptor_id_
Y
Yuang Liu 已提交
154 155
            << " Send data_is_ready msg to " << down_id
            << " for step: " << step_;
156 157 158 159 160 161 162 163 164 165
    Send(down_id, ready_msg);
  }
}

void ComputeInterceptor::ReplyCompletedToUpStream() {
  for (auto& ins : in_readys_) {
    auto up_id = ins.first;
    auto ready_size = ins.second.second;
    ready_size -= 1;
    PADDLE_ENFORCE_GE(
166 167
        ready_size,
        0,
168
        platform::errors::OutOfRange(
169 170
            "upstream=%lld ready_size must >= 0, but now got %lld",
            up_id,
171 172 173
            ready_size));
    ins.second.second = ready_size;

Y
Yuang Liu 已提交
174 175 176
    VLOG(3) << "ComputeInterceptor " << interceptor_id_
            << " Reply data_is_useless msg to " << up_id
            << " for step: " << step_;
177
    if (is_source_ && up_id == -1) return;
Y
Yuang Liu 已提交
178

179
    InterceptorMessage reply_msg;
180
    reply_msg.set_message_type(DATA_IS_USELESS);
181 182 183 184
    Send(up_id, reply_msg);
  }
}

185
void ComputeInterceptor::RunOps() {
186
  VLOG(3) << "ComputeInterceptor " << interceptor_id_ << " running ops for the "
Y
Yuang Liu 已提交
187
          << step_ + 1 << " time.";
188 189
  for (auto op : node_->ops()) {
    op->Run(*microbatch_scopes_[step_ % node_->max_run_times()], place_);
190 191
    if (gc_) {
      framework::DeleteUnusedTensors(
192 193 194 195
          *microbatch_scopes_[step_ % node_->max_run_times()],
          op,
          node_->unused_vars(),
          gc_.get());
196
    }
197 198 199
  }
}

200
void ComputeInterceptor::Run() {
201
  while (IsInputReady() && CanWriteOutput()) {
202
    VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running";
203

204
    RunOps();
205
    ++step_;
206 207 208 209 210

    // send to downstream and increase buff used
    SendDataReadyToDownStream();
    // reply to upstream and decrease ready data
    ReplyCompletedToUpStream();
211
    // Try to stop Carrier
212
    if (is_last_ && (step_ % node_->max_run_times() == 0)) {
213 214
      VLOG(3) << "Interceptor " << GetInterceptorId()
              << " is stopping carrier.";
215
      // FIXME(wangxi): with multi sink interceptor
216 217
      StopCarrier();
    }
218 219 220
  }
}

221 222 223 224
void ComputeInterceptor::ReceivedStop(int64_t up_id) {
  received_stop_ = true;

  // source node has no upstream, stop is send by carrier or others
225
  if (is_source_ && up_id == -1) return;
226 227

  auto it = in_stops_.find(up_id);
228 229
  PADDLE_ENFORCE_NE(it,
                    in_stops_.end(),
230 231 232
                    platform::errors::NotFound(
                        "Cannot find upstream=%lld in in_stops.", up_id));
  PADDLE_ENFORCE_EQ(
233 234
      it->second,
      false,
235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262
      platform::errors::AlreadyExists("Already received stop from %lld, stop "
                                      "cannot be send more than once."));
  it->second = true;
}

void ComputeInterceptor::TryStop() {
  if (!received_stop_) return;

  // can stop only when all upstream is stop and
  // downstream complete
  for (auto& in_stop : in_stops_) {
    if (!in_stop.second) return;
  }
  for (auto& out_buff : out_buffs_) {
    auto used_size = out_buff.second.second;
    if (used_size != 0) return;
  }

  // send stop to downstream
  for (auto& out : out_buffs_) {
    auto down_id = out.first;
    InterceptorMessage stop;
    stop.set_message_type(STOP);
    Send(down_id, stop);
  }
  stop_ = true;
}

263 264
void ComputeInterceptor::Compute(const InterceptorMessage& msg) {
  if (msg.message_type() == DATA_IS_READY) {
265 266
    IncreaseReady(msg.src_id());
    Run();
267
  } else if (msg.message_type() == DATA_IS_USELESS) {
268 269
    DecreaseBuff(msg.src_id());
    Run();
270 271
  } else if (msg.message_type() == STOP) {
    ReceivedStop(msg.src_id());
272
  }
273 274

  TryStop();
275 276 277 278 279 280
}

REGISTER_INTERCEPTOR(Compute, ComputeInterceptor);

}  // namespace distributed
}  // namespace paddle