diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index b1d7059f311cd370a40e83d7b0016d5af8cdb163..dec8d1d846c86ce954233ee6db2e9199a39d6ce1 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -442,6 +442,51 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, attr_name, infershape_input.size())); } } + } else if (attr_defs[i].type_index == + std::type_index(typeid(std::vector))) { + auto& attr = attr_reader.GetAttr(attr_name); + if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + const auto& vec = BOOST_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + infer_meta_context.EmplaceBackAttr(std::move(scalar_list)); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + const auto& vec = BOOST_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + infer_meta_context.EmplaceBackAttr(std::move(scalar_list)); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + const auto& vec = BOOST_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + infer_meta_context.EmplaceBackAttr(std::move(scalar_list)); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + const auto& vec = BOOST_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + infer_meta_context.EmplaceBackAttr(std::move(scalar_list)); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported cast op attribute `%s` to vector when " + "construct InferMetaContext.", + attr_names[i])); + } } else if (ctx->HasAttr(attr_name)) { // Emplace Back Attr according to the type of attr. auto& attr = attr_reader.GetAttr(attr_name); diff --git a/paddle/fluid/operators/set_value_op.cc b/paddle/fluid/operators/set_value_op.cc index 513ab46e9b5eebdb39faf4401d9d8b2fc387a82f..73655bcb18500e54564936eac4400a0c7b49af62 100644 --- a/paddle/fluid/operators/set_value_op.cc +++ b/paddle/fluid/operators/set_value_op.cc @@ -13,9 +13,15 @@ // limitations under the License. #include "paddle/fluid/operators/set_value_op.h" + #include + +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" + namespace paddle { namespace framework { class InferShapeContext; @@ -34,6 +40,8 @@ class CPUDeviceContext; namespace paddle { namespace operators { +using Tensor = framework::Tensor; + class SetValue : public framework::OperatorWithKernel { public: SetValue(const std::string &type, const framework::VariableNameMap &inputs, @@ -41,17 +49,6 @@ class SetValue : public framework::OperatorWithKernel { const framework::AttributeMap &attrs) : OperatorWithKernel(type, inputs, outputs, attrs) {} - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "SetValue"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SetValue"); - auto in_dims = ctx->GetInputDim("Input"); - PADDLE_ENFORCE_LT( - in_dims.size(), 7, - platform::errors::InvalidArgument( - "The rank of input should be less than 7, but received %d.", - in_dims.size())); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { @@ -236,10 +233,13 @@ DECLARE_INPLACE_OP_INFERER(SetValueOpInplaceInferer, {"Input", "Out"}); namespace ops = paddle::operators; namespace plat = paddle::platform; +DECLARE_INFER_SHAPE_FUNCTOR(set_value, SetValueInferShapeFunctor, + PD_INFER_META(phi::SetValueInferMeta)); + REGISTER_OPERATOR(set_value, ops::SetValue, ops::SetValueMaker, ops::SetValueGradMaker, ops::SetValueGradMaker, - ops::SetValueOpInplaceInferer); + ops::SetValueOpInplaceInferer, SetValueInferShapeFunctor); REGISTER_OPERATOR(set_value_grad, ops::SetValueGrad); diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index bc6ab52427727edada55991e387d9044442cda0f..8a2d718f124578dbab0164048f8daa09e9a54e8f 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1090,6 +1090,16 @@ void RollInferMeta(const MetaTensor& x, out->set_dtype(x.dtype()); } +void SetValueInferMeta(const MetaTensor& x, MetaTensor* out) { + auto in_dims = x.dims(); + PADDLE_ENFORCE_LT( + in_dims.size(), + 7, + phi::errors::InvalidArgument( + "The rank of input should be less than 7, but received %d.", + in_dims.size())); +} + void ShapeInferMeta(const MetaTensor& input, MetaTensor* out) { auto in_dim = input.dims(); out->set_dims(phi::make_ddim({in_dim.size()})); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 6cb9653624f4227cdf0c98617856b8b4bbbe29cb..7203a327b55698c0d4bd0271b2908cbc4a9b5ca1 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -177,6 +177,8 @@ void RollInferMeta(const MetaTensor& x, const std::vector& axis, MetaTensor* out); +void SetValueInferMeta(const MetaTensor& x, MetaTensor* out); + void ShapeInferMeta(const MetaTensor& input, MetaTensor* out); void ShardIndexInferMeta(const MetaTensor& in, diff --git a/paddle/phi/ops/compat/set_value_sig.cc b/paddle/phi/ops/compat/set_value_sig.cc index 9653250bded84f8ff87f613f6e17e50e351504fa..5feff54b028ba437125d65e4a6709254704164d8 100644 --- a/paddle/phi/ops/compat/set_value_sig.cc +++ b/paddle/phi/ops/compat/set_value_sig.cc @@ -19,9 +19,9 @@ namespace phi { KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { if (ctx.IsDenseTensorInput("Input")) { - if (ctx.HasInput("StartsTensorList")) { - if (ctx.HasInput("EndsTensorList")) { - if (ctx.HasInput("StepsTensorList")) { + if (ctx.InputSize("StartsTensorList") > 0) { + if (ctx.InputSize("EndsTensorList") > 0) { + if (ctx.InputSize("StepsTensorList") > 0) { if (ctx.HasInput("ValueTensor")) { return KernelSignature("set_value_with_tensor", {"Input", "ValueTensor"}, @@ -197,7 +197,7 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { } } } else { - if (ctx.HasInput("StepsTensorList")) { + if (ctx.InputSize("StepsTensorList") > 0) { if (ctx.HasInput("ValueTensor")) { return KernelSignature("set_value_with_tensor", {"Input", "ValueTensor"}, @@ -374,8 +374,8 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { } } } else { - if (ctx.HasInput("EndsTensorList")) { - if (ctx.HasInput("StepsTensorList")) { + if (ctx.InputSize("EndsTensorList") > 0) { + if (ctx.InputSize("StepsTensorList") > 0) { if (ctx.HasInput("ValueTensor")) { return KernelSignature("set_value_with_tensor", {"Input", "ValueTensor"}, @@ -551,7 +551,7 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { } } } else { - if (ctx.HasInput("StepsTensorList")) { + if (ctx.InputSize("StepsTensorList") > 0) { if (ctx.HasInput("ValueTensor")) { return KernelSignature("set_value_with_tensor", {"Input", "ValueTensor"}, @@ -734,9 +734,9 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature SetValueGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - if (ctx.HasInput("StartsTensorList")) { - if (ctx.HasInput("EndsTensorList")) { - if (ctx.HasInput("StepsTensorList")) { + if (ctx.InputSize("StartsTensorList") > 0) { + if (ctx.InputSize("EndsTensorList") > 0) { + if (ctx.InputSize("StepsTensorList") > 0) { return KernelSignature( "set_value_grad", {GradVarName("Out")}, @@ -760,7 +760,7 @@ KernelSignature SetValueGradOpArgumentMapping( {GradVarName("Input"), GradVarName("ValueTensor")}); } } else { - if (ctx.HasInput("StepsTensorList")) { + if (ctx.InputSize("StepsTensorList") > 0) { return KernelSignature( "set_value_grad", {GradVarName("Out")}, @@ -785,8 +785,8 @@ KernelSignature SetValueGradOpArgumentMapping( } } } else { - if (ctx.HasInput("EndsTensorList")) { - if (ctx.HasInput("StepsTensorList")) { + if (ctx.InputSize("EndsTensorList") > 0) { + if (ctx.InputSize("StepsTensorList") > 0) { return KernelSignature( "set_value_grad", {GradVarName("Out")}, @@ -810,7 +810,7 @@ KernelSignature SetValueGradOpArgumentMapping( {GradVarName("Input"), GradVarName("ValueTensor")}); } } else { - if (ctx.HasInput("StepsTensorList")) { + if (ctx.InputSize("StepsTensorList") > 0) { return KernelSignature( "set_value_grad", {GradVarName("Out")},