未验证 提交 0f154961 编写于 作者: L limingshu 提交者: GitHub

Reimplement the comparision binary ops using the new optimized CUDA function (#33064)

上级 e8d6ff50
...@@ -13,18 +13,85 @@ See the License for the specific language governing permissions and ...@@ -13,18 +13,85 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/controlflow/compare_op.h" #include "paddle/fluid/operators/controlflow/compare_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
REGISTER_COMPARE_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor, namespace ops = paddle::operators;
paddle::operators::GreaterThanFunctor); namespace plat = paddle::platform;
REGISTER_COMPARE_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor,
paddle::operators::GreaterEqualFunctor); namespace paddle {
REGISTER_COMPARE_KERNEL(greater_than, CUDA, namespace operators {
paddle::operators::GreaterThanFunctor,
paddle::operators::LessThanFunctor); #define DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(Func, op) \
REGISTER_COMPARE_KERNEL(greater_equal, CUDA, template <typename T, typename Enable = void> \
paddle::operators::GreaterEqualFunctor, struct Func##Functor { \
paddle::operators::LessEqualFunctor); using ELEMENT_TYPE = T; \
REGISTER_COMPARE_KERNEL(equal, CUDA, paddle::operators::EqualFunctor, inline HOSTDEVICE bool operator()(const T* args) const { \
paddle::operators::EqualFunctor); return args[0] op args[1]; \
REGISTER_COMPARE_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor, } \
paddle::operators::NotEqualFunctor); };
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaLessThan, <)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaLessEqual, <=)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaGreaterThan, >)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaGreaterEqual, >=)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaEqual, ==)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(CudaNotEqual, !=)
#undef DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT
template <typename T>
struct CudaEqualFunctor<
T, typename std::enable_if<std::is_floating_point<T>::value>::type> {
using ELEMENT_TYPE = T;
HOSTDEVICE bool operator()(const T* args) const {
return fabs(static_cast<double>(args[0] - args[1])) < 1e-8;
}
};
template <typename T>
struct CudaNotEqualFunctor<
T, typename std::enable_if<std::is_floating_point<T>::value>::type> {
using ELEMENT_TYPE = T;
HOSTDEVICE bool operator()(const T* args) const {
return fabs(static_cast<double>(args[0] - args[1])) > 1e-8;
}
};
template <typename Functor, typename InverseFunctor>
class CompareOpKernel<platform::CUDADeviceContext, Functor, InverseFunctor>
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
public:
using InT = typename Functor::ELEMENT_TYPE;
using OutT = bool;
void Compute(const framework::ExecutionContext& ctx) const override {
auto functor = Functor();
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
PackTensorsIntoVector<OutT>(ctx, &ins, &outs);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, InT, OutT>(
ctx, ins, &outs, functor);
}
};
} // namespace operators
} // namespace paddle
#define REGISTER_CUDA_COMPARE_KERNEL(op_type, func) \
REGISTER_OP_CUDA_KERNEL( \
op_type, ops::CompareOpKernel<plat::CUDADeviceContext, \
ops::func##Functor<int>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, \
ops::func##Functor<int64_t>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func##Functor<float>, \
void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, \
ops::func##Functor<double>, void>);
REGISTER_CUDA_COMPARE_KERNEL(equal, CudaEqual)
REGISTER_CUDA_COMPARE_KERNEL(not_equal, CudaNotEqual)
REGISTER_CUDA_COMPARE_KERNEL(less_than, CudaLessThan)
REGISTER_CUDA_COMPARE_KERNEL(less_equal, CudaLessEqual)
REGISTER_CUDA_COMPARE_KERNEL(greater_than, CudaGreaterThan)
REGISTER_CUDA_COMPARE_KERNEL(greater_equal, CudaGreaterEqual)
#undef REGISTER_CUDA_COMPARE_KERNEL
...@@ -42,20 +42,11 @@ class ElementwiseAddKernel<platform::CUDADeviceContext, T> ...@@ -42,20 +42,11 @@ class ElementwiseAddKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> { : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::LoDTensor>("X"); std::vector<const framework::Tensor*> ins;
auto* y = ctx.Input<framework::LoDTensor>("Y"); std::vector<framework::Tensor*> outs;
auto* z = ctx.Output<framework::LoDTensor>("Out"); PackTensorsIntoVector<T>(ctx, &ins, &outs);
z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis");
axis = axis == -1 ? std::abs(x->dims().size() - y->dims().size()) : axis;
std::vector<const framework::Tensor*> ins = {x, y};
std::vector<framework::Tensor*> outs = {z};
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>( LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
cuda_ctx, ins, &outs, axis, CudaAddFunctor<T>()); ctx, ins, &outs, CudaAddFunctor<T>());
} }
}; };
......
...@@ -343,7 +343,6 @@ template <typename InT, typename OutT, typename BroadcastArgsWarpper, ...@@ -343,7 +343,6 @@ template <typename InT, typename OutT, typename BroadcastArgsWarpper,
__global__ void ElementwiseBroadcastKernel( __global__ void ElementwiseBroadcastKernel(
BroadcastArgsWarpper broadcast_warpper, int main_tid, int tail_tid) { BroadcastArgsWarpper broadcast_warpper, int main_tid, int tail_tid) {
int tid = threadIdx.x + blockIdx.x * blockDim.x; int tid = threadIdx.x + blockIdx.x * blockDim.x;
// Vectorized calculation of major data whose length is the max multipler of // Vectorized calculation of major data whose length is the max multipler of
// VecSize, // VecSize,
// eg: Calcualting the front 1024-length data in total 1027 data once VecSize // eg: Calcualting the front 1024-length data in total 1027 data once VecSize
...@@ -501,23 +500,30 @@ void LaunchBroadcastElementwiseCudaKernel( ...@@ -501,23 +500,30 @@ void LaunchBroadcastElementwiseCudaKernel(
} }
} }
template <ElementwiseType ET, typename InT, typename OutType, typename Functor> template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
void LaunchElementwiseCudaKernel( void LaunchElementwiseCudaKernel(
const platform::CUDADeviceContext &cuda_ctx, const framework::ExecutionContext &ctx,
const std::vector<const framework::Tensor *> &ins, const std::vector<const framework::Tensor *> &ins,
std::vector<framework::Tensor *> *outs, int axis, Functor func) { std::vector<framework::Tensor *> *outs, Functor func) {
std::vector<int> dims_size;
bool no_broadcast_flag = true; bool no_broadcast_flag = true;
for (auto *in : ins) { for (auto *in : ins) {
no_broadcast_flag = ins[0]->dims() == in->dims(); no_broadcast_flag = ins[0]->dims() == in->dims();
dims_size.emplace_back(in->dims().size());
} }
const auto &cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
if (no_broadcast_flag) { if (no_broadcast_flag) {
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary, InT, OutType>( LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary, InT, OutT>(
cuda_ctx, ins, outs, func); cuda_ctx, ins, outs, func);
} else { } else {
LaunchBroadcastElementwiseCudaKernel<ElementwiseType::kBinary, InT, int axis = ctx.HasAttr("axis") ? ctx.Attr<int>("axis") : -1;
OutType>(cuda_ctx, ins, outs, axis, axis = axis == -1
func); ? *std::max_element(dims_size.begin(), dims_size.end()) -
*std::min_element(dims_size.begin(), dims_size.end())
: axis;
LaunchBroadcastElementwiseCudaKernel<ET, InT, OutT>(cuda_ctx, ins, outs,
axis, func);
} }
} }
......
...@@ -60,6 +60,26 @@ constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024; ...@@ -60,6 +60,26 @@ constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024;
namespace paddle { namespace paddle {
namespace operators { namespace operators {
/*
* To pack the input and output tnesors into vector for
* LaunchElementwiseCudaKernel
*/
template <typename T>
void PackTensorsIntoVector(const framework::ExecutionContext &ctx,
std::vector<const framework::Tensor *> *ins,
std::vector<framework::Tensor *> *outs) {
auto *x = ctx.Input<framework::LoDTensor>("X");
auto *y = ctx.Input<framework::LoDTensor>("Y");
auto *z = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
ins->emplace_back(x);
outs->emplace_back(z);
if (y != nullptr) {
ins->emplace_back(y);
}
}
/* /*
* Out = X ⊙ Y * Out = X ⊙ Y
* If Y's shape does not match X' shape, they will be reshaped. * If Y's shape does not match X' shape, they will be reshaped.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册