You need to sign in or sign up before continuing.
未验证 提交 b1365d25 编写于 作者: Y Yiqun Liu 提交者: GitHub

Unify the functor of elementwise and logical ops. (#35767)

上级 dfa242e4
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -12,9 +15,6 @@ limitations under the License. */
#include "paddle/fluid/operators/controlflow/logical_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
namespace paddle {
namespace operators {
......@@ -22,9 +22,10 @@ template <typename Functor>
class BinaryLogicalOpKernel<platform::CUDADeviceContext, Functor>
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
using InT = typename Functor::ELEMENT_TYPE;
using OutT = bool;
void Compute(const framework::ExecutionContext& ctx) const override {
using InT = typename Functor::ELEMENT_TYPE;
using OutT = bool;
auto functor = Functor();
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
......@@ -45,6 +46,9 @@ class BinaryLogicalOpKernel<platform::CUDADeviceContext, Functor>
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
#define REGISTER_LOGICAL_CUDA_KERNEL(op_name, func) \
REGISTER_OP_CUDA_KERNEL( \
op_name, \
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......
......@@ -22,6 +22,7 @@ namespace paddle {
namespace operators {
// Define the binary functors used in elementwise ops.
// Note: InverseXxxFunctor is needed when calling ElementwiseComputeEx on CPU.
// Add
template <typename T>
......
......@@ -66,8 +66,8 @@ REGISTER_OP_CPU_KERNEL(
elementwise_mod,
ops::ElementwiseModKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseModKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseModFPKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseModFPKernel<paddle::platform::CPUDeviceContext, double>);
ops::ElementwiseModKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseModKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_VERSION(elementwise_mod)
.AddCheckpoint(
......
......@@ -14,9 +14,6 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_mod_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
namespace paddle {
namespace operators {
......@@ -38,6 +35,9 @@ class ElementwiseModKernel<platform::CUDADeviceContext, T>
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
elementwise_mod, ops::ElementwiseModKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseModKernel<plat::CUDADeviceContext, int64_t>,
......
......@@ -44,9 +44,9 @@ struct ModFunctor<T,
}
};
template <typename T>
template <typename T, typename Enable = void>
struct InverseModFunctor {
inline HOSTDEVICE T operator()(T a, T b) const {
inline HOSTDEVICE T operator()(const T a, const T b) const {
T res = b % a;
if ((res != 0) && ((res < 0) != (a < 0))) res += a;
return res;
......@@ -54,8 +54,9 @@ struct InverseModFunctor {
};
template <typename T>
struct InverseModFunctorFP {
inline HOSTDEVICE T operator()(T a, T b) const {
struct InverseModFunctor<
T, typename std::enable_if_t<std::is_floating_point<T>::value>> {
inline HOSTDEVICE T operator()(const T a, const T b) const {
T res = fmod(b, a);
if ((res != 0) && ((a < 0) != (res < 0))) res += a;
return res;
......@@ -78,22 +79,6 @@ void elementwise_mod(const framework::ExecutionContext &ctx,
}
}
template <typename DeviceContext, typename T>
void elementwise_mod_fp(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *y,
framework::Tensor *z) {
int axis = ctx.Attr<int>("axis");
auto x_dims = x->dims();
auto y_dims = y->dims();
if (x_dims.size() >= y_dims.size()) {
ElementwiseComputeEx<ModFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
ModFunctor<T>(), z);
} else {
ElementwiseComputeEx<InverseModFunctorFP<T>, DeviceContext, T>(
ctx, x, y, axis, InverseModFunctorFP<T>(), z);
}
}
template <typename DeviceContext, typename T>
class ElementwiseModKernel : public framework::OpKernel<T> {
public:
......@@ -109,20 +94,5 @@ class ElementwiseModKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
class ElementwiseModFPKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *x = ctx.Input<framework::LoDTensor>("X");
auto *y = ctx.Input<framework::LoDTensor>("Y");
auto *z = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
// dtype of x and y is float or double
elementwise_mod_fp<DeviceContext, T>(ctx, x, y, z);
}
};
} // namespace operators
} // namespace paddle
......@@ -199,10 +199,6 @@ void ElementwiseComputeEx(const framework::ExecutionContext &ctx,
const framework::Tensor *x,
const framework::Tensor *y, int axis, Functor func,
framework::Tensor *z) {
z->mutable_data<OutType>(ctx.GetPlace());
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
auto pt_z = paddle::experimental::MakePtenDenseTensor(*z);
if (platform::is_gpu_place(ctx.GetPlace())) {
#if defined(__NVCC__) || defined(__HIPCC__)
std::vector<const framework::Tensor *> ins = {x, y};
......@@ -217,6 +213,11 @@ void ElementwiseComputeEx(const framework::ExecutionContext &ctx,
return;
}
z->mutable_data<OutType>(ctx.GetPlace());
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
auto pt_z = paddle::experimental::MakePtenDenseTensor(*z);
const auto &dev_ctx =
ctx.template device_context<platform::CPUDeviceContext>();
pten::ElementwiseCompute<Functor, T, OutType>(
......
......@@ -16,9 +16,6 @@ limitations under the License. */
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/function_traits.h"
// only can include the headers in paddle/top/api dirs
#include "paddle/pten/api/lib/utils/tensor_utils.h"
......@@ -27,8 +24,6 @@ limitations under the License. */
namespace paddle {
namespace operators {
namespace kps = paddle::operators::kernel_primitives;
using ElementwiseType = pten::ElementwiseType;
template <ElementwiseType ET, typename InT, typename OutT, typename Functor,
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -21,7 +24,7 @@ namespace operators {
template <typename T>
struct PowFunctor {
inline HOSTDEVICE T operator()(T a, T b) const {
inline HOSTDEVICE T operator()(const T a, const T b) const {
// TODO(wujionghao): A potential speed improvement is supporting different
// types in C++.
#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
......
......@@ -26,7 +26,6 @@
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/math/functors.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/for_range.h"
......
......@@ -22,6 +22,7 @@ namespace pten {
namespace funcs {
// Define the binary functors used in elementwise ops.
// Note: InverseXxxFunctor is needed when calling ElementwiseComputeEx on CPU.
// Add
template <typename T>
......@@ -48,10 +49,22 @@ template <typename T>
struct MultiplyFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const { return a * b; }
};
template <>
struct MultiplyFunctor<bool> {
inline HOSTDEVICE bool operator()(const bool a, const bool b) const {
return a && b;
}
};
template <typename T>
struct InverseMultiplyFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const { return b * a; }
};
template <>
struct InverseMultiplyFunctor<bool> {
inline HOSTDEVICE bool operator()(const bool a, const bool b) const {
return b && a;
}
};
// Divide
#define DIV_ERROR_INFO \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册