未验证 提交 ecf892f0 编写于 作者: Y YuanRisheng 提交者: GitHub

[PHI]Add new Tensor type and migrate save_combine kernel (#47856)

* add new tensor

* fix windows compile bugs

* fix ci bugs

* fix ci bugs

* fix ci bugs

* perfect according comment

* fix ci compile bugs

* add raw tensor

* fix ci bugs

* modify code by comment

* delete String
上级 737fbdba
...@@ -26,6 +26,35 @@ function(find_register FILENAME PATTERN OUTPUT) ...@@ -26,6 +26,35 @@ function(find_register FILENAME PATTERN OUTPUT)
PARENT_SCOPE) PARENT_SCOPE)
endfunction() endfunction()
function(find_phi_register FILENAME ADD_PATH)
# set op_name to OUTPUT
set(options "")
set(oneValueArgs "")
set(multiValueArgs "")
file(READ ${FILENAME} CONTENT)
string(
REGEX
MATCH
"PD_REGISTER_KERNEL\\([ \t\r\n]*[a-z0-9_]*,[[ \\\t\r\n\/]*[a-z0-9_]*]?[ \\\t\r\n]*[a-zA-Z]*,[ \\\t\r\n]*[A-Z_]*"
register
"${CONTENT}")
if(NOT register STREQUAL "")
string(REPLACE "PD_REGISTER_KERNEL(" "" register "${register}")
string(REPLACE "," ";" register "${register}")
string(REGEX REPLACE "[ \\\t\r\n]+" "" register "${register}")
string(REGEX REPLACE "//cuda_only" "" register "${register}")
list(GET register 0 kernel_name)
list(GET register 1 kernel_backend)
list(GET register 2 kernel_layout)
file(
APPEND ${ADD_PATH}
"PD_DECLARE_KERNEL(${kernel_name}, ${kernel_backend}, ${kernel_layout});\n"
)
endif()
endfunction()
function(op_library TARGET) function(op_library TARGET)
# op_library is a function to create op library. The interface is same as # op_library is a function to create op library. The interface is same as
# cc_library. But it handle split GPU/CPU code and link some common library # cc_library. But it handle split GPU/CPU code and link some common library
...@@ -371,6 +400,8 @@ function(op_library TARGET) ...@@ -371,6 +400,8 @@ function(op_library TARGET)
foreach(cc_src ${cc_srcs}) foreach(cc_src ${cc_srcs})
# pybind USE_OP_ITSELF # pybind USE_OP_ITSELF
set(op_name "") set(op_name "")
# Add PHI Kernel Registry Message
find_phi_register(${cc_src} ${pybind_file})
find_register(${cc_src} "REGISTER_OPERATOR" op_name) find_register(${cc_src} "REGISTER_OPERATOR" op_name)
if(NOT ${op_name} EQUAL "") if(NOT ${op_name} EQUAL "")
file(APPEND ${pybind_file} "USE_OP_ITSELF(${op_name});\n") file(APPEND ${pybind_file} "USE_OP_ITSELF(${op_name});\n")
...@@ -408,6 +439,8 @@ function(op_library TARGET) ...@@ -408,6 +439,8 @@ function(op_library TARGET)
# message("cu_srcs ${cu_srcs}") # message("cu_srcs ${cu_srcs}")
foreach(cu_src ${cu_srcs}) foreach(cu_src ${cu_srcs})
set(op_name "") set(op_name "")
# Add PHI Kernel Registry Message
find_phi_register(${cu_src} ${pybind_file})
find_register(${cu_src} "REGISTER_OP_CUDA_KERNEL" op_name) find_register(${cu_src} "REGISTER_OP_CUDA_KERNEL" op_name)
if(NOT ${op_name} EQUAL "") if(NOT ${op_name} EQUAL "")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, CUDA);\n") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, CUDA);\n")
......
...@@ -115,7 +115,7 @@ proto_library(trainer_desc_proto SRCS trainer_desc.proto DEPS framework_proto ...@@ -115,7 +115,7 @@ proto_library(trainer_desc_proto SRCS trainer_desc.proto DEPS framework_proto
cc_library( cc_library(
string_array string_array
SRCS string_array.cc SRCS string_array.cc
DEPS utf8proc) DEPS utf8proc phi_enforce)
cc_library( cc_library(
data_type data_type
...@@ -233,7 +233,8 @@ cc_test( ...@@ -233,7 +233,8 @@ cc_test(
cc_library( cc_library(
var_type_traits var_type_traits
SRCS var_type_traits.cc SRCS var_type_traits.cc
DEPS framework_proto scope tensor_array sparse_coo_tensor sparse_csr_tensor) DEPS framework_proto scope tensor_array sparse_coo_tensor sparse_csr_tensor
extended_tensor)
if(WITH_GPU) if(WITH_GPU)
target_link_libraries(var_type_traits dynload_cuda) target_link_libraries(var_type_traits dynload_cuda)
endif() endif()
......
...@@ -24,6 +24,7 @@ limitations under the License. */ ...@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/fluid/framework/details/nan_inf_utils.h" #include "paddle/fluid/framework/details/nan_inf_utils.h"
#include "paddle/fluid/framework/op_call_stack.h" #include "paddle/fluid/framework/op_call_stack.h"
#include "paddle/fluid/framework/phi_utils.h" #include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/framework/raw_tensor.h"
#include "paddle/fluid/framework/shape_inference.h" #include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/transfer_scope_cache.h" #include "paddle/fluid/framework/transfer_scope_cache.h"
#include "paddle/fluid/framework/unused_var_check.h" #include "paddle/fluid/framework/unused_var_check.h"
...@@ -3008,6 +3009,9 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -3008,6 +3009,9 @@ void OperatorWithKernel::BuildPhiKernelContext(
need_prepare_phi_data_ = true; need_prepare_phi_data_ = true;
tensor_in = &(var->Get<framework::LoDTensorArray>()); tensor_in = &(var->Get<framework::LoDTensorArray>());
phi_kernel_context->EmplaceBackInputWithoutSetRange(tensor_in); phi_kernel_context->EmplaceBackInputWithoutSetRange(tensor_in);
} else if (var->IsType<framework::Vocab>()) {
tensor_in = &(var->Get<framework::Vocab>());
phi_kernel_context->EmplaceBackInputWithoutSetRange(tensor_in);
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported input `%s` type when call pt kernel.", "Unsupported input `%s` type when call pt kernel.",
...@@ -3057,6 +3061,13 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -3057,6 +3061,13 @@ void OperatorWithKernel::BuildPhiKernelContext(
// Note: If the input LoDTensorArray size is 0, the output // Note: If the input LoDTensorArray size is 0, the output
// LoDTensorArray is also 0 // LoDTensorArray is also 0
phi_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out); phi_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out);
} else if (var->template IsType<paddle::framework::RawTensor>()) {
tensor_out = var->template GetMutable<paddle::framework::RawTensor>();
phi_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out);
} else if (!var->IsInitialized()) {
// The following is for RAW type of var
tensor_out = var->template GetMutable<paddle::framework::RawTensor>();
phi_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out);
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported output `%s` type when call pt kernel.", "Unsupported output `%s` type when call pt kernel.",
...@@ -3156,6 +3167,7 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -3156,6 +3167,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
} }
} }
break; break;
case phi::AttributeType::SCALARS: { case phi::AttributeType::SCALARS: {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
attr_iter, attr_iter,
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <unordered_map>
#include "paddle/phi/core/extended_tensor.h"
#include "paddle/utils/any.h"
namespace paddle {
namespace framework {
/// \brief Fluid Kernel and PHI Kernel will be unified in the future.
/// So, we need a class in PHI that can represent the RAW type in Fluid.
/// The RawTensor is for PHI Kernel that has RAW type arguments.
class RawTensor : public phi::ExtendedTensor,
public phi::TypeInfoTraits<phi::TensorBase, RawTensor> {
public:
RawTensor() = default;
RawTensor(RawTensor&& other) = default;
RawTensor(const RawTensor& other) = default;
RawTensor& operator=(RawTensor&& other) = default;
/// \brief Destroy the RawTensor and release exclusive resources.
virtual ~RawTensor() = default;
public:
/// \brief Returns the name of the class for type traits.
/// \return The name of the class.
static const char* name() { return "RawTensor"; }
template <typename T>
T* GetMutable() {
if (!data_.empty()) {
try {
return paddle::any_cast<T*>(data_);
} catch (paddle::bad_any_cast&) {
PADDLE_THROW(phi::errors::InvalidArgument(
"Invalid data type error, expected %s, actual %s.",
typeid(T).name(),
data_type_.name()));
}
}
T* created_data = new T();
data_ = created_data;
data_deleter_ = [created_data]() { delete created_data; };
data_type_ = std::type_index(typeid(T));
return created_data;
}
template <typename T>
bool IsType() const {
return std::type_index(typeid(T)) == data_type_;
}
private:
paddle::any data_;
std::function<void(void)> data_deleter_;
std::type_index data_type_ = std::type_index(typeid(void));
};
} // namespace framework
} // namespace paddle
...@@ -20,13 +20,82 @@ limitations under the License. */ ...@@ -20,13 +20,82 @@ limitations under the License. */
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/phi/core/extended_tensor.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class Vocab : public phi::ExtendedTensor,
public phi::TypeInfoTraits<phi::TensorBase, Vocab> {
public:
Vocab() = default;
Vocab(Vocab&& other) = default;
Vocab(const Vocab& other) = default;
Vocab& operator=(const Vocab& other) = default;
Vocab& operator=(Vocab&& other) = default;
Vocab& operator=(
const std::unordered_map<std::wstring, std::int32_t>& other) {
this->data_ = other;
return *this;
}
/// \brief Destroy the Vocab and release exclusive resources.
virtual ~Vocab() = default;
public:
/// \brief Returns the name of the class for type traits.
/// \return The name of the class.
static const char* name() { return "Vocab"; }
size_t size() const { return data_.size(); }
void clear() { data_.clear(); }
void emplace(const std::wstring& key, std::int32_t value) {
data_.emplace(key, value);
}
std::int32_t at(const std::wstring& key) { return data_.at(key); }
std::int32_t at(const std::wstring& key) const { return data_.at(key); }
std::unordered_map<std::wstring, std::int32_t>::iterator find(
const std::wstring& key) {
return data_.find(key);
}
std::unordered_map<std::wstring, std::int32_t>::const_iterator find(
const std::wstring& key) const {
return data_.find(key);
}
std::unordered_map<std::wstring, std::int32_t>::iterator begin() {
return data_.begin();
}
std::unordered_map<std::wstring, std::int32_t>::const_iterator begin() const {
return data_.begin();
}
std::unordered_map<std::wstring, std::int32_t>::iterator end() {
return data_.end();
}
std::unordered_map<std::wstring, std::int32_t>::const_iterator end() const {
return data_.end();
}
private:
std::unordered_map<std::wstring, std::int32_t> data_;
};
using String = std::string; using String = std::string;
using Strings = std::vector<std::string>; using Strings = std::vector<std::string>;
using Vocab = std::unordered_map<std::wstring, std::int32_t>;
// Convert the std::string type to the std::string type. // Convert the std::string type to the std::string type.
bool ConvertStrToWstr(const std::string& src, std::wstring* res); bool ConvertStrToWstr(const std::string& src, std::wstring* res);
......
...@@ -41,6 +41,7 @@ ...@@ -41,6 +41,7 @@
#include "paddle/fluid/platform/device/xpu/bkcl_helper.h" #include "paddle/fluid/platform/device/xpu/bkcl_helper.h"
#endif #endif
#include "paddle/fluid/framework/raw_tensor.h"
#include "paddle/fluid/operators/cuda_graph_with_in_out.h" #include "paddle/fluid/operators/cuda_graph_with_in_out.h"
namespace paddle { namespace paddle {
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/raw_tensor.h"
#include "paddle/fluid/framework/string_array.h" #include "paddle/fluid/framework/string_array.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -219,7 +220,8 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl< ...@@ -219,7 +220,8 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl<
float, float,
Vocab, Vocab,
std::vector<int>, std::vector<int>,
std::vector<float>>; std::vector<float>,
RawTensor>;
template <typename T> template <typename T>
struct VarTypeTrait { struct VarTypeTrait {
static_assert(VarTypeRegistry::IsRegistered<T>(), "Must be registered type"); static_assert(VarTypeRegistry::IsRegistered<T>(), "Must be registered type");
......
...@@ -38,6 +38,7 @@ ...@@ -38,6 +38,7 @@
#if defined(PADDLE_WITH_XPU_BKCL) #if defined(PADDLE_WITH_XPU_BKCL)
#include "paddle/fluid/platform/device/xpu/bkcl_helper.h" #include "paddle/fluid/platform/device/xpu/bkcl_helper.h"
#endif #endif
#include "paddle/fluid/framework/raw_tensor.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -22,10 +22,10 @@ namespace framework { ...@@ -22,10 +22,10 @@ namespace framework {
TEST(Variable, GetMutable) { TEST(Variable, GetMutable) {
std::unique_ptr<Variable> v(new Variable()); std::unique_ptr<Variable> v(new Variable());
auto* t = v->GetMutable<std::string>(); auto* t = v->GetMutable<String>();
*t = "1234"; *t = "1234";
const auto& tt = v->Get<std::string>(); const auto& tt = v->Get<String>();
EXPECT_EQ("1234", tt); EXPECT_EQ("1234", tt);
try { try {
......
...@@ -5,7 +5,7 @@ cc_library( ...@@ -5,7 +5,7 @@ cc_library(
cc_library( cc_library(
var_helper var_helper
SRCS var_helper.cc SRCS var_helper.cc
DEPS tensor selected_rows) DEPS tensor selected_rows extended_tensor)
if(WITH_XPU) if(WITH_XPU)
cc_library( cc_library(
prepared_operator prepared_operator
......
...@@ -89,7 +89,7 @@ std::vector<std::string> Layer::FunctionNames() const { ...@@ -89,7 +89,7 @@ std::vector<std::string> Layer::FunctionNames() const {
PD_SPECIALZE_ATTRIBUTE_TYPE(int) PD_SPECIALZE_ATTRIBUTE_TYPE(int)
PD_SPECIALZE_ATTRIBUTE_TYPE(float) PD_SPECIALZE_ATTRIBUTE_TYPE(float)
PD_SPECIALZE_ATTRIBUTE_TYPE(std::string) PD_SPECIALZE_ATTRIBUTE_TYPE(framework::String)
PD_SPECIALZE_ATTRIBUTE_TYPE(std::vector<int>) PD_SPECIALZE_ATTRIBUTE_TYPE(std::vector<int>)
PD_SPECIALZE_ATTRIBUTE_TYPE(std::vector<float>) PD_SPECIALZE_ATTRIBUTE_TYPE(std::vector<float>)
PD_SPECIALZE_ATTRIBUTE_TYPE(std::vector<std::string>) PD_SPECIALZE_ATTRIBUTE_TYPE(std::vector<std::string>)
......
...@@ -86,7 +86,7 @@ TEST(CpuLayerTest, Construct) { ...@@ -86,7 +86,7 @@ TEST(CpuLayerTest, Construct) {
int ds = layer.Attribute<int>("down_sampling"); int ds = layer.Attribute<int>("down_sampling");
EXPECT_EQ(ds, 4); EXPECT_EQ(ds, 4);
std::string fstr = layer.Attribute<std::string>("fstr"); std::string fstr = layer.Attribute<framework::String>("fstr");
EXPECT_STREQ(fstr.c_str(), "save str property"); EXPECT_STREQ(fstr.c_str(), "save str property");
std::vector<int> ints = layer.Attribute<std::vector<int>>("ints"); std::vector<int> ints = layer.Attribute<std::vector<int>>("ints");
......
...@@ -97,7 +97,7 @@ std::unordered_map<std::string, std::shared_ptr<Variable>> Property::Values() { ...@@ -97,7 +97,7 @@ std::unordered_map<std::string, std::shared_ptr<Variable>> Property::Values() {
*var->GetMutable<int>() = static_cast<int>(GetInt64(n)); *var->GetMutable<int>() = static_cast<int>(GetInt64(n));
break; break;
case ValueProto::STRING: case ValueProto::STRING:
*var->GetMutable<std::string>() = GetString(n); *var->GetMutable<paddle::framework::String>() = GetString(n);
break; break;
case ValueProto::FLOATS: case ValueProto::FLOATS:
*var->GetMutable<std::vector<float>>() = GetFloats(n); *var->GetMutable<std::vector<float>>() = GetFloats(n);
......
...@@ -12,7 +12,7 @@ unset(OP_LIBRARY CACHE) ...@@ -12,7 +12,7 @@ unset(OP_LIBRARY CACHE)
set(pybind_file ${PADDLE_BINARY_DIR}/paddle/fluid/pybind/pybind.h.tmp CACHE INTERNAL "pybind.h file") set(pybind_file ${PADDLE_BINARY_DIR}/paddle/fluid/pybind/pybind.h.tmp CACHE INTERNAL "pybind.h file")
set(pybind_file_prune ${PADDLE_BINARY_DIR}/paddle/fluid/pybind/pybind.h.prune CACHE INTERNAL "pybind.h file") set(pybind_file_prune ${PADDLE_BINARY_DIR}/paddle/fluid/pybind/pybind.h.prune CACHE INTERNAL "pybind.h file")
set(pybind_file_final ${PADDLE_BINARY_DIR}/paddle/fluid/pybind/pybind.h) set(pybind_file_final ${PADDLE_BINARY_DIR}/paddle/fluid/pybind/pybind.h)
file(WRITE ${pybind_file} "// Generated by the paddle/fluid/operators/CMakeLists.txt. DO NOT EDIT!\n\n") file(WRITE ${pybind_file} "#include \"paddle/phi/core/kernel_registry.h\" // Generated by the paddle/fluid/operators/CMakeLists.txt. DO NOT EDIT!\n\n")
add_subdirectory(math) add_subdirectory(math)
add_subdirectory(controlflow) add_subdirectory(controlflow)
...@@ -109,7 +109,7 @@ register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op load_combin ...@@ -109,7 +109,7 @@ register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op load_combin
op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc run_program_op_npu.cc DEPS executor_cache ${OP_HEADER_DEPS}) op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc run_program_op_npu.cc DEPS executor_cache ${OP_HEADER_DEPS})
target_link_libraries(run_program_op cuda_graph_with_memory_pool) target_link_libraries(run_program_op cuda_graph_with_memory_pool)
op_library(quantize_linear_op DEPS phi) op_library(quantize_linear_op DEPS phi)
op_library(save_combine_op DEPS string_array) op_library(save_combine_op DEPS string_array phi)
op_library(load_combine_op DEPS string_array) op_library(load_combine_op DEPS string_array)
if (WITH_GPU OR WITH_ROCM) if (WITH_GPU OR WITH_ROCM)
......
...@@ -16,6 +16,10 @@ limitations under the License. */ ...@@ -16,6 +16,10 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/kernel_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -100,10 +104,22 @@ REGISTER_OPERATOR(save_combine, ...@@ -100,10 +104,22 @@ REGISTER_OPERATOR(save_combine,
ops::SaveCombineOpProtoMaker, ops::SaveCombineOpProtoMaker,
ops::SaveCombineOpInferVarType); ops::SaveCombineOpInferVarType);
REGISTER_OP_CPU_KERNEL( PD_REGISTER_KERNEL(save_combine_tensor,
save_combine, CPU,
ops::SaveCombineOpKernel<phi::CPUContext, float>, ALL_LAYOUT,
ops::SaveCombineOpKernel<phi::CPUContext, double>, paddle::operators::SaveCombineTensorKernel,
ops::SaveCombineOpKernel<phi::CPUContext, paddle::platform::bfloat16>, int,
ops::SaveCombineOpKernel<phi::CPUContext, int>, int64_t,
ops::SaveCombineOpKernel<phi::CPUContext, int64_t>); float,
double,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(save_combine_vocab,
CPU,
ALL_LAYOUT,
paddle::operators::SaveCombineVocabKernel,
int,
int64_t,
float,
double,
phi::dtype::bfloat16) {}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
You may obtain a copy of the License at // You may obtain a copy of the License at
//
http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
//
Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
limitations under the License. */ // limitations under the License.
#include "paddle/fluid/operators/save_combine_op.h" #include "paddle/fluid/operators/save_combine_op.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace ops = paddle::operators; PD_REGISTER_KERNEL(save_combine_tensor,
GPU,
ALL_LAYOUT,
paddle::operators::SaveCombineTensorKernel,
int,
int64_t,
float,
double) {}
REGISTER_OP_CUDA_KERNEL(save_combine, PD_REGISTER_KERNEL(save_combine_vocab,
ops::SaveCombineOpKernel<phi::GPUContext, float>, GPU,
ops::SaveCombineOpKernel<phi::GPUContext, double>, ALL_LAYOUT,
ops::SaveCombineOpKernel<phi::GPUContext, int>, paddle::operators::SaveCombineVocabKernel,
ops::SaveCombineOpKernel<phi::GPUContext, int64_t>); int,
int64_t,
float,
double) {}
...@@ -27,35 +27,161 @@ limitations under the License. */ ...@@ -27,35 +27,161 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/raw_tensor.h"
#include "paddle/fluid/framework/string_array.h" #include "paddle/fluid/framework/string_array.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/backends/dynload/port.h" #include "paddle/phi/backends/dynload/port.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/serialization.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T>
class SaveCombineOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto place = ctx.GetPlace();
auto filename = ctx.Attr<std::string>("file_path");
auto overwrite = ctx.Attr<bool>("overwrite");
auto save_as_fp16 = ctx.Attr<bool>("save_as_fp16");
auto save_to_memory = ctx.Attr<bool>("save_to_memory");
auto output = ctx.Output<std::string>("Y");
bool is_present = FileExists(filename); inline void SaveToMemory(const std::string& file_path,
const std::ostringstream& ss,
bool save_to_memory,
std::string* output) {
if (save_to_memory) {
PADDLE_ENFORCE_NE(output,
nullptr,
phi::errors::InvalidArgument(
"Cannot find variable Y for save_combine_op"));
*output = ss.str();
} else {
MkDirRecursively(DirName(file_path).c_str());
std::ofstream fout(file_path, std::ios::binary);
PADDLE_ENFORCE_EQ(static_cast<bool>(fout),
true,
phi::errors::Unavailable(
"Cannot open %s to save variables.", file_path));
fout << ss.str();
fout.close();
}
}
template <typename T, typename Context>
void SaveCombineTensorKernel(const Context& dev_ctx,
const std::vector<const phi::DenseTensor*>& x,
const std::string& file_path,
bool overwrite,
bool save_as_fp16,
bool save_to_memory,
phi::ExtendedTensor* out) {
std::string* y = nullptr;
if (out != nullptr) {
auto raw_out = static_cast<paddle::framework::RawTensor*>(out);
y = raw_out->GetMutable<std::string>();
}
bool is_present = FileExists(file_path);
if (is_present && !overwrite) { if (is_present && !overwrite) {
PADDLE_THROW(platform::errors::PreconditionNotMet( PADDLE_THROW(phi::errors::PreconditionNotMet(
"%s exists! Cannot save_combine to it when overwrite is set to " "%s exists! Cannot save_combine to it when overwrite is set to "
"false.", "false.",
filename, file_path,
overwrite)); overwrite));
} }
std::ostringstream ss; std::ostringstream ss;
PADDLE_ENFORCE_GT(x.size(),
0UL,
phi::errors::InvalidArgument(
"The number of variables to be saved is %d, expect "
"it to be greater than 0.",
x.size()));
for (size_t i = 0; i < x.size(); i++) {
auto& tensor = *(x[i]);
PADDLE_ENFORCE_EQ(
tensor.IsInitialized(),
true,
phi::errors::InvalidArgument(
"The Tensor with Index (%d) to be saved is not initialized.", i));
// Serialize tensors one by one
// Check types to see if a fp16 transformation is required
auto in_dtype = framework::TransToProtoVarType(tensor.dtype());
auto out_dtype = save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
if (in_dtype != out_dtype) {
auto place = dev_ctx.GetPlace();
auto in_kernel_type = framework::OpKernelType(in_dtype, place);
auto out_kernel_type = framework::OpKernelType(out_dtype, place);
phi::DenseTensor out;
framework::TransDataType(in_kernel_type, out_kernel_type, tensor, &out);
// copy LoD info to the new tensor
out.set_lod(tensor.lod());
phi::SerializeToStream(ss, out, dev_ctx);
} else {
phi::SerializeToStream(ss, tensor, dev_ctx);
}
}
SaveToMemory(file_path, ss, save_to_memory, y);
}
template <typename T, typename Context>
void SaveCombineVocabKernel(
const Context& dev_ctx,
const std::vector<const phi::ExtendedTensor*>& inputs,
const std::string& file_path,
bool overwrite,
bool save_as_fp16,
bool save_to_memory,
phi::ExtendedTensor* out) {
std::string* y = nullptr;
if (out != nullptr) {
auto raw_out = static_cast<paddle::framework::RawTensor*>(out);
y = raw_out->GetMutable<std::string>();
}
std::vector<const framework::Vocab*> x;
x.reserve(inputs.size());
for (auto input : inputs) {
x.push_back(static_cast<const framework::Vocab*>(input));
}
bool is_present = FileExists(file_path);
if (is_present && !overwrite) {
PADDLE_THROW(phi::errors::PreconditionNotMet(
"%s exists! Cannot save_combine to it when overwrite is set to "
"false.",
file_path,
overwrite));
}
std::ostringstream ss;
PADDLE_ENFORCE_GT(x.size(),
0UL,
phi::errors::InvalidArgument(
"The number of variables to be saved is %d, expect "
"it to be greater than 0.",
x.size()));
for (size_t i = 0; i < x.size(); i++) {
auto& tensor = *(x[i]);
std::unordered_map<std::string, std::int32_t> data;
for (auto it = tensor.begin(); it != tensor.end(); ++it) {
std::string t;
paddle::framework::ConvertWstrToStr(it->first, &t);
data.emplace(t, it->second);
}
paddle::framework::StringMapToStream(ss, data);
}
SaveToMemory(file_path, ss, save_to_memory, y);
}
template <typename DeviceContext, typename T>
class SaveCombineOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto place = ctx.GetPlace();
auto filename = ctx.Attr<std::string>("file_path");
auto overwrite = ctx.Attr<bool>("overwrite");
auto save_as_fp16 = ctx.Attr<bool>("save_as_fp16");
auto save_to_memory = ctx.Attr<bool>("save_to_memory");
auto output = ctx.Output<framework::RawTensor>("Y");
auto inp_var_names = ctx.InputNames("X"); auto inp_var_names = ctx.InputNames("X");
auto &inp_vars = ctx.MultiInputVar("X"); auto& inp_vars = ctx.MultiInputVar("X");
PADDLE_ENFORCE_GT(inp_var_names.size(), PADDLE_ENFORCE_GT(inp_var_names.size(),
0UL, 0UL,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -64,8 +190,8 @@ class SaveCombineOpKernel : public framework::OpKernel<T> { ...@@ -64,8 +190,8 @@ class SaveCombineOpKernel : public framework::OpKernel<T> {
inp_var_names.size())); inp_var_names.size()));
// get device context from pool // get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place); auto& dev_ctx = *pool.Get(place);
for (size_t i = 0; i < inp_var_names.size(); i++) { for (size_t i = 0; i < inp_var_names.size(); i++) {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
...@@ -81,58 +207,31 @@ class SaveCombineOpKernel : public framework::OpKernel<T> { ...@@ -81,58 +207,31 @@ class SaveCombineOpKernel : public framework::OpKernel<T> {
"phi::DenseTensor or Vocab variable, %s has wrong type.", "phi::DenseTensor or Vocab variable, %s has wrong type.",
inp_var_names[i])); inp_var_names[i]));
if (inp_vars[i]->IsType<phi::DenseTensor>()) { if (inp_vars.size() > 0 && inp_vars[0]->IsType<phi::DenseTensor>()) {
auto &tensor = inp_vars[i]->Get<phi::DenseTensor>(); std::vector<const phi::DenseTensor*> x(inp_vars.size());
PADDLE_ENFORCE_EQ( for (auto inp_var : inp_vars) {
tensor.IsInitialized(), x.push_back(&(inp_var->Get<phi::DenseTensor>()));
true,
platform::errors::InvalidArgument(
"The Tensor of Variable(%s) to be saved is not initialized.",
inp_var_names[i]));
// Serialize tensors one by one
// Check types to see if a fp16 transformation is required
auto in_dtype = framework::TransToProtoVarType(tensor.dtype());
auto out_dtype =
save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
if (in_dtype != out_dtype) {
auto in_kernel_type = framework::OpKernelType(in_dtype, place);
auto out_kernel_type = framework::OpKernelType(out_dtype, place);
phi::DenseTensor out;
// copy LoD info to the new tensor
out.set_lod(tensor.lod());
framework::TransDataType(
in_kernel_type, out_kernel_type, tensor, &out);
framework::SerializeToStream(ss, out, dev_ctx);
} else {
framework::SerializeToStream(ss, tensor, dev_ctx);
} }
SaveCombineTensorKernel<T>(dev_ctx,
x,
filename,
overwrite,
save_as_fp16,
save_to_memory,
output);
} else { } else {
auto &tensor = inp_vars[i]->Get<framework::Vocab>(); std::vector<const phi::ExtendedTensor*> x(inp_vars.size());
std::unordered_map<std::string, std::int32_t> data; for (auto inp_var : inp_vars) {
for (auto it = tensor.begin(); it != tensor.end(); ++it) { x.push_back(&(inp_var->Get<framework::Vocab>()));
std::string t;
framework::ConvertWstrToStr(it->first, &t);
data.emplace(t, it->second);
}
framework::StringMapToStream(ss, data);
} }
SaveCombineVocabKernel<T>(dev_ctx,
x,
filename,
overwrite,
save_as_fp16,
save_to_memory,
output);
} }
if (save_to_memory) {
PADDLE_ENFORCE_NE(output,
nullptr,
platform::errors::InvalidArgument(
"Cannot find variable Y for save_combine_op"));
*output = ss.str();
} else {
MkDirRecursively(DirName(filename).c_str());
std::ofstream fout(filename, std::ios::binary);
PADDLE_ENFORCE_EQ(static_cast<bool>(fout),
true,
platform::errors::Unavailable(
"Cannot open %s to save variables.", filename));
fout << ss.str();
fout.close();
} }
} }
}; };
......
...@@ -20,8 +20,10 @@ limitations under the License. */ ...@@ -20,8 +20,10 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "paddle/phi/core/kernel_registry.h"
USE_CPU_ONLY_OP(save_combine); USE_OP_ITSELF(save_combine);
PD_DECLARE_KERNEL(save_combine_tensor, CPU, ALL_LAYOUT);
USE_CPU_ONLY_OP(load_combine); USE_CPU_ONLY_OP(load_combine);
template <typename T, typename U> template <typename T, typename U>
......
...@@ -29,6 +29,7 @@ typedef SSIZE_T ssize_t; ...@@ -29,6 +29,7 @@ typedef SSIZE_T ssize_t;
#include "paddle/fluid/eager/hooks.h" #include "paddle/fluid/eager/hooks.h"
#include "paddle/fluid/eager/utils.h" #include "paddle/fluid/eager/utils.h"
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/string_array.h"
#include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -1493,7 +1494,7 @@ static PyObject* tensor_method_set_vocab(TensorObject* self, ...@@ -1493,7 +1494,7 @@ static PyObject* tensor_method_set_vocab(TensorObject* self,
PyObject* args, PyObject* args,
PyObject* kwargs) { PyObject* kwargs) {
EAGER_TRY EAGER_TRY
using Vocab = std::unordered_map<std::wstring, int>; using Vocab = paddle::framework::Vocab;
auto vocab = CastPyArg2Vocab(PyTuple_GET_ITEM(args, 0), 0); auto vocab = CastPyArg2Vocab(PyTuple_GET_ITEM(args, 0), 0);
auto var_tensor = std::make_shared<egr::VariableCompatTensor>(); auto var_tensor = std::make_shared<egr::VariableCompatTensor>();
*var_tensor->GetMutable<Vocab>() = vocab; *var_tensor->GetMutable<Vocab>() = vocab;
...@@ -1524,7 +1525,7 @@ static PyObject* tensor_method_get_map_tensor(TensorObject* self, ...@@ -1524,7 +1525,7 @@ static PyObject* tensor_method_get_map_tensor(TensorObject* self,
true, true,
paddle::platform::errors::Fatal( paddle::platform::errors::Fatal(
"this method is only effective for VariableCompatTensor")); "this method is only effective for VariableCompatTensor"));
using Vocab = std::unordered_map<std::wstring, int>; using Vocab = paddle::framework::Vocab;
auto* var_tensor = auto* var_tensor =
static_cast<const egr::VariableCompatTensor*>(self->tensor.impl().get()); static_cast<const egr::VariableCompatTensor*>(self->tensor.impl().get());
return ToPyObject(var_tensor->Get<Vocab>()); return ToPyObject(var_tensor->Get<Vocab>());
......
...@@ -590,11 +590,12 @@ paddle::framework::proto::VarType::Type CastPyArg2ProtoType(PyObject* obj, ...@@ -590,11 +590,12 @@ paddle::framework::proto::VarType::Type CastPyArg2ProtoType(PyObject* obj,
return dtype; return dtype;
} }
std::unordered_map<std::wstring, int> CastPyArg2Vocab(PyObject* obj, paddle::framework::Vocab CastPyArg2Vocab(PyObject* obj, ssize_t arg_pos) {
ssize_t arg_pos) {
if (PyDict_Check(obj)) { if (PyDict_Check(obj)) {
return ::pybind11::handle(obj) paddle::framework::Vocab vocab;
.cast<std::unordered_map<std::wstring, int>>(); vocab = ::pybind11::handle(obj)
.cast<std::unordered_map<std::wstring, std::int32_t>>();
return vocab;
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"argument (position %d) must be dict, but got %s", "argument (position %d) must be dict, but got %s",
...@@ -887,7 +888,7 @@ PyObject* ToPyObject( ...@@ -887,7 +888,7 @@ PyObject* ToPyObject(
return dict; return dict;
} }
PyObject* ToPyObject(const std::unordered_map<std::wstring, int>& value) { PyObject* ToPyObject(const paddle::framework::Vocab& value) {
PyObject* dict = PyDict_New(); PyObject* dict = PyDict_New();
for (const auto& map_iter : value) { for (const auto& map_iter : value) {
// Convert Key // Convert Key
......
...@@ -20,6 +20,7 @@ typedef SSIZE_T ssize_t; ...@@ -20,6 +20,7 @@ typedef SSIZE_T ssize_t;
#include "paddle/fluid/eager/hooks.h" #include "paddle/fluid/eager/hooks.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/string_array.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/jit/function.h" #include "paddle/fluid/jit/function.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
...@@ -74,8 +75,7 @@ std::vector<std::vector<size_t>> CastPyArg2VectorOfVectorOfSize_t( ...@@ -74,8 +75,7 @@ std::vector<std::vector<size_t>> CastPyArg2VectorOfVectorOfSize_t(
PyObject* obj, size_t arg_pos); PyObject* obj, size_t arg_pos);
framework::proto::VarType::Type CastPyArg2ProtoType(PyObject* obj, framework::proto::VarType::Type CastPyArg2ProtoType(PyObject* obj,
ssize_t arg_pos); ssize_t arg_pos);
std::unordered_map<std::wstring, int> CastPyArg2Vocab(PyObject* obj, paddle::framework::Vocab CastPyArg2Vocab(PyObject* obj, ssize_t arg_pos);
ssize_t arg_pos);
std::vector<std::string> CastPyArg2VectorOfString(PyObject* obj, std::vector<std::string> CastPyArg2VectorOfString(PyObject* obj,
ssize_t arg_pos); ssize_t arg_pos);
std::shared_ptr<jit::Function> CastPyArg2JitFunction(PyObject* obj, std::shared_ptr<jit::Function> CastPyArg2JitFunction(PyObject* obj,
...@@ -116,7 +116,7 @@ PyObject* ToPyObject(const paddle::framework::proto::VarType& type); ...@@ -116,7 +116,7 @@ PyObject* ToPyObject(const paddle::framework::proto::VarType& type);
PyObject* ToPyObject(const void* value); PyObject* ToPyObject(const void* value);
PyObject* ToPyObject( PyObject* ToPyObject(
const std::unordered_map<std::string, std::vector<std::string>>& value); const std::unordered_map<std::string, std::vector<std::string>>& value);
PyObject* ToPyObject(const std::unordered_map<std::wstring, int>& value); PyObject* ToPyObject(const paddle::framework::Vocab& value);
class PyTensorHook : public egr::TensorHook { class PyTensorHook : public egr::TensorHook {
public: public:
......
...@@ -55,6 +55,7 @@ limitations under the License. */ ...@@ -55,6 +55,7 @@ limitations under the License. */
#include "paddle/fluid/framework/parallel_executor.h" #include "paddle/fluid/framework/parallel_executor.h"
#include "paddle/fluid/framework/phi_utils.h" #include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/framework/prune.h" #include "paddle/fluid/framework/prune.h"
#include "paddle/fluid/framework/raw_tensor.h"
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/scope_pool.h" #include "paddle/fluid/framework/scope_pool.h"
#include "paddle/fluid/framework/selected_rows_utils.h" #include "paddle/fluid/framework/selected_rows_utils.h"
...@@ -942,14 +943,20 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -942,14 +943,20 @@ All parameter, weight, gradient are variables in Paddle.
py::return_value_policy::reference) py::return_value_policy::reference)
.def("get_bytes", .def("get_bytes",
[](Variable &self) { [](Variable &self) {
return py::bytes(*self.GetMutable<std::string>()); if (self.IsType<String>()) {
return py::bytes(*(self.GetMutable<String>()));
} else {
return py::bytes(
*(self.GetMutable<RawTensor>()->GetMutable<std::string>()));
}
}) })
.def("set_string_list", .def("set_string_list",
[](Variable &self, Strings str_list) { [](Variable &self, Strings str_list) {
*self.GetMutable<Strings>() = str_list; *self.GetMutable<Strings>() = str_list;
}) })
.def("set_vocab", .def("set_vocab",
[](Variable &self, Vocab vocab) { [](Variable &self,
const std::unordered_map<std::wstring, std::int32_t> &vocab) {
*self.GetMutable<Vocab>() = vocab; *self.GetMutable<Vocab>() = vocab;
}) })
.def( .def(
......
...@@ -38,7 +38,8 @@ set(PHI_DEPS ...@@ -38,7 +38,8 @@ set(PHI_DEPS
sparse_coo_tensor sparse_coo_tensor
string_tensor string_tensor
api_scalar api_scalar
api_int_array) api_int_array
extended_tensor)
get_property(phi_kernels GLOBAL PROPERTY PHI_KERNELS) get_property(phi_kernels GLOBAL PROPERTY PHI_KERNELS)
set(PHI_DEPS ${PHI_DEPS} ${phi_kernels}) set(PHI_DEPS ${PHI_DEPS} ${phi_kernels})
......
...@@ -82,6 +82,11 @@ cc_library( ...@@ -82,6 +82,11 @@ cc_library(
SRCS tensor_array.cc SRCS tensor_array.cc
DEPS dense_tensor tensor_base) DEPS dense_tensor tensor_base)
cc_library(
extended_tensor
SRCS extended_tensor.cc
DEPS tensor_base)
cc_library( cc_library(
meta_tensor meta_tensor
SRCS meta_tensor.cc SRCS meta_tensor.cc
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/extended_tensor.h"
namespace phi {
int64_t ExtendedTensor::numel() const {
PADDLE_THROW(phi::errors::Unavailable(
"ExtendedTensor does not support `numel` method."));
}
const DDim& ExtendedTensor::dims() const {
PADDLE_THROW(phi::errors::Unavailable(
"ExtendedTensor does not support `dims` method."));
}
const Place& ExtendedTensor::place() const {
PADDLE_THROW(phi::errors::Unavailable(
"ExtendedTensor does not support `place` method."));
}
DataType ExtendedTensor::dtype() const {
PADDLE_THROW(phi::errors::Unavailable(
"ExtendedTensor does not support `dtype` method."));
}
DataLayout ExtendedTensor::layout() const {
PADDLE_THROW(phi::errors::Unavailable(
"ExtendedTensor does not support `dtype` method."));
}
bool ExtendedTensor::valid() const {
PADDLE_THROW(phi::errors::Unavailable(
"ExtendedTensor does not support `valid` method."));
}
bool ExtendedTensor::initialized() const {
PADDLE_THROW(phi::errors::Unavailable(
"ExtendedTensor does not support `initialized` method."));
}
void* ExtendedTensor::AllocateFrom(Allocator* allocator,
DataType dtype,
size_t requested_size) {
PADDLE_THROW(phi::errors::Unavailable(
"ExtendedTensor does not support `AllocateFrom` method."));
}
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/allocator.h"
#include "paddle/phi/core/tensor_base.h"
#include "paddle/phi/core/tensor_meta.h"
namespace phi {
/// \brief The ExtendedTensor is a interface for custom designed class.
/// If you want to pass some self-designed data as input/output to kernels,
/// you can inherit from this class to store your self-designed data.
class ExtendedTensor : public TensorBase {
public:
ExtendedTensor() = default;
virtual ~ExtendedTensor() = default;
public:
/// \brief Returns the name of the class for type traits.
/// \return The name of the class.
static const char* name() { return "ExtendedTensor"; }
int64_t numel() const override;
const DDim& dims() const override;
const Place& place() const override;
DataType dtype() const override;
DataLayout layout() const override;
bool valid() const override;
bool initialized() const override;
void* AllocateFrom(Allocator* allocator,
DataType dtype,
size_t requested_size = 0) override;
};
} // namespace phi
...@@ -141,7 +141,7 @@ enum class AttributeType { ...@@ -141,7 +141,7 @@ enum class AttributeType {
INT_ARRAY, INT_ARRAY,
DATA_TYPE, DATA_TYPE,
DATA_LAYOUT, DATA_LAYOUT,
PLACE, PLACE
}; };
struct AttributeArgDef { struct AttributeArgDef {
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "paddle/phi/core/custom_kernel.h" #include "paddle/phi/core/custom_kernel.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/extended_tensor.h"
#include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/core/kernel_factory.h"
#include "paddle/phi/core/kernel_utils.h" #include "paddle/phi/core/kernel_utils.h"
#include "paddle/phi/core/macros.h" #include "paddle/phi/core/macros.h"
...@@ -100,6 +101,12 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> { ...@@ -100,6 +101,12 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
default_tensor_layout, default_tensor_layout,
default_key.dtype(), default_key.dtype(),
arg_type); arg_type);
} else if (arg_type == std::type_index(typeid(
const std::vector<const ExtendedTensor*>&))) {
args_def->AppendInput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == std::type_index(typeid( } else if (arg_type == std::type_index(typeid(
const std::vector<const SelectedRows*>&))) { const std::vector<const SelectedRows*>&))) {
args_def->AppendInput(default_key.backend(), args_def->AppendInput(default_key.backend(),
...@@ -191,6 +198,11 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> { ...@@ -191,6 +198,11 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
default_tensor_layout, default_tensor_layout,
default_key.dtype(), default_key.dtype(),
arg_type); arg_type);
} else if (arg_type == std::type_index(typeid(ExtendedTensor*))) {
args_def->AppendOutput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == std::type_index(typeid(bool))) { } else if (arg_type == std::type_index(typeid(bool))) {
args_def->AppendAttribute(AttributeType::BOOL); args_def->AppendAttribute(AttributeType::BOOL);
} else if (arg_type == std::type_index(typeid(int))) { } else if (arg_type == std::type_index(typeid(int))) {
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "paddle/phi/common/scalar.h" #include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/extended_tensor.h"
#include "paddle/phi/core/kernel_context.h" #include "paddle/phi/core/kernel_context.h"
#include "paddle/phi/core/selected_rows.h" #include "paddle/phi/core/selected_rows.h"
#include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_coo_tensor.h"
...@@ -264,6 +265,7 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> { ...@@ -264,6 +265,7 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(DenseTensor); PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(DenseTensor);
PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(SelectedRows); PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(SelectedRows);
PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(DenseTensor); PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(DenseTensor);
PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(ExtendedTensor);
PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(TensorBase); PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(TensorBase);
PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(SelectedRows); PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(SelectedRows);
PD_SPECIALIZE_KernelCallHelper_FOR_INPUT(SelectedRows); PD_SPECIALIZE_KernelCallHelper_FOR_INPUT(SelectedRows);
...@@ -323,6 +325,7 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> { ...@@ -323,6 +325,7 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(StringTensor); PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(StringTensor);
PD_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(TensorArray); PD_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(TensorArray);
PD_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(ExtendedTensor);
/* End case */ /* End case */
template <typename T> template <typename T>
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature SaveCombineOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorInputs("X")) {
return KernelSignature(
"save_combine_tensor",
{"X"},
{"file_path", "overwrite", "save_as_fp16", "save_to_memory"},
{"Y"});
} else {
return KernelSignature(
"save_combine_vocab",
{"X"},
{"file_path", "overwrite", "save_as_fp16", "save_to_memory"},
{"Y"});
}
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(save_combine, phi::SaveCombineOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册