未验证 提交 ea1a0d45 编写于 作者: L limingshu 提交者: GitHub

Replace usage of elementwise cuda forward kernel in Compare_all_op (#33754)

上级 4d167240
......@@ -30,29 +30,13 @@ class CompareReduceOpKernel
auto* x = context.Input<Tensor>("X");
auto* y = context.Input<Tensor>("Y");
auto* z = context.Output<Tensor>("Out");
bool shape_same = true;
Tensor tmp;
framework::DDim x_dims = x->dims();
framework::DDim y_dims = y->dims();
// judge the two inputs shape is same, if not same, just return false
if (x_dims.size() != y_dims.size()) {
shape_same = false;
} else {
for (auto i = 0; i < x_dims.size(); i++) {
if (x_dims[i] != y_dims[i]) {
shape_same = false;
break;
}
}
}
bool* z_data = z->mutable_data<bool>(context.GetPlace());
if (!shape_same) {
if (x->dims() != y->dims()) {
z_data[0] = false;
} else {
tmp.mutable_data<bool>(x_dims, context.GetPlace());
tmp.mutable_data<bool>(x->dims(), context.GetPlace());
if (x->numel() == 1 && y->numel() == 1) {
bool* z_data = tmp.mutable_data<bool>(context.GetPlace());
z_data[0] = Functor()(x->data<T>()[0], y->data<T>()[0]);
......
......@@ -14,14 +14,18 @@ limitations under the License. */
#include <thrust/fill.h>
#include "paddle/fluid/operators/controlflow/compare_all_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
namespace paddle {
namespace operators {
template <typename T>
struct IdentityFunctor {
HOSTDEVICE explicit inline IdentityFunctor() {}
HOSTDEVICE inline T operator()(const T& x) const { return x; }
};
......@@ -33,6 +37,24 @@ struct BitwiseAdd {
return a & b;
}
};
template <typename T, typename Enable = void>
struct CudaEqualReduceFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T args[]) const {
return (args[0] == args[1]);
}
};
template <typename T>
struct CudaEqualReduceFunctor<
T, typename std::enable_if<std::is_floating_point<T>::value>::type> {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T args[]) const {
return fabs(static_cast<double>(args[0] - args[1])) < 1e-8;
}
};
template <typename DeviceContext, typename Functor>
class CompareReduceOpKernel
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
......@@ -44,32 +66,22 @@ class CompareReduceOpKernel
auto* x = context.Input<Tensor>("X");
auto* y = context.Input<Tensor>("Y");
auto* z = context.Output<Tensor>("Out");
bool shape_same = true;
bool* z_data = z->mutable_data<bool>(context.GetPlace());
Tensor tmp;
framework::DDim x_dims = x->dims();
framework::DDim y_dims = y->dims();
if (x_dims.size() != y_dims.size()) {
shape_same = false;
} else {
for (auto i = 0; i < x_dims.size(); i++) {
if (x_dims[i] != y_dims[i]) {
shape_same = false;
break;
}
}
}
bool* z_data = z->mutable_data<bool>(context.GetPlace());
if (!shape_same) {
if (x->dims() != y->dims()) {
thrust::device_ptr<bool> z_dev_ptr(z_data);
thrust::fill(z_dev_ptr, z_dev_ptr + 1, false);
return;
} else {
tmp.mutable_data<bool>(x_dims, context.GetPlace());
ElementwiseComputeEx<Functor, DeviceContext, T, bool>(context, x, y, 0,
Functor(), &tmp);
tmp.mutable_data<bool>(x->dims(), context.GetPlace());
const auto& cuda_ctx =
context.template device_context<platform::CUDADeviceContext>();
std::vector<const framework::Tensor*> ins = {x, y};
std::vector<framework::Tensor*> outs = {&tmp};
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary, T, bool>(
cuda_ctx, ins, &outs, Functor());
// Reduce by 'bitwise and' operator
std::vector<int> reduce_dims;
reduce_dims.resize(tmp.dims().size());
......@@ -85,18 +97,17 @@ class CompareReduceOpKernel
} // namespace operators
} // namespace paddle
#define REGISTER_COMPARE_REDUCE_CUDA_KERNEL(op_type, functor) \
REGISTER_OP_CUDA_KERNEL( \
op_type, paddle::operators::CompareReduceOpKernel< \
paddle::platform::CUDADeviceContext, functor<bool>>, \
paddle::operators::CompareReduceOpKernel< \
paddle::platform::CUDADeviceContext, functor<int>>, \
paddle::operators::CompareReduceOpKernel< \
paddle::platform::CUDADeviceContext, functor<int64_t>>, \
paddle::operators::CompareReduceOpKernel< \
paddle::platform::CUDADeviceContext, functor<float>>, \
paddle::operators::CompareReduceOpKernel< \
paddle::platform::CUDADeviceContext, functor<double>>);
REGISTER_COMPARE_REDUCE_CUDA_KERNEL(equal_all,
paddle::operators::EqualReduceFunctor);
#define REGISTER_COMPARE_REDUCE_CUDA_KERNEL(op_type, functor) \
REGISTER_OP_CUDA_KERNEL( \
op_type, \
ops::CompareReduceOpKernel<plat::CUDADeviceContext, ops::functor<bool>>, \
ops::CompareReduceOpKernel<plat::CUDADeviceContext, ops::functor<int>>, \
ops::CompareReduceOpKernel<plat::CUDADeviceContext, \
ops::functor<int64_t>>, \
ops::CompareReduceOpKernel<plat::CUDADeviceContext, \
ops::functor<float>>, \
ops::CompareReduceOpKernel<plat::CUDADeviceContext, \
ops::functor<double>>);
REGISTER_COMPARE_REDUCE_CUDA_KERNEL(equal_all, CudaEqualReduceFunctor)
#undef REGISTER_COMPARE_REDUCE_CUDA_KERNEL
......@@ -59,7 +59,6 @@ struct CudaNotEqualFunctor<
template <typename Functor, typename InverseFunctor>
class CompareOpKernel<platform::CUDADeviceContext, Functor, InverseFunctor>
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
public:
using InT = typename Functor::ELEMENT_TYPE;
using OutT = bool;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册