未验证 提交 8031a4dc 编写于 作者: C chentianyu03 提交者: GitHub

[Phi] move Reduce max kernel into phi (#40225)

* add reduce_max kernel

* add reduce max kernel

* update reduce max Argumentmapping

* remove reduce_max kernel

* remove reduce_max kernel

* add reduce max infermeta

* rename reduce infermeta
上级 fb4215b2
...@@ -14,15 +14,28 @@ ...@@ -14,15 +14,28 @@
#include "paddle/fluid/operators/reduce_ops/reduce_min_max_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_min_max_op.h"
REGISTER_REDUCE_OP(reduce_max); #include "paddle/fluid/framework/infershape_utils.h"
REGISTER_OP_CPU_KERNEL( #include "paddle/phi/core/infermeta_utils.h"
reduce_max, ops::ReduceKernel<paddle::platform::CPUDeviceContext, float, #include "paddle/phi/infermeta/unary.h"
ops::MaxFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, double, namespace ops = paddle::operators;
ops::MaxFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, int, ops::MaxFunctor>, class ReduceMaxOpMaker : public ops::ReduceOpMaker {
ops::ReduceKernel<paddle::platform::CPUDeviceContext, int64_t, protected:
ops::MaxFunctor>); virtual std::string GetName() const { return "reduce_max"; }
virtual std::string GetOpType() const { return "Reduce reduce_max"; }
};
DECLARE_INFER_SHAPE_FUNCTOR(reduce_max, ReduceMaxInferShapeFunctor,
PD_INFER_META(phi::ReduceInferMetaBase));
REGISTER_OPERATOR(
reduce_max, ops::ReduceOp, ReduceMaxOpMaker,
paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>,
paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>,
ReduceMaxInferShapeFunctor);
REGISTER_OPERATOR(reduce_max_grad, ops::ReduceGradOp)
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
reduce_max_grad, ops::ReduceGradKernel<paddle::platform::CPUDeviceContext, reduce_max_grad, ops::ReduceGradKernel<paddle::platform::CPUDeviceContext,
float, ops::MaxOrMinGradFunctor>, float, ops::MaxOrMinGradFunctor>,
......
...@@ -97,7 +97,7 @@ class __reduce_meanMaker__ : public ops::ReduceOpMaker { ...@@ -97,7 +97,7 @@ class __reduce_meanMaker__ : public ops::ReduceOpMaker {
}; };
DECLARE_INFER_SHAPE_FUNCTOR(reduce_mean, ReduceMeanInferShapeFunctor, DECLARE_INFER_SHAPE_FUNCTOR(reduce_mean, ReduceMeanInferShapeFunctor,
PD_INFER_META(phi::MeanRawInferMeta)); PD_INFER_META(phi::ReduceInferMetaBase));
REGISTER_OPERATOR(reduce_mean, ops::ReduceOp, __reduce_meanMaker__, REGISTER_OPERATOR(reduce_mean, ops::ReduceOp, __reduce_meanMaker__,
ops::ReduceMeanOpGradMaker<paddle::framework::OpDesc>, ops::ReduceMeanOpGradMaker<paddle::framework::OpDesc>,
......
...@@ -103,7 +103,7 @@ class ReduceSumOpMaker : public ops::ReduceOpMaker { ...@@ -103,7 +103,7 @@ class ReduceSumOpMaker : public ops::ReduceOpMaker {
}; };
DECLARE_INFER_SHAPE_FUNCTOR(reduce_sum, ReduceSumInferShapeFunctor, DECLARE_INFER_SHAPE_FUNCTOR(reduce_sum, ReduceSumInferShapeFunctor,
PD_INFER_META(phi::ReduceInferMetaBase)); PD_INFER_META(phi::SumRawInferMeta));
REGISTER_OPERATOR(reduce_sum, ops::ReduceOp, ReduceSumOpMaker, REGISTER_OPERATOR(reduce_sum, ops::ReduceOp, ReduceSumOpMaker,
ops::ReduceSumVarTypeInference, ops::ReduceSumVarTypeInference,
......
...@@ -47,6 +47,7 @@ const std::unordered_set<std::string> deprecated_op_names({"diag", ...@@ -47,6 +47,7 @@ const std::unordered_set<std::string> deprecated_op_names({"diag",
"matmul_grad", "matmul_grad",
"matmul_grad_grad", "matmul_grad_grad",
"mean", "mean",
"max",
"reshape", "reshape",
"reshape_grad", "reshape_grad",
"expand", "expand",
......
...@@ -406,7 +406,7 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x, ...@@ -406,7 +406,7 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x,
ReshapeInferMeta(x, shape, out, config); ReshapeInferMeta(x, shape, out, config);
} }
/* Why not use ReduceInferMetaBase directly? /* Why not use SumRawInferMeta directly?
Because we need make InferMetaFunction's args follow the design of api.yaml Because we need make InferMetaFunction's args follow the design of api.yaml
*/ */
void SumInferMeta(const MetaTensor& x, void SumInferMeta(const MetaTensor& x,
...@@ -415,15 +415,13 @@ void SumInferMeta(const MetaTensor& x, ...@@ -415,15 +415,13 @@ void SumInferMeta(const MetaTensor& x,
bool keep_dim, bool keep_dim,
MetaTensor* out) { MetaTensor* out) {
bool reduce_all = false; bool reduce_all = false;
ReduceInferMetaBase(x, axis, keep_dim, reduce_all, dtype, out); SumRawInferMeta(x, axis, keep_dim, reduce_all, dtype, out);
} }
void ReduceInferMetaBase(const MetaTensor& x, DDim ReduceInferDim(const MetaTensor& x,
const std::vector<int64_t>& axis, const std::vector<int64_t>& axis,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all) {
DataType dtype,
MetaTensor* out) {
auto x_rank = x.dims().size(); auto x_rank = x.dims().size();
std::vector<int64_t> formated_axis = axis; std::vector<int64_t> formated_axis = axis;
...@@ -486,6 +484,17 @@ void ReduceInferMetaBase(const MetaTensor& x, ...@@ -486,6 +484,17 @@ void ReduceInferMetaBase(const MetaTensor& x,
} }
DDim out_dim = phi::make_ddim(out_dim_vector); DDim out_dim = phi::make_ddim(out_dim_vector);
return out_dim;
}
void SumRawInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
bool reduce_all,
DataType dtype,
MetaTensor* out) {
DDim out_dim = ReduceInferDim(x, axis, keep_dim, reduce_all);
DataType out_dtype; DataType out_dtype;
if (dtype != DataType::UNDEFINED) { if (dtype != DataType::UNDEFINED) {
out_dtype = dtype; out_dtype = dtype;
...@@ -503,20 +512,23 @@ void ReduceInferMetaBase(const MetaTensor& x, ...@@ -503,20 +512,23 @@ void ReduceInferMetaBase(const MetaTensor& x,
out->set_layout(x.layout()); out->set_layout(x.layout());
} }
void MeanRawInferMeta(const MetaTensor& x, void ReduceInferMetaBase(const MetaTensor& x,
const std::vector<int64_t>& axis, const std::vector<int64_t>& axis,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
MetaTensor* out) { MetaTensor* out) {
ReduceInferMetaBase(x, axis, keep_dim, reduce_all, DataType::UNDEFINED, out); DDim out_dim = ReduceInferDim(x, axis, keep_dim, reduce_all);
out->set_dims(out_dim);
out->set_dtype(x.dtype());
out->set_layout(x.layout());
} }
void MeanInferMeta(const MetaTensor& x, void ReduceInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis, const std::vector<int64_t>& axis,
bool keep_dim, bool keep_dim,
MetaTensor* out) { MetaTensor* out) {
bool reduce_all = false; bool reduce_all = false;
ReduceInferMetaBase(x, axis, keep_dim, reduce_all, DataType::UNDEFINED, out); ReduceInferMetaBase(x, axis, keep_dim, reduce_all, out);
} }
void TransferLayoutInferMeta(const MetaTensor& x, void TransferLayoutInferMeta(const MetaTensor& x,
......
...@@ -94,23 +94,23 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x, ...@@ -94,23 +94,23 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x,
MetaTensor* out, MetaTensor* out,
MetaConfig config = MetaConfig()); MetaConfig config = MetaConfig());
void SumRawInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
bool reduce_all,
DataType dtype,
MetaTensor* out);
void ReduceInferMetaBase(const MetaTensor& x, void ReduceInferMetaBase(const MetaTensor& x,
const std::vector<int64_t>& axis, const std::vector<int64_t>& axis,
bool keep_dim, bool keep_dim,
bool reduce_all, bool reduce_all,
DataType dtype,
MetaTensor* out); MetaTensor* out);
void MeanRawInferMeta(const MetaTensor& x, void ReduceInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis, const std::vector<int64_t>& axis,
bool keep_dim, bool keep_dim,
bool reduce_all, MetaTensor* out);
MetaTensor* out);
void MeanInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
MetaTensor* out);
void SumInferMeta(const MetaTensor& x, void SumInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis, const std::vector<int64_t>& axis,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/reduce_max_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/reduce.h"
#include "paddle/phi/kernels/funcs/reduce_functor.h"
namespace phi {
template <typename T, typename Context>
void MaxRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
auto out_dtype = x.dtype();
phi::Reduce<CPUContext, T, phi::funcs::MaxFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
}
} // namespace phi
PD_REGISTER_KERNEL(
max_raw, CPU, ALL_LAYOUT, phi::MaxRawKernel, float, double, int, int64_t) {}
...@@ -41,5 +41,13 @@ struct ProdFunctor { ...@@ -41,5 +41,13 @@ struct ProdFunctor {
} }
}; };
//////// Max Functor ///////
struct MaxFunctor {
template <typename DeviceContext, typename X, typename Y, typename Dim>
void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) {
y->device(place) = x->maximum(dim);
}
};
} // namespace funcs } // namespace funcs
} // namespace phi } // namespace phi
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -11,13 +11,27 @@ ...@@ -11,13 +11,27 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
// reduce_max #include "paddle/phi/kernels/reduce_max_kernel.h"
REGISTER_OP_CUDA_KERNEL(
reduce_max, #include "paddle/phi/core/kernel_registry.h"
ops::ReduceCudaKernel<float, kps::MaxFunctor, kps::IdentityFunctor>, #include "paddle/phi/kernels/gpu/reduce.h"
ops::ReduceCudaKernel<double, kps::MaxFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<int, kps::MaxFunctor, kps::IdentityFunctor>, namespace phi {
ops::ReduceCudaKernel<int64_t, kps::MaxFunctor, kps::IdentityFunctor>);
template <typename T, typename Context>
void MaxRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
auto out_dtype = x.dtype();
phi::Reduce<T, kps::MaxFunctor, kps::IdentityFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
}
} // namespace phi
PD_REGISTER_KERNEL(
max_raw, GPU, ALL_LAYOUT, phi::MaxRawKernel, float, double, int, int64_t) {}
...@@ -156,7 +156,7 @@ DenseTensor Mean(const Context& dev_ctx, ...@@ -156,7 +156,7 @@ DenseTensor Mean(const Context& dev_ctx,
bool keep_dim) { bool keep_dim) {
DenseTensor dense_out; DenseTensor dense_out;
MetaTensor meta_out(&dense_out); MetaTensor meta_out(&dense_out);
ReduceInferMetaBase(x, axis, keep_dim, false, x.dtype(), &meta_out); SumRawInferMeta(x, axis, keep_dim, false, x.dtype(), &meta_out);
MeanKernel<T, Context>(dev_ctx, x, axis, keep_dim, &dense_out); MeanKernel<T, Context>(dev_ctx, x, axis, keep_dim, &dense_out);
return dense_out; return dense_out;
} }
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/reduce_max_kernel.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void MaxKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
DenseTensor* out) {
bool reduce_all = false;
MaxRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
}
} // namespace phi
PD_REGISTER_KERNEL(
max, CPU, ALL_LAYOUT, phi::MaxKernel, float, double, int, int64_t) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(
max, GPU, ALL_LAYOUT, phi::MaxKernel, float, double, int, int64_t) {}
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/infermeta/binary.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/kernels/empty_kernel.h"
namespace phi {
template <typename T, typename Context>
void MaxRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DenseTensor* out);
template <typename T, typename Context>
void MaxKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
DenseTensor* out);
} // namespace phi
...@@ -21,7 +21,7 @@ KernelSignature ReduceSumOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -21,7 +21,7 @@ KernelSignature ReduceSumOpArgumentMapping(const ArgumentMappingContext& ctx) {
bool reduce_all = paddle::any_cast<bool>(ctx.Attr("reduce_all")); bool reduce_all = paddle::any_cast<bool>(ctx.Attr("reduce_all"));
// When ctx is InferShapeArgumentMappingContext, the reduce_all is used in // When ctx is InferShapeArgumentMappingContext, the reduce_all is used in
// InferShape, so we must return the "sum_raw" KernelSignature. // InferShape, so we must return the "sum_raw" KernelSignature.
// And the InferMeta function(i.e. ReduceInferMetaBase) is accordance with // And the InferMeta function(i.e. SumRawInferMeta) is accordance with
// the "sum_raw" KernelSignature // the "sum_raw" KernelSignature
if (ctx.IsForInferShape() || reduce_all) { if (ctx.IsForInferShape() || reduce_all) {
return KernelSignature("sum_raw", return KernelSignature("sum_raw",
...@@ -40,7 +40,8 @@ KernelSignature ReduceMeanOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -40,7 +40,8 @@ KernelSignature ReduceMeanOpArgumentMapping(const ArgumentMappingContext& ctx) {
bool reduce_all = paddle::any_cast<bool>(ctx.Attr("reduce_all")); bool reduce_all = paddle::any_cast<bool>(ctx.Attr("reduce_all"));
// When ctx is InferShapeArgumentMappingContext, the reduce_all is used in // When ctx is InferShapeArgumentMappingContext, the reduce_all is used in
// InferShape, so we must return the "mean_raw" KernelSignature. // InferShape, so we must return the "mean_raw" KernelSignature.
// And the InferMeta function(i.e. MeanRawInferMeta) is accordance with the // And the InferMeta function(i.e. ReduceInferMetaBase) is accordance with
// the
// "mean_raw" KernelSignature // "mean_raw" KernelSignature
if (ctx.IsForInferShape() || reduce_all) { if (ctx.IsForInferShape() || reduce_all) {
return KernelSignature( return KernelSignature(
...@@ -56,11 +57,30 @@ KernelSignature ReduceProdOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -56,11 +57,30 @@ KernelSignature ReduceProdOpArgumentMapping(const ArgumentMappingContext& ctx) {
"reduce_prod", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"}); "reduce_prod", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"});
} }
KernelSignature ReduceMaxOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorInput("X")) {
bool reduce_all = paddle::any_cast<bool>(ctx.Attr("reduce_all"));
// When ctx is InferShapeArgumentMappingContext, the reduce_all is used in
// InferShape, so we must return the "max_raw" KernelSignature.
// And the InferMeta function(i.e. ReduceInferMetaBase) is accordance with
// the
// "max_raw" KernelSignature
if (ctx.IsForInferShape() || reduce_all) {
return KernelSignature(
"max_raw", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"});
}
return KernelSignature("max", {"X"}, {"dim", "keep_dim"}, {"Out"});
}
return KernelSignature("unregistered", {}, {}, {});
}
} // namespace phi } // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(reduce_sum, sum); PD_REGISTER_BASE_KERNEL_NAME(reduce_sum, sum);
PD_REGISTER_BASE_KERNEL_NAME(reduce_mean, mean); PD_REGISTER_BASE_KERNEL_NAME(reduce_mean, mean);
PD_REGISTER_BASE_KERNEL_NAME(reduce_max, max);
PD_REGISTER_ARG_MAPPING_FN(reduce_sum, phi::ReduceSumOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(reduce_sum, phi::ReduceSumOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(reduce_mean, phi::ReduceMeanOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(reduce_mean, phi::ReduceMeanOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(reduce_prod, phi::ReduceProdOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(reduce_prod, phi::ReduceProdOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(reduce_max, phi::ReduceMaxOpArgumentMapping);
...@@ -124,7 +124,7 @@ ...@@ -124,7 +124,7 @@
args : (Tensor x, int64[] axis={}, bool keep_dim=false) args : (Tensor x, int64[] axis={}, bool keep_dim=false)
output : Tensor output : Tensor
infer_meta : infer_meta :
func : MeanInferMeta func : ReduceInferMeta
kernel : kernel :
func : mean func : mean
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册