未验证 提交 f7f1dc03 编写于 作者: Y YuanRisheng 提交者: GitHub

[PHI]Change feed_op to phi kernel (#49116)

* change feed_op to phi kernel

* fix ci bugs

* fix build bugs

* fix ci bugs

* fix compile bugs

* fix ci bugs

* perfect code

* perfect comment code

* fix install bugs

* modify code according comment

* remove visitor in feed_op

* modify according comment

* perfect code according comment

* add infershape

* fix py3 bugs

* fix getexpected kernel type

* fix getexpected kernel type

* fix ci bugs

* add registry for custom device

* fix py3 bugs

* fix floating point error

* fix py3 test bugs
上级 b2a10916
......@@ -26,7 +26,7 @@ function(find_register FILENAME PATTERN OUTPUT)
PARENT_SCOPE)
endfunction()
function(find_phi_register FILENAME ADD_PATH)
function(find_phi_register FILENAME ADD_PATH PATTERN)
# set op_name to OUTPUT
set(options "")
set(oneValueArgs "")
......@@ -36,11 +36,11 @@ function(find_phi_register FILENAME ADD_PATH)
string(
REGEX
MATCH
"PD_REGISTER_KERNEL\\([ \t\r\n]*[a-z0-9_]*,[[ \\\t\r\n\/]*[a-z0-9_]*]?[ \\\t\r\n]*[a-zA-Z]*,[ \\\t\r\n]*[A-Z_]*"
"${PATTERN}\\([ \t\r\n]*[a-z0-9_]*,[[ \\\t\r\n\/]*[a-z0-9_]*]?[ \\\t\r\n]*[a-zA-Z]*,[ \\\t\r\n]*[A-Z_]*"
register
"${CONTENT}")
if(NOT register STREQUAL "")
string(REPLACE "PD_REGISTER_KERNEL(" "" register "${register}")
string(REPLACE "${PATTERN}(" "" register "${register}")
string(REPLACE "," ";" register "${register}")
string(REGEX REPLACE "[ \\\t\r\n]+" "" register "${register}")
string(REGEX REPLACE "//cuda_only" "" register "${register}")
......@@ -401,7 +401,8 @@ function(op_library TARGET)
# pybind USE_OP_ITSELF
set(op_name "")
# Add PHI Kernel Registry Message
find_phi_register(${cc_src} ${pybind_file})
find_phi_register(${cc_src} ${pybind_file} "PD_REGISTER_KERNEL")
find_phi_register(${cc_src} ${pybind_file} "PD_REGISTER_GENERAL_KERNEL")
find_register(${cc_src} "REGISTER_OPERATOR" op_name)
if(NOT ${op_name} EQUAL "")
file(APPEND ${pybind_file} "USE_OP_ITSELF(${op_name});\n")
......@@ -440,7 +441,8 @@ function(op_library TARGET)
foreach(cu_src ${cu_srcs})
set(op_name "")
# Add PHI Kernel Registry Message
find_phi_register(${cu_src} ${pybind_file})
find_phi_register(${cu_src} ${pybind_file} "PD_REGISTER_KERNEL")
find_phi_register(${cu_src} ${pybind_file} "PD_REGISTER_GENERAL_KERNEL")
find_register(${cu_src} "REGISTER_OP_CUDA_KERNEL" op_name)
if(NOT ${op_name} EQUAL "")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, CUDA);\n")
......
......@@ -47,7 +47,7 @@ void SetFeedVariable(Scope* scope,
}
void SetFeedVariable(Scope* scope,
const Strings& input,
const std::vector<std::string>& input,
const std::string& var_name,
size_t index) {
// If var_name Variable is not found in GlobalScope, a new variable will
......@@ -59,7 +59,7 @@ void SetFeedVariable(Scope* scope,
feed_inputs.resize(index + 1);
}
// shared data with input tensor
feed_inputs[index] = input;
feed_inputs[index] = Strings(input);
}
FetchType& GetFetchVariable(const Scope& scope,
......
......@@ -35,7 +35,7 @@ void SetFeedVariable(Scope* scope,
size_t index);
void SetFeedVariable(Scope* scope,
const Strings& input,
const std::vector<std::string>& input,
const std::string& var_name,
size_t index);
......
......@@ -19,12 +19,14 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/string_array.h"
#include "paddle/phi/core/extended_tensor.h"
namespace paddle {
namespace framework {
using FeedType =
paddle::variant<phi::DenseTensor, Strings, phi::SparseCooTensor>;
using FeedList = std::vector<FeedType>;
using FeedList = paddle::framework::PhiVector<FeedType>;
using FetchType = paddle::variant<phi::DenseTensor,
LoDTensorArray,
......
......@@ -117,6 +117,15 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext {
return var_type == proto::VarType::SPARSE_COO;
}
bool IsSparseCooTensorOutput(const std::string& name) const override {
auto var_types = ctx_.GetOutputsVarType(name);
return std::all_of(var_types.begin(),
var_types.end(),
[](const proto::VarType::Type& type) {
return type == proto::VarType::SPARSE_COO;
});
}
bool IsSparseCsrTensorInput(const std::string& name) const override {
auto var_type = ctx_.GetInputVarType(name);
return var_type == proto::VarType::SPARSE_CSR;
......
......@@ -126,6 +126,45 @@ void InferShapeUtilsTestKernel(const Context& dev_ctx,
VLOG(6) << "Come into InferShapeUtilsTestKernel";
}
void TestOutputInferMeta(const phi::MetaTensor& x, phi::MetaTensor* out) {
ASSERT_EQ(x.dtype(), phi::DataType::FLOAT32);
}
class InferShapeUtilsTestOutputOpMaker : public OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "input of test op");
AddOutput("Out", "output of test op");
AddComment("This is test op");
}
};
class InferShapeUtilsTestOutputOp : public OperatorWithKernel {
public:
using OperatorWithKernel::OperatorWithKernel;
phi::KernelKey GetExpectedKernelType(
const ExecutionContext& ctx) const override {
return phi::KernelKey(proto::VarType::FP32, ctx.GetPlace());
}
};
phi::KernelSignature TestSparseOutputOpArgumentMapping(
const phi::ArgumentMappingContext& ctx) {
if (ctx.IsSparseCooTensorOutput("Out")) {
return phi::KernelSignature(
"test_sparse_coo_tensor_output", {"X"}, {}, {"Out"});
}
return phi::KernelSignature("test_output", {"X"}, {}, {"Out"});
}
template <typename T, typename Context>
void InferShapeUtilsTestOutputKernel(const Context& dev_ctx,
const phi::DenseTensor& x,
phi::SparseCooTensor* out) {
VLOG(6) << "Come into InferShapeUtilsTestOutputKernel";
}
} // namespace framework
} // namespace paddle
......@@ -143,6 +182,21 @@ PD_REGISTER_KERNEL(infer_shape_utils_test,
paddle::framework::InferShapeUtilsTestKernel,
int) {}
DECLARE_INFER_SHAPE_FUNCTOR(
infer_shape_utils_test_output,
InferShapeUtilsTestOutputInferShapeFunctor,
PD_INFER_META(paddle::framework::TestOutputInferMeta));
REGISTER_OPERATOR(infer_shape_utils_test_output,
paddle::framework::InferShapeUtilsTestOutputOp,
paddle::framework::InferShapeUtilsTestOutputOpMaker,
InferShapeUtilsTestOutputInferShapeFunctor);
PD_REGISTER_KERNEL(test_sparse_coo_tensor_output,
CPU,
ALL_LAYOUT,
paddle::framework::InferShapeUtilsTestOutputKernel,
int) {}
TEST(InferShapeUtilsTest, ALL) {
paddle::framework::ProgramDesc prog;
paddle::framework::proto::BlockDesc proto_block;
......@@ -200,3 +254,27 @@ TEST(InferShapeUtilsTest, ALL) {
op->InferShape(block_desc);
}
TEST(InferShapeUtilsTestOutput, ALL) {
paddle::framework::ProgramDesc prog;
paddle::framework::proto::BlockDesc proto_block;
paddle::framework::BlockDesc block_desc(&prog, &proto_block);
auto* op = block_desc.AppendOp();
op->SetType("infer_shape_utils_test_output");
auto* x = block_desc.Var("x");
x->SetType(paddle::framework::proto::VarType::LOD_TENSOR);
x->SetDataType(paddle::framework::proto::VarType::FP32);
op->SetInput("X", {"x"});
auto* out = block_desc.Var("out");
out->SetType(paddle::framework::proto::VarType::SPARSE_COO);
op->SetOutput("Out", {"out"});
phi::OpUtilsMap::Instance().InsertArgumentMappingFn(
"infer_shape_utils_test_output",
paddle::framework::TestSparseOutputOpArgumentMapping);
op->InferShape(block_desc);
}
......@@ -1561,6 +1561,63 @@ void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
this->Info().infer_shape_(&infer_shape_ctx);
}
template <typename T>
bool HasSameTensorType(phi::TensorBase* phi_tensor, Variable* var) {
if (phi_tensor == nullptr && var == nullptr) {
return true;
} else if (phi_tensor != nullptr && var != nullptr) {
if (T::classof(phi_tensor) && var->IsType<T>()) {
return true;
}
}
return false;
}
// TODO(YuanRisheng): We need collect all `need_prepare_phi_data_`
// into this function.
void OperatorWithKernel::CheckWhetherPreparePhiData(
const VariableNameMap& innames,
const VariableNameMap& outnames,
const Scope& scope) const {
if (run_phi_kernel_ && impl_ != nullptr) {
const auto& phi_kernel_context = impl_->getKernelContext();
size_t phi_tensor_index = 0;
// Check each tensor in KernelContext, if there is a tensor that has
// different type with variable. The PhiKernelContext need be reconstructed.
// We use kernel_signature_'s output to retrieve tensor. Because the tensor
// in phi_kernel_context stored in the order of kernel_signature_'s output.
if (phi_kernel_context->OutputsSize() >= phi_tensor_index ||
kernel_signature_ == nullptr) {
need_prepare_phi_data_ = true;
return;
}
const auto& phi_output_names = kernel_signature_->output_names;
for (auto& phi_output_name : phi_output_names) {
const auto& iter = outnames.find(phi_output_name);
if (iter != outnames.end()) {
for (auto& var_name : iter->second) {
auto var_output = scope.FindVar(var_name);
auto phi_output =
phi_kernel_context->MutableOutputAt<phi::TensorBase>(
phi_tensor_index);
if (phi_output == nullptr) {
continue;
}
if (!(HasSameTensorType<phi::DenseTensor>(phi_output, var_output) ||
HasSameTensorType<phi::SparseCooTensor>(phi_output,
var_output) ||
HasSameTensorType<framework::Strings>(phi_output,
var_output))) {
need_prepare_phi_data_ = true;
}
phi_tensor_index++;
}
}
}
}
}
void OperatorWithKernel::RunImpl(const Scope& scope,
const platform::Place& place) const {
// To reduce the elapsed time of HasAttr, we use bool variable to record the
......@@ -1571,6 +1628,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
HasAttr(kAllKernelsMustComputeRuntimeShape))
all_kernels_must_compute_runtime_shape_ = true;
const Scope* cur_scope = &scope;
CheckWhetherPreparePhiData(Inputs(), Outputs(), scope);
if (!enable_cache_runtime_context_) {
RuntimeContext ctx(Inputs(), Outputs(), scope);
RunImpl(scope, place, &ctx);
......@@ -2993,7 +3051,6 @@ void OperatorWithKernel::BuildPhiKernelContext(
"to the size of kernel attribute_defs (%d).",
attr_names.size(),
attr_defs.size()));
for (size_t i = 0; i < input_names.size(); ++i) {
auto it = ctx.inputs.find(input_names[i]);
......@@ -3037,6 +3094,9 @@ void OperatorWithKernel::BuildPhiKernelContext(
} else if (var->IsType<framework::Vocab>()) {
tensor_in = &(var->Get<framework::Vocab>());
phi_kernel_context->EmplaceBackInputWithoutSetRange(tensor_in);
} else if (var->IsType<framework::FeedList>()) {
tensor_in = &(var->Get<framework::FeedList>());
phi_kernel_context->EmplaceBackInputWithoutSetRange(tensor_in);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported input `%s` type when call pt kernel.",
......@@ -3047,7 +3107,6 @@ void OperatorWithKernel::BuildPhiKernelContext(
phi_kernel_context->AssignInputRange(std::make_pair(start_idx, end_idx), i);
}
VLOG(4) << "Done inputs";
for (size_t i = 0; i < output_names.size(); ++i) {
auto it = ctx.outputs.find(output_names[i]);
size_t start_idx =
......@@ -3087,6 +3146,9 @@ void OperatorWithKernel::BuildPhiKernelContext(
// Note: If the input LoDTensorArray size is 0, the output
// LoDTensorArray is also 0
phi_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out);
} else if (var->template IsType<framework::Strings>()) {
tensor_out = var->template GetMutable<framework::Strings>();
phi_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out);
} else if (var->template IsType<paddle::framework::RawTensor>()) {
tensor_out = var->template GetMutable<paddle::framework::RawTensor>();
phi_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out);
......@@ -3108,7 +3170,6 @@ void OperatorWithKernel::BuildPhiKernelContext(
i);
}
VLOG(4) << "Done outputs";
for (size_t i = 0; i < attr_names.size(); ++i) {
VLOG(6) << "BuildPhiKernelContext: " << attr_names[i] << ": "
<< attr_defs[i].type_index;
......
......@@ -550,6 +550,13 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext {
return var->IsType<phi::SparseCooTensor>();
}
bool IsSparseCooTensorOutput(const std::string& name) const override {
auto vars = ctx_.MultiOutputVar(name);
return std::all_of(vars.begin(), vars.end(), [](const Variable* var) {
return var->IsType<phi::SparseCooTensor>();
});
}
bool IsSparseCsrTensorInput(const std::string& name) const override {
const auto* var = ctx_.InputVar(name);
return var->IsType<phi::SparseCsrTensor>();
......@@ -746,6 +753,10 @@ class OperatorWithKernel : public OperatorBase {
RuntimeContext* ctx,
const phi::Place& place) const;
void CheckWhetherPreparePhiData(const VariableNameMap& innames,
const VariableNameMap& outnames,
const Scope& scope) const;
void TransferInplaceVarsBack(const Scope& scope,
const std::vector<std::string>& inplace_vars,
const Scope& exec_scope) const;
......
......@@ -23,8 +23,8 @@ namespace paddle {
namespace framework {
/// \brief Fluid Kernel and PHI Kernel will be unified in the future.
/// So, we need a class in PHI that can represent the RAW type in Fluid.
/// The RawTensor is for PHI Kernel that has RAW type arguments.
/// So, we need a class in PHI that can represent the RawTensor type in Fluid.
/// The RawTensor is for PHI Kernel that has RawTensor type arguments.
class RawTensor : public phi::ExtendedTensor,
public phi::TypeInfoTraits<phi::TensorBase, RawTensor> {
public:
......@@ -37,13 +37,35 @@ class RawTensor : public phi::ExtendedTensor,
RawTensor& operator=(RawTensor&& other) = default;
/// \brief Destroy the RawTensor and release exclusive resources.
virtual ~RawTensor() = default;
virtual ~RawTensor() {
if (!data_.empty()) {
data_deleter_();
}
}
public:
/// \brief Returns the name of the class for type traits.
/// \return The name of the class.
static const char* name() { return "RawTensor"; }
template <typename T>
T& Get() const {
PADDLE_ENFORCE_EQ(data_.empty(),
false,
platform::errors::PreconditionNotMet(
"The data in RawTensor is empty. Please set data "
"before using it."));
try {
return *(paddle::any_cast<T*>(data_));
} catch (paddle::bad_any_cast&) {
PADDLE_THROW(phi::errors::InvalidArgument(
"Invalid data type error, expected %s, actual %s.",
typeid(T).name(),
data_type_.name()));
}
}
template <typename T>
T* GetMutable() {
if (!data_.empty()) {
......@@ -70,7 +92,7 @@ class RawTensor : public phi::ExtendedTensor,
private:
paddle::any data_;
std::function<void(void)> data_deleter_;
std::function<void(void)> data_deleter_ = []() {};
std::type_index data_type_ = std::type_index(typeid(void));
};
......
......@@ -25,6 +25,10 @@ limitations under the License. */
namespace paddle {
namespace framework {
// Note(YuanRisheng): Vocab is mainly used for faster_tokenizer_op and we don't
// recommend widely use it. Because faster_tokenizer_op may be deleted in the
// future and this class will be deleted.
class Vocab : public phi::ExtendedTensor,
public phi::TypeInfoTraits<phi::TensorBase, Vocab> {
public:
......@@ -94,8 +98,73 @@ class Vocab : public phi::ExtendedTensor,
std::unordered_map<std::wstring, std::int32_t> data_;
};
// Note(YuanRisheng): PhiVector is essentially a vector that only used for PHI
// Kernel. It can be used when you define a non-tensor type that needs to be
// stored in a vector as PHI kernel argument.
template <typename T>
class PhiVector : public phi::ExtendedTensor,
public phi::TypeInfoTraits<phi::TensorBase, PhiVector<T>> {
public:
PhiVector() = default;
explicit PhiVector(const std::vector<T>& init_data) : data_(init_data) {}
PhiVector(PhiVector&& other) = default;
PhiVector(const PhiVector& other) = default;
PhiVector& operator=(const PhiVector& other) = default;
PhiVector& operator=(const std::vector<T>& other) {
data_ = other;
return *this;
}
PhiVector& operator=(PhiVector&& other) = default;
/// \brief Destroy the PhiVector and release exclusive resources.
virtual ~PhiVector() = default;
public:
/// \brief Returns the name of the class for type traits.
/// \return The name of the class.
static const char* name() {
return (std::string("PhiVector_") + std::string(typeid(T).name())).c_str();
}
size_t size() const { return data_.size(); }
void resize(size_t size) { data_.resize(size); }
void clear() { data_.clear(); }
void emplace_back(const T& feed_data) { data_.emplace_back(feed_data); }
const T& operator[](size_t index) const { return data_[index]; }
T& operator[](size_t index) { return data_[index]; }
T& at(size_t index) { return data_.at(index); }
const T& at(size_t index) const { return data_.at(index); }
typename std::vector<T>::iterator begin() { return data_.begin(); }
typename std::vector<T>::const_iterator begin() const {
return data_.begin();
}
typename std::vector<T>::iterator end() { return data_.end(); }
typename std::vector<T>::const_iterator end() const { return data_.end(); }
private:
std::vector<T> data_;
};
using String = std::string;
using Strings = std::vector<std::string>;
using Strings = PhiVector<std::string>;
// Convert the std::string type to the std::string type.
bool ConvertStrToWstr(const std::string& src, std::wstring* res);
......
......@@ -221,6 +221,7 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl<
Vocab,
std::vector<int>,
std::vector<float>,
std::vector<std::string>,
RawTensor>;
template <typename T>
struct VarTypeTrait {
......
......@@ -1655,7 +1655,8 @@ std::unique_ptr<ZeroCopyTensor> AnalysisPredictor::GetInputTensor(
auto custom_place = place_;
auto paddleplace = static_cast<PaddlePlace>(
static_cast<size_t>(PaddlePlace::kCUSTOM) +
phi::GetOrRegisterGlobalDeviceTypeId(place_.GetDeviceType()));
phi::CustomRegisteredDeviceMap::Instance()
.GetOrRegisterGlobalDeviceTypeId(place_.GetDeviceType()));
res->SetPlace(paddleplace, custom_place.GetDeviceId());
} else {
auto gpu_place = place_;
......@@ -1710,7 +1711,8 @@ std::unique_ptr<ZeroCopyTensor> AnalysisPredictor::GetOutputTensor(
auto custom_place = place_;
auto paddleplace = static_cast<PaddlePlace>(
static_cast<size_t>(PaddlePlace::kCUSTOM) +
phi::GetOrRegisterGlobalDeviceTypeId(place_.GetDeviceType()));
phi::CustomRegisteredDeviceMap::Instance()
.GetOrRegisterGlobalDeviceTypeId(place_.GetDeviceType()));
res->SetPlace(paddleplace, custom_place.GetDeviceId());
} else {
auto gpu_place = place_;
......
......@@ -25,13 +25,16 @@ if(WITH_ONNXRUNTIME)
cc_library(
zero_copy_tensor_dummy
SRCS zero_copy_tensor_dummy.cc
DEPS onnxruntime)
DEPS onnxruntime phi_enforce)
else()
cc_library(
zero_copy_tensor
SRCS zero_copy_tensor.cc
DEPS scope lod_tensor enforce)
cc_library(zero_copy_tensor_dummy SRCS zero_copy_tensor_dummy.cc)
cc_library(
zero_copy_tensor_dummy
SRCS zero_copy_tensor_dummy.cc
DEPS phi_enforce)
endif()
cc_test(
......
......@@ -16,6 +16,7 @@
#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/string_array.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/api/paddle_tensor.h"
#include "paddle/fluid/memory/memcpy.h"
......@@ -76,7 +77,8 @@ void Tensor::ReshapeStrings(const size_t &shape) {
var,
paddle::platform::errors::PreconditionNotMet(
"No tensor called [%s] in the runtime scope", name_));
paddle_infer::Strings *tensor = var->GetMutable<paddle_infer::Strings>();
paddle::framework::Strings *tensor =
var->GetMutable<paddle::framework::Strings>();
tensor->resize(shape);
}
......@@ -261,7 +263,9 @@ void Tensor::CopyFromCpu(const T *data) {
paddle::platform::DeviceContextPool &pool =
paddle::platform::DeviceContextPool::Instance();
paddle::platform::CustomPlace custom_place(
phi::GetGlobalDeviceType(device_type_id), device_);
phi::CustomRegisteredDeviceMap::Instance().GetGlobalDeviceType(
device_type_id),
device_);
auto *t_data = tensor->mutable_data<T>(custom_place);
auto *dev_ctx = static_cast<const paddle::platform::CustomDeviceContext *>(
pool.Get(custom_place));
......@@ -354,7 +358,7 @@ void Tensor::ShareExternalData(const T *data,
}
void Tensor::CopyStringsFromCpu(const paddle_infer::Strings *data) {
EAGER_GET_TENSOR(paddle_infer::Strings);
EAGER_GET_TENSOR(paddle::framework::Strings);
PADDLE_ENFORCE_GE(tensor->size(),
0,
paddle::platform::errors::PreconditionNotMet(
......
......@@ -112,6 +112,12 @@ bool PluginArgumentMappingContext::IsSparseCooTensorInput(
const std::string& name) const {
return false;
}
bool PluginArgumentMappingContext::IsSparseCooTensorOutput(
const std::string& name) const {
return false;
}
bool PluginArgumentMappingContext::IsSparseCsrTensorInput(
const std::string& name) const {
return false;
......
......@@ -56,6 +56,8 @@ class PluginArgumentMappingContext : public ::phi::ArgumentMappingContext {
bool IsDenseTensorOutput(const std::string& name) const override;
bool IsSparseCooTensorOutput(const std::string& name) const override;
bool IsSelectedRowsOutput(const std::string& name) const override;
bool IsForInferShape() const override { return false; }
......
......@@ -124,6 +124,7 @@ TEST(ArgMappingContexTest, BasicFunction) {
EXPECT_EQ(context.IsDenseTensorOutput("Out"), false);
EXPECT_EQ(context.IsSelectedRowsOutput("Out"), false);
EXPECT_EQ(context.IsSparseCooTensorOutput("Out"), false);
EXPECT_EQ(context.IsForInferShape(), false);
}
......
......@@ -11,6 +11,8 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/raw_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
namespace paddle {
......@@ -28,117 +30,128 @@ class OpBase;
namespace paddle {
namespace operators {
// FeedVariableVisitor is to feed the variable data
// according to data type (phi::DenseTensor or Strings).
class FeedVariableVisitor {
public:
explicit FeedVariableVisitor(framework::Variable *out_var,
const platform::Place &place)
: out_var_(out_var), place_(place) {}
void operator()(const phi::DenseTensor &in_tensor) const {
phi::DenseTensor *out_tensor = out_var_->GetMutable<phi::DenseTensor>();
if (platform::is_same_place(in_tensor.place(), place_)) {
out_tensor->ShareDataWith(in_tensor);
#ifdef PADDLE_WITH_IPU
} else if (platform::is_ipu_place(place_)) {
// For ipu, both in_tensor and out_tensor are allocated on cpu,
// PopART will copy tensor from host automatically,
// no TensorCopy() is required here.
out_tensor->ShareDataWith(in_tensor);
#endif
} else {
platform::DeviceContext *context =
platform::DeviceContextPool::Instance().Get(place_);
framework::TensorCopy(in_tensor, place_, *context, out_tensor);
}
out_tensor->set_lod(in_tensor.lod());
const framework::FeedType& CheckAndGetFeedItem(const phi::ExtendedTensor& x,
int col) {
PADDLE_ENFORCE_GE(col,
0,
platform::errors::InvalidArgument(
"Expected the column index (the attribute 'col' of "
"operator 'Feed') of current feeding variable to be "
"no less than 0. But received column index = %d.",
col));
auto feed_list = static_cast<const paddle::framework::FeedList*>(&x);
PADDLE_ENFORCE_LT(
static_cast<size_t>(col),
feed_list->size(),
platform::errors::InvalidArgument(
"The column index of current feeding variable is expected to be "
"less than the length of feeding list. But received column index = "
"%d, the length of feeding list = %d",
col,
feed_list->size()));
return feed_list->at(static_cast<size_t>(col));
}
template <typename Context>
void FeedDenseTensorKernel(const Context& dev_ctx,
const phi::ExtendedTensor& x,
int col,
phi::DenseTensor* out) {
PADDLE_ENFORCE_NOT_NULL(
out,
platform::errors::NotFound(
"Output cannot be found in scope for operator 'Feed'"));
const auto& feed_item = CheckAndGetFeedItem(x, col);
const auto& in_tensor = paddle::get<phi::DenseTensor>(feed_item);
const auto& place = dev_ctx.GetPlace();
if (platform::is_same_place(in_tensor.place(), place)) {
out->ShareDataWith(in_tensor);
} else {
framework::TensorCopy(in_tensor, place, dev_ctx, out);
}
void operator()(const framework::Strings &in_str) const {
framework::Strings *out_str = out_var_->GetMutable<framework::Strings>();
out_str->resize(in_str.size());
*out_str = in_str;
out->set_lod(in_tensor.lod());
}
template <typename Context>
void FeedSparseCooTensorKernel(const Context& dev_ctx,
const phi::ExtendedTensor& x,
int col,
phi::SparseCooTensor* out) {
PADDLE_ENFORCE_NOT_NULL(
out,
platform::errors::NotFound(
"Output cannot be found in scope for operator 'Feed'"));
const auto& feed_item = CheckAndGetFeedItem(x, col);
const auto& in_tensor = paddle::get<phi::SparseCooTensor>(feed_item);
const auto& place = dev_ctx.GetPlace();
if (platform::is_same_place(in_tensor.place(), place)) {
*out = in_tensor;
} else {
phi::DenseTensor indices, values;
framework::TensorCopy(in_tensor.indices(), place, dev_ctx, &indices);
framework::TensorCopy(in_tensor.values(), place, dev_ctx, &values);
out->SetMember(indices, values, in_tensor.meta());
}
}
template <typename Context>
void FeedStringsKernel(const Context& dev_ctx,
const phi::ExtendedTensor& x,
int col,
phi::ExtendedTensor* out) {
PADDLE_ENFORCE_NOT_NULL(
out,
platform::errors::NotFound(
"Output cannot be found in scope for operator 'Feed'"));
const auto& feed_item = CheckAndGetFeedItem(x, col);
auto strs_out = static_cast<framework::Strings*>(out);
const auto& in_str = paddle::get<framework::Strings>(feed_item);
strs_out->resize(in_str.size());
*strs_out = in_str;
}
class FeedOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void operator()(const phi::SparseCooTensor &in_tensor) const {
phi::SparseCooTensor *out_tensor =
out_var_->GetMutable<phi::SparseCooTensor>();
if (platform::is_same_place(in_tensor.place(), place_)) {
*out_tensor = in_tensor;
} else {
platform::DeviceContext *context =
platform::DeviceContextPool::Instance().Get(place_);
phi::DenseTensor indices, values;
framework::TensorCopy(in_tensor.indices(), place_, *context, &indices);
framework::TensorCopy(in_tensor.values(), place_, *context, &values);
out_tensor->SetMember(indices, values, in_tensor.meta());
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "feed");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "feed");
if (ctx->IsRuntime()) {
framework::Variable* x_var =
PADDLE_GET(framework::Variable*, ctx->GetInputVarPtrs("X")[0]);
auto& x = x_var->Get<framework::FeedList>();
int col = ctx->Attrs().Get<int>("col");
auto& feed_item = x[col];
if (feed_item.index() == 0) {
const auto& feed_item = CheckAndGetFeedItem(x, col);
auto& feed_tensor = PADDLE_GET_CONST(phi::DenseTensor, feed_item);
ctx->SetOutputDim("Out", feed_tensor.dims());
} else if (feed_item.index() == 1) {
auto& feed_str = PADDLE_GET_CONST(framework::Strings, feed_item);
framework::Variable* out_var =
PADDLE_GET(framework::Variable*, ctx->GetOutputVarPtrs("Out")[0]);
out_var->GetMutable<framework::Strings>()->resize(feed_str.size());
} else {
auto& feed_sparse_tensor =
PADDLE_GET_CONST(phi::SparseCooTensor, feed_item);
framework::Variable* out_var =
PADDLE_GET(framework::Variable*, ctx->GetOutputVarPtrs("Out")[0]);
out_var->GetMutable<phi::SparseCooTensor>()->set_meta(
feed_sparse_tensor.meta());
out_var->GetMutable<phi::SparseCooTensor>()->SetCoalesced(
feed_sparse_tensor.coalesced());
out_var->GetMutable<phi::SparseCooTensor>()->SetIndicesDict(
feed_sparse_tensor.GetIndicesDict());
}
}
}
private:
framework::Variable *out_var_;
const platform::Place &place_;
};
class FeedOp : public framework::OperatorBase {
public:
FeedOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
OP_INOUT_CHECK(HasInputs("X"), "Input", "X", "Feed");
OP_INOUT_CHECK(HasOutputs("Out"), "Output", "Out", "Feed");
auto feed_var_name = Input("X");
auto *feed_var = scope.FindVar(feed_var_name);
PADDLE_ENFORCE_NOT_NULL(
feed_var,
platform::errors::NotFound(
"Input varibale(%s) cannot be found in scope for operator 'Feed'.",
feed_var_name));
auto out_name = this->Output("Out");
auto *out_var = scope.FindVar(out_name);
PADDLE_ENFORCE_NOT_NULL(
out_var,
platform::errors::NotFound(
"Output variable(%s) cannot be found in scope for operator 'Feed'",
out_name));
auto col = Attr<int>("col");
PADDLE_ENFORCE_GE(col,
0,
platform::errors::InvalidArgument(
"Expected the column index (the attribute 'col' of "
"operator 'Feed') of current feeding variable to be "
"no less than 0. But received column index = %d.",
col));
VLOG(3) << "Feed variable " << feed_var_name << "'s " << col
<< " column to variable " << out_name;
auto &feed_list = feed_var->Get<framework::FeedList>();
PADDLE_ENFORCE_LT(
static_cast<size_t>(col),
feed_list.size(),
platform::errors::InvalidArgument(
"The column index of current feeding variable is expected to be "
"less than the length of feeding list. But received column index = "
"%d, the length of feeding list = %d",
col,
feed_list.size()));
auto &feed_item = feed_list.at(static_cast<size_t>(col));
FeedVariableVisitor visitor(out_var, place);
paddle::visit(visitor, feed_item);
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace());
}
};
......@@ -164,9 +177,152 @@ It should not be configured by users directly.
} // namespace operators
} // namespace paddle
// TODO(YuanRisheng): Maybe we need design a new registry macro for
// registering device independent kernels.
REGISTER_OPERATOR(
feed,
paddle::operators::FeedOp,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
paddle::operators::FeedOpInfoMaker);
PD_REGISTER_GENERAL_KERNEL(
feed_dense_tensor,
CPU,
ALL_LAYOUT,
paddle::operators::FeedDenseTensorKernel<phi::CPUContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(
feed_sparse_coo_tensor,
CPU,
ALL_LAYOUT,
paddle::operators::FeedSparseCooTensorKernel<phi::CPUContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(
feed_strings,
CPU,
ALL_LAYOUT,
paddle::operators::FeedStringsKernel<phi::CPUContext>,
ALL_DTYPE) {}
#if defined(PADDLE_WITH_MKLDNN)
PD_REGISTER_GENERAL_KERNEL(
feed_dense_tensor,
OneDNN,
ALL_LAYOUT,
paddle::operators::FeedDenseTensorKernel<phi::OneDNNContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(
feed_sparse_coo_tensor,
OneDNN,
ALL_LAYOUT,
paddle::operators::FeedSparseCooTensorKernel<phi::OneDNNContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(
feed_strings,
OneDNN,
ALL_LAYOUT,
paddle::operators::FeedStringsKernel<phi::OneDNNContext>,
ALL_DTYPE) {}
#endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_GENERAL_KERNEL(
feed_dense_tensor,
GPU,
ALL_LAYOUT,
paddle::operators::FeedDenseTensorKernel<phi::GPUContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(
feed_sparse_coo_tensor,
GPU,
ALL_LAYOUT,
paddle::operators::FeedSparseCooTensorKernel<phi::GPUContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(
feed_strings,
GPU,
ALL_LAYOUT,
paddle::operators::FeedStringsKernel<phi::GPUContext>,
ALL_DTYPE) {}
#elif defined(PADDLE_WITH_XPU)
PD_REGISTER_GENERAL_KERNEL(
feed_dense_tensor,
XPU,
ALL_LAYOUT,
paddle::operators::FeedDenseTensorKernel<phi::XPUContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(
feed_sparse_coo_tensor,
XPU,
ALL_LAYOUT,
paddle::operators::FeedSparseCooTensorKernel<phi::XPUContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(
feed_strings,
XPU,
ALL_LAYOUT,
paddle::operators::FeedStringsKernel<phi::XPUContext>,
ALL_DTYPE) {}
#elif defined(PADDLE_WITH_ASCEND_CL)
PD_REGISTER_GENERAL_KERNEL(
feed_dense_tensor,
npu,
ALL_LAYOUT,
paddle::operators::FeedDenseTensorKernel<phi::CustomContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(
feed_sparse_coo_tensor,
npu,
ALL_LAYOUT,
paddle::operators::FeedSparseCooTensorKernel<phi::CustomContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(
feed_strings,
npu,
ALL_LAYOUT,
paddle::operators::FeedStringsKernel<phi::CustomContext>,
ALL_DTYPE) {}
#elif defined(PADDLE_WITH_MLU)
PD_REGISTER_GENERAL_KERNEL(
feed_dense_tensor,
CustomMLU,
ALL_LAYOUT,
paddle::operators::FeedDenseTensorKernel<phi::CustomContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(
feed_sparse_coo_tensor,
CustomMLU,
ALL_LAYOUT,
paddle::operators::FeedSparseCooTensorKernel<phi::CustomContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(
feed_strings,
CustomMLU,
ALL_LAYOUT,
paddle::operators::FeedStringsKernel<phi::CustomContext>,
ALL_DTYPE) {}
#elif defined(PADDLE_WITH_CUSTOM_DEVICE)
PD_REGISTER_GENERAL_KERNEL(
feed_dense_tensor,
custom_cpu,
ALL_LAYOUT,
paddle::operators::FeedDenseTensorKernel<phi::CustomContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(
feed_sparse_coo_tensor,
custom_cpu,
ALL_LAYOUT,
paddle::operators::FeedSparseCooTensorKernel<phi::CustomContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(
feed_strings,
custom_cpu,
ALL_LAYOUT,
paddle::operators::FeedStringsKernel<phi::CustomContext>,
ALL_DTYPE) {}
#endif
......@@ -407,8 +407,8 @@ int BertTokenizer::Encode(
void BertTokenizer::BatchEncode(
vector<unordered_map<string, vector<int64_t>>>* batch_encode_inputs,
const vector<string>& batch_text,
const vector<string>& batch_text_pair /* = vector<string>() */,
const framework::Strings& batch_text,
const framework::Strings& batch_text_pair /* = vector<string>() */,
bool is_split_into_words /* = false */,
const size_t max_seq_len /* = 0 */,
bool pad_to_max_seq_len /* = false */) const {
......
......@@ -100,8 +100,8 @@ class BertTokenizer {
bool pad_to_max_seq_len = false) const;
void BatchEncode(
vector<unordered_map<string, vector<int64_t>>>* batch_encode_inputs,
const vector<string>& batch_text,
const vector<string>& batch_text_pair = vector<string>(),
const framework::Strings& batch_text,
const framework::Strings& batch_text_pair = framework::Strings(),
bool is_split_into_words = false,
const size_t max_seq_len = 0,
bool pad_to_max_seq_len = false) const;
......@@ -162,7 +162,7 @@ class FasterTokenizerKernel : public framework::OpKernel<T> {
} else {
tokenizer.BatchEncode(&batch_encode_inputs,
*text,
vector<string>(),
framework::Strings(),
is_split_into_words,
max_seq_len,
pad_to_max_seq_len);
......
......@@ -1503,7 +1503,7 @@ static PyObject* tensor_method_set_string_list(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
using Strings = std::vector<std::string>;
using Strings = paddle::framework::Strings;
auto strings = CastPyArg2VectorOfString(PyTuple_GET_ITEM(args, 0), 0);
auto var_tensor = std::make_shared<egr::VariableCompatTensor>();
*var_tensor->GetMutable<Strings>() = strings;
......
......@@ -184,39 +184,41 @@ PyObject* tensor_properties_get_shape(TensorObject* self, void* closure) {
value[i] = ddim[i];
}
}
auto desired_layout =
paddle::imperative::LayoutAutoTune::Instance().GetDesiredLayout();
auto default_layout =
paddle::imperative::LayoutAutoTune::Instance().GetDefaultLayout();
bool change_dim =
(desired_layout != default_layout &&
self->tensor.layout() == desired_layout && value.size() == 4);
VLOG(6) << "eager_properties 'Shape' method, layout autotune "
<< " desired_layout: " << desired_layout
<< " default_layout: " << default_layout
<< " tensor layout: " << self->tensor.layout()
<< " tensor's shape size is : " << value.size();
std::vector<int64_t> dims = value;
if (change_dim && phi::DataLayoutToString(desired_layout) == "NCHW") {
// NCHW -> NHWC
VLOG(6) << "layout autotune get Shape from NCHW -> NHWC " << value[0] << " "
<< value[1] << " " << value[2] << " " << value[3] << " to "
<< dims[0] << " " << dims[2] << " " << dims[3] << " " << dims[1];
value[0] = dims[0];
value[1] = dims[2];
value[2] = dims[3];
value[3] = dims[1];
} else if (change_dim && phi::DataLayoutToString(desired_layout) == "NHWC") {
// NHWC -> NCHW
VLOG(6) << "layout autotune get Shape from NHWC -> NCHW " << value[0] << " "
<< value[1] << " " << value[2] << " " << value[3] << " to "
<< dims[0] << " " << dims[3] << " " << dims[1] << " " << dims[2]
<< " " << dims[1];
value[0] = dims[0];
value[1] = dims[3];
value[2] = dims[1];
value[3] = dims[2];
if (!egr::IsVariableCompatTensor(self->tensor)) {
auto desired_layout =
paddle::imperative::LayoutAutoTune::Instance().GetDesiredLayout();
auto default_layout =
paddle::imperative::LayoutAutoTune::Instance().GetDefaultLayout();
bool change_dim =
(desired_layout != default_layout &&
self->tensor.layout() == desired_layout && value.size() == 4);
VLOG(6) << "eager_properties 'Shape' method, layout autotune "
<< " desired_layout: " << desired_layout
<< " default_layout: " << default_layout
<< " tensor layout: " << self->tensor.layout()
<< " tensor's shape size is : " << value.size();
std::vector<int64_t> dims = value;
if (change_dim && phi::DataLayoutToString(desired_layout) == "NCHW") {
// NCHW -> NHWC
VLOG(6) << "layout autotune get Shape from NCHW -> NHWC " << value[0]
<< " " << value[1] << " " << value[2] << " " << value[3] << " to "
<< dims[0] << " " << dims[2] << " " << dims[3] << " " << dims[1];
value[0] = dims[0];
value[1] = dims[2];
value[2] = dims[3];
value[3] = dims[1];
} else if (change_dim &&
phi::DataLayoutToString(desired_layout) == "NHWC") {
// NHWC -> NCHW
VLOG(6) << "layout autotune get Shape from NHWC -> NCHW " << value[0]
<< " " << value[1] << " " << value[2] << " " << value[3] << " to "
<< dims[0] << " " << dims[3] << " " << dims[1] << " " << dims[2]
<< " " << dims[1];
value[0] = dims[0];
value[1] = dims[3];
value[2] = dims[1];
value[3] = dims[2];
}
}
return ToPyObject(value);
......
......@@ -970,7 +970,7 @@ All parameter, weight, gradient are variables in Paddle.
}
})
.def("set_string_list",
[](Variable &self, Strings str_list) {
[](Variable &self, std::vector<std::string> str_list) {
*self.GetMutable<Strings>() = str_list;
})
.def("set_vocab",
......@@ -1926,7 +1926,7 @@ All parameter, weight, gradient are variables in Paddle.
m.def("set_feed_variable",
static_cast<void (*)( // NOLINT
Scope *,
const Strings &,
const std::vector<std::string> &,
const std::string &,
size_t)>(&framework::SetFeedVariable));
m.def("get_fetch_variable",
......
......@@ -134,7 +134,9 @@ inline std::ostream& operator<<(std::ostream& os, Backend backend) {
default: {
size_t device_type_id_ = static_cast<size_t>(backend) -
static_cast<size_t>(Backend::NUM_BACKENDS);
std::string device_type = phi::GetGlobalDeviceType(device_type_id_);
std::string device_type =
phi::CustomRegisteredDeviceMap::Instance().GetGlobalDeviceType(
device_type_id_);
if (!device_type.empty()) {
os << device_type;
} else {
......@@ -178,7 +180,8 @@ inline Backend StringToBackend(const char* backend_cstr) {
return Backend::IPU;
} else {
return static_cast<Backend>(static_cast<size_t>(Backend::NUM_BACKENDS) +
phi::GetOrRegisterGlobalDeviceTypeId(s));
phi::CustomRegisteredDeviceMap::Instance()
.GetOrRegisterGlobalDeviceTypeId(s));
}
}
......@@ -207,7 +210,9 @@ inline std::string BackendToString(const Backend& backend) {
default:
size_t device_type_id_ = static_cast<size_t>(backend) -
static_cast<size_t>(Backend::NUM_BACKENDS);
std::string device_type = phi::GetGlobalDeviceType(device_type_id_);
std::string device_type =
phi::CustomRegisteredDeviceMap::Instance().GetGlobalDeviceType(
device_type_id_);
if (!device_type.empty()) {
return device_type;
} else {
......
......@@ -16,7 +16,6 @@ limitations under the License. */
#include <sstream>
#include <string>
#include <unordered_map>
#include "glog/logging.h"
#include "paddle/phi/api/ext/exception.h"
......@@ -54,7 +53,8 @@ std::string Place::DebugString() const {
std::ostringstream os;
os << "Place(";
if (alloc_type_ == AllocationType::CUSTOM) {
os << GetGlobalDeviceType(device_type_id_);
os << phi::CustomRegisteredDeviceMap::Instance().GetGlobalDeviceType(
device_type_id_);
} else {
os << AllocationTypeStr(alloc_type_);
}
......@@ -85,25 +85,29 @@ Place GetPinnedPlace(const Place &place) {
}
}
static std::unordered_map<std::string, size_t> global_registered_device_type_id;
static std::unordered_map<size_t, std::string> global_registered_device_type;
CustomRegisteredDeviceMap &CustomRegisteredDeviceMap::Instance() {
static CustomRegisteredDeviceMap g_custom_registered_device_map;
return g_custom_registered_device_map;
}
size_t GetOrRegisterGlobalDeviceTypeId(const std::string &device_type) {
size_t CustomRegisteredDeviceMap::GetOrRegisterGlobalDeviceTypeId(
const std::string &device_type) {
if (device_type.empty()) return 0;
if (global_registered_device_type_id.find(device_type) ==
global_registered_device_type_id.end()) {
size_t device_type_id = global_registered_device_type_id.size() + 1;
global_registered_device_type_id[device_type] = device_type_id;
global_registered_device_type[device_type_id] = device_type;
if (registered_device_type_id_.find(device_type) ==
registered_device_type_id_.end()) {
size_t device_type_id = registered_device_type_id_.size() + 1;
registered_device_type_id_[device_type] = device_type_id;
registered_device_type_[device_type_id] = device_type;
}
return global_registered_device_type_id[device_type];
return registered_device_type_id_[device_type];
}
std::string GetGlobalDeviceType(size_t device_type_id) {
if (global_registered_device_type.find(device_type_id) ==
global_registered_device_type.end())
std::string CustomRegisteredDeviceMap::GetGlobalDeviceType(
size_t device_type_id) {
if (registered_device_type_.find(device_type_id) ==
registered_device_type_.end())
return "";
return global_registered_device_type[device_type_id];
return registered_device_type_[device_type_id];
}
constexpr static int kAllocationTypeBitLength = 8;
......@@ -143,7 +147,9 @@ static int8_t GetCorrectDeviceIdByPlaceType(
Place::Place(paddle::PlaceType type)
: device(detail::GetCorrectDeviceIdByPlaceType(type)),
alloc_type_(static_cast<AllocationType>(type)),
device_type_id_(GetOrRegisterGlobalDeviceTypeId("")) {
device_type_id_(
CustomRegisteredDeviceMap::Instance().GetOrRegisterGlobalDeviceTypeId(
"")) {
LOG_FIRST_N(WARNING, 1)
<< "The `paddle::PlaceType::kCPU/kGPU` is deprecated since version "
"2.3, and will be removed in version 2.4! Please use "
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include <string>
#include <unordered_map>
#include "paddle/phi/api/include/dll_decl.h"
......@@ -37,11 +38,21 @@ enum class AllocationType : int8_t {
CUSTOM = 9,
};
const char* AllocationTypeStr(AllocationType type);
class CustomRegisteredDeviceMap {
public:
static CustomRegisteredDeviceMap& Instance();
size_t GetOrRegisterGlobalDeviceTypeId(const std::string& device_type);
size_t GetOrRegisterGlobalDeviceTypeId(const std::string& device_type);
std::string GetGlobalDeviceType(size_t device_type_id_);
std::string GetGlobalDeviceType(size_t device_type_id_);
private:
CustomRegisteredDeviceMap() = default;
std::unordered_map<std::string, size_t> registered_device_type_id_;
std::unordered_map<size_t, std::string> registered_device_type_;
};
const char* AllocationTypeStr(AllocationType type);
/// \brief The place is used to specify where the data is stored.
class PADDLE_API Place {
......@@ -53,12 +64,14 @@ class PADDLE_API Place {
const std::string& dev_type = "")
: device(id),
alloc_type_(type),
device_type_id_(GetOrRegisterGlobalDeviceTypeId(dev_type)) {}
device_type_id_(phi::CustomRegisteredDeviceMap::Instance()
.GetOrRegisterGlobalDeviceTypeId(dev_type)) {}
explicit Place(AllocationType type, const std::string& dev_type = "")
: device(0),
alloc_type_(type),
device_type_id_(GetOrRegisterGlobalDeviceTypeId(dev_type)) {}
device_type_id_(phi::CustomRegisteredDeviceMap::Instance()
.GetOrRegisterGlobalDeviceTypeId(dev_type)) {}
// See NOTE [ Why need to temporarily adapt to PlaceType? ]
Place(paddle::PlaceType type); // NOLINT
......@@ -69,7 +82,8 @@ class PADDLE_API Place {
alloc_type_ = type;
device = device_id;
if (!dev_type.empty()) {
device_type_id_ = GetOrRegisterGlobalDeviceTypeId(dev_type);
device_type_id_ = phi::CustomRegisteredDeviceMap::Instance()
.GetOrRegisterGlobalDeviceTypeId(dev_type);
}
}
......@@ -78,7 +92,8 @@ class PADDLE_API Place {
int8_t GetDeviceId() const { return device; }
std::string GetDeviceType() const {
return GetGlobalDeviceType(device_type_id_);
return phi::CustomRegisteredDeviceMap::Instance().GetGlobalDeviceType(
device_type_id_);
}
std::string DebugString() const;
......
......@@ -110,6 +110,7 @@ class ArgumentMappingContext {
virtual bool IsSelectedRowsInput(const std::string& name) const = 0;
virtual bool IsSelectedRowsInputs(const std::string& name) const = 0;
virtual bool IsSparseCooTensorInput(const std::string& name) const = 0;
virtual bool IsSparseCooTensorOutput(const std::string& name) const = 0;
virtual bool IsSparseCsrTensorInput(const std::string& name) const = 0;
// For compatibility with LoDTensorArray
virtual bool IsDenseTensorVectorInput(const std::string& name) const = 0;
......
......@@ -46,7 +46,8 @@ Backend TransToPhiBackend(const phi::Place& place) {
case AllocationType::CUSTOM:
return static_cast<Backend>(
static_cast<size_t>(Backend::NUM_BACKENDS) +
GetOrRegisterGlobalDeviceTypeId(place.GetDeviceType()));
phi::CustomRegisteredDeviceMap::Instance()
.GetOrRegisterGlobalDeviceTypeId(place.GetDeviceType()));
default:
PADDLE_THROW(phi::errors::InvalidArgument(
"Unsupported transform %s to phi Backend.", place));
......@@ -91,7 +92,9 @@ phi::Place TransToPhiPlace(const Backend& backend, bool set_device_id) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
size_t device_type_id_ = static_cast<size_t>(backend) -
static_cast<size_t>(Backend::NUM_BACKENDS);
std::string device_type = phi::GetGlobalDeviceType(device_type_id_);
std::string device_type =
phi::CustomRegisteredDeviceMap::Instance().GetGlobalDeviceType(
device_type_id_);
if (!device_type.empty()) {
return phi::CustomPlace(
device_type,
......
......@@ -101,6 +101,12 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type ==
std::type_index(typeid(const phi::ExtendedTensor&))) {
args_def->AppendInput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == std::type_index(typeid(
const std::vector<const ExtendedTensor*>&))) {
args_def->AppendInput(default_key.backend(),
......
......@@ -265,6 +265,7 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(DenseTensor);
PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(SelectedRows);
PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(DenseTensor);
PD_SPECIALIZE_KernelCallHelper_FOR_INPUT(ExtendedTensor);
PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(ExtendedTensor);
PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(TensorBase);
PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(SelectedRows);
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature FeedOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorOutput("Out")) {
return KernelSignature("feed_dense_tensor", {"X"}, {"col"}, {"Out"});
} else if (ctx.IsSparseCooTensorOutput("Out")) {
return KernelSignature("feed_sparse_coo_tensor", {"X"}, {"col"}, {"Out"});
} else {
return KernelSignature("feed_strings", {"X"}, {"col"}, {"Out"});
}
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(feed, feed_dense_tensor);
PD_REGISTER_ARG_MAPPING_FN(feed, phi::FeedOpArgumentMapping);
......@@ -70,8 +70,10 @@ TEST(Backend, StringToBackend) {
#else
EXPECT_EQ(phi::Backend::KPS, pexp::StringToBackend("KPS"));
#endif
EXPECT_EQ(static_cast<phi::Backend>(
static_cast<size_t>(phi::Backend::NUM_BACKENDS) + 1),
EXPECT_EQ(static_cast<Backend>(
static_cast<size_t>(Backend::NUM_BACKENDS) +
phi::CustomRegisteredDeviceMap::Instance()
.GetOrRegisterGlobalDeviceTypeId("CustomBackend")),
pexp::StringToBackend("CustomBackend"));
}
......
......@@ -94,6 +94,10 @@ class TestArgumentMappingContext : public phi::ArgumentMappingContext {
return false;
}
bool IsSparseCooTensorOutput(const std::string& name) const override {
return false;
}
bool IsDenseTensorOutput(const std::string& name) const override {
return dense_tensor_outputs.count(name) > 0;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册