未验证 提交 b12e7a17 编写于 作者: H hong 提交者: GitHub

update basic infrastructure (#39383)

* update basic infrastructure; support string,  suport vecotr<int>, add tensor args type index; test=develop

* remove useless code; test=develop

* fix bug; test=develop

* polish code; test=develop
上级 eaa3fd45
...@@ -43,10 +43,19 @@ static void ParseArgs(const OpKernelInfo& op_kernel_info, ...@@ -43,10 +43,19 @@ static void ParseArgs(const OpKernelInfo& op_kernel_info,
auto& attribute_defs = OpKernelInfoHelper::GetAttributeDefs(op_kernel_info); auto& attribute_defs = OpKernelInfoHelper::GetAttributeDefs(op_kernel_info);
for (auto& input : input_defs) { for (auto& input : input_defs) {
args_def->AppendInput(input.backend, input.layout, input.dtype); auto type_index =
input.is_vector
? std::type_index(typeid(const std::vector<pten::DenseTensor>&))
: std::type_index(typeid(const pten::DenseTensor&));
args_def->AppendInput(input.backend, input.layout, input.dtype, type_index);
} }
for (auto& output : output_defs) { for (auto& output : output_defs) {
args_def->AppendOutput(output.backend, output.layout, output.dtype); auto type_index =
output.is_vector
? std::type_index(typeid(const std::vector<pten::DenseTensor>&))
: std::type_index(typeid(const pten::DenseTensor&));
args_def->AppendOutput(output.backend, output.layout, output.dtype,
type_index);
} }
for (auto& attr : attribute_defs) { for (auto& attr : attribute_defs) {
args_def->AppendAttribute(attr.type_index); args_def->AppendAttribute(attr.type_index);
......
...@@ -1217,7 +1217,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1217,7 +1217,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
pt_kernel_name, pt_cpu_kernel_key))); pt_kernel_name, pt_cpu_kernel_key)));
dev_ctx = pool.Get(platform::CPUPlace()); dev_ctx = pool.Get(platform::CPUPlace());
if (pt_kernel_->IsValid()) { if (pt_kernel_->IsValid()) {
VLOG(6) << "Static mode PrepareImpl - kernel name: " << pt_kernel_name VLOG(6) << "Static mode PrepareImpl - kernel name: " << pt_kernel_name
<< " | kernel key: " << pt_cpu_kernel_key << " | kernel key: " << pt_cpu_kernel_key
...@@ -1919,7 +1918,12 @@ Scope* OperatorWithKernel::PreparePtenData( ...@@ -1919,7 +1918,12 @@ Scope* OperatorWithKernel::PreparePtenData(
for (size_t i = 0; i < input_defs.size(); ++i) { for (size_t i = 0; i < input_defs.size(); ++i) {
auto& in_def = input_defs.at(i); auto& in_def = input_defs.at(i);
auto& ins_vector = ctx->inputs.at(input_names[i]); auto it = ctx->inputs.find(input_names[i]);
if (it == ctx->inputs.end()) {
continue;
}
auto& ins_vector = it->second;
auto& name_vec = name_map.at(input_names[i]); auto& name_vec = name_map.at(input_names[i]);
bool should_skip_input = bool should_skip_input =
no_buffer_ins && no_buffer_ins->count(input_names[i]) > 0; no_buffer_ins && no_buffer_ins->count(input_names[i]) > 0;
...@@ -2003,18 +2007,29 @@ void OperatorWithKernel::BuildPtenKernelContext( ...@@ -2003,18 +2007,29 @@ void OperatorWithKernel::BuildPtenKernelContext(
attr_names.size(), attr_defs.size())); attr_names.size(), attr_defs.size()));
for (size_t i = 0; i < input_names.size(); ++i) { for (size_t i = 0; i < input_names.size(); ++i) {
auto& ins_vector = ctx.inputs.at(input_names[i]); auto it = ctx.inputs.find(input_names[i]);
// calcute the start and end index of the input tensors // calcute the start and end index of the input tensors
size_t start_idx = size_t start_idx =
(i == 0 ? 0 : pt_kernel_context->InputRangeAt(i - 1).second); (i == 0 ? 0 : pt_kernel_context->InputRangeAt(i - 1).second);
size_t end_idx = start_idx + ins_vector.size();
// deal with optional here
if ((it == ctx.inputs.end()) &&
(input_defs[i].type_index ==
std::type_index(typeid(paddle::optional<const pten::DenseTensor&>)))) {
pt_kernel_context->EmplaceBackInputWithoutSetRange(nullptr);
auto end_idx = start_idx + 1;
pt_kernel_context->AssignInputRange(std::make_pair(start_idx, end_idx),
i);
continue;
}
auto ins_vector = it->second;
size_t end_idx = start_idx + ins_vector.size();
for (size_t offset = 0; offset < ins_vector.size(); ++offset) { for (size_t offset = 0; offset < ins_vector.size(); ++offset) {
const pten::TensorBase* tensor_in = nullptr; const pten::TensorBase* tensor_in = nullptr;
auto* var = ins_vector[offset]; auto* var = ins_vector[offset];
if (var->IsType<pten::DenseTensor>()) { if (var->IsType<framework::LoDTensor>()) {
tensor_in = &(var->Get<pten::DenseTensor>()); tensor_in = &(var->Get<framework::LoDTensor>());
} else if (var->IsType<pten::SelectedRows>()) { } else if (var->IsType<pten::SelectedRows>()) {
tensor_in = &(var->Get<pten::SelectedRows>()); tensor_in = &(var->Get<pten::SelectedRows>());
} else { } else {
...@@ -2022,23 +2037,37 @@ void OperatorWithKernel::BuildPtenKernelContext( ...@@ -2022,23 +2037,37 @@ void OperatorWithKernel::BuildPtenKernelContext(
"Unsupported input `%s` type when call pt kernel.", "Unsupported input `%s` type when call pt kernel.",
framework::ToTypeName(var->Type()))); framework::ToTypeName(var->Type())));
} }
pt_kernel_context->EmplaceBackInputWithoutSetRange(tensor_in); pt_kernel_context->EmplaceBackInputWithoutSetRange(tensor_in);
} }
pt_kernel_context->AssignInputRange(std::make_pair(start_idx, end_idx), i); pt_kernel_context->AssignInputRange(std::make_pair(start_idx, end_idx), i);
} }
for (size_t i = 0; i < output_names.size(); ++i) { for (size_t i = 0; i < output_names.size(); ++i) {
auto& outs_vector = ctx.outputs.at(output_names[i]); auto it = ctx.outputs.find(output_names[i]);
size_t start_idx = size_t start_idx =
(i == 0 ? 0 : pt_kernel_context->OutputRangeAt(i - 1).second); (i == 0 ? 0 : pt_kernel_context->OutputRangeAt(i - 1).second);
if (it == ctx.outputs.end() || it->second.empty()) {
// Deal with the case that some outputs are not found or be NULL when run
// the kernel.
// For example : the outputs of matmul_grad are dx and dy,
// sometimes dx or dy may be NULL.
pt_kernel_context->EmplaceBackOutputWithoutSetRange(nullptr);
auto end_idx = start_idx + 1;
pt_kernel_context->AssignOutputRange(std::make_pair(start_idx, end_idx),
i);
continue;
}
auto& outs_vector = it->second;
size_t end_idx = start_idx + outs_vector.size(); size_t end_idx = start_idx + outs_vector.size();
for (size_t offset = 0; offset < outs_vector.size(); ++offset) { for (size_t offset = 0; offset < outs_vector.size(); ++offset) {
pten::TensorBase* tensor_out = nullptr; pten::TensorBase* tensor_out = nullptr;
auto* var = outs_vector[offset]; auto* var = outs_vector[offset];
if (var->template IsType<pten::DenseTensor>()) { if (var->template IsType<framework::LoDTensor>()) {
tensor_out = var->template GetMutable<pten::DenseTensor>(); tensor_out = var->template GetMutable<framework::LoDTensor>();
} else if (var->template IsType<pten::SelectedRows>()) { } else if (var->template IsType<pten::SelectedRows>()) {
tensor_out = var->template GetMutable<pten::SelectedRows>(); tensor_out = var->template GetMutable<pten::SelectedRows>();
} else { } else {
...@@ -2055,14 +2084,6 @@ void OperatorWithKernel::BuildPtenKernelContext( ...@@ -2055,14 +2084,6 @@ void OperatorWithKernel::BuildPtenKernelContext(
pt_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out); pt_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out);
} }
// Deal with the case that some outputs are NULL when run the kernel.
// For example : the outputs of matmul_grad are dx and dy,
// sometimes dx or dy may be NULL.
if (outs_vector.empty()) {
pt_kernel_context->EmplaceBackOutputWithoutSetRange({nullptr});
end_idx = start_idx + 1;
}
pt_kernel_context->AssignOutputRange(std::make_pair(start_idx, end_idx), i); pt_kernel_context->AssignOutputRange(std::make_pair(start_idx, end_idx), i);
} }
...@@ -2134,6 +2155,9 @@ void OperatorWithKernel::BuildPtenKernelContext( ...@@ -2134,6 +2155,9 @@ void OperatorWithKernel::BuildPtenKernelContext(
pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(float, attr)); pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(float, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(bool))) { } else if (attr_defs[i].type_index == std::type_index(typeid(bool))) {
pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::string))) {
pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(std::string, attr));
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index ==
std::type_index(typeid(pten::DataType))) { std::type_index(typeid(pten::DataType))) {
auto data_type = pten::TransToPtenDataType( auto data_type = pten::TransToPtenDataType(
...@@ -2152,6 +2176,10 @@ void OperatorWithKernel::BuildPtenKernelContext( ...@@ -2152,6 +2176,10 @@ void OperatorWithKernel::BuildPtenKernelContext(
} }
// TODO(YuanRisheng) Need support vector<int64_t> attr // TODO(YuanRisheng) Need support vector<int64_t> attr
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<int32_t>))) {
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
pt_kernel_context->EmplaceBackAttr(vector_int_attr);
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` when construct " "Unsupported cast op attribute `%s` when construct "
......
...@@ -259,26 +259,36 @@ void BuildDygraphPtenKernelContext( ...@@ -259,26 +259,36 @@ void BuildDygraphPtenKernelContext(
attr_names.size(), attr_defs.size())); attr_names.size(), attr_defs.size()));
for (size_t i = 0; i < input_names.size(); ++i) { for (size_t i = 0; i < input_names.size(); ++i) {
auto& ins_vector = ins.at(input_names[i]); auto it = ins.find(input_names[i]);
size_t start_idx = (i == 0 ? 0 : kernel_ctx->InputRangeAt(i - 1).second); size_t start_idx = (i == 0 ? 0 : kernel_ctx->InputRangeAt(i - 1).second);
size_t end_idx = start_idx + ins_vector.size();
for (size_t offset = 0; offset < ins_vector.size(); ++offset) { if ((it == ins.end()) &&
const pten::TensorBase* tensor_in = nullptr; (input_defs[i].type_index ==
auto& var = ins_vector[offset]->Var(); std::type_index(typeid(paddle::optional<const pten::DenseTensor&>)))) {
if (var.template IsType<pten::DenseTensor>()) { kernel_ctx->EmplaceBackInputWithoutSetRange(nullptr);
tensor_in = &(var.template Get<pten::DenseTensor>()); auto end_idx = start_idx + 1;
} else if (var.template IsType<pten::SelectedRows>()) { kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i);
tensor_in = &(var.template Get<pten::SelectedRows>()); } else {
} else { auto ins_vector = it->second;
PADDLE_THROW(platform::errors::Unimplemented( size_t end_idx = start_idx + ins_vector.size();
"Unsupported input `%s` type when call pt kernel.",
framework::ToTypeName(var.Type()))); for (size_t offset = 0; offset < ins_vector.size(); ++offset) {
const pten::TensorBase* tensor_in = nullptr;
auto& var = ins_vector[offset]->Var();
if (var.template IsType<pten::DenseTensor>()) {
tensor_in = &(var.template Get<pten::DenseTensor>());
} else if (var.template IsType<pten::SelectedRows>()) {
tensor_in = &(var.template Get<pten::SelectedRows>());
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported input `%s` type when call pt kernel.",
framework::ToTypeName(var.Type())));
}
kernel_ctx->EmplaceBackInputWithoutSetRange(tensor_in);
} }
kernel_ctx->EmplaceBackInputWithoutSetRange(tensor_in); kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i);
} }
kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i);
} }
for (size_t i = 0; i < output_names.size(); ++i) { for (size_t i = 0; i < output_names.size(); ++i) {
...@@ -336,6 +346,10 @@ void BuildDygraphPtenKernelContext( ...@@ -336,6 +346,10 @@ void BuildDygraphPtenKernelContext(
std::type_index(typeid(std::vector<int32_t>))) { std::type_index(typeid(std::vector<int32_t>))) {
kernel_ctx->EmplaceBackAttr(std::move( kernel_ctx->EmplaceBackAttr(std::move(
pten::ScalarArray(BOOST_GET_CONST(std::vector<int32_t>, attr)))); pten::ScalarArray(BOOST_GET_CONST(std::vector<int32_t>, attr))));
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<int32_t>))) {
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
kernel_ctx->EmplaceBackAttr(vector_int_attr);
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to VectorTensor when " "Unsupported cast op attribute `%s` to VectorTensor when "
...@@ -398,6 +412,9 @@ void BuildDygraphPtenKernelContext( ...@@ -398,6 +412,9 @@ void BuildDygraphPtenKernelContext(
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(float, attr)); kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(float, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(bool))) { } else if (attr_defs[i].type_index == std::type_index(typeid(bool))) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::string))) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::string, attr));
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index ==
std::type_index(typeid(pten::DataType))) { std::type_index(typeid(pten::DataType))) {
auto data_type = pten::TransToPtenDataType( auto data_type = pten::TransToPtenDataType(
...@@ -440,6 +457,10 @@ void PreparePtenData(const pten::Kernel& pt_kernel, ...@@ -440,6 +457,10 @@ void PreparePtenData(const pten::Kernel& pt_kernel,
for (size_t i = 0; i < input_names.size(); ++i) { for (size_t i = 0; i < input_names.size(); ++i) {
auto& in_def = input_defs.at(i); auto& in_def = input_defs.at(i);
auto it = ins.find(input_names[i]);
if (it == ins.end()) {
continue;
}
auto& ins_vector = ins.at(input_names[i]); auto& ins_vector = ins.at(input_names[i]);
for (size_t offset = 0; offset < ins_vector.size(); ++offset) { for (size_t offset = 0; offset < ins_vector.size(); ++offset) {
......
...@@ -95,9 +95,16 @@ struct TensorArgDef { ...@@ -95,9 +95,16 @@ struct TensorArgDef {
Backend backend; Backend backend;
DataLayout layout; DataLayout layout;
DataType dtype; DataType dtype;
std::type_index type_index;
TensorArgDef(Backend in_backend, DataLayout in_layout, DataType in_dtype) TensorArgDef(Backend in_backend,
: backend(in_backend), layout(in_layout), dtype(in_dtype) {} DataLayout in_layout,
DataType in_dtype,
std::type_index in_type_index)
: backend(in_backend),
layout(in_layout),
dtype(in_dtype),
type_index(in_type_index) {}
TensorArgDef& SetBackend(Backend in_backend) { TensorArgDef& SetBackend(Backend in_backend) {
backend = in_backend; backend = in_backend;
...@@ -126,12 +133,18 @@ class KernelArgsDef { ...@@ -126,12 +133,18 @@ class KernelArgsDef {
public: public:
KernelArgsDef() = default; KernelArgsDef() = default;
void AppendInput(Backend backend, DataLayout layout, DataType dtype) { void AppendInput(Backend backend,
input_defs_.emplace_back(TensorArgDef(backend, layout, dtype)); DataLayout layout,
DataType dtype,
std::type_index type_index) {
input_defs_.emplace_back(TensorArgDef(backend, layout, dtype, type_index));
} }
void AppendOutput(Backend backend, DataLayout layout, DataType dtype) { void AppendOutput(Backend backend,
output_defs_.emplace_back(TensorArgDef(backend, layout, dtype)); DataLayout layout,
DataType dtype,
std::type_index type_index) {
output_defs_.emplace_back(TensorArgDef(backend, layout, dtype, type_index));
} }
void AppendAttribute(std::type_index type_index) { void AppendAttribute(std::type_index type_index) {
......
...@@ -64,29 +64,43 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> { ...@@ -64,29 +64,43 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
#endif #endif
// do nothing, skip context arg now // do nothing, skip context arg now
} else if (arg_type == std::type_index(typeid(const DenseTensor&))) { } else if (arg_type == std::type_index(typeid(const DenseTensor&))) {
args_def->AppendInput( args_def->AppendInput(default_key.backend(),
default_key.backend(), default_tensor_layout, default_key.dtype()); default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == std::type_index(typeid( } else if (arg_type == std::type_index(typeid(
paddle::optional<const DenseTensor&>))) { paddle::optional<const DenseTensor&>))) {
args_def->AppendInput( args_def->AppendInput(default_key.backend(),
default_key.backend(), default_tensor_layout, default_key.dtype()); default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == } else if (arg_type ==
std::type_index(typeid(const std::vector<DenseTensor>&))) { std::type_index(typeid(const std::vector<DenseTensor>&))) {
args_def->AppendInput( args_def->AppendInput(default_key.backend(),
default_key.backend(), default_tensor_layout, default_key.dtype()); default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == std::type_index(typeid(const SelectedRows&))) { } else if (arg_type == std::type_index(typeid(const SelectedRows&))) {
args_def->AppendInput( args_def->AppendInput(default_key.backend(),
default_key.backend(), default_tensor_layout, default_key.dtype()); default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == std::type_index(typeid(DenseTensor*))) { } else if (arg_type == std::type_index(typeid(DenseTensor*))) {
args_def->AppendOutput( args_def->AppendOutput(default_key.backend(),
default_key.backend(), default_tensor_layout, default_key.dtype()); default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == } else if (arg_type ==
std::type_index(typeid(std::vector<DenseTensor*>))) { std::type_index(typeid(std::vector<DenseTensor*>))) {
args_def->AppendOutput( args_def->AppendOutput(default_key.backend(),
default_key.backend(), default_tensor_layout, default_key.dtype()); default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == std::type_index(typeid(SelectedRows*))) { } else if (arg_type == std::type_index(typeid(SelectedRows*))) {
args_def->AppendOutput( args_def->AppendOutput(default_key.backend(),
default_key.backend(), default_tensor_layout, default_key.dtype()); default_tensor_layout,
default_key.dtype(),
arg_type);
} else { } else {
// Attribute deal with // Attribute deal with
// TODO(chenweihang): now here allow any types of attribute, maybe // TODO(chenweihang): now here allow any types of attribute, maybe
......
...@@ -238,6 +238,7 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> { ...@@ -238,6 +238,7 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<int64_t>&); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<int64_t>&);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const ScalarArray&); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const ScalarArray&);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<int>&); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<int>&);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::string&);
/* Output Helpers */ /* Output Helpers */
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册