diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc index e3d9e2228360080d31bd1336659d3fb0ddba08d4..9a7da7aa831c15ce884380b27dde94f5f795574d 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc +++ b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc @@ -178,6 +178,8 @@ void HandleForSpecialOp(ir::Operation* op, auto var = CreateVar(out_value, name, scope, local_scope); auto tensor_array = var->GetMutable(); + // clear tensor array + tensor_array->clear(); for (size_t i = 0; i < input_num; ++i) { auto value = op->operand(i); @@ -203,8 +205,10 @@ void HandleForSpecialOp(ir::Operation* op, // change opreand name to param_name auto orig_name = name_map->at(in_ptr); + if (scope->FindVar(param_name) == nullptr) { + scope->Rename(orig_name, param_name); + } (*name_map)[in_ptr] = param_name; - scope->Rename(orig_name, param_name); } if (op_name == "builtin.get_parameter") { diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h index 32ffa663144d8d5bf6a783025f1b34ffee3a6c87..77dbb98daa3f9faf1e3879c83f667620518f64b1 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h +++ b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h @@ -33,6 +33,7 @@ #include "paddle/phi/core/kernel_context.h" #include "paddle/fluid/ir/dialect/kernel_attribute.h" +#include "paddle/fluid/ir/dialect/kernel_type.h" #include "paddle/fluid/ir/dialect/pd_attribute.h" #include "paddle/fluid/ir/interface/op_yaml_info_parser.h" #include "paddle/phi/core/infermeta_utils.h" @@ -109,6 +110,7 @@ void BuildPhiContext( ctx->EmplaceBackInput(in_ptr); continue; } + auto in_var_name = name_map.at(ptr); VLOG(6) << "ctx->EmplaceBackInput: " << t << "\t" << in_var_name; @@ -154,9 +156,27 @@ void BuildPhiContext( auto& tensor_attr_type = op_yaml_info.TensorAttrTypeName(t); VLOG(6) << "ctx->EmplaceBack mutable attr: " << t << "\t" << in_var_name; if (tensor_attr_type == "paddle::dialect::IntArrayAttribute") { - phi::Attribute r1 = phi::TensorRef( - &(inner_scope->FindVar(in_var_name)->Get())); - ctx->EmplaceBackAttr(r1); + if (ptr.type().isa()) { + phi::Attribute r1 = phi::TensorRef( + &(inner_scope->FindVar(in_var_name)->Get())); + ctx->EmplaceBackAttr(r1); + } else if (ptr.type().isa()) { + auto& tensor_array = inner_scope->FindVar(in_var_name) + ->Get(); + if (tensor_array.size() == 1) { + ctx->EmplaceBackAttr(phi::TensorRef(tensor_array[0])); + } else { + std::vector vec_ref; + for (size_t i = 0; i < tensor_array.size(); ++i) { + vec_ref.emplace_back(phi::TensorRef(tensor_array[i])); + } + ctx->EmplaceBackAttr(vec_ref); + } + } else { + PADDLE_THROW(phi::errors::Unimplemented( + " [%s] only support dense tensor and vector type ", + tensor_attr_type)); + } } else if (tensor_attr_type == "paddle::dialect::ScalarAttribute") { phi::Attribute r1 = phi::TensorRef( &(inner_scope->FindVar(in_var_name)->Get())); diff --git a/paddle/fluid/ir_adaptor/translator/op_compat_info.h b/paddle/fluid/ir_adaptor/translator/op_compat_info.h index a56bcfa3392ae5b2f40b320038e0a3e52a09ae99..5feb2c6c76b0791c1d69af1abf78fe38051e834a 100644 --- a/paddle/fluid/ir_adaptor/translator/op_compat_info.h +++ b/paddle/fluid/ir_adaptor/translator/op_compat_info.h @@ -81,6 +81,7 @@ class OpNameNormalizer { const std::string& arg_name) { bool is_grad_op = (op_type.find(kPhiGradSuffix) != std::string::npos); bool is_grad_arg = (arg_name.find(kPhiGradSuffix) != std::string::npos); + if (is_grad_op && is_grad_arg) { std::string target = kPhiGradSuffix; std::string data = kFluidVarGradSuffix; diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index cef8ab1a88bb78343fb2683c8beb8d3420d2bd69..ddd8feac29af0f69d175d8e332ec0f4eee0fe472 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -622,6 +622,7 @@ void OpTranscriber::RecordOpResultMapping(TranslationContext* param_map, generated_by_vector = false; } } + (*param_map)[arg_name] = VariableDefiningInfo( value, generated_by_vector, generated_by_vector ? idx_in_vector : -1); idx_in_vector++; diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index d670039c7a9142471d4b088c7231384e109f777c..e964afc0405b81a88f6a4d0e0a955a98e942683f 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -2530,6 +2530,7 @@ int_array: sections : data_type : int + tensor_name : AxesTensor - op : sqrt backward : sqrt_grad, sqrt_double_grad (sqrt_grad_grad) diff --git a/paddle/phi/common/int_array.cc b/paddle/phi/common/int_array.cc index abbe205bc915c54b05d006a82c5d5145a723aa81..a39638aaa6bd4ec85e578516bb1c76f94e838e60 100644 --- a/paddle/phi/common/int_array.cc +++ b/paddle/phi/common/int_array.cc @@ -43,6 +43,56 @@ IntArrayBase::IntArrayBase( } } +template <> +IntArrayBase::IntArrayBase( + const std::vector& tensor_ref_list) { + is_from_tensor_ = true; + for (size_t i = 0; i < tensor_ref_list.size(); ++i) { + DataType data_type = tensor_ref_list[i].Get()->dtype(); + switch (data_type) { + case DataType::INT32: + if (tensor_ref_list[i].Get()->place().GetType() == + AllocationType::CPU) { + array_.push_back(*tensor_ref_list[i].Get()->template data()); + } else { + phi::DenseTensor tensor_tmp; + phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance(); + auto dev_ctx = pool.Get(tensor_ref_list[i].Get()->place()); + phi::Copy(*dev_ctx, + *(tensor_ref_list[i].Get()), + CPUPlace(), + true, + &tensor_tmp); + array_.push_back(*tensor_tmp.template data()); + } + break; + case DataType::INT64: + if (tensor_ref_list[i].Get()->place().GetType() == + AllocationType::CPU) { + array_.push_back(*tensor_ref_list[i].Get()->template data()); + } else { + phi::DenseTensor tensor_tmp; + phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance(); + auto dev_ctx = pool.Get(tensor_ref_list[i].Get()->place()); + phi::Copy(*dev_ctx, + *(tensor_ref_list[i].Get()), + CPUPlace(), + true, + &tensor_tmp); + array_.push_back(*tensor_tmp.template data()); + } + break; + default: + PD_THROW( + "Data type error. Currently, The data type of IntArrayBase " + "only supports Tensor with int32 and int64, " + "but now received `", + data_type, + "`."); + } + } +} + template <> IntArrayBase::IntArrayBase( const std::vector& tensor_list) { diff --git a/paddle/phi/common/int_array.h b/paddle/phi/common/int_array.h index 50f88ddb6fbe5267dfdc66c5d740ce54a8d6e741..0c4b3d4c8ca5b790a380df9da88d58eef69a2a95 100644 --- a/paddle/phi/common/int_array.h +++ b/paddle/phi/common/int_array.h @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/phi/api/ext/exception.h" #include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/tensor_ref.h" namespace phi { class DDim; @@ -62,6 +63,8 @@ class IntArrayBase { // The Tensor in vec must have only one element IntArrayBase(const std::vector& tensor_list); // NOLINT + explicit IntArrayBase(const std::vector& tensor_ref_list); + template IntArrayBase(const IntArrayBase& other) : array_(other.GetData()) {} @@ -87,6 +90,7 @@ class IntArrayBase { void AssignDataFromTensor(const T& tensor) { size_t n = tensor.numel(); + array_.reserve(n); switch (tensor.dtype()) { case DataType::INT32: diff --git a/paddle/phi/common/tensor_ref.h b/paddle/phi/common/tensor_ref.h index 471b3cd1482a80b8023ade8307bd79634b0e4fc6..b2ff0665e1b5932e0f1f196e20c458d51c7b99aa 100644 --- a/paddle/phi/common/tensor_ref.h +++ b/paddle/phi/common/tensor_ref.h @@ -21,10 +21,10 @@ #include "paddle/phi/api/ext/exception.h" #include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/enforce.h" namespace phi { -class TensorBase; // In static model pre analysis, we can't get the data from tensor class TensorRef { diff --git a/paddle/phi/core/attribute.h b/paddle/phi/core/attribute.h index b7c57eda016ab7afbf95a43dc9239d8de4a6647e..40c66a669c9e89c451654a7c7ec6af40366a23a3 100644 --- a/paddle/phi/core/attribute.h +++ b/paddle/phi/core/attribute.h @@ -48,7 +48,8 @@ using Attribute = paddle::variant; + TensorRef, + std::vector>; using AttributeMap = paddle::flat_hash_map; diff --git a/paddle/phi/core/infermeta_utils.cc b/paddle/phi/core/infermeta_utils.cc index 6adda5748c994824ea2075d1491316210a7287db..bab872f63fae456801ddb78261cb156110999a70 100644 --- a/paddle/phi/core/infermeta_utils.cc +++ b/paddle/phi/core/infermeta_utils.cc @@ -159,6 +159,8 @@ template const DataType& InferMetaContext::AttrAt(size_t idx) const; template const DataLayout& InferMetaContext::AttrAt(size_t idx) const; template const Place& InferMetaContext::AttrAt(size_t idx) const; template const TensorRef& InferMetaContext::AttrAt(size_t idx) const; +template const std::vector& InferMetaContext::AttrAt( + size_t idx) const; MetaFnFactory& MetaFnFactory::Instance() { static MetaFnFactory g_meta_fn_map; diff --git a/paddle/phi/core/infermeta_utils.h b/paddle/phi/core/infermeta_utils.h index ff4f9c93ca5075ce0a72919be4ec9a8cd88dcdee..bc6ef528d3ba93095d629944b7668005d3cb4863 100644 --- a/paddle/phi/core/infermeta_utils.h +++ b/paddle/phi/core/infermeta_utils.h @@ -118,8 +118,7 @@ class InferMetaContext { } \ } -#define PD_SPECIALIZE_InferMetaFnCallHelper_FOR_TENSOR_SCALAR_INTARRAY( \ - attr_type) \ +#define PD_SPECIALIZE_InferMetaFnCallHelper_FOR_TENSOR_SCALAR(attr_type) \ template \ struct InferMetaFnCallHelper { \ template \ @@ -141,6 +140,32 @@ class InferMetaContext { } \ } +#define PD_SPECIALIZE_InferMetaFnCallHelper_FOR_TENSOR_INTARRAY(attr_type) \ + template \ + struct InferMetaFnCallHelper { \ + template \ + static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) { \ + static_assert(out_idx == 0, \ + "InferMeta's Attributes should appear before Outputs."); \ + const Attribute& t = ctx->AttrAt(attr_idx); \ + static Attribute cmp_t = phi::TensorRef(nullptr); \ + static Attribute vec_ref = \ + std::vector({phi::TensorRef(nullptr)}); \ + attr_type attr1; \ + if (cmp_t.index() == t.index()) { \ + attr1 = attr_type((*paddle::get(t).Get())); \ + } else if (vec_ref.index() == t.index()) { \ + attr1 = attr_type(paddle::get>(t)); \ + } else { \ + attr1 = paddle::get(t); \ + } \ + InferMetaFnCallHelper< \ + Tail...>::template Call(ctx, \ + pargs..., \ + attr1); \ + } \ + } + template struct InferMetaTypeTag {}; @@ -222,8 +247,8 @@ struct InferMetaFnImpl { PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(Backend); PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataLayout); PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(std::string); - PD_SPECIALIZE_InferMetaFnCallHelper_FOR_TENSOR_SCALAR_INTARRAY(Scalar); - PD_SPECIALIZE_InferMetaFnCallHelper_FOR_TENSOR_SCALAR_INTARRAY(IntArray); + PD_SPECIALIZE_InferMetaFnCallHelper_FOR_TENSOR_SCALAR(Scalar); + PD_SPECIALIZE_InferMetaFnCallHelper_FOR_TENSOR_INTARRAY(IntArray); PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF( std::vector); PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector); diff --git a/paddle/phi/core/kernel_context.cc b/paddle/phi/core/kernel_context.cc index 35fa5b8312df2d878e0a8d023e7ba5defadba0d1..2d4863957b4252aeb3f6c7627d7bb97eb1c9cfc5 100644 --- a/paddle/phi/core/kernel_context.cc +++ b/paddle/phi/core/kernel_context.cc @@ -147,5 +147,6 @@ template const DataType& KernelContext::AttrAt(size_t idx) const; template const DataLayout& KernelContext::AttrAt(size_t idx) const; template const Place& KernelContext::AttrAt(size_t idx) const; template const TensorRef& KernelContext::AttrAt(size_t idx) const; +template const std::vector& KernelContext::AttrAt(size_t idx) const; } // namespace phi diff --git a/paddle/phi/core/kernel_utils.h b/paddle/phi/core/kernel_utils.h index f4dc4636bdde3f6d0c06b6a008fb3c944b7f7c42..38949f2816d969520943a28e4eaf6d581de91fbb 100644 --- a/paddle/phi/core/kernel_utils.h +++ b/paddle/phi/core/kernel_utils.h @@ -221,31 +221,59 @@ namespace phi { } \ } -#define PD_SPECIALIZE_KernelCallHelper_FOR_TENSOR_SCALAR_INTARRAY(attr_type) \ - template \ - struct KernelCallHelper { \ - template \ - static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \ - static_assert(out_idx == 0, \ - "Kernel's Attributes should appear before Outputs."); \ - const Attribute& t = ctx->AttrAt(attr_idx); \ - static Attribute cmp_t = phi::TensorRef(nullptr); \ - attr_type attr1; \ - if (cmp_t.index() == t.index()) { \ - attr1 = attr_type(*paddle::get(t).Get()); \ - } else { \ - attr1 = paddle::get(t); \ - } \ - KernelCallHelper:: \ - template Compute( \ - ctx, pargs..., attr1); \ - } \ +#define PD_SPECIALIZE_KernelCallHelper_FOR_TENSOR_SCALAR(attr_type) \ + template \ + struct KernelCallHelper { \ + template \ + static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \ + static_assert(out_idx == 0, \ + "Kernel's Attributes should appear before Outputs."); \ + const Attribute& t = ctx->AttrAt(attr_idx); \ + static Attribute cmp_t = phi::TensorRef(nullptr); \ + attr_type attr1; \ + if (cmp_t.index() == t.index()) { \ + attr1 = attr_type(*paddle::get(t).Get()); \ + } else { \ + attr1 = paddle::get(t); \ + } \ + KernelCallHelper:: \ + template Compute( \ + ctx, pargs..., attr1); \ + } \ } +#define PD_SPECIALIZE_KernelCallHelper_FOR_TENSOR_INTARRAY(attr_type) \ + template \ + struct KernelCallHelper { \ + template \ + static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \ + static_assert(out_idx == 0, \ + "Kernel's Attributes should appear before Outputs."); \ + const Attribute& t = ctx->AttrAt(attr_idx); \ + static Attribute cmp_t = phi::TensorRef(nullptr); \ + static Attribute vec_ref = \ + std::vector({phi::TensorRef(nullptr)}); \ + attr_type attr1; \ + if (cmp_t.index() == t.index()) { \ + attr1 = attr_type(*paddle::get(t).Get()); \ + } else if (vec_ref.index() == t.index()) { \ + attr1 = attr_type(paddle::get>(t)); \ + } else { \ + attr1 = paddle::get(t); \ + } \ + KernelCallHelper:: \ + template Compute( \ + ctx, pargs..., attr1); \ + } \ + } template struct TypeTag {}; @@ -325,8 +353,8 @@ struct KernelImpl { PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(DataLayout); PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(Place); PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::string); - PD_SPECIALIZE_KernelCallHelper_FOR_TENSOR_SCALAR_INTARRAY(Scalar); - PD_SPECIALIZE_KernelCallHelper_FOR_TENSOR_SCALAR_INTARRAY(IntArray); + PD_SPECIALIZE_KernelCallHelper_FOR_TENSOR_SCALAR(Scalar); + PD_SPECIALIZE_KernelCallHelper_FOR_TENSOR_INTARRAY(IntArray); PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector); PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector); PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector);