diff --git a/paddle/fluid/framework/custom_kernel.cc b/paddle/fluid/framework/custom_kernel.cc index a04838c9d4f8744cdeeaadb0fdd57c9506f269c5..058be98928227f74630389738939f39fee6c87e8 100644 --- a/paddle/fluid/framework/custom_kernel.cc +++ b/paddle/fluid/framework/custom_kernel.cc @@ -43,10 +43,19 @@ static void ParseArgs(const OpKernelInfo& op_kernel_info, auto& attribute_defs = OpKernelInfoHelper::GetAttributeDefs(op_kernel_info); for (auto& input : input_defs) { - args_def->AppendInput(input.backend, input.layout, input.dtype); + auto type_index = + input.is_vector + ? std::type_index(typeid(const std::vector&)) + : std::type_index(typeid(const pten::DenseTensor&)); + args_def->AppendInput(input.backend, input.layout, input.dtype, type_index); } for (auto& output : output_defs) { - args_def->AppendOutput(output.backend, output.layout, output.dtype); + auto type_index = + output.is_vector + ? std::type_index(typeid(const std::vector&)) + : std::type_index(typeid(const pten::DenseTensor&)); + args_def->AppendOutput(output.backend, output.layout, output.dtype, + type_index); } for (auto& attr : attribute_defs) { args_def->AppendAttribute(attr.type_index); diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 0f558b46872a2d8e94cbe7141119b0d81e696e90..5ed86fb6b717f9c24415932c1cd2884ffb8e71bd 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1217,7 +1217,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope, pt_kernel_name, pt_cpu_kernel_key))); dev_ctx = pool.Get(platform::CPUPlace()); - if (pt_kernel_->IsValid()) { VLOG(6) << "Static mode PrepareImpl - kernel name: " << pt_kernel_name << " | kernel key: " << pt_cpu_kernel_key @@ -1919,7 +1918,12 @@ Scope* OperatorWithKernel::PreparePtenData( for (size_t i = 0; i < input_defs.size(); ++i) { auto& in_def = input_defs.at(i); - auto& ins_vector = ctx->inputs.at(input_names[i]); + auto it = ctx->inputs.find(input_names[i]); + if (it == ctx->inputs.end()) { + continue; + } + + auto& ins_vector = it->second; auto& name_vec = name_map.at(input_names[i]); bool should_skip_input = no_buffer_ins && no_buffer_ins->count(input_names[i]) > 0; @@ -2003,18 +2007,29 @@ void OperatorWithKernel::BuildPtenKernelContext( attr_names.size(), attr_defs.size())); for (size_t i = 0; i < input_names.size(); ++i) { - auto& ins_vector = ctx.inputs.at(input_names[i]); + auto it = ctx.inputs.find(input_names[i]); // calcute the start and end index of the input tensors size_t start_idx = (i == 0 ? 0 : pt_kernel_context->InputRangeAt(i - 1).second); - size_t end_idx = start_idx + ins_vector.size(); + // deal with optional here + if ((it == ctx.inputs.end()) && + (input_defs[i].type_index == + std::type_index(typeid(paddle::optional)))) { + pt_kernel_context->EmplaceBackInputWithoutSetRange(nullptr); + auto end_idx = start_idx + 1; + pt_kernel_context->AssignInputRange(std::make_pair(start_idx, end_idx), + i); + continue; + } + auto ins_vector = it->second; + size_t end_idx = start_idx + ins_vector.size(); for (size_t offset = 0; offset < ins_vector.size(); ++offset) { const pten::TensorBase* tensor_in = nullptr; auto* var = ins_vector[offset]; - if (var->IsType()) { - tensor_in = &(var->Get()); + if (var->IsType()) { + tensor_in = &(var->Get()); } else if (var->IsType()) { tensor_in = &(var->Get()); } else { @@ -2022,23 +2037,37 @@ void OperatorWithKernel::BuildPtenKernelContext( "Unsupported input `%s` type when call pt kernel.", framework::ToTypeName(var->Type()))); } + pt_kernel_context->EmplaceBackInputWithoutSetRange(tensor_in); } pt_kernel_context->AssignInputRange(std::make_pair(start_idx, end_idx), i); } for (size_t i = 0; i < output_names.size(); ++i) { - auto& outs_vector = ctx.outputs.at(output_names[i]); - + auto it = ctx.outputs.find(output_names[i]); size_t start_idx = (i == 0 ? 0 : pt_kernel_context->OutputRangeAt(i - 1).second); + + if (it == ctx.outputs.end() || it->second.empty()) { + // Deal with the case that some outputs are not found or be NULL when run + // the kernel. + // For example : the outputs of matmul_grad are dx and dy, + // sometimes dx or dy may be NULL. + pt_kernel_context->EmplaceBackOutputWithoutSetRange(nullptr); + auto end_idx = start_idx + 1; + pt_kernel_context->AssignOutputRange(std::make_pair(start_idx, end_idx), + i); + continue; + } + auto& outs_vector = it->second; + size_t end_idx = start_idx + outs_vector.size(); for (size_t offset = 0; offset < outs_vector.size(); ++offset) { pten::TensorBase* tensor_out = nullptr; auto* var = outs_vector[offset]; - if (var->template IsType()) { - tensor_out = var->template GetMutable(); + if (var->template IsType()) { + tensor_out = var->template GetMutable(); } else if (var->template IsType()) { tensor_out = var->template GetMutable(); } else { @@ -2055,14 +2084,6 @@ void OperatorWithKernel::BuildPtenKernelContext( pt_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out); } - // Deal with the case that some outputs are NULL when run the kernel. - // For example : the outputs of matmul_grad are dx and dy, - // sometimes dx or dy may be NULL. - if (outs_vector.empty()) { - pt_kernel_context->EmplaceBackOutputWithoutSetRange({nullptr}); - end_idx = start_idx + 1; - } - pt_kernel_context->AssignOutputRange(std::make_pair(start_idx, end_idx), i); } @@ -2134,6 +2155,9 @@ void OperatorWithKernel::BuildPtenKernelContext( pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(float, attr)); } else if (attr_defs[i].type_index == std::type_index(typeid(bool))) { pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); + } else if (attr_defs[i].type_index == + std::type_index(typeid(std::string))) { + pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(std::string, attr)); } else if (attr_defs[i].type_index == std::type_index(typeid(pten::DataType))) { auto data_type = pten::TransToPtenDataType( @@ -2152,6 +2176,10 @@ void OperatorWithKernel::BuildPtenKernelContext( } // TODO(YuanRisheng) Need support vector attr + } else if (attr_defs[i].type_index == + std::type_index(typeid(std::vector))) { + const auto& vector_int_attr = BOOST_GET_CONST(std::vector, attr); + pt_kernel_context->EmplaceBackAttr(vector_int_attr); } else { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported cast op attribute `%s` when construct " diff --git a/paddle/fluid/imperative/prepared_operator.h b/paddle/fluid/imperative/prepared_operator.h index 8775f715bfb202f1b07a4ccacb113ff97d71fada..2efe808a8a3db14bf64e34bc88f0374683d476f5 100644 --- a/paddle/fluid/imperative/prepared_operator.h +++ b/paddle/fluid/imperative/prepared_operator.h @@ -259,26 +259,36 @@ void BuildDygraphPtenKernelContext( attr_names.size(), attr_defs.size())); for (size_t i = 0; i < input_names.size(); ++i) { - auto& ins_vector = ins.at(input_names[i]); + auto it = ins.find(input_names[i]); size_t start_idx = (i == 0 ? 0 : kernel_ctx->InputRangeAt(i - 1).second); - size_t end_idx = start_idx + ins_vector.size(); - for (size_t offset = 0; offset < ins_vector.size(); ++offset) { - const pten::TensorBase* tensor_in = nullptr; - auto& var = ins_vector[offset]->Var(); - if (var.template IsType()) { - tensor_in = &(var.template Get()); - } else if (var.template IsType()) { - tensor_in = &(var.template Get()); - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Unsupported input `%s` type when call pt kernel.", - framework::ToTypeName(var.Type()))); + if ((it == ins.end()) && + (input_defs[i].type_index == + std::type_index(typeid(paddle::optional)))) { + kernel_ctx->EmplaceBackInputWithoutSetRange(nullptr); + auto end_idx = start_idx + 1; + kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i); + } else { + auto ins_vector = it->second; + size_t end_idx = start_idx + ins_vector.size(); + + for (size_t offset = 0; offset < ins_vector.size(); ++offset) { + const pten::TensorBase* tensor_in = nullptr; + auto& var = ins_vector[offset]->Var(); + if (var.template IsType()) { + tensor_in = &(var.template Get()); + } else if (var.template IsType()) { + tensor_in = &(var.template Get()); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported input `%s` type when call pt kernel.", + framework::ToTypeName(var.Type()))); + } + kernel_ctx->EmplaceBackInputWithoutSetRange(tensor_in); } - kernel_ctx->EmplaceBackInputWithoutSetRange(tensor_in); + kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i); } - kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i); } for (size_t i = 0; i < output_names.size(); ++i) { @@ -336,6 +346,10 @@ void BuildDygraphPtenKernelContext( std::type_index(typeid(std::vector))) { kernel_ctx->EmplaceBackAttr(std::move( pten::ScalarArray(BOOST_GET_CONST(std::vector, attr)))); + } else if (attr_defs[i].type_index == + std::type_index(typeid(std::vector))) { + const auto& vector_int_attr = BOOST_GET_CONST(std::vector, attr); + kernel_ctx->EmplaceBackAttr(vector_int_attr); } else { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported cast op attribute `%s` to VectorTensor when " @@ -398,6 +412,9 @@ void BuildDygraphPtenKernelContext( kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(float, attr)); } else if (attr_defs[i].type_index == std::type_index(typeid(bool))) { kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); + } else if (attr_defs[i].type_index == + std::type_index(typeid(std::string))) { + kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::string, attr)); } else if (attr_defs[i].type_index == std::type_index(typeid(pten::DataType))) { auto data_type = pten::TransToPtenDataType( @@ -440,6 +457,10 @@ void PreparePtenData(const pten::Kernel& pt_kernel, for (size_t i = 0; i < input_names.size(); ++i) { auto& in_def = input_defs.at(i); + auto it = ins.find(input_names[i]); + if (it == ins.end()) { + continue; + } auto& ins_vector = ins.at(input_names[i]); for (size_t offset = 0; offset < ins_vector.size(); ++offset) { diff --git a/paddle/pten/core/kernel_factory.h b/paddle/pten/core/kernel_factory.h index 7d64429d3f2d04a239fb18a5ea554b4f3fe36ca6..296195ea4e6ceff3a21c22cf3f13369fe0c6d568 100644 --- a/paddle/pten/core/kernel_factory.h +++ b/paddle/pten/core/kernel_factory.h @@ -95,9 +95,16 @@ struct TensorArgDef { Backend backend; DataLayout layout; DataType dtype; + std::type_index type_index; - TensorArgDef(Backend in_backend, DataLayout in_layout, DataType in_dtype) - : backend(in_backend), layout(in_layout), dtype(in_dtype) {} + TensorArgDef(Backend in_backend, + DataLayout in_layout, + DataType in_dtype, + std::type_index in_type_index) + : backend(in_backend), + layout(in_layout), + dtype(in_dtype), + type_index(in_type_index) {} TensorArgDef& SetBackend(Backend in_backend) { backend = in_backend; @@ -126,12 +133,18 @@ class KernelArgsDef { public: KernelArgsDef() = default; - void AppendInput(Backend backend, DataLayout layout, DataType dtype) { - input_defs_.emplace_back(TensorArgDef(backend, layout, dtype)); + void AppendInput(Backend backend, + DataLayout layout, + DataType dtype, + std::type_index type_index) { + input_defs_.emplace_back(TensorArgDef(backend, layout, dtype, type_index)); } - void AppendOutput(Backend backend, DataLayout layout, DataType dtype) { - output_defs_.emplace_back(TensorArgDef(backend, layout, dtype)); + void AppendOutput(Backend backend, + DataLayout layout, + DataType dtype, + std::type_index type_index) { + output_defs_.emplace_back(TensorArgDef(backend, layout, dtype, type_index)); } void AppendAttribute(std::type_index type_index) { diff --git a/paddle/pten/core/kernel_registry.h b/paddle/pten/core/kernel_registry.h index a0ff340b00074f8debace70a59df822ebe2a7fc1..9e8778bdb392a1c92c6f48a19399438ca85ba2de 100644 --- a/paddle/pten/core/kernel_registry.h +++ b/paddle/pten/core/kernel_registry.h @@ -64,29 +64,43 @@ struct KernelArgsParseFunctor { #endif // do nothing, skip context arg now } else if (arg_type == std::type_index(typeid(const DenseTensor&))) { - args_def->AppendInput( - default_key.backend(), default_tensor_layout, default_key.dtype()); + args_def->AppendInput(default_key.backend(), + default_tensor_layout, + default_key.dtype(), + arg_type); } else if (arg_type == std::type_index(typeid( paddle::optional))) { - args_def->AppendInput( - default_key.backend(), default_tensor_layout, default_key.dtype()); + args_def->AppendInput(default_key.backend(), + default_tensor_layout, + default_key.dtype(), + arg_type); } else if (arg_type == std::type_index(typeid(const std::vector&))) { - args_def->AppendInput( - default_key.backend(), default_tensor_layout, default_key.dtype()); + args_def->AppendInput(default_key.backend(), + default_tensor_layout, + default_key.dtype(), + arg_type); } else if (arg_type == std::type_index(typeid(const SelectedRows&))) { - args_def->AppendInput( - default_key.backend(), default_tensor_layout, default_key.dtype()); + args_def->AppendInput(default_key.backend(), + default_tensor_layout, + default_key.dtype(), + arg_type); } else if (arg_type == std::type_index(typeid(DenseTensor*))) { - args_def->AppendOutput( - default_key.backend(), default_tensor_layout, default_key.dtype()); + args_def->AppendOutput(default_key.backend(), + default_tensor_layout, + default_key.dtype(), + arg_type); } else if (arg_type == std::type_index(typeid(std::vector))) { - args_def->AppendOutput( - default_key.backend(), default_tensor_layout, default_key.dtype()); + args_def->AppendOutput(default_key.backend(), + default_tensor_layout, + default_key.dtype(), + arg_type); } else if (arg_type == std::type_index(typeid(SelectedRows*))) { - args_def->AppendOutput( - default_key.backend(), default_tensor_layout, default_key.dtype()); + args_def->AppendOutput(default_key.backend(), + default_tensor_layout, + default_key.dtype(), + arg_type); } else { // Attribute deal with // TODO(chenweihang): now here allow any types of attribute, maybe diff --git a/paddle/pten/core/kernel_utils.h b/paddle/pten/core/kernel_utils.h index d05e3c2887306fa83a4ddf32230790f392697b55..7c611d7eccd11ec0e9958e8577697d192bfad43a 100644 --- a/paddle/pten/core/kernel_utils.h +++ b/paddle/pten/core/kernel_utils.h @@ -238,6 +238,7 @@ struct KernelImpl { PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector&); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const ScalarArray&); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector&); + PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::string&); /* Output Helpers */