未验证 提交 09039636 编写于 作者: C chentianyu03 提交者: GitHub

[Phi]rm reduce infershape (#39820)

* modify infershape utils and rm reduce infershape

* merge develop

* fix infermete bug

* add IsForInferShape func in ArgumentMappingContext

* add reduce_mean infermeta

* modify annotation

* add default dims
上级 197da15a
...@@ -88,6 +88,8 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext { ...@@ -88,6 +88,8 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext {
return var_types[0] == proto::VarType::SELECTED_ROWS; return var_types[0] == proto::VarType::SELECTED_ROWS;
} }
bool IsForInferShape() const override { return true; }
private: private:
const InferShapeContext& ctx_; const InferShapeContext& ctx_;
}; };
...@@ -127,7 +129,9 @@ class CompatMetaTensor : public phi::MetaTensor { ...@@ -127,7 +129,9 @@ class CompatMetaTensor : public phi::MetaTensor {
} }
} else { } else {
auto* var = BOOST_GET_CONST(VarDesc*, var_); auto* var = BOOST_GET_CONST(VarDesc*, var_);
return phi::make_ddim(var->GetShape());
return var->GetShape().empty() ? phi::make_ddim({0UL})
: phi::make_ddim(var->GetShape());
} }
} }
......
...@@ -489,6 +489,8 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext { ...@@ -489,6 +489,8 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext {
return ctx_.OutputVar(name)->IsType<phi::SelectedRows>(); return ctx_.OutputVar(name)->IsType<phi::SelectedRows>();
} }
bool IsForInferShape() const override { return false; }
private: private:
const ExecutionContext& ctx_; const ExecutionContext& ctx_;
}; };
......
...@@ -18,6 +18,10 @@ ...@@ -18,6 +18,10 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -92,9 +96,13 @@ class __reduce_meanMaker__ : public ops::ReduceOpMaker { ...@@ -92,9 +96,13 @@ class __reduce_meanMaker__ : public ops::ReduceOpMaker {
virtual std::string GetOpType() const { return "Reduce reduce_mean"; } virtual std::string GetOpType() const { return "Reduce reduce_mean"; }
}; };
DELCARE_INFER_SHAPE_FUNCTOR(reduce_mean, ReduceMeanInferShapeFunctor,
PT_INFER_META(phi::MeanRawInferMeta));
REGISTER_OPERATOR(reduce_mean, ops::ReduceOp, __reduce_meanMaker__, REGISTER_OPERATOR(reduce_mean, ops::ReduceOp, __reduce_meanMaker__,
ops::ReduceMeanOpGradMaker<paddle::framework::OpDesc>, ops::ReduceMeanOpGradMaker<paddle::framework::OpDesc>,
ops::ReduceMeanOpGradMaker<paddle::imperative::OpBase>); ops::ReduceMeanOpGradMaker<paddle::imperative::OpBase>,
ReduceMeanInferShapeFunctor);
REGISTER_OPERATOR(reduce_mean_grad, ops::ReduceGradOp, REGISTER_OPERATOR(reduce_mean_grad, ops::ReduceGradOp,
ops::ReduceMeanDoubleGradDescMaker, ops::ReduceMeanDoubleGradDescMaker,
ops::ReduceMeanDoubleGradOpBaseMaker, ops::ReduceMeanDoubleGradOpBaseMaker,
......
...@@ -16,6 +16,10 @@ ...@@ -16,6 +16,10 @@
#include <string> #include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class OpDesc; class OpDesc;
...@@ -98,10 +102,14 @@ class ReduceSumOpMaker : public ops::ReduceOpMaker { ...@@ -98,10 +102,14 @@ class ReduceSumOpMaker : public ops::ReduceOpMaker {
virtual std::string GetOpType() const { return "Reduce reduce_sum"; } virtual std::string GetOpType() const { return "Reduce reduce_sum"; }
}; };
DELCARE_INFER_SHAPE_FUNCTOR(reduce_sum, ReduceSumInferShapeFunctor,
PT_INFER_META(phi::ReduceInferMetaBase));
REGISTER_OPERATOR(reduce_sum, ops::ReduceOp, ReduceSumOpMaker, REGISTER_OPERATOR(reduce_sum, ops::ReduceOp, ReduceSumOpMaker,
ops::ReduceSumVarTypeInference, ops::ReduceSumVarTypeInference,
ops::ReduceSumOpGradMaker<paddle::framework::OpDesc>, ops::ReduceSumOpGradMaker<paddle::framework::OpDesc>,
ops::ReduceSumOpGradMaker<paddle::imperative::OpBase>); ops::ReduceSumOpGradMaker<paddle::imperative::OpBase>,
ReduceSumInferShapeFunctor);
REGISTER_OPERATOR(reduce_sum_grad, ops::ReduceGradOp, REGISTER_OPERATOR(reduce_sum_grad, ops::ReduceGradOp,
ops::ReduceSumDoubleOpGradMaker<paddle::framework::OpDesc>, ops::ReduceSumDoubleOpGradMaker<paddle::framework::OpDesc>,
ops::ReduceSumDoubleOpGradMaker<paddle::imperative::OpBase>, ops::ReduceSumDoubleOpGradMaker<paddle::imperative::OpBase>,
......
...@@ -46,6 +46,8 @@ class ProtoArgumentMappingContext : public phi::ArgumentMappingContext { ...@@ -46,6 +46,8 @@ class ProtoArgumentMappingContext : public phi::ArgumentMappingContext {
bool IsDenseTensorOutput(const std::string& name) const override; bool IsDenseTensorOutput(const std::string& name) const override;
bool IsSelectedRowsOutput(const std::string& name) const override; bool IsSelectedRowsOutput(const std::string& name) const override;
bool IsForInferShape() const override { return false; }
private: private:
mlir::Operation* op_; mlir::Operation* op_;
const std::unordered_map<std::string, uint8_t>& input_map_; const std::unordered_map<std::string, uint8_t>& input_map_;
......
...@@ -91,6 +91,10 @@ class ArgumentMappingContext { ...@@ -91,6 +91,10 @@ class ArgumentMappingContext {
virtual bool IsDenseTensorOutput(const std::string& name) const = 0; virtual bool IsDenseTensorOutput(const std::string& name) const = 0;
virtual bool IsSelectedRowsOutput(const std::string& name) const = 0; virtual bool IsSelectedRowsOutput(const std::string& name) const = 0;
// use this function to mark it comes from InferShapeArgumentMappingContext
// and will be used in infershape
virtual bool IsForInferShape() const = 0;
}; };
} // namespace phi } // namespace phi
...@@ -375,7 +375,7 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x, ...@@ -375,7 +375,7 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x,
ReshapeInferMeta(x, shape, out, config); ReshapeInferMeta(x, shape, out, config);
} }
/* Why not use ReduceInferMeta directly? /* Why not use ReduceInferMetaBase directly?
Because we need make InferMetaFunction's args follow the design of api.yaml Because we need make InferMetaFunction's args follow the design of api.yaml
*/ */
void SumInferMeta(const MetaTensor& x, void SumInferMeta(const MetaTensor& x,
...@@ -383,22 +383,53 @@ void SumInferMeta(const MetaTensor& x, ...@@ -383,22 +383,53 @@ void SumInferMeta(const MetaTensor& x,
DataType dtype, DataType dtype,
bool keep_dim, bool keep_dim,
MetaTensor* out) { MetaTensor* out) {
ReduceInferMetaBase(x, axis, keep_dim, dtype, out); bool reduce_all = false;
ReduceInferMetaBase(x, axis, keep_dim, reduce_all, dtype, out);
} }
void ReduceInferMetaBase(const MetaTensor& x, void ReduceInferMetaBase(const MetaTensor& x,
const std::vector<int64_t>& axis, const std::vector<int64_t>& axis,
bool keep_dim, bool keep_dim,
bool reduce_all,
DataType dtype, DataType dtype,
MetaTensor* out) { MetaTensor* out) {
bool reduce_all = true; auto x_rank = x.dims().size();
std::set<int64_t> dims_set(axis.begin(), axis.end());
std::vector<int64_t> formated_axis = axis;
for (size_t i = 0; i < axis.size(); ++i) {
PADDLE_ENFORCE_LT(axis[i],
x_rank,
errors::InvalidArgument(
"The reduce dim index %d should be in the "
"range [-dimension(X), dimension(X)] "
"which dimesion = %d. But received dim index = %d.",
i,
x_rank,
axis[i]));
PADDLE_ENFORCE_GE(axis[i],
-x_rank,
errors::InvalidArgument(
"The reduce dim index %d should be in the "
"range [-dimension(X), dimension(X)] "
"which dimesion = %d. But received dim index = %d.",
i,
x_rank,
axis[i]));
if (axis[i] < 0) {
formated_axis[i] = axis[i] + x_rank;
}
}
bool full_dim = true;
std::set<int64_t> dims_set(formated_axis.begin(), formated_axis.end());
for (int64_t i = 0; i < x.dims().size(); ++i) { for (int64_t i = 0; i < x.dims().size(); ++i) {
if (dims_set.find(i) == dims_set.end()) { if (dims_set.find(i) == dims_set.end()) {
reduce_all = false; full_dim = false;
break; break;
} }
} }
reduce_all = reduce_all || full_dim;
std::vector<int64_t> out_dim_vector; std::vector<int64_t> out_dim_vector;
if (keep_dim) { if (keep_dim) {
...@@ -441,11 +472,20 @@ void ReduceInferMetaBase(const MetaTensor& x, ...@@ -441,11 +472,20 @@ void ReduceInferMetaBase(const MetaTensor& x,
out->set_layout(x.layout()); out->set_layout(x.layout());
} }
void ReduceInferMeta(const MetaTensor& x, void MeanRawInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis, const std::vector<int64_t>& axis,
bool keep_dim, bool keep_dim,
MetaTensor* out) { bool reduce_all,
ReduceInferMetaBase(x, axis, keep_dim, DataType::UNDEFINED, out); MetaTensor* out) {
ReduceInferMetaBase(x, axis, keep_dim, reduce_all, DataType::UNDEFINED, out);
}
void MeanInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
MetaTensor* out) {
bool reduce_all = false;
ReduceInferMetaBase(x, axis, keep_dim, reduce_all, DataType::UNDEFINED, out);
} }
void TransferLayoutInferMeta(const MetaTensor& x, void TransferLayoutInferMeta(const MetaTensor& x,
......
...@@ -86,13 +86,20 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x, ...@@ -86,13 +86,20 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x,
void ReduceInferMetaBase(const MetaTensor& x, void ReduceInferMetaBase(const MetaTensor& x,
const std::vector<int64_t>& axis, const std::vector<int64_t>& axis,
bool keep_dim, bool keep_dim,
bool reduce_all,
DataType dtype, DataType dtype,
MetaTensor* out); MetaTensor* out);
void ReduceInferMeta(const MetaTensor& x, void MeanRawInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis, const std::vector<int64_t>& axis,
bool keep_dim, bool keep_dim,
MetaTensor* out); bool reduce_all,
MetaTensor* out);
void MeanInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
MetaTensor* out);
void SumInferMeta(const MetaTensor& x, void SumInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis, const std::vector<int64_t>& axis,
......
...@@ -156,7 +156,7 @@ DenseTensor Mean(const Context& dev_ctx, ...@@ -156,7 +156,7 @@ DenseTensor Mean(const Context& dev_ctx,
bool keep_dim) { bool keep_dim) {
auto dense_out = phi::Empty<T, Context>(dev_ctx); auto dense_out = phi::Empty<T, Context>(dev_ctx);
MetaTensor meta_out(&dense_out); MetaTensor meta_out(&dense_out);
ReduceInferMetaBase(x, axis, keep_dim, x.dtype(), &meta_out); ReduceInferMetaBase(x, axis, keep_dim, false, x.dtype(), &meta_out);
MeanKernel<T, Context>(dev_ctx, x, axis, keep_dim, &dense_out); MeanKernel<T, Context>(dev_ctx, x, axis, keep_dim, &dense_out);
return dense_out; return dense_out;
} }
......
...@@ -17,28 +17,36 @@ limitations under the License. */ ...@@ -17,28 +17,36 @@ limitations under the License. */
namespace phi { namespace phi {
KernelSignature ReduceSumOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature ReduceSumOpArgumentMapping(const ArgumentMappingContext& ctx) {
bool reduce_all = paddle::any_cast<bool>(ctx.Attr("reduce_all"));
if (ctx.IsDenseTensorInput("X")) { if (ctx.IsDenseTensorInput("X")) {
if (!reduce_all) { bool reduce_all = paddle::any_cast<bool>(ctx.Attr("reduce_all"));
return KernelSignature( // When ctx is InferShapeArgumentMappingContext, the reduce_all is used in
"sum", {"X"}, {"dim", "out_dtype", "keep_dim"}, {"Out"}); // InferShape, so we must return the "sum_raw" KernelSignature.
// And the InferMeta function(i.e. ReduceInferMetaBase) is accordance with
// the "sum_raw" KernelSignature
if (ctx.IsForInferShape() || reduce_all) {
return KernelSignature("sum_raw",
{"X"},
{"dim", "keep_dim", "reduce_all", "out_dtype"},
{"Out"});
} }
return KernelSignature("sum_raw", return KernelSignature(
{"X"}, "sum", {"X"}, {"dim", "out_dtype", "keep_dim"}, {"Out"});
{"dim", "keep_dim", "reduce_all", "out_dtype"},
{"Out"});
} }
return KernelSignature("unregistered", {}, {}, {}); return KernelSignature("unregistered", {}, {}, {});
} }
KernelSignature ReduceMeanOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature ReduceMeanOpArgumentMapping(const ArgumentMappingContext& ctx) {
bool reduce_all = paddle::any_cast<bool>(ctx.Attr("reduce_all"));
if (ctx.IsDenseTensorInput("X")) { if (ctx.IsDenseTensorInput("X")) {
if (!reduce_all) { bool reduce_all = paddle::any_cast<bool>(ctx.Attr("reduce_all"));
return KernelSignature("mean", {"X"}, {"dim", "keep_dim"}, {"Out"}); // When ctx is InferShapeArgumentMappingContext, the reduce_all is used in
// InferShape, so we must return the "mean_raw" KernelSignature.
// And the InferMeta function(i.e. MeanRawInferMeta) is accordance with the
// "mean_raw" KernelSignature
if (ctx.IsForInferShape() || reduce_all) {
return KernelSignature(
"mean_raw", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"});
} }
return KernelSignature( return KernelSignature("mean", {"X"}, {"dim", "keep_dim"}, {"Out"});
"mean_raw", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"});
} }
return KernelSignature("unregistered", {}, {}, {}); return KernelSignature("unregistered", {}, {}, {});
} }
......
...@@ -80,6 +80,8 @@ class TestArgumentMappingContext : public phi::ArgumentMappingContext { ...@@ -80,6 +80,8 @@ class TestArgumentMappingContext : public phi::ArgumentMappingContext {
return selected_rows_outputs.count(name) > 0; return selected_rows_outputs.count(name) > 0;
} }
bool IsForInferShape() const override { return false; }
private: private:
const std::unordered_set<std::string> dense_tensor_inputs; const std::unordered_set<std::string> dense_tensor_inputs;
const std::unordered_set<std::string> selected_rows_inputs; const std::unordered_set<std::string> selected_rows_inputs;
......
...@@ -124,7 +124,7 @@ ...@@ -124,7 +124,7 @@
args : (Tensor x, int64_t[] axis={}, bool keep_dim=false) args : (Tensor x, int64_t[] axis={}, bool keep_dim=false)
output : Tensor output : Tensor
infer_meta : infer_meta :
func : ReduceInferMeta func : MeanInferMeta
kernel : kernel :
func : mean func : mean
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册