未验证 提交 09039636 编写于 作者: C chentianyu03 提交者: GitHub

[Phi]rm reduce infershape (#39820)

* modify infershape utils and rm reduce infershape

* merge develop

* fix infermete bug

* add IsForInferShape func in ArgumentMappingContext

* add reduce_mean infermeta

* modify annotation

* add default dims
上级 197da15a
......@@ -88,6 +88,8 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext {
return var_types[0] == proto::VarType::SELECTED_ROWS;
}
bool IsForInferShape() const override { return true; }
private:
const InferShapeContext& ctx_;
};
......@@ -127,7 +129,9 @@ class CompatMetaTensor : public phi::MetaTensor {
}
} else {
auto* var = BOOST_GET_CONST(VarDesc*, var_);
return phi::make_ddim(var->GetShape());
return var->GetShape().empty() ? phi::make_ddim({0UL})
: phi::make_ddim(var->GetShape());
}
}
......
......@@ -489,6 +489,8 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext {
return ctx_.OutputVar(name)->IsType<phi::SelectedRows>();
}
bool IsForInferShape() const override { return false; }
private:
const ExecutionContext& ctx_;
};
......
......@@ -18,6 +18,10 @@
#include <utility>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -92,9 +96,13 @@ class __reduce_meanMaker__ : public ops::ReduceOpMaker {
virtual std::string GetOpType() const { return "Reduce reduce_mean"; }
};
DELCARE_INFER_SHAPE_FUNCTOR(reduce_mean, ReduceMeanInferShapeFunctor,
PT_INFER_META(phi::MeanRawInferMeta));
REGISTER_OPERATOR(reduce_mean, ops::ReduceOp, __reduce_meanMaker__,
ops::ReduceMeanOpGradMaker<paddle::framework::OpDesc>,
ops::ReduceMeanOpGradMaker<paddle::imperative::OpBase>);
ops::ReduceMeanOpGradMaker<paddle::imperative::OpBase>,
ReduceMeanInferShapeFunctor);
REGISTER_OPERATOR(reduce_mean_grad, ops::ReduceGradOp,
ops::ReduceMeanDoubleGradDescMaker,
ops::ReduceMeanDoubleGradOpBaseMaker,
......
......@@ -16,6 +16,10 @@
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace framework {
class OpDesc;
......@@ -98,10 +102,14 @@ class ReduceSumOpMaker : public ops::ReduceOpMaker {
virtual std::string GetOpType() const { return "Reduce reduce_sum"; }
};
DELCARE_INFER_SHAPE_FUNCTOR(reduce_sum, ReduceSumInferShapeFunctor,
PT_INFER_META(phi::ReduceInferMetaBase));
REGISTER_OPERATOR(reduce_sum, ops::ReduceOp, ReduceSumOpMaker,
ops::ReduceSumVarTypeInference,
ops::ReduceSumOpGradMaker<paddle::framework::OpDesc>,
ops::ReduceSumOpGradMaker<paddle::imperative::OpBase>);
ops::ReduceSumOpGradMaker<paddle::imperative::OpBase>,
ReduceSumInferShapeFunctor);
REGISTER_OPERATOR(reduce_sum_grad, ops::ReduceGradOp,
ops::ReduceSumDoubleOpGradMaker<paddle::framework::OpDesc>,
ops::ReduceSumDoubleOpGradMaker<paddle::imperative::OpBase>,
......
......@@ -46,6 +46,8 @@ class ProtoArgumentMappingContext : public phi::ArgumentMappingContext {
bool IsDenseTensorOutput(const std::string& name) const override;
bool IsSelectedRowsOutput(const std::string& name) const override;
bool IsForInferShape() const override { return false; }
private:
mlir::Operation* op_;
const std::unordered_map<std::string, uint8_t>& input_map_;
......
......@@ -91,6 +91,10 @@ class ArgumentMappingContext {
virtual bool IsDenseTensorOutput(const std::string& name) const = 0;
virtual bool IsSelectedRowsOutput(const std::string& name) const = 0;
// use this function to mark it comes from InferShapeArgumentMappingContext
// and will be used in infershape
virtual bool IsForInferShape() const = 0;
};
} // namespace phi
......@@ -375,7 +375,7 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x,
ReshapeInferMeta(x, shape, out, config);
}
/* Why not use ReduceInferMeta directly?
/* Why not use ReduceInferMetaBase directly?
Because we need make InferMetaFunction's args follow the design of api.yaml
*/
void SumInferMeta(const MetaTensor& x,
......@@ -383,22 +383,53 @@ void SumInferMeta(const MetaTensor& x,
DataType dtype,
bool keep_dim,
MetaTensor* out) {
ReduceInferMetaBase(x, axis, keep_dim, dtype, out);
bool reduce_all = false;
ReduceInferMetaBase(x, axis, keep_dim, reduce_all, dtype, out);
}
void ReduceInferMetaBase(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
bool reduce_all,
DataType dtype,
MetaTensor* out) {
bool reduce_all = true;
std::set<int64_t> dims_set(axis.begin(), axis.end());
auto x_rank = x.dims().size();
std::vector<int64_t> formated_axis = axis;
for (size_t i = 0; i < axis.size(); ++i) {
PADDLE_ENFORCE_LT(axis[i],
x_rank,
errors::InvalidArgument(
"The reduce dim index %d should be in the "
"range [-dimension(X), dimension(X)] "
"which dimesion = %d. But received dim index = %d.",
i,
x_rank,
axis[i]));
PADDLE_ENFORCE_GE(axis[i],
-x_rank,
errors::InvalidArgument(
"The reduce dim index %d should be in the "
"range [-dimension(X), dimension(X)] "
"which dimesion = %d. But received dim index = %d.",
i,
x_rank,
axis[i]));
if (axis[i] < 0) {
formated_axis[i] = axis[i] + x_rank;
}
}
bool full_dim = true;
std::set<int64_t> dims_set(formated_axis.begin(), formated_axis.end());
for (int64_t i = 0; i < x.dims().size(); ++i) {
if (dims_set.find(i) == dims_set.end()) {
reduce_all = false;
full_dim = false;
break;
}
}
reduce_all = reduce_all || full_dim;
std::vector<int64_t> out_dim_vector;
if (keep_dim) {
......@@ -441,11 +472,20 @@ void ReduceInferMetaBase(const MetaTensor& x,
out->set_layout(x.layout());
}
void ReduceInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
MetaTensor* out) {
ReduceInferMetaBase(x, axis, keep_dim, DataType::UNDEFINED, out);
void MeanRawInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
bool reduce_all,
MetaTensor* out) {
ReduceInferMetaBase(x, axis, keep_dim, reduce_all, DataType::UNDEFINED, out);
}
void MeanInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
MetaTensor* out) {
bool reduce_all = false;
ReduceInferMetaBase(x, axis, keep_dim, reduce_all, DataType::UNDEFINED, out);
}
void TransferLayoutInferMeta(const MetaTensor& x,
......
......@@ -86,13 +86,20 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x,
void ReduceInferMetaBase(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
bool reduce_all,
DataType dtype,
MetaTensor* out);
void ReduceInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
MetaTensor* out);
void MeanRawInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
bool reduce_all,
MetaTensor* out);
void MeanInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
MetaTensor* out);
void SumInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
......
......@@ -156,7 +156,7 @@ DenseTensor Mean(const Context& dev_ctx,
bool keep_dim) {
auto dense_out = phi::Empty<T, Context>(dev_ctx);
MetaTensor meta_out(&dense_out);
ReduceInferMetaBase(x, axis, keep_dim, x.dtype(), &meta_out);
ReduceInferMetaBase(x, axis, keep_dim, false, x.dtype(), &meta_out);
MeanKernel<T, Context>(dev_ctx, x, axis, keep_dim, &dense_out);
return dense_out;
}
......
......@@ -17,28 +17,36 @@ limitations under the License. */
namespace phi {
KernelSignature ReduceSumOpArgumentMapping(const ArgumentMappingContext& ctx) {
bool reduce_all = paddle::any_cast<bool>(ctx.Attr("reduce_all"));
if (ctx.IsDenseTensorInput("X")) {
if (!reduce_all) {
return KernelSignature(
"sum", {"X"}, {"dim", "out_dtype", "keep_dim"}, {"Out"});
bool reduce_all = paddle::any_cast<bool>(ctx.Attr("reduce_all"));
// When ctx is InferShapeArgumentMappingContext, the reduce_all is used in
// InferShape, so we must return the "sum_raw" KernelSignature.
// And the InferMeta function(i.e. ReduceInferMetaBase) is accordance with
// the "sum_raw" KernelSignature
if (ctx.IsForInferShape() || reduce_all) {
return KernelSignature("sum_raw",
{"X"},
{"dim", "keep_dim", "reduce_all", "out_dtype"},
{"Out"});
}
return KernelSignature("sum_raw",
{"X"},
{"dim", "keep_dim", "reduce_all", "out_dtype"},
{"Out"});
return KernelSignature(
"sum", {"X"}, {"dim", "out_dtype", "keep_dim"}, {"Out"});
}
return KernelSignature("unregistered", {}, {}, {});
}
KernelSignature ReduceMeanOpArgumentMapping(const ArgumentMappingContext& ctx) {
bool reduce_all = paddle::any_cast<bool>(ctx.Attr("reduce_all"));
if (ctx.IsDenseTensorInput("X")) {
if (!reduce_all) {
return KernelSignature("mean", {"X"}, {"dim", "keep_dim"}, {"Out"});
bool reduce_all = paddle::any_cast<bool>(ctx.Attr("reduce_all"));
// When ctx is InferShapeArgumentMappingContext, the reduce_all is used in
// InferShape, so we must return the "mean_raw" KernelSignature.
// And the InferMeta function(i.e. MeanRawInferMeta) is accordance with the
// "mean_raw" KernelSignature
if (ctx.IsForInferShape() || reduce_all) {
return KernelSignature(
"mean_raw", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"});
}
return KernelSignature(
"mean_raw", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"});
return KernelSignature("mean", {"X"}, {"dim", "keep_dim"}, {"Out"});
}
return KernelSignature("unregistered", {}, {}, {});
}
......
......@@ -80,6 +80,8 @@ class TestArgumentMappingContext : public phi::ArgumentMappingContext {
return selected_rows_outputs.count(name) > 0;
}
bool IsForInferShape() const override { return false; }
private:
const std::unordered_set<std::string> dense_tensor_inputs;
const std::unordered_set<std::string> selected_rows_inputs;
......
......@@ -124,7 +124,7 @@
args : (Tensor x, int64_t[] axis={}, bool keep_dim=false)
output : Tensor
infer_meta :
func : ReduceInferMeta
func : MeanInferMeta
kernel :
func : mean
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册