未验证 提交 3c557e2f 编写于 作者: Y YuanRisheng 提交者: GitHub

[BugFix]Fix bugs when compile with OneDNN (#50096)

* fix bugs

* fix ci bugs
上级 e48c882f
...@@ -26,6 +26,11 @@ namespace framework { ...@@ -26,6 +26,11 @@ namespace framework {
using FeedType = using FeedType =
paddle::variant<phi::DenseTensor, Strings, phi::SparseCooTensor>; paddle::variant<phi::DenseTensor, Strings, phi::SparseCooTensor>;
template <>
struct PhiVectorType<FeedType> {
const char *type_name = "PhiVectorFeedType";
};
using FeedList = paddle::framework::PhiVector<FeedType>; using FeedList = paddle::framework::PhiVector<FeedType>;
using FetchType = paddle::variant<phi::DenseTensor, using FetchType = paddle::variant<phi::DenseTensor,
......
...@@ -102,6 +102,14 @@ class Vocab : public phi::ExtendedTensor, ...@@ -102,6 +102,14 @@ class Vocab : public phi::ExtendedTensor,
// Kernel. It can be used when you define a non-tensor type that needs to be // Kernel. It can be used when you define a non-tensor type that needs to be
// stored in a vector as PHI kernel argument. // stored in a vector as PHI kernel argument.
template <typename T>
struct PhiVectorType;
template <>
struct PhiVectorType<std::string> {
const char* type_name = "PhiVectorString";
};
template <typename T> template <typename T>
class PhiVector : public phi::ExtendedTensor, class PhiVector : public phi::ExtendedTensor,
public phi::TypeInfoTraits<phi::TensorBase, PhiVector<T>> { public phi::TypeInfoTraits<phi::TensorBase, PhiVector<T>> {
...@@ -129,9 +137,7 @@ class PhiVector : public phi::ExtendedTensor, ...@@ -129,9 +137,7 @@ class PhiVector : public phi::ExtendedTensor,
public: public:
/// \brief Returns the name of the class for type traits. /// \brief Returns the name of the class for type traits.
/// \return The name of the class. /// \return The name of the class.
static const char* name() { static const char* name() { return PhiVectorType<T>().type_name; }
return (std::string("PhiVector_") + std::string(typeid(T).name())).c_str();
}
size_t size() const { return data_.size(); } size_t size() const { return data_.size(); }
......
...@@ -267,44 +267,6 @@ PD_REGISTER_GENERAL_KERNEL( ...@@ -267,44 +267,6 @@ PD_REGISTER_GENERAL_KERNEL(
ALL_LAYOUT, ALL_LAYOUT,
paddle::operators::FeedStringsKernel<phi::XPUContext>, paddle::operators::FeedStringsKernel<phi::XPUContext>,
ALL_DTYPE) {} ALL_DTYPE) {}
#elif defined(PADDLE_WITH_ASCEND_CL)
PD_REGISTER_GENERAL_KERNEL(
feed_dense_tensor,
npu,
ALL_LAYOUT,
paddle::operators::FeedDenseTensorKernel<phi::CustomContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(
feed_sparse_coo_tensor,
npu,
ALL_LAYOUT,
paddle::operators::FeedSparseCooTensorKernel<phi::CustomContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(
feed_strings,
npu,
ALL_LAYOUT,
paddle::operators::FeedStringsKernel<phi::CustomContext>,
ALL_DTYPE) {}
#elif defined(PADDLE_WITH_MLU)
PD_REGISTER_GENERAL_KERNEL(
feed_dense_tensor,
CustomMLU,
ALL_LAYOUT,
paddle::operators::FeedDenseTensorKernel<phi::CustomContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(
feed_sparse_coo_tensor,
CustomMLU,
ALL_LAYOUT,
paddle::operators::FeedSparseCooTensorKernel<phi::CustomContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(
feed_strings,
CustomMLU,
ALL_LAYOUT,
paddle::operators::FeedStringsKernel<phi::CustomContext>,
ALL_DTYPE) {}
#endif #endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE #ifdef PADDLE_WITH_CUSTOM_DEVICE
namespace paddle { namespace paddle {
......
...@@ -27,7 +27,7 @@ math_library(sequence_scale) ...@@ -27,7 +27,7 @@ math_library(sequence_scale)
cc_library( cc_library(
phi_data_layout_transform phi_data_layout_transform
SRCS data_layout_transform.cc SRCS data_layout_transform.cc
DEPS tensor) DEPS tensor blas)
if(WITH_GPU OR WITH_ROCM) if(WITH_GPU OR WITH_ROCM)
if(MKL_FOUND AND WITH_ONEMKL) if(MKL_FOUND AND WITH_ONEMKL)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册