未验证 提交 2e736531 编写于 作者: C caozhou 提交者: GitHub

[Phi] Migrate InferShape of multiplex, qr, tril_triu (#40102)

* migrate infershape

* fix tril_triu infershape error

* fix qr_op infershape

* add parse qr mode func

* move order
上级 f51a5791
...@@ -14,8 +14,13 @@ limitations under the License. */ ...@@ -14,8 +14,13 @@ limitations under the License. */
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -25,44 +30,6 @@ class MultiplexOp : public framework::OperatorWithKernel { ...@@ -25,44 +30,6 @@ class MultiplexOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Ids"), "Input", "Ids", "Multiplex");
PADDLE_ENFORCE_NE(
ctx->Inputs("X").empty(), true,
platform::errors::InvalidArgument("MultiInput(X) shouldn't be empty."));
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Multiplex");
auto ids_dim = ctx->GetInputDim("Ids");
PADDLE_ENFORCE_EQ(
ids_dim.size(), 2,
platform::errors::PreconditionNotMet(
"The index tensor must be a vector with 2 dimensions"));
PADDLE_ENFORCE_EQ(
ids_dim[1], 1,
platform::errors::PreconditionNotMet(
"The index tensor must be a vector with batchSize x 1."));
auto ins_dims = ctx->GetInputsDim("X");
auto num_ins = ins_dims.size();
PADDLE_ENFORCE_GT(num_ins, 1,
platform::errors::InvalidArgument(
"multiplex operator should have more than "
"one candidate input tensors."));
auto in_dim = ins_dims[0];
PADDLE_ENFORCE_GE(
in_dim.size(), 2,
platform::errors::InvalidArgument(
"The rank of candidate tensors must be not less than 2."));
for (size_t i = 1; i < num_ins; i++) {
auto dim = ins_dims[i];
PADDLE_ENFORCE_EQ(
in_dim, dim,
platform::errors::PreconditionNotMet(
"All the candidate tensors must have the same size."));
}
ctx->SetOutputDim("Out", in_dim);
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
...@@ -164,8 +131,11 @@ class MultiplexGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -164,8 +131,11 @@ class MultiplexGradMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(multiplex, MultiplexInferShapeFunctor,
PD_INFER_META(phi::MultiplexInferMeta));
REGISTER_OPERATOR(multiplex, ops::MultiplexOp, ops::MultiplexOpMaker, REGISTER_OPERATOR(multiplex, ops::MultiplexOp, ops::MultiplexOpMaker,
ops::MultiplexGradMaker<paddle::framework::OpDesc>, ops::MultiplexGradMaker<paddle::framework::OpDesc>,
ops::MultiplexGradMaker<paddle::imperative::OpBase>); ops::MultiplexGradMaker<paddle::imperative::OpBase>,
MultiplexInferShapeFunctor);
REGISTER_OPERATOR(multiplex_grad, ops::MultiplexGradOp); REGISTER_OPERATOR(multiplex_grad, ops::MultiplexGradOp);
...@@ -21,6 +21,10 @@ ...@@ -21,6 +21,10 @@
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#endif #endif
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -29,43 +33,6 @@ using DDim = framework::DDim; ...@@ -29,43 +33,6 @@ using DDim = framework::DDim;
class QrOp : public framework::OperatorWithKernel { class QrOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "qr");
OP_INOUT_CHECK(ctx->HasOutput("Q"), "Output", "Q", "qr");
OP_INOUT_CHECK(ctx->HasOutput("R"), "Output", "R", "qr");
auto x_dims = ctx->GetInputDim("X");
int x_rank = x_dims.size();
PADDLE_ENFORCE_GE(x_dims.size(), 2,
platform::errors::InvalidArgument(
"the rank of input must greater than 2"));
bool compute_q;
bool reduced_mode;
int m = x_dims[x_rank - 2];
int n = x_dims[x_rank - 1];
int min_mn = std::min(m, n);
std::string mode = ctx->Attrs().Get<std::string>("mode");
std::tie(compute_q, reduced_mode) = _parse_qr_mode(mode);
if (compute_q) {
int k = reduced_mode ? min_mn : m;
auto q_dims_vec = phi::vectorize(x_dims);
q_dims_vec[q_dims_vec.size() - 1] = k;
ctx->SetOutputDim("Q", phi::make_ddim(q_dims_vec));
} else {
ctx->SetOutputDim("Q", phi::make_ddim({0}));
}
int k = reduced_mode ? min_mn : m;
auto r_dims_vec = phi::vectorize(x_dims);
r_dims_vec[r_dims_vec.size() - 2] = k;
r_dims_vec[r_dims_vec.size() - 1] = n;
ctx->SetOutputDim("R", phi::make_ddim(r_dims_vec));
ctx->ShareLoD("X", /*->*/ "Q");
ctx->ShareLoD("X", /*->*/ "R");
}
}; };
class QrOpMaker : public framework::OpProtoAndCheckerMaker { class QrOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -83,10 +50,8 @@ class QrOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -83,10 +50,8 @@ class QrOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault("reduced"); .SetDefault("reduced");
AddComment(R"DOC( AddComment(R"DOC(
Qr Operator. Qr Operator.
This operator is used to perform QR operation for batched matrics $X$. This operator is used to perform QR operation for batched matrics $X$.
$$Q, R = qr(X)$$ $$Q, R = qr(X)$$
)DOC"); )DOC");
} }
}; };
...@@ -138,10 +103,13 @@ class QrGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -138,10 +103,13 @@ class QrGradMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(qr, QrInferShapeFunctor,
PD_INFER_META(phi::QrInferMeta));
REGISTER_OPERATOR(qr, ops::QrOp, ops::QrOpMaker, REGISTER_OPERATOR(qr, ops::QrOp, ops::QrOpMaker,
ops::QrGradMaker<paddle::framework::OpDesc>, ops::QrGradMaker<paddle::framework::OpDesc>,
ops::QrGradMaker<paddle::imperative::OpBase>); ops::QrGradMaker<paddle::imperative::OpBase>,
QrInferShapeFunctor);
REGISTER_OPERATOR(qr_grad, ops::QrGradOp); REGISTER_OPERATOR(qr_grad, ops::QrGradOp);
......
...@@ -13,29 +13,18 @@ See the License for the specific language governing permissions and ...@@ -13,29 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <memory> #include <memory>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class TrilTriuOp : public framework::OperatorWithKernel { class TrilTriuOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound("Input(X) of TrilTriuOp is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound("Output(Out) of TrilTriuOp is not found."));
const auto& x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_GE(x_dims.size(), 2,
platform::errors::InvalidArgument(
"Input(X)'s rank must be at least 2 in TrilTriuOp."));
ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
}; };
class TrilTriuOpMaker : public framework::OpProtoAndCheckerMaker { class TrilTriuOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -100,7 +89,10 @@ class TrilTriuGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -100,7 +89,10 @@ class TrilTriuGradOpMaker : public framework::SingleGradOpMaker<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
DECLARE_INFER_SHAPE_FUNCTOR(tril_triu, TrilTriuInferShapeFunctor,
PD_INFER_META(phi::TrilTriuInferMeta));
REGISTER_OPERATOR(tril_triu, ops::TrilTriuOp, ops::TrilTriuOpMaker, REGISTER_OPERATOR(tril_triu, ops::TrilTriuOp, ops::TrilTriuOpMaker,
ops::TrilTriuGradOpMaker<paddle::framework::OpDesc>, ops::TrilTriuGradOpMaker<paddle::framework::OpDesc>,
ops::TrilTriuGradOpMaker<paddle::imperative::OpBase>); ops::TrilTriuGradOpMaker<paddle::imperative::OpBase>,
TrilTriuInferShapeFunctor);
REGISTER_OPERATOR(tril_triu_grad, ops::TrilTriuGradOp); REGISTER_OPERATOR(tril_triu_grad, ops::TrilTriuGradOp);
...@@ -832,6 +832,50 @@ void MultiDotInferMeta(const std::vector<MetaTensor*>& x, MetaTensor* out) { ...@@ -832,6 +832,50 @@ void MultiDotInferMeta(const std::vector<MetaTensor*>& x, MetaTensor* out) {
out->share_lod(*x.at(0)); out->share_lod(*x.at(0));
} }
void MultiplexInferMeta(const std::vector<MetaTensor*>& ins,
const MetaTensor& ids,
MetaTensor* out) {
PADDLE_ENFORCE_NE(
ins.empty(),
true,
phi::errors::InvalidArgument("MultiInput(X) shouldn't be empty."));
auto ids_dim = ids.dims();
PADDLE_ENFORCE_EQ(ids_dim.size(),
2,
phi::errors::PreconditionNotMet(
"The index tensor must be a vector with 2 dimensions"));
PADDLE_ENFORCE_EQ(
ids_dim[1],
1,
phi::errors::PreconditionNotMet(
"The index tensor must be a vector with batchSize x 1."));
auto ins_dims = GetMetaTensorsDim(ins);
auto num_ins = ins_dims.size();
PADDLE_ENFORCE_GT(
num_ins,
1,
phi::errors::InvalidArgument("multiplex operator should have more than "
"one candidate input tensors."));
auto in_dim = ins_dims[0];
PADDLE_ENFORCE_GE(
in_dim.size(),
2,
phi::errors::InvalidArgument(
"The rank of candidate tensors must be not less than 2."));
for (size_t i = 1; i < num_ins; i++) {
auto dim = ins_dims[i];
PADDLE_ENFORCE_EQ(
in_dim,
dim,
phi::errors::PreconditionNotMet(
"All the candidate tensors must have the same size."));
}
out->set_dims(in_dim);
out->set_dtype(ins[0]->dtype());
}
void PsroiPoolInferMeta(const MetaTensor& x, void PsroiPoolInferMeta(const MetaTensor& x,
const MetaTensor& rois, const MetaTensor& rois,
paddle::optional<const MetaTensor&> rois_num, paddle::optional<const MetaTensor&> rois_num,
......
...@@ -152,6 +152,10 @@ void HierarchicalSigmoidInferMeta(const MetaTensor& x, ...@@ -152,6 +152,10 @@ void HierarchicalSigmoidInferMeta(const MetaTensor& x,
void MultiDotInferMeta(const std::vector<MetaTensor*>& x, MetaTensor* out); void MultiDotInferMeta(const std::vector<MetaTensor*>& x, MetaTensor* out);
void MultiplexInferMeta(const std::vector<MetaTensor*>& ins,
const MetaTensor& ids,
MetaTensor* out);
void PsroiPoolInferMeta(const MetaTensor& x, void PsroiPoolInferMeta(const MetaTensor& x,
const MetaTensor& rois, const MetaTensor& rois,
paddle::optional<const MetaTensor&> rois_num, paddle::optional<const MetaTensor&> rois_num,
......
...@@ -22,6 +22,7 @@ limitations under the License. */ ...@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/phi/common/type_traits.h" #include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/kernels/funcs/parse_qr_mode.h"
#include "paddle/phi/kernels/funcs/pooling.h" #include "paddle/phi/kernels/funcs/pooling.h"
#include "paddle/phi/kernels/funcs/unfold_functor.h" #include "paddle/phi/kernels/funcs/unfold_functor.h"
#include "paddle/phi/kernels/funcs/unsqueeze.h" #include "paddle/phi/kernels/funcs/unsqueeze.h"
...@@ -1129,6 +1130,44 @@ void RealAndImagInferMeta(const MetaTensor& x, MetaTensor* out) { ...@@ -1129,6 +1130,44 @@ void RealAndImagInferMeta(const MetaTensor& x, MetaTensor* out) {
out->set_layout(x.layout()); out->set_layout(x.layout());
} }
void QrInferMeta(const MetaTensor& x,
const std::string& mode,
MetaTensor* q,
MetaTensor* r) {
auto x_dims = x.dims();
int x_rank = x_dims.size();
PADDLE_ENFORCE_GE(
x_dims.size(),
2,
phi::errors::InvalidArgument("the rank of input must greater than 2"));
bool compute_q;
bool reduced_mode;
int m = x_dims[x_rank - 2];
int n = x_dims[x_rank - 1];
int min_mn = std::min(m, n);
std::tie(compute_q, reduced_mode) = phi::funcs::ParseQrMode(mode);
if (compute_q) {
int k = reduced_mode ? min_mn : m;
auto q_dims_vec = phi::vectorize(x_dims);
q_dims_vec[q_dims_vec.size() - 1] = k;
q->set_dims(phi::make_ddim(q_dims_vec));
} else {
q->set_dims(phi::make_ddim({0}));
}
int k = reduced_mode ? min_mn : m;
auto r_dims_vec = phi::vectorize(x_dims);
r_dims_vec[r_dims_vec.size() - 2] = k;
r_dims_vec[r_dims_vec.size() - 1] = n;
r->set_dims(phi::make_ddim(r_dims_vec));
q->share_lod(x);
r->share_lod(x);
q->set_dtype(x.dtype());
r->set_dtype(x.dtype());
}
DDim ReduceInferDim(const MetaTensor& x, DDim ReduceInferDim(const MetaTensor& x,
const std::vector<int64_t>& axis, const std::vector<int64_t>& axis,
bool keep_dim, bool keep_dim,
...@@ -1847,6 +1886,20 @@ void UnbindInferMeta(const MetaTensor& x, ...@@ -1847,6 +1886,20 @@ void UnbindInferMeta(const MetaTensor& x,
} }
} }
void TrilTriuInferMeta(const MetaTensor& x,
int diagonal,
bool lower,
MetaTensor* out) {
const auto& x_dims = x.dims();
PADDLE_ENFORCE_GE(x_dims.size(),
2,
phi::errors::InvalidArgument(
"Input(X)'s rank must be at least 2 in TrilTriuOp."));
out->set_dims(x.dims());
out->share_lod(x);
out->set_dtype(x.dtype());
}
void UnchangedInferMeta(const MetaTensor& x, MetaTensor* out) { void UnchangedInferMeta(const MetaTensor& x, MetaTensor* out) {
out->share_meta(x); out->share_meta(x);
} }
......
...@@ -180,6 +180,11 @@ void PoolInferMeta(const MetaTensor& x, ...@@ -180,6 +180,11 @@ void PoolInferMeta(const MetaTensor& x,
MetaTensor* out, MetaTensor* out,
MetaConfig config = MetaConfig()); MetaConfig config = MetaConfig());
void QrInferMeta(const MetaTensor& x,
const std::string& mode,
MetaTensor* q,
MetaTensor* r);
void RealAndImagInferMeta(const MetaTensor& x, MetaTensor* out); void RealAndImagInferMeta(const MetaTensor& x, MetaTensor* out);
void ReduceInferMeta(const MetaTensor& x, void ReduceInferMeta(const MetaTensor& x,
...@@ -282,6 +287,11 @@ void TransposeGradInferMeta(const MetaTensor& x, ...@@ -282,6 +287,11 @@ void TransposeGradInferMeta(const MetaTensor& x,
const std::vector<int>& axis, const std::vector<int>& axis,
MetaTensor* out); MetaTensor* out);
void TrilTriuInferMeta(const MetaTensor& x,
int diagonal,
bool lower,
MetaTensor* out);
void UnbindInferMeta(const MetaTensor& x, void UnbindInferMeta(const MetaTensor& x,
int axis, int axis,
std::vector<MetaTensor>* outs); std::vector<MetaTensor>* outs);
......
...@@ -19,30 +19,10 @@ ...@@ -19,30 +19,10 @@
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/complex_functors.h" #include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/parse_qr_mode.h"
namespace phi { namespace phi {
static inline std::tuple<bool, bool> ParseQrMode(const std::string& mode) {
bool compute_q;
bool reduced;
if (mode == "reduced") {
compute_q = true;
reduced = true;
} else if (mode == "complete") {
compute_q = true;
reduced = false;
} else if (mode == "r") {
compute_q = false;
reduced = true;
} else {
PADDLE_THROW(errors::InvalidArgument(
"QR received unrecognized mode '%s'"
" but expected one of 'reduced' (default), 'r', or 'complete'",
mode));
}
return std::make_tuple(compute_q, reduced);
}
template <typename T, typename Context> template <typename T, typename Context>
void QrKernel(const Context& ctx, void QrKernel(const Context& ctx,
const DenseTensor& x, const DenseTensor& x,
...@@ -51,7 +31,7 @@ void QrKernel(const Context& ctx, ...@@ -51,7 +31,7 @@ void QrKernel(const Context& ctx,
DenseTensor* r) { DenseTensor* r) {
bool compute_q; bool compute_q;
bool reduced_mode; bool reduced_mode;
std::tie(compute_q, reduced_mode) = ParseQrMode(mode); std::tie(compute_q, reduced_mode) = phi::funcs::ParseQrMode(mode);
auto numel = x.numel(); auto numel = x.numel();
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
numel, 0, errors::PreconditionNotMet("The input of QR is empty.")); numel, 0, errors::PreconditionNotMet("The input of QR is empty."));
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace phi {
namespace funcs {
static inline std::tuple<bool, bool> ParseQrMode(const std::string& mode) {
bool compute_q;
bool reduced;
if (mode == "reduced") {
compute_q = true;
reduced = true;
} else if (mode == "complete") {
compute_q = true;
reduced = false;
} else if (mode == "r") {
compute_q = false;
reduced = true;
} else {
PADDLE_THROW(errors::InvalidArgument(
"QR received unrecognized mode '%s'"
" but expected one of 'reduced' (default), 'r', or 'complete'",
mode));
}
return std::make_tuple(compute_q, reduced);
}
} // namespace funcs
} // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册