未验证 提交 eb42dd52 编写于 作者: C Chen Weihang 提交者: GitHub

[Pten->Phi PR4] Rename pten in funcs to phi (#39961)

* rename pten_utils to phi_utils

* rename pten_utils target

* rename Pten to Phi

* replace pten with phi

* resolve conflict
......@@ -238,7 +238,7 @@ void DeserializeLodTensor(framework::Variable* var, const VarMsg& msg,
void* tensor_data = tensor->mutable_data(
place,
framework::TransToPtenDataType(VarMessageToVarType(msg.data_type())));
framework::TransToPhiDataType(VarMessageToVarType(msg.data_type())));
// IO Buffer
if (platform::is_cpu_place(place)) {
......@@ -281,7 +281,7 @@ void DeserializeSelectedRows(
tensor->Resize(phi::make_ddim(vec_dim));
void* tensor_data = tensor->mutable_data(
place,
framework::TransToPtenDataType(VarMessageToVarType(msg.data_type())));
framework::TransToPhiDataType(VarMessageToVarType(msg.data_type())));
// IO Buffer
if (platform::is_cpu_place(place)) {
unsigned long data_len; // NOLINT
......
......@@ -33,36 +33,36 @@ static void ScaleDeviceDispatch(const phi::DenseTensor& dense_tensor,
phi::DenseTensor* dense_out) {
switch (dense_tensor.dtype()) {
case phi::DataType::FLOAT64: {
phi::ScaleKernel<double, typename paddle::framework::ConvertToPtenContext<
phi::ScaleKernel<double, typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
static_cast<const typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
dense_tensor /* tensor */, scale /* scale */, bias /* bias */,
bias_after_scale /* bias_after_scale */, dense_out /* out tensor */);
break;
}
case phi::DataType::FLOAT32: {
phi::ScaleKernel<float, typename paddle::framework::ConvertToPtenContext<
phi::ScaleKernel<float, typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
static_cast<const typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
dense_tensor /* tensor */, scale /* scale */, bias /* bias */,
bias_after_scale /* bias_after_scale */, dense_out /* out tensor */);
break;
}
case phi::DataType::INT64: {
phi::ScaleKernel<int64_t, typename paddle::framework::
ConvertToPtenContext<DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
phi::ScaleKernel<int64_t, typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
dense_tensor /* tensor */, scale /* scale */, bias /* bias */,
bias_after_scale /* bias_after_scale */, dense_out /* out tensor */);
break;
}
case phi::DataType::INT32: {
phi::ScaleKernel<int32_t, typename paddle::framework::
ConvertToPtenContext<DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
phi::ScaleKernel<int32_t, typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
dense_tensor /* tensor */, scale /* scale */, bias /* bias */,
bias_after_scale /* bias_after_scale */, dense_out /* out tensor */);
......
......@@ -22,7 +22,7 @@
#include "paddle/phi/api/all.h"
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/framework/variable.h"
namespace egr {
......@@ -43,7 +43,7 @@ paddle::experimental::Tensor CreateTensorWithValue(
bool is_leaf) {
paddle::experimental::Tensor out = paddle::experimental::full(
phi::vectorize(ddim), paddle::experimental::Scalar(value), dtype,
phi::TransToPtenBackend(place));
phi::TransToPhiBackend(place));
auto meta = EagerUtils::autograd_meta(&out);
if (is_leaf) {
......
......@@ -27,7 +27,7 @@
#include "paddle/fluid/pybind/pybind.h"
#include "paddle/fluid/string/string_helper.h"
// pten
// phi
#include "paddle/phi/kernels/declarations.h"
#define NUM_CREATED_DUP_INPUTS 4
......@@ -544,7 +544,7 @@ static bool CheckOpProto(proto::OpProto* op_proto) {
// since only OperatorWithKernel can run in dygraph mode.
auto& all_kernels = paddle::framework::OperatorWithKernel::AllOpKernels();
if (!all_kernels.count(op_type) &&
!phi::KernelFactory::Instance().HasCompatiblePtenKernel(op_type)) {
!phi::KernelFactory::Instance().HasCompatiblePhiKernel(op_type)) {
return false;
}
......
......@@ -14,10 +14,10 @@
#pragma once
// framework deps
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable.h"
// pten deps
// Phi deps
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/api/lib/api_declare.h"
#include "paddle/phi/api/lib/utils/tensor_utils.h"
......@@ -31,7 +31,7 @@
* provide variable in
* paddle::framework::ExecutionContext to support it. We should remove this as
* soon as we finish our latest
* Pten Lib, and use paddle::experimental::Tensor instead.
* Phi Lib, and use paddle::experimental::Tensor instead.
*
* Note: Keep this class as clean as possible.
* This class should only support method declared in
......
......@@ -23,7 +23,7 @@
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/framework/variable.h"
PADDLE_DEFINE_EXPORTED_bool(retain_grad_for_all_tensor, true,
......
......@@ -193,9 +193,9 @@ cc_library(unused_var_check SRCS unused_var_check.cc DEPS glog no_need_buffer_va
cc_library(op_kernel_type SRCS op_kernel_type.cc DEPS device_context place)
IF(WITH_XPU)
cc_library(phi_utils SRCS pten_utils.cc DEPS lod_tensor selected_rows_utils place phi var_type_traits phi_api_utils op_info xpu_op_list)
cc_library(phi_utils SRCS phi_utils.cc DEPS lod_tensor selected_rows_utils place phi var_type_traits phi_api_utils op_info xpu_op_list)
ELSE()
cc_library(phi_utils SRCS pten_utils.cc DEPS lod_tensor selected_rows_utils place phi var_type_traits phi_api_utils op_info)
cc_library(phi_utils SRCS phi_utils.cc DEPS lod_tensor selected_rows_utils place phi var_type_traits phi_api_utils op_info)
ENDIF()
IF(WITH_XPU)
......@@ -450,7 +450,7 @@ if(WITH_TESTING AND TEST selected_rows_utils_test)
endif()
cc_test(scope_guard_test SRCS scope_guard_test.cc)
cc_test(phi_utils_test SRCS pten_utils_test.cc DEPS phi_utils)
cc_test(phi_utils_test SRCS phi_utils_test.cc DEPS phi_utils)
if(WITH_GPU OR WITH_ROCM)
cc_library(fluid_convert_utils SRCS convert_utils.cc DEPS data_type place gpu_info)
......
......@@ -33,7 +33,7 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/pybind/pybind.h"
// pten
// phi
#include "paddle/phi/kernels/declarations.h"
namespace paddle {
......
......@@ -18,7 +18,7 @@ limitations under the License. */
namespace paddle {
namespace framework {
paddle::experimental::DataType TransToPtenDataType(
paddle::experimental::DataType TransToPhiDataType(
const paddle::framework::proto::VarType::Type& dtype) {
// Set the order of case branches according to the frequency with
// the data type is used
......
......@@ -32,7 +32,7 @@ namespace framework {
using DataType = paddle::experimental::DataType;
using DataLayout = paddle::experimental::DataLayout;
DataType TransToPtenDataType(
DataType TransToPhiDataType(
const paddle::framework::proto::VarType::Type& dtype);
paddle::framework::proto::VarType::Type TransToProtoVarType(
......
......@@ -43,35 +43,35 @@ TEST(ConvertUtils, DataType) {
CHECK(paddle::framework::TransToProtoVarType(paddle::DataType::FLOAT16) ==
paddle::framework::proto::VarType::FP16);
// proto -> enum
CHECK(paddle::framework::TransToPtenDataType(
CHECK(paddle::framework::TransToPhiDataType(
paddle::framework::proto::VarType::FP64) ==
paddle::DataType::FLOAT64);
CHECK(paddle::framework::TransToPtenDataType(
CHECK(paddle::framework::TransToPhiDataType(
paddle::framework::proto::VarType::FP32) ==
paddle::DataType::FLOAT32);
CHECK(paddle::framework::TransToPtenDataType(
CHECK(paddle::framework::TransToPhiDataType(
paddle::framework::proto::VarType::INT64) ==
paddle::DataType::INT64);
CHECK(paddle::framework::TransToPtenDataType(
CHECK(paddle::framework::TransToPhiDataType(
paddle::framework::proto::VarType::INT32) ==
paddle::DataType::INT32);
CHECK(paddle::framework::TransToPtenDataType(
CHECK(paddle::framework::TransToPhiDataType(
paddle::framework::proto::VarType::INT8) == paddle::DataType::INT8);
CHECK(paddle::framework::TransToPtenDataType(
CHECK(paddle::framework::TransToPhiDataType(
paddle::framework::proto::VarType::UINT8) ==
paddle::DataType::UINT8);
CHECK(paddle::framework::TransToPtenDataType(
CHECK(paddle::framework::TransToPhiDataType(
paddle::framework::proto::VarType::INT16) ==
paddle::DataType::INT16);
CHECK(paddle::framework::TransToPtenDataType(
CHECK(paddle::framework::TransToPhiDataType(
paddle::framework::proto::VarType::BOOL) == paddle::DataType::BOOL);
CHECK(paddle::framework::TransToPtenDataType(
CHECK(paddle::framework::TransToPhiDataType(
paddle::framework::proto::VarType::COMPLEX64) ==
paddle::DataType::COMPLEX64);
CHECK(paddle::framework::TransToPtenDataType(
CHECK(paddle::framework::TransToPhiDataType(
paddle::framework::proto::VarType::COMPLEX128) ==
paddle::DataType::COMPLEX128);
CHECK(paddle::framework::TransToPtenDataType(
CHECK(paddle::framework::TransToPhiDataType(
paddle::framework::proto::VarType::FP16) ==
paddle::DataType::FLOAT16);
}
......
......@@ -30,7 +30,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_meta_info_helper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/fluid/string/string_helper.h"
......@@ -779,13 +779,13 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
for (size_t i = 0; i < ctx->InputSize(in_name); ++i) {
auto dtype = ctx->GetInputDataType(in_name, i);
vec_custom_dtype.emplace_back(
paddle::framework::TransToPtenDataType(dtype));
paddle::framework::TransToPhiDataType(dtype));
}
vec_input_dtypes.emplace_back(vec_custom_dtype);
} else {
auto dtype = ctx->GetInputDataType(in_name);
input_dtypes.emplace_back(
paddle::framework::TransToPtenDataType(dtype));
paddle::framework::TransToPhiDataType(dtype));
}
}
......
......@@ -23,7 +23,7 @@ limitations under the License. */
#include "paddle/fluid/platform/init.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/framework/phi_utils.h"
namespace paddle {
namespace framework {
......
......@@ -28,7 +28,7 @@ TEST(DataType, float16) {
Tensor tensor;
CPUPlace cpu;
tensor.mutable_data(cpu, f::TransToPtenDataType(dtype));
tensor.mutable_data(cpu, f::TransToPhiDataType(dtype));
// test fp16 tensor
EXPECT_EQ(f::TransToProtoVarType(tensor.dtype()),
......@@ -51,7 +51,7 @@ TEST(DataType, bfloat16) {
Tensor tensor;
CPUPlace cpu;
tensor.mutable_data(cpu, f::TransToPtenDataType(dtype));
tensor.mutable_data(cpu, f::TransToPhiDataType(dtype));
// test bf16 tensor
EXPECT_EQ(f::TransToProtoVarType(tensor.dtype()),
......
......@@ -34,7 +34,7 @@ limitations under the License. */
#include "paddle/fluid/platform/timer.h"
#include "paddle/fluid/pybind/pybind.h"
// pten
// phi
#include "paddle/phi/kernels/declarations.h"
namespace paddle {
namespace framework {
......
......@@ -161,7 +161,7 @@ void HeterWrapper::DeSerializeToTensor(Scope* scope,
tensor->set_lod(lod);
void* tensor_data = tensor->mutable_data(
place, framework::TransToPtenDataType(ToVarType(req_var.data_type())));
place, framework::TransToPhiDataType(ToVarType(req_var.data_type())));
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
memory::Copy(place, tensor_data, platform::CPUPlace(), req_var.data().data(),
......@@ -202,7 +202,7 @@ void HeterWrapper::DeSerializeToTensor(Scope* scope,
tensor->set_lod(lod);
void* tensor_data = tensor->mutable_data(
place, framework::TransToPtenDataType(ToVarType(req_var.data_type())));
place, framework::TransToPhiDataType(ToVarType(req_var.data_type())));
#ifdef PADDLE_WITH_XPU
memory::Copy(place, tensor_data, platform::CPUPlace(), req_var.data().data(),
......
......@@ -38,7 +38,7 @@ void SetMicroId(paddle::framework::Scope* scope,
std::vector<int> dims{1};
tensor->Resize(phi::make_ddim(dims));
void* tensor_data = tensor->mutable_data(
place, framework::TransToPtenDataType(framework::proto::VarType::FP32));
place, framework::TransToPhiDataType(framework::proto::VarType::FP32));
if (platform::is_gpu_place(place)) {
#ifdef PADDLE_WITH_CUDA
std::vector<char> temp;
......
......@@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/common/scalar_array.h"
......@@ -144,7 +144,7 @@ class CompatMetaTensor : public phi::MetaTensor {
}
} else {
auto* var = BOOST_GET_CONST(VarDesc*, var_);
return paddle::framework::TransToPtenDataType(var->GetDataType());
return paddle::framework::TransToPhiDataType(var->GetDataType());
}
}
......@@ -341,10 +341,10 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
}
if (infershape_inputs.size() != 1) {
infer_meta_context.EmplaceBackAttr(
std::move(experimental::MakePtenScalarArrayFromVarList(vars)));
std::move(experimental::MakePhiScalarArrayFromVarList(vars)));
} else {
infer_meta_context.EmplaceBackAttr(
std::move(experimental::MakePtenScalarArrayFromVar(*vars[0])));
std::move(experimental::MakePhiScalarArrayFromVar(*vars[0])));
}
} else {
// If is not in runtime, we will set default value(-1) for ScalarArray
......@@ -419,7 +419,7 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
if (ctx->IsRuntime()) {
Variable* var = BOOST_GET_CONST(Variable*, infershape_input[0]);
infer_meta_context.EmplaceBackAttr(
std::move(experimental::MakePtenScalarFromVar(*var)));
std::move(experimental::MakePhiScalarFromVar(*var)));
} else {
phi::Scalar tensor_scalar(-1);
tensor_scalar.SetFromTensor(true);
......@@ -481,7 +481,7 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
BOOST_GET_CONST(std::vector<std::string>, attr));
} else if (attr_defs[i].type_index ==
std::type_index(typeid(phi::DataType))) {
auto data_type = paddle::framework::TransToPtenDataType(
auto data_type = paddle::framework::TransToPhiDataType(
static_cast<framework::proto::VarType::Type>(
BOOST_GET_CONST(int, attr)));
infer_meta_context.EmplaceBackAttr(data_type);
......
......@@ -276,13 +276,13 @@ bool FuseOptimizerOpPass::OpWithKernelSupportCPUAndGPU(
bool support_gpu = false;
auto &kernel_factory = phi::KernelFactory::Instance();
auto kernel_key_map =
kernel_factory.SelectKernelMap(phi::TransToPtenKernelName(op_type));
kernel_factory.SelectKernelMap(phi::TransToPhiKernelName(op_type));
bool has_op_kernel = kernel_key_map.size() > 0 ? true : false;
for (auto &kernel : kernel_key_map) {
if (platform::is_gpu_place(phi::TransToPtenPlace(kernel.first.backend()))) {
if (platform::is_gpu_place(phi::TransToPhiPlace(kernel.first.backend()))) {
support_gpu = true;
} else if (platform::is_cpu_place(
phi::TransToPtenPlace(kernel.first.backend()))) {
phi::TransToPhiPlace(kernel.first.backend()))) {
support_cpu = true;
}
}
......
......@@ -96,7 +96,7 @@ void InitTensorHolder(Scope* scope, const paddle::platform::Place& place,
auto x = scope->Var(var_name);
auto tensor = x->GetMutable<LoDTensor>();
tensor->mutable_data(place,
framework::TransToPtenDataType(proto::VarType::FP32), 1);
framework::TransToPhiDataType(proto::VarType::FP32), 1);
}
void MainTest(bool convWithExistingBias) {
......
......@@ -126,7 +126,7 @@ void InitTensorHolder(Scope* scope, const paddle::platform::Place& place,
auto x = scope->Var(var_name);
auto tensor = x->GetMutable<LoDTensor>();
tensor->mutable_data(place,
framework::TransToPtenDataType(proto::VarType::FP32), 1);
framework::TransToPhiDataType(proto::VarType::FP32), 1);
}
void PreparePass(std::unique_ptr<ir::Graph>* graph, const ProgramDesc& prog,
......
......@@ -526,7 +526,7 @@ void InitTensorHolder(Scope* scope, const paddle::platform::Place& place,
auto x = scope->Var(var_name);
auto tensor = x->GetMutable<LoDTensor>();
tensor->mutable_data(place,
framework::TransToPtenDataType(proto::VarType::FP32), 1);
framework::TransToPhiDataType(proto::VarType::FP32), 1);
}
void PrepareGraph(std::unique_ptr<ir::Graph>* graph, const ProgramDesc& prog) {
......
......@@ -447,7 +447,7 @@ void MergeLoDTensor(LoDTensor *target,
target->set_layout(new_layout);
target->set_lod(new_lod);
target->mutable_data(dst_place,
paddle::framework::TransToPtenDataType(new_type));
paddle::framework::TransToPhiDataType(new_type));
int begin = 0;
for (auto *src : lod_tensors) {
......
......@@ -416,18 +416,18 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
if (op_with_kernel == nullptr) {
instr_node.OpBase()->Run(*local_scope, place_);
} else {
// fit for pten
if (instr_node.PtenKernel() && instr_node.PtenKernel()->IsValid()) {
VLOG(4) << "Run pten kernel: " << op->Type();
// fit for phi
if (instr_node.PhiKernel() && instr_node.PhiKernel()->IsValid()) {
VLOG(4) << "Run phi kernel: " << op->Type();
VLOG(4) << instr_node.InnerRuntimeContext().get() << " "
<< &instr_node.DeviceContext();
phi::KernelContext pt_kernel_context;
op_with_kernel->BuildPtenKernelContext(
op_with_kernel->BuildPhiKernelContext(
*instr_node.InnerRuntimeContext().get(),
const_cast<platform::DeviceContext*>(&instr_node.DeviceContext()),
&pt_kernel_context);
(*instr_node.PtenKernel())(&pt_kernel_context);
(*instr_node.PhiKernel())(&pt_kernel_context);
} else {
instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get());
......
......@@ -407,14 +407,14 @@ void build_op_func_list(const platform::Place& place,
auto exec_ctx =
ExecutionContext(*op_with_kernel, scope, *dev_ctx, runtime_context);
auto run_pten_kernel = false;
if (phi::KernelFactory::Instance().HasCompatiblePtenKernel(
auto run_phi_kernel = false;
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(
op_with_kernel->Type())) {
auto pt_kernel_key = op_with_kernel->ChoosePtenKernel(exec_ctx);
auto pt_kernel_name = op_with_kernel->PtenKernelSignature()->name;
auto pt_kernel_key = op_with_kernel->ChoosePhiKernel(exec_ctx);
auto pt_kernel_name = op_with_kernel->PhiKernelSignature()->name;
if (op_with_kernel->PtenKernel()->IsValid()) {
run_pten_kernel = true;
if (op_with_kernel->PhiKernel()->IsValid()) {
run_phi_kernel = true;
} else {
auto kernels_iter = all_op_kernels.find(op_with_kernel->Type());
if (kernels_iter == all_op_kernels.end() ||
......@@ -422,26 +422,26 @@ void build_op_func_list(const platform::Place& place,
kernels_iter->second.end()) {
auto pt_cpu_kernel_key = FallBackToCpu(
expected_kernel_key, pt_kernel_key, *op_with_kernel);
op_with_kernel->ResetPtenKernel(
op_with_kernel->ResetPhiKernel(
new phi::Kernel(phi::KernelFactory::Instance().SelectKernel(
pt_kernel_name, pt_cpu_kernel_key)));
if (op_with_kernel->PtenKernel()->IsValid()) {
if (op_with_kernel->PhiKernel()->IsValid()) {
VLOG(6) << "Static mode PrepareImpl - kernel name: "
<< pt_kernel_name
<< " | kernel key: " << pt_cpu_kernel_key
<< " | kernel: " << *(op_with_kernel->PtenKernel());
run_pten_kernel = true;
<< " | kernel: " << *(op_with_kernel->PhiKernel());
run_phi_kernel = true;
}
}
}
}
VLOG(3) << op_with_kernel->Type()
<< " : expected_kernel_key : " << expected_kernel_key;
if (run_pten_kernel) {
if (run_phi_kernel) {
phi::KernelContext pt_kernel_context;
op_with_kernel->BuildPtenKernelContext(runtime_context, dev_ctx,
&pt_kernel_context);
op_func_node.pt_kernel_ = op_with_kernel->PtenKernel();
op_with_kernel->BuildPhiKernelContext(runtime_context, dev_ctx,
&pt_kernel_context);
op_func_node.pt_kernel_ = op_with_kernel->PhiKernel();
(*op_func_node.pt_kernel_)(&pt_kernel_context);
} else {
......
......@@ -688,9 +688,7 @@ OpKernelComputeFunc Instruction::KernelFunc() const {
return op_func_node_.kernel_func_;
}
phi::Kernel* Instruction::PtenKernel() const {
return op_func_node_.pt_kernel_;
}
phi::Kernel* Instruction::PhiKernel() const { return op_func_node_.pt_kernel_; }
OpFuncType Instruction::KernelType() const { return op_func_node_.type_; }
......
......@@ -300,7 +300,7 @@ struct OpFuncNode {
OpKernelComputeFunc kernel_func_;
platform::DeviceContext* dev_ctx_; // not owned
// fit for pten kernel
// fit for phi kernel
phi::Kernel* pt_kernel_{nullptr}; // not owned
OpFuncType type_;
......@@ -321,7 +321,7 @@ class Instruction {
OpKernelComputeFunc KernelFunc() const;
phi::Kernel* PtenKernel() const;
phi::Kernel* PhiKernel() const;
OpFuncType KernelType() const;
......
......@@ -24,7 +24,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/details/nan_inf_utils.h"
#include "paddle/fluid/framework/op_call_stack.h"
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/transfer_scope_cache.h"
#include "paddle/fluid/framework/unused_var_check.h"
......@@ -616,9 +616,9 @@ bool OpSupportGPU(const std::string& op_type) {
// check in new Function kernel first
auto& kernel_factory = phi::KernelFactory::Instance();
auto kernel_key_map =
kernel_factory.SelectKernelMap(phi::TransToPtenKernelName(op_type));
kernel_factory.SelectKernelMap(phi::TransToPhiKernelName(op_type));
for (auto& kernel : kernel_key_map) {
if (platform::is_gpu_place(phi::TransToPtenPlace(kernel.first.backend()))) {
if (platform::is_gpu_place(phi::TransToPhiPlace(kernel.first.backend()))) {
return true;
}
}
......@@ -1186,10 +1186,10 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
// phase
phi::KernelKey pt_kernel_key;
std::string pt_kernel_name;
if (phi::KernelFactory::Instance().HasCompatiblePtenKernel(type_)) {
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(type_)) {
if (pt_kernel_signature_ == nullptr || pt_kernel_ == nullptr) {
pt_kernel_signature_.reset(
new KernelSignature(std::move(GetExpectedPtenKernelArgs(exe_ctx))));
new KernelSignature(std::move(GetExpectedPhiKernelArgs(exe_ctx))));
VLOG(6) << *pt_kernel_signature_.get();
kernel_type_.reset(
......@@ -1197,17 +1197,17 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
dev_ctx = pool.Get(kernel_type_->place_);
pt_kernel_name = pt_kernel_signature_->name;
pt_kernel_key = TransOpKernelTypeToPtenKernelKey(*kernel_type_.get());
pt_kernel_key = TransOpKernelTypeToPhiKernelKey(*kernel_type_.get());
pt_kernel_.reset(
new phi::Kernel(phi::KernelFactory::Instance().SelectKernel(
pt_kernel_name, pt_kernel_key)));
if (pt_kernel_->IsValid()) {
VLOG(6) << "Static mode ChoosePtenKernel - kernel name: "
VLOG(6) << "Static mode ChoosePhiKernel - kernel name: "
<< pt_kernel_name << " | kernel key: " << pt_kernel_key
<< " | kernel: " << *pt_kernel_;
} else {
VLOG(6) << "Static mode ChoosePtenKernel - kernel `" << pt_kernel_name
VLOG(6) << "Static mode ChoosePhiKernel - kernel `" << pt_kernel_name
<< "` not found.";
}
}
......@@ -1222,7 +1222,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
&& !is_xpu_unsupport
#endif
) {
run_pten_kernel_ = true;
run_phi_kernel_ = true;
} else {
auto& all_op_kernels = AllOpKernels();
auto kernels_iter = all_op_kernels.find(type_);
......@@ -1244,12 +1244,12 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
VLOG(6) << "Static mode PrepareImpl - kernel name: " << pt_kernel_name
<< " | kernel key: " << pt_cpu_kernel_key
<< " | kernel: " << *pt_kernel_;
run_pten_kernel_ = true;
run_phi_kernel_ = true;
}
}
}
}
if (!run_pten_kernel_) {
if (!run_phi_kernel_) {
if (kernel_type_.get() == nullptr || kernel_func_.get() == nullptr) {
ChooseKernel(exe_ctx);
dev_ctx = pool.Get(kernel_type_->place_);
......@@ -1290,13 +1290,13 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
platform::RecordEvent record_event("compute",
platform::TracerEventType::OperatorInner,
1, platform::EventRole::kInnerOp);
if (run_pten_kernel_) {
if (run_phi_kernel_) {
phi::KernelContext pt_kernel_context;
// Do data transform before building KernelContext
// TODO(zhiqiu): support TransferInplaceVarsBack
PreparePtenData(exec_scope, *pt_kernel_, *pt_kernel_signature_,
runtime_ctx);
BuildPtenKernelContext(*runtime_ctx, dev_ctx, &pt_kernel_context);
PreparePhiData(exec_scope, *pt_kernel_, *pt_kernel_signature_,
runtime_ctx);
BuildPhiKernelContext(*runtime_ctx, dev_ctx, &pt_kernel_context);
(*pt_kernel_)(&pt_kernel_context);
} else {
(*kernel_func_)(
......@@ -1388,26 +1388,26 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType(
return expected_kernel_key;
}
phi::KernelKey OperatorWithKernel::ChoosePtenKernel(
phi::KernelKey OperatorWithKernel::ChoosePhiKernel(
const ExecutionContext& ctx) const {
pt_kernel_signature_.reset(
new KernelSignature(std::move(GetExpectedPtenKernelArgs(ctx))));
new KernelSignature(std::move(GetExpectedPhiKernelArgs(ctx))));
VLOG(6) << *pt_kernel_signature_.get();
kernel_type_.reset(
new OpKernelType(std::move(InnerGetExpectedKernelType(ctx))));
auto pt_kernel_name = pt_kernel_signature_->name;
auto pt_kernel_key = TransOpKernelTypeToPtenKernelKey(*kernel_type_.get());
auto pt_kernel_key = TransOpKernelTypeToPhiKernelKey(*kernel_type_.get());
pt_kernel_.reset(new phi::Kernel(phi::KernelFactory::Instance().SelectKernel(
pt_kernel_name, pt_kernel_key)));
if (pt_kernel_->IsValid()) {
VLOG(6) << "Static mode ChoosePtenKernel - kernel name: " << pt_kernel_name
VLOG(6) << "Static mode ChoosePhiKernel - kernel name: " << pt_kernel_name
<< " | kernel key: " << pt_kernel_key
<< " | kernel: " << *pt_kernel_;
} else {
VLOG(6) << "Static mode ChoosePtenKernel - kernel `" << pt_kernel_name
VLOG(6) << "Static mode ChoosePhiKernel - kernel `" << pt_kernel_name
<< "` not found.";
}
return pt_kernel_key;
......@@ -1918,7 +1918,7 @@ OpKernelType OperatorWithKernel::GetKernelTypeForVar(
tensor.layout());
}
KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs(
KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs(
const ExecutionContext& ctx) const {
InitDefaultKernelSignatureMap();
ExecutionArgumentMappingContext arg_mapping_ctx(ctx);
......@@ -1926,7 +1926,7 @@ KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs(
arg_mapping_ctx);
}
Scope* OperatorWithKernel::PreparePtenData(
Scope* OperatorWithKernel::PreparePhiData(
const Scope& scope, const phi::Kernel& pt_kernel,
const KernelSignature& pt_kernel_signature, RuntimeContext* ctx) const {
auto& input_names = std::get<0>(pt_kernel_signature.args);
......@@ -1981,12 +1981,12 @@ Scope* OperatorWithKernel::PreparePtenData(
if (in_def.backend == phi::Backend::ALL_BACKEND) {
continue;
}
auto expected_place = phi::TransToPtenPlace(in_def.backend);
auto expected_place = phi::TransToPhiPlace(in_def.backend);
if (platform::is_same_place(tensor_in->place(), expected_place)) {
continue;
}
VLOG(3) << "PTen Transform Variable " << input_names[i] << " from "
VLOG(3) << "phi Transform Variable " << input_names[i] << " from "
<< tensor_in->place() << " to " << expected_place;
if (!new_scope) {
......@@ -2007,7 +2007,7 @@ Scope* OperatorWithKernel::PreparePtenData(
return new_scope;
}
void OperatorWithKernel::BuildPtenKernelContext(
void OperatorWithKernel::BuildPhiKernelContext(
const RuntimeContext& ctx, platform::DeviceContext* dev_ctx,
phi::KernelContext* pt_kernel_context) const {
pt_kernel_context->SetDeviceContext(dev_ctx);
......@@ -2111,7 +2111,7 @@ void OperatorWithKernel::BuildPtenKernelContext(
experimental::ResetTensorDtypeAndLayoutByArgDef(tensor_out,
output_defs.at(i));
SetAllocationForOutputTenosr(
tensor_out, phi::TransToPtenPlace(output_defs.at(i).backend));
tensor_out, phi::TransToPhiPlace(output_defs.at(i).backend));
pt_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out);
}
......@@ -2145,10 +2145,10 @@ void OperatorWithKernel::BuildPtenKernelContext(
auto& ins_vector = ctx.inputs.at(attr_names[i]);
if (ins_vector.size() == 1) { // ShapeTensor
pt_kernel_context->EmplaceBackAttr(std::move(
experimental::MakePtenScalarArrayFromVar(*ins_vector.front())));
experimental::MakePhiScalarArrayFromVar(*ins_vector.front())));
} else { // ShapeTensorList
pt_kernel_context->EmplaceBackAttr(std::move(
experimental::MakePtenScalarArrayFromVarList(ins_vector)));
experimental::MakePhiScalarArrayFromVarList(ins_vector)));
}
}
} else if (attr_defs[i].type_index ==
......@@ -2178,8 +2178,8 @@ void OperatorWithKernel::BuildPtenKernelContext(
}
} else {
auto& ins_vector = ctx.inputs.at(attr_names[i]);
pt_kernel_context->EmplaceBackAttr(std::move(
experimental::MakePtenScalarFromVar(*ins_vector.front())));
pt_kernel_context->EmplaceBackAttr(
std::move(experimental::MakePhiScalarFromVar(*ins_vector.front())));
}
} else {
......@@ -2198,7 +2198,7 @@ void OperatorWithKernel::BuildPtenKernelContext(
pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(std::string, attr));
} else if (attr_defs[i].type_index ==
std::type_index(typeid(phi::DataType))) {
auto data_type = paddle::framework::TransToPtenDataType(
auto data_type = paddle::framework::TransToPhiDataType(
static_cast<framework::proto::VarType::Type>(
BOOST_GET_CONST(int, attr)));
pt_kernel_context->EmplaceBackAttr(data_type);
......@@ -2206,7 +2206,7 @@ void OperatorWithKernel::BuildPtenKernelContext(
std::type_index(typeid(std::vector<int64_t>))) {
if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int>))) {
// Emplace Back Attr according to the type of Pten_Kernel args.
// Emplace Back Attr according to the type of Phi_Kernel args.
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
const std::vector<int64_t> vector_int64_attr(vector_int_attr.begin(),
vector_int_attr.end());
......
......@@ -30,7 +30,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_kernel_type.h"
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows_utils.h"
#include "paddle/fluid/framework/tensor.h"
......@@ -423,7 +423,7 @@ class ExecutionContext {
"size(%d).",
allocation_ptr->size(), phi::product(dim) * sizeof(T)));
paddle::framework::Tensor temp_tensor(framework::TransToPtenDataType(
paddle::framework::Tensor temp_tensor(framework::TransToPhiDataType(
framework::ToDataType(std::type_index(typeid(T)))));
temp_tensor.Resize(dim);
temp_tensor.ResetHolder(std::move(shared_allocation));
......@@ -538,14 +538,14 @@ class OperatorWithKernel : public OperatorBase {
}
bool SupportGPU() const override {
auto pten_kernels = phi::KernelFactory::Instance().SelectKernelMap(
phi::TransToPtenKernelName(type_));
auto has_pten_kernel =
std::any_of(pten_kernels.begin(), pten_kernels.end(),
auto phi_kernels = phi::KernelFactory::Instance().SelectKernelMap(
phi::TransToPhiKernelName(type_));
auto has_phi_kernel =
std::any_of(phi_kernels.begin(), phi_kernels.end(),
[](phi::KernelKeyMap::const_reference kern_pair) {
return kern_pair.first.backend() == phi::Backend::GPU;
});
if (has_pten_kernel) {
if (has_phi_kernel) {
return true;
} else {
auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
......@@ -558,7 +558,7 @@ class OperatorWithKernel : public OperatorBase {
}
bool SupportNPU() const override {
// TODO(zhiqiu): support pten if needed?
// TODO(zhiqiu): support phi if needed?
auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
return std::any_of(op_kernels.begin(), op_kernels.end(),
[](OpKernelMap::const_reference kern_pair) {
......@@ -566,7 +566,7 @@ class OperatorWithKernel : public OperatorBase {
});
}
bool SupportMLU() const override {
// TODO(zhiqiu): support pten if needed?
// TODO(zhiqiu): support phi if needed?
auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
return std::any_of(op_kernels.begin(), op_kernels.end(),
[](OpKernelMap::const_reference kern_pair) {
......@@ -603,39 +603,39 @@ class OperatorWithKernel : public OperatorBase {
return kernel_type_->place_;
}
/* member functions for adapting to pten lib */
/* member functions for adapting to phi lib */
/** In the Tensor calculation library, the new Kernel adopts a clearer and
* more streamlined design. The arguments of the Kernel and the input and
* output arguments registered in the original OpMaker do not match in some
* cases, so we use map to record the arguments required by the kernel.
* When selecting Kernel during Op execution, select the arguments of the
* original Op according to the GetExpectedPtenKernelArgs returned arguments.
* original Op according to the GetExpectedPhiKernelArgs returned arguments.
*/
phi::KernelSignature GetExpectedPtenKernelArgs(
phi::KernelSignature GetExpectedPhiKernelArgs(
const ExecutionContext& ctx) const;
/* member functions for adapting to pten lib */
phi::KernelKey ChoosePtenKernel(const ExecutionContext& ctx) const;
/* member functions for adapting to phi lib */
phi::KernelKey ChoosePhiKernel(const ExecutionContext& ctx) const;
/**
* Transfer data place for pten kernel
* Transfer data place for phi kernel
* Is this really needed?
*/
Scope* PreparePtenData(const Scope& scope, const phi::Kernel& pt_kernel,
const phi::KernelSignature& pt_kernel_signature,
RuntimeContext* ctx) const;
Scope* PreparePhiData(const Scope& scope, const phi::Kernel& pt_kernel,
const phi::KernelSignature& pt_kernel_signature,
RuntimeContext* ctx) const;
void BuildPtenKernelContext(const RuntimeContext& ctx,
platform::DeviceContext* dev_ctx,
phi::KernelContext* pt_kernel_context) const;
void BuildPhiKernelContext(const RuntimeContext& ctx,
platform::DeviceContext* dev_ctx,
phi::KernelContext* pt_kernel_context) const;
phi::KernelSignature* PtenKernelSignature() const {
phi::KernelSignature* PhiKernelSignature() const {
return pt_kernel_signature_.get();
}
phi::Kernel* PtenKernel() const { return pt_kernel_.get(); }
phi::Kernel* PhiKernel() const { return pt_kernel_.get(); }
void ResetPtenKernel(phi::Kernel* kernel) const {
void ResetPhiKernel(phi::Kernel* kernel) const {
return pt_kernel_.reset(kernel);
}
......@@ -692,9 +692,9 @@ class OperatorWithKernel : public OperatorBase {
mutable std::mutex cache_update_mutex_;
mutable bool enable_cache_transfer_scope_ = false;
// NOTE(chenweihang): Similar op members are used to adapt to
// new pten kernel, if there is a better design in the future,
// new phi kernel, if there is a better design in the future,
// we may polish the implementation here
mutable bool run_pten_kernel_ = false;
mutable bool run_phi_kernel_ = false;
mutable bool run_kp_kernel = false;
mutable std::unique_ptr<phi::KernelSignature> pt_kernel_signature_;
mutable std::unique_ptr<phi::Kernel> pt_kernel_;
......
......@@ -209,7 +209,7 @@ class CinnGraphSymbolizationTest : public ::testing::Test {
tensor.Resize(dims);
tensor.mutable_data(
platform::CPUPlace(),
framework::TransToPtenDataType(framework::proto::VarType::FP32));
framework::TransToPhiDataType(framework::proto::VarType::FP32));
return tensor;
};
#define FillFeedList(Name) feed_targets[#Name] = create_tensor();
......
......@@ -15,7 +15,7 @@ limitations under the License. */
#include <sstream>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_info.h"
......@@ -57,12 +57,11 @@ class KernelArgsNameMakerByOpProto : public KernelArgsNameMaker {
paddle::SmallVector<std::string> attr_names_;
};
OpKernelType TransPtenKernelKeyToOpKernelType(
const phi::KernelKey& kernel_key) {
OpKernelType TransPhiKernelKeyToOpKernelType(const phi::KernelKey& kernel_key) {
proto::VarType::Type data_type =
paddle::framework::TransToProtoVarType(kernel_key.dtype());
// no need to set current device id here
platform::Place place = phi::TransToPtenPlace(kernel_key.backend(), false);
platform::Place place = phi::TransToPhiPlace(kernel_key.backend(), false);
DataLayout data_layout = kernel_key.layout();
LibraryType library_type = LibraryType::kPlain;
if (kernel_key.backend() == phi::Backend::MKLDNN) {
......@@ -76,9 +75,9 @@ OpKernelType TransPtenKernelKeyToOpKernelType(
return OpKernelType(data_type, place, data_layout, library_type);
}
phi::KernelKey TransOpKernelTypeToPtenKernelKey(
phi::KernelKey TransOpKernelTypeToPhiKernelKey(
const OpKernelType& kernel_type) {
phi::Backend backend = phi::TransToPtenBackend(kernel_type.place_);
phi::Backend backend = phi::TransToPhiBackend(kernel_type.place_);
if (kernel_type.library_type_ == LibraryType::kMKLDNN) {
backend = phi::Backend::MKLDNN;
} else if (kernel_type.library_type_ == LibraryType::kCUDNN) {
......@@ -88,7 +87,7 @@ phi::KernelKey TransOpKernelTypeToPtenKernelKey(
}
paddle::experimental::DataLayout layout = kernel_type.data_layout_;
paddle::experimental::DataType dtype =
paddle::framework::TransToPtenDataType(kernel_type.data_type_);
paddle::framework::TransToPhiDataType(kernel_type.data_type_);
return phi::KernelKey(backend, layout, dtype);
}
......@@ -98,8 +97,8 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
#ifdef PADDLE_WITH_XPU
if (platform::is_xpu_place(expected_kernel_key.place_) ||
paddle::platform::is_in_xpu_black_list(op.Type())) {
VLOG(3) << "pten missing XPU kernel: " << op.Type()
<< ", expected_kernel_key:" << expected_kernel_key
VLOG(3) << "phi missing XPU kernel: " << op.Type()
<< "phipected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!";
return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
kernel_key.dtype());
......@@ -107,8 +106,8 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
#endif
#ifdef PADDLE_WITH_ASCEND_CL
if (platform::is_npu_place(expected_kernel_key.place_)) {
VLOG(3) << "pten missing NPU kernel: " << op.Type()
<< ", expected_kernel_key:" << expected_kernel_key
VLOG(3) << "phi missing NPU kernel: " << op.Type()
<< "phipected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!";
return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
kernel_key.dtype());
......@@ -116,8 +115,8 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
#endif
#ifdef PADDLE_WITH_MLU
if (platform::is_mlu_place(expected_kernel_key.place_)) {
VLOG(3) << "pten missing MLU kernel: " << op.Type()
<< ", expected_kernel_key:" << expected_kernel_key
VLOG(3) << "phi missing MLU kernel: " << op.Type()
<< "phipected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!";
return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
kernel_key.dtype());
......@@ -132,17 +131,17 @@ KernelArgsNameMakerByOpProto::GetInputArgsNames() {
auto& in = op_proto_->inputs()[i];
auto& in_name = in.name();
if ((in.has_extra() && in.extra()) || (in.has_quant() && in.quant())) {
VLOG(6) << "Parse PtenKernel input: skip extra & quant input - "
VLOG(6) << "Parse PhiKernel input: skip extra & quant input - "
<< in_name;
continue;
}
// If contains dispensable input, we should override the
// OpArgumentMapping method self in phi/ops/compat dir
if (in.has_dispensable() && in.dispensable()) {
VLOG(6) << "Parse PtenKernel input: skip dispensable input - " << in_name;
VLOG(6) << "Parse PhiKernel input: skip dispensable input - " << in_name;
continue;
}
VLOG(6) << "Parse PtenKernel input: " << in_name;
VLOG(6) << "Parse PhiKernel input: " << in_name;
input_names_.emplace_back(in_name);
}
return input_names_;
......@@ -154,11 +153,11 @@ KernelArgsNameMakerByOpProto::GetOutputArgsNames() {
auto& out = op_proto_->outputs()[i];
auto& out_name = out.name();
if ((out.has_extra() && out.extra()) || (out.has_quant() && out.quant())) {
VLOG(6) << "Parse PtenKernel output: skip extra & quant output - "
VLOG(6) << "Parse PhiKernel output: skip extra & quant output - "
<< out_name;
continue;
}
VLOG(6) << "Parse PtenKernel output: " << out_name;
VLOG(6) << "Parse PhiKernel output: " << out_name;
output_names_.emplace_back(out_name);
}
return output_names_;
......@@ -173,17 +172,17 @@ KernelArgsNameMakerByOpProto::GetAttrsArgsNames() {
attr_name == "op_role" || attr_name == "op_role_var" ||
attr_name == "op_namescope" || attr_name == "op_callstack" ||
attr_name == "op_device") {
VLOG(6) << "Parse PtenKernel attribute: skip needless attr - "
VLOG(6) << "Parse PhiKernel attribute: skip needless attr - "
<< attr_name;
continue;
}
if ((attr.has_extra() && attr.extra()) ||
(attr.has_quant() && attr.quant())) {
VLOG(6) << "Parse PtenKernel attribute: skip extra & quant attr - "
VLOG(6) << "Parse PhiKernel attribute: skip extra & quant attr - "
<< attr_name;
continue;
}
VLOG(6) << "Parse PtenKernel attribute: " << attr_name;
VLOG(6) << "Parse PhiKernel attribute: " << attr_name;
attr_names_.emplace_back(attr_name);
}
......@@ -191,7 +190,7 @@ KernelArgsNameMakerByOpProto::GetAttrsArgsNames() {
}
KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() {
return KernelSignature(phi::TransToPtenKernelName(op_proto_->type()),
return KernelSignature(phi::TransToPhiKernelName(op_proto_->type()),
GetInputArgsNames(), GetAttrsArgsNames(),
GetOutputArgsNames());
}
......@@ -203,7 +202,7 @@ void InitDefaultKernelSignatureMap() {
for (const auto& pair : paddle::framework::OpInfoMap::Instance().map()) {
const auto& op_type = pair.first;
const auto* op_proto = pair.second.proto_;
if (phi::KernelFactory::Instance().HasCompatiblePtenKernel(op_type) &&
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op_type) &&
op_proto) {
paddle::framework::KernelArgsNameMakerByOpProto maker(op_proto);
VLOG(10) << "Register kernel signature for " << op_type;
......
......@@ -44,9 +44,8 @@ using KernelSignature = phi::KernelSignature;
/* Kernel Key translate */
OpKernelType TransPtenKernelKeyToOpKernelType(const phi::KernelKey& kernel_key);
phi::KernelKey TransOpKernelTypeToPtenKernelKey(
const OpKernelType& kernel_type);
OpKernelType TransPhiKernelKeyToOpKernelType(const phi::KernelKey& kernel_key);
phi::KernelKey TransOpKernelTypeToPhiKernelKey(const OpKernelType& kernel_type);
phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
const phi::KernelKey& kernel_key,
const framework::OperatorBase& op);
......@@ -68,25 +67,25 @@ void SetAllocationForOutputTenosr(phi::TensorBase* tensor,
// TODO(Wilber): support others device context.
template <typename T>
struct ConvertToPtenContext {
struct ConvertToPhiContext {
using TYPE = T;
};
template <>
struct ConvertToPtenContext<platform::CPUDeviceContext> {
struct ConvertToPhiContext<platform::CPUDeviceContext> {
using TYPE = phi::CPUContext;
};
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template <>
struct ConvertToPtenContext<platform::CUDADeviceContext> {
struct ConvertToPhiContext<platform::CUDADeviceContext> {
using TYPE = phi::GPUContext;
};
#endif
#ifdef PADDLE_WITH_XPU
template <>
struct ConvertToPtenContext<platform::XPUDeviceContext> {
struct ConvertToPhiContext<platform::XPUDeviceContext> {
using TYPE = phi::XPUContext;
};
#endif
......
......@@ -12,17 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/framework/phi_utils.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/selected_rows_utils.h"
#include "paddle/fluid/framework/variable.h"
TEST(PtenUtils, TransPtenKernelKeyToOpKernelType) {
TEST(PhiUtils, TransPhiKernelKeyToOpKernelType) {
phi::KernelKey kernel_key(phi::Backend::CPU, phi::DataLayout::NCHW,
phi::DataType::FLOAT32);
auto op_kernel_type =
paddle::framework::TransPtenKernelKeyToOpKernelType(kernel_key);
paddle::framework::TransPhiKernelKeyToOpKernelType(kernel_key);
ASSERT_EQ(op_kernel_type.data_type_, paddle::framework::proto::VarType::FP32);
ASSERT_EQ(op_kernel_type.data_layout_, paddle::framework::DataLayout::kNCHW);
ASSERT_TRUE(paddle::platform::is_cpu_place(op_kernel_type.place_));
......@@ -33,7 +33,7 @@ TEST(PtenUtils, TransPtenKernelKeyToOpKernelType) {
phi::KernelKey kernel_key_mkldnn(phi::Backend::MKLDNN, phi::DataLayout::NCHW,
phi::DataType::FLOAT32);
op_kernel_type =
paddle::framework::TransPtenKernelKeyToOpKernelType(kernel_key_mkldnn);
paddle::framework::TransPhiKernelKeyToOpKernelType(kernel_key_mkldnn);
ASSERT_EQ(op_kernel_type.data_type_, paddle::framework::proto::VarType::FP32);
ASSERT_EQ(op_kernel_type.data_layout_, paddle::framework::DataLayout::kNCHW);
ASSERT_TRUE(paddle::platform::is_cpu_place(op_kernel_type.place_));
......@@ -45,7 +45,7 @@ TEST(PtenUtils, TransPtenKernelKeyToOpKernelType) {
phi::KernelKey kernel_key_cudnn(phi::Backend::GPUDNN, phi::DataLayout::NCHW,
phi::DataType::FLOAT32);
op_kernel_type =
paddle::framework::TransPtenKernelKeyToOpKernelType(kernel_key_cudnn);
paddle::framework::TransPhiKernelKeyToOpKernelType(kernel_key_cudnn);
ASSERT_EQ(op_kernel_type.data_type_, paddle::framework::proto::VarType::FP32);
ASSERT_EQ(op_kernel_type.data_layout_, paddle::framework::DataLayout::kNCHW);
ASSERT_TRUE(paddle::platform::is_gpu_place(op_kernel_type.place_));
......@@ -54,12 +54,12 @@ TEST(PtenUtils, TransPtenKernelKeyToOpKernelType) {
#endif
}
TEST(PtenUtils, TransOpKernelTypeToPtenKernelKey) {
TEST(PhiUtils, TransOpKernelTypeToPhiKernelKey) {
paddle::framework::OpKernelType op_kernel_type(
paddle::framework::proto::VarType::FP32, paddle::platform::CPUPlace(),
paddle::framework::DataLayout::kNCHW);
auto kernel_key =
paddle::framework::TransOpKernelTypeToPtenKernelKey(op_kernel_type);
paddle::framework::TransOpKernelTypeToPhiKernelKey(op_kernel_type);
ASSERT_EQ(kernel_key.dtype(), phi::DataType::FLOAT32);
ASSERT_EQ(kernel_key.layout(), phi::DataLayout::NCHW);
ASSERT_EQ(kernel_key.backend(), phi::Backend::CPU);
......@@ -69,8 +69,8 @@ TEST(PtenUtils, TransOpKernelTypeToPtenKernelKey) {
paddle::framework::proto::VarType::FP32, paddle::platform::CPUPlace(),
paddle::framework::DataLayout::kMKLDNN,
paddle::framework::LibraryType::kMKLDNN);
auto kernel_key_mkldnn = paddle::framework::TransOpKernelTypeToPtenKernelKey(
op_kernel_type_mkldnn);
auto kernel_key_mkldnn =
paddle::framework::TransOpKernelTypeToPhiKernelKey(op_kernel_type_mkldnn);
ASSERT_EQ(kernel_key_mkldnn.dtype(), phi::DataType::FLOAT32);
ASSERT_EQ(kernel_key_mkldnn.layout(), phi::DataLayout::MKLDNN);
ASSERT_EQ(kernel_key_mkldnn.backend(), phi::Backend::MKLDNN);
......@@ -82,7 +82,7 @@ TEST(PtenUtils, TransOpKernelTypeToPtenKernelKey) {
paddle::framework::DataLayout::kNCHW,
paddle::framework::LibraryType::kCUDNN);
auto kernel_key_cudnn =
paddle::framework::TransOpKernelTypeToPtenKernelKey(op_kernel_type_cudnn);
paddle::framework::TransOpKernelTypeToPhiKernelKey(op_kernel_type_cudnn);
ASSERT_EQ(kernel_key_cudnn.dtype(), phi::DataType::FLOAT32);
ASSERT_EQ(kernel_key_cudnn.layout(), phi::DataLayout::NCHW);
ASSERT_EQ(kernel_key_cudnn.backend(), phi::Backend::GPUDNN);
......
......@@ -1457,7 +1457,7 @@ std::ostream& print_tensor<paddle::platform::complex<double>>(
std::ostream& operator<<(std::ostream& os, const LoD& lod) {
// NOTE(xiongkun):
// https://stackoverflow.com/questions/5195512/namespaces-and-operator-resolution
// if we don't redefine, the operator << of pten / framework LoD is not found.
// if we don't redefine, the operator << of phi / framework LoD is not found.
paddle::string::operator<<(os, lod);
return os;
}
......
......@@ -70,12 +70,12 @@ OpSupportedInfos(const std::string& place,
}
}
auto pten_kernels = phi::KernelFactory::Instance().kernels();
for (auto& kernel_pair : pten_kernels) {
auto phi_kernels = phi::KernelFactory::Instance().kernels();
for (auto& kernel_pair : phi_kernels) {
auto op_type = phi::TransToFluidOpName(kernel_pair.first);
for (auto& info_pair : kernel_pair.second) {
framework::OpKernelType kernel_type =
framework::TransPtenKernelKeyToOpKernelType(info_pair.first);
framework::TransPhiKernelKeyToOpKernelType(info_pair.first);
if (is_target_place[query_place](kernel_type.place_) &&
kernel_type.data_type_ == dtype && all_ops.count(op_type)) {
VLOG(4) << op_type << " " << supported_ops.size();
......
......@@ -154,7 +154,7 @@ void BasicEngine::CheckBackwardInputs(const OpBase& op) {
// Here, we use the type of the corresponding forward datatype.
tensor->mutable_data(
op.place(), framework::TransToPtenDataType(var->ForwardDataType()));
op.place(), framework::TransToPhiDataType(var->ForwardDataType()));
VLOG(6) << "Set ungenerated Grad: " << var->Name()
<< " as zero with dtype "
<< framework::DataTypeToString(var->ForwardDataType());
......
......@@ -791,13 +791,13 @@ void EagerGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var,
<< var->Var().Get<framework::LoDTensor>().dims();
tensor->Resize(var->Var().Get<framework::LoDTensor>().dims());
tensor->mutable_data(place,
framework::TransToPtenDataType(var->DataType()));
framework::TransToPhiDataType(var->DataType()));
phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
} else {
auto* tensor =
dst_var->MutableVar()->GetMutable<framework::LoDTensor>();
tensor->mutable_data(place,
framework::TransToPtenDataType(var->DataType()));
framework::TransToPhiDataType(var->DataType()));
phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
}
}
......@@ -925,13 +925,13 @@ void SortedGradientAccumulator::SumGrad(std::shared_ptr<VariableWrapper> var,
<< var->Var().Get<framework::LoDTensor>().dims();
tensor->Resize(var->Var().Get<framework::LoDTensor>().dims());
tensor->mutable_data(place,
framework::TransToPtenDataType(var->DataType()));
framework::TransToPhiDataType(var->DataType()));
phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
} else {
auto* tensor =
dst_var->MutableVar()->GetMutable<framework::LoDTensor>();
tensor->mutable_data(place,
framework::TransToPtenDataType(var->DataType()));
framework::TransToPhiDataType(var->DataType()));
phi::funcs::set_constant(*dev_ctx, tensor, 0.0);
}
}
......
......@@ -314,10 +314,10 @@ static void FillConstantLike(const VariableWrapper &ref_var,
// default data_type for now.
if (ref_var.ForwardDataType() != -1) {
dst_tensor->mutable_data(
place, framework::TransToPtenDataType(ref_var.ForwardDataType()));
place, framework::TransToPhiDataType(ref_var.ForwardDataType()));
} else {
dst_tensor->mutable_data(
place, framework::TransToPtenDataType(ref_var.DataType()));
dst_tensor->mutable_data(place,
framework::TransToPhiDataType(ref_var.DataType()));
}
phi::funcs::set_constant(*dev_ctx, dst_tensor, value);
}
......
......@@ -121,7 +121,7 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op,
kernel_type_(kernel_type),
func_(nullptr),
dev_ctx_(dev_ctx),
run_pten_kernel_(true),
run_phi_kernel_(true),
pt_kernel_signature_(kernel_signature),
pt_kernel_(pt_kernel) {}
......@@ -151,7 +151,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
#endif
// NOTE(zhiqiu): for kernels on given device, for example NPU, the order to
// choose is:
// pten npu kernel > fluid npu kernel > pten cpu kernel > fluid cpu kernel
// phi npu kernel > fluid npu kernel > phi cpu kernel > fluid cpu kernel
// 1. get expected kernel key
auto dygraph_exe_ctx = DygraphExecutionContext<VarType>(
......@@ -168,12 +168,12 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
expected_kernel_key) ||
paddle::platform::is_in_xpu_black_list(op.Type());
#endif
if (phi::KernelFactory::Instance().HasCompatiblePtenKernel(op.Type())) {
pt_kernel_signature = op.GetExpectedPtenKernelArgs(dygraph_exe_ctx);
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op.Type())) {
pt_kernel_signature = op.GetExpectedPhiKernelArgs(dygraph_exe_ctx);
VLOG(6) << pt_kernel_signature;
pt_kernel_name = pt_kernel_signature.name;
pt_kernel_key = TransOpKernelTypeToPtenKernelKey(expected_kernel_key);
pt_kernel_key = TransOpKernelTypeToPhiKernelKey(expected_kernel_key);
auto pt_kernel = phi::KernelFactory::Instance().SelectKernel(pt_kernel_name,
pt_kernel_key);
......@@ -195,7 +195,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
return PreparedOp(op, ctx, expected_kernel_key, pt_kernel_signature,
pt_kernel, dev_ctx);
} else {
VLOG(6) << "Dynamic mode ChoosePtenKernel - kernel `" << pt_kernel_name
VLOG(6) << "Dynamic mode ChoosePhiKernel - kernel `" << pt_kernel_name
<< "` not found.";
}
}
......@@ -211,7 +211,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
|| is_xpu_unsupport
#endif
) {
if (phi::KernelFactory::Instance().HasCompatiblePtenKernel(op.Type())) {
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op.Type())) {
auto pt_cpu_kernel_key =
FallBackToCpu(expected_kernel_key, pt_kernel_key, op);
auto pt_cpu_kernel = phi::KernelFactory::Instance().SelectKernel(
......@@ -423,12 +423,12 @@ static void PreparedOpRunPtImpl(
platform::TracerEventType::OperatorInner,
1, platform::EventRole::kInnerOp);
PreparePtenData<VarType>(pt_kernel, pt_kernel_signature, ins);
PreparePhiData<VarType>(pt_kernel, pt_kernel_signature, ins);
phi::KernelContext pt_kernel_context;
BuildDygraphPtenKernelContext<VarType>(pt_kernel_signature, pt_kernel, ins,
outs, attrs, default_attrs, dev_ctx,
&pt_kernel_context);
BuildDygraphPhiKernelContext<VarType>(pt_kernel_signature, pt_kernel, ins,
outs, attrs, default_attrs, dev_ctx,
&pt_kernel_context);
pt_kernel(&pt_kernel_context);
}
......@@ -451,7 +451,7 @@ void PreparedOp::Run(const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
if (run_pten_kernel_) {
if (run_phi_kernel_) {
PreparedOpRunPtImpl<VarBase>(op_, kernel_type_, pt_kernel_signature_,
pt_kernel_, dev_ctx_, ins, outs, attrs,
default_attrs);
......@@ -465,7 +465,7 @@ void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
if (run_pten_kernel_) {
if (run_phi_kernel_) {
PreparedOpRunPtImpl<VariableWrapper>(
op_, kernel_type_, pt_kernel_signature_, pt_kernel_, dev_ctx_, ins,
outs, attrs, default_attrs);
......@@ -479,7 +479,7 @@ void PreparedOp::Run(const NameVarMap<egr::EagerVariable>& ins,
const NameVarMap<egr::EagerVariable>& outs,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
if (run_pten_kernel_) {
if (run_phi_kernel_) {
PreparedOpRunPtImpl<egr::EagerVariable>(
op_, kernel_type_, pt_kernel_signature_, pt_kernel_, dev_ctx_, ins,
outs, attrs, default_attrs);
......
......@@ -22,7 +22,7 @@
#include "paddle/fluid/framework/data_transform.h"
#include "paddle/fluid/framework/op_kernel_type.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/imperative/execution_context.h"
#include "paddle/fluid/imperative/layer.h"
......@@ -201,9 +201,9 @@ class PreparedOp {
framework::OperatorWithKernel::OpKernelFunc func_;
platform::DeviceContext* dev_ctx_;
// NOTE(chenweihang): Similar op members are used to adapt to
// new pten kernel, if there is a better design in the future,
// new phi kernel, if there is a better design in the future,
// we may polish the implementation here
bool run_pten_kernel_{false};
bool run_phi_kernel_{false};
bool run_kp_kernel_{false};
framework::KernelSignature pt_kernel_signature_;
phi::Kernel pt_kernel_;
......@@ -225,7 +225,7 @@ const inline framework::Attribute& GetAttr(
}
template <typename VarType>
void BuildDygraphPtenKernelContext(
void BuildDygraphPhiKernelContext(
const framework::KernelSignature& pt_kernel_signature,
const phi::Kernel& pt_kernel, const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs,
......@@ -327,7 +327,7 @@ void BuildDygraphPtenKernelContext(
experimental::ResetTensorDtypeAndLayoutByArgDef(tensor_out,
output_defs.at(i));
framework::SetAllocationForOutputTenosr(
tensor_out, phi::TransToPtenPlace(output_defs.at(i).backend));
tensor_out, phi::TransToPhiPlace(output_defs.at(i).backend));
kernel_ctx->EmplaceBackOutputWithoutSetRange(tensor_out);
}
......@@ -369,7 +369,7 @@ void BuildDygraphPtenKernelContext(
auto& ins_vector = ins.at(attr_names[i]);
if (ins_vector.size() == 1) { // ShapeTensor
kernel_ctx->EmplaceBackAttr(std::move(
experimental::MakePtenScalarArrayFromVar(ins_vector[0]->Var())));
experimental::MakePhiScalarArrayFromVar(ins_vector[0]->Var())));
} else { // ShapeTensorList
std::vector<framework::Variable*> variables;
variables.reserve(ins_vector.size());
......@@ -377,7 +377,7 @@ void BuildDygraphPtenKernelContext(
variables.push_back(var_base->MutableVar());
}
kernel_ctx->EmplaceBackAttr(std::move(
experimental::MakePtenScalarArrayFromVarList(variables)));
experimental::MakePhiScalarArrayFromVarList(variables)));
}
}
} else if (attr_defs[i].type_index ==
......@@ -409,7 +409,7 @@ void BuildDygraphPtenKernelContext(
} else { // scalar is in the input
auto& ins_vector = ins.at(attr_names[i]);
kernel_ctx->EmplaceBackAttr(std::move(
experimental::MakePtenScalarFromVar(ins_vector[0]->Var())));
experimental::MakePhiScalarFromVar(ins_vector[0]->Var())));
}
} else {
......@@ -428,7 +428,7 @@ void BuildDygraphPtenKernelContext(
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::string, attr));
} else if (attr_defs[i].type_index ==
std::type_index(typeid(phi::DataType))) {
auto data_type = framework::TransToPtenDataType(
auto data_type = framework::TransToPhiDataType(
static_cast<framework::proto::VarType::Type>(
BOOST_GET_CONST(int, attr)));
kernel_ctx->EmplaceBackAttr(data_type);
......@@ -436,7 +436,7 @@ void BuildDygraphPtenKernelContext(
std::type_index(typeid(std::vector<int64_t>))) {
if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int>))) {
// Emplace Back Attr according to the type of Pten_Kernel args.
// Emplace Back Attr according to the type of Phi_Kernel args.
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
const std::vector<int64_t> vector_int64_attr(vector_int_attr.begin(),
vector_int_attr.end());
......@@ -456,9 +456,9 @@ void BuildDygraphPtenKernelContext(
}
template <typename VarType>
void PreparePtenData(const phi::Kernel& pt_kernel,
const framework::KernelSignature& pt_kernel_signature,
const NameVarMap<VarType>& ins) {
void PreparePhiData(const phi::Kernel& pt_kernel,
const framework::KernelSignature& pt_kernel_signature,
const NameVarMap<VarType>& ins) {
auto& input_names = std::get<0>(pt_kernel_signature.args);
auto& input_defs = pt_kernel.args_def().input_defs();
......@@ -482,12 +482,12 @@ void PreparePtenData(const phi::Kernel& pt_kernel,
if (in_def.backend == phi::Backend::ALL_BACKEND) {
continue;
}
auto expected_place = phi::TransToPtenPlace(in_def.backend);
auto expected_place = phi::TransToPhiPlace(in_def.backend);
if (platform::is_same_place(tensor_in->place(), expected_place)) {
continue;
}
VLOG(3) << "Pten Transform Variable " << input_names[i] << " from "
VLOG(3) << "Phi Transform Variable " << input_names[i] << " from "
<< tensor_in->place() << " to " << expected_place;
framework::Tensor tmp_tensor;
......
......@@ -446,7 +446,7 @@ void Reducer::InitializeGroups(
InitializeDenseGroups(variable_indices_, &group);
auto tensor = group.dense_contents_.GetMutable<framework::LoDTensor>();
tensor->Resize(phi::make_ddim({group.all_length_}))
.mutable_data(place_, framework::TransToPtenDataType(group.dtype_));
.mutable_data(place_, framework::TransToPhiDataType(group.dtype_));
}
// map variables to this group by VariableLocator
......@@ -738,7 +738,7 @@ void Reducer::MarkVarReady(const size_t var_index, const bool is_used_var) {
if (!group_tensor.IsInitialized()) {
group_tensor.Resize({static_cast<int64_t>(length)});
group_tensor.mutable_data(place_,
framework::TransToPtenDataType(group.dtype_));
framework::TransToPhiDataType(group.dtype_));
}
#ifdef PADDLE_WITH_XPU_BKCL
......
......@@ -96,7 +96,7 @@ void GroupConcatSplit(Place place, size_t size) {
{ // concat
auto* tensor = group.dense_contents_.GetMutable<framework::LoDTensor>();
tensor->Resize(phi::make_ddim({group.all_length_}))
.mutable_data(place, framework::TransToPtenDataType(group.dtype_));
.mutable_data(place, framework::TransToPhiDataType(group.dtype_));
group.ConcatTensors(*dev_ctx);
group.DivNRanks(*dev_ctx, 1);
......
......@@ -26,7 +26,7 @@ limitations under the License. */
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/pybind/pybind.h"
// pten
// phi
#include "paddle/phi/kernels/declarations.h"
DEFINE_string(devices, "", "The devices to be used which is joined by comma.");
......
......@@ -198,7 +198,7 @@ void InitDstTensor(framework::LoDTensor* dst,
const paddle::lite_api::Tensor& src) {
dst->mutable_data(
inference::lite::utils::GetNativePlace(src.target()),
framework::TransToPtenDataType(GetNativePrecisionType(src.precision())));
framework::TransToPhiDataType(GetNativePrecisionType(src.precision())));
SetLoD(dst->mutable_lod(), src.lod());
}
......@@ -269,7 +269,7 @@ void TensorDataShare(framework::LoDTensor* dst, paddle::lite_api::Tensor* src) {
SetLoD(dst->mutable_lod(), src->lod());
dst->ResetHolderWithType(
holder,
framework::TransToPtenDataType(GetNativePrecisionType(src->precision())));
framework::TransToPhiDataType(GetNativePrecisionType(src->precision())));
}
} // namespace utils
......
......@@ -24,7 +24,7 @@ limitations under the License. */
#include "paddle/fluid/platform/timer.h"
#include "paddle/fluid/pybind/pybind.h"
// pten
// phi
#include "paddle/phi/kernels/declarations.h"
namespace paddle {
......
......@@ -138,7 +138,7 @@ class CastOp : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
// cast use pten kernel, so no need to REGISTER_OP_CPU_KERNEL here.
// cast use phi kernel, so no need to REGISTER_OP_CPU_KERNEL here.
REGISTER_OPERATOR(cast, ops::CastOp,
ops::CastOpGradMaker<paddle::framework::OpDesc>,
ops::CastOpGradMaker<paddle::imperative::OpBase>,
......
......@@ -63,12 +63,12 @@ class CastOpKernel : public framework::OpKernel<InT> {
out->mutable_data(dev_ctx.GetPlace(),
static_cast<framework::proto::VarType::Type>(out_dtype));
auto pt_out_dtype = framework::TransToPtenDataType(
auto pt_out_dtype = framework::TransToPhiDataType(
static_cast<framework::proto::VarType::Type>(out_dtype));
// call new kernel
phi::CastKernel<InT>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
static_cast<const typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*in, pt_out_dtype, out);
}
......
......@@ -46,11 +46,11 @@ class CastXPUKernel : public framework::OpKernel<InT> {
out->mutable_data(dev_ctx.GetPlace(),
static_cast<framework::proto::VarType::Type>(out_dtype));
auto pt_out_dtype = framework::TransToPtenDataType(
auto pt_out_dtype = framework::TransToPhiDataType(
static_cast<framework::proto::VarType::Type>(out_dtype));
// call pten kernel
// call phi kernel
phi::CastKernel<InT>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
static_cast<const typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*in, pt_out_dtype, out);
}
......
......@@ -203,7 +203,7 @@ class CholeskySolveGradKernel : public framework::OpKernel<T> {
commonterm_conj = helper.Transpose(commonterm_conj);
phi::AddRawKernel<T>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
static_cast<const typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE &>(dev_ctx),
commonterm, commonterm_conj, -1, &commonterm);
......
......@@ -54,7 +54,7 @@ struct FillConstantVisitor {
* = nullptr) const {
#ifdef PADDLE_WITH_ASCEND_CL
if (platform::is_npu_place(dev_ctx_.GetPlace())) {
Tensor tensor_tmp(framework::TransToPtenDataType(dtype_));
Tensor tensor_tmp(framework::TransToPhiDataType(dtype_));
tensor_tmp.mutable_data<T>({1}, context_.GetPlace());
FillNpuTensorWithConstant<T>(&tensor_tmp, static_cast<T>(value_));
......@@ -194,7 +194,7 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> {
void *fused_tensor_ptr =
fused_tensor->Resize(phi::make_ddim({static_cast<int64_t>(numel)}))
.mutable_data(context.GetPlace(),
framework::TransToPtenDataType(dtype));
framework::TransToPhiDataType(dtype));
VLOG(10) << "Fused tensor addr " << fused_tensor_ptr;
// Init the continuous space
......
......@@ -37,7 +37,7 @@ class ConjKernel : public framework::OpKernel<T> {
// call new kernel
phi::ConjKernel<T>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
static_cast<const typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*x, out);
}
......
......@@ -41,9 +41,9 @@ class DotKernel : public framework::OpKernel<T> {
out->mutable_data<T>(x->place());
// call new kernel
phi::DotKernel<T, typename paddle::framework::ConvertToPtenContext<
phi::DotKernel<T, typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
static_cast<const typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*x, *y, out);
}
......@@ -66,7 +66,7 @@ class DotGradKernel : public framework::OpKernel<T> {
// call new kernel
phi::DotGradKernel<T>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
static_cast<const typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*tensor_x, *tensor_y, *tensor_dout, tensor_dx, tensor_dy);
}
......
......@@ -55,7 +55,7 @@ class ElementwiseAddKernel : public framework::OpKernel<T> {
auto &dev_ctx = ctx.device_context<DeviceContext>();
int axis = ctx.Attr<int>("axis");
phi::AddRawKernel<T>(
static_cast<const typename framework::ConvertToPtenContext<
static_cast<const typename framework::ConvertToPhiContext<
DeviceContext>::TYPE &>(dev_ctx),
*x, *y, axis, z);
#endif
......
......@@ -63,11 +63,11 @@ class ElementwiseDivKernel : public framework::OpKernel<T> {
auto& dev_ctx = ctx.device_context<DeviceContext>();
int axis = ctx.Attr<int>("axis");
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
auto pt_z = paddle::experimental::MakePtenDenseTensor(*z);
auto pt_x = paddle::experimental::MakePhiDenseTensor(*x);
auto pt_y = paddle::experimental::MakePhiDenseTensor(*y);
auto pt_z = paddle::experimental::MakePhiDenseTensor(*z);
phi::DivideRawKernel<T>(
static_cast<const typename framework::ConvertToPtenContext<
static_cast<const typename framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*pt_x.get(), *pt_y.get(), axis, pt_z.get());
}
......
......@@ -49,9 +49,9 @@ class ElementwiseMulKernel<platform::CUDADeviceContext, T>
z_lod->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis");
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x_lod);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y_lod);
auto pt_z = paddle::experimental::MakePtenDenseTensor(*z_lod);
auto pt_x = paddle::experimental::MakePhiDenseTensor(*x_lod);
auto pt_y = paddle::experimental::MakePhiDenseTensor(*y_lod);
auto pt_z = paddle::experimental::MakePhiDenseTensor(*z_lod);
phi::MultiplyRawKernel<T>(static_cast<const phi::GPUContext&>(cuda_ctx),
*pt_x.get(), *pt_y.get(), axis, pt_z.get());
} else {
......
......@@ -122,11 +122,11 @@ class ElementwiseMulKernel : public framework::OpKernel<T> {
auto& dev_ctx = ctx.device_context<DeviceContext>();
int axis = ctx.Attr<int>("axis");
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x_lod);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
auto pt_z = paddle::experimental::MakePtenDenseTensor(*z_lod);
auto pt_x = paddle::experimental::MakePhiDenseTensor(*x_lod);
auto pt_y = paddle::experimental::MakePhiDenseTensor(*y);
auto pt_z = paddle::experimental::MakePhiDenseTensor(*z_lod);
phi::MultiplyRawKernel<T>(
static_cast<const typename framework::ConvertToPtenContext<
static_cast<const typename framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*pt_x.get(), *pt_y.get(), axis, pt_z.get());
} else {
......
......@@ -31,18 +31,18 @@ void LaunchElementwiseCudaKernel(
std::vector<phi::DenseTensor *> pt_outputs;
// TODO(YuanRisheng) *_tmp for cache DenseTensor, because the temporary
// DenseTensor obj
// generated by MakePtenDenseTensor can be destroyed when exits loop. *_tmp
// generated by MakePhiDenseTensor can be destroyed when exits loop. *_tmp
// can be deleted
// when DenseTensor support copy constructor.
std::vector<std::unique_ptr<phi::DenseTensor>> pt_inputs_tmp;
std::vector<std::unique_ptr<phi::DenseTensor>> pt_outputs_tmp;
for (auto in : ins) {
pt_inputs_tmp.emplace_back(
std::move(paddle::experimental::MakePtenDenseTensor(*in)));
std::move(paddle::experimental::MakePhiDenseTensor(*in)));
}
for (auto out : *outs) {
pt_outputs_tmp.emplace_back(
std::move(paddle::experimental::MakePtenDenseTensor(*out)));
std::move(paddle::experimental::MakePhiDenseTensor(*out)));
}
for (int i = 0; i < pt_inputs_tmp.size(); i++) {
pt_inputs.push_back(pt_inputs_tmp[i].get());
......
......@@ -14,7 +14,7 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/framework/tensor.h"
// only can include the headers in paddle/top/api dirs
......@@ -34,18 +34,18 @@ void LaunchSameDimsElementwiseCudaKernel(
std::vector<phi::DenseTensor *> pt_outputs;
// TODO(YuanRisheng) *_tmp for cache DenseTensor, because the temporary
// DenseTensor obj
// generated by MakePtenDenseTensor can be destroyed when exits loop. *_tmp
// generated by MakePhiDenseTensor can be destroyed when exits loop. *_tmp
// can be deleted
// when DenseTensor support copy constructor.
std::vector<std::unique_ptr<phi::DenseTensor>> pt_inputs_tmp;
std::vector<std::unique_ptr<phi::DenseTensor>> pt_outputs_tmp;
for (auto in : ins) {
pt_inputs_tmp.emplace_back(
std::move(paddle::experimental::MakePtenDenseTensor(*in)));
std::move(paddle::experimental::MakePhiDenseTensor(*in)));
}
for (auto out : *outs) {
pt_outputs_tmp.emplace_back(
std::move(paddle::experimental::MakePtenDenseTensor(*out)));
std::move(paddle::experimental::MakePhiDenseTensor(*out)));
}
for (int i = 0; i < pt_inputs_tmp.size(); i++) {
pt_inputs.push_back(pt_inputs_tmp[i].get());
......
......@@ -34,7 +34,7 @@ class ElementwiseSubKernel : public framework::OpKernel<T> {
auto& dev_ctx = ctx.device_context<DeviceContext>();
int axis = ctx.Attr<int>("axis");
phi::SubtractRawKernel<T>(
static_cast<const typename framework::ConvertToPtenContext<
static_cast<const typename framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*x, *y, axis, z);
}
......@@ -56,7 +56,7 @@ class ElementwiseSubGradKernel : public ElemwiseGradKernel<T> {
auto& dev_ctx = ctx.device_context<DeviceContext>();
phi::SubtractGradKernel<T>(
static_cast<const typename framework::ConvertToPtenContext<
static_cast<const typename framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*x, *y, *dout, axis, dx, dy);
}
......@@ -86,7 +86,7 @@ class ElementwiseSubDoubleGradKernel : public framework::OpKernel<T> {
ddy_optional = *ddy;
}
phi::SubtractDoubleGradKernel<T>(
static_cast<const typename framework::ConvertToPtenContext<
static_cast<const typename framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*y, ddx_optional, ddy_optional, *dout, axis, ddout);
}
......
......@@ -39,7 +39,7 @@ class EmptyKernel : public framework::OpKernel<T> {
out_tensor->Resize(shape);
out_tensor->mutable_data(context.GetPlace(),
framework::TransToPtenDataType(dtype));
framework::TransToPhiDataType(dtype));
}
};
......
......@@ -54,7 +54,7 @@ class FillAnyLikeNPUKernel : public framework::OpKernel<T> {
std::isnan(value), false,
platform::errors::InvalidArgument("The filled value is NaN."));
Tensor tensor_tmp(framework::TransToPtenDataType(data_type));
Tensor tensor_tmp(framework::TransToPhiDataType(data_type));
tensor_tmp.mutable_data<T>({1}, context.GetPlace());
FillNpuTensorWithConstant<T>(&tensor_tmp, static_cast<T>(value));
......
......@@ -60,9 +60,9 @@ class FillAnyLikeXPUKernel : public framework::OpKernel<T> {
auto& dev_ctx =
context.template device_context<paddle::platform::XPUDeviceContext>();
// call pten kernel
// call phi kernel
phi::FullLikeKernel<T>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
static_cast<const typename paddle::framework::ConvertToPhiContext<
paddle::platform::XPUDeviceContext>::TYPE&>(dev_ctx),
*x, value, phi::DataType::UNDEFINED, out);
}
......
......@@ -63,7 +63,7 @@ class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel<T> {
auto &dev_ctx = *pool.Get(platform::CPUPlace());
phi::funcs::SetConstant<platform::CPUDeviceContext, T> functor;
out->mutable_data(platform::CPUPlace(),
framework::TransToPtenDataType(data_type));
framework::TransToPhiDataType(data_type));
functor(reinterpret_cast<const platform::CPUDeviceContext &>(dev_ctx),
out, static_cast<T>(value));
}
......@@ -72,7 +72,7 @@ class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel<T> {
auto &dev_ctx = *pool.Get(ctx.GetPlace());
phi::funcs::SetConstant<platform::CUDADeviceContext, T> functor;
out->mutable_data(ctx.GetPlace(),
framework::TransToPtenDataType(data_type));
framework::TransToPhiDataType(data_type));
functor(reinterpret_cast<const platform::CUDADeviceContext &>(dev_ctx),
out, static_cast<T>(value));
}
......
......@@ -72,13 +72,13 @@ class FillConstantBatchSizeLikeOpNPUKernel : public framework::OpKernel<T> {
auto &dev_ctx = *pool.Get(platform::CPUPlace());
phi::funcs::SetConstant<platform::CPUDeviceContext, T> functor;
out->mutable_data(platform::CPUPlace(),
framework::TransToPtenDataType(data_type));
framework::TransToPhiDataType(data_type));
functor(reinterpret_cast<const platform::CPUDeviceContext &>(dev_ctx),
out, static_cast<T>(value));
} else {
out->mutable_data(ctx.GetPlace(),
framework::TransToPtenDataType(data_type));
Tensor tensor_tmp(framework::TransToPtenDataType(data_type));
framework::TransToPhiDataType(data_type));
Tensor tensor_tmp(framework::TransToPhiDataType(data_type));
tensor_tmp.mutable_data<T>({1}, ctx.GetPlace());
FillNpuTensorWithConstant<T>(&tensor_tmp, value);
......
......@@ -122,7 +122,7 @@ class FillConstantKernel : public framework::OpKernel<T> {
<< ((data_type == framework::proto::VarType::BF16) ? "<bfloat16>"
: "<T>");
tensor->mutable_data(platform::CPUPlace(),
framework::TransToPtenDataType(data_type));
framework::TransToPhiDataType(data_type));
phi::funcs::SetConstant<platform::CPUDeviceContext, T> functor;
auto &dev_ctx = *pool.Get(platform::CPUPlace());
functor(reinterpret_cast<const platform::CPUDeviceContext &>(dev_ctx),
......@@ -130,7 +130,7 @@ class FillConstantKernel : public framework::OpKernel<T> {
} else if (actual_place == 1) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
tensor->mutable_data(ctx.GetPlace(),
framework::TransToPtenDataType(data_type));
framework::TransToPhiDataType(data_type));
phi::funcs::SetConstant<platform::CUDADeviceContext, T> functor;
auto &dev_ctx = *pool.Get(ctx.GetPlace());
functor(reinterpret_cast<const platform::CUDADeviceContext &>(dev_ctx),
......@@ -142,7 +142,7 @@ class FillConstantKernel : public framework::OpKernel<T> {
} else if (actual_place == 2) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
tensor->mutable_data(platform::CUDAPinnedPlace(),
framework::TransToPtenDataType(data_type));
framework::TransToPhiDataType(data_type));
phi::funcs::SetConstant<platform::CUDAPinnedDeviceContext, T> functor;
auto &dev_ctx = *pool.Get(platform::CUDAPinnedPlace());
functor(
......@@ -155,7 +155,7 @@ class FillConstantKernel : public framework::OpKernel<T> {
} else if (actual_place == 3) {
#ifdef PADDLE_WITH_XPU
tensor->mutable_data(ctx.GetPlace(),
framework::TransToPtenDataType(data_type));
framework::TransToPhiDataType(data_type));
phi::funcs::SetConstant<platform::XPUDeviceContext, T> functor;
auto &dev_ctx = *pool.Get(ctx.GetPlace());
functor(reinterpret_cast<const platform::XPUDeviceContext &>(dev_ctx),
......
......@@ -61,7 +61,7 @@ class FillConstantNPUKernel : public framework::OpKernel<T> {
out_var->mutable_data<T>(shape, ctx.GetPlace());
if (data_type != framework::proto::VarType::BOOL) {
Tensor tensor_value(framework::TransToPtenDataType(data_type));
Tensor tensor_value(framework::TransToPhiDataType(data_type));
tensor_value.mutable_data<T>({1}, ctx.GetPlace());
FillNpuTensorWithConstant<T>(&tensor_value, value);
NpuOpRunner runner;
......
......@@ -49,10 +49,10 @@ class FillKernel : public framework::OpKernel<T> {
out.Resize(phi::make_ddim(ctx.Attr<std::vector<int>>("shape")));
auto dtype =
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype"));
auto pten_dtype = framework::TransToPtenDataType(dtype);
auto phi_dtype = framework::TransToPhiDataType(dtype);
platform::CPUPlace cpu;
auto force_cpu = ctx.Attr<bool>("force_cpu");
out.mutable_data(force_cpu ? cpu : ctx.GetPlace(), pten_dtype);
out.mutable_data(force_cpu ? cpu : ctx.GetPlace(), phi_dtype);
framework::LoDTensor tensor;
......@@ -61,7 +61,7 @@ class FillKernel : public framework::OpKernel<T> {
} else {
// Always make tensor in CPU memory.
tensor.Resize(out.dims());
tensor.mutable_data(cpu, pten_dtype);
tensor.mutable_data(cpu, phi_dtype);
}
framework::VisitDataType(
......
......@@ -15,7 +15,7 @@ limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/operators/math/pooling.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/kernels/empty_kernel.h"
......@@ -132,9 +132,9 @@ class FlattenContiguousRangeKernel : public framework::OpKernel<T> {
auto &dev_ctx = context.device_context<DeviceContext>();
// call new kernel
phi::FlattenKernel<T, typename paddle::framework::ConvertToPtenContext<
phi::FlattenKernel<T, typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
static_cast<const typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE &>(dev_ctx),
*in, start_axis, stop_axis, out);
}
......@@ -153,9 +153,9 @@ class FlattenContiguousRangeGradKernel : public framework::OpKernel<T> {
auto &dev_ctx = ctx.device_context<DeviceContext>();
// call new kernel
phi::FlattenGradKernel<T, typename paddle::framework::ConvertToPtenContext<
phi::FlattenGradKernel<T, typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
static_cast<const typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE &>(dev_ctx),
*d_out, *xshape, d_x);
}
......
......@@ -34,9 +34,9 @@ class EmbeddingEltWiseLayerNormKernel : public framework::OpKernel<T> {
int input_num = static_cast<int>(ids.size());
framework::Tensor in_ids_(
framework::TransToPtenDataType(framework::proto::VarType::INT64)),
framework::TransToPhiDataType(framework::proto::VarType::INT64)),
in_embs_(
framework::TransToPtenDataType(framework::proto::VarType::INT64));
framework::TransToPhiDataType(framework::proto::VarType::INT64));
framework::DDim in_dim{input_num};
int device_id;
#ifdef PADDLE_WITH_HIP
......
......@@ -88,8 +88,8 @@ void SetValueCompute(const framework::ExecutionContext& ctx,
// set_value is what we want.
paddle::framework::TensorCopy(*in, place, out);
Tensor slice_tensor(framework::TransToPtenDataType(dtype)),
pad_tensor(framework::TransToPtenDataType(dtype));
Tensor slice_tensor(framework::TransToPhiDataType(dtype)),
pad_tensor(framework::TransToPhiDataType(dtype));
slice_tensor.mutable_data<T>(slice_dims, place);
pad_tensor.mutable_data<T>(in_dims, place);
......@@ -147,7 +147,7 @@ void SetValueCompute(const framework::ExecutionContext& ctx,
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, &slice_tensor, value_tensor, -1, SubFunctor<T>(), &slice_tensor);
} else {
Tensor value_t(framework::TransToPtenDataType(dtype));
Tensor value_t(framework::TransToPhiDataType(dtype));
auto value_dims = phi::make_ddim(shape);
CheckIsDimsMatch(slice_dims_for_assign, value_dims);
......@@ -224,8 +224,8 @@ void Tensor_Add(const DeviceContext& dev_ctx, const framework::Tensor& src1,
out->mutable_data<T>(dev_ctx.GetPlace());
phi::AddRawKernel<
T, typename paddle::framework::ConvertToPtenContext<DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
T, typename paddle::framework::ConvertToPhiContext<DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
src1, src2, -1, out);
}
......@@ -237,8 +237,8 @@ void Tensor_Sub(const DeviceContext& dev_ctx, const framework::Tensor& src1,
out->mutable_data<T>(dev_ctx.GetPlace());
phi::SubtractRawKernel<
T, typename paddle::framework::ConvertToPtenContext<DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
T, typename paddle::framework::ConvertToPhiContext<DeviceContext>::TYPE>(
static_cast<const typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
src1, src2, -1, out);
}
......
......@@ -35,8 +35,8 @@ class AccuracyMLUKernel : public framework::OpKernel<T> {
}
// cast `indices` or `label` if their type is not INT32
Tensor indices_int32(framework::TransToPtenDataType(VT::INT32));
Tensor label_int32(framework::TransToPtenDataType(VT::INT32));
Tensor indices_int32(framework::TransToPhiDataType(VT::INT32));
Tensor label_int32(framework::TransToPhiDataType(VT::INT32));
auto indices_type = framework::TransToProtoVarType(indices->type());
if (indices_type != VT::INT32) {
PADDLE_ENFORCE_EQ(MLUSupportsCast(indices_type, VT::INT32), true,
......@@ -78,7 +78,7 @@ class AccuracyMLUKernel : public framework::OpKernel<T> {
// equal
MLUCnnlTensorDesc indices_int32_desc(indices_int32);
MLUCnnlTensorDesc label_int32_desc(label_int32);
Tensor equal_tensor(framework::TransToPtenDataType(VT::BOOL));
Tensor equal_tensor(framework::TransToPhiDataType(VT::BOOL));
equal_tensor.Resize(indices->dims());
equal_tensor.mutable_data<bool>(ctx.GetPlace());
MLUCnnlTensorDesc equal_tensor_desc(equal_tensor);
......@@ -88,7 +88,7 @@ class AccuracyMLUKernel : public framework::OpKernel<T> {
GetBasePtr(&equal_tensor));
// cast equal
Tensor equal_fp32(framework::TransToPtenDataType(VT::FP32));
Tensor equal_fp32(framework::TransToPhiDataType(VT::FP32));
equal_fp32.Resize(indices->dims());
equal_fp32.mutable_data<float>(ctx.GetPlace());
MLUCnnlTensorDesc equal_fp32_desc(equal_fp32);
......@@ -99,7 +99,7 @@ class AccuracyMLUKernel : public framework::OpKernel<T> {
// [correct]
// reduce_max
Tensor correct_max(framework::TransToPtenDataType(VT::FP32));
Tensor correct_max(framework::TransToPhiDataType(VT::FP32));
correct_max.Resize(phi::make_ddim({num_samples}));
correct_max.mutable_data<float>(ctx.GetPlace());
MLUCnnlTensorDesc correct_max_desc(correct_max);
......@@ -112,7 +112,7 @@ class AccuracyMLUKernel : public framework::OpKernel<T> {
correct_max_desc.get(), GetBasePtr(&correct_max));
// reduce_sum
Tensor correct_sum(framework::TransToPtenDataType(VT::FP32));
Tensor correct_sum(framework::TransToPhiDataType(VT::FP32));
correct_sum.Resize(correct->dims());
correct_sum.mutable_data<float>(ctx.GetPlace());
MLUCnnlTensorDesc correct_sum_desc(correct_sum);
......@@ -138,7 +138,7 @@ class AccuracyMLUKernel : public framework::OpKernel<T> {
MLUCnnl::Fill(ctx, num_samples, total_desc.get(), GetBasePtr(total));
// use `total` of type `float32` for calculating accuracy
Tensor total_fp32(framework::TransToPtenDataType(VT::FP32));
Tensor total_fp32(framework::TransToPhiDataType(VT::FP32));
total_fp32.Resize(total->dims());
total_fp32.mutable_data<float>(ctx.GetPlace());
MLUCnnlTensorDesc total_fp32_desc(total_fp32);
......
......@@ -85,7 +85,7 @@ inline cnnlDataType_t ToCnnlDataType(
inline cnnlDataType_t ToCnnlDataType(
const paddle::framework::proto::VarType::Type& type) {
return ToCnnlDataType(framework::TransToPtenDataType(type));
return ToCnnlDataType(framework::TransToPhiDataType(type));
}
template <typename T>
......
......@@ -257,12 +257,12 @@ class ReduceKernel : public framework::OpKernel<T> {
std::vector<int64_t> tmp_dims(dims.begin(), dims.end());
// call new kernel
phi::Reduce<typename framework::ConvertToPtenContext<DeviceContext>::TYPE,
T, Functor>(
static_cast<const typename framework::ConvertToPtenContext<
phi::Reduce<typename framework::ConvertToPhiContext<DeviceContext>::TYPE, T,
Functor>(
static_cast<const typename framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*input, reduce_all, tmp_dims, keep_dim,
framework::TransToPtenDataType(cast_out_dtype), output);
framework::TransToPhiDataType(cast_out_dtype), output);
}
};
template <typename DeviceContext, typename OutT, typename Functor>
......@@ -684,7 +684,7 @@ class ReduceCudaKernel : public framework::OpKernel<T> {
const Tensor* input = context.Input<Tensor>("X");
Tensor* output = context.Output<Tensor>("Out");
auto out_dtype = context.Attr<int>("out_dtype");
auto pt_out_dtype = paddle::framework::TransToPtenDataType(
auto pt_out_dtype = paddle::framework::TransToPhiDataType(
static_cast<framework::proto::VarType::Type>(out_dtype));
std::vector<int> dims = context.Attr<std::vector<int>>("dim");
......@@ -714,7 +714,7 @@ class ReduceCudaGradKernel : public framework::OpKernel<T> {
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* d_x = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto out_dtype = context.Attr<int>("in_dtype");
auto pt_out_dtype = framework::TransToPtenDataType(
auto pt_out_dtype = framework::TransToPhiDataType(
static_cast<framework::proto::VarType::Type>(out_dtype));
// get reduce_dim and reduce_num for reduce_mean_grad
int dim_size = in_x->dims().size();
......@@ -735,8 +735,8 @@ class ReduceCudaGradKernel : public framework::OpKernel<T> {
} else {
d_x->mutable_data(dev_ctx.GetPlace(), d_out->dtype());
}
auto pt_d_out = paddle::experimental::MakePtenDenseTensor(new_d_out);
auto pt_d_x = paddle::experimental::MakePtenDenseTensor(*d_x);
auto pt_d_out = paddle::experimental::MakePhiDenseTensor(new_d_out);
auto pt_d_x = paddle::experimental::MakePhiDenseTensor(*d_x);
if (out_dtype <= 0) {
pt_out_dtype = d_out->dtype();
}
......
......@@ -16,7 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/framework/phi_utils.h"
// only can include the headers in paddle/phi/api dirs
#include "paddle/phi/api/lib/utils/tensor_utils.h"
......
......@@ -42,9 +42,9 @@ class ScaleXPUKernel : public framework::OpKernel<T> {
framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(out_var);
out->mutable_data<T>(in->place());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
// call pten kernel
// call phi kernel
phi::ScaleKernel<T>(
static_cast<const typename framework::ConvertToPtenContext<
static_cast<const typename framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*in, scale, bias, bias_after_scale, out);
}
......
......@@ -87,7 +87,7 @@ class SoftmaxWithCrossEntropyMLUKernel : public framework::OpKernel<T> {
platform::errors::InvalidArgument(
"If soft_label=False, axis must be -1 or"
" can be regard as last dimention in mlu kernel."));
framework::Tensor labels_int32(framework::TransToPtenDataType(VT::INT32));
framework::Tensor labels_int32(framework::TransToPhiDataType(VT::INT32));
labels_int32.Resize(labels->dims());
labels_int32.mutable_data<int32_t>(ctx.GetPlace());
......
......@@ -47,7 +47,7 @@ class TopkMLUKernel : public framework::OpKernel<T> {
const bool sorted = true;
const int axis = -1;
// cnnl only support int32/int16 type of indices
framework::Tensor indices_int32(framework::TransToPtenDataType(VT::INT32));
framework::Tensor indices_int32(framework::TransToPhiDataType(VT::INT32));
indices_int32.Resize(indices->dims());
indices_int32.mutable_data<int32_t>(place);
......
......@@ -55,7 +55,7 @@ class TopkV2MLUKernel : public framework::OpKernel<T> {
indices->mutable_data<int64_t>(place);
// cnnl only support int32/int16 type of indices
framework::Tensor indices_int32(framework::TransToPtenDataType(VT::INT32));
framework::Tensor indices_int32(framework::TransToPhiDataType(VT::INT32));
indices_int32.Resize(indices->dims());
indices_int32.mutable_data<int32_t>(place);
......
......@@ -36,7 +36,7 @@ class GPUUniformRandomInplaceGradKernel : public framework::OpKernel<T> {
ctx.template device_context<platform::CUDADeviceContext>();
float value = static_cast<float>(0.0f);
phi::FullKernel<T>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
static_cast<const typename paddle::framework::ConvertToPhiContext<
paddle::platform::CUDADeviceContext>::TYPE&>(dev_cxt),
dims, value, phi::DataType::UNDEFINED, dx);
}
......
......@@ -113,7 +113,7 @@ void Executor::Run(const std::vector<const Tensor *> &inputs,
auto fetch_dtype = fetch_info.dataType();
auto paddle_type = PopartType2VarType(fetch_dtype);
tensor->mutable_data(ctx.GetPlace(),
framework::TransToPtenDataType(paddle_type));
framework::TransToPhiDataType(paddle_type));
anchor_wrappers.emplace(tensor_id, PaddleIArray(tensor));
popart_anchors.emplace(tensor_id, anchor_wrappers.at(tensor_id));
}
......
......@@ -467,7 +467,7 @@ void NpuOpRunner::TypeAdapter(
} else {
tmp_inputs[i].Resize(inputs[i].dims());
tmp_inputs[i].mutable_data(dev_ctx.GetPlace(),
framework::TransToPtenDataType(input_type[i]));
framework::TransToPhiDataType(input_type[i]));
const auto &cast_runner = NpuOpRunner(
"Cast", {inputs[i]}, {tmp_inputs[i]},
......@@ -484,7 +484,7 @@ void NpuOpRunner::TypeAdapter(
} else {
tmp_outputs[i].Resize(outputs[i].dims());
tmp_outputs[i].mutable_data(
dev_ctx.GetPlace(), framework::TransToPtenDataType(output_type[i]));
dev_ctx.GetPlace(), framework::TransToPhiDataType(output_type[i]));
}
}
......
......@@ -1056,7 +1056,7 @@ class ReorderMKLDNNHandler {
platform::Place place) {
auto dst_md = platform::MKLDNNMemDesc(dims_, dtype_dst_, fmt);
auto dst_data = output->mutable_data(
place, framework::TransToPtenDataType(vtype_dst_), dst_md.get_size());
place, framework::TransToPhiDataType(vtype_dst_), dst_md.get_size());
return std::make_shared<dnnl::memory>(dst_md, engine_, dst_data);
}
......@@ -1065,7 +1065,7 @@ class ReorderMKLDNNHandler {
const MKLDNNMemoryFormat& fmt, platform::Place place) {
auto dst_md = platform::MKLDNNMemDesc(dims, dtype_dst_, fmt);
auto dst_data = output->mutable_data(
place, framework::TransToPtenDataType(vtype_dst_), dst_md.get_size());
place, framework::TransToPhiDataType(vtype_dst_), dst_md.get_size());
return std::make_shared<dnnl::memory>(dst_md, engine_, dst_data);
}
......
......@@ -59,7 +59,7 @@ struct Transform {
BinaryOperation op);
};
// NOTE: After the pten kernel is migrated, it needs to be deleted.
// NOTE: After the phi kernel is migrated, it needs to be deleted.
template <>
struct Transform<platform::CPUDeviceContext> {
template <typename InputIter, typename OutputIter, typename UnaryOperation>
......
......@@ -75,7 +75,7 @@ void EmptyTensorInitializer(TensorObject* self, const std::string& name,
std::shared_ptr<phi::DenseTensor> dense_tensor =
std::make_shared<phi::DenseTensor>(
phi::make_intrusive<paddle::experimental::SharedStorage>(place),
phi::DenseTensorMeta(paddle::framework::TransToPtenDataType(dtype),
phi::DenseTensorMeta(paddle::framework::TransToPhiDataType(dtype),
ddims));
if (phi::product(ddims) > 0) {
dense_tensor->mutable_data(place);
......@@ -133,7 +133,7 @@ void InitTensorWithTensor(TensorObject* self,
VLOG(4) << "Same place, do ShareDataWith";
} else {
self->tensor.set_impl(
src.copy_to(phi::TransToPtenBackend(place), true).impl());
src.copy_to(phi::TransToPhiBackend(place), true).impl());
VLOG(4) << "Different place, do TensorCopy";
}
if (src.get_autograd_meta()) {
......@@ -157,7 +157,7 @@ void InitTensorWithFrameworkTensor(TensorObject* self,
auto temp =
paddle::experimental::Tensor(std::make_shared<phi::DenseTensor>(src));
self->tensor.set_impl(
temp.copy_to(phi::TransToPtenBackend(place), true).impl());
temp.copy_to(phi::TransToPhiBackend(place), true).impl());
VLOG(4) << "Different place, do TensorCopy";
}
egr::EagerUtils::autograd_meta(&(self->tensor))->SetPersistable(false);
......
......@@ -135,7 +135,7 @@ static PyObject* eager_api_tensor_copy(PyObject* self, PyObject* args,
auto place = CastPyArg2Place(PyTuple_GET_ITEM(args, 2), 2);
bool blocking = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 3), 3);
dst = src.copy_to(phi::TransToPtenBackend(place), blocking);
dst = src.copy_to(phi::TransToPhiBackend(place), blocking);
egr::EagerUtils::autograd_meta(&dst)->SetStopGradient(
egr::EagerUtils::autograd_meta(&(src))->StopGradient());
egr::EagerUtils::autograd_meta(&dst)->SetPersistable(
......
......@@ -191,7 +191,7 @@ static PyObject* tensor_method__copy_to(TensorObject* self, PyObject* args,
bool blocking = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 0), 0);
auto place = CastPyArg2Place(PyTuple_GET_ITEM(args, 1), 1);
auto cp_tensor =
self->tensor.copy_to(phi::TransToPtenBackend(place), blocking);
self->tensor.copy_to(phi::TransToPhiBackend(place), blocking);
egr::EagerUtils::autograd_meta(&cp_tensor)->SetStopGradient(true);
egr::EagerUtils::autograd_meta(&cp_tensor)
->SetPersistable(
......
......@@ -32,7 +32,7 @@
#endif
#include "paddle/fluid/pybind/op_function_generator.h"
// pten
// phi
#include "paddle/phi/kernels/declarations.h"
// clang-format off
......@@ -365,9 +365,9 @@ GenerateOpFunctions() {
auto& op_type = op_proto->type();
// Skip ooerator which is not inherit form OperatorWithKernel, like while,
// since only OperatorWithKernel can run in dygraph mode.
// if the pten lib contains op kernel, we still generate ops method
// if the phi lib contains op kernel, we still generate ops method
if (!all_kernels.count(op_type) &&
!phi::KernelFactory::Instance().HasCompatiblePtenKernel(op_type)) {
!phi::KernelFactory::Instance().HasCompatiblePhiKernel(op_type)) {
continue;
}
std::string func_name = "eager_api_" + op_type;
......
......@@ -15,7 +15,7 @@
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/pybind/pybind.h" // NOLINT
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/phi/core/kernel_factory.h"
......
......@@ -32,7 +32,7 @@
#include "paddle/fluid/framework/fleet/ascend_wrapper.h"
#endif
// pten
// phi
#include "paddle/phi/kernels/declarations.h"
// NOTE(pangyoki): Inplace OP with duplicable input.
......@@ -400,9 +400,9 @@ GenerateOpFunctions() {
auto& op_type = op_proto->type();
// Skip operator which is not inherit form OperatorWithKernel, like while,
// since only OperatorWithKernel can run in dygraph mode.
// if the pten lib contains op kernel, we still generate ops method
// if the phi lib contains op kernel, we still generate ops method
if (!all_kernels.count(op_type) &&
!phi::KernelFactory::Instance().HasCompatiblePtenKernel(op_type)) {
!phi::KernelFactory::Instance().HasCompatiblePhiKernel(op_type)) {
continue;
}
......
......@@ -50,8 +50,8 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/parallel_executor.h"
#include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/framework/prune.h"
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/save_load_util.h"
#include "paddle/fluid/framework/scope_pool.h"
......@@ -464,7 +464,7 @@ static void inline CreateVariableIfNotExit(
tensor_temp->Resize(phi::make_ddim(var_desc.GetShape()));
tensor_temp->mutable_data(
exe->GetPlace(),
framework::TransToPtenDataType(var_desc.GetDataType()));
framework::TransToPhiDataType(var_desc.GetDataType()));
}
}
} else {
......@@ -671,60 +671,60 @@ PYBIND11_MODULE(core_noavx, m) {
m.def("_get_use_default_grad_op_desc_maker_ops",
[] { return OpInfoMap::Instance().GetUseDefaultGradOpDescMakerOps(); });
m.def(
"_get_all_register_op_kernels",
[](const std::string &lib) {
std::unordered_map<std::string, std::vector<std::string>>
all_kernels_info;
if (lib == "fluid" || lib == "all") {
auto &all_kernels =
paddle::framework::OperatorWithKernel::AllOpKernels();
for (auto &kernel_pair : all_kernels) {
auto op_type = kernel_pair.first;
std::vector<std::string> kernel_types;
for (auto &info_pair : kernel_pair.second) {
paddle::framework::OpKernelType kernel_type = info_pair.first;
kernel_types.emplace_back(
paddle::framework::KernelTypeToString(kernel_type));
m.def("_get_all_register_op_kernels",
[](const std::string &lib) {
std::unordered_map<std::string, std::vector<std::string>>
all_kernels_info;
if (lib == "fluid" || lib == "all") {
auto &all_kernels =
paddle::framework::OperatorWithKernel::AllOpKernels();
for (auto &kernel_pair : all_kernels) {
auto op_type = kernel_pair.first;
std::vector<std::string> kernel_types;
for (auto &info_pair : kernel_pair.second) {
paddle::framework::OpKernelType kernel_type = info_pair.first;
kernel_types.emplace_back(
paddle::framework::KernelTypeToString(kernel_type));
}
all_kernels_info.emplace(op_type, kernel_types);
}
all_kernels_info.emplace(op_type, kernel_types);
}
}
if (lib == "pten" || lib == "all") {
auto pten_kernels = phi::KernelFactory::Instance().kernels();
for (auto &kernel_pair : pten_kernels) {
auto op_type = phi::TransToFluidOpName(kernel_pair.first);
std::vector<std::string> kernel_types;
for (auto &info_pair : kernel_pair.second) {
framework::OpKernelType kernel_type =
framework::TransPtenKernelKeyToOpKernelType(info_pair.first);
auto kernel_type_str = framework::KernelTypeToString(kernel_type);
if (all_kernels_info.count(op_type)) {
if (std::find(all_kernels_info[op_type].begin(),
all_kernels_info[op_type].end(),
kernel_type_str) ==
all_kernels_info[op_type].end()) {
all_kernels_info[op_type].emplace_back(kernel_type_str);
if (lib == "phi" || lib == "all") {
auto phi_kernels = phi::KernelFactory::Instance().kernels();
for (auto &kernel_pair : phi_kernels) {
auto op_type = phi::TransToFluidOpName(kernel_pair.first);
std::vector<std::string> kernel_types;
for (auto &info_pair : kernel_pair.second) {
framework::OpKernelType kernel_type =
framework::TransPhiKernelKeyToOpKernelType(info_pair.first);
auto kernel_type_str =
framework::KernelTypeToString(kernel_type);
if (all_kernels_info.count(op_type)) {
if (std::find(all_kernels_info[op_type].begin(),
all_kernels_info[op_type].end(),
kernel_type_str) ==
all_kernels_info[op_type].end()) {
all_kernels_info[op_type].emplace_back(kernel_type_str);
}
} else {
kernel_types.emplace_back(kernel_type_str);
}
} else {
kernel_types.emplace_back(kernel_type_str);
}
}
if (!kernel_types.empty()) {
all_kernels_info.emplace(op_type, kernel_types);
if (!kernel_types.empty()) {
all_kernels_info.emplace(op_type, kernel_types);
}
}
}
}
return all_kernels_info;
},
py::arg("lib") = "all",
R"DOC(
return all_kernels_info;
},
py::arg("lib") = "all",
R"DOC(
Return the registered kernels in paddle.
Args:
lib[string]: the libarary, could be 'pten', 'fluid' and 'all'.
lib[string]: the libarary, could be 'phi', 'fluid' and 'all'.
)DOC");
// NOTE(zjl): ctest would load environment variables at the beginning even
......@@ -823,39 +823,39 @@ PYBIND11_MODULE(core_noavx, m) {
.def("_mutable_data",
[](framework::Tensor &self, paddle::platform::CPUPlace &place,
paddle::framework::proto::VarType::Type type) {
return reinterpret_cast<uintptr_t>(self.mutable_data(
place, framework::TransToPtenDataType(type)));
return reinterpret_cast<uintptr_t>(
self.mutable_data(place, framework::TransToPhiDataType(type)));
})
.def("_mutable_data",
[](framework::Tensor &self, paddle::platform::XPUPlace &place,
paddle::framework::proto::VarType::Type type) {
return reinterpret_cast<uintptr_t>(self.mutable_data(
place, framework::TransToPtenDataType(type)));
return reinterpret_cast<uintptr_t>(
self.mutable_data(place, framework::TransToPhiDataType(type)));
})
.def("_mutable_data",
[](framework::Tensor &self, paddle::platform::CUDAPlace &place,
paddle::framework::proto::VarType::Type type) {
return reinterpret_cast<uintptr_t>(self.mutable_data(
place, framework::TransToPtenDataType(type)));
return reinterpret_cast<uintptr_t>(
self.mutable_data(place, framework::TransToPhiDataType(type)));
})
.def("_mutable_data",
[](framework::Tensor &self, paddle::platform::CUDAPinnedPlace &place,
paddle::framework::proto::VarType::Type type) {
return reinterpret_cast<uintptr_t>(self.mutable_data(
place, framework::TransToPtenDataType(type)));
return reinterpret_cast<uintptr_t>(
self.mutable_data(place, framework::TransToPhiDataType(type)));
})
.def("_mutable_data",
[](framework::Tensor &self, paddle::platform::MLUPlace &place,
paddle::framework::proto::VarType::Type type) {
return reinterpret_cast<uintptr_t>(self.mutable_data(
place, framework::TransToPtenDataType(type)));
return reinterpret_cast<uintptr_t>(
self.mutable_data(place, framework::TransToPhiDataType(type)));
})
.def("_clear", &framework::Tensor::clear)
.def("_mutable_data",
[](framework::Tensor &self, paddle::platform::NPUPlace &place,
paddle::framework::proto::VarType::Type type) {
return reinterpret_cast<uintptr_t>(self.mutable_data(
place, framework::TransToPtenDataType(type)));
return reinterpret_cast<uintptr_t>(
self.mutable_data(place, framework::TransToPhiDataType(type)));
})
.def("_copy_from", &TensorCopyFrom<paddle::platform::CPUPlace>,
py::arg("tensor"), py::arg("place"), py::arg("batch_size") = -1)
......
......@@ -324,7 +324,7 @@ void SetTensorFromPyArrayT(
if (zero_copy) {
auto holder = std::make_shared<details::NumpyAllocation<T>>(array);
auto type = framework::ToDataType(std::type_index(typeid(T)));
self->ResetHolderWithType(holder, framework::TransToPtenDataType(type));
self->ResetHolderWithType(holder, framework::TransToPhiDataType(type));
} else {
auto dst = self->mutable_data<T>(place);
std::memcpy(dst, array.data(), array.nbytes());
......@@ -348,7 +348,7 @@ void SetTensorFromPyArrayT(
if (zero_copy) {
auto holder = std::make_shared<details::NumpyAllocation<T>>(array);
auto type = framework::ToDataType(std::type_index(typeid(T)));
self->ResetHolderWithType(holder, framework::TransToPtenDataType(type));
self->ResetHolderWithType(holder, framework::TransToPhiDataType(type));
} else {
// IPU does not store Tensor data, Tensor will be created on CPU
if (!self->initialized()) {
......@@ -518,7 +518,7 @@ void SetUVATensorFromPyArray(
cuda_device_pointer, need_allocate_size,
platform::CUDAPlace(device_id));
self_tensor->ResetHolderWithType(holder,
framework::TransToPtenDataType(data_type));
framework::TransToPhiDataType(data_type));
#endif
}
......
......@@ -24,12 +24,12 @@ limitations under the License. */
#endif
#endif
// new pten apis
// new phi apis
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/api/include/sparse_api.h"
#include "paddle/phi/api/include/tensor.h"
// pten common headers
// phi common headers
#include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/layout.h"
......
......@@ -49,8 +49,6 @@ namespace paddle {
namespace experimental {
class CompatiblePTenTensorUtils;
class AbstractAutogradMeta {
public:
// No AbstractAutogradMeta should be created
......@@ -59,7 +57,7 @@ class AbstractAutogradMeta {
/**
* Tensor is the API description of the basic data structure in the
* [ "Paddle Tensor Operation (pten)" Library ].
* [ "Paddle Tensor Operation (phi)" Library ].
*
* It is not limited to a simple n-dimensional array.
* It contains a smart pointer to `TensorImpl`. The data description contained
......@@ -366,7 +364,7 @@ class PADDLE_API Tensor final {
/* Part 5: Data Transform methods */
/* Alert!!!!: All copy method can only deep copy impl, autograd info only be
* copied */
/* out of pten */
/* out of phi */
/**
* @brief Copy the current Tensor data to the specified device
* and return the new Tensor. It's usually used to set the input tensor data.
......@@ -476,9 +474,6 @@ class PADDLE_API Tensor final {
/* Part 9: Auto generated Tensor methods */
private:
friend class CompatiblePTenTensorUtils;
private:
/**
* [ Why use abstract TensorImpl interface here? ]
......
......@@ -58,7 +58,7 @@ Tensor copy_to_impl(const Tensor& x, Backend backend, bool blocking) {
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(
*dev_ctx, *dense_x, phi::TransToPtenPlace(backend), blocking, kernel_out);
*dev_ctx, *dense_x, phi::TransToPhiPlace(backend), blocking, kernel_out);
return out;
}
......
......@@ -27,7 +27,7 @@ namespace experimental {
#endif
/**
* Now there is no module to call pten's API. When compiling, the function
* Now there is no module to call phi's API. When compiling, the function
* implementation will be optimized. Therefore, the symbol will be exposed
* manually for the time being.
*
......@@ -41,7 +41,7 @@ namespace experimental {
#define PD_DECLARE_API(name) \
extern PADDLE_API int RegisterSymbolsFor##name(); \
UNUSED static int use_pten_api_##name = RegisterSymbolsFor##name()
UNUSED static int use_phi_api_##name = RegisterSymbolsFor##name()
} // namespace experimental
} // namespace paddle
......@@ -106,7 +106,7 @@ inline paddle::optional<phi::MetaTensor> MakeMetaTensor(
inline phi::DenseTensor* SetKernelOutput(Backend backend, Tensor* out) {
if (!out->initialized()) {
auto dense_tensor = std::make_shared<phi::DenseTensor>(
phi::make_intrusive<SharedStorage>(phi::TransToPtenPlace(backend)),
phi::make_intrusive<SharedStorage>(phi::TransToPhiPlace(backend)),
phi::DenseTensorMeta());
out->set_impl(dense_tensor);
return dense_tensor.get();
......@@ -120,7 +120,7 @@ inline std::vector<phi::DenseTensor*> SetKernelOutput(
std::vector<phi::DenseTensor*> results(out_size);
for (size_t i = 0; i < out_size; ++i) {
auto tensor_ptr = std::make_shared<phi::DenseTensor>(
phi::make_intrusive<SharedStorage>(phi::TransToPtenPlace(backend)),
phi::make_intrusive<SharedStorage>(phi::TransToPhiPlace(backend)),
phi::DenseTensorMeta());
results[i] = tensor_ptr.get();
out->emplace_back();
......
......@@ -38,7 +38,7 @@ inline bool NeedTransformPlace(const paddle::platform::Place& input,
const TransformFlag& transform_flag) {
bool ret = transform_flag.need_trans_backend() &&
target != Backend::ALL_BACKEND &&
!platform::is_same_place(input, phi::TransToPtenPlace(target));
!platform::is_same_place(input, phi::TransToPhiPlace(target));
return ret;
}
......@@ -168,10 +168,10 @@ phi::DenseTensor TransformData(const phi::DenseTensor& tensor,
out.place(), target_args_def.backend, transform_flag)) {
phi::DenseTensor result(
phi::make_intrusive<paddle::experimental::SharedStorage>(
phi::TransToPtenPlace(target_args_def.backend)),
phi::TransToPhiPlace(target_args_def.backend)),
{out.dtype(), out.dims(), out.layout()});
framework::TransDataDevice(
out, phi::TransToPtenPlace(target_args_def.backend), &result);
out, phi::TransToPhiPlace(target_args_def.backend), &result);
out = result;
}
return out;
......
......@@ -21,7 +21,7 @@ namespace experimental {
namespace detail {
BackendSet GetTensorBackendSet(const Tensor& t) {
BackendSet backend_set(phi::TransToPtenBackend(t.inner_place()));
BackendSet backend_set(phi::TransToPhiBackend(t.inner_place()));
switch (t.layout()) {
case DataLayout::MKLDNN:
backend_set = backend_set | BackendSet(Backend::MKLDNN);
......@@ -53,7 +53,7 @@ std::size_t CountLeadingZeros(uint64_t val) {
phi::DeviceContext* GetDeviceContextByBackend(phi::Backend backend) {
auto& pool = paddle::platform::DeviceContextPool::Instance();
return pool.Get(phi::TransToPtenPlace(backend));
return pool.Get(phi::TransToPhiPlace(backend));
}
DataType ParseDataType(DataType dtype) { return dtype; }
......@@ -83,7 +83,7 @@ DataType ParseDataTypeWithInputOrder(DataType dtype, const Tensor& tensor) {
Backend ParseBackend(Backend backend) { return backend; }
Backend ParseBackend(const Tensor& tensor) {
return phi::TransToPtenBackend(tensor.inner_place());
return phi::TransToPhiBackend(tensor.inner_place());
}
Backend ParseBackendWithInputOrder(Backend backend, const Tensor& tensor) {
......
......@@ -86,11 +86,11 @@ PADDLE_API Tensor to_sparse_coo(const Tensor& x,
// create empty SparseCooTensor
phi::DenseTensor non_zero_indices(
phi::make_intrusive<paddle::experimental::SharedStorage>(
phi::TransToPtenPlace(backend)),
phi::TransToPhiPlace(backend)),
std::move(indices_meta));
phi::DenseTensor non_zero_elements(
phi::make_intrusive<paddle::experimental::SharedStorage>(
phi::TransToPtenPlace(backend)),
phi::TransToPhiPlace(backend)),
std::move(elements_meta));
auto coo = std::make_shared<phi::SparseCooTensor>(
non_zero_indices, non_zero_elements, x.dims());
......@@ -148,15 +148,15 @@ PADDLE_API Tensor to_sparse_csr(const Tensor& x, Backend backend) {
// create empty SparseCooTensor
phi::DenseTensor non_zero_crows(
phi::make_intrusive<paddle::experimental::SharedStorage>(
phi::TransToPtenPlace(backend)),
phi::TransToPhiPlace(backend)),
std::move(crows_meta));
phi::DenseTensor non_zero_cols(
phi::make_intrusive<paddle::experimental::SharedStorage>(
phi::TransToPtenPlace(backend)),
phi::TransToPhiPlace(backend)),
std::move(cols_meta));
phi::DenseTensor non_zero_elements(
phi::make_intrusive<paddle::experimental::SharedStorage>(
phi::TransToPtenPlace(backend)),
phi::TransToPhiPlace(backend)),
std::move(elements_meta));
auto csr = std::make_shared<phi::SparseCsrTensor>(
non_zero_crows, non_zero_cols, non_zero_elements, x.dims());
......@@ -211,7 +211,7 @@ PADDLE_API Tensor to_dense(const Tensor& x, Backend backend) {
// create empty SparseCooTensor
auto dense_out = std::make_shared<phi::DenseTensor>(
phi::make_intrusive<paddle::experimental::SharedStorage>(
phi::TransToPtenPlace(backend)),
phi::TransToPhiPlace(backend)),
std::move(dense_meta));
kernel_context.EmplaceBackOutput(dense_out.get());
......
......@@ -33,7 +33,7 @@ limitations under the License. */
*
* We hope to organize the basic implementation of Tensor and the logic related
* to Tensor computation into an independent library, which we call
* [Tensor Operation Library, pten], so we extract or rewrite the original
* [Tensor Operation Library, phi], so we extract or rewrite the original
* Kernels.
*
* In the future, the training library, inference library and custom operators
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部