未验证 提交 080024f0 编写于 作者: Z zyfncg 提交者: GitHub

refactor unary infermeta (#40365)

上级 ec09ef26
此差异已折叠。
...@@ -32,32 +32,20 @@ class MetaConfig; ...@@ -32,32 +32,20 @@ class MetaConfig;
// Because functions in this file not only can infer shape, but also need // Because functions in this file not only can infer shape, but also need
// infer lod or other useful data. // infer lod or other useful data.
void ArgMinMaxInferMeta(const MetaTensor& x,
int64_t axis,
bool keepdims,
bool flatten,
int dtype,
MetaTensor* out,
MetaConfig config = MetaConfig());
void ArgsortInferMeta(const MetaTensor& input, void ArgsortInferMeta(const MetaTensor& input,
int axis, int axis,
bool descending, bool descending,
MetaTensor* output, MetaTensor* output,
MetaTensor* indices); MetaTensor* indices);
void UnchangedInferMeta(const MetaTensor& x, MetaTensor* out);
// meta x -> out without change, check if axis in range [-Rank(x), Rank(x)-1]
void UnchangedInferMetaCheckAxis(const MetaTensor& x,
int axis,
MetaTensor* out);
void RealAndImagInferMeta(const MetaTensor& x, MetaTensor* out);
void FlattenInferMeta(const MetaTensor& x,
int start_axis,
int stop_axis,
MetaTensor* out);
void GumbelSoftmaxInferMeta(const MetaTensor& x,
float temperature,
bool hard,
int axis,
MetaTensor* out);
void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out); void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out);
void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out); void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out);
...@@ -76,6 +64,30 @@ void CumsumInferMeta(const MetaTensor& x, ...@@ -76,6 +64,30 @@ void CumsumInferMeta(const MetaTensor& x,
bool reverse, bool reverse,
MetaTensor* out); MetaTensor* out);
void DiagInferMeta(const MetaTensor& x,
int offset,
float padding_value,
MetaTensor* out);
void DiagonalInferMeta(
const MetaTensor& input, int offset, int axis1, int axis2, MetaTensor* out);
void EighInferMeta(const MetaTensor& x,
const std::string& uplo,
MetaTensor* out_w,
MetaTensor* out_v);
void FlattenInferMeta(const MetaTensor& x,
int start_axis,
int stop_axis,
MetaTensor* out);
void GumbelSoftmaxInferMeta(const MetaTensor& x,
float temperature,
bool hard,
int axis,
MetaTensor* out);
void IncrementInferMeta(const MetaTensor& x, float value, MetaTensor* out); void IncrementInferMeta(const MetaTensor& x, float value, MetaTensor* out);
void InferMetaFromVecValue(const MetaTensor& x, void InferMetaFromVecValue(const MetaTensor& x,
...@@ -84,11 +96,37 @@ void InferMetaFromVecValue(const MetaTensor& x, ...@@ -84,11 +96,37 @@ void InferMetaFromVecValue(const MetaTensor& x,
void IsEmptyInferMeta(const MetaTensor& x, MetaTensor* out); void IsEmptyInferMeta(const MetaTensor& x, MetaTensor* out);
void IsfiniteInferMeta(const MetaTensor& input, MetaTensor* out);
void MultinomialInferMeta(const MetaTensor& x, void MultinomialInferMeta(const MetaTensor& x,
int num_samples, int num_samples,
bool replacement, bool replacement,
MetaTensor* out); MetaTensor* out);
void PadInferMeta(const MetaTensor& input,
const std::vector<int>& paddings,
float pad_value,
MetaTensor* out,
MetaConfig config = MetaConfig());
void PixelShuffleInferMeta(const MetaTensor& x,
int upscale_factor,
const std::string& data_format,
MetaTensor* out);
void RealAndImagInferMeta(const MetaTensor& x, MetaTensor* out);
void ReduceInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
MetaTensor* out);
void ReduceInferMetaBase(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
bool reduce_all,
MetaTensor* out);
void ReshapeInferMeta(const MetaTensor& x, void ReshapeInferMeta(const MetaTensor& x,
const ScalarArray& shape, const ScalarArray& shape,
MetaTensor* out, MetaTensor* out,
...@@ -100,28 +138,23 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x, ...@@ -100,28 +138,23 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x,
MetaTensor* out, MetaTensor* out,
MetaConfig config = MetaConfig()); MetaConfig config = MetaConfig());
void TileInferMeta(const MetaTensor& x, void ShardIndexInferMeta(const MetaTensor& in,
const ScalarArray& repeat_times, int index_num,
MetaTensor* out, int nshards,
MetaConfig config = MetaConfig()); int shard_id,
int ignore_value,
MetaTensor* out,
MetaConfig config = MetaConfig());
void SumRawInferMeta(const MetaTensor& x, void SizeInferMeta(const MetaTensor& input, MetaTensor* out);
const std::vector<int64_t>& axis,
bool keep_dim,
bool reduce_all,
DataType dtype,
MetaTensor* out);
void ReduceInferMetaBase(const MetaTensor& x, void SoftmaxInferMeta(const MetaTensor& x, int axis, MetaTensor* out);
const std::vector<int64_t>& axis,
bool keep_dim,
bool reduce_all,
MetaTensor* out);
void ReduceInferMeta(const MetaTensor& x, void SplitInferMeta(const MetaTensor& x_meta,
const std::vector<int64_t>& axis, const ScalarArray& num_or_sections,
bool keep_dim, const Scalar& axis,
MetaTensor* out); std::vector<MetaTensor*> out,
MetaConfig config = MetaConfig());
void SumInferMeta(const MetaTensor& x, void SumInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis, const std::vector<int64_t>& axis,
...@@ -129,21 +162,39 @@ void SumInferMeta(const MetaTensor& x, ...@@ -129,21 +162,39 @@ void SumInferMeta(const MetaTensor& x,
bool keep_dim, bool keep_dim,
MetaTensor* out); MetaTensor* out);
void SumRawInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
bool reduce_all,
DataType dtype,
MetaTensor* out);
void TileInferMeta(const MetaTensor& x,
const ScalarArray& repeat_times,
MetaTensor* out,
MetaConfig config = MetaConfig());
void TraceInferMeta(
const MetaTensor& x, int offset, int axis1, int axis2, MetaTensor* out);
void TransferLayoutInferMeta(const MetaTensor& x, void TransferLayoutInferMeta(const MetaTensor& x,
DataLayout layout, DataLayout layout,
MetaTensor* out); MetaTensor* out);
void SplitInferMeta(const MetaTensor& x_meta, void TransposeInferMeta(const MetaTensor& x,
const ScalarArray& num_or_sections, const std::vector<int>& axis,
const Scalar& axis, MetaTensor* out);
std::vector<MetaTensor*> out,
MetaConfig config = MetaConfig());
void UnbindInferMeta(const MetaTensor& x, void UnbindInferMeta(const MetaTensor& x,
int axis, int axis,
std::vector<MetaTensor>* outs); std::vector<MetaTensor>* outs);
void TraceInferMeta(
const MetaTensor& x, int offset, int axis1, int axis2, MetaTensor* out); void UnchangedInferMeta(const MetaTensor& x, MetaTensor* out);
// meta x -> out without change, check if axis in range [-Rank(x), Rank(x)-1]
void UnchangedInferMetaCheckAxis(const MetaTensor& x,
int axis,
MetaTensor* out);
void UnfoldInferMeta(const MetaTensor& x, void UnfoldInferMeta(const MetaTensor& x,
const std::vector<int>& kernel_sizes, const std::vector<int>& kernel_sizes,
...@@ -153,56 +204,6 @@ void UnfoldInferMeta(const MetaTensor& x, ...@@ -153,56 +204,6 @@ void UnfoldInferMeta(const MetaTensor& x,
MetaTensor* out, MetaTensor* out,
MetaConfig config = MetaConfig()); MetaConfig config = MetaConfig());
void DiagInferMeta(const MetaTensor& x,
int offset,
float padding_value,
MetaTensor* out);
void ArgMinMaxInferMeta(const MetaTensor& x,
int64_t axis,
bool keepdims,
bool flatten,
int dtype,
MetaTensor* out,
MetaConfig config = MetaConfig());
void SizeInferMeta(const MetaTensor& input, MetaTensor* out);
void PadInferMeta(const MetaTensor& input,
const std::vector<int>& paddings,
float pad_value,
MetaTensor* out,
MetaConfig config = MetaConfig());
void DiagonalInferMeta(
const MetaTensor& input, int offset, int axis1, int axis2, MetaTensor* out);
void PixelShuffleInferMeta(const MetaTensor& x,
int upscale_factor,
const std::string& data_format,
MetaTensor* out);
void IsfiniteInferMeta(const MetaTensor& input, MetaTensor* out);
void TransposeInferMeta(const MetaTensor& x,
const std::vector<int>& axis,
MetaTensor* out);
void EighInferMeta(const MetaTensor& x,
const std::string& uplo,
MetaTensor* out_w,
MetaTensor* out_v);
void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out); void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out);
void ShardIndexInferMeta(const MetaTensor& in,
int index_num,
int nshards,
int shard_id,
int ignore_value,
MetaTensor* out,
MetaConfig config = MetaConfig());
void SoftmaxInferMeta(const MetaTensor& x, int axis, MetaTensor* out);
} // namespace phi } // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册