未验证 提交 ea1a0d45 编写于 作者: L limingshu 提交者: GitHub

Replace usage of elementwise cuda forward kernel in Compare_all_op (#33754)

上级 4d167240
...@@ -30,29 +30,13 @@ class CompareReduceOpKernel ...@@ -30,29 +30,13 @@ class CompareReduceOpKernel
auto* x = context.Input<Tensor>("X"); auto* x = context.Input<Tensor>("X");
auto* y = context.Input<Tensor>("Y"); auto* y = context.Input<Tensor>("Y");
auto* z = context.Output<Tensor>("Out"); auto* z = context.Output<Tensor>("Out");
bool shape_same = true;
Tensor tmp; Tensor tmp;
framework::DDim x_dims = x->dims();
framework::DDim y_dims = y->dims();
// judge the two inputs shape is same, if not same, just return false
if (x_dims.size() != y_dims.size()) {
shape_same = false;
} else {
for (auto i = 0; i < x_dims.size(); i++) {
if (x_dims[i] != y_dims[i]) {
shape_same = false;
break;
}
}
}
bool* z_data = z->mutable_data<bool>(context.GetPlace()); bool* z_data = z->mutable_data<bool>(context.GetPlace());
if (!shape_same) {
if (x->dims() != y->dims()) {
z_data[0] = false; z_data[0] = false;
} else { } else {
tmp.mutable_data<bool>(x_dims, context.GetPlace()); tmp.mutable_data<bool>(x->dims(), context.GetPlace());
if (x->numel() == 1 && y->numel() == 1) { if (x->numel() == 1 && y->numel() == 1) {
bool* z_data = tmp.mutable_data<bool>(context.GetPlace()); bool* z_data = tmp.mutable_data<bool>(context.GetPlace());
z_data[0] = Functor()(x->data<T>()[0], y->data<T>()[0]); z_data[0] = Functor()(x->data<T>()[0], y->data<T>()[0]);
......
...@@ -14,14 +14,18 @@ limitations under the License. */ ...@@ -14,14 +14,18 @@ limitations under the License. */
#include <thrust/fill.h> #include <thrust/fill.h>
#include "paddle/fluid/operators/controlflow/compare_all_op.h" #include "paddle/fluid/operators/controlflow/compare_all_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h" #include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T>
struct IdentityFunctor { struct IdentityFunctor {
HOSTDEVICE explicit inline IdentityFunctor() {} HOSTDEVICE explicit inline IdentityFunctor() {}
HOSTDEVICE inline T operator()(const T& x) const { return x; } HOSTDEVICE inline T operator()(const T& x) const { return x; }
}; };
...@@ -33,6 +37,24 @@ struct BitwiseAdd { ...@@ -33,6 +37,24 @@ struct BitwiseAdd {
return a & b; return a & b;
} }
}; };
template <typename T, typename Enable = void>
struct CudaEqualReduceFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T args[]) const {
return (args[0] == args[1]);
}
};
template <typename T>
struct CudaEqualReduceFunctor<
T, typename std::enable_if<std::is_floating_point<T>::value>::type> {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T args[]) const {
return fabs(static_cast<double>(args[0] - args[1])) < 1e-8;
}
};
template <typename DeviceContext, typename Functor> template <typename DeviceContext, typename Functor>
class CompareReduceOpKernel class CompareReduceOpKernel
: public framework::OpKernel<typename Functor::ELEM_TYPE> { : public framework::OpKernel<typename Functor::ELEM_TYPE> {
...@@ -44,32 +66,22 @@ class CompareReduceOpKernel ...@@ -44,32 +66,22 @@ class CompareReduceOpKernel
auto* x = context.Input<Tensor>("X"); auto* x = context.Input<Tensor>("X");
auto* y = context.Input<Tensor>("Y"); auto* y = context.Input<Tensor>("Y");
auto* z = context.Output<Tensor>("Out"); auto* z = context.Output<Tensor>("Out");
bool shape_same = true; bool* z_data = z->mutable_data<bool>(context.GetPlace());
Tensor tmp; Tensor tmp;
framework::DDim x_dims = x->dims();
framework::DDim y_dims = y->dims();
if (x_dims.size() != y_dims.size()) { if (x->dims() != y->dims()) {
shape_same = false;
} else {
for (auto i = 0; i < x_dims.size(); i++) {
if (x_dims[i] != y_dims[i]) {
shape_same = false;
break;
}
}
}
bool* z_data = z->mutable_data<bool>(context.GetPlace());
if (!shape_same) {
thrust::device_ptr<bool> z_dev_ptr(z_data); thrust::device_ptr<bool> z_dev_ptr(z_data);
thrust::fill(z_dev_ptr, z_dev_ptr + 1, false); thrust::fill(z_dev_ptr, z_dev_ptr + 1, false);
return; return;
} else { } else {
tmp.mutable_data<bool>(x_dims, context.GetPlace()); tmp.mutable_data<bool>(x->dims(), context.GetPlace());
ElementwiseComputeEx<Functor, DeviceContext, T, bool>(context, x, y, 0, const auto& cuda_ctx =
Functor(), &tmp); context.template device_context<platform::CUDADeviceContext>();
std::vector<const framework::Tensor*> ins = {x, y};
std::vector<framework::Tensor*> outs = {&tmp};
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary, T, bool>(
cuda_ctx, ins, &outs, Functor());
// Reduce by 'bitwise and' operator // Reduce by 'bitwise and' operator
std::vector<int> reduce_dims; std::vector<int> reduce_dims;
reduce_dims.resize(tmp.dims().size()); reduce_dims.resize(tmp.dims().size());
...@@ -85,18 +97,17 @@ class CompareReduceOpKernel ...@@ -85,18 +97,17 @@ class CompareReduceOpKernel
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#define REGISTER_COMPARE_REDUCE_CUDA_KERNEL(op_type, functor) \ #define REGISTER_COMPARE_REDUCE_CUDA_KERNEL(op_type, functor) \
REGISTER_OP_CUDA_KERNEL( \ REGISTER_OP_CUDA_KERNEL( \
op_type, paddle::operators::CompareReduceOpKernel< \ op_type, \
paddle::platform::CUDADeviceContext, functor<bool>>, \ ops::CompareReduceOpKernel<plat::CUDADeviceContext, ops::functor<bool>>, \
paddle::operators::CompareReduceOpKernel< \ ops::CompareReduceOpKernel<plat::CUDADeviceContext, ops::functor<int>>, \
paddle::platform::CUDADeviceContext, functor<int>>, \ ops::CompareReduceOpKernel<plat::CUDADeviceContext, \
paddle::operators::CompareReduceOpKernel< \ ops::functor<int64_t>>, \
paddle::platform::CUDADeviceContext, functor<int64_t>>, \ ops::CompareReduceOpKernel<plat::CUDADeviceContext, \
paddle::operators::CompareReduceOpKernel< \ ops::functor<float>>, \
paddle::platform::CUDADeviceContext, functor<float>>, \ ops::CompareReduceOpKernel<plat::CUDADeviceContext, \
paddle::operators::CompareReduceOpKernel< \ ops::functor<double>>);
paddle::platform::CUDADeviceContext, functor<double>>);
REGISTER_COMPARE_REDUCE_CUDA_KERNEL(equal_all, CudaEqualReduceFunctor)
REGISTER_COMPARE_REDUCE_CUDA_KERNEL(equal_all, #undef REGISTER_COMPARE_REDUCE_CUDA_KERNEL
paddle::operators::EqualReduceFunctor);
...@@ -59,7 +59,6 @@ struct CudaNotEqualFunctor< ...@@ -59,7 +59,6 @@ struct CudaNotEqualFunctor<
template <typename Functor, typename InverseFunctor> template <typename Functor, typename InverseFunctor>
class CompareOpKernel<platform::CUDADeviceContext, Functor, InverseFunctor> class CompareOpKernel<platform::CUDADeviceContext, Functor, InverseFunctor>
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> { : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
public: public:
using InT = typename Functor::ELEMENT_TYPE; using InT = typename Functor::ELEMENT_TYPE;
using OutT = bool; using OutT = bool;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册