未验证 提交 9ee03302 编写于 作者: Z Zhang Zheng 提交者: GitHub

[Phi]Move infershape of top_k/expand_as/kron/searchsorted to phi (#40632)

* [Phi]Move infershape of top_k/expand_as/kron/searchsorted to phi

* add set_dtype

* fix order
上级 c142e37d
......@@ -12,7 +12,9 @@ limitations under the License. */
#include "paddle/fluid/operators/expand_as_v2_op.h"
#include <memory>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
......@@ -22,27 +24,6 @@ using framework::Tensor;
class ExpandAsV2Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ExpandAsV2");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ExpandAsV2");
auto x_dims = ctx->GetInputDim("X");
auto target_shape = ctx->Attrs().Get<std::vector<int>>("target_shape");
PADDLE_ENFORCE_GE(
target_shape.size(), static_cast<size_t>(x_dims.size()),
platform::errors::InvalidArgument(
"The rank of target_shape must be greater than or equal "
"to the rank of Input(X). But received Input(X): input "
"rank %u; received target_shape: rank %u.",
x_dims.size(), target_shape.size()));
PADDLE_ENFORCE_LE(target_shape.size(), MAX_RANK_SUPPORTED,
platform::errors::InvalidArgument(
"The rank of target_shape must be less than or equal "
"to %d. But received: rank %u.",
MAX_RANK_SUPPORTED, target_shape.size()));
ctx->SetOutputDim("Out", phi::make_ddim(target_shape));
}
};
class ExpandAsV2OpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -116,9 +97,12 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(ExpandAsV2GradNoNeedBufVarsInferer, "X");
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(expand_as_v2, ExpandAsInferShapeFunctor,
PD_INFER_META(phi::ExpandAsInferMeta));
REGISTER_OPERATOR(expand_as_v2, ops::ExpandAsV2Op, ops::ExpandAsV2OpMaker,
ops::ExpandAsV2GradOpMaker<paddle::framework::OpDesc>,
ops::ExpandAsV2GradOpMaker<paddle::imperative::OpBase>);
ops::ExpandAsV2GradOpMaker<paddle::imperative::OpBase>,
ExpandAsInferShapeFunctor);
REGISTER_OPERATOR(expand_as_v2_grad, ops::ExpandAsV2GradOp,
ops::ExpandAsV2GradNoNeedBufVarsInferer);
......
......@@ -17,7 +17,9 @@ limitations under the License. */
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
......@@ -26,27 +28,6 @@ class KronOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "kron");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "kron");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "kron");
auto dim_x = ctx->GetInputDim("X");
auto dim_y = ctx->GetInputDim("Y");
auto rank_x = dim_x.size();
auto rank_y = dim_y.size();
auto rank = (rank_x > rank_y) ? rank_x : rank_y;
std::vector<int64_t> dim_out;
dim_out.reserve(rank);
for (int i = 0; i < rank; i++) {
int64_t dim_xi = (i < rank - rank_x) ? 1 : dim_x.at(i - (rank - rank_x));
int64_t dim_yi = (i < rank - rank_y) ? 1 : dim_y.at(i - (rank - rank_y));
dim_out.push_back(dim_xi == -1 || dim_yi == -1 ? -1 : dim_xi * dim_yi);
}
ctx->SetOutputDim("Out", phi::make_ddim(dim_out));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -173,7 +154,10 @@ class KronGradOpMaker : public framework::SingleGradOpMaker<T> {
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(kron, KronInferShapeFunctor,
PD_INFER_META(phi::KronInferMeta));
REGISTER_OPERATOR(kron, ops::KronOp, ops::KronOpMaker,
ops::KronGradOpMaker<paddle::framework::OpDesc>,
ops::KronGradOpMaker<paddle::imperative::OpBase>);
ops::KronGradOpMaker<paddle::imperative::OpBase>,
KronInferShapeFunctor);
REGISTER_OPERATOR(kron_grad, ops::KronGradOp);
......@@ -12,8 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
......@@ -21,60 +23,6 @@ namespace operators {
class SearchSortedOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
static bool SearchsortedDimsMatchedBeforeLastDim(
const framework::DDim& sequences_dims,
const framework::DDim& values_dims) {
if (sequences_dims.size() != values_dims.size()) {
return false;
}
const auto& sequences_dims_size = sequences_dims.size();
for (int64_t dim = 0; dim < sequences_dims_size - 1; ++dim) {
if (sequences_dims[dim] != values_dims[dim]) {
return false;
}
}
return true;
}
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("SortedSequence"), "Input", "SortedSequence",
"searchsorted");
OP_INOUT_CHECK(ctx->HasInput("Values"), "Input", "Values", "searchsorted");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "searchsorted");
auto sequences_dims = ctx->GetInputDim("SortedSequence");
auto values_dims = ctx->GetInputDim("Values");
auto out_int32 = ctx->Attrs().Get<bool>("out_int32");
if (sequences_dims.size() != 1) {
PADDLE_ENFORCE_EQ(
SearchsortedDimsMatchedBeforeLastDim(sequences_dims, values_dims),
true,
platform::errors::Unavailable(
"The dimensions of sorted_sequence tensor ( %s ) and values "
"tensor ( %s ) can not match. Because the input sorted_sequence "
"tensor must be 1 dimension or the first N-1 dimensions of "
"sorted_sequence tensor and input values tensor must match. "
"Please input appropriate sorted_sequence and values again! ",
sequences_dims, values_dims));
}
if (out_int32) {
PADDLE_ENFORCE_LT(
sequences_dims[sequences_dims.size() - 1],
std::numeric_limits<int>::max(),
platform::errors::Unavailable(
"The size of sorted_sequence %d exceed the maximum limit d%. "
"Because the size of sorted_sequence should be less than the "
"output maximum value for int32 bit. Please set appropriate "
"sorted_sequence to meet this requirement! ",
sequences_dims[sequences_dims.size() - 1],
std::numeric_limits<int>::max()));
}
ctx->SetOutputDim("Out", values_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
......@@ -115,4 +63,7 @@ class SearchSortedOpMaker : public framework::OpProtoAndCheckerMaker {
namespace ops = paddle::operators;
REGISTER_OPERATOR(searchsorted, ops::SearchSortedOp, ops::SearchSortedOpMaker);
DECLARE_INFER_SHAPE_FUNCTOR(searchsorted, SearchsortedInferShapeFunctor,
PD_INFER_META(phi::SearchsortedInferMeta));
REGISTER_OPERATOR(searchsorted, ops::SearchSortedOp, ops::SearchSortedOpMaker,
SearchsortedInferShapeFunctor);
......@@ -14,7 +14,9 @@ limitations under the License. */
#include <memory>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -23,56 +25,6 @@ class TopkV2Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "topk_v2");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "topk_v2");
OP_INOUT_CHECK(ctx->HasOutput("Indices"), "Output", "Indices", "topk_v2");
auto input_dims = ctx->GetInputDim("X");
const int& dim_size = input_dims.size();
int axis = static_cast<int>(ctx->Attrs().Get<int>("axis"));
PADDLE_ENFORCE_EQ(
(axis < dim_size) && (axis >= (-1 * dim_size)), true,
paddle::platform::errors::InvalidArgument(
"the axis of topk must be [-%d, %d), but you set axis is %d",
dim_size, dim_size, axis));
if (axis < 0) axis += dim_size;
int k;
auto k_is_tensor = ctx->HasInput("K");
if (k_is_tensor) {
k = -1;
} else {
k = static_cast<int>(ctx->Attrs().Get<int>("k"));
PADDLE_ENFORCE_EQ(k >= 1, true,
paddle::platform::errors::InvalidArgument(
"the attribute of k in the topk must >= 1 or be a "
"Tensor, but received %d .",
k));
}
PADDLE_ENFORCE_GE(input_dims.size(), 1,
paddle::platform::errors::InvalidArgument(
"input of topk must have >= 1d shape"));
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_GE(
input_dims[axis], k,
paddle::platform::errors::InvalidArgument(
"input of topk op must have >= %d columns in axis of %d", k,
axis));
}
framework::DDim dims = input_dims;
dims[axis] = k;
ctx->SetOutputDim("Out", dims);
ctx->SetOutputDim("Indices", dims);
ctx->ShareLoD("X", "Out");
ctx->ShareLoD("X", "Indices");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -169,8 +121,11 @@ class TopkV2GradOpMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(top_k_v2, TopKInferShapeFunctor,
PD_INFER_META(phi::TopKInferMeta));
REGISTER_OPERATOR(top_k_v2, ops::TopkV2Op, ops::TopkV2OpMaker,
ops::TopkV2GradOpMaker<paddle::framework::OpDesc>,
ops::TopkV2GradOpMaker<paddle::imperative::OpBase>);
ops::TopkV2GradOpMaker<paddle::imperative::OpBase>,
TopKInferShapeFunctor);
REGISTER_OPERATOR(top_k_v2_grad, ops::TopkV2OpGrad);
......@@ -476,6 +476,33 @@ void ElementwiseRawInferMeta(const MetaTensor& x,
out->share_lod(x);
}
void ExpandAsInferMeta(const MetaTensor& x,
paddle::optional<const MetaTensor&> y,
const std::vector<int>& target_shape,
MetaTensor* out) {
#define MAX_RANK_SUPPORTED 6
auto x_dims = x.dims();
PADDLE_ENFORCE_GE(
target_shape.size(),
static_cast<size_t>(x_dims.size()),
phi::errors::InvalidArgument(
"The rank of target_shape must be greater than or equal "
"to the rank of Input(X). But received Input(X): input "
"rank %u; received target_shape: rank %u.",
x_dims.size(),
target_shape.size()));
PADDLE_ENFORCE_LE(target_shape.size(),
MAX_RANK_SUPPORTED,
phi::errors::InvalidArgument(
"The rank of target_shape must be less than or equal "
"to %d. But received: rank %u.",
MAX_RANK_SUPPORTED,
target_shape.size()));
out->set_dims(phi::make_ddim(target_shape));
out->set_dtype(x.dtype());
#undef MAX_RANK_SUPPORTED
}
void GatherInferMeta(const MetaTensor& x,
const MetaTensor& index,
const Scalar& axis,
......@@ -728,6 +755,24 @@ void IndexSelectInferMeta(const MetaTensor& x,
output->share_lod(x);
}
void KronInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) {
auto dim_x = x.dims();
auto dim_y = y.dims();
auto rank_x = dim_x.size();
auto rank_y = dim_y.size();
auto rank = (rank_x > rank_y) ? rank_x : rank_y;
std::vector<int64_t> dim_out;
dim_out.reserve(rank);
for (int i = 0; i < rank; i++) {
int64_t dim_xi = (i < rank - rank_x) ? 1 : dim_x.at(i - (rank - rank_x));
int64_t dim_yi = (i < rank - rank_y) ? 1 : dim_y.at(i - (rank - rank_y));
dim_out.push_back(dim_xi == -1 || dim_yi == -1 ? -1 : dim_xi * dim_yi);
}
out->set_dims(phi::make_ddim(dim_out));
out->set_dtype(x.dtype());
}
void LogLossInferMeta(const MetaTensor& input,
const MetaTensor& label,
float epsilon,
......@@ -873,6 +918,60 @@ void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out) {
out->share_lod(x);
}
void SearchsortedInferMeta(const MetaTensor& sorted_sequence,
const MetaTensor& value,
bool out_int32,
bool right,
MetaTensor* out) {
auto sequences_dims = sorted_sequence.dims();
auto values_dims = value.dims();
bool flag = true;
if (sequences_dims.size() != values_dims.size()) {
flag = false;
}
const auto& sequences_dims_size = sequences_dims.size();
for (int64_t dim = 0; dim < sequences_dims_size - 1; ++dim) {
if (sequences_dims[dim] != values_dims[dim]) {
flag = false;
break;
}
}
if (sequences_dims.size() != 1) {
PADDLE_ENFORCE_EQ(
flag,
true,
phi::errors::Unavailable(
"The dimensions of sorted_sequence tensor ( %s ) and values "
"tensor ( %s ) can not match. Because the input sorted_sequence "
"tensor must be 1 dimension or the first N-1 dimensions of "
"sorted_sequence tensor and input values tensor must match. "
"Please input appropriate sorted_sequence and values again! ",
sequences_dims,
values_dims));
}
if (out_int32) {
PADDLE_ENFORCE_LT(
sequences_dims[sequences_dims.size() - 1],
std::numeric_limits<int>::max(),
phi::errors::Unavailable(
"The size of sorted_sequence %d exceed the maximum limit d%. "
"Because the size of sorted_sequence should be less than the "
"output maximum value for int32 bit. Please set appropriate "
"sorted_sequence to meet this requirement! ",
sequences_dims[sequences_dims.size() - 1],
std::numeric_limits<int>::max()));
}
out->set_dims(values_dims);
if (out_int32) {
out->set_dtype(DataType::INT32);
} else {
out->set_dtype(DataType::INT64);
}
}
void SegmentPoolInferMeta(const MetaTensor& x,
const MetaTensor& segment_ids,
const std::string& pooltype,
......
......@@ -90,6 +90,11 @@ void ElementwiseRawInferMeta(const MetaTensor& x_meta,
int axis,
MetaTensor* out);
void ExpandAsInferMeta(const MetaTensor& x,
paddle::optional<const MetaTensor&> y,
const std::vector<int>& target_shape,
MetaTensor* out);
void GatherInferMeta(const MetaTensor& x,
const MetaTensor& index,
const Scalar& axis,
......@@ -125,6 +130,8 @@ void IndexSelectInferMeta(const MetaTensor& x,
int dim,
MetaTensor* output);
void KronInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out);
void LogLossInferMeta(const MetaTensor& input,
const MetaTensor& label,
float epsilon,
......@@ -139,6 +146,12 @@ void MatmulInferMeta(const MetaTensor& x,
void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out);
void SearchsortedInferMeta(const MetaTensor& sorted_sequence,
const MetaTensor& value,
bool out_int32,
bool right,
MetaTensor* out);
void SegmentPoolInferMeta(const MetaTensor& x,
const MetaTensor& segment_ids,
const std::string& pooltype,
......
......@@ -1384,6 +1384,55 @@ void TileInferMeta(const MetaTensor& x,
}
}
void TopKInferMeta(const MetaTensor& x,
const Scalar& k_scalar,
int axis,
bool largest,
bool sorted,
MetaTensor* out,
MetaTensor* indices,
MetaConfig config) {
auto input_dims = x.dims();
const int& dim_size = input_dims.size();
PADDLE_ENFORCE_EQ(
(axis < dim_size) && (axis >= (-1 * dim_size)),
true,
phi::errors::InvalidArgument(
"the axis of topk must be [-%d, %d), but you set axis is %d",
dim_size,
dim_size,
axis));
if (axis < 0) axis += dim_size;
int k = k_scalar.to<int>();
if (k_scalar.FromTensor()) {
k = -1;
} else {
PADDLE_ENFORCE_EQ(k >= 1,
true,
phi::errors::InvalidArgument(
"the attribute of k in the topk must >= 1 or be a "
"Tensor, but received %d .",
k));
}
PADDLE_ENFORCE_GE(
input_dims.size(),
1,
phi::errors::InvalidArgument("input of topk must have >= 1d shape"));
phi::DDim dims = input_dims;
dims[axis] = k;
out->set_dims(dims);
out->share_lod(x);
out->set_dtype(x.dtype());
indices->set_dims(dims);
indices->share_lod(x);
indices->set_dtype(DataType::INT64);
}
void TraceInferMeta(
const MetaTensor& x, int offset, int axis1, int axis2, MetaTensor* out) {
int dim1 = axis1;
......
......@@ -215,6 +215,15 @@ void TileInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());
void TopKInferMeta(const MetaTensor& x,
const Scalar& k_scalar,
int axis,
bool largest,
bool sorted,
MetaTensor* out,
MetaTensor* indices,
MetaConfig config = MetaConfig());
void TraceInferMeta(
const MetaTensor& x, int offset, int axis1, int axis2, MetaTensor* out);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册