未验证 提交 e914f7fc 编写于 作者: W wanghuancoder 提交者: GitHub

[IR] Ir fill constant (#56520)

* support ir fill constant
上级 ae84c603
......@@ -145,6 +145,44 @@ static inline ir::Attribute TransToIrAttribute(phi::Scalar scalar,
}
}
inline DataType VarTypeToDataType(
::paddle::framework::proto::VarType_Type var_type) {
switch (var_type) {
case paddle::framework::proto::VarType_Type::VarType_Type_BOOL:
return DataType::BOOL;
case paddle::framework::proto::VarType_Type::VarType_Type_INT16:
return DataType::INT16;
case paddle::framework::proto::VarType_Type::VarType_Type_INT32:
return DataType::INT32;
case paddle::framework::proto::VarType_Type::VarType_Type_INT64:
return DataType::INT64;
case paddle::framework::proto::VarType_Type::VarType_Type_FP16:
return DataType::FLOAT16;
case paddle::framework::proto::VarType_Type::VarType_Type_FP32:
return DataType::FLOAT32;
case paddle::framework::proto::VarType_Type::VarType_Type_FP64:
return DataType::FLOAT64;
case paddle::framework::proto::VarType_Type::VarType_Type_SIZE_T:
return DataType::UINT64;
case paddle::framework::proto::VarType_Type::VarType_Type_UINT8:
return DataType::UINT8;
case paddle::framework::proto::VarType_Type::VarType_Type_INT8:
return DataType::INT8;
case paddle::framework::proto::VarType_Type::VarType_Type_BF16:
return DataType::BFLOAT16;
case paddle::framework::proto::VarType_Type::VarType_Type_COMPLEX64:
return DataType::COMPLEX64;
case paddle::framework::proto::VarType_Type::VarType_Type_COMPLEX128:
return DataType::COMPLEX128;
case paddle::framework::proto::VarType_Type::VarType_Type_PSTRING:
return DataType::PSTRING;
default:
PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported proto::VarType_Type `%s` when casting it into DataType.",
var_type));
}
}
VariantType GetAttributeData(const ir::Attribute& attr);
bool IsLegacyOp(const std::string& name);
......
......@@ -195,6 +195,20 @@ inline ir::Operation* InsertFullArrayOperationForAttributeInput(
return full_int_array_op.operation();
}
inline ir::Operation* InsertStackOperationForTarget(
ir::IrContext* ctx,
TranslationContext* param_map,
ir::Program* program,
const std::vector<std::string>& args,
int axis = 0) {
auto* combine_op =
InsertCombineOperationForTarget(ctx, param_map, program, args);
ir::Builder builder(ctx, program->block());
dialect::StackOp stack_op =
builder.Build<dialect::StackOp>(combine_op->result(0), axis);
return stack_op.operation();
}
inline ir::OpResult GetAttributeAsInput(ir::IrContext* ctx,
ir::Program* program,
const OpDesc& op_desc,
......@@ -1175,6 +1189,168 @@ struct TrilAndTriuOpTranscriber : public OpTranscriber {
}
};
struct FillConstant2FullTranscriber : public OpTranscriber {
ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) override {
const auto& op_info = ctx->GetRegisteredOpInfo(dialect::FullOp::name());
if (!op_info) {
IR_THROW("Op fill_constant should have corresponding OpInfo pd.full");
}
return op_info;
}
std::vector<ir::OpResult> GenerateOperationInput(
ir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
const std::string& normalized_op_name,
const OpInputInfoList& input_infos,
ir::Program* program) override {
return {};
}
ir::AttributeMap TranslateOpAttribute(
ir::IrContext* ctx,
const std::string& normalized_op_name,
const OpAttributeInfoList& op_attr_infos,
const OpDesc& op_desc) override {
std::vector<int64_t> shape =
PADDLE_GET_CONST(std::vector<int64_t>, op_desc.GetAttr("shape"));
float value = PADDLE_GET_CONST(float, op_desc.GetAttr("value"));
int dtype = PADDLE_GET_CONST(int, op_desc.GetAttr("dtype"));
auto attr_value = ir::FloatAttribute::get(ctx, value);
ir::AttributeMap attribute_map = {
{"shape",
paddle::dialect::IntArrayAttribute::get(ctx, phi::IntArray(shape))},
{"value", attr_value.dyn_cast<paddle::dialect::ScalarAttribute>()},
{"dtype",
paddle::dialect::DataTypeAttribute::get(
ctx,
paddle::dialect::VarTypeToDataType(
static_cast<paddle::framework::proto::VarType_Type>(dtype)))}};
int place_type = PADDLE_GET_CONST(int, op_desc.GetAttr("place_type"));
switch (place_type) {
case -1:
attribute_map["place"] =
paddle::dialect::PlaceAttribute::get(ctx, phi::CPUPlace());
break;
case 0:
attribute_map["place"] =
paddle::dialect::PlaceAttribute::get(ctx, phi::CPUPlace());
break;
case 1:
attribute_map["place"] =
paddle::dialect::PlaceAttribute::get(ctx, phi::GPUPlace());
break;
case 2:
attribute_map["place"] =
paddle::dialect::PlaceAttribute::get(ctx, phi::GPUPinnedPlace());
break;
case 3:
attribute_map["place"] =
paddle::dialect::PlaceAttribute::get(ctx, phi::XPUPlace());
break;
}
return attribute_map;
}
};
struct FillConstant2FullWithTensorTranscriber : public OpTranscriber {
ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) override {
const auto& op_info = ctx->GetRegisteredOpInfo("pd.full_with_tensor");
if (!op_info) {
IR_THROW(
"Op fill_constant should have corresponding OpInfo "
"pd.full_with_tensor");
}
return op_info;
}
std::vector<ir::OpResult> GenerateOperationInput(
ir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
const std::string& normalized_op_name,
const OpInputInfoList& input_infos,
ir::Program* program) override {
std::vector<ir::OpResult> op_inputs;
if (op_desc.HasInput("ShapeTensor", true) &&
op_desc.Input("ShapeTensor", true).size() > 0) {
auto shape_tensor_vars = op_desc.Input("ShapeTensor", true);
auto defining_info = (*param_map)[shape_tensor_vars[0]];
op_inputs.push_back(defining_info.value);
} else if (op_desc.HasInput("ShapeTensorList", true) &&
op_desc.Input("ShapeTensorList", true).size() > 0) {
auto shape_tensor_list_vars = op_desc.Input("ShapeTensorList", true);
auto defining_op = InsertStackOperationForTarget(
ctx, param_map, program, shape_tensor_list_vars);
op_inputs.push_back(defining_op->result(0));
} else {
std::vector<int64_t> shape =
PADDLE_GET_CONST(std::vector<int64_t>, op_desc.GetAttr("shape"));
ir::Attribute new_attr =
paddle::dialect::IntArrayAttribute::get(ctx, phi::IntArray(shape));
auto defining_op =
InsertFullArrayOperationForAttributeInput(ctx, program, new_attr);
op_inputs.push_back(defining_op->result(0));
}
if (op_desc.HasInput("ValueTensor", true) &&
op_desc.Input("ValueTensor", true).size() > 0) {
auto value_tensor_vars = op_desc.Input("ValueTensor", true);
auto defining_info = (*param_map)[value_tensor_vars[0]];
op_inputs.push_back(defining_info.value);
} else {
float value = PADDLE_GET_CONST(float, op_desc.GetAttr("value"));
ir::Attribute new_attr = ir::FloatAttribute::get(ctx, value);
auto defining_op =
InsertFullOperationForAttributeInput(ctx, program, new_attr);
op_inputs.push_back(defining_op->result(0));
}
return op_inputs;
}
ir::AttributeMap TranslateOpAttribute(
ir::IrContext* ctx,
const std::string& normalized_op_name,
const OpAttributeInfoList& op_attr_infos,
const OpDesc& op_desc) override {
int dtype = PADDLE_GET_CONST(int, op_desc.GetAttr("dtype"));
ir::AttributeMap attribute_map = {
{"dtype",
paddle::dialect::DataTypeAttribute::get(
ctx,
paddle::dialect::VarTypeToDataType(
static_cast<paddle::framework::proto::VarType_Type>(dtype)))}};
return attribute_map;
}
};
struct FillConstantTranscriber : public OpTranscriber {
ir::Operation* operator()(ir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
ir::Program* program) override {
bool has_mutable_attribute = op_desc.HasInput("ShapeTensor", true) &&
op_desc.Input("ShapeTensor", true).size() > 0;
has_mutable_attribute |= op_desc.HasInput("ShapeTensorList", true) &&
op_desc.Input("ShapeTensorList", true).size() > 0;
has_mutable_attribute |= op_desc.HasInput("ValueTensor", true) &&
op_desc.Input("ValueTensor", true).size() > 0;
if (!has_mutable_attribute) {
return FillConstant2FullTranscriber()(ctx, param_map, op_desc, program);
} else {
return FillConstant2FullWithTensorTranscriber()(
ctx, param_map, op_desc, program);
}
}
};
ir::OpResult TranslateNumClassesForOneHot(ir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
......@@ -1474,6 +1650,7 @@ OpTranslator::OpTranslator() {
special_handlers["split"] = SplitOpTranscriber();
special_handlers["sum"] = AddNOpTranscriber();
special_handlers["tril_triu"] = TrilAndTriuOpTranscriber();
special_handlers["fill_constant"] = FillConstantTranscriber();
// special handler for elementwise ops with axis != -1
// note(lyk): maybe we should do this by a pass, which seems more reasonable
......
......@@ -404,6 +404,16 @@
data_transform :
skip_transform : x
- op : full_with_tensor
args : (Tensor shape, Tensor value, DataType dtype=DataType::FLOAT32)
output: Tensor(out)
infer_meta :
func : FullWithTensorInferMeta
param : [shape, dtype]
kernel :
func : full_with_tensor
data_type : dtype
- op : fused_adam_
args : (Tensor[] params, Tensor[] grads, Tensor learning_rate, Tensor[] moments1, Tensor[] moments2, Tensor[] beta1_pows, Tensor[] beta2_pows, Tensor[] master_params, Tensor skip_update, Scalar beta1, Scalar beta2, Scalar epsilon, int chunk_size, float weight_decay, bool use_adamw, bool multi_precision, bool use_global_beta_pow)
output : Tensor[](params_out){params.size()}, Tensor[](moments1_out){params.size()}, Tensor[](moments2_out){params.size()}, Tensor[](beta1_pows_out){params.size()}, Tensor[](beta2_pows_out){params.size()}, Tensor[](master_params_out){params.size()}
......
......@@ -19,12 +19,14 @@ limitations under the License. */
#include "glog/logging.h"
#include "paddle/phi/backends/device_memory_aligment.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/core/meta_tensor.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/infermeta/binary.h"
#include "paddle/phi/infermeta/nullary.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/concat_funcs.h"
......@@ -4154,5 +4156,12 @@ void MaskedMultiheadAttentionInferMeta(const MetaTensor& x,
}
}
void FullWithTensorInferMeta(const MetaTensor& shape,
DataType dtype,
MetaTensor* out) {
out->set_dims(make_ddim({-1}));
out->set_dtype(dtype);
}
} // namespace phi
PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta);
......@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/meta_tensor.h"
namespace phi {
// Common InferMeta Functions for multiary operators, The format like:
......@@ -819,4 +820,8 @@ void MaskedMultiheadAttentionInferMeta(const MetaTensor& x,
MetaTensor* cache_kv_out,
MetaTensor* beam_cache_offset_out);
void FullWithTensorInferMeta(const MetaTensor& shape,
DataType dtype,
MetaTensor* out);
} // namespace phi
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/impl/full_whit_tensor_kernel_impl.h"
namespace phi {
......@@ -130,3 +131,20 @@ PD_REGISTER_KERNEL(full_like,
PD_REGISTER_KERNEL(
full_int_array, CPU, ALL_LAYOUT, phi::FullIntArrayKernel, int, int64_t) {}
PD_REGISTER_KERNEL(full_with_tensor,
CPU,
ALL_LAYOUT,
phi::FullWithTensorKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -31,6 +31,13 @@ void FullKernel(const Context& dev_ctx,
DataType dtype,
DenseTensor* out);
template <typename T, typename Context>
void FullWithTensorKernel(const Context& dev_ctx,
const DenseTensor& shape,
const DenseTensor& value,
DataType dtype,
DenseTensor* out);
template <typename T, typename Context>
void FullLikeKernel(const Context& dev_ctx,
const DenseTensor& x,
......
......@@ -17,6 +17,8 @@ limitations under the License. */
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/impl/full_whit_tensor_kernel_impl.h"
namespace phi {
template <typename InT, typename OutT = InT>
......@@ -144,3 +146,20 @@ PD_REGISTER_KERNEL(full_like,
phi::dtype::complex<double>) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(full_with_tensor,
GPU,
ALL_LAYOUT,
phi::FullWithTensorKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/full_kernel.h"
namespace phi {
template <typename T, typename Context>
void FullWithTensorKernel(const Context& dev_ctx,
const DenseTensor& shape,
const DenseTensor& value,
DataType dtype,
DenseTensor* out) {
auto shape_tmp = IntArray(shape);
out->Resize(phi::make_ddim(shape_tmp.GetData()));
FullKernel<T, Context>(dev_ctx, shape_tmp, Scalar(value), dtype, out);
}
} // namespace phi
......@@ -35,6 +35,16 @@ void FullKernel(const Context& dev_ctx,
phi::FullKernel<T>(dev_ctx, shape, val, dtype, out->mutable_value());
}
template <typename T, typename Context>
void FullWithTensorKernel(const Context& dev_ctx,
const DenseTensor& shape,
const DenseTensor& value,
DataType dtype,
SelectedRows* out) {
phi::FullWithTensorKernel<T>(
dev_ctx, shape, value, dtype, out->mutable_value());
}
} // namespace sr
} // namespace phi
......@@ -84,3 +94,50 @@ PD_REGISTER_KERNEL(full_sr,
bool,
phi::dtype::float16) {}
#endif
PD_REGISTER_KERNEL(full_with_tensor_sr,
CPU,
ALL_LAYOUT,
phi::sr::FullWithTensorKernel,
float,
double,
uint8_t,
int16_t,
int,
int64_t,
bool,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(full_with_tensor_sr,
GPU,
ALL_LAYOUT,
phi::sr::FullWithTensorKernel,
float,
double,
uint8_t,
int16_t,
int,
int64_t,
bool,
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#endif
#if defined(PADDLE_WITH_XPU)
PD_REGISTER_KERNEL(full_with_tensor_sr,
XPU,
ALL_LAYOUT,
phi::sr::FullWithTensorKernel,
float,
uint8_t,
int16_t,
int,
int64_t,
bool,
phi::dtype::float16) {}
#endif
......@@ -28,5 +28,11 @@ void FullKernel(const Context& dev_ctx,
DataType dtype,
SelectedRows* out);
template <typename T, typename Context>
void FullWithTensorKernel(const Context& dev_ctx,
const DenseTensor& shape,
const DenseTensor& value,
DataType dtype,
SelectedRows* out);
} // namespace sr
} // namespace phi
......@@ -23,6 +23,7 @@
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/impl/full_whit_tensor_kernel_impl.h"
namespace phi {
......@@ -152,3 +153,16 @@ PD_REGISTER_KERNEL(full_batch_size_like,
phi::dtype::float16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(full_with_tensor,
XPU,
ALL_LAYOUT,
phi::FullWithTensorKernel,
float,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool,
phi::dtype::float16) {}
......@@ -194,3 +194,4 @@ test_warprnnt_op
test_where_op
test_yolo_box_op
test_yolov3_loss_op
test_fill_constant_op
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册