提交 a77d1ebe 编写于 作者: X Xinqi Li

SyncDevice

上级 3a30230d
......@@ -13,6 +13,7 @@ class CpuDeviceCtx final : public DeviceCtx {
std::unique_ptr<DeviceCtx> Copy() const { return std::unique_ptr<DeviceCtx>(new CpuDeviceCtx()); }
void SyncDevice() override {}
void AddCallBack(std::function<void()> callback) const override { callback(); }
private:
......
......@@ -29,6 +29,8 @@ class CudaDeviceCtx : public DeviceCtx {
}
const cudnnHandle_t& cudnn_handle() const override { return *(cuda_handler_->cudnn_handle()); }
void SyncDevice() override { CudaCheck(cudaStreamSynchronize(cuda_stream())); }
void AddCallBack(std::function<void()> callback) const override {
cuda_handler_->AddCallBack(callback);
}
......
......@@ -19,6 +19,7 @@ class DeviceCtx {
virtual const ncclComm_t& nccl_handle() const { UNIMPLEMENTED(); }
#endif
virtual void SyncDevice() { UNIMPLEMENTED(); }
virtual void AddCallBack(std::function<void()>) const = 0;
protected:
......
#include "oneflow/core/kernel/input_kernel.h"
#include "oneflow/core/kernel/kernel.h"
namespace oneflow {
namespace {
template<DeviceType device_type>
class InputKernel final : public KernelIf<device_type> {
public:
OF_DISALLOW_COPY_AND_MOVE(InputKernel);
InputKernel() = default;
~InputKernel() = default;
private:
void Forward(const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const override {}
};
} // namespace
ADD_DEVICE_TYPE_KERNEL_CREATOR(OperatorConf::kInputConf, InputKernel);
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_KERNEL_INPUT_KERNEL_H_
#define ONEFLOW_CORE_KERNEL_INPUT_KERNEL_H_
#include "oneflow/core/kernel/kernel.h"
namespace oneflow {
template<DeviceType device_type>
class InputKernel final : public KernelIf<device_type> {
public:
OF_DISALLOW_COPY_AND_MOVE(InputKernel);
InputKernel() = default;
~InputKernel() = default;
private:
void Forward(const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const override {}
};
} // namespace oneflow
#endif // ONEFLOW_CORE_KERNEL_INPUT_KERNEL_H_
......@@ -6,6 +6,7 @@ template<DeviceType device_type>
void ReturnKernel<device_type>::ForwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
BnInOp2Blob("out")->CopyDataContentFrom(ctx.device_ctx, BnInOp2Blob("in"));
ctx.device_ctx->SyncDevice();
}
ADD_DEVICE_TYPE_KERNEL_CREATOR(OperatorConf::kReturnConf, ReturnKernel);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册