未验证 提交 bf50784c 编写于 作者: R Ruibiao Chen 提交者: GitHub

New executor static build for fluid kernel (#50670)

* Check structed kernel for new executor static build

* Update code

* Ready for resnet50

* Move transfer_dtype to phi

* Ready for transformer

* Fix CI errors

* Fix layer_norm InferMeta

* Remove layer_norm infermeta fix
上级 819f8939
...@@ -176,26 +176,47 @@ void DataTranferHelper::RunAndConstructOpFuncNode( ...@@ -176,26 +176,47 @@ void DataTranferHelper::RunAndConstructOpFuncNode(
new_op_func_node.input_index["X"] = {var_scope_->VarId(var_name)}; new_op_func_node.input_index["X"] = {var_scope_->VarId(var_name)};
new_op_func_node.output_index["Out"] = {var_scope_->VarId(new_var_name)}; new_op_func_node.output_index["Out"] = {var_scope_->VarId(new_var_name)};
new_op_func_node.dev_ctx_ = dev_ctx;
new_op_func_node.operator_base_ = op;
const phi::Place& place = dev_ctx->GetPlace();
if (platform::is_cpu_place(place)) {
new_op_func_node.type_ = OpFuncType::kCpuSync;
} else if (platform::is_gpu_place(place)) {
// MemcpyD2H in gpu is synchronous, see
// https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html#api-sync-behavior__memcpy-async
// for more detial.
new_op_func_node.type_ =
(op_type == kMemcpyD2H ? OpFuncType::kGpuSync : OpFuncType::kGpuAsync);
} else if (platform::is_xpu_place(place)) {
// Memcpy in xpu is synchronous
new_op_func_node.type_ = OpFuncType::kGpuSync;
} else {
// Memcpy in npu and custom devices is asynchronous
new_op_func_node.type_ = OpFuncType::kGpuAsync;
}
if (!run_phi_kernel) { if (!run_phi_kernel) {
op_with_kernel->ChooseKernel(exec_ctx); op_with_kernel->ChooseKernel(exec_ctx);
new_op_func_node.kernel_func_ = *op_with_kernel->kernel_func(); new_op_func_node.kernel_func_ = *op_with_kernel->kernel_func();
new_op_func_node.kernel_func_(exec_ctx); new_op_func_node.kernel_func_(exec_ctx);
} else { } else {
new_op_func_node.phi_kernel_ = op_with_kernel->PhiKernel(); new_op_func_node.phi_kernel_ = op_with_kernel->PhiKernel();
if (skip_run) {
FakeInitializeOutputsForFunctionKernel(
*(new_op_func_node.phi_kernel_),
*(op_with_kernel->PhiKernelSignature()),
runtime_context,
*dev_ctx);
} else {
phi::KernelContext phi_kernel_context; phi::KernelContext phi_kernel_context;
op_with_kernel->BuildPhiKernelContext( op_with_kernel->BuildPhiKernelContext(
runtime_context, dev_ctx, &phi_kernel_context); runtime_context, dev_ctx, &phi_kernel_context);
if (!skip_run) {
(*new_op_func_node.phi_kernel_)(&phi_kernel_context); (*new_op_func_node.phi_kernel_)(&phi_kernel_context);
} else {
FakeInitializeOutputs(new_op_func_node.phi_kernel_,
op_with_kernel->PhiKernelSignature(),
&phi_kernel_context);
} }
} }
const phi::Place& place = dev_ctx->GetPlace();
// NOTE(winter-wang): in npu and custom device, D2H kernel is asynchronous. // NOTE(winter-wang): in npu and custom device, D2H kernel is asynchronous.
// need to explicit synchronization. // need to explicit synchronization.
if ((platform::is_npu_place(place) || platform::is_custom_place(place)) && if ((platform::is_npu_place(place) || platform::is_custom_place(place)) &&
...@@ -203,23 +224,6 @@ void DataTranferHelper::RunAndConstructOpFuncNode( ...@@ -203,23 +224,6 @@ void DataTranferHelper::RunAndConstructOpFuncNode(
dev_ctx->Wait(); dev_ctx->Wait();
} }
if (platform::is_cpu_place(place)) {
new_op_func_node.type_ = OpFuncType::kCpuSync;
} else if (platform::is_gpu_place(place)) {
// MemcpyD2H in gpu is synchronous, see
// https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html#api-sync-behavior__memcpy-async
// for more detial.
new_op_func_node.type_ =
(op_type == kMemcpyD2H ? OpFuncType::kGpuSync : OpFuncType::kGpuAsync);
} else if (platform::is_xpu_place(place)) {
// Memcpy in xpu is synchronous
new_op_func_node.type_ = OpFuncType::kGpuSync;
} else {
// Memcpy in npu and custom devices is asynchronous
new_op_func_node.type_ = OpFuncType::kGpuAsync;
}
new_op_func_node.dev_ctx_ = dev_ctx;
new_op_func_node.operator_base_ = op;
VLOG(3) << "Run " << op_type << " done."; VLOG(3) << "Run " << op_type << " done.";
new_op_func_nodes->emplace_back(std::move(new_op_func_node)); new_op_func_nodes->emplace_back(std::move(new_op_func_node));
......
...@@ -42,18 +42,18 @@ class DataTranferHelper { ...@@ -42,18 +42,18 @@ class DataTranferHelper {
std::vector<OpFuncNode>* new_op_func_nodes, std::vector<OpFuncNode>* new_op_func_nodes,
bool use_local_scope, bool use_local_scope,
bool is_fetch_v2, bool is_fetch_v2,
bool skip_run = false); bool static_build = false);
void RunAndConstructShareNode(const std::string& src_var_name, void RunAndConstructShareNode(const std::string& src_var_name,
const std::string& dst_var_name, const std::string& dst_var_name,
std::vector<OpFuncNode>* op_func_nodes, std::vector<OpFuncNode>* op_func_nodes,
bool skip_run = false); bool static_build = false);
void RunAndConstructOpFuncNode(const std::shared_ptr<OperatorBase>& op, void RunAndConstructOpFuncNode(const std::shared_ptr<OperatorBase>& op,
const std::string& var_name, const std::string& var_name,
const std::string& new_var_name, const std::string& new_var_name,
std::vector<OpFuncNode>* op_func_nodes, std::vector<OpFuncNode>* op_func_nodes,
bool skip_run = false); bool static_build = false);
private: private:
platform::Place place_; platform::Place place_;
...@@ -69,7 +69,7 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, ...@@ -69,7 +69,7 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
OpFuncNode* op_func_node, OpFuncNode* op_func_node,
std::vector<OpFuncNode>* op_func_nodes, std::vector<OpFuncNode>* op_func_nodes,
bool use_local_scope = true, bool use_local_scope = true,
bool skip_run = false); bool static_build = false);
void HandleComplexGradToRealGrad(const OpFuncNode& op_func_node, void HandleComplexGradToRealGrad(const OpFuncNode& op_func_node,
const platform::Place& place, const platform::Place& place,
...@@ -78,7 +78,7 @@ void HandleComplexGradToRealGrad(const OpFuncNode& op_func_node, ...@@ -78,7 +78,7 @@ void HandleComplexGradToRealGrad(const OpFuncNode& op_func_node,
VariableScope* var_scope, VariableScope* var_scope,
std::vector<OpFuncNode>* op_func_nodes, std::vector<OpFuncNode>* op_func_nodes,
framework::Scope* local_scope, framework::Scope* local_scope,
bool skip_run = false); bool static_build = false);
inline bool need_device_transform(const phi::KernelKey& kernel_type_for_var, inline bool need_device_transform(const phi::KernelKey& kernel_type_for_var,
const phi::DenseTensor* tensor, const phi::DenseTensor* tensor,
......
...@@ -38,11 +38,6 @@ PADDLE_DEFINE_EXPORTED_bool( ...@@ -38,11 +38,6 @@ PADDLE_DEFINE_EXPORTED_bool(
false, false,
"Log memory stats after each op runs, just used for debug."); "Log memory stats after each op runs, just used for debug.");
PADDLE_DEFINE_EXPORTED_bool(
new_executor_static_build,
false,
"Build the interpreterCore statically without running.");
DECLARE_bool(use_mkldnn); DECLARE_bool(use_mkldnn);
DECLARE_bool(check_nan_inf); DECLARE_bool(check_nan_inf);
...@@ -52,6 +47,94 @@ namespace interpreter { ...@@ -52,6 +47,94 @@ namespace interpreter {
using VariableIdMap = std::map<std::string, std::vector<int>>; using VariableIdMap = std::map<std::string, std::vector<int>>;
// These Op needs set output dtype when register phi kernel, but they didn't
static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"abs",
"accuracy",
"adam",
"adamw",
"all_close",
"all_raw",
"angle",
"any_raw",
"arg_sort",
"argmax",
"argmin",
"as_real",
"atan2",
"auc",
"bincount",
"clip_by_norm",
"complex",
"conv3d_coo",
"distribute_fpn_proposals",
"edit_distance",
"eig",
"eig_grad",
"eigh",
"eigvals",
"ftt_c2r",
"ftt_r2c",
"fused_adam",
"fused_matmul",
"generate_proposals",
"graph_sample_neighbors",
"group_norm",
"histogram",
"instance_norm",
"is_empty",
"is_finite",
"kthvalue",
"lamb",
"layer_norm",
"layer_norm_grad",
"less_equal",
"less_than",
"logical_and",
"logical_not",
"logical_or",
"logical_xor",
"lstsq",
"lu",
"matrix_nms",
"matrix_rank_tol",
"merged_adam",
"mode",
"momentum",
"multiclass_nms3",
"multinomial",
"nanmedian",
"nms",
"nonzero",
"numl",
"qr",
"qr_grad",
"rnn",
"roi_pool",
"search_sort",
"select",
"send_recv",
"send_ue_recv",
"sgd",
"svd",
"sync_batch_norm_grad",
"top_k",
"unique",
"unique_consecutive_flattened_tensor",
"unique_raw",
"viterbi_decode",
"viterbi_devode",
"yolo_loss"};
// These Ops can use InferMeta to infer the output dtype
static std::set<std::string> OpsWithAvailablePhiInferMeta = {
"abs", "adam", "adamw", "layer_norm", "layer_norm_grad", "merged_adam"};
// Cannot static analysis these Ops' output dtype or backend because their
// kernels have not moved to PHI yet.
static std::set<std::string> OpsWithFluidKernelNeedMoveToPhi = {
"fused_batch_norm_act", "fused_batch_norm_act_grad"};
// NOTE(Ruibiao): SingleStreamGuard make some multi-strem op (i.e., // NOTE(Ruibiao): SingleStreamGuard make some multi-strem op (i.e.,
// c_allreduce_sum) run in single stream. It is dedicated to BuildOpFuncList // c_allreduce_sum) run in single stream. It is dedicated to BuildOpFuncList
// which run kernel without stream synchronization. // which run kernel without stream synchronization.
...@@ -121,6 +204,48 @@ void AsyncWorkQueue::AddTask(const OpFuncType& op_func_type, ...@@ -121,6 +204,48 @@ void AsyncWorkQueue::AddTask(const OpFuncType& op_func_type,
queue_group_->AddTask(op_func_type == OpFuncType::kGpuAsync, std::move(fn)); queue_group_->AddTask(op_func_type == OpFuncType::kGpuAsync, std::move(fn));
} }
bool BlockCanBeStaticBuilt(const framework::BlockDesc& block) {
// has_fluid_kernel = (kernelCode >> 3) & 1
// has_structed_kernel = (kernelCode >> 2) & 1
// need_move_to_phi = (kernelCode >> 1) & 1
// need_set_dtype = KernelCode & 1
using KernelCode = int8_t;
std::set<std::pair<std::string, KernelCode>> invalid_ops;
for (auto& op : block.AllOps()) {
auto op_type = op->Type();
bool has_fluid_kernel = OperatorWithKernel::AllOpKernels().count(op_type);
bool has_structured_kernel =
phi::KernelFactory::Instance().HasStructuredKernel(op_type);
bool need_move_to_phi = (has_fluid_kernel || has_structured_kernel) &&
OpsWithFluidKernelNeedMoveToPhi.count(op_type);
bool need_set_dtype =
!has_fluid_kernel && !has_structured_kernel &&
OpsNeedSetOutputDtypeWhenRegisterPhiKernel.count(op_type) &&
!OpsWithAvailablePhiInferMeta.count(op_type);
KernelCode kernel_code = (has_fluid_kernel << 3) +
(has_structured_kernel << 2) +
(need_move_to_phi << 1) + need_set_dtype;
if (need_move_to_phi || need_set_dtype) {
invalid_ops.insert(std::make_pair(op_type, kernel_code));
}
}
if (!invalid_ops.empty()) {
std::stringstream ss;
ss << "The following OPs are unable to static build:\n";
for (auto& item : invalid_ops) {
ss << item.first << " [has_fluid_kernel = " << (item.second >> 3 & 1)
<< ", has_structed_kerenl = " << (item.second >> 2 & 1)
<< ", need_move_to_phi = " << (item.second >> 1 & 1)
<< ", need_set_dtype = " << (item.second & 1) << "]\n";
}
VLOG(0) << ss.str();
}
return invalid_ops.empty();
}
bool IsCommunicationOp(const std::string& op_name) { bool IsCommunicationOp(const std::string& op_name) {
const std::set<std::string> special_comm_op_set = { const std::set<std::string> special_comm_op_set = {
"send", "send",
...@@ -144,6 +269,10 @@ bool IsCpuOp(const Instruction& instr) { ...@@ -144,6 +269,10 @@ bool IsCpuOp(const Instruction& instr) {
return platform::is_cpu_place(instr.DeviceContext().GetPlace()); return platform::is_cpu_place(instr.DeviceContext().GetPlace());
} }
bool IsGradOp(const std::string& op_name) {
return paddle::string::ends_with(op_name, "_grad");
}
bool IsSupportedHeterPlace(const phi::Place& place) { bool IsSupportedHeterPlace(const phi::Place& place) {
return platform::is_gpu_place(place) || platform::is_npu_place(place) || return platform::is_gpu_place(place) || platform::is_npu_place(place) ||
platform::is_xpu_place(place) || platform::is_ipu_place(place) || platform::is_xpu_place(place) || platform::is_ipu_place(place) ||
...@@ -162,33 +291,6 @@ bool IsMemcpyOp(const Instruction& instr) { ...@@ -162,33 +291,6 @@ bool IsMemcpyOp(const Instruction& instr) {
return IsMemcpyD2H(instr) || IsMemcpyH2D(instr); return IsMemcpyD2H(instr) || IsMemcpyH2D(instr);
} }
bool IsBlockContainsOnlyPhiKernel(const framework::BlockDesc& block) {
bool res = true;
for (auto& op : block.AllOps()) {
auto op_type = op->Type();
if (op_type == "feed" || op_type == "fetch_v2") {
continue;
}
auto has_phi_kernel =
!phi::KernelFactory::Instance()
.SelectKernelMap(phi::TransToPhiKernelName(op_type))
.empty();
if (!has_phi_kernel) {
auto kernel_iter = OperatorWithKernel::AllOpKernels().find(op_type);
if (kernel_iter != OperatorWithKernel::AllOpKernels().end()) {
VLOG(4) << op_type << " has no phi kernel, but has fluid kernel.";
res = false;
} else {
VLOG(4) << op_type << " has no phi kernel, and no fluid kernel.";
}
} else {
VLOG(4) << op_type << " has phi kernel";
}
}
return res;
}
void AddFetch(const std::vector<std::string>& fetch_names, void AddFetch(const std::vector<std::string>& fetch_names,
framework::BlockDesc* block) { framework::BlockDesc* block) {
auto* fetch_holder = block->Var(kFetchVarName); auto* fetch_holder = block->Var(kFetchVarName);
...@@ -275,53 +377,6 @@ GetUnusedVars(const BlockDesc& block, ...@@ -275,53 +377,6 @@ GetUnusedVars(const BlockDesc& block,
return result; return result;
} }
void BuildVariableScope(const framework::BlockDesc& block,
const ExecutionConfig& execution_config,
VariableScope* var_scope) {
VLOG(3) << "Creating Variables";
auto inner_scope = var_scope->GetMutableScope();
// NOTE(zhiqiu): if create_local_scope_ is true, the persistable is
// created in var_scope.scope_ , and other scope is created in local scope.
Scope* local_scope = execution_config.create_local_scope
? var_scope->GetMutableLocalScope()
: var_scope->GetMutableScope();
for (auto& var_desc : block.AllVars()) {
auto var_name = var_desc->Name();
// TODO(xiongkun): user may create a variable with name that exists before.
// under such circumstances, we should raise a error. Currently we can't
// get the var_desc of startup_program, so leave it later.
if (var_name == framework::kEmptyVarName) {
continue;
}
if (var_desc->Persistable() ||
execution_config.force_root_scope_vars.count(var_name)) {
// In principle, we should put all trainable parameters in global scope,
// which means the root of the scope tree. Some cases like quantization
// will look up these parameters in global scope.
const Scope* ancestor_scope = inner_scope;
while (ancestor_scope->parent()) {
ancestor_scope = ancestor_scope->parent();
}
auto* ptr = const_cast<Scope*>(ancestor_scope)->Var(var_name);
// NOTE(zhiqiu): if var exists in scope and the type is right,
// InitializeVariable will not create a new variable.
InitializeVariable(ptr, var_desc->GetType());
VLOG(3) << "Create Variable " << var_name << " global, which pointer is "
<< ptr << " type is " << static_cast<int>(var_desc->GetType());
} else {
auto* ptr = local_scope->Var(var_name);
InitializeVariable(ptr, var_desc->GetType());
VLOG(3) << "Create Variable " << var_name << " locally, which pointer is "
<< ptr << " type is " << static_cast<int>(var_desc->GetType());
}
var_scope->AddVar(var_name, var_desc);
}
}
OpFuncType AnalyseOpFuncType(const OpFuncNode& op_func_node, OpFuncType AnalyseOpFuncType(const OpFuncNode& op_func_node,
const platform::Place& place) { const platform::Place& place) {
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
...@@ -508,72 +563,14 @@ void HandleOperatorBase(const platform::Place& place, ...@@ -508,72 +563,14 @@ void HandleOperatorBase(const platform::Place& place,
op_func_node->dev_ctx_ = dev_ctx; op_func_node->dev_ctx_ = dev_ctx;
} }
void FakeInitializeOutputs(phi::Kernel* phi_kernel, void BuildOpFuncList(const platform::Place& place,
phi::KernelSignature* kernel_sig,
phi::KernelContext* phi_kernel_context) {
auto output_defs = phi_kernel->args_def().output_defs();
auto out_names = kernel_sig->output_names;
for (size_t i = 0; i < out_names.size(); ++i) {
VLOG(4) << out_names[i];
// calcute the start and end index of the output tensors
size_t start_idx = phi_kernel_context->OutputRangeAt(i).first;
size_t end_idx = phi_kernel_context->OutputRangeAt(i).second;
for (size_t j = start_idx; j < end_idx; ++j) {
auto* out_tensor = phi_kernel_context->MutableOutputAt(j);
if (out_tensor == nullptr) {
VLOG(4) << "Output" << out_names[i] << " is nullptr";
continue;
}
auto backend = output_defs[j].backend;
auto* dev_ctx =
&(phi_kernel_context->GetDeviceContext<phi::DeviceContext>());
if (phi::DenseTensor::classof(out_tensor)) {
if (!out_tensor->initialized()) {
VLOG(4) << "DenseTensor fake alloc 0 bytes of type "
<< out_tensor->dtype() << " on backend " << backend << " "
<< out_tensor;
if (backend == phi::TransToPhiBackend(dev_ctx->GetPlace())) {
dev_ctx->Alloc(out_tensor,
out_tensor->dtype(),
/*requested_size=*/0,
/*pinned=*/false,
/*fake_alloc=*/true);
} else {
if (backend == phi::Backend::CPU ||
backend == phi::Backend::ONEDNN) {
dev_ctx->HostAlloc(out_tensor,
out_tensor->dtype(),
/*requested_size=*/0,
/*fake_alloc=*/true);
}
}
}
} else if (phi::SparseCooTensor::classof(out_tensor)) {
// todo
VLOG(4) << "SparseCooTensor";
} else if (phi::SparseCsrTensor::classof(out_tensor)) {
// todo
VLOG(4) << "SparseCsrTensor";
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Only support "
"DenseTensor/SparseCooTensor/SparseCsrTensor "
"now"));
VLOG(4) << "SparseCooTensor";
}
}
}
}
bool BuildOpFuncList(const platform::Place& place,
const framework::BlockDesc& block, const framework::BlockDesc& block,
const std::set<std::string>& skip_gc_vars, const std::set<std::string>& skip_gc_vars,
std::vector<OpFuncNode>* vec_func_list, std::vector<OpFuncNode>* vec_func_list,
VariableScope* var_scope, VariableScope* var_scope,
const ExecutionConfig& execution_config, const ExecutionConfig& execution_config,
bool use_local_scope) { bool use_local_scope,
bool static_build) {
Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope() Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope()
: var_scope->GetMutableScope(); : var_scope->GetMutableScope();
std::vector<std::unique_ptr<OperatorBase>> std::vector<std::unique_ptr<OperatorBase>>
...@@ -581,9 +578,7 @@ bool BuildOpFuncList(const platform::Place& place, ...@@ -581,9 +578,7 @@ bool BuildOpFuncList(const platform::Place& place,
// Step 1: create all ops for current block. // Step 1: create all ops for current block.
CreateAllOps(block, &ops_unique); CreateAllOps(block, &ops_unique);
auto skip_run = VLOG(4) << "Static build: " << static_build;
FLAGS_new_executor_static_build && IsBlockContainsOnlyPhiKernel(block);
VLOG(4) << "Static build: " << skip_run;
if (!execution_config.used_for_jit) { if (!execution_config.used_for_jit) {
// If gc is enabled and block size > 1 // If gc is enabled and block size > 1
...@@ -822,7 +817,7 @@ bool BuildOpFuncList(const platform::Place& place, ...@@ -822,7 +817,7 @@ bool BuildOpFuncList(const platform::Place& place,
&op_func_node, &op_func_node,
vec_func_list, vec_func_list,
use_local_scope, use_local_scope,
skip_run); static_build);
VLOG(4) << "apply data transform done. "; VLOG(4) << "apply data transform done. ";
// step 4. infershape, see OperatorWithKernel::RunImpl in operator.cc // step 4. infershape, see OperatorWithKernel::RunImpl in operator.cc
// for why. // for why.
...@@ -840,43 +835,57 @@ bool BuildOpFuncList(const platform::Place& place, ...@@ -840,43 +835,57 @@ bool BuildOpFuncList(const platform::Place& place,
if (run_phi_kernel && if (run_phi_kernel &&
op_func_node.phi_kernel_->GetKernelRegisteredType() == op_func_node.phi_kernel_->GetKernelRegisteredType() ==
phi::KernelRegisteredType::FUNCTION) { phi::KernelRegisteredType::FUNCTION) {
VLOG(6) << op_type << " run function kernel";
if (static_build) {
FakeInitializeOutputsForFunctionKernel(
*(op_func_node.phi_kernel_),
*(op_with_kernel->PhiKernelSignature()),
runtime_context,
*dev_ctx);
} else {
phi::KernelContext phi_kernel_context; phi::KernelContext phi_kernel_context;
op_with_kernel->BuildPhiKernelContext( op_with_kernel->BuildPhiKernelContext(
runtime_context, dev_ctx, &phi_kernel_context); runtime_context, dev_ctx, &phi_kernel_context);
if (!skip_run) {
(*op_func_node.phi_kernel_)(&phi_kernel_context); (*op_func_node.phi_kernel_)(&phi_kernel_context);
} else {
FakeInitializeOutputs(op_func_node.phi_kernel_,
op_with_kernel->PhiKernelSignature(),
&phi_kernel_context);
} }
} else if (run_phi_kernel && } else if (run_phi_kernel &&
op_func_node.phi_kernel_->GetKernelRegisteredType() == op_func_node.phi_kernel_->GetKernelRegisteredType() ==
phi::KernelRegisteredType::STRUCTURE) { phi::KernelRegisteredType::STRUCTURE) {
VLOG(6) << op_type << " run structure kernel";
ExecutionContext execution_context( ExecutionContext execution_context(
*op_with_kernel, *runtime_scope, *dev_ctx, runtime_context); *op_with_kernel, *runtime_scope, *dev_ctx, runtime_context);
if (static_build) {
FakeInitializeOutputsForStructureKernel(kernel_type,
&execution_context);
} else {
(*op_func_node.phi_kernel_)(&execution_context); (*op_func_node.phi_kernel_)(&execution_context);
}
} else { } else {
VLOG(6) << op_type << " run fluid kernel";
// the place of exec_ctx maybe has changed. // the place of exec_ctx maybe has changed.
if (!skip_run) { ExecutionContext execution_context(
op_func_node.kernel_func_(ExecutionContext( *op_with_kernel, *runtime_scope, *dev_ctx, runtime_context);
*op_with_kernel, *runtime_scope, *dev_ctx, runtime_context)); if (static_build) {
FakeInitializeOutputsForStructureKernel(kernel_type,
&execution_context);
} else { } else {
// TODO(zhiqiu): is it needed to support fluid kernel? op_func_node.kernel_func_(execution_context);
} }
} }
// post-process grad_op.outputs if need cast complex grad into real // post-process grad_op.outputs if need cast complex grad into real
// grad. // grad.
// NOTE(Aurelius84): insert a transfer_dtype_op inplacely to cast it. // NOTE(Aurelius84): insert a transfer_dtype_op inplacely to cast it.
if (framework::IsComplexType(kernel_type.data_type_)) { if (IsGradOp(op_type) &&
framework::IsComplexType(kernel_type.data_type_)) {
interpreter::HandleComplexGradToRealGrad(op_func_node, interpreter::HandleComplexGradToRealGrad(op_func_node,
place, place,
output_name_map, output_name_map,
&runtime_context.outputs, &runtime_context.outputs,
var_scope, var_scope,
vec_func_list, vec_func_list,
local_scope); local_scope,
static_build);
} }
if (!op_func_node.inplace_back_map.empty()) { if (!op_func_node.inplace_back_map.empty()) {
auto& m = op_func_node.inplace_back_map; auto& m = op_func_node.inplace_back_map;
...@@ -949,7 +958,205 @@ bool BuildOpFuncList(const platform::Place& place, ...@@ -949,7 +958,205 @@ bool BuildOpFuncList(const platform::Place& place,
interpreter::LogDeviceMemoryStats(place); interpreter::LogDeviceMemoryStats(place);
} }
return skip_run; }
void BuildVariableScope(const framework::BlockDesc& block,
const ExecutionConfig& execution_config,
VariableScope* var_scope) {
VLOG(3) << "Creating Variables";
auto inner_scope = var_scope->GetMutableScope();
// NOTE(zhiqiu): if create_local_scope_ is true, the persistable is
// created in var_scope.scope_ , and other scope is created in local scope.
Scope* local_scope = execution_config.create_local_scope
? var_scope->GetMutableLocalScope()
: var_scope->GetMutableScope();
for (auto& var_desc : block.AllVars()) {
auto var_name = var_desc->Name();
// TODO(xiongkun): user may create a variable with name that exists before.
// under such circumstances, we should raise a error. Currently we can't
// get the var_desc of startup_program, so leave it later.
if (var_name == framework::kEmptyVarName) {
continue;
}
if (var_desc->Persistable() ||
execution_config.force_root_scope_vars.count(var_name)) {
// In principle, we should put all trainable parameters in global scope,
// which means the root of the scope tree. Some cases like quantization
// will look up these parameters in global scope.
const Scope* ancestor_scope = inner_scope;
while (ancestor_scope->parent()) {
ancestor_scope = ancestor_scope->parent();
}
auto* ptr = const_cast<Scope*>(ancestor_scope)->Var(var_name);
// NOTE(zhiqiu): if var exists in scope and the type is right,
// InitializeVariable will not create a new variable.
InitializeVariable(ptr, var_desc->GetType());
VLOG(3) << "Create Variable " << var_name << " global, which pointer is "
<< ptr << " type is " << static_cast<int>(var_desc->GetType());
} else {
auto* ptr = local_scope->Var(var_name);
InitializeVariable(ptr, var_desc->GetType());
VLOG(3) << "Create Variable " << var_name << " locally, which pointer is "
<< ptr << " type is " << static_cast<int>(var_desc->GetType());
}
var_scope->AddVar(var_name, var_desc);
}
}
phi::TensorBase* GetTensorFormVar(framework::Variable* var) {
if (var) {
if (var->template IsType<phi::DenseTensor>()) {
return var->template GetMutable<phi::DenseTensor>();
} else if (var->template IsType<phi::SelectedRows>()) {
return var->template GetMutable<phi::SelectedRows>();
} else if (var->template IsType<phi::SparseCooTensor>()) {
return var->template GetMutable<phi::SparseCooTensor>();
} else if (var->template IsType<framework::LoDTensorArray>()) {
return var->template GetMutable<framework::LoDTensorArray>();
} else if (var->template IsType<framework::Strings>()) {
return var->template GetMutable<framework::Strings>();
} else if (var->template IsType<paddle::framework::RawTensor>()) {
return var->template GetMutable<paddle::framework::RawTensor>();
} else if (!var->IsInitialized()) {
// The following is for RAW type of var
return var->template GetMutable<paddle::framework::RawTensor>();
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported `%s` type when get tensor.",
framework::ToTypeName(var->Type())));
}
} else {
VLOG(4) << "Var is nullptr";
return nullptr;
}
}
void FakeInitializeTensor(const platform::DeviceContext& dev_ctx,
const phi::DataType& dtype,
const phi::Place& place,
phi::TensorBase* tensor) {
PADDLE_ENFORCE_NOT_NULL(
tensor,
phi::errors::InvalidArgument(
"The tensor to fake intialize should not be null."));
if (place == phi::CPUPlace()) {
dev_ctx.HostAlloc(tensor,
dtype,
/*requested_size=*/0,
/*fake_alloc=*/true);
} else {
PADDLE_ENFORCE_EQ(
place,
dev_ctx.GetPlace(),
phi::errors::Unavailable("The place %s for fack alloc is not equal to "
"the place %s of DeviceContext.",
place,
dev_ctx.GetPlace()));
dev_ctx.Alloc(tensor,
dtype,
/*requested_size=*/0,
/*pinned=*/false,
/*fake_alloc=*/true);
}
}
void FakeInitializeOutputsForFunctionKernel(
const phi::Kernel& phi_kernel,
const phi::KernelSignature& kernel_sig,
const RuntimeContext& ctx,
const platform::DeviceContext& dev_ctx) {
std::string op_name = std::string(kernel_sig.name);
if (OpsNeedSetOutputDtypeWhenRegisterPhiKernel.count(op_name)) {
PADDLE_ENFORCE_GT(
OpsWithAvailablePhiInferMeta.count(op_name),
0,
phi::errors::Unavailable(
"Cannot static build for op %s because it did not set output dtype "
"in phi kernel register. Please set its output dtype and remove it "
"from OpsNeedSetOutputDtypeWhenRegisterPhiKernel set, or add it to "
" OpsWithAvailablePhiInferMeta set if its InferMeta is available.",
op_name));
}
auto output_names = kernel_sig.output_names;
auto output_defs = phi_kernel.args_def().output_defs();
PADDLE_ENFORCE_EQ(output_names.size(),
output_defs.size(),
platform::errors::InvalidArgument(
"The size of outputs_args names (%d) must be equal to "
"the size of kernel output_defs (%d).",
output_names.size(),
output_defs.size()));
size_t start_idx = 0;
for (size_t i = 0; i < output_names.size(); ++i) {
auto it = ctx.outputs.find(output_names[i]);
// Deal with the case that some outputs are not found or be NULL when run
// the kernel. For example : the outputs of matmul_grad are dx and dy,
// sometimes dx or dy may be NULL.
if (it == ctx.outputs.end() || it->second.empty()) {
VLOG(4) << "Output " << output_names[i] << " not found";
++start_idx;
continue;
}
auto& outs_vector = it->second;
for (size_t offset = 0; offset < outs_vector.size(); ++offset) {
phi::TensorBase* out_tensor = GetTensorFormVar(outs_vector[offset]);
if (out_tensor && !out_tensor->initialized()) {
phi::TensorArgDef& tensor_arg_def = output_defs[start_idx + offset];
phi::DataType dtype = tensor_arg_def.dtype;
phi::Place place = phi::TransToPhiPlace(tensor_arg_def.backend);
if (dtype == DataType::UNDEFINED ||
OpsNeedSetOutputDtypeWhenRegisterPhiKernel.count(
std::string(kernel_sig.name))) {
VLOG(4) << "Get dtype result from InferMeta";
dtype = out_tensor->dtype(); // dtype from InferMeta
}
VLOG(4) << output_names[i] << " fake alloc with type " << dtype
<< " on place " << place << " " << out_tensor;
FakeInitializeTensor(dev_ctx, dtype, place, out_tensor);
}
}
start_idx += outs_vector.size();
}
}
void FakeInitializeOutputsForStructureKernel(
const framework::OpKernelType& op_kernel_type,
ExecutionContext* execution_context) {
const std::string& op_type = execution_context->Type();
if (op_type == "fetch_v2") {
return;
}
const VariableNameMap& outputs = execution_context->GetOp().Outputs();
for (auto& item : outputs) {
const std::string& parameter_name = item.first;
auto multi_output_var = execution_context->MultiOutputVar(parameter_name);
for (Variable* var : multi_output_var) {
phi::TensorBase* out_tensor = GetTensorFormVar(var);
if (out_tensor && !out_tensor->initialized()) {
phi::DataType dtype =
phi::TransToPhiDataType(op_kernel_type.data_type_);
phi::Place place = execution_context->GetPlace();
VLOG(4) << parameter_name << " fake alloc with type " << dtype
<< " on place " << place << " " << out_tensor;
FakeInitializeTensor(
execution_context->device_context(), dtype, place, out_tensor);
}
}
}
} }
void LogDeviceMemoryStats(const platform::Place& place) { void LogDeviceMemoryStats(const platform::Place& place) {
......
...@@ -65,12 +65,16 @@ class AsyncWorkQueue { ...@@ -65,12 +65,16 @@ class AsyncWorkQueue {
std::unique_ptr<WorkQueueGroup> queue_group_; std::unique_ptr<WorkQueueGroup> queue_group_;
}; };
bool BlockCanBeStaticBuilt(const framework::BlockDesc& block);
bool IsCommunicationOp(const std::string& op_name); bool IsCommunicationOp(const std::string& op_name);
bool IsCommunicationOp(const Instruction& instr); bool IsCommunicationOp(const Instruction& instr);
bool IsCpuOp(const Instruction& instr); bool IsCpuOp(const Instruction& instr);
bool IsGradOp(const std::string& op_name);
bool IsMemcpyD2H(const Instruction& instr); bool IsMemcpyD2H(const Instruction& instr);
bool IsMemcpyH2D(const Instruction& instr); bool IsMemcpyH2D(const Instruction& instr);
...@@ -82,23 +86,30 @@ bool IsSupportedHeterPlace(const phi::Place& place); ...@@ -82,23 +86,30 @@ bool IsSupportedHeterPlace(const phi::Place& place);
void AddFetch(const std::vector<std::string>& fetch_names, void AddFetch(const std::vector<std::string>& fetch_names,
framework::BlockDesc* block); framework::BlockDesc* block);
bool BuildOpFuncList(const platform::Place& place, void BuildOpFuncList(const platform::Place& place,
const framework::BlockDesc& block, const framework::BlockDesc& block,
const std::set<std::string>& skip_gc_vars, const std::set<std::string>& skip_gc_vars,
std::vector<OpFuncNode>* vec_func_list, std::vector<OpFuncNode>* vec_func_list,
VariableScope* scope, VariableScope* scope,
const ExecutionConfig& execution_config, const ExecutionConfig& execution_config,
bool use_local_scope = true); bool use_local_scope = true,
bool static_build = false);
void BuildVariableScope(const framework::BlockDesc& block, void BuildVariableScope(const framework::BlockDesc& block,
const ExecutionConfig& execution_config, const ExecutionConfig& execution_config,
VariableScope* var_scope); VariableScope* var_scope);
void LogDeviceMemoryStats(const platform::Place& place); void FakeInitializeOutputsForFunctionKernel(
const phi::Kernel& phi_kernel,
const phi::KernelSignature& kernel_sig,
const RuntimeContext& ctx,
const platform::DeviceContext& dev_ctx);
void FakeInitializeOutputs(phi::Kernel* phi_kernel, void FakeInitializeOutputsForStructureKernel(
phi::KernelSignature* kernel_sig, const framework::OpKernelType& op_kernel_type,
phi::KernelContext* phi_kernel_context); ExecutionContext* execution_context);
void LogDeviceMemoryStats(const platform::Place& place);
} // namespace interpreter } // namespace interpreter
} // namespace framework } // namespace framework
......
...@@ -38,6 +38,10 @@ PADDLE_DEFINE_EXPORTED_bool( ...@@ -38,6 +38,10 @@ PADDLE_DEFINE_EXPORTED_bool(
new_executor_serial_run, new_executor_serial_run,
false, false,
"Enable serial execution for standalone executor, used for debug."); "Enable serial execution for standalone executor, used for debug.");
PADDLE_DEFINE_EXPORTED_bool(
new_executor_static_build,
false,
"Build the interpreterCore statically without running kernels.");
PADDLE_DEFINE_EXPORTED_bool(new_executor_use_inplace, PADDLE_DEFINE_EXPORTED_bool(new_executor_use_inplace,
false, false,
"Use inplace in new executor"); "Use inplace in new executor");
...@@ -117,6 +121,9 @@ InterpreterCore::InterpreterCore(const platform::Place& place, ...@@ -117,6 +121,9 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
var_scope_(scope) { var_scope_(scope) {
VLOG(4) << "InterpreterCore(): " << this << " on " << place_; VLOG(4) << "InterpreterCore(): " << this << " on " << place_;
static_build_ = FLAGS_new_executor_static_build &&
interpreter::BlockCanBeStaticBuilt(block);
exception_notifier_ = main_thread_blocker_.RegisterEvent(kExceptionCaught); exception_notifier_ = main_thread_blocker_.RegisterEvent(kExceptionCaught);
completion_notifier_ = main_thread_blocker_.RegisterEvent(kTaskCompletion); completion_notifier_ = main_thread_blocker_.RegisterEvent(kTaskCompletion);
...@@ -275,20 +282,21 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -275,20 +282,21 @@ paddle::framework::FetchList InterpreterCore::Run(
block_, execution_config_, &var_scope_); block_, execution_config_, &var_scope_);
std::vector<paddle::framework::OpFuncNode> op_func_nodes; std::vector<paddle::framework::OpFuncNode> op_func_nodes;
auto skip_run = paddle::framework::interpreter::BuildOpFuncList( paddle::framework::interpreter::BuildOpFuncList(
place_, place_,
block_, block_,
execution_config_.skip_gc_vars, execution_config_.skip_gc_vars,
&op_func_nodes, &op_func_nodes,
&var_scope_, &var_scope_,
execution_config_, execution_config_,
HasLocalScope()); HasLocalScope(),
static_build_);
SetFeedVarsInplaceSkip(feed_names); SetFeedVarsInplaceSkip(feed_names);
// convert vec func_list to graph // convert vec func_list to graph
Convert(&op_func_nodes); Convert(&op_func_nodes);
is_build_ = true; is_build_ = true;
UpdateSyncOpNum(); UpdateSyncOpNum();
if (skip_run) { if (static_build_) {
VLOG(4) << "RUN impl"; VLOG(4) << "RUN impl";
RunImpl(); RunImpl();
} }
...@@ -1270,20 +1278,21 @@ void InterpreterCore::Prepare(const std::vector<std::string>& feed_names, ...@@ -1270,20 +1278,21 @@ void InterpreterCore::Prepare(const std::vector<std::string>& feed_names,
block_, execution_config_, &var_scope_); block_, execution_config_, &var_scope_);
FeedInput(); FeedInput();
std::vector<paddle::framework::OpFuncNode> op_func_nodes; std::vector<paddle::framework::OpFuncNode> op_func_nodes;
auto skip_run = paddle::framework::interpreter::BuildOpFuncList( paddle::framework::interpreter::BuildOpFuncList(
place_, place_,
block_, block_,
execution_config_.skip_gc_vars, execution_config_.skip_gc_vars,
&op_func_nodes, &op_func_nodes,
&var_scope_, &var_scope_,
execution_config_, execution_config_,
HasLocalScope()); HasLocalScope(),
static_build_);
SetFeedVarsInplaceSkip(feed_names); SetFeedVarsInplaceSkip(feed_names);
// convert vec func_list to graph // convert vec func_list to graph
Convert(&op_func_nodes); Convert(&op_func_nodes);
UpdateSyncOpNum(); UpdateSyncOpNum();
is_build_ = true; is_build_ = true;
if (skip_run) { if (static_build_) {
VLOG(4) << "RUN impl"; VLOG(4) << "RUN impl";
RunImpl(); RunImpl();
} }
......
...@@ -133,6 +133,7 @@ class InterpreterCore { ...@@ -133,6 +133,7 @@ class InterpreterCore {
private: private:
bool is_build_{false}; bool is_build_{false};
bool static_build_{false};
const platform::Place place_; const platform::Place place_;
const BlockDesc& block_; // not owned const BlockDesc& block_; // not owned
......
...@@ -12,13 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,13 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/cast_op.h"
#include <memory> #include <memory>
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
#ifdef PADDLE_WITH_MLU #ifdef PADDLE_WITH_MLU
#include "paddle/fluid/operators/mlu/mlu_baseop.h" #include "paddle/fluid/operators/mlu/mlu_baseop.h"
#endif #endif
...@@ -89,13 +91,6 @@ class CastOp : public framework::OperatorWithKernel { ...@@ -89,13 +91,6 @@ class CastOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(framework::InferShapeContext *context) const override {
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "cast");
OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "cast");
context->SetOutputDim("Out", context->GetInputDim("X"));
context->ShareLoD("X", "Out");
}
phi::KernelKey GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
// CastOp kernel's device type is decided by input tensor place // CastOp kernel's device type is decided by input tensor place
...@@ -150,13 +145,18 @@ class CastOp : public framework::OperatorWithKernel { ...@@ -150,13 +145,18 @@ class CastOp : public framework::OperatorWithKernel {
namespace ops = paddle::operators; namespace ops = paddle::operators;
using CPU = phi::CPUContext; using CPU = phi::CPUContext;
DECLARE_INFER_SHAPE_FUNCTOR(cast,
CastInferShapeFunctor,
PD_INFER_META(phi::CastInferMeta));
// cast use phi kernel, so no need to REGISTER_OP_CPU_KERNEL here. // cast use phi kernel, so no need to REGISTER_OP_CPU_KERNEL here.
REGISTER_OPERATOR(cast, REGISTER_OPERATOR(cast,
ops::CastOp, ops::CastOp,
ops::CastOpGradMaker<paddle::framework::OpDesc>, ops::CastOpGradMaker<paddle::framework::OpDesc>,
ops::CastOpGradMaker<paddle::imperative::OpBase>, ops::CastOpGradMaker<paddle::imperative::OpBase>,
ops::CastCompositeGradOpMaker, ops::CastCompositeGradOpMaker,
ops::CastOpProtoMaker); ops::CastOpProtoMaker,
CastInferShapeFunctor);
// [ why register transfer_dtype_op alias with cast_op? ] // [ why register transfer_dtype_op alias with cast_op? ]
// In case of InterpreterCore, if we reuse cast_op, we cannot distinguish // In case of InterpreterCore, if we reuse cast_op, we cannot distinguish
...@@ -165,19 +165,5 @@ REGISTER_OPERATOR(transfer_dtype, ...@@ -165,19 +165,5 @@ REGISTER_OPERATOR(transfer_dtype,
ops::CastOp, ops::CastOp,
ops::CastOpGradMaker<paddle::framework::OpDesc>, ops::CastOpGradMaker<paddle::framework::OpDesc>,
ops::CastOpGradMaker<paddle::imperative::OpBase>, ops::CastOpGradMaker<paddle::imperative::OpBase>,
ops::CastOpProtoMaker); ops::CastOpProtoMaker,
REGISTER_OP_CPU_KERNEL( CastInferShapeFunctor);
transfer_dtype,
ops::CastOpKernel<CPU, float>,
ops::CastOpKernel<CPU, double>,
ops::CastOpKernel<CPU, int>,
ops::CastOpKernel<CPU, int64_t>,
ops::CastOpKernel<CPU, int>,
ops::CastOpKernel<CPU, int16_t>,
ops::CastOpKernel<CPU, bool>,
ops::CastOpKernel<CPU, uint8_t>,
ops::CastOpKernel<CPU, int8_t>,
ops::CastOpKernel<CPU, paddle::platform::float16>,
ops::CastOpKernel<CPU, paddle::platform::bfloat16>,
ops::CastOpKernel<CPU, paddle::platform::complex<float>>,
ops::CastOpKernel<CPU, paddle::platform::complex<double>>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/cast_op.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
using CUDA = phi::GPUContext;
// See [ why register transfer_dtype_op alias with cast_op? ] in cast_op.cc
REGISTER_OP_CUDA_KERNEL(transfer_dtype,
ops::CastOpKernel<CUDA, float>,
ops::CastOpKernel<CUDA, double>,
ops::CastOpKernel<CUDA, int>,
ops::CastOpKernel<CUDA, int64_t>,
ops::CastOpKernel<CUDA, int16_t>,
ops::CastOpKernel<CUDA, bool>,
ops::CastOpKernel<CUDA, uint8_t>,
ops::CastOpKernel<CUDA, plat::float16>,
ops::CastOpKernel<CUDA, plat::complex<float>>,
ops::CastOpKernel<CUDA, plat::complex<double>>,
ops::CastOpKernel<CUDA, plat::bfloat16>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/phi_utils.h"
#include "paddle/phi/common/transform.h"
#include "paddle/phi/kernels/cast_kernel.h"
namespace paddle {
namespace operators {
template <typename InT, typename OutT>
struct CastOpTransformFunctor {
HOSTDEVICE OutT operator()(InT in) const { return static_cast<OutT>(in); }
};
template <typename DeviceContext, typename InT>
struct CastOpFunctor {
const phi::DenseTensor* in_;
phi::DenseTensor* out_;
const DeviceContext& ctx_;
CastOpFunctor(const phi::DenseTensor* in,
phi::DenseTensor* out,
const DeviceContext& ctx)
: in_(in), out_(out), ctx_(ctx) {}
template <typename OutT>
void apply() const {
auto* in_begin = in_->data<InT>();
auto numel = in_->numel();
auto* in_end = in_begin + numel;
auto* out_begin = out_->mutable_data<OutT>(ctx_.GetPlace());
phi::Transform<DeviceContext> trans;
trans(
ctx_, in_begin, in_end, out_begin, CastOpTransformFunctor<InT, OutT>());
}
};
template <typename DeviceContext, typename InT>
class CastOpKernel : public framework::OpKernel<InT> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<phi::DenseTensor>("X");
auto* out = context.Output<phi::DenseTensor>("Out");
auto out_dtype = context.Attr<int>("out_dtype");
auto& dev_ctx = context.device_context<DeviceContext>();
out->mutable_data(dev_ctx.GetPlace(),
static_cast<framework::proto::VarType::Type>(out_dtype));
auto pt_out_dtype = framework::TransToPhiDataType(
static_cast<framework::proto::VarType::Type>(out_dtype));
// call new kernel
phi::CastKernel<InT>(
static_cast<const typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*in,
pt_out_dtype,
out);
}
};
} // namespace operators
} // namespace paddle
...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/cast_op.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h" #include "paddle/fluid/operators/mlu/mlu_baseop.h"
#include "paddle/fluid/platform/device/mlu/device_context.h" #include "paddle/fluid/platform/device/mlu/device_context.h"
......
...@@ -15,7 +15,6 @@ limitations under the License. */ ...@@ -15,7 +15,6 @@ limitations under the License. */
#include <memory> #include <memory>
#include <string> #include <string>
#include "paddle/fluid/operators/cast_op.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace paddle { namespace paddle {
......
...@@ -121,29 +121,38 @@ class FeedOp : public framework::OperatorWithKernel { ...@@ -121,29 +121,38 @@ class FeedOp : public framework::OperatorWithKernel {
if (ctx->IsRuntime()) { if (ctx->IsRuntime()) {
framework::Variable* x_var = framework::Variable* x_var =
PADDLE_GET(framework::Variable*, ctx->GetInputVarPtrs("X")[0]); PADDLE_GET(framework::Variable*, ctx->GetInputVarPtrs("X")[0]);
framework::Variable* out_var =
PADDLE_GET(framework::Variable*, ctx->GetOutputVarPtrs("Out")[0]);
auto& x = x_var->Get<framework::FeedList>(); auto& x = x_var->Get<framework::FeedList>();
int col = ctx->Attrs().Get<int>("col"); int col = ctx->Attrs().Get<int>("col");
auto& feed_item = x[col];
if (feed_item.index() == 0) {
const auto& feed_item = CheckAndGetFeedItem(x, col); const auto& feed_item = CheckAndGetFeedItem(x, col);
if (feed_item.index() == 0) { // DenseTensor
auto& feed_tensor = PADDLE_GET_CONST(phi::DenseTensor, feed_item); auto& feed_tensor = PADDLE_GET_CONST(phi::DenseTensor, feed_item);
ctx->SetOutputDim("Out", feed_tensor.dims()); phi::DenseTensor* out_tensor = out_var->GetMutable<phi::DenseTensor>();
} else if (feed_item.index() == 1) { phi::DenseTensorMeta meta = out_tensor->meta();
meta.dims = feed_tensor.dims();
meta.dtype = feed_tensor.dtype();
meta.layout = feed_tensor.layout();
meta.lod = feed_tensor.lod();
out_tensor->set_meta(meta);
} else if (feed_item.index() == 1) { // Strings
auto& feed_str = PADDLE_GET_CONST(framework::Strings, feed_item); auto& feed_str = PADDLE_GET_CONST(framework::Strings, feed_item);
framework::Variable* out_var =
PADDLE_GET(framework::Variable*, ctx->GetOutputVarPtrs("Out")[0]);
out_var->GetMutable<framework::Strings>()->resize(feed_str.size()); out_var->GetMutable<framework::Strings>()->resize(feed_str.size());
} else { } else if (feed_item.index() == 2) { // SparseCooTensor
auto& feed_sparse_tensor = auto& feed_sparse_tensor =
PADDLE_GET_CONST(phi::SparseCooTensor, feed_item); PADDLE_GET_CONST(phi::SparseCooTensor, feed_item);
framework::Variable* out_var =
PADDLE_GET(framework::Variable*, ctx->GetOutputVarPtrs("Out")[0]);
out_var->GetMutable<phi::SparseCooTensor>()->set_meta( out_var->GetMutable<phi::SparseCooTensor>()->set_meta(
feed_sparse_tensor.meta()); feed_sparse_tensor.meta());
out_var->GetMutable<phi::SparseCooTensor>()->SetCoalesced( out_var->GetMutable<phi::SparseCooTensor>()->SetCoalesced(
feed_sparse_tensor.coalesced()); feed_sparse_tensor.coalesced());
out_var->GetMutable<phi::SparseCooTensor>()->SetIndicesDict( out_var->GetMutable<phi::SparseCooTensor>()->SetIndicesDict(
feed_sparse_tensor.GetIndicesDict()); feed_sparse_tensor.GetIndicesDict());
} else {
PADDLE_THROW(
phi::errors::Unimplemented("Only support DenseTnesor, Strings, and "
"SparseCooTensor for feed op now."));
} }
} }
} }
...@@ -151,7 +160,23 @@ class FeedOp : public framework::OperatorWithKernel { ...@@ -151,7 +160,23 @@ class FeedOp : public framework::OperatorWithKernel {
protected: protected:
phi::KernelKey GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); const framework::Variable* x_var = ctx.InputVar("X");
auto& x = x_var->Get<framework::FeedList>();
int col = ctx.Attr<int>("col");
auto& feed_item = x[col];
framework::proto::VarType::Type expected_data_type;
if (feed_item.index() == 0) { // DenseTensor
expected_data_type = framework::TransToProtoVarType(
PADDLE_GET_CONST(phi::DenseTensor, feed_item).dtype());
} else if (feed_item.index() == 2) { // SparseCooTensor
expected_data_type = framework::TransToProtoVarType(
PADDLE_GET_CONST(phi::SparseCooTensor, feed_item).dtype());
} else { // Strings
expected_data_type = framework::proto::VarType::FP32;
}
return phi::KernelKey(expected_data_type, ctx.GetPlace());
} }
}; };
......
...@@ -21,7 +21,6 @@ limitations under the License. */ ...@@ -21,7 +21,6 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/cast_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op_function.h" #include "paddle/fluid/operators/reduce_ops/reduce_op_function.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
// only can include the headers in paddle/phi/api dirs // only can include the headers in paddle/phi/api dirs
......
...@@ -111,7 +111,9 @@ PD_REGISTER_KERNEL(check_finite_and_unscale, ...@@ -111,7 +111,9 @@ PD_REGISTER_KERNEL(check_finite_and_unscale,
ALL_LAYOUT, ALL_LAYOUT,
phi::CheckFiniteAndUnscaleKernel, phi::CheckFiniteAndUnscaleKernel,
float, float,
double) {} double) {
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::BOOL);
}
PD_REGISTER_KERNEL(update_loss_scaling, PD_REGISTER_KERNEL(update_loss_scaling,
CPU, CPU,
......
...@@ -79,7 +79,9 @@ PD_REGISTER_KERNEL(equal_all, ...@@ -79,7 +79,9 @@ PD_REGISTER_KERNEL(equal_all,
int, int,
int64_t, int64_t,
float, float,
double) {} double) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
}
#define PD_REGISTER_COMPARE_KERNEL(name, func) \ #define PD_REGISTER_COMPARE_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, \ PD_REGISTER_KERNEL(name, \
...@@ -92,7 +94,9 @@ PD_REGISTER_KERNEL(equal_all, ...@@ -92,7 +94,9 @@ PD_REGISTER_KERNEL(equal_all,
int64_t, \ int64_t, \
float, \ float, \
double, \ double, \
phi::dtype::float16) {} \ phi::dtype::float16) { \
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL); \
} \
PD_REGISTER_KERNEL(name##_raw, \ PD_REGISTER_KERNEL(name##_raw, \
CPU, \ CPU, \
ALL_LAYOUT, \ ALL_LAYOUT, \
...@@ -103,7 +107,9 @@ PD_REGISTER_KERNEL(equal_all, ...@@ -103,7 +107,9 @@ PD_REGISTER_KERNEL(equal_all,
int64_t, \ int64_t, \
float, \ float, \
double, \ double, \
phi::dtype::float16) {} phi::dtype::float16) { \
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL); \
}
PD_REGISTER_COMPARE_KERNEL(less_than, LessThan) PD_REGISTER_COMPARE_KERNEL(less_than, LessThan)
PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual) PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual)
PD_REGISTER_COMPARE_KERNEL(greater_than, GreaterThan) PD_REGISTER_COMPARE_KERNEL(greater_than, GreaterThan)
......
...@@ -209,7 +209,11 @@ PD_REGISTER_KERNEL(dropout, ...@@ -209,7 +209,11 @@ PD_REGISTER_KERNEL(dropout,
phi::DropoutRawKernel, phi::DropoutRawKernel,
float, float,
double, double,
phi::dtype::bfloat16) {} phi::dtype::bfloat16) {
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::UINT8);
}
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
dropout_nd, CPU, ALL_LAYOUT, phi::DropoutNdKernel, float, double) {} dropout_nd, CPU, ALL_LAYOUT, phi::DropoutNdKernel, float, double) {
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::UINT8);
}
...@@ -83,4 +83,6 @@ void OneHotRawKernel(const Context& dev_ctx, ...@@ -83,4 +83,6 @@ void OneHotRawKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
one_hot_raw, CPU, ALL_LAYOUT, phi::OneHotRawKernel, int, int64_t) {} one_hot_raw, CPU, ALL_LAYOUT, phi::OneHotRawKernel, int, int64_t) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}
...@@ -358,7 +358,9 @@ PD_REGISTER_KERNEL(check_finite_and_unscale, ...@@ -358,7 +358,9 @@ PD_REGISTER_KERNEL(check_finite_and_unscale,
float, float,
double, double,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16) {} phi::dtype::bfloat16) {
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::BOOL);
}
PD_REGISTER_KERNEL(update_loss_scaling, PD_REGISTER_KERNEL(update_loss_scaling,
GPU, GPU,
......
...@@ -90,6 +90,7 @@ PD_REGISTER_KERNEL(dropout, ...@@ -90,6 +90,7 @@ PD_REGISTER_KERNEL(dropout,
phi::dtype::bfloat16, phi::dtype::bfloat16,
phi::dtype::float16) { phi::dtype::float16) {
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::UINT8);
} }
PD_REGISTER_KERNEL(dropout_nd, PD_REGISTER_KERNEL(dropout_nd,
...@@ -101,4 +102,5 @@ PD_REGISTER_KERNEL(dropout_nd, ...@@ -101,4 +102,5 @@ PD_REGISTER_KERNEL(dropout_nd,
phi::dtype::bfloat16, phi::dtype::bfloat16,
phi::dtype::float16) { phi::dtype::float16) {
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::UINT8);
} }
...@@ -91,4 +91,6 @@ void OneHotRawKernel(const Context& dev_ctx, ...@@ -91,4 +91,6 @@ void OneHotRawKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
one_hot_raw, GPU, ALL_LAYOUT, phi::OneHotRawKernel, int, int64_t) {} one_hot_raw, GPU, ALL_LAYOUT, phi::OneHotRawKernel, int, int64_t) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}
...@@ -95,26 +95,49 @@ inline void CompareAllKernelImpl(const Context& ctx, ...@@ -95,26 +95,49 @@ inline void CompareAllKernelImpl(const Context& ctx,
} // namespace phi } // namespace phi
#ifdef PADDLE_WITH_XPU_KP #ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL(less_than, KPS, ALL_LAYOUT, phi::LessThanKernel, int) {} PD_REGISTER_KERNEL(less_than, KPS, ALL_LAYOUT, phi::LessThanKernel, int) {
PD_REGISTER_KERNEL(less_equal, KPS, ALL_LAYOUT, phi::LessEqualKernel, int) {} kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
}
PD_REGISTER_KERNEL(less_equal, KPS, ALL_LAYOUT, phi::LessEqualKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
}
PD_REGISTER_KERNEL(greater_than, KPS, ALL_LAYOUT, phi::GreaterThanKernel, int) { PD_REGISTER_KERNEL(greater_than, KPS, ALL_LAYOUT, phi::GreaterThanKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
} }
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
greater_equal, KPS, ALL_LAYOUT, phi::GreaterEqualKernel, int) {} greater_equal, KPS, ALL_LAYOUT, phi::GreaterEqualKernel, int) {
PD_REGISTER_KERNEL(equal, KPS, ALL_LAYOUT, phi::EqualKernel, int) {} kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
PD_REGISTER_KERNEL(not_equal, KPS, ALL_LAYOUT, phi::NotEqualKernel, int) {} }
PD_REGISTER_KERNEL(equal, KPS, ALL_LAYOUT, phi::EqualKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
}
PD_REGISTER_KERNEL(not_equal, KPS, ALL_LAYOUT, phi::NotEqualKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
}
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
less_than_raw, KPS, ALL_LAYOUT, phi::LessThanRawKernel, int) {} less_than_raw, KPS, ALL_LAYOUT, phi::LessThanRawKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
}
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
less_equal_raw, KPS, ALL_LAYOUT, phi::LessEqualRawKernel, int) {} less_equal_raw, KPS, ALL_LAYOUT, phi::LessEqualRawKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
}
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
greater_than_raw, KPS, ALL_LAYOUT, phi::GreaterThanRawKernel, int) {} greater_than_raw, KPS, ALL_LAYOUT, phi::GreaterThanRawKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
}
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
greater_equal_raw, KPS, ALL_LAYOUT, phi::GreaterEqualRawKernel, int) {} greater_equal_raw, KPS, ALL_LAYOUT, phi::GreaterEqualRawKernel, int) {
PD_REGISTER_KERNEL(equal_raw, KPS, ALL_LAYOUT, phi::EqualRawKernel, int) {} kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
}
PD_REGISTER_KERNEL(equal_raw, KPS, ALL_LAYOUT, phi::EqualRawKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
}
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
not_equal_raw, KPS, ALL_LAYOUT, phi::NotEqualRawKernel, int) {} not_equal_raw, KPS, ALL_LAYOUT, phi::NotEqualRawKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
}
#else #else
...@@ -126,7 +149,9 @@ PD_REGISTER_KERNEL(equal_all, ...@@ -126,7 +149,9 @@ PD_REGISTER_KERNEL(equal_all,
int, int,
int64_t, int64_t,
float, float,
double) {} double) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
}
#define PD_REGISTER_COMPARE_KERNEL(name, func) \ #define PD_REGISTER_COMPARE_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, \ PD_REGISTER_KERNEL(name, \
...@@ -140,7 +165,9 @@ PD_REGISTER_KERNEL(equal_all, ...@@ -140,7 +165,9 @@ PD_REGISTER_KERNEL(equal_all,
float, \ float, \
double, \ double, \
phi::dtype::float16, \ phi::dtype::float16, \
phi::dtype::bfloat16) {} \ phi::dtype::bfloat16) { \
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL); \
} \
PD_REGISTER_KERNEL(name##_raw, \ PD_REGISTER_KERNEL(name##_raw, \
KPS, \ KPS, \
ALL_LAYOUT, \ ALL_LAYOUT, \
...@@ -152,7 +179,9 @@ PD_REGISTER_KERNEL(equal_all, ...@@ -152,7 +179,9 @@ PD_REGISTER_KERNEL(equal_all,
float, \ float, \
double, \ double, \
phi::dtype::float16, \ phi::dtype::float16, \
phi::dtype::bfloat16) {} phi::dtype::bfloat16) { \
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL); \
}
PD_REGISTER_COMPARE_KERNEL(less_than, LessThan) PD_REGISTER_COMPARE_KERNEL(less_than, LessThan)
PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual) PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual)
......
...@@ -30,12 +30,18 @@ void OneHotKernel(const Context& dev_ctx, ...@@ -30,12 +30,18 @@ void OneHotKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(one_hot, CPU, ALL_LAYOUT, phi::OneHotKernel, int, int64_t) {} PD_REGISTER_KERNEL(one_hot, CPU, ALL_LAYOUT, phi::OneHotKernel, int, int64_t) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::FLOAT32);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(one_hot, GPU, ALL_LAYOUT, phi::OneHotKernel, int, int64_t) {} PD_REGISTER_KERNEL(one_hot, GPU, ALL_LAYOUT, phi::OneHotKernel, int, int64_t) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::FLOAT32);
}
#endif #endif
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
PD_REGISTER_KERNEL(one_hot, XPU, ALL_LAYOUT, phi::OneHotKernel, int, int64_t) {} PD_REGISTER_KERNEL(one_hot, XPU, ALL_LAYOUT, phi::OneHotKernel, int, int64_t) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::FLOAT32);
}
#endif #endif
...@@ -47,6 +47,8 @@ PD_REGISTER_KERNEL(shape, ...@@ -47,6 +47,8 @@ PD_REGISTER_KERNEL(shape,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) { phi::dtype::complex<double>) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->OutputAt(0).SetBackend(phi::Backend::CPU);
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT32);
} }
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
...@@ -65,6 +67,8 @@ PD_REGISTER_KERNEL(shape, ...@@ -65,6 +67,8 @@ PD_REGISTER_KERNEL(shape,
phi::dtype::complex<double>, phi::dtype::complex<double>,
phi::dtype::float16) { phi::dtype::float16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->OutputAt(0).SetBackend(phi::Backend::CPU);
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT32);
} }
#endif #endif
...@@ -80,5 +84,7 @@ PD_REGISTER_KERNEL(shape, ...@@ -80,5 +84,7 @@ PD_REGISTER_KERNEL(shape,
double, double,
phi::dtype::float16) { phi::dtype::float16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->OutputAt(0).SetBackend(phi::Backend::CPU);
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT32);
} }
#endif #endif
...@@ -285,4 +285,6 @@ PD_REGISTER_KERNEL(check_finite_and_unscale, ...@@ -285,4 +285,6 @@ PD_REGISTER_KERNEL(check_finite_and_unscale,
ALL_LAYOUT, ALL_LAYOUT,
phi::CheckFiniteAndUnscaleKernel, phi::CheckFiniteAndUnscaleKernel,
float, float,
phi::dtype::float16) {} phi::dtype::float16) {
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::BOOL);
}
...@@ -90,7 +90,9 @@ DEFINE_XPU_COMPARE_KERNEL(GreaterEqual, xpu::broadcast_greater_equal<XPUType>) ...@@ -90,7 +90,9 @@ DEFINE_XPU_COMPARE_KERNEL(GreaterEqual, xpu::broadcast_greater_equal<XPUType>)
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
less_than, XPU, ALL_LAYOUT, phi::LessThanKernel, int, int64_t, float) {} less_than, XPU, ALL_LAYOUT, phi::LessThanKernel, int, int64_t, float) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
}
PD_REGISTER_KERNEL(less_than_raw, PD_REGISTER_KERNEL(less_than_raw,
XPU, XPU,
...@@ -98,18 +100,24 @@ PD_REGISTER_KERNEL(less_than_raw, ...@@ -98,18 +100,24 @@ PD_REGISTER_KERNEL(less_than_raw,
phi::LessThanRawKernel, phi::LessThanRawKernel,
int, int,
int64_t, int64_t,
float) {} float) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
}
#define PD_REGISTER_COMPARE_KERNEL(name, func) \ #define PD_REGISTER_COMPARE_KERNEL(name, func) \
PD_REGISTER_KERNEL( \ PD_REGISTER_KERNEL( \
name, XPU, ALL_LAYOUT, phi::func##Kernel, int, int64_t, float) {} \ name, XPU, ALL_LAYOUT, phi::func##Kernel, int, int64_t, float) { \
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL); \
} \
PD_REGISTER_KERNEL(name##_raw, \ PD_REGISTER_KERNEL(name##_raw, \
XPU, \ XPU, \
ALL_LAYOUT, \ ALL_LAYOUT, \
phi::func##RawKernel, \ phi::func##RawKernel, \
int, \ int, \
int64_t, \ int64_t, \
float) {} float) { \
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL); \
}
PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual) PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual)
PD_REGISTER_COMPARE_KERNEL(greater_than, GreaterThan) PD_REGISTER_COMPARE_KERNEL(greater_than, GreaterThan)
......
...@@ -62,4 +62,6 @@ void OneHotRawKernel(const Context& dev_ctx, ...@@ -62,4 +62,6 @@ void OneHotRawKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
one_hot_raw, XPU, ALL_LAYOUT, phi::OneHotRawKernel, int, int64_t) {} one_hot_raw, XPU, ALL_LAYOUT, phi::OneHotRawKernel, int, int64_t) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}
...@@ -22,4 +22,7 @@ KernelSignature CastOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -22,4 +22,7 @@ KernelSignature CastOpArgumentMapping(const ArgumentMappingContext& ctx) {
} // namespace phi } // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(transfer_dtype, cast);
PD_REGISTER_ARG_MAPPING_FN(cast, phi::CastOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(cast, phi::CastOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(transfer_dtype, phi::CastOpArgumentMapping);
...@@ -22,7 +22,7 @@ from paddle.framework import set_flags ...@@ -22,7 +22,7 @@ from paddle.framework import set_flags
paddle.enable_static() paddle.enable_static()
def build_resnet50(): def build_resnet50(use_amp=False):
main_program = paddle.static.Program() main_program = paddle.static.Program()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
...@@ -36,49 +36,80 @@ def build_resnet50(): ...@@ -36,49 +36,80 @@ def build_resnet50():
loss = paddle.nn.functional.cross_entropy(input=prediction, label=label) loss = paddle.nn.functional.cross_entropy(input=prediction, label=label)
loss = paddle.mean(loss) loss = paddle.mean(loss)
adam = paddle.optimizer.Adam(learning_rate=0.001) adam = paddle.optimizer.Adam(learning_rate=0.001)
if use_amp:
adam = paddle.static.amp.decorate(
optimizer=adam,
init_loss_scaling=1.0,
use_dynamic_loss_scaling=False,
use_pure_fp16=True,
use_fp16_guard=False,
)
adam.minimize(loss) adam.minimize(loss)
return main_program, startup_program, loss build_strategy = paddle.static.BuildStrategy()
build_strategy.enable_addto = True
build_strategy.fuse_elewise_add_act_ops = True
if use_amp:
build_strategy.fuse_bn_act_ops = True
build_strategy.fuse_bn_add_act_ops = True
main_program = paddle.static.CompiledProgram(
main_program, build_strategy=build_strategy
)
class TestAOTChooseKernel(unittest.TestCase): return main_program, startup_program, loss, adam
def test_aot_choose_kernel(self):
if not paddle.fluid.core.is_compiled_with_cuda():
return
def run(aot_choose_kernel=None):
def run_resnet50(aot_choose_kernel=False, use_amp=False):
paddle.seed(2022) paddle.seed(2022)
np.random.seed(2022) np.random.seed(2022)
main_program, startup_program, loss = build_resnet50() main_program, startup_program, loss, optimizer = build_resnet50(use_amp)
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
scope = paddle.static.Scope() scope = paddle.static.Scope()
exe = paddle.static.Executor()
set_flags({'FLAGS_cudnn_deterministic': 1}) set_flags({'FLAGS_cudnn_deterministic': 1})
if aot_choose_kernel: if aot_choose_kernel:
set_flags({'FLAGS_new_executor_static_build': 1}) set_flags({'FLAGS_new_executor_static_build': 1})
else:
set_flags({'FLAGS_new_executor_static_build': 0}) if use_amp:
set_flags({'FLAGS_conv_workspace_size_limit': 1500})
set_flags({'FLAGS_max_inplace_grad_add': 8})
set_flags({'FLAGS_cudnn_batchnorm_spatial_persistent': 1})
with paddle.static.scope_guard(scope): with paddle.static.scope_guard(scope):
exe.run(startup_program) exe.run(startup_program)
if use_amp:
optimizer.amp_init(place)
for i in range(10): feed_dtype = 'float16' if use_amp else 'float32'
for i in range(1):
feed = { feed = {
'image': np.random.randint( 'image': np.random.randint(
0, 256, size=[32, 3, 224, 224] 0, 256, size=[32, 3, 224, 224]
).astype('float32'), ).astype(feed_dtype),
'label': np.random.randint(0, 1000, size=[32]).astype( 'label': np.random.randint(0, 1000, size=[32]).astype('int64'),
'int64'
),
} }
loss_ = exe.run(main_program, feed=feed, fetch_list=[loss]) loss_ = exe.run(main_program, feed=feed, fetch_list=[loss])
return loss_ return loss_
loss1 = run(aot_choose_kernel=True)
loss2 = run(aot_choose_kernel=False)
class TestAOTChooseKernel(unittest.TestCase):
def test_resnet50_aot_choose_kernel(self):
if not paddle.fluid.core.is_compiled_with_cuda():
return
loss1 = run_resnet50(aot_choose_kernel=True)
loss2 = run_resnet50(aot_choose_kernel=False)
self.assertEqual(loss1, loss2)
def test_resnet50_amp_aot_choose_kernel(self):
if not paddle.fluid.core.is_compiled_with_cuda():
return
loss1 = run_resnet50(aot_choose_kernel=True, use_amp=True)
loss2 = run_resnet50(aot_choose_kernel=False, use_amp=True)
self.assertEqual(loss1, loss2) self.assertEqual(loss1, loss2)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册