未验证 提交 c46e661d 编写于 作者: C chentianyu03 提交者: GitHub

[Phi]move reduce_min/any/all kernel (#40374)

* add reduce_min kernel

* remove raw reduce_min kernel

* add reduce min

* add reduce any all impl

* add bool reduce Kernel

* remove raw any/all kernel

* add any all kernel

* rm comment
上级 36db75b4
......@@ -14,6 +14,10 @@
#include "paddle/fluid/operators/reduce_ops/reduce_all_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace framework {
class OpDesc;
......@@ -28,9 +32,17 @@ class CPUDeviceContext;
} // namespace platform
} // namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR(reduce_all, ReduceAllInferShapeFunctor,
PD_INFER_META(phi::ReduceInferMetaBase));
class ReduceAllOpMaker : public ops::ReduceOpMaker {
protected:
virtual std::string GetName() const { return "reduce_all"; }
virtual std::string GetOpType() const { return "Reduce reduce_all"; }
};
// kernel's device type is decided by input tensor place, to be consistent with
// compare and logical ops
REGISTER_REDUCE_OP_WITHOUT_GRAD(reduce_all, UseInputPlace);
REGISTER_OP_CPU_KERNEL(reduce_all,
ops::BoolReduceKernel<paddle::platform::CPUDeviceContext,
bool, ops::AllFunctor>);
REGISTER_OPERATOR(
reduce_all, ops::ReduceOpUseInputPlace, ReduceAllOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ReduceAllInferShapeFunctor);
......@@ -14,6 +14,9 @@
#include "paddle/fluid/operators/reduce_ops/reduce_any_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace framework {
class OpDesc;
......@@ -28,9 +31,18 @@ class CPUDeviceContext;
} // namespace platform
} // namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR(reduce_any, ReduceAnyInferShapeFunctor,
PD_INFER_META(phi::ReduceInferMetaBase));
class ReduceAnyOpMaker : public ops::ReduceOpMaker {
protected:
virtual std::string GetName() const { return "reduce_any"; }
virtual std::string GetOpType() const { return "Reduce reduce_any"; }
};
// kernel's device type is decided by input tensor place, to be consistent with
// compare and logical ops
REGISTER_REDUCE_OP_WITHOUT_GRAD(reduce_any, UseInputPlace);
REGISTER_OP_CPU_KERNEL(reduce_any,
ops::BoolReduceKernel<paddle::platform::CPUDeviceContext,
bool, ops::AnyFunctor>);
REGISTER_OPERATOR(
reduce_any, ops::ReduceOpUseInputPlace, ReduceAnyOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ReduceAnyInferShapeFunctor);
......@@ -35,7 +35,7 @@ namespace p = paddle::platform;
using Tensor = paddle::framework::Tensor;
USE_OP(reduce_any);
USE_OP_ITSELF(reduce_any);
USE_OP_DEVICE_KERNEL(reduce_any, NPU);
template <typename T>
......
......@@ -14,15 +14,28 @@
#include "paddle/fluid/operators/reduce_ops/reduce_min_max_op.h"
REGISTER_REDUCE_OP(reduce_min);
REGISTER_OP_CPU_KERNEL(
reduce_min, ops::ReduceKernel<paddle::platform::CPUDeviceContext, float,
ops::MinFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, double,
ops::MinFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, int, ops::MinFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, int64_t,
ops::MinFunctor>);
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace ops = paddle::operators;
class ReduceMinOpMaker : public ops::ReduceOpMaker {
protected:
virtual std::string GetName() const { return "reduce_min"; }
virtual std::string GetOpType() const { return "Reduce reduce_min"; }
};
DECLARE_INFER_SHAPE_FUNCTOR(reduce_min, ReduceMinInferShapeFunctor,
PD_INFER_META(phi::ReduceInferMetaBase));
REGISTER_OPERATOR(
reduce_min, ops::ReduceOp, ReduceMinOpMaker,
paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>,
paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>,
ReduceMinInferShapeFunctor);
REGISTER_OPERATOR(reduce_min_grad, ops::ReduceGradOp)
REGISTER_OP_CPU_KERNEL(
reduce_min_grad, ops::ReduceGradKernel<paddle::platform::CPUDeviceContext,
float, ops::MaxOrMinGradFunctor>,
......
......@@ -48,6 +48,9 @@ const std::unordered_set<std::string> deprecated_op_names({"diag",
"matmul_grad_grad",
"mean",
"max",
"min",
"any",
"all",
"reshape",
"reshape_grad",
"expand",
......
......@@ -239,4 +239,29 @@ void Reduce(const DeviceContext& dev_ctx,
}
}
template <typename DeviceContext, typename OutT, typename Functor>
void BoolReduceKernel(const DeviceContext& dev_ctx,
const phi::DenseTensor& input,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
phi::DenseTensor* output) {
dev_ctx.template Alloc<OutT>(output);
// The dims has full dim, set the reduce_all is True
const auto& input_dim_size = input.dims().size();
std::set<int> dims_set(dims.begin(), dims.end());
bool full_dim = true;
for (auto i = 0; i < input_dim_size; i++) {
if (dims_set.find(i) == dims_set.end()) {
full_dim = false;
break;
}
}
reduce_all = (reduce_all || full_dim);
ReduceKernelImpl<DeviceContext, bool, OutT, Functor>(
dev_ctx, input, output, dims, keep_dim, reduce_all);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/reduce_all_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/reduce.h"
#include "paddle/phi/kernels/funcs/reduce_functor.h"
namespace phi {
template <typename T, typename Context>
void AllRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
phi::BoolReduceKernel<CPUContext, T, phi::funcs::AllFunctor>(
dev_ctx, x, dims, keep_dim, reduce_all, out);
}
} // namespace phi
PD_REGISTER_KERNEL(all_raw, CPU, ALL_LAYOUT, phi::AllRawKernel, bool) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/reduce_any_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/reduce.h"
#include "paddle/phi/kernels/funcs/reduce_functor.h"
namespace phi {
template <typename T, typename Context>
void AnyRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
phi::BoolReduceKernel<CPUContext, T, phi::funcs::AnyFunctor>(
dev_ctx, x, dims, keep_dim, reduce_all, out);
}
} // namespace phi
PD_REGISTER_KERNEL(any_raw, CPU, ALL_LAYOUT, phi::AnyRawKernel, bool) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/reduce_min_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/reduce.h"
#include "paddle/phi/kernels/funcs/reduce_functor.h"
namespace phi {
template <typename T, typename Context>
void MinRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
auto out_dtype = x.dtype();
phi::Reduce<CPUContext, T, phi::funcs::MinFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
}
} // namespace phi
PD_REGISTER_KERNEL(
min_raw, CPU, ALL_LAYOUT, phi::MinRawKernel, float, double, int, int64_t) {}
......@@ -49,5 +49,29 @@ struct MaxFunctor {
}
};
//////// Min Functor ///////
struct MinFunctor {
template <typename DeviceContext, typename X, typename Y, typename Dim>
void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) {
y->device(place) = x->minimum(dim);
}
};
//////// All Functor ///////
struct AllFunctor {
template <typename DeviceContext, typename X, typename Y, typename Dim>
void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) {
y->device(place) = x->all(dim);
}
};
//////// Any Functor ///////
struct AnyFunctor {
template <typename DeviceContext, typename X, typename Y, typename Dim>
void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) {
y->device(place) = x->any(dim);
}
};
} // namespace funcs
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/reduce_all_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/gpu/reduce.h"
namespace phi {
template <typename T, typename Context>
void AllRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
auto out_dtype = x.dtype();
phi::Reduce<T, kps::LogicalAndFunctor, kps::IdentityFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
}
} // namespace phi
PD_REGISTER_KERNEL(all_raw, GPU, ALL_LAYOUT, phi::AllRawKernel, bool) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/reduce_any_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/gpu/reduce.h"
namespace phi {
template <typename T, typename Context>
void AnyRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
auto out_dtype = x.dtype();
phi::Reduce<T, kps::LogicalOrFunctor, kps::IdentityFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
}
} // namespace phi
PD_REGISTER_KERNEL(any_raw, GPU, ALL_LAYOUT, phi::AnyRawKernel, bool) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/reduce_min_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/gpu/reduce.h"
namespace phi {
template <typename T, typename Context>
void MinRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
auto out_dtype = x.dtype();
phi::Reduce<T, kps::MinFunctor, kps::IdentityFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
}
} // namespace phi
PD_REGISTER_KERNEL(
min_raw, GPU, ALL_LAYOUT, phi::MinRawKernel, float, double, int, int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/reduce_all_kernel.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void AllKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
DenseTensor* out) {
bool reduce_all = false;
AllRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
}
} // namespace phi
PD_REGISTER_KERNEL(all, CPU, ALL_LAYOUT, phi::AllKernel, bool) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(all, GPU, ALL_LAYOUT, phi::AllKernel, bool) {}
#endif
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -12,8 +12,24 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_all_op.h"
#pragma once
REGISTER_OP_CUDA_KERNEL(
reduce_all,
ops::ReduceCudaKernel<bool, kps::LogicalAndFunctor, kps::IdentityFunctor>);
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void AllRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DenseTensor* out);
template <typename T, typename Context>
void AllKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/reduce_any_kernel.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void AnyKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
DenseTensor* out) {
bool reduce_all = false;
AnyRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
}
} // namespace phi
PD_REGISTER_KERNEL(any, CPU, ALL_LAYOUT, phi::AnyKernel, bool) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(any, GPU, ALL_LAYOUT, phi::AnyKernel, bool) {}
#endif
// Copyright (c) 2018 PaddlePaddle Authors. Any Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -12,9 +12,24 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_any_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#pragma once
REGISTER_OP_CUDA_KERNEL(
reduce_any,
ops::ReduceCudaKernel<bool, kps::LogicalOrFunctor, kps::IdentityFunctor>);
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void AnyRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DenseTensor* out);
template <typename T, typename Context>
void AnyKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
DenseTensor* out);
} // namespace phi
......@@ -15,9 +15,6 @@
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/infermeta/binary.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/kernels/empty_kernel.h"
namespace phi {
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/reduce_min_kernel.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void MinKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
DenseTensor* out) {
bool reduce_all = false;
MinRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
}
} // namespace phi
PD_REGISTER_KERNEL(
min, CPU, ALL_LAYOUT, phi::MinKernel, float, double, int, int64_t) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(
min, GPU, ALL_LAYOUT, phi::MinKernel, float, double, int, int64_t) {}
#endif
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -11,13 +11,25 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
// reduce_min
REGISTER_OP_CUDA_KERNEL(
reduce_min,
ops::ReduceCudaKernel<float, kps::MinFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<double, kps::MinFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<int, kps::MinFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<int64_t, kps::MinFunctor, kps::IdentityFunctor>);
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void MinRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
bool reduce_all,
DenseTensor* out);
template <typename T, typename Context>
void MinKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& dims,
bool keep_dim,
DenseTensor* out);
} // namespace phi
......@@ -41,8 +41,7 @@ KernelSignature ReduceMeanOpArgumentMapping(const ArgumentMappingContext& ctx) {
// When ctx is InferShapeArgumentMappingContext, the reduce_all is used in
// InferShape, so we must return the "mean_raw" KernelSignature.
// And the InferMeta function(i.e. ReduceInferMetaBase) is accordance with
// the
// "mean_raw" KernelSignature
// the "mean_raw" KernelSignature
if (ctx.IsForInferShape() || reduce_all) {
return KernelSignature(
"mean_raw", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"});
......@@ -63,8 +62,7 @@ KernelSignature ReduceMaxOpArgumentMapping(const ArgumentMappingContext& ctx) {
// When ctx is InferShapeArgumentMappingContext, the reduce_all is used in
// InferShape, so we must return the "max_raw" KernelSignature.
// And the InferMeta function(i.e. ReduceInferMetaBase) is accordance with
// the
// "max_raw" KernelSignature
// the "max_raw" KernelSignature
if (ctx.IsForInferShape() || reduce_all) {
return KernelSignature(
"max_raw", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"});
......@@ -74,6 +72,54 @@ KernelSignature ReduceMaxOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("unregistered", {}, {}, {});
}
KernelSignature ReduceMinOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorInput("X")) {
bool reduce_all = paddle::any_cast<bool>(ctx.Attr("reduce_all"));
// When ctx is InferShapeArgumentMappingContext, the reduce_all is used in
// InferShape, so we must return the "min_raw" KernelSignature.
// And the InferMeta function(i.e. ReduceInferMetaBase) is accordance with
// the "min_raw" KernelSignature
if (ctx.IsForInferShape() || reduce_all) {
return KernelSignature(
"min_raw", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"});
}
return KernelSignature("min", {"X"}, {"dim", "keep_dim"}, {"Out"});
}
return KernelSignature("unregistered", {}, {}, {});
}
KernelSignature ReduceAnyOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorInput("X")) {
bool reduce_all = paddle::any_cast<bool>(ctx.Attr("reduce_all"));
// When ctx is InferShapeArgumentMappingContext, the reduce_all is used in
// InferShape, so we must return the "any_raw" KernelSignature.
// And the InferMeta function(i.e. ReduceInferMetaBase) is accordance with
// the "any_raw" KernelSignature
if (ctx.IsForInferShape() || reduce_all) {
return KernelSignature(
"any_raw", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"});
}
return KernelSignature("any", {"X"}, {"dim", "keep_dim"}, {"Out"});
}
return KernelSignature("unregistered", {}, {}, {});
}
KernelSignature ReduceAllOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorInput("X")) {
bool reduce_all = paddle::any_cast<bool>(ctx.Attr("reduce_all"));
// When ctx is InferShapeArgumentMappingContext, the reduce_all is used in
// InferShape, so we must return the "all_raw" KernelSignature.
// And the InferMeta function(i.e. ReduceInferMetaBase) is accordance with
// the "all_raw" KernelSignature
if (ctx.IsForInferShape() || reduce_all) {
return KernelSignature(
"all_raw", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"});
}
return KernelSignature("all", {"X"}, {"dim", "keep_dim"}, {"Out"});
}
return KernelSignature("unregistered", {}, {}, {});
}
KernelSignature ReduceSumGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
......@@ -88,11 +134,19 @@ KernelSignature ReduceSumGradOpArgumentMapping(
PD_REGISTER_BASE_KERNEL_NAME(reduce_sum, sum);
PD_REGISTER_BASE_KERNEL_NAME(reduce_mean, mean);
PD_REGISTER_BASE_KERNEL_NAME(reduce_max, max);
PD_REGISTER_BASE_KERNEL_NAME(reduce_min, min);
PD_REGISTER_BASE_KERNEL_NAME(reduce_all, all);
PD_REGISTER_BASE_KERNEL_NAME(reduce_any, any);
PD_REGISTER_BASE_KERNEL_NAME(reduce_sum_grad, sum_grad);
PD_REGISTER_ARG_MAPPING_FN(reduce_sum, phi::ReduceSumOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(reduce_mean, phi::ReduceMeanOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(reduce_prod, phi::ReduceProdOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(reduce_max, phi::ReduceMaxOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(reduce_min, phi::ReduceMinOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(reduce_all, phi::ReduceAllOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(reduce_any, phi::ReduceAnyOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(reduce_sum_grad,
phi::ReduceSumGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册