input_kernel.cpp 2.0 KB
Newer Older
S
Shenghang Tsai 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
L
Li Xinqi 已提交
16

X
Xinqi Li 已提交
17
#include "oneflow/core/kernel/kernel.h"
L
Li Xinqi 已提交
18 19 20
#include "oneflow/core/common/buffer_manager.h"
#include "oneflow/core/job/job_instance.h"
#include "oneflow/core/job/global_for.h"
X
Xinqi 已提交
21 22 23

namespace oneflow {

X
Xinqi Li 已提交
24 25 26
namespace {

template<DeviceType device_type>
J
Juncheng 已提交
27
class InputKernel final : public Kernel {
X
Xinqi Li 已提交
28 29 30 31 32 33
 public:
  OF_DISALLOW_COPY_AND_MOVE(InputKernel);
  InputKernel() = default;
  ~InputKernel() = default;

 private:
J
Juncheng 已提交
34
  void ForwardDataContent(const KernelContext* ctx) const override {
L
Li Xinqi 已提交
35
    if (CHECK_JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
J
Juncheng 已提交
36 37
      CHECK(this->op_conf().input_conf().has_job_name());
      const auto& job_name = this->op_conf().input_conf().job_name();
L
Li Xinqi 已提交
38 39 40 41 42 43 44
      const auto& op_name = this->op_conf().name();
      auto* buffer_mgr = Global<BufferMgr<std::shared_ptr<JobInstance>>>::Get();
      auto* buffer = buffer_mgr->Get(GetInputBufferName(job_name, op_name));
      std::shared_ptr<JobInstance> job_instance;
      BufferStatus buffer_status = buffer->TryReceive(&job_instance);
      CHECK_NE(buffer_status, kBufferStatusEmpty);
      if (buffer_status == kBufferStatusSuccess) {
J
Juncheng 已提交
45
        OfBlob ofblob(ctx->device_ctx(), ctx->BnInOp2Blob("out"));
L
Li Xinqi 已提交
46 47 48 49
        job_instance->PushBlobByOpName(reinterpret_cast<uint64_t>(&ofblob), op_name);
      }
    }
  }
J
Juncheng 已提交
50
  void ForwardHeader(const KernelContext* ctx) const override {}
X
Xinqi Li 已提交
51 52 53 54
};

}  // namespace

X
Xinqi 已提交
55
ADD_DEVICE_TYPE_KERNEL_CREATOR(OperatorConf::kInputConf, InputKernel);
X
Xinqi Li 已提交
56 57

}  // namespace oneflow