未验证 提交 c335288d 编写于 作者: Z zyfncg 提交者: GitHub

move infershape of set_value to phi (#40636)

上级 ed8a9370
...@@ -442,6 +442,51 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -442,6 +442,51 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
attr_name, infershape_input.size())); attr_name, infershape_input.size()));
} }
} }
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<phi::Scalar>))) {
auto& attr = attr_reader.GetAttr(attr_name);
if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int32_t>))) {
const auto& vec = BOOST_GET_CONST(std::vector<int32_t>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int64_t>))) {
const auto& vec = BOOST_GET_CONST(std::vector<int64_t>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<float>))) {
const auto& vec = BOOST_GET_CONST(std::vector<float>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<double>))) {
const auto& vec = BOOST_GET_CONST(std::vector<double>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to vector<Scalar> when "
"construct InferMetaContext.",
attr_names[i]));
}
} else if (ctx->HasAttr(attr_name)) { } else if (ctx->HasAttr(attr_name)) {
// Emplace Back Attr according to the type of attr. // Emplace Back Attr according to the type of attr.
auto& attr = attr_reader.GetAttr(attr_name); auto& attr = attr_reader.GetAttr(attr_name);
......
...@@ -13,9 +13,15 @@ ...@@ -13,9 +13,15 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/set_value_op.h" #include "paddle/fluid/operators/set_value_op.h"
#include <string> #include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class InferShapeContext; class InferShapeContext;
...@@ -34,6 +40,8 @@ class CPUDeviceContext; ...@@ -34,6 +40,8 @@ class CPUDeviceContext;
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
class SetValue : public framework::OperatorWithKernel { class SetValue : public framework::OperatorWithKernel {
public: public:
SetValue(const std::string &type, const framework::VariableNameMap &inputs, SetValue(const std::string &type, const framework::VariableNameMap &inputs,
...@@ -41,17 +49,6 @@ class SetValue : public framework::OperatorWithKernel { ...@@ -41,17 +49,6 @@ class SetValue : public framework::OperatorWithKernel {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "SetValue");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SetValue");
auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_LT(
in_dims.size(), 7,
platform::errors::InvalidArgument(
"The rank of input should be less than 7, but received %d.",
in_dims.size()));
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
...@@ -236,10 +233,13 @@ DECLARE_INPLACE_OP_INFERER(SetValueOpInplaceInferer, {"Input", "Out"}); ...@@ -236,10 +233,13 @@ DECLARE_INPLACE_OP_INFERER(SetValueOpInplaceInferer, {"Input", "Out"});
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
DECLARE_INFER_SHAPE_FUNCTOR(set_value, SetValueInferShapeFunctor,
PD_INFER_META(phi::SetValueInferMeta));
REGISTER_OPERATOR(set_value, ops::SetValue, ops::SetValueMaker, REGISTER_OPERATOR(set_value, ops::SetValue, ops::SetValueMaker,
ops::SetValueGradMaker<paddle::framework::OpDesc>, ops::SetValueGradMaker<paddle::framework::OpDesc>,
ops::SetValueGradMaker<paddle::imperative::OpBase>, ops::SetValueGradMaker<paddle::imperative::OpBase>,
ops::SetValueOpInplaceInferer); ops::SetValueOpInplaceInferer, SetValueInferShapeFunctor);
REGISTER_OPERATOR(set_value_grad, ops::SetValueGrad); REGISTER_OPERATOR(set_value_grad, ops::SetValueGrad);
......
...@@ -1090,6 +1090,16 @@ void RollInferMeta(const MetaTensor& x, ...@@ -1090,6 +1090,16 @@ void RollInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype()); out->set_dtype(x.dtype());
} }
void SetValueInferMeta(const MetaTensor& x, MetaTensor* out) {
auto in_dims = x.dims();
PADDLE_ENFORCE_LT(
in_dims.size(),
7,
phi::errors::InvalidArgument(
"The rank of input should be less than 7, but received %d.",
in_dims.size()));
}
void ShapeInferMeta(const MetaTensor& input, MetaTensor* out) { void ShapeInferMeta(const MetaTensor& input, MetaTensor* out) {
auto in_dim = input.dims(); auto in_dim = input.dims();
out->set_dims(phi::make_ddim({in_dim.size()})); out->set_dims(phi::make_ddim({in_dim.size()}));
......
...@@ -177,6 +177,8 @@ void RollInferMeta(const MetaTensor& x, ...@@ -177,6 +177,8 @@ void RollInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis, const std::vector<int64_t>& axis,
MetaTensor* out); MetaTensor* out);
void SetValueInferMeta(const MetaTensor& x, MetaTensor* out);
void ShapeInferMeta(const MetaTensor& input, MetaTensor* out); void ShapeInferMeta(const MetaTensor& input, MetaTensor* out);
void ShardIndexInferMeta(const MetaTensor& in, void ShardIndexInferMeta(const MetaTensor& in,
......
...@@ -19,9 +19,9 @@ namespace phi { ...@@ -19,9 +19,9 @@ namespace phi {
KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorInput("Input")) { if (ctx.IsDenseTensorInput("Input")) {
if (ctx.HasInput("StartsTensorList")) { if (ctx.InputSize("StartsTensorList") > 0) {
if (ctx.HasInput("EndsTensorList")) { if (ctx.InputSize("EndsTensorList") > 0) {
if (ctx.HasInput("StepsTensorList")) { if (ctx.InputSize("StepsTensorList") > 0) {
if (ctx.HasInput("ValueTensor")) { if (ctx.HasInput("ValueTensor")) {
return KernelSignature("set_value_with_tensor", return KernelSignature("set_value_with_tensor",
{"Input", "ValueTensor"}, {"Input", "ValueTensor"},
...@@ -197,7 +197,7 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -197,7 +197,7 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) {
} }
} }
} else { } else {
if (ctx.HasInput("StepsTensorList")) { if (ctx.InputSize("StepsTensorList") > 0) {
if (ctx.HasInput("ValueTensor")) { if (ctx.HasInput("ValueTensor")) {
return KernelSignature("set_value_with_tensor", return KernelSignature("set_value_with_tensor",
{"Input", "ValueTensor"}, {"Input", "ValueTensor"},
...@@ -374,8 +374,8 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -374,8 +374,8 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) {
} }
} }
} else { } else {
if (ctx.HasInput("EndsTensorList")) { if (ctx.InputSize("EndsTensorList") > 0) {
if (ctx.HasInput("StepsTensorList")) { if (ctx.InputSize("StepsTensorList") > 0) {
if (ctx.HasInput("ValueTensor")) { if (ctx.HasInput("ValueTensor")) {
return KernelSignature("set_value_with_tensor", return KernelSignature("set_value_with_tensor",
{"Input", "ValueTensor"}, {"Input", "ValueTensor"},
...@@ -551,7 +551,7 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -551,7 +551,7 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) {
} }
} }
} else { } else {
if (ctx.HasInput("StepsTensorList")) { if (ctx.InputSize("StepsTensorList") > 0) {
if (ctx.HasInput("ValueTensor")) { if (ctx.HasInput("ValueTensor")) {
return KernelSignature("set_value_with_tensor", return KernelSignature("set_value_with_tensor",
{"Input", "ValueTensor"}, {"Input", "ValueTensor"},
...@@ -734,9 +734,9 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -734,9 +734,9 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature SetValueGradOpArgumentMapping( KernelSignature SetValueGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
if (ctx.HasInput("StartsTensorList")) { if (ctx.InputSize("StartsTensorList") > 0) {
if (ctx.HasInput("EndsTensorList")) { if (ctx.InputSize("EndsTensorList") > 0) {
if (ctx.HasInput("StepsTensorList")) { if (ctx.InputSize("StepsTensorList") > 0) {
return KernelSignature( return KernelSignature(
"set_value_grad", "set_value_grad",
{GradVarName("Out")}, {GradVarName("Out")},
...@@ -760,7 +760,7 @@ KernelSignature SetValueGradOpArgumentMapping( ...@@ -760,7 +760,7 @@ KernelSignature SetValueGradOpArgumentMapping(
{GradVarName("Input"), GradVarName("ValueTensor")}); {GradVarName("Input"), GradVarName("ValueTensor")});
} }
} else { } else {
if (ctx.HasInput("StepsTensorList")) { if (ctx.InputSize("StepsTensorList") > 0) {
return KernelSignature( return KernelSignature(
"set_value_grad", "set_value_grad",
{GradVarName("Out")}, {GradVarName("Out")},
...@@ -785,8 +785,8 @@ KernelSignature SetValueGradOpArgumentMapping( ...@@ -785,8 +785,8 @@ KernelSignature SetValueGradOpArgumentMapping(
} }
} }
} else { } else {
if (ctx.HasInput("EndsTensorList")) { if (ctx.InputSize("EndsTensorList") > 0) {
if (ctx.HasInput("StepsTensorList")) { if (ctx.InputSize("StepsTensorList") > 0) {
return KernelSignature( return KernelSignature(
"set_value_grad", "set_value_grad",
{GradVarName("Out")}, {GradVarName("Out")},
...@@ -810,7 +810,7 @@ KernelSignature SetValueGradOpArgumentMapping( ...@@ -810,7 +810,7 @@ KernelSignature SetValueGradOpArgumentMapping(
{GradVarName("Input"), GradVarName("ValueTensor")}); {GradVarName("Input"), GradVarName("ValueTensor")});
} }
} else { } else {
if (ctx.HasInput("StepsTensorList")) { if (ctx.InputSize("StepsTensorList") > 0) {
return KernelSignature( return KernelSignature(
"set_value_grad", "set_value_grad",
{GradVarName("Out")}, {GradVarName("Out")},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册