未验证 提交 2e736531 编写于 作者: C caozhou 提交者: GitHub

[Phi] Migrate InferShape of multiplex, qr, tril_triu (#40102)

* migrate infershape

* fix tril_triu infershape error

* fix qr_op infershape

* add parse qr mode func

* move order
上级 f51a5791
......@@ -14,8 +14,13 @@ limitations under the License. */
#include <memory>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle {
namespace operators {
......@@ -25,44 +30,6 @@ class MultiplexOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Ids"), "Input", "Ids", "Multiplex");
PADDLE_ENFORCE_NE(
ctx->Inputs("X").empty(), true,
platform::errors::InvalidArgument("MultiInput(X) shouldn't be empty."));
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Multiplex");
auto ids_dim = ctx->GetInputDim("Ids");
PADDLE_ENFORCE_EQ(
ids_dim.size(), 2,
platform::errors::PreconditionNotMet(
"The index tensor must be a vector with 2 dimensions"));
PADDLE_ENFORCE_EQ(
ids_dim[1], 1,
platform::errors::PreconditionNotMet(
"The index tensor must be a vector with batchSize x 1."));
auto ins_dims = ctx->GetInputsDim("X");
auto num_ins = ins_dims.size();
PADDLE_ENFORCE_GT(num_ins, 1,
platform::errors::InvalidArgument(
"multiplex operator should have more than "
"one candidate input tensors."));
auto in_dim = ins_dims[0];
PADDLE_ENFORCE_GE(
in_dim.size(), 2,
platform::errors::InvalidArgument(
"The rank of candidate tensors must be not less than 2."));
for (size_t i = 1; i < num_ins; i++) {
auto dim = ins_dims[i];
PADDLE_ENFORCE_EQ(
in_dim, dim,
platform::errors::PreconditionNotMet(
"All the candidate tensors must have the same size."));
}
ctx->SetOutputDim("Out", in_dim);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -164,8 +131,11 @@ class MultiplexGradMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(multiplex, MultiplexInferShapeFunctor,
PD_INFER_META(phi::MultiplexInferMeta));
REGISTER_OPERATOR(multiplex, ops::MultiplexOp, ops::MultiplexOpMaker,
ops::MultiplexGradMaker<paddle::framework::OpDesc>,
ops::MultiplexGradMaker<paddle::imperative::OpBase>);
ops::MultiplexGradMaker<paddle::imperative::OpBase>,
MultiplexInferShapeFunctor);
REGISTER_OPERATOR(multiplex_grad, ops::MultiplexGradOp);
......@@ -21,6 +21,10 @@
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -29,43 +33,6 @@ using DDim = framework::DDim;
class QrOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "qr");
OP_INOUT_CHECK(ctx->HasOutput("Q"), "Output", "Q", "qr");
OP_INOUT_CHECK(ctx->HasOutput("R"), "Output", "R", "qr");
auto x_dims = ctx->GetInputDim("X");
int x_rank = x_dims.size();
PADDLE_ENFORCE_GE(x_dims.size(), 2,
platform::errors::InvalidArgument(
"the rank of input must greater than 2"));
bool compute_q;
bool reduced_mode;
int m = x_dims[x_rank - 2];
int n = x_dims[x_rank - 1];
int min_mn = std::min(m, n);
std::string mode = ctx->Attrs().Get<std::string>("mode");
std::tie(compute_q, reduced_mode) = _parse_qr_mode(mode);
if (compute_q) {
int k = reduced_mode ? min_mn : m;
auto q_dims_vec = phi::vectorize(x_dims);
q_dims_vec[q_dims_vec.size() - 1] = k;
ctx->SetOutputDim("Q", phi::make_ddim(q_dims_vec));
} else {
ctx->SetOutputDim("Q", phi::make_ddim({0}));
}
int k = reduced_mode ? min_mn : m;
auto r_dims_vec = phi::vectorize(x_dims);
r_dims_vec[r_dims_vec.size() - 2] = k;
r_dims_vec[r_dims_vec.size() - 1] = n;
ctx->SetOutputDim("R", phi::make_ddim(r_dims_vec));
ctx->ShareLoD("X", /*->*/ "Q");
ctx->ShareLoD("X", /*->*/ "R");
}
};
class QrOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -83,10 +50,8 @@ class QrOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault("reduced");
AddComment(R"DOC(
Qr Operator.
This operator is used to perform QR operation for batched matrics $X$.
$$Q, R = qr(X)$$
)DOC");
}
};
......@@ -138,10 +103,13 @@ class QrGradMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(qr, QrInferShapeFunctor,
PD_INFER_META(phi::QrInferMeta));
REGISTER_OPERATOR(qr, ops::QrOp, ops::QrOpMaker,
ops::QrGradMaker<paddle::framework::OpDesc>,
ops::QrGradMaker<paddle::imperative::OpBase>);
ops::QrGradMaker<paddle::imperative::OpBase>,
QrInferShapeFunctor);
REGISTER_OPERATOR(qr_grad, ops::QrGradOp);
......
......@@ -13,29 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class TrilTriuOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound("Input(X) of TrilTriuOp is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound("Output(Out) of TrilTriuOp is not found."));
const auto& x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_GE(x_dims.size(), 2,
platform::errors::InvalidArgument(
"Input(X)'s rank must be at least 2 in TrilTriuOp."));
ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class TrilTriuOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -100,7 +89,10 @@ class TrilTriuGradOpMaker : public framework::SingleGradOpMaker<T> {
namespace ops = paddle::operators;
namespace plat = paddle::platform;
DECLARE_INFER_SHAPE_FUNCTOR(tril_triu, TrilTriuInferShapeFunctor,
PD_INFER_META(phi::TrilTriuInferMeta));
REGISTER_OPERATOR(tril_triu, ops::TrilTriuOp, ops::TrilTriuOpMaker,
ops::TrilTriuGradOpMaker<paddle::framework::OpDesc>,
ops::TrilTriuGradOpMaker<paddle::imperative::OpBase>);
ops::TrilTriuGradOpMaker<paddle::imperative::OpBase>,
TrilTriuInferShapeFunctor);
REGISTER_OPERATOR(tril_triu_grad, ops::TrilTriuGradOp);
......@@ -832,6 +832,50 @@ void MultiDotInferMeta(const std::vector<MetaTensor*>& x, MetaTensor* out) {
out->share_lod(*x.at(0));
}
void MultiplexInferMeta(const std::vector<MetaTensor*>& ins,
const MetaTensor& ids,
MetaTensor* out) {
PADDLE_ENFORCE_NE(
ins.empty(),
true,
phi::errors::InvalidArgument("MultiInput(X) shouldn't be empty."));
auto ids_dim = ids.dims();
PADDLE_ENFORCE_EQ(ids_dim.size(),
2,
phi::errors::PreconditionNotMet(
"The index tensor must be a vector with 2 dimensions"));
PADDLE_ENFORCE_EQ(
ids_dim[1],
1,
phi::errors::PreconditionNotMet(
"The index tensor must be a vector with batchSize x 1."));
auto ins_dims = GetMetaTensorsDim(ins);
auto num_ins = ins_dims.size();
PADDLE_ENFORCE_GT(
num_ins,
1,
phi::errors::InvalidArgument("multiplex operator should have more than "
"one candidate input tensors."));
auto in_dim = ins_dims[0];
PADDLE_ENFORCE_GE(
in_dim.size(),
2,
phi::errors::InvalidArgument(
"The rank of candidate tensors must be not less than 2."));
for (size_t i = 1; i < num_ins; i++) {
auto dim = ins_dims[i];
PADDLE_ENFORCE_EQ(
in_dim,
dim,
phi::errors::PreconditionNotMet(
"All the candidate tensors must have the same size."));
}
out->set_dims(in_dim);
out->set_dtype(ins[0]->dtype());
}
void PsroiPoolInferMeta(const MetaTensor& x,
const MetaTensor& rois,
paddle::optional<const MetaTensor&> rois_num,
......
......@@ -152,6 +152,10 @@ void HierarchicalSigmoidInferMeta(const MetaTensor& x,
void MultiDotInferMeta(const std::vector<MetaTensor*>& x, MetaTensor* out);
void MultiplexInferMeta(const std::vector<MetaTensor*>& ins,
const MetaTensor& ids,
MetaTensor* out);
void PsroiPoolInferMeta(const MetaTensor& x,
const MetaTensor& rois,
paddle::optional<const MetaTensor&> rois_num,
......
......@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/kernels/funcs/parse_qr_mode.h"
#include "paddle/phi/kernels/funcs/pooling.h"
#include "paddle/phi/kernels/funcs/unfold_functor.h"
#include "paddle/phi/kernels/funcs/unsqueeze.h"
......@@ -1129,6 +1130,44 @@ void RealAndImagInferMeta(const MetaTensor& x, MetaTensor* out) {
out->set_layout(x.layout());
}
void QrInferMeta(const MetaTensor& x,
const std::string& mode,
MetaTensor* q,
MetaTensor* r) {
auto x_dims = x.dims();
int x_rank = x_dims.size();
PADDLE_ENFORCE_GE(
x_dims.size(),
2,
phi::errors::InvalidArgument("the rank of input must greater than 2"));
bool compute_q;
bool reduced_mode;
int m = x_dims[x_rank - 2];
int n = x_dims[x_rank - 1];
int min_mn = std::min(m, n);
std::tie(compute_q, reduced_mode) = phi::funcs::ParseQrMode(mode);
if (compute_q) {
int k = reduced_mode ? min_mn : m;
auto q_dims_vec = phi::vectorize(x_dims);
q_dims_vec[q_dims_vec.size() - 1] = k;
q->set_dims(phi::make_ddim(q_dims_vec));
} else {
q->set_dims(phi::make_ddim({0}));
}
int k = reduced_mode ? min_mn : m;
auto r_dims_vec = phi::vectorize(x_dims);
r_dims_vec[r_dims_vec.size() - 2] = k;
r_dims_vec[r_dims_vec.size() - 1] = n;
r->set_dims(phi::make_ddim(r_dims_vec));
q->share_lod(x);
r->share_lod(x);
q->set_dtype(x.dtype());
r->set_dtype(x.dtype());
}
DDim ReduceInferDim(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
......@@ -1847,6 +1886,20 @@ void UnbindInferMeta(const MetaTensor& x,
}
}
void TrilTriuInferMeta(const MetaTensor& x,
int diagonal,
bool lower,
MetaTensor* out) {
const auto& x_dims = x.dims();
PADDLE_ENFORCE_GE(x_dims.size(),
2,
phi::errors::InvalidArgument(
"Input(X)'s rank must be at least 2 in TrilTriuOp."));
out->set_dims(x.dims());
out->share_lod(x);
out->set_dtype(x.dtype());
}
void UnchangedInferMeta(const MetaTensor& x, MetaTensor* out) {
out->share_meta(x);
}
......
......@@ -180,6 +180,11 @@ void PoolInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());
void QrInferMeta(const MetaTensor& x,
const std::string& mode,
MetaTensor* q,
MetaTensor* r);
void RealAndImagInferMeta(const MetaTensor& x, MetaTensor* out);
void ReduceInferMeta(const MetaTensor& x,
......@@ -282,6 +287,11 @@ void TransposeGradInferMeta(const MetaTensor& x,
const std::vector<int>& axis,
MetaTensor* out);
void TrilTriuInferMeta(const MetaTensor& x,
int diagonal,
bool lower,
MetaTensor* out);
void UnbindInferMeta(const MetaTensor& x,
int axis,
std::vector<MetaTensor>* outs);
......
......@@ -19,30 +19,10 @@
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/parse_qr_mode.h"
namespace phi {
static inline std::tuple<bool, bool> ParseQrMode(const std::string& mode) {
bool compute_q;
bool reduced;
if (mode == "reduced") {
compute_q = true;
reduced = true;
} else if (mode == "complete") {
compute_q = true;
reduced = false;
} else if (mode == "r") {
compute_q = false;
reduced = true;
} else {
PADDLE_THROW(errors::InvalidArgument(
"QR received unrecognized mode '%s'"
" but expected one of 'reduced' (default), 'r', or 'complete'",
mode));
}
return std::make_tuple(compute_q, reduced);
}
template <typename T, typename Context>
void QrKernel(const Context& ctx,
const DenseTensor& x,
......@@ -51,7 +31,7 @@ void QrKernel(const Context& ctx,
DenseTensor* r) {
bool compute_q;
bool reduced_mode;
std::tie(compute_q, reduced_mode) = ParseQrMode(mode);
std::tie(compute_q, reduced_mode) = phi::funcs::ParseQrMode(mode);
auto numel = x.numel();
PADDLE_ENFORCE_GT(
numel, 0, errors::PreconditionNotMet("The input of QR is empty."));
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace phi {
namespace funcs {
static inline std::tuple<bool, bool> ParseQrMode(const std::string& mode) {
bool compute_q;
bool reduced;
if (mode == "reduced") {
compute_q = true;
reduced = true;
} else if (mode == "complete") {
compute_q = true;
reduced = false;
} else if (mode == "r") {
compute_q = false;
reduced = true;
} else {
PADDLE_THROW(errors::InvalidArgument(
"QR received unrecognized mode '%s'"
" but expected one of 'reduced' (default), 'r', or 'complete'",
mode));
}
return std::make_tuple(compute_q, reduced);
}
} // namespace funcs
} // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册