未验证 提交 c335288d 编写于 作者: Z zyfncg 提交者: GitHub

move infershape of set_value to phi (#40636)

上级 ed8a9370
......@@ -442,6 +442,51 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
attr_name, infershape_input.size()));
}
}
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<phi::Scalar>))) {
auto& attr = attr_reader.GetAttr(attr_name);
if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int32_t>))) {
const auto& vec = BOOST_GET_CONST(std::vector<int32_t>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int64_t>))) {
const auto& vec = BOOST_GET_CONST(std::vector<int64_t>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<float>))) {
const auto& vec = BOOST_GET_CONST(std::vector<float>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<double>))) {
const auto& vec = BOOST_GET_CONST(std::vector<double>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to vector<Scalar> when "
"construct InferMetaContext.",
attr_names[i]));
}
} else if (ctx->HasAttr(attr_name)) {
// Emplace Back Attr according to the type of attr.
auto& attr = attr_reader.GetAttr(attr_name);
......
......@@ -13,9 +13,15 @@
// limitations under the License.
#include "paddle/fluid/operators/set_value_op.h"
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace framework {
class InferShapeContext;
......@@ -34,6 +40,8 @@ class CPUDeviceContext;
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class SetValue : public framework::OperatorWithKernel {
public:
SetValue(const std::string &type, const framework::VariableNameMap &inputs,
......@@ -41,17 +49,6 @@ class SetValue : public framework::OperatorWithKernel {
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "SetValue");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SetValue");
auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_LT(
in_dims.size(), 7,
platform::errors::InvalidArgument(
"The rank of input should be less than 7, but received %d.",
in_dims.size()));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
......@@ -236,10 +233,13 @@ DECLARE_INPLACE_OP_INFERER(SetValueOpInplaceInferer, {"Input", "Out"});
namespace ops = paddle::operators;
namespace plat = paddle::platform;
DECLARE_INFER_SHAPE_FUNCTOR(set_value, SetValueInferShapeFunctor,
PD_INFER_META(phi::SetValueInferMeta));
REGISTER_OPERATOR(set_value, ops::SetValue, ops::SetValueMaker,
ops::SetValueGradMaker<paddle::framework::OpDesc>,
ops::SetValueGradMaker<paddle::imperative::OpBase>,
ops::SetValueOpInplaceInferer);
ops::SetValueOpInplaceInferer, SetValueInferShapeFunctor);
REGISTER_OPERATOR(set_value_grad, ops::SetValueGrad);
......
......@@ -1090,6 +1090,16 @@ void RollInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype());
}
void SetValueInferMeta(const MetaTensor& x, MetaTensor* out) {
auto in_dims = x.dims();
PADDLE_ENFORCE_LT(
in_dims.size(),
7,
phi::errors::InvalidArgument(
"The rank of input should be less than 7, but received %d.",
in_dims.size()));
}
void ShapeInferMeta(const MetaTensor& input, MetaTensor* out) {
auto in_dim = input.dims();
out->set_dims(phi::make_ddim({in_dim.size()}));
......
......@@ -177,6 +177,8 @@ void RollInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
MetaTensor* out);
void SetValueInferMeta(const MetaTensor& x, MetaTensor* out);
void ShapeInferMeta(const MetaTensor& input, MetaTensor* out);
void ShardIndexInferMeta(const MetaTensor& in,
......
......@@ -19,9 +19,9 @@ namespace phi {
KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorInput("Input")) {
if (ctx.HasInput("StartsTensorList")) {
if (ctx.HasInput("EndsTensorList")) {
if (ctx.HasInput("StepsTensorList")) {
if (ctx.InputSize("StartsTensorList") > 0) {
if (ctx.InputSize("EndsTensorList") > 0) {
if (ctx.InputSize("StepsTensorList") > 0) {
if (ctx.HasInput("ValueTensor")) {
return KernelSignature("set_value_with_tensor",
{"Input", "ValueTensor"},
......@@ -197,7 +197,7 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) {
}
}
} else {
if (ctx.HasInput("StepsTensorList")) {
if (ctx.InputSize("StepsTensorList") > 0) {
if (ctx.HasInput("ValueTensor")) {
return KernelSignature("set_value_with_tensor",
{"Input", "ValueTensor"},
......@@ -374,8 +374,8 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) {
}
}
} else {
if (ctx.HasInput("EndsTensorList")) {
if (ctx.HasInput("StepsTensorList")) {
if (ctx.InputSize("EndsTensorList") > 0) {
if (ctx.InputSize("StepsTensorList") > 0) {
if (ctx.HasInput("ValueTensor")) {
return KernelSignature("set_value_with_tensor",
{"Input", "ValueTensor"},
......@@ -551,7 +551,7 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) {
}
}
} else {
if (ctx.HasInput("StepsTensorList")) {
if (ctx.InputSize("StepsTensorList") > 0) {
if (ctx.HasInput("ValueTensor")) {
return KernelSignature("set_value_with_tensor",
{"Input", "ValueTensor"},
......@@ -734,9 +734,9 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature SetValueGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.HasInput("StartsTensorList")) {
if (ctx.HasInput("EndsTensorList")) {
if (ctx.HasInput("StepsTensorList")) {
if (ctx.InputSize("StartsTensorList") > 0) {
if (ctx.InputSize("EndsTensorList") > 0) {
if (ctx.InputSize("StepsTensorList") > 0) {
return KernelSignature(
"set_value_grad",
{GradVarName("Out")},
......@@ -760,7 +760,7 @@ KernelSignature SetValueGradOpArgumentMapping(
{GradVarName("Input"), GradVarName("ValueTensor")});
}
} else {
if (ctx.HasInput("StepsTensorList")) {
if (ctx.InputSize("StepsTensorList") > 0) {
return KernelSignature(
"set_value_grad",
{GradVarName("Out")},
......@@ -785,8 +785,8 @@ KernelSignature SetValueGradOpArgumentMapping(
}
}
} else {
if (ctx.HasInput("EndsTensorList")) {
if (ctx.HasInput("StepsTensorList")) {
if (ctx.InputSize("EndsTensorList") > 0) {
if (ctx.InputSize("StepsTensorList") > 0) {
return KernelSignature(
"set_value_grad",
{GradVarName("Out")},
......@@ -810,7 +810,7 @@ KernelSignature SetValueGradOpArgumentMapping(
{GradVarName("Input"), GradVarName("ValueTensor")});
}
} else {
if (ctx.HasInput("StepsTensorList")) {
if (ctx.InputSize("StepsTensorList") > 0) {
return KernelSignature(
"set_value_grad",
{GradVarName("Out")},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册