未验证 提交 fbedf77e 编写于 作者: 王明冬 提交者: GitHub

add ipu support for standalone executor. (#44342)

上级 04e55582
......@@ -315,6 +315,7 @@ std::shared_ptr<OperatorBase> TransferDevice(const std::string& var_name,
op_type = kMemcpyH2D;
int dst_place_type = platform::is_gpu_place(dst_place) ? 0
: platform::is_npu_place(dst_place) ? 1
: platform::is_ipu_place(dst_place) ? 3
: platform::is_xpu_place(dst_place) ? 2
: -1;
attr_map = {{"dst_place_type", dst_place_type}};
......
......@@ -25,6 +25,7 @@
#include "paddle/fluid/platform/os_info.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/fluid/platform/profiler/supplement_tracing.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_context.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
......@@ -475,8 +476,13 @@ void InterpreterCore::Convert(
BuildSkipShareLoDInfo();
for (size_t i = 0; i < vec_instruction_.size(); ++i) {
#ifdef PADDLE_WITH_IPU
gc_event_.emplace_back(phi::CPUPlace(), 0);
#else
gc_event_.emplace_back(vec_instruction_[i].DeviceContext().GetPlace(),
platform::GenerateDeviceEventFlag());
#endif
}
bool inplaced = false;
for (auto inst : vec_instruction_) {
......
......@@ -390,7 +390,7 @@ static bool IsCpuOp(const Instruction& instr) {
// is supported heterogeneous place
static bool IsSupportedHetePlace(const phi::Place& place) {
return platform::is_gpu_place(place) || platform::is_npu_place(place) ||
platform::is_xpu_place(place);
platform::is_xpu_place(place) || platform::is_ipu_place(place);
}
} // namespace interpreter
......
......@@ -204,8 +204,9 @@ bool StreamAnalyzer::IsDirectRun(Instruction& cur_instr,
const Instruction& next_instr) {
if (&cur_instr.DeviceContext() == &next_instr.DeviceContext()) return true;
// xpu memcpy kerenl is synchronous.
if (platform::is_xpu_place(place_)) return true;
// xpu&ipu memcpy kerenl is synchronous.
if (platform::is_ipu_place(place_) || platform::is_xpu_place(place_))
return true;
// npu d2h kernel is asynchronous.
if (platform::is_npu_place(place_)) {
......
......@@ -408,6 +408,12 @@ struct OpKernelRegistrarFunctorEx<PlaceType,
::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
__VA_ARGS__)
#define REGISTER_OP_IPU_KERNEL_FUNCTOR(op_type, ...) \
REGISTER_OP_KERNEL_EX( \
op_type, IPU, ::paddle::platform::IPUPlace, DEFAULT_TYPE, \
::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
__VA_ARGS__)
/**
* Macro to mark what Operator and Kernel
* we will use and tell the compiler to
......
......@@ -233,3 +233,31 @@ REGISTER_OP_NPU_KERNEL_FUNCTOR(memcpy_d2h,
int16_t,
ops::MemcpyD2HKernel);
#endif
#ifdef PADDLE_WITH_IPU
REGISTER_OP_IPU_KERNEL_FUNCTOR(memcpy_d2h,
float,
ops::MemcpyD2HKernel,
double,
ops::MemcpyD2HKernel,
int8_t,
ops::MemcpyD2HKernel,
uint8_t,
ops::MemcpyD2HKernel,
int,
ops::MemcpyD2HKernel,
int64_t,
ops::MemcpyD2HKernel,
bool,
ops::MemcpyD2HKernel,
paddle::platform::bfloat16,
ops::MemcpyD2HKernel,
paddle::platform::complex<float>,
ops::MemcpyD2HKernel,
paddle::platform::complex<double>,
ops::MemcpyD2HKernel,
plat::float16,
ops::MemcpyD2HKernel,
int16_t,
ops::MemcpyD2HKernel);
#endif
......@@ -100,6 +100,7 @@ class MemcpyH2DOpProtoMaker : public framework::OpProtoAndCheckerMaker {
"0. CUDAPinnedPlace/CPU <->CUDAPlace"
"1. NPUPinnedPlace/CPU <-> NPUPlace"
"2. CPU <->XPUPlace"
"3. CPU <->IPUPlace"
"Other place type is Unimplemented and will cause ERROR.");
AddComment(R"DOC(
MemcpyD2H Operator.
......@@ -233,3 +234,31 @@ REGISTER_OP_NPU_KERNEL_FUNCTOR(memcpy_h2d,
int16_t,
ops::MemcpyH2DKernel);
#endif
#ifdef PADDLE_WITH_IPU
REGISTER_OP_IPU_KERNEL_FUNCTOR(memcpy_h2d,
float,
ops::MemcpyH2DKernel,
double,
ops::MemcpyH2DKernel,
int8_t,
ops::MemcpyH2DKernel,
uint8_t,
ops::MemcpyH2DKernel,
int,
ops::MemcpyH2DKernel,
int64_t,
ops::MemcpyH2DKernel,
bool,
ops::MemcpyH2DKernel,
paddle::platform::bfloat16,
ops::MemcpyH2DKernel,
paddle::platform::complex<float>,
ops::MemcpyH2DKernel,
paddle::platform::complex<double>,
ops::MemcpyH2DKernel,
plat::float16,
ops::MemcpyH2DKernel,
int16_t,
ops::MemcpyH2DKernel);
#endif
......@@ -50,7 +50,7 @@ class MemcpyH2DFunctor {
lod_tensor.dtype(),
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
if (dst_place_type_ == 0 || dst_place_type_ == 1 || dst_place_type_ == 2) {
if (dst_place_type_ >= 0 && dst_place_type_ <= 3) {
framework::TensorCopy(
lod_tensor, dev_ctx_.GetPlace(), dev_ctx_, &out_tensor);
} else {
......
......@@ -64,7 +64,7 @@ class DeviceEvent {
"Required type < %d, but received type = %d",
MaxDeviceTypes,
type_id_));
// TODO(Aurelius84): only support CPU/CUDA, need consider XPU/NPU later
// TODO(Aurelius84): only support CPU/CUDA/XPU/NPU.
PADDLE_ENFORCE_LT(type_id_,
4,
platform::errors::Unavailable(
......
......@@ -1388,8 +1388,8 @@ class Executor(object):
program = pruned_program
def _can_use_interpreter_core(program, place):
if core.is_compiled_with_mlu() or core.is_compiled_with_ipu(
) or isinstance(place, core.CustomPlace):
if core.is_compiled_with_mlu() or isinstance(
place, core.CustomPlace):
return False
compiled = isinstance(program, compiler.CompiledProgram)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册