diff --git a/paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h b/paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h index cb0c59671827cff09cb47dfbf11e83d42b55ebe8..37063cc8217c46d639c27bf86832630a15747bda 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h +++ b/paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h @@ -145,6 +145,44 @@ static inline ir::Attribute TransToIrAttribute(phi::Scalar scalar, } } +inline DataType VarTypeToDataType( + ::paddle::framework::proto::VarType_Type var_type) { + switch (var_type) { + case paddle::framework::proto::VarType_Type::VarType_Type_BOOL: + return DataType::BOOL; + case paddle::framework::proto::VarType_Type::VarType_Type_INT16: + return DataType::INT16; + case paddle::framework::proto::VarType_Type::VarType_Type_INT32: + return DataType::INT32; + case paddle::framework::proto::VarType_Type::VarType_Type_INT64: + return DataType::INT64; + case paddle::framework::proto::VarType_Type::VarType_Type_FP16: + return DataType::FLOAT16; + case paddle::framework::proto::VarType_Type::VarType_Type_FP32: + return DataType::FLOAT32; + case paddle::framework::proto::VarType_Type::VarType_Type_FP64: + return DataType::FLOAT64; + case paddle::framework::proto::VarType_Type::VarType_Type_SIZE_T: + return DataType::UINT64; + case paddle::framework::proto::VarType_Type::VarType_Type_UINT8: + return DataType::UINT8; + case paddle::framework::proto::VarType_Type::VarType_Type_INT8: + return DataType::INT8; + case paddle::framework::proto::VarType_Type::VarType_Type_BF16: + return DataType::BFLOAT16; + case paddle::framework::proto::VarType_Type::VarType_Type_COMPLEX64: + return DataType::COMPLEX64; + case paddle::framework::proto::VarType_Type::VarType_Type_COMPLEX128: + return DataType::COMPLEX128; + case paddle::framework::proto::VarType_Type::VarType_Type_PSTRING: + return DataType::PSTRING; + default: + PADDLE_THROW(phi::errors::Unimplemented( + "Unsupported proto::VarType_Type `%s` when casting it into DataType.", + var_type)); + } +} + VariantType GetAttributeData(const ir::Attribute& attr); bool IsLegacyOp(const std::string& name); diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 59887078d050b89a2b2039c79c5e3f8f9d1e6270..27a06849a0648de9b94fb379fa68323179ab652a 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -195,6 +195,20 @@ inline ir::Operation* InsertFullArrayOperationForAttributeInput( return full_int_array_op.operation(); } +inline ir::Operation* InsertStackOperationForTarget( + ir::IrContext* ctx, + TranslationContext* param_map, + ir::Program* program, + const std::vector& args, + int axis = 0) { + auto* combine_op = + InsertCombineOperationForTarget(ctx, param_map, program, args); + ir::Builder builder(ctx, program->block()); + dialect::StackOp stack_op = + builder.Build(combine_op->result(0), axis); + return stack_op.operation(); +} + inline ir::OpResult GetAttributeAsInput(ir::IrContext* ctx, ir::Program* program, const OpDesc& op_desc, @@ -1175,6 +1189,168 @@ struct TrilAndTriuOpTranscriber : public OpTranscriber { } }; +struct FillConstant2FullTranscriber : public OpTranscriber { + ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) override { + const auto& op_info = ctx->GetRegisteredOpInfo(dialect::FullOp::name()); + if (!op_info) { + IR_THROW("Op fill_constant should have corresponding OpInfo pd.full"); + } + + return op_info; + } + + std::vector GenerateOperationInput( + ir::IrContext* ctx, + TranslationContext* param_map, + const OpDesc& op_desc, + const std::string& normalized_op_name, + const OpInputInfoList& input_infos, + ir::Program* program) override { + return {}; + } + + ir::AttributeMap TranslateOpAttribute( + ir::IrContext* ctx, + const std::string& normalized_op_name, + const OpAttributeInfoList& op_attr_infos, + const OpDesc& op_desc) override { + std::vector shape = + PADDLE_GET_CONST(std::vector, op_desc.GetAttr("shape")); + float value = PADDLE_GET_CONST(float, op_desc.GetAttr("value")); + int dtype = PADDLE_GET_CONST(int, op_desc.GetAttr("dtype")); + + auto attr_value = ir::FloatAttribute::get(ctx, value); + + ir::AttributeMap attribute_map = { + {"shape", + paddle::dialect::IntArrayAttribute::get(ctx, phi::IntArray(shape))}, + {"value", attr_value.dyn_cast()}, + {"dtype", + paddle::dialect::DataTypeAttribute::get( + ctx, + paddle::dialect::VarTypeToDataType( + static_cast(dtype)))}}; + int place_type = PADDLE_GET_CONST(int, op_desc.GetAttr("place_type")); + switch (place_type) { + case -1: + attribute_map["place"] = + paddle::dialect::PlaceAttribute::get(ctx, phi::CPUPlace()); + break; + case 0: + attribute_map["place"] = + paddle::dialect::PlaceAttribute::get(ctx, phi::CPUPlace()); + break; + case 1: + attribute_map["place"] = + paddle::dialect::PlaceAttribute::get(ctx, phi::GPUPlace()); + break; + case 2: + attribute_map["place"] = + paddle::dialect::PlaceAttribute::get(ctx, phi::GPUPinnedPlace()); + break; + case 3: + attribute_map["place"] = + paddle::dialect::PlaceAttribute::get(ctx, phi::XPUPlace()); + break; + } + return attribute_map; + } +}; + +struct FillConstant2FullWithTensorTranscriber : public OpTranscriber { + ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) override { + const auto& op_info = ctx->GetRegisteredOpInfo("pd.full_with_tensor"); + if (!op_info) { + IR_THROW( + "Op fill_constant should have corresponding OpInfo " + "pd.full_with_tensor"); + } + + return op_info; + } + + std::vector GenerateOperationInput( + ir::IrContext* ctx, + TranslationContext* param_map, + const OpDesc& op_desc, + const std::string& normalized_op_name, + const OpInputInfoList& input_infos, + ir::Program* program) override { + std::vector op_inputs; + if (op_desc.HasInput("ShapeTensor", true) && + op_desc.Input("ShapeTensor", true).size() > 0) { + auto shape_tensor_vars = op_desc.Input("ShapeTensor", true); + auto defining_info = (*param_map)[shape_tensor_vars[0]]; + op_inputs.push_back(defining_info.value); + } else if (op_desc.HasInput("ShapeTensorList", true) && + op_desc.Input("ShapeTensorList", true).size() > 0) { + auto shape_tensor_list_vars = op_desc.Input("ShapeTensorList", true); + auto defining_op = InsertStackOperationForTarget( + ctx, param_map, program, shape_tensor_list_vars); + op_inputs.push_back(defining_op->result(0)); + } else { + std::vector shape = + PADDLE_GET_CONST(std::vector, op_desc.GetAttr("shape")); + ir::Attribute new_attr = + paddle::dialect::IntArrayAttribute::get(ctx, phi::IntArray(shape)); + auto defining_op = + InsertFullArrayOperationForAttributeInput(ctx, program, new_attr); + op_inputs.push_back(defining_op->result(0)); + } + + if (op_desc.HasInput("ValueTensor", true) && + op_desc.Input("ValueTensor", true).size() > 0) { + auto value_tensor_vars = op_desc.Input("ValueTensor", true); + auto defining_info = (*param_map)[value_tensor_vars[0]]; + op_inputs.push_back(defining_info.value); + } else { + float value = PADDLE_GET_CONST(float, op_desc.GetAttr("value")); + ir::Attribute new_attr = ir::FloatAttribute::get(ctx, value); + auto defining_op = + InsertFullOperationForAttributeInput(ctx, program, new_attr); + op_inputs.push_back(defining_op->result(0)); + } + return op_inputs; + } + + ir::AttributeMap TranslateOpAttribute( + ir::IrContext* ctx, + const std::string& normalized_op_name, + const OpAttributeInfoList& op_attr_infos, + const OpDesc& op_desc) override { + int dtype = PADDLE_GET_CONST(int, op_desc.GetAttr("dtype")); + + ir::AttributeMap attribute_map = { + {"dtype", + paddle::dialect::DataTypeAttribute::get( + ctx, + paddle::dialect::VarTypeToDataType( + static_cast(dtype)))}}; + return attribute_map; + } +}; + +struct FillConstantTranscriber : public OpTranscriber { + ir::Operation* operator()(ir::IrContext* ctx, + TranslationContext* param_map, + const OpDesc& op_desc, + ir::Program* program) override { + bool has_mutable_attribute = op_desc.HasInput("ShapeTensor", true) && + op_desc.Input("ShapeTensor", true).size() > 0; + has_mutable_attribute |= op_desc.HasInput("ShapeTensorList", true) && + op_desc.Input("ShapeTensorList", true).size() > 0; + has_mutable_attribute |= op_desc.HasInput("ValueTensor", true) && + op_desc.Input("ValueTensor", true).size() > 0; + + if (!has_mutable_attribute) { + return FillConstant2FullTranscriber()(ctx, param_map, op_desc, program); + } else { + return FillConstant2FullWithTensorTranscriber()( + ctx, param_map, op_desc, program); + } + } +}; + ir::OpResult TranslateNumClassesForOneHot(ir::IrContext* ctx, TranslationContext* param_map, const OpDesc& op_desc, @@ -1474,6 +1650,7 @@ OpTranslator::OpTranslator() { special_handlers["split"] = SplitOpTranscriber(); special_handlers["sum"] = AddNOpTranscriber(); special_handlers["tril_triu"] = TrilAndTriuOpTranscriber(); + special_handlers["fill_constant"] = FillConstantTranscriber(); // special handler for elementwise ops with axis != -1 // note(lyk): maybe we should do this by a pass, which seems more reasonable diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 4d8fd7a6b8f78cff7bf590a9fd9ee0f3078fe65b..74f040d9bbbd4771ca5990bf0a9983e0c7a4748a 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -404,6 +404,16 @@ data_transform : skip_transform : x +- op : full_with_tensor + args : (Tensor shape, Tensor value, DataType dtype=DataType::FLOAT32) + output: Tensor(out) + infer_meta : + func : FullWithTensorInferMeta + param : [shape, dtype] + kernel : + func : full_with_tensor + data_type : dtype + - op : fused_adam_ args : (Tensor[] params, Tensor[] grads, Tensor learning_rate, Tensor[] moments1, Tensor[] moments2, Tensor[] beta1_pows, Tensor[] beta2_pows, Tensor[] master_params, Tensor skip_update, Scalar beta1, Scalar beta2, Scalar epsilon, int chunk_size, float weight_decay, bool use_adamw, bool multi_precision, bool use_global_beta_pow) output : Tensor[](params_out){params.size()}, Tensor[](moments1_out){params.size()}, Tensor[](moments2_out){params.size()}, Tensor[](beta1_pows_out){params.size()}, Tensor[](beta2_pows_out){params.size()}, Tensor[](master_params_out){params.size()} diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index fe7a596f69a1c3e9bba6b84d34cf36b31dc8ed12..3e987b63f8af128d8c514529e3efceb25d5d6512 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -19,12 +19,14 @@ limitations under the License. */ #include "glog/logging.h" #include "paddle/phi/backends/device_memory_aligment.h" +#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/layout.h" #include "paddle/phi/common/scalar.h" #include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/core/meta_tensor.h" #include "paddle/phi/core/utils/data_type.h" #include "paddle/phi/infermeta/binary.h" +#include "paddle/phi/infermeta/nullary.h" #include "paddle/phi/kernels/funcs/common_shape.h" #include "paddle/phi/kernels/funcs/concat_funcs.h" @@ -4154,5 +4156,12 @@ void MaskedMultiheadAttentionInferMeta(const MetaTensor& x, } } +void FullWithTensorInferMeta(const MetaTensor& shape, + DataType dtype, + MetaTensor* out) { + out->set_dims(make_ddim({-1})); + out->set_dtype(dtype); +} + } // namespace phi PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta); diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 2a0942262f8b40ff11e0cac09466e318ebea336c..c427f7e8fcc2985c40caf6eb94add4ef5c36320c 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/phi/common/int_array.h" #include "paddle/phi/common/scalar.h" #include "paddle/phi/core/meta_tensor.h" + namespace phi { // Common InferMeta Functions for multiary operators, The format like: @@ -819,4 +820,8 @@ void MaskedMultiheadAttentionInferMeta(const MetaTensor& x, MetaTensor* cache_kv_out, MetaTensor* beam_cache_offset_out); +void FullWithTensorInferMeta(const MetaTensor& shape, + DataType dtype, + MetaTensor* out); + } // namespace phi diff --git a/paddle/phi/kernels/cpu/full_kernel.cc b/paddle/phi/kernels/cpu/full_kernel.cc index 20d50d88f68a3c556b7e95c8e56753359edfe01b..54a4b781f76c7bfa176e9cad389cc8bf8a24dd91 100644 --- a/paddle/phi/kernels/cpu/full_kernel.cc +++ b/paddle/phi/kernels/cpu/full_kernel.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" +#include "paddle/phi/kernels/impl/full_whit_tensor_kernel_impl.h" namespace phi { @@ -130,3 +131,20 @@ PD_REGISTER_KERNEL(full_like, PD_REGISTER_KERNEL( full_int_array, CPU, ALL_LAYOUT, phi::FullIntArrayKernel, int, int64_t) {} + +PD_REGISTER_KERNEL(full_with_tensor, + CPU, + ALL_LAYOUT, + phi::FullWithTensorKernel, + float, + double, + int8_t, + uint8_t, + int16_t, + int, + int64_t, + bool, + phi::dtype::float16, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/full_kernel.h b/paddle/phi/kernels/full_kernel.h index a5c5c705c9edf3fb5d32395307efecbdbad70573..cef58433e9e04f7ae2e2b376c3cc889f6380f87c 100644 --- a/paddle/phi/kernels/full_kernel.h +++ b/paddle/phi/kernels/full_kernel.h @@ -31,6 +31,13 @@ void FullKernel(const Context& dev_ctx, DataType dtype, DenseTensor* out); +template +void FullWithTensorKernel(const Context& dev_ctx, + const DenseTensor& shape, + const DenseTensor& value, + DataType dtype, + DenseTensor* out); + template void FullLikeKernel(const Context& dev_ctx, const DenseTensor& x, diff --git a/paddle/phi/kernels/gpu/full_kernel.cu b/paddle/phi/kernels/gpu/full_kernel.cu index c3d331c6dcce9234bd4ffc08bae779982e8c549f..1babbd12f831e2db30e80e8fe86048a547fa34f8 100644 --- a/paddle/phi/kernels/gpu/full_kernel.cu +++ b/paddle/phi/kernels/gpu/full_kernel.cu @@ -17,6 +17,8 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/elementwise_base.h" +#include "paddle/phi/kernels/impl/full_whit_tensor_kernel_impl.h" + namespace phi { template @@ -144,3 +146,20 @@ PD_REGISTER_KERNEL(full_like, phi::dtype::complex) { kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); } + +PD_REGISTER_KERNEL(full_with_tensor, + GPU, + ALL_LAYOUT, + phi::FullWithTensorKernel, + float, + double, + int8_t, + uint8_t, + int16_t, + int, + int64_t, + bool, + phi::dtype::float16, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/impl/full_whit_tensor_kernel_impl.h b/paddle/phi/kernels/impl/full_whit_tensor_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..a78af4f98c2b5c8fd3c126911041713d02dc298f --- /dev/null +++ b/paddle/phi/kernels/impl/full_whit_tensor_kernel_impl.h @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/full_kernel.h" + +namespace phi { + +template +void FullWithTensorKernel(const Context& dev_ctx, + const DenseTensor& shape, + const DenseTensor& value, + DataType dtype, + DenseTensor* out) { + auto shape_tmp = IntArray(shape); + out->Resize(phi::make_ddim(shape_tmp.GetData())); + FullKernel(dev_ctx, shape_tmp, Scalar(value), dtype, out); +} +} // namespace phi diff --git a/paddle/phi/kernels/selected_rows/full_kernel.cc b/paddle/phi/kernels/selected_rows/full_kernel.cc index e04139448dddc2f942886e2abd98b9c8c4431fd9..ead1fc1626f8b053895923ee2410a5c3692f38fa 100644 --- a/paddle/phi/kernels/selected_rows/full_kernel.cc +++ b/paddle/phi/kernels/selected_rows/full_kernel.cc @@ -35,6 +35,16 @@ void FullKernel(const Context& dev_ctx, phi::FullKernel(dev_ctx, shape, val, dtype, out->mutable_value()); } +template +void FullWithTensorKernel(const Context& dev_ctx, + const DenseTensor& shape, + const DenseTensor& value, + DataType dtype, + SelectedRows* out) { + phi::FullWithTensorKernel( + dev_ctx, shape, value, dtype, out->mutable_value()); +} + } // namespace sr } // namespace phi @@ -84,3 +94,50 @@ PD_REGISTER_KERNEL(full_sr, bool, phi::dtype::float16) {} #endif + +PD_REGISTER_KERNEL(full_with_tensor_sr, + CPU, + ALL_LAYOUT, + phi::sr::FullWithTensorKernel, + float, + double, + uint8_t, + int16_t, + int, + int64_t, + bool, + phi::dtype::float16, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PD_REGISTER_KERNEL(full_with_tensor_sr, + GPU, + ALL_LAYOUT, + phi::sr::FullWithTensorKernel, + float, + double, + uint8_t, + int16_t, + int, + int64_t, + bool, + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} +#endif + +#if defined(PADDLE_WITH_XPU) +PD_REGISTER_KERNEL(full_with_tensor_sr, + XPU, + ALL_LAYOUT, + phi::sr::FullWithTensorKernel, + float, + uint8_t, + int16_t, + int, + int64_t, + bool, + phi::dtype::float16) {} +#endif diff --git a/paddle/phi/kernels/selected_rows/full_kernel.h b/paddle/phi/kernels/selected_rows/full_kernel.h index d4b1859fdfcfb6ec6bdfccfcc9dfc28741b12bc5..07cfe7fd6378b14e9639e7cbc20c1f98aa46fa38 100644 --- a/paddle/phi/kernels/selected_rows/full_kernel.h +++ b/paddle/phi/kernels/selected_rows/full_kernel.h @@ -28,5 +28,11 @@ void FullKernel(const Context& dev_ctx, DataType dtype, SelectedRows* out); +template +void FullWithTensorKernel(const Context& dev_ctx, + const DenseTensor& shape, + const DenseTensor& value, + DataType dtype, + SelectedRows* out); } // namespace sr } // namespace phi diff --git a/paddle/phi/kernels/xpu/full_kernel.cc b/paddle/phi/kernels/xpu/full_kernel.cc index f1754b0631ad4f87ef729857bfe106d6ee6bfdfe..4d28fd74107672d36907a8c7ea28adac462ee04b 100644 --- a/paddle/phi/kernels/xpu/full_kernel.cc +++ b/paddle/phi/kernels/xpu/full_kernel.cc @@ -23,6 +23,7 @@ #include "paddle/phi/common/scalar.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/visit_type.h" +#include "paddle/phi/kernels/impl/full_whit_tensor_kernel_impl.h" namespace phi { @@ -152,3 +153,16 @@ PD_REGISTER_KERNEL(full_batch_size_like, phi::dtype::float16) { kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); } + +PD_REGISTER_KERNEL(full_with_tensor, + XPU, + ALL_LAYOUT, + phi::FullWithTensorKernel, + float, + int8_t, + uint8_t, + int16_t, + int, + int64_t, + bool, + phi::dtype::float16) {} diff --git a/test/white_list/new_ir_op_test_white_list b/test/white_list/new_ir_op_test_white_list index 7bd1c73c485912bd4a2437c6be3c735d399024e9..57d6154bfc7d561d210046b100c808fdc3491dfa 100644 --- a/test/white_list/new_ir_op_test_white_list +++ b/test/white_list/new_ir_op_test_white_list @@ -194,3 +194,4 @@ test_warprnnt_op test_where_op test_yolo_box_op test_yolov3_loss_op +test_fill_constant_op