未验证 提交 b2704837 编写于 作者: 王明冬 提交者: GitHub

add xpu support for new static alone executor. test=develop (#43076)

上级 0d17c047
......@@ -298,6 +298,15 @@ elseif(WITH_ROCM)
data_type_transform_test
SRCS data_type_transform_test.cc data_type_transform_test.cu
DEPS data_type_transform)
elseif(WITH_XPU)
cc_library(
data_type_transform
SRCS data_type_transform.cc
DEPS tensor xpulib)
cc_test(
data_type_transform_test
SRCS data_type_transform_test.cc
DEPS data_type_transform)
else()
cc_library(
data_type_transform
......
......@@ -277,13 +277,33 @@ std::shared_ptr<OperatorBase> TransferDevice(const std::string& var_name,
// 2. Construct VariableNameMap
VariableNameMap in_name_map = {{"X", {var_name}}};
VariableNameMap out_name_map = {{"Out", {*new_var_name}}};
int dst_place_type = platform::is_cpu_place(dst_place) ? 0
: platform::is_gpu_place(dst_place) ? 1
: -1;
AttributeMap attr_map = {{"dst_place_type", dst_place_type}};
// 3. Create memcpy_d2h_op or memcpy_h2d_op
std::string op_type = get_memcpy_type(src_place, dst_place);
std::string op_type;
AttributeMap attr_map;
PADDLE_ENFORCE_EQ(platform::is_same_place(src_place, dst_place), false,
platform::errors::PreconditionNotMet(
"Required src_place shall be different with dst_place, "
"but received same place: %s",
src_place));
if (IsSupportedHetePlace(dst_place)) {
op_type = kMemcpyH2D;
int dst_place_type = platform::is_gpu_place(dst_place) ? 0
: platform::is_npu_place(dst_place) ? 1
: platform::is_xpu_place(dst_place) ? 2
: -1;
attr_map = {{"dst_place_type", dst_place_type}};
} else if (IsSupportedHetePlace(src_place)) {
op_type = kMemcpyD2H;
int dst_place_type = platform::is_cpu_place(dst_place) ? 0
: platform::is_cuda_pinned_place(dst_place) ? 1
: -1;
attr_map = {{"dst_place_type", dst_place_type}};
} else {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"Not support Memcpy typ : %s -> %s", src_place, dst_place));
}
auto& op_info = OpInfoMap::Instance().Get(op_type);
auto op = std::shared_ptr<OperatorBase>(
op_info.Creator()(op_type, in_name_map, out_name_map, attr_map));
......@@ -434,31 +454,13 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
if (transfered) {
// NOTE(zhiqiu): UPDATE the corresponding OeratorBase to make it consistent
// with instruction. (hot fix, it is not good design here)
op_func_node->operator_base_ =
std::shared_ptr<OperatorBase>(framework::OpRegistry::CreateOp(
op_base->Type(), new_ins, new_outs, op_base->Attrs()));
// with instruction.
op_base->Inputs() = new_ins;
op_base->Outputs() = new_outs;
}
op_func_node->no_data_transform_index = std::move(no_data_transform_index);
}
std::string get_memcpy_type(const platform::Place& src_place,
const platform::Place& dst_place) {
PADDLE_ENFORCE_EQ(platform::is_same_place(src_place, dst_place), false,
platform::errors::PreconditionNotMet(
"Required src_place shall be different with dst_place, "
"but received same place: %s",
src_place));
if (platform::is_gpu_place(dst_place)) {
return kMemcpyH2D;
} else if (platform::is_gpu_place(src_place)) {
return kMemcpyD2H;
} else {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"Not support Memcpy typ : %s -> %s", src_place, dst_place));
}
}
void HandleComplexGradToRealGrad(const OpFuncNode& op_func_node,
const platform::Place& place,
const VariableNameMap& out_names,
......
......@@ -68,9 +68,6 @@ void HandleComplexGradToRealGrad(const OpFuncNode& op_func_node,
std::vector<OpFuncNode>* op_func_nodes,
framework::Scope* local_scope);
std::string get_memcpy_type(const platform::Place& src_place,
const platform::Place& dst_place);
inline bool need_device_transform(const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_key) {
auto& src_place = kernel_type_for_var.place_;
......
......@@ -348,7 +348,7 @@ void deal_operator_base(const platform::Place& place,
auto* dev_ctx = pool.Get(place);
// input, output is prepared. set the other attributes.
op_func_node->operator_base_ = op_base;
if (platform::is_gpu_place(place)) {
if (IsSupportedHetePlace(place)) {
op_func_node->type_ = OpFuncType::kQueueAsync;
} else if (platform::is_cpu_place(place)) {
op_func_node->type_ = OpFuncType::kQueueSync;
......@@ -379,7 +379,6 @@ void build_op_func_list(const platform::Place& place,
bool use_local_scope) {
Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope()
: var_scope->GetMutableScope();
auto& all_op_kernels = OperatorWithKernel::AllOpKernels();
std::vector<std::unique_ptr<OperatorBase>>
ops_unique; // its elements will be moved to vec_func_list
// Step 1: create all ops for current block.
......@@ -429,7 +428,7 @@ void build_op_func_list(const platform::Place& place,
std::tie(outs_map, outs_name2id) =
build_variable_map(outputs_names, var_scope, enforce_exist);
// step 2: build OpFuncNode
// step 1: build OpFuncNode
OpFuncNode op_func_node;
op_func_node.operator_base_ = ops[i];
op_func_node.input_index = ins_name2id;
......@@ -449,11 +448,7 @@ void build_op_func_list(const platform::Place& place,
runtime_context.inputs.swap(ins_map);
runtime_context.outputs.swap(outs_map);
platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
Scope scope;
Scope* runtime_scope = &scope;
Scope scope, *runtime_scope = &scope;
// NOTE(Ruibiao): We do not encourage directly using scope in OP kernel.
// But some OPs do have such behavior (e.g., cinn_launch OP). Here special
// treatment for them.
......@@ -465,63 +460,17 @@ void build_op_func_list(const platform::Place& place,
runtime_scope = local_scope;
}
auto expected_kernel_key = op_with_kernel->GetExpectedKernelType(
ExecutionContext(*op, *runtime_scope, *dev_ctx, runtime_context));
op_with_kernel->ResetKernelType(new OpKernelType(expected_kernel_key));
auto& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
auto exec_ctx = ExecutionContext(*op_with_kernel, *runtime_scope,
*dev_ctx, runtime_context);
auto expected_kernel_key =
op_with_kernel->GetExpectedKernelType(exec_ctx);
// change device by the device_guard()
apply_device_guard(op, place, &expected_kernel_key);
VLOG(3) << "expected_kernel_key : " << expected_kernel_key;
// step 3. apply data transforms and insert data transfer ops
VariableValueMap& ins_map_temp = runtime_context.inputs;
VariableValueMap& outs_map_temp = runtime_context.outputs;
// NOTE(zhiqiu): op_func_node->operator_base_ maybe changed in
// ApplyDataTransform
ApplyDataTransform(expected_kernel_key,
place,
&ins_map_temp,
&outs_map_temp,
var_scope,
&op_func_node,
vec_func_list,
use_local_scope);
op_with_kernel = const_cast<framework::OperatorWithKernel*>(
static_cast<const framework::OperatorWithKernel*>(
op_func_node.operator_base_.get()));
// step 4. Run op kernel
VLOG(3) << op_with_kernel->Type()
<< " : expected_kernel_key : " << expected_kernel_key;
if (platform::is_gpu_place(expected_kernel_key.place_)) {
op_func_node.type_ = OpFuncType::kQueueAsync;
} else if (platform::is_cpu_place(expected_kernel_key.place_)) {
op_func_node.type_ = OpFuncType::kQueueSync;
} else {
PADDLE_THROW(platform::errors::Fatal("Unsupported current place %s",
expected_kernel_key.place_));
}
if (!(expected_kernel_key.place_ == dev_ctx->GetPlace())) {
dev_ctx = pool.Get(expected_kernel_key.place_);
}
op_func_node.dev_ctx_ = dev_ctx;
VLOG(3) << op_with_kernel->Type()
<< " : expected_kernel_key : " << expected_kernel_key;
// see OperatorWithKernel::RunImpl in operator.cc for why
if (!(op->HasAttr(kAllKernelsMustComputeRuntimeShape) &&
op->Attr<bool>(kAllKernelsMustComputeRuntimeShape))) {
InterpretercoreInferShapeContext infer_shape_ctx(*op, runtime_context);
// TODO(Aurelius84): In case of control flow ops, they are NOT
// inheritted from OperatorWithKernel.
op_with_kernel->Info().infer_shape_(&infer_shape_ctx);
}
auto exec_ctx = ExecutionContext(
*op_with_kernel, *runtime_scope, *dev_ctx, runtime_context);
VLOG(4) << "expected_kernel_key : " << expected_kernel_key;
// step 2. select op kernel
auto run_phi_kernel = false;
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(
op_with_kernel->Type())) {
......@@ -531,10 +480,7 @@ void build_op_func_list(const platform::Place& place,
if (op_with_kernel->PhiKernel()->IsValid()) {
run_phi_kernel = true;
} else {
auto kernels_iter = all_op_kernels.find(op_with_kernel->Type());
if (kernels_iter == all_op_kernels.end() ||
kernels_iter->second.find(expected_kernel_key) ==
kernels_iter->second.end()) {
if (!op_with_kernel->SupportsKernelType(expected_kernel_key)) {
auto pt_cpu_kernel_key = FallBackToCpu(
expected_kernel_key, pt_kernel_key, *op_with_kernel);
op_with_kernel->ResetPhiKernel(
......@@ -545,55 +491,76 @@ void build_op_func_list(const platform::Place& place,
<< pt_kernel_name
<< " | kernel key: " << pt_cpu_kernel_key
<< " | kernel: " << *(op_with_kernel->PhiKernel());
op_with_kernel->ResetKernelType(new OpKernelType(
TransPhiKernelKeyToOpKernelType(pt_cpu_kernel_key)));
run_phi_kernel = true;
}
}
}
}
if (!run_phi_kernel) {
op_with_kernel->ChooseKernel(exec_ctx);
op_func_node.kernel_func_ = *op_with_kernel->kernel_func();
} else {
op_func_node.pt_kernel_ = op_with_kernel->PhiKernel();
}
auto kernel_type = *(op_with_kernel->kernel_type());
if (kernel_type.place_ != dev_ctx->GetPlace()) {
dev_ctx = pool.Get(kernel_type.place_);
}
op_func_node.dev_ctx_ = dev_ctx;
if (IsSupportedHetePlace(kernel_type.place_)) {
op_func_node.type_ = OpFuncType::kQueueAsync;
} else if (platform::is_cpu_place(kernel_type.place_)) {
op_func_node.type_ = OpFuncType::kQueueSync;
} else {
PADDLE_THROW(platform::errors::Fatal("Unsupported current place %s",
kernel_type.place_));
}
VLOG(3) << op_with_kernel->Type()
<< " : expected_kernel_key : " << expected_kernel_key;
<< " : finally selected kernel_key: " << kernel_type;
// step 3. data transform
VariableValueMap& ins_map_temp = runtime_context.inputs;
VariableValueMap& outs_map_temp = runtime_context.outputs;
ApplyDataTransform(kernel_type, place, &ins_map_temp, &outs_map_temp,
var_scope, &op_func_node, vec_func_list,
use_local_scope);
// step 4. infershape, see OperatorWithKernel::RunImpl in operator.cc for
// why.
if (!(op->HasAttr(kAllKernelsMustComputeRuntimeShape) &&
op->Attr<bool>(kAllKernelsMustComputeRuntimeShape))) {
InterpretercoreInferShapeContext infer_shape_ctx(*op, runtime_context);
// TODO(Aurelius84): In case of control flow ops, they are NOT
// inheritted from OperatorWithKernel.
op_with_kernel->Info().infer_shape_(&infer_shape_ctx);
}
// step 5. run kernel
if (run_phi_kernel) {
phi::KernelContext pt_kernel_context;
op_with_kernel->BuildPhiKernelContext(
runtime_context, dev_ctx, &pt_kernel_context);
op_func_node.pt_kernel_ = op_with_kernel->PhiKernel();
op_with_kernel->BuildPhiKernelContext(runtime_context, dev_ctx,
&pt_kernel_context);
(*op_func_node.pt_kernel_)(&pt_kernel_context);
} else {
auto kernels_iter = all_op_kernels.find(op->Type());
PADDLE_ENFORCE_NE(
kernels_iter,
all_op_kernels.end(),
platform::errors::Unavailable(
"There are no kernels which are registered in the %s operator.",
op->Type()));
OpKernelMap& kernels = kernels_iter->second;
auto kernel_iter = kernels.find(expected_kernel_key);
PADDLE_ENFORCE_NE(kernel_iter,
kernels.end(),
platform::errors::NotFound(
"Operator (%s) does not have kernel for %s.",
op->Type(),
KernelTypeToString(expected_kernel_key)));
// TODO(zhiqiu): add fallback logic
op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second);
op_func_node.kernel_func_(exec_ctx);
// the place of exec_ctx maybe has changed.
op_func_node.kernel_func_(ExecutionContext(
*op_with_kernel, *runtime_scope, *dev_ctx, runtime_context));
}
// post-process grad_op.outputs if need cast complex grad into real grad.
// post-process grad_op.outputs if need cast complex grad into real
// grad.
// NOTE(Aurelius84): insert a transfer_dtype_op inplacely to cast it.
if (framework::IsComplexType(expected_kernel_key.data_type_)) {
interpreter::HandleComplexGradToRealGrad(op_func_node,
place,
outputs_names,
&runtime_context.outputs,
var_scope,
vec_func_list,
local_scope);
if (framework::IsComplexType(kernel_type.data_type_)) {
interpreter::HandleComplexGradToRealGrad(
op_func_node, place, outputs_names, &runtime_context.outputs,
var_scope, vec_func_list, local_scope);
}
if (!op_func_node.inplace_back_map.empty()) {
auto& m = op_func_node.inplace_back_map;
// NOTE(zhiqiu): same logic as TransferInplaceVarsBack() in operator.cc
// NOTE(zhiqiu): same logic as TransferInplaceVarsBack() in
// operator.cc
for (auto& p : m) {
auto* transformed_tensor =
GetMutableLoDTensorOrSelectedRowsValueFromVar(
......@@ -899,17 +866,17 @@ std::map<int, std::list<int>> build_op_downstream_map(
// step2: update 2 var2xxxx data structure
for (auto& item :
vec_instruction[op_idx].Inputs()) { // for all inputs(read only)
vec_instruction[op_idx].Outputs()) { // for all write vars
for (auto var : item.second) {
update_var_min_rw_op(op2dependences, &var2min_rw_op, op_idx, var);
var2recent_write_op[var] = op_idx;
var2min_rw_op[var] = {static_cast<int>(op_idx)};
remove_duplicate.insert(var);
}
}
for (auto& item :
vec_instruction[op_idx].Outputs()) { // for all write vars
vec_instruction[op_idx].Inputs()) { // for all inputs(read only)
for (auto var : item.second) {
var2recent_write_op[var] = op_idx;
if (remove_duplicate.count(var) ==
0) { // var in input list and in output list, so remove it.
update_var_min_rw_op(op2dependences, &var2min_rw_op, op_idx, var);
......
......@@ -297,7 +297,7 @@ struct InstructionInfo {
enum class OpFuncType {
kQueueSync = 0, // CPU kernel, block host
kQueueAsync = 1, // GPU Kernel or d2h, h2d, send, recv, broadcast
kQueueAsync = 1, // GPU、XPU Kernel or d2h, h2d, send, recv, broadcast
};
class RuntimeInferShapeContext;
......@@ -417,6 +417,11 @@ static bool IsCpuOp(const Instruction& instr) {
return platform::is_cpu_place(instr.DeviceContext().GetPlace());
}
// is supported heterogeneous place
static bool IsSupportedHetePlace(const phi::Place& place) {
return platform::is_gpu_place(place) || platform::is_xpu_place(place);
}
} // namespace interpreter
} // namespace framework
......
......@@ -155,20 +155,24 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext(
const OpFuncNode& op_func_node) {
auto& op_type = op_func_node.operator_base_->Type();
auto* dev_ctx = op_func_node.dev_ctx_;
if (op_type == interpreter::kMemcpyD2H) {
VLOG(3) << "Get dev_ctx from d2h_context_pool_";
dev_ctx = d2h_ctxs_[place_].get().get();
} else if (op_type == interpreter::kMemcpyH2D) {
VLOG(3) << "Get dev_ctx from h2d_context_pool_";
dev_ctx = h2d_ctxs_[place_].get().get();
// only gpu need update. xpu not need, because xpu memcpy op kernel is
// synchronous.
if (platform::is_gpu_place(place_)) {
if (op_type == interpreter::kMemcpyD2H) {
VLOG(3) << "Get dev_ctx from d2h_context_pool_";
dev_ctx = d2h_ctxs_[place_].get().get();
} else if (op_type == interpreter::kMemcpyH2D) {
VLOG(3) << "Get dev_ctx from h2d_context_pool_";
dev_ctx = h2d_ctxs_[place_].get().get();
}
}
return dev_ctx;
}
/*
* NOTE(dev): The following cases are considered as directly run:
*
* 0. in XPU place. because xpu memcpy op kernel is synchronous.
* 1. with same dev_ctx_, such as: CPU -> CPU, GPU -> GPU
* 2. CPU -> any (it is possible: CPU op->VAR->GPU op, when var is no need
* buffer or no need data transform)
......@@ -177,7 +181,8 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext(
*/
bool StreamAnalyzer::IsDirectRun(Instruction& cur_instr,
const Instruction& next_instr) {
return (&cur_instr.DeviceContext() == &next_instr.DeviceContext() ||
return platform::is_xpu_place(place_) ||
(&cur_instr.DeviceContext() == &next_instr.DeviceContext() ||
interpreter::IsCpuOp(cur_instr) ||
interpreter::IsMemcpyD2H(cur_instr) ||
interpreter::IsMemcpyH2D(next_instr));
......@@ -187,6 +192,9 @@ platform::DeviceType StreamAnalyzer::GetWaiterType(const Instruction& instr) {
if (instr.KernelType() == OpFuncType::kQueueSync) {
return platform::kCPU;
} else {
if (platform::is_xpu_place(place_)) {
return platform::kXPU;
}
return platform::kCUDA;
}
}
......
......@@ -1296,6 +1296,23 @@ bool OperatorWithKernel::SupportsMKLDNN(
});
}
bool OperatorWithKernel::SupportsKernelType(
const OpKernelType& kernel_type) const {
auto& all_op_kernels = AllOpKernels();
auto kernels_iter = all_op_kernels.find(type_);
bool support =
kernels_iter != all_op_kernels.end() &&
kernels_iter->second.find(kernel_type) != kernels_iter->second.end();
#if defined(PADDLE_WITH_XPU)
if (paddle::platform::is_xpu_place(kernel_type.place_)) {
support = support &&
paddle::platform::is_xpu_support_op(type_, kernel_type) &&
!paddle::platform::is_in_xpu_black_list(type_);
}
#endif
return support;
}
bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
proto::VarType::Type data_type) const {
const auto& attrs_map = ctx.Attrs();
......
......@@ -193,6 +193,8 @@ class OperatorBase {
const VariableNameMap& Inputs() const { return inputs_; }
const VariableNameMap& Outputs() const { return outputs_; }
VariableNameMap& Inputs() { return inputs_; }
VariableNameMap& Outputs() { return outputs_; }
const OpInfo& Info() const {
PADDLE_ENFORCE_NOT_NULL(
......@@ -579,6 +581,8 @@ class OperatorWithKernel : public OperatorBase {
}
bool SupportsMKLDNN(proto::VarType::Type data_type) const;
bool SupportsKernelType(const OpKernelType& kernel_type) const;
bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
proto::VarType::Type data_type) const;
......@@ -621,6 +625,7 @@ class OperatorWithKernel : public OperatorBase {
/* member functions for adapting to phi lib */
phi::KernelKey ChoosePhiKernel(const ExecutionContext& ctx) const;
void ChooseKernel(const ExecutionContext& ctx) const;
/**
* Transfer data place for phi kernel
* Is this really needed?
......@@ -644,6 +649,7 @@ class OperatorWithKernel : public OperatorBase {
}
const OpKernelType* kernel_type() const { return kernel_type_.get(); }
const OpKernelFunc* kernel_func() const { return kernel_func_.get(); }
void ResetKernelType(OpKernelType* kernel_type) {
kernel_type_.reset(kernel_type);
......@@ -672,8 +678,6 @@ class OperatorWithKernel : public OperatorBase {
OpKernelType InnerGetExpectedKernelType(const ExecutionContext& ctx) const;
void ChooseKernel(const ExecutionContext& ctx) const;
void HandleComplexGradToRealGrad(const Scope& scope,
RuntimeContext* ctx) const;
......
......@@ -704,7 +704,8 @@ class AllocatorFacadePrivate {
if (platform::is_gpu_place(place)) {
std::shared_ptr<StreamSafeCUDAAllocator>&& allocator =
std::make_shared<StreamSafeCUDAAllocator>(
pair.second, place, /* default_stream = */ nullptr,
pair.second, place,
/* default_stream = */ nullptr,
/* in_cuda_graph_capturing = */ !allow_free_idle_chunk_);
pair.second = allocator;
......@@ -1044,8 +1045,11 @@ AllocationPtr AllocatorFacade::Alloc(const platform::Place& place, size_t size,
} else {
return m->GetAllocator(p, size)->Allocate(size);
}
#elif defined PADDLE_WITH_XPU
return GetAllocator(place)->Allocate(size);
#else
PADDLE_THROW(platform::errors::PreconditionNotMet("Not compiled with GPU."));
PADDLE_THROW(
platform::errors::PreconditionNotMet("Not compiled with GPU or XPU."));
#endif
}
......
......@@ -174,6 +174,9 @@ set(COMMON_OP_DEPS ${COMMON_OP_DEPS} eigen_function)
if (WITH_GPU OR WITH_ROCM)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu bert_encoder_functor)
endif()
if(WITH_XPU)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} xpulib)
endif()
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} device_memory_aligment)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} layer)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} tensor_formatter)
......
......@@ -95,7 +95,7 @@ class MemcpyD2HOpProtoMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>(
"dst_place_type",
"Determine the dst place of tensor copy. "
"By Now it ONLY support NPUPlace/CUDAPlace <-> CUDAPinnedPlace/CPU"
"By Now it ONLY support XPU/NPUPlace/CUDAPlace <-> CUDAPinnedPlace/CPU"
"Other place type is Unimplemented and will cause ERROR."
"0: dst is on CPUPlace. "
"1: dst is on CUDAPinnedPlace. ");
......@@ -140,6 +140,17 @@ REGISTER_OP_CUDA_KERNEL_FUNCTOR(
ops::MemcpyD2HKernel, int16_t, ops::MemcpyD2HKernel);
#endif
#ifdef PADDLE_WITH_XPU
REGISTER_OP_XPU_KERNEL_FUNCTOR(
memcpy_d2h, float, ops::MemcpyD2HKernel, double, ops::MemcpyD2HKernel,
int8_t, ops::MemcpyD2HKernel, uint8_t, ops::MemcpyD2HKernel, int,
ops::MemcpyD2HKernel, int64_t, ops::MemcpyD2HKernel, bool,
ops::MemcpyD2HKernel, paddle::platform::bfloat16, ops::MemcpyD2HKernel,
paddle::platform::complex<float>, ops::MemcpyD2HKernel,
paddle::platform::complex<double>, ops::MemcpyD2HKernel, plat::float16,
ops::MemcpyD2HKernel, int16_t, ops::MemcpyD2HKernel);
#endif
#ifdef PADDLE_WITH_ASCEND_CL
REGISTER_OP_NPU_KERNEL_FUNCTOR(
memcpy_d2h, float, ops::MemcpyD2HKernel, double, ops::MemcpyD2HKernel,
......
......@@ -98,7 +98,8 @@ class MemcpyH2DOpProtoMaker : public framework::OpProtoAndCheckerMaker {
"By Now it ONLY support CUDAPinnedPlace/CPU <-> NPUPlace/CUDAPlace "
"Other place type is Unimplemented and will cause ERROR."
"0: dst is on CUDAPlace. "
"1: dst is on NPUPlace. ");
"1: dst is on NPUPlace. "
"2: dst is on XPUPlace. ");
AddComment(R"DOC(
MemcpyD2H Operator.
By now, it ONLY supports the memcopy between CUDAPinnedPlace/CPU <-> NPUPlace/CUDAPlace.
......@@ -140,6 +141,17 @@ REGISTER_OP_CUDA_KERNEL_FUNCTOR(
ops::MemcpyH2DKernel, int16_t, ops::MemcpyH2DKernel);
#endif
#ifdef PADDLE_WITH_XPU
REGISTER_OP_XPU_KERNEL_FUNCTOR(
memcpy_h2d, float, ops::MemcpyH2DKernel, double, ops::MemcpyH2DKernel,
int8_t, ops::MemcpyH2DKernel, uint8_t, ops::MemcpyH2DKernel, int,
ops::MemcpyH2DKernel, int64_t, ops::MemcpyH2DKernel, bool,
ops::MemcpyH2DKernel, paddle::platform::bfloat16, ops::MemcpyH2DKernel,
paddle::platform::complex<float>, ops::MemcpyH2DKernel,
paddle::platform::complex<double>, ops::MemcpyH2DKernel, plat::float16,
ops::MemcpyH2DKernel, int16_t, ops::MemcpyH2DKernel);
#endif
#ifdef PADDLE_WITH_ASCEND_CL
REGISTER_OP_NPU_KERNEL_FUNCTOR(
memcpy_h2d, float, ops::MemcpyH2DKernel, double, ops::MemcpyH2DKernel,
......
......@@ -49,7 +49,7 @@ class MemcpyH2DFunctor {
dev_ctx_.GetPlace(), lod_tensor.dtype(),
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
if (dst_place_type_ == 0 || dst_place_type_ == 1) {
if (dst_place_type_ == 0 || dst_place_type_ == 1 || dst_place_type_ == 2) {
framework::TensorCopy(lod_tensor, dev_ctx_.GetPlace(), dev_ctx_,
&out_tensor);
} else {
......
......@@ -270,6 +270,15 @@ cc_library(
set(DEVICE_EVENT_LIBS
device_event_base
CACHE INTERNAL "device event libs")
if(WITH_XPU)
cc_library(
device_event_xpu
SRCS device_event_xpu.cc
DEPS device_event_base xpu_info)
set(DEVICE_EVENT_LIBS
device_event_xpu
CACHE INTERNAL "device event libs")
endif()
if(WITH_GPU)
nv_library(
......
......@@ -25,6 +25,7 @@
using ::paddle::platform::kCPU;
using ::paddle::platform::kCUDA;
using ::paddle::platform::kXPU;
USE_EVENT(kCPU)
USE_EVENT_WAIT(kCPU, kCPU)
......@@ -34,3 +35,9 @@ USE_EVENT(kCUDA);
USE_EVENT_WAIT(kCUDA, kCUDA)
USE_EVENT_WAIT(kCPU, kCUDA)
#endif
#ifdef PADDLE_WITH_XPU
USE_EVENT(kXPU);
USE_EVENT_WAIT(kXPU, kXPU)
USE_EVENT_WAIT(kCPU, kXPU)
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/device/xpu/xpu_info.h"
#include "paddle/fluid/platform/device_event_base.h"
#ifdef PADDLE_WITH_XPU
namespace paddle {
namespace platform {
struct XPUDeviceEventWrapper {
explicit XPUDeviceEventWrapper(const platform::Place& place) {
PADDLE_ENFORCE_EQ(
platform::is_xpu_place(place), true,
platform::errors::PreconditionNotMet(
"Required device shall be XPUPlace, but received %d. ", place));
device_id_ = place.device;
PADDLE_ENFORCE_GT(
device_id_, -1,
platform::errors::PreconditionNotMet(
"Required DeviceOption.device_id > -1, but received %d. ",
device_id_));
xpu_event_create(&handle_);
}
xpuEventHandle handle_;
int device_id_;
};
void DeviceEventCreateXPU(DeviceEvent* event, const platform::Place& place,
unsigned int) {
event->InitEvent(std::make_shared<XPUDeviceEventWrapper>(place));
}
void DeviceEventRecordXPU(DeviceEvent* event, const DeviceContext* context) {
auto* wrapper = static_cast<XPUDeviceEventWrapper*>(event->GetEvent().get());
PADDLE_ENFORCE_NOT_NULL(
wrapper, platform::errors::PreconditionNotMet(
"Failed to dynamic_cast event into XPUDeviceEventWrapper."));
auto* xpu_dev_ctx = dynamic_cast<const platform::XPUDeviceContext*>(context);
PADDLE_ENFORCE_NOT_NULL(
xpu_dev_ctx,
platform::errors::PreconditionNotMet(
"Failed to dynamic_cast context into XPUDeviceContext."));
xpu_event_record(wrapper->handle_, xpu_dev_ctx->stream());
}
void DeviceEventFinishXPU(const DeviceEvent* event) {
auto* wrapper = static_cast<XPUDeviceEventWrapper*>(event->GetEvent().get());
PADDLE_ENFORCE_NOT_NULL(
wrapper, platform::errors::PreconditionNotMet(
"Failed to dynamic_cast event into XPUDeviceEventWrapper."));
xpu_event_wait(wrapper->handle_);
}
// current xpu not support query, used wait to instead.
bool DeviceEventQueryXPU(const DeviceEvent* event) {
DeviceEventFinishXPU(event);
return true;
}
void DeviceEventXPUWaitXPU(const DeviceEvent* event,
const DeviceContext* context) {
auto* wrapper = static_cast<XPUDeviceEventWrapper*>(event->GetEvent().get());
PADDLE_ENFORCE_NOT_NULL(
wrapper, platform::errors::PreconditionNotMet(
"Failed to dynamic_cast event into XPUDeviceEventWrapper."));
auto* xpu_dev_ctx = dynamic_cast<const platform::XPUDeviceContext*>(context);
PADDLE_ENFORCE_NOT_NULL(
xpu_dev_ctx,
platform::errors::PreconditionNotMet(
"Failed to dynamic_cast context into XOUDeviceContext."));
xpu_stream_wait_event(xpu_dev_ctx->stream(), wrapper->handle_);
}
void DeviceEventCPUWaitXPU(const DeviceEvent* event,
const DeviceContext* context) {
DeviceEventFinishXPU(event);
}
void DeviceEventSetFinishedXPU(const DeviceEvent* event) {
// do nothing
}
void EventResetXPU(const DeviceEvent* event) {
// do nothing
}
} // namespace platform
} // namespace paddle
using ::paddle::platform::kCPU;
using ::paddle::platform::kXPU;
REGISTER_EVENT_CREATE_FUNCTION(kXPU, paddle::platform::DeviceEventCreateXPU)
REGISTER_EVENT_RECORD_FUNCTION(kXPU, paddle::platform::DeviceEventRecordXPU)
REGISTER_EVENT_QUERY_FUNCTION(kXPU, paddle::platform::DeviceEventQueryXPU)
REGISTER_EVENT_FINISH_FUNCTION(kXPU, paddle::platform::DeviceEventFinishXPU)
REGISTER_EVENT_SET_FINISHED_FUNCTION(
kXPU, paddle::platform::DeviceEventSetFinishedXPU)
REGISTER_EVENT_WAIT_FUNCTION(kXPU, kXPU,
paddle::platform::DeviceEventXPUWaitXPU)
REGISTER_EVENT_WAIT_FUNCTION(kCPU, kXPU,
paddle::platform::DeviceEventCPUWaitXPU)
REGISTER_EVENT_RESET_FUNCTION(kXPU, paddle::platform::EventResetXPU)
#endif
......@@ -1392,9 +1392,9 @@ class Executor(object):
program = pruned_program
def _can_use_interpreter_core(program, place):
if core.is_compiled_with_npu() or core.is_compiled_with_xpu(
) or core.is_compiled_with_mlu() or core.is_compiled_with_ipu(
) or isinstance(place, core.CustomPlace):
if core.is_compiled_with_npu() or core.is_compiled_with_mlu(
) or core.is_compiled_with_ipu() or isinstance(
place, core.CustomPlace):
return False
compiled = isinstance(program, compiler.CompiledProgram)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册