未验证 提交 0969a4eb 编写于 作者: F From00 提交者: GitHub

Move compare OPs to phi (#39970)

* Move compare OPs to phi

* Fix bug

* Use BroadcastKernel and ElementwiseKernel in phi
上级 4c0511fa
......@@ -19,6 +19,6 @@ else()
target_link_libraries(conditional_block_infer_op conditional_block_op)
endif()
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(equal_all);\nUSE_NO_KERNEL_OP(read_from_array);\n")
file(APPEND ${pybind_file} "USE_OP_ITSELF(less_than);\nUSE_OP_ITSELF(equal_all);\nUSE_NO_KERNEL_OP(read_from_array);\n")
file(APPEND ${pybind_file} "USE_OP_ITSELF(logical_and);\nUSE_OP_ITSELF(logical_or);\nUSE_OP_ITSELF(logical_xor);\nUSE_OP_ITSELF(logical_not);\n")
file(APPEND ${pybind_file} "USE_OP(bitwise_and);\nUSE_OP(bitwise_or);\nUSE_OP(bitwise_xor);\nUSE_OP(bitwise_not);\n")
......@@ -12,49 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/controlflow/compare_all_op.h"
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename Functor>
class CompareReduceOpKernel
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
public:
void Compute(const framework::ExecutionContext& context) const override {
using T = typename Functor::ELEM_TYPE;
using Tensor = framework::Tensor;
auto* x = context.Input<Tensor>("X");
auto* y = context.Input<Tensor>("Y");
auto* z = context.Output<Tensor>("Out");
Tensor tmp;
bool* z_data = z->mutable_data<bool>(context.GetPlace());
if (x->dims() != y->dims()) {
z_data[0] = false;
} else {
tmp.mutable_data<bool>(x->dims(), context.GetPlace());
if (x->numel() == 1 && y->numel() == 1) {
bool* z_data = tmp.mutable_data<bool>(context.GetPlace());
z_data[0] = Functor()(x->data<T>()[0], y->data<T>()[0]);
} else {
ElementwiseComputeEx<Functor, platform::CPUDeviceContext, T, bool>(
context, x, y, 0, Functor(), &tmp);
}
auto ipt = framework::EigenVector<bool>::Flatten(tmp);
auto out = framework::EigenScalar<bool>::From(*z);
auto& place =
*context.template device_context<platform::CPUDeviceContext>()
.eigen_device();
auto reduce_dim = Eigen::array<int, 1>({{0}});
out.device(place) = ipt.all(reduce_dim);
}
}
};
template <typename OpComment>
class CompareReduceOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
......@@ -81,26 +46,6 @@ template <typename OpComment>
class CompareReduceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* context) const override {
OpComment comment;
PADDLE_ENFORCE_EQ(context->HasInput("X"), true,
platform::errors::InvalidArgument(
"%s operator must have input X", comment.type));
PADDLE_ENFORCE_EQ(context->HasInput("Y"), true,
platform::errors::InvalidArgument(
"%s operator must have input Y", comment.type));
auto dim_x = context->GetInputDim("X");
auto dim_y = context->GetInputDim("Y");
PADDLE_ENFORCE_GE(
dim_x.size(), dim_y.size(),
platform::errors::InvalidArgument(
"The size of dim_y should not be greater than dim_x's."));
context->SetOutputDim("Out", {1});
context->ShareLoD("X", "Out");
}
};
} // namespace operators
......@@ -113,25 +58,13 @@ class CompareReduceOp : public framework::OperatorWithKernel {
}; \
char _##op_type##Comment::type[]{#op_type}; \
char _##op_type##Comment::equation[]{_equation}; \
DELCARE_INFER_SHAPE_FUNCTOR(op_type, op_type##_InferShapeFunctor, \
PT_INFER_META(phi::CompareAllInferMeta)); \
REGISTER_OPERATOR( \
op_type, ::paddle::operators::CompareReduceOp<_##op_type##Comment>, \
::paddle::operators::CompareReduceOpProtoMaker<_##op_type##Comment>, \
::paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, \
::paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
::paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>, \
op_type##_InferShapeFunctor);
#define REGISTER_COMPARE_REDUCE_CPU_KERNEL(op_type, functor) \
REGISTER_OP_CPU_KERNEL( \
op_type, ::paddle::operators::CompareReduceOpKernel< \
::paddle::platform::CPUDeviceContext, functor<bool>>, \
::paddle::operators::CompareReduceOpKernel< \
::paddle::platform::CPUDeviceContext, functor<int>>, \
::paddle::operators::CompareReduceOpKernel< \
::paddle::platform::CPUDeviceContext, functor<int64_t>>, \
::paddle::operators::CompareReduceOpKernel< \
::paddle::platform::CPUDeviceContext, functor<float>>, \
::paddle::operators::CompareReduceOpKernel< \
::paddle::platform::CPUDeviceContext, functor<double>>);
REGISTER_COMPARE_REDUCE_OP(equal_all, "X == Y");
REGISTER_COMPARE_REDUCE_CPU_KERNEL(equal_all,
paddle::operators::EqualReduceFunctor);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/fill.h>
#include "paddle/fluid/operators/controlflow/compare_all_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
namespace paddle {
namespace operators {
template <typename T>
struct BitwiseAdd {
// Bitwise add operator, returns <tt>a + b</tt>
inline T initial() { return static_cast<T>(true); }
__host__ __device__ __forceinline__ T operator()(const T& a,
const T& b) const {
return a & b;
}
};
template <typename DeviceContext, typename Functor>
class CompareReduceOpKernel
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
public:
void Compute(const framework::ExecutionContext& context) const override {
using T = typename Functor::ELEM_TYPE;
using Tensor = framework::Tensor;
auto* x = context.Input<Tensor>("X");
auto* y = context.Input<Tensor>("Y");
auto* z = context.Output<Tensor>("Out");
bool* z_data = z->mutable_data<bool>(context.GetPlace());
Tensor tmp;
if (x->dims() != y->dims()) {
thrust::device_ptr<bool> z_dev_ptr(z_data);
thrust::fill(z_dev_ptr, z_dev_ptr + 1, false);
return;
} else {
tmp.mutable_data<bool>(x->dims(), context.GetPlace());
const auto& cuda_ctx =
context.template device_context<platform::CUDADeviceContext>();
std::vector<const framework::Tensor*> ins = {x, y};
std::vector<framework::Tensor*> outs = {&tmp};
paddle::operators::LaunchSameDimsElementwiseCudaKernel<bool>(
cuda_ctx, ins, &outs, Functor());
// Reduce by 'bitwise and' operator
std::vector<int> reduce_dims;
reduce_dims.resize(tmp.dims().size());
for (int i = 0; i < reduce_dims.size(); ++i) reduce_dims[i] = i;
auto stream = context.cuda_device_context().stream();
TensorReduceImpl<bool, bool, BitwiseAdd, kps::IdentityFunctor<bool>>(
context.cuda_device_context(), tmp, z, kps::IdentityFunctor<bool>(),
reduce_dims, stream);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
#define REGISTER_COMPARE_REDUCE_CUDA_KERNEL(op_type, functor) \
REGISTER_OP_CUDA_KERNEL( \
op_type, \
ops::CompareReduceOpKernel<plat::CUDADeviceContext, ops::functor<bool>>, \
ops::CompareReduceOpKernel<plat::CUDADeviceContext, ops::functor<int>>, \
ops::CompareReduceOpKernel<plat::CUDADeviceContext, \
ops::functor<int64_t>>, \
ops::CompareReduceOpKernel<plat::CUDADeviceContext, \
ops::functor<float>>, \
ops::CompareReduceOpKernel<plat::CUDADeviceContext, \
ops::functor<double>>);
REGISTER_COMPARE_REDUCE_CUDA_KERNEL(equal_all, EqualReduceFunctor)
#undef REGISTER_COMPARE_REDUCE_CUDA_KERNEL
......@@ -12,14 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/controlflow/compare_op.h"
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
......@@ -60,31 +58,6 @@ class CompareOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* context) const override {
OpComment comment;
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", comment.type);
OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y", comment.type);
auto dim_x = context->GetInputDim("X");
auto dim_y = context->GetInputDim("Y");
if (context->GetInputDim("X") == context->GetInputDim("Y")) {
context->ShareDim("X", /*->*/ "Out");
context->ShareLoD("X", /*->*/ "Out");
} else {
int max_dim = std::max(dim_x.size(), dim_y.size());
int axis = std::abs(dim_x.size() - dim_y.size());
std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
GetBroadcastDimsArrays(dim_x, dim_y, x_dims_array.data(),
y_dims_array.data(), out_dims_array.data(),
max_dim, axis);
context->SetOutputDim("Out", phi::make_ddim(out_dims_array));
// to do
context->ShareLoD("X", /*->*/ "Out");
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
......@@ -116,37 +89,31 @@ class CompareOp : public framework::OperatorWithKernel {
"In order to force fill output variable to gpu memory.", \
false));
#define REGISTER_COMPARE_OP(op_type, _equation) \
struct _##op_type##Comment { \
static char type[]; \
static char equation[]; \
}; \
char _##op_type##Comment::type[]{#op_type}; \
char _##op_type##Comment::equation[]{_equation}; \
REGISTER_OPERATOR( \
op_type, ::paddle::operators::CompareOp<_##op_type##Comment>, \
::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \
::paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, \
::paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); \
#define REGISTER_COMPARE_OP(op_type, _equation) \
struct _##op_type##Comment { \
static char type[]; \
static char equation[]; \
}; \
char _##op_type##Comment::type[]{#op_type}; \
char _##op_type##Comment::equation[]{_equation}; \
DELCARE_INFER_SHAPE_FUNCTOR(op_type, op_type##_InferShapeFunctor, \
PT_INFER_META(phi::CompareInferMeta)); \
REGISTER_OPERATOR( \
op_type, ::paddle::operators::CompareOp<_##op_type##Comment>, \
::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \
::paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, \
::paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>, \
op_type##_InferShapeFunctor); \
REGISTER_COMPARE_OP_VERSION(op_type);
REGISTER_COMPARE_OP(less_than, "Out = X < Y");
REGISTER_COMPARE_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor,
paddle::operators::GreaterThanFunctor);
REGISTER_COMPARE_OP(less_equal, "Out = X <= Y");
REGISTER_COMPARE_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor,
paddle::operators::GreaterEqualFunctor);
REGISTER_COMPARE_OP(greater_than, "Out = X > Y");
REGISTER_COMPARE_KERNEL(greater_than, CPU,
paddle::operators::GreaterThanFunctor,
paddle::operators::LessThanFunctor);
REGISTER_COMPARE_OP(greater_equal, "Out = X >= Y");
REGISTER_COMPARE_KERNEL(greater_equal, CPU,
paddle::operators::GreaterEqualFunctor,
paddle::operators::LessEqualFunctor);
REGISTER_COMPARE_OP(equal, "Out = X == Y");
REGISTER_COMPARE_KERNEL(equal, CPU, paddle::operators::EqualFunctor,
paddle::operators::EqualFunctor);
REGISTER_COMPARE_OP(not_equal, "Out = X != Y");
REGISTER_COMPARE_KERNEL(not_equal, CPU, paddle::operators::NotEqualFunctor,
paddle::operators::NotEqualFunctor);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/controlflow/compare_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
namespace paddle {
namespace operators {
template <typename Functor, typename InverseFunctor>
class CompareOpKernel<platform::CUDADeviceContext, Functor, InverseFunctor>
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
public:
using InT = typename Functor::ELEM_TYPE;
using OutT = bool;
void Compute(const framework::ExecutionContext& ctx) const override {
auto functor = Functor();
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
int axis = PackTensorsIntoVector<OutT>(ctx, &ins, &outs);
paddle::operators::LaunchElementwiseCudaKernel<ElementwiseType::kBinary,
InT, OutT>(
cuda_ctx, ins, &outs, axis, functor);
}
};
} // namespace operators
} // namespace paddle
#define REGISTER_CUDA_COMPARE_KERNEL(op_type, func) \
REGISTER_OP_CUDA_KERNEL( \
op_type, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<bool>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<int16_t>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<int>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<int64_t>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<float>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<double>, void>);
REGISTER_CUDA_COMPARE_KERNEL(equal, EqualFunctor)
REGISTER_CUDA_COMPARE_KERNEL(not_equal, NotEqualFunctor)
REGISTER_CUDA_COMPARE_KERNEL(less_than, LessThanFunctor)
REGISTER_CUDA_COMPARE_KERNEL(less_equal, LessEqualFunctor)
REGISTER_CUDA_COMPARE_KERNEL(greater_than, GreaterThanFunctor)
REGISTER_CUDA_COMPARE_KERNEL(greater_equal, GreaterEqualFunctor)
#undef REGISTER_CUDA_COMPARE_KERNEL
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <math.h>
#include <type_traits>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/platform/transform.h"
namespace paddle {
namespace operators {
#define COMPARE_FUNCTOR(func_name, op) \
template <typename InT, typename OutT = bool> \
struct func_name { \
using ELEM_TYPE = InT; \
HOSTDEVICE OutT operator()(const InT a, const InT b) const { \
return static_cast<OutT>(a op b); \
} \
};
COMPARE_FUNCTOR(LessThanFunctor, <)
COMPARE_FUNCTOR(LessEqualFunctor, <=)
COMPARE_FUNCTOR(GreaterThanFunctor, >)
COMPARE_FUNCTOR(GreaterEqualFunctor, >=)
#undef COMPARE_FUNCTOR
template <typename InT, typename OutT = bool>
struct EqualFunctor {
using ELEM_TYPE = InT;
HOSTDEVICE OutT operator()(const InT a, const InT b) const {
if (std::is_floating_point<InT>::value) {
// This branch will be optimized while compiling if T is integer. It is
// safe to cast a and b to double.
return static_cast<OutT>(fabs(static_cast<double>(a - b)) < 1e-8);
} else {
return static_cast<OutT>(a == b);
}
}
};
template <typename InT, typename OutT = bool>
struct NotEqualFunctor {
using ELEM_TYPE = InT;
HOSTDEVICE bool operator()(const InT a, const InT b) const {
return !EqualFunctor<InT, OutT>()(a, b);
}
};
template <typename DeviceContext, typename Functor, typename InverseFunctor>
class CompareOpKernel
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
public:
void Compute(const framework::ExecutionContext& context) const override {
using T = typename Functor::ELEM_TYPE;
using Tensor = framework::Tensor;
auto* x = context.Input<Tensor>("X");
auto* y = context.Input<Tensor>("Y");
auto* z = context.Output<Tensor>("Out");
int axis = context.Attr<int>("axis");
auto x_dims = x->dims();
auto y_dims = y->dims();
if (x_dims.size() >= y_dims.size()) {
ElementwiseComputeEx<Functor, DeviceContext, T, bool>(context, x, y, axis,
Functor(), z);
} else {
ElementwiseComputeEx<InverseFunctor, DeviceContext, T, bool>(
context, x, y, axis, InverseFunctor(), z);
}
}
};
} // namespace operators
} // namespace paddle
#define REGISTER_COMPARE_KERNEL(op_type, dev, functor, inverse_functor) \
REGISTER_OP_##dev##_KERNEL(op_type, \
::paddle::operators::CompareOpKernel< \
::paddle::platform::dev##DeviceContext, \
functor<bool>, inverse_functor<bool>>, \
::paddle::operators::CompareOpKernel< \
::paddle::platform::dev##DeviceContext, \
functor<int>, inverse_functor<int>>, \
::paddle::operators::CompareOpKernel< \
::paddle::platform::dev##DeviceContext, \
functor<int16_t>, inverse_functor<int16_t>>, \
::paddle::operators::CompareOpKernel< \
::paddle::platform::dev##DeviceContext, \
functor<int64_t>, inverse_functor<int64_t>>, \
::paddle::operators::CompareOpKernel< \
::paddle::platform::dev##DeviceContext, \
functor<float>, inverse_functor<float>>, \
::paddle::operators::CompareOpKernel< \
::paddle::platform::dev##DeviceContext, \
functor<double>, inverse_functor<double>>);
......@@ -11,7 +11,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/controlflow/compare_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
......
......@@ -12,7 +12,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/controlflow/compare_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
......
......@@ -17,6 +17,7 @@
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/svd_helper.h"
#include "paddle/phi/kernels/funcs/compare_functors.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
......@@ -224,15 +225,15 @@ class MatrixRankCPUKernel : public framework::OpKernel<T> {
int axis = -1;
if (eigenvalue_tensor.dims().size() >= tol_tensor.dims().size()) {
ElementwiseComputeEx<GreaterThanFunctor<T, int64_t>,
ElementwiseComputeEx<phi::funcs::GreaterThanFunctor<T, int64_t>,
platform::CPUDeviceContext, T, int>(
context, &eigenvalue_tensor, &tol_tensor, axis,
GreaterThanFunctor<T, int64_t>(), &compare_result);
phi::funcs::GreaterThanFunctor<T, int64_t>(), &compare_result);
} else {
ElementwiseComputeEx<LessThanFunctor<T, int64_t>,
ElementwiseComputeEx<phi::funcs::LessThanFunctor<T, int64_t>,
platform::CPUDeviceContext, T, int>(
context, &eigenvalue_tensor, &tol_tensor, axis,
LessThanFunctor<T, int64_t>(), &compare_result);
phi::funcs::LessThanFunctor<T, int64_t>(), &compare_result);
}
auto dito_int =
math::DeviceIndependenceTensorOperations<platform::CPUDeviceContext,
......
......@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/operators/svd_helper.h"
#include "paddle/fluid/platform/dynload/cusolver.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/compare_functors.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......@@ -129,10 +130,10 @@ class MatrixRankGPUKernel : public framework::OpKernel<T> {
compare_result.mutable_data<int64_t>(detail::NewAxisDim(dim_out, k),
context.GetPlace());
int axis = -1;
ElementwiseComputeEx<GreaterThanFunctor<T, int64_t>,
ElementwiseComputeEx<phi::funcs::GreaterThanFunctor<T, int64_t>,
platform::CUDADeviceContext, T, int64_t>(
context, &eigenvalue_tensor, &tol_tensor, axis,
GreaterThanFunctor<T, int64_t>(), &compare_result);
phi::funcs::GreaterThanFunctor<T, int64_t>(), &compare_result);
auto dito_int =
math::DeviceIndependenceTensorOperations<platform::CUDADeviceContext,
int64_t>(context);
......
......@@ -15,7 +15,6 @@
#pragma once
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/controlflow/compare_op.h"
#include "paddle/phi/core/ddim.h"
namespace paddle {
......
......@@ -12,7 +12,7 @@ limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/operators/controlflow/compare_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/metrics/accuracy_op.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
......
......@@ -14,12 +14,13 @@ limitations under the License. */
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/operators/controlflow/compare_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/elementwise/elementwise_functor.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/operators/unique_op.h"
#include "paddle/phi/kernels/funcs/compare_functors.h"
#include "paddle/phi/kernels/funcs/gather.h"
#ifdef PADDLE_WITH_MKLML
#include <omp.h>
......@@ -353,8 +354,8 @@ class ViterbiDecodeKernel : public framework::OpKernel<T> {
BinaryOperation<DeviceContext, SubFunctor, int64_t> SubInt;
if (include_bos_eos_tag) {
AddFloat(dev_ctx, logit0, start_trans, &alpha);
GetMask<DeviceContext, EqualFunctor, T>()(ctx, left_length, one,
&float_mask);
GetMask<DeviceContext, phi::funcs::EqualFunctor, T>()(ctx, left_length,
one, &float_mask);
MulFloat(dev_ctx, stop_trans, float_mask, &alpha_nxt);
AddFloat(dev_ctx, alpha, alpha_nxt, &alpha);
} else {
......@@ -375,8 +376,8 @@ class ViterbiDecodeKernel : public framework::OpKernel<T> {
alpha.Resize({batch_size, n_labels});
// mask = paddle.cast((left_length > 0), dtype='float32')
// alpha = mask * alpha_nxt + (1 - mask) * alpha
GetMask<DeviceContext, GreaterThanFunctor, T>()(ctx, left_length, zero,
&float_mask);
GetMask<DeviceContext, phi::funcs::GreaterThanFunctor, T>()(
ctx, left_length, zero, &float_mask);
// alpha_nxt = mask * alpha_nxt
MulFloat(dev_ctx, alpha_nxt, float_mask, &alpha_nxt);
// inv_mask = 1 - mask
......@@ -386,8 +387,8 @@ class ViterbiDecodeKernel : public framework::OpKernel<T> {
// alpha += alpha_nxt
AddFloat(dev_ctx, alpha, alpha_nxt, &alpha);
if (include_bos_eos_tag) {
GetMask<DeviceContext, EqualFunctor, T>()(ctx, left_length, one,
&float_mask);
GetMask<DeviceContext, phi::funcs::EqualFunctor, T>()(ctx, left_length,
one, &float_mask);
// alpha += mask * trans_exp[:, self.stop_idx]
MulFloat(dev_ctx, stop_trans, float_mask, &alpha_nxt);
AddFloat(dev_ctx, alpha, alpha_nxt, &alpha);
......@@ -396,8 +397,8 @@ class ViterbiDecodeKernel : public framework::OpKernel<T> {
}
argmax(ctx, alpha, &last_ids, scores, 1);
left_length.Resize({batch_size});
GetMask<DeviceContext, GreaterEqualFunctor, int64_t>()(ctx, left_length,
zero, &int_mask);
GetMask<DeviceContext, phi::funcs::GreaterEqualFunctor, int64_t>()(
ctx, left_length, zero, &int_mask);
// last_ids_update = last_ids * tag_mask
int last_ids_index = 1;
int actual_len = (std::min)(seq_len, static_cast<int>(max_seq_len));
......@@ -416,17 +417,17 @@ class ViterbiDecodeKernel : public framework::OpKernel<T> {
batch_path[actual_len - last_ids_index];
hist->Resize({batch_size * n_labels});
gather(dev_ctx, *hist, gather_idx, &last_ids_update);
GetMask<DeviceContext, GreaterThanFunctor, int64_t>()(ctx, left_length,
zero, &int_mask);
GetMask<DeviceContext, phi::funcs::GreaterThanFunctor, int64_t>()(
ctx, left_length, zero, &int_mask);
MulInt(dev_ctx, last_ids_update, int_mask, &last_ids_update);
GetMask<DeviceContext, EqualFunctor, int64_t>()(ctx, left_length, zero,
&zero_len_mask);
GetMask<DeviceContext, phi::funcs::EqualFunctor, int64_t>()(
ctx, left_length, zero, &zero_len_mask);
MulInt(dev_ctx, last_ids, zero_len_mask, &last_ids_tmp);
SubInt(dev_ctx, one, zero_len_mask, &zero_len_mask);
MulInt(dev_ctx, last_ids_update, zero_len_mask, &last_ids_update);
AddInt(dev_ctx, last_ids_update, last_ids_tmp, &last_ids_update);
GetMask<DeviceContext, LessThanFunctor, int64_t>()(ctx, left_length, zero,
&int_mask);
GetMask<DeviceContext, phi::funcs::LessThanFunctor, int64_t>()(
ctx, left_length, zero, &int_mask);
MulInt(dev_ctx, last_ids, int_mask, &last_ids);
AddInt(dev_ctx, last_ids_update, last_ids, &last_ids);
}
......
......@@ -13,11 +13,60 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/infermeta/binary.h"
#include <algorithm>
#include <vector>
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
namespace phi {
void CompareInferMeta(const MetaTensor& x,
const MetaTensor& y,
int axis,
MetaTensor* out) {
auto dim_x = x.dims();
auto dim_y = y.dims();
if (dim_x == dim_y) {
out->share_meta(x);
} else {
int max_dim = std::max(dim_x.size(), dim_y.size());
int axis = std::abs(dim_x.size() - dim_y.size());
std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
funcs::GetBroadcastDimsArrays(dim_x,
dim_y,
x_dims_array.data(),
y_dims_array.data(),
out_dims_array.data(),
max_dim,
axis);
out->set_dims(make_ddim(out_dims_array));
out->share_lod(x);
}
out->set_dtype(DataType::BOOL);
}
void CompareAllInferMeta(const MetaTensor& x,
const MetaTensor& y,
MetaTensor* out) {
auto dim_x = x.dims();
auto dim_y = y.dims();
PADDLE_ENFORCE_GE(
dim_x.size(),
dim_y.size(),
errors::InvalidArgument(
"The size of dim_y should not be greater than dim_x's."));
out->share_lod(x);
out->set_dims(make_ddim({1}));
out->set_dtype(DataType::BOOL);
}
void DotInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) {
auto x_dims = x.dims();
auto x_rank = static_cast<size_t>(x_dims.size());
......
......@@ -29,6 +29,15 @@ namespace phi {
// Because functions in this file not only can infer shape, but also need
// infer lod or other useful data.
void CompareInferMeta(const MetaTensor& x,
const MetaTensor& y,
int axis,
MetaTensor* out);
void CompareAllInferMeta(const MetaTensor& x,
const MetaTensor& y,
MetaTensor* out);
void DotInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out);
void MatmulInferMeta(const MetaTensor& x,
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -13,31 +13,35 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <math.h>
#include <algorithm>
#include <type_traits>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/platform/transform.h"
namespace paddle {
namespace operators {
template <typename T>
struct EqualReduceFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T a, const T b) const {
if (std::is_floating_point<T>::value) {
// This branch will be optimized while compiling if T is integer. It is
// safe to cast a and b to double.
return fabs(static_cast<double>(a - b)) < 1e-8;
} else {
return (a == b);
}
}
};
} // namespace operators
} // namespace paddle
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
#define DECALRE_COMPARE_KERNEL(compare_kernel) \
template <typename T, typename Context> \
void compare_kernel(const Context& ctx, \
const DenseTensor& x, \
const DenseTensor& y, \
int axis, \
DenseTensor* out);
DECALRE_COMPARE_KERNEL(LessThanKernel)
DECALRE_COMPARE_KERNEL(LessEqualKernel)
DECALRE_COMPARE_KERNEL(GreaterThanKernel)
DECALRE_COMPARE_KERNEL(GreaterEqualKernel)
DECALRE_COMPARE_KERNEL(EqualKernel)
DECALRE_COMPARE_KERNEL(NotEqualKernel)
#undef DECALRE_COMPARE_KERNEL
#define DECALRE_COMPARE_ALL_KERNEL(compare_all_kernel) \
template <typename T, typename Context> \
void compare_all_kernel(const Context& ctx, \
const DenseTensor& x, \
const DenseTensor& y, \
DenseTensor* out);
DECALRE_COMPARE_ALL_KERNEL(EqualAll)
#undef DECALRE_COMPARE_KERNEL
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/compare_kernel.h"
#include "paddle/phi/kernels/impl/compare_kernel_impl.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
namespace phi {
template <typename T,
typename Context,
typename Functor,
typename InverseFunctor>
inline void CompareKernelImpl(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out) {
ctx.template Alloc<bool>(out);
if (x.dims().size() >= y.dims().size()) {
funcs::ElementwiseCompute<Functor, T, bool>(
ctx, x, y, axis, Functor(), out);
} else {
funcs::ElementwiseCompute<InverseFunctor, T, bool>(
ctx, x, y, axis, InverseFunctor(), out);
}
}
template <typename T, typename Context, typename Functor>
inline void CompareAllKernelImpl(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
bool* out_data = ctx.template Alloc<bool>(out);
if (x.dims() != y.dims()) {
out_data[0] = false;
} else {
DenseTensor tmp;
tmp.Resize(x.dims());
ctx.template Alloc<bool>(&tmp);
if (x.numel() == 1 && y.numel() == 1) {
bool* tmp_data = tmp.data<bool>();
tmp_data[0] = Functor()(x.data<T>()[0], y.data<T>()[0]);
} else {
funcs::ElementwiseCompute<Functor, T, bool>(
ctx, x, y, 0, Functor(), &tmp);
}
auto tmp_flat = EigenVector<bool>::Flatten(tmp);
auto out_es = EigenScalar<bool>::From(*out);
auto& place = *ctx.eigen_device();
auto reduce_dim = Eigen::array<int, 1>({{0}});
out_es.device(place) = tmp_flat.all(reduce_dim);
}
}
} // namespace phi
PD_REGISTER_KERNEL(less_than,
CPU,
ALL_LAYOUT,
phi::LessThanKernel,
bool,
int16_t,
int,
int64_t,
float,
double) {}
PD_REGISTER_KERNEL(less_equal,
CPU,
ALL_LAYOUT,
phi::LessEqualKernel,
bool,
int16_t,
int,
int64_t,
float,
double) {}
PD_REGISTER_KERNEL(greater_than,
CPU,
ALL_LAYOUT,
phi::GreaterThanKernel,
bool,
int16_t,
int,
int64_t,
float,
double) {}
PD_REGISTER_KERNEL(greater_equal,
CPU,
ALL_LAYOUT,
phi::GreaterEqualKernel,
bool,
int16_t,
int,
int64_t,
float,
double) {}
PD_REGISTER_KERNEL(equal,
CPU,
ALL_LAYOUT,
phi::EqualKernel,
bool,
int16_t,
int,
int64_t,
float,
double) {}
PD_REGISTER_KERNEL(not_equal,
CPU,
ALL_LAYOUT,
phi::NotEqualKernel,
bool,
int16_t,
int,
int64_t,
float,
double) {}
PD_REGISTER_KERNEL(equal_all,
CPU,
ALL_LAYOUT,
phi::EqualAllKernel,
bool,
int,
int64_t,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace phi {
namespace funcs {
#define COMPARE_FUNCTOR(func_name, op) \
template <typename InT, typename OutT = bool> \
struct func_name { \
HOSTDEVICE OutT operator()(const InT a, const InT b) const { \
return static_cast<OutT>(a op b); \
} \
};
COMPARE_FUNCTOR(LessThanFunctor, <)
COMPARE_FUNCTOR(LessEqualFunctor, <=)
COMPARE_FUNCTOR(GreaterThanFunctor, >)
COMPARE_FUNCTOR(GreaterEqualFunctor, >=)
#undef COMPARE_FUNCTOR
template <typename InT, typename OutT = bool>
struct EqualFunctor {
HOSTDEVICE OutT operator()(const InT a, const InT b) const {
if (std::is_floating_point<InT>::value) {
return static_cast<OutT>(fabs(static_cast<double>(a - b)) < 1e-8);
} else {
return static_cast<OutT>(a == b);
}
}
};
template <typename InT, typename OutT = bool>
struct NotEqualFunctor {
HOSTDEVICE bool operator()(const InT a, const InT b) const {
return !EqualFunctor<InT, OutT>()(a, b);
}
};
} // namespace funcs
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/compare_kernel.h"
#include "paddle/phi/kernels/impl/compare_kernel_impl.h"
#include <thrust/fill.h>
#include <vector>
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/gpu/reduce.h"
#include "paddle/phi/kernels/primitive/functor_primitives.h"
namespace phi {
template <typename T>
struct BitwiseAdd {
// Bitwise add operator, returns <tt>a + b</tt>
inline T initial() { return static_cast<T>(true); }
__host__ __device__ __forceinline__ T operator()(const T& a,
const T& b) const {
return a & b;
}
};
template <typename T,
typename Context,
typename Functor,
typename InverseFunctor>
inline void CompareKernelImpl(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out) {
ctx.template Alloc<bool>(out);
std::vector<const DenseTensor*> ins{&x, &y};
std::vector<DenseTensor*> outs{out};
funcs::BroadcastKernel<ElementwiseType::kBinary, T, bool>(
ctx, ins, &outs, axis, Functor());
}
template <typename T, typename Context, typename Functor>
inline void CompareAllKernelImpl(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
bool* out_data = ctx.template Alloc<bool>(out);
if (x.dims() != y.dims()) {
thrust::device_ptr<bool> out_dev_ptr(out_data);
thrust::fill(out_dev_ptr, out_dev_ptr + 1, false);
return;
}
DenseTensor tmp;
tmp.Resize(x.dims());
ctx.template Alloc<bool>(&tmp);
std::vector<const DenseTensor*> ins{&x, &y};
std::vector<DenseTensor*> outs{&tmp};
funcs::ElementwiseKernel<bool>(ctx, ins, &outs, Functor());
// Reduce by 'bitwise and' operator
std::vector<int> reduce_dims;
reduce_dims.resize(tmp.dims().size());
for (int i = 0; i < reduce_dims.size(); ++i) {
reduce_dims[i] = i;
}
kernels::TensorReduceImpl<bool, bool, BitwiseAdd, kps::IdentityFunctor<bool>>(
ctx, tmp, out, kps::IdentityFunctor<bool>(), reduce_dims, ctx.stream());
}
} // namespace phi
PD_REGISTER_KERNEL(less_than,
GPU,
ALL_LAYOUT,
phi::LessThanKernel,
bool,
int16_t,
int,
int64_t,
float,
double) {}
PD_REGISTER_KERNEL(less_equal,
GPU,
ALL_LAYOUT,
phi::LessEqualKernel,
bool,
int16_t,
int,
int64_t,
float,
double) {}
PD_REGISTER_KERNEL(greater_than,
GPU,
ALL_LAYOUT,
phi::GreaterThanKernel,
bool,
int16_t,
int,
int64_t,
float,
double) {}
PD_REGISTER_KERNEL(greater_equal,
GPU,
ALL_LAYOUT,
phi::GreaterEqualKernel,
bool,
int16_t,
int,
int64_t,
float,
double) {}
PD_REGISTER_KERNEL(equal,
GPU,
ALL_LAYOUT,
phi::EqualKernel,
bool,
int16_t,
int,
int64_t,
float,
double) {}
PD_REGISTER_KERNEL(not_equal,
GPU,
ALL_LAYOUT,
phi::NotEqualKernel,
bool,
int16_t,
int,
int64_t,
float,
double) {}
PD_REGISTER_KERNEL(equal_all,
GPU,
ALL_LAYOUT,
phi::EqualAllKernel,
bool,
int,
int64_t,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/compare_kernel.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/compare_functors.h"
namespace phi {
template <typename T,
typename Context,
typename Functor,
typename InverseFunctor>
inline void CompareKernelImpl(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out);
template <typename T, typename Context, typename Functor>
inline void CompareAllKernelImpl(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out);
#define DEFINE_COMPARE_KERNEL(compare_kernel, functor, inverse_functor) \
template <typename T, typename Context> \
void compare_kernel(const Context& ctx, \
const DenseTensor& x, \
const DenseTensor& y, \
int axis, \
DenseTensor* out) { \
CompareKernelImpl<T, Context, functor<T>, inverse_functor<T>>( \
ctx, x, y, axis, out); \
}
DEFINE_COMPARE_KERNEL(LessThanKernel,
funcs::LessThanFunctor,
funcs::GreaterThanFunctor)
DEFINE_COMPARE_KERNEL(LessEqualKernel,
funcs::LessEqualFunctor,
funcs::GreaterEqualFunctor)
DEFINE_COMPARE_KERNEL(GreaterThanKernel,
funcs::GreaterThanFunctor,
funcs::LessThanFunctor)
DEFINE_COMPARE_KERNEL(GreaterEqualKernel,
funcs::GreaterEqualFunctor,
funcs::LessEqualFunctor)
DEFINE_COMPARE_KERNEL(EqualKernel, funcs::EqualFunctor, funcs::EqualFunctor)
DEFINE_COMPARE_KERNEL(NotEqualKernel,
funcs::NotEqualFunctor,
funcs::NotEqualFunctor)
#undef DEFINE_COMPARE_KERNEL
#define DEFINE_COMPARE_ALL_KERNEL(compare_all_kernel, functor) \
template <typename T, typename Context> \
void compare_all_kernel(const Context& ctx, \
const DenseTensor& x, \
const DenseTensor& y, \
DenseTensor* out) { \
CompareAllKernelImpl<T, Context, functor<T>>(ctx, x, y, out); \
}
DEFINE_COMPARE_ALL_KERNEL(EqualAllKernel, funcs::EqualFunctor)
#undef DEFINE_COMPARE_ALL_KERNEL
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature LessThanArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("less_than", {"X", "Y"}, {"axis"}, {"Out"});
}
KernelSignature LessEqualArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("less_equal", {"X", "Y"}, {"axis"}, {"Out"});
}
KernelSignature GreaterThanArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("greater_than", {"X", "Y"}, {"axis"}, {"Out"});
}
KernelSignature GreaterEqualArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("greater_equal", {"X", "Y"}, {"axis"}, {"Out"});
}
KernelSignature EqualArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("equal", {"X", "Y"}, {"axis"}, {"Out"});
}
KernelSignature NotEqualArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("not_equal", {"X", "Y"}, {"axis"}, {"Out"});
}
KernelSignature EqualAllArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("equal_all", {"X", "Y"}, {}, {"Out"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(less_than, phi::LessThanArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(less_equal, phi::LessEqualArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(greater_than, phi::GreaterThanArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(greater_equal, phi::GreaterEqualArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(equal, phi::EqualArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(not_equal, phi::NotEqualArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(equal_all, phi::EqualAllArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册