未验证 提交 f11a53ee 编写于 作者: L LutaoChu 提交者: GitHub

Optimize argsort Op performance on GPU

* argsort op acceleration on GPU when the input size is equal to the length of the ‘axis’ dimension
上级 1d3b27ca
......@@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/copy.h>
#include <thrust/execution_policy.h>
#include <thrust/sequence.h>
#include <thrust/sort.h>
#include "cub/cub.cuh"
#include "paddle/fluid/framework/op_registry.h"
......@@ -58,6 +60,16 @@ static __global__ void FillIndex(T* indices, T num_rows, T num_cols) {
}
}
template <typename T, typename IndType>
static __global__ void FillFlattenGrad(const T* dO, const IndType* indices,
int64_t size, T* dX) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (int i = index; i < size; i += stride) {
dX[indices[i]] = dO[i];
}
}
template <typename T, typename IndType>
static __global__ void FillGrad(const T* dO, const IndType* indices, T* dX,
IndType num_rows, IndType num_cols) {
......@@ -193,6 +205,23 @@ void ArgFullAssign(const platform::CUDADeviceContext& ctx, const Tensor* dO,
}
template <typename T>
void ArgFlattenAssign(const platform::CUDADeviceContext& ctx, const Tensor* dO,
const Tensor* indices, int64_t size, Tensor* dX) {
auto cu_stream = ctx.stream();
const int64_t block_size =
std::min(size, static_cast<int64_t>(ctx.GetMaxThreadsPerBlock()));
int64_t max_threads = ctx.GetMaxPhysicalThreadCount();
const int64_t max_blocks =
std::max(((max_threads - 1) / block_size + 1), static_cast<int64_t>(1));
const int64_t grid_size =
std::min(max_blocks, (size + block_size - 1) / block_size);
FillFlattenGrad<<<grid_size, block_size, 0, cu_stream>>>(
dO->data<T>(), indices->data<int64_t>(), size, dX->data<T>());
}
template <typename DeviceContext, typename T>
class ArgsortOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -205,8 +234,25 @@ class ArgsortOpCUDAKernel : public framework::OpKernel<T> {
auto in_dims = input->dims();
axis = (axis < 0) ? (in_dims.size() + axis) : axis;
int64_t numel = input->numel();
int64_t groups = numel / in_dims[axis];
const T* in_data = input->data<T>();
auto size = input->numel();
T* out_data = output->mutable_data<T>(ctx.GetPlace());
int64_t* ids_data = indices->mutable_data<int64_t>(ctx.GetPlace());
// Use thrust for parallel acceleration when the input size is equal to the
// length of the ‘axis’ dimension.
// Compared to the following 'Special case for full sort', ascending sort is
// 34 times faster and descending sort is 31 times faster.
if (size == in_dims[axis]) {
thrust::sequence(thrust::device, ids_data, ids_data + size);
thrust::copy(thrust::device, in_data, in_data + size, out_data);
thrust::sort_by_key(thrust::device, out_data, out_data + size, ids_data);
if (descending) {
thrust::reverse(thrust::device, out_data, out_data + size);
thrust::reverse(thrust::device, ids_data, ids_data + size);
}
return;
}
// Special case for full sort, speedup ~190x.
if (axis == -1 || axis + 1 == in_dims.size()) {
......@@ -276,23 +322,28 @@ class ArgsortGradOpCUDAKernel : public framework::OpKernel<T> {
int axis = ctx.Attr<int>("axis");
dX->mutable_data<T>(ctx.GetPlace());
auto dxt = framework::EigenVector<T>::Flatten(*dX);
auto& place = *ctx.template device_context<platform::CUDADeviceContext>()
.eigen_device();
dxt.device(place) = dxt.constant(static_cast<T>(0));
if (dO->numel() == 0) return;
auto in_dims = indices->dims();
auto in_dims = dX->dims();
axis = (axis < 0) ? (in_dims.size() + axis) : axis;
int64_t numel = indices->numel();
int64_t size = dX->numel();
const auto& dev_ctx = ctx.cuda_device_context();
// Parallel acceleration when the input size is equal to the length of the
// ‘axis’ dimension.
// Compared to 'special case for full sort' below, the gradient calculation
// is 10 times faster.
if (size == in_dims[axis]) {
ArgFlattenAssign<T>(dev_ctx, dO, indices, size, dX);
return;
}
// Special case for full sort, speedup ~190x.
if (axis == -1 || axis + 1 == in_dims.size()) {
const int64_t input_height = framework::product(
framework::slice_ddim(in_dims, 0, in_dims.size() - 1));
const int64_t input_width = in_dims[in_dims.size() - 1];
const auto& dev_ctx = ctx.cuda_device_context();
ArgFullAssign<T, int64_t>(dev_ctx, dO, indices, dX, input_height,
input_width);
} else {
......@@ -316,7 +367,6 @@ class ArgsortGradOpCUDAKernel : public framework::OpKernel<T> {
Tensor trans_ind;
trans_ind.mutable_data<int64_t>(trans_dims, ctx.GetPlace());
int ndims = trans.size();
const auto& dev_ctx = ctx.cuda_device_context();
// Do transpose
TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, *dO,
&trans_dO, trans);
......@@ -345,11 +395,17 @@ class ArgsortGradOpCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(
argsort, paddle::operators::ArgsortOpCUDAKernel<float>,
paddle::operators::ArgsortOpCUDAKernel<double>,
paddle::operators::ArgsortOpCUDAKernel<int>,
paddle::operators::ArgsortOpCUDAKernel<int64_t>,
paddle::operators::ArgsortOpCUDAKernel<paddle::platform::float16>);
argsort,
paddle::operators::ArgsortOpCUDAKernel<paddle::platform::CUDADeviceContext,
float>,
paddle::operators::ArgsortOpCUDAKernel<paddle::platform::CUDADeviceContext,
double>,
paddle::operators::ArgsortOpCUDAKernel<paddle::platform::CUDADeviceContext,
int>,
paddle::operators::ArgsortOpCUDAKernel<paddle::platform::CUDADeviceContext,
int64_t>,
paddle::operators::ArgsortOpCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL(
argsort_grad, paddle::operators::ArgsortGradOpCUDAKernel<float>,
paddle::operators::ArgsortGradOpCUDAKernel<double>,
......
......@@ -348,57 +348,99 @@ class TestArgsortErrorOnGPU(TestArgsortErrorOnCPU):
class TestArgsort(unittest.TestCase):
def init(self):
self.input_shape = [10000, ]
self.axis = 0
def setUp(self):
self.init()
if core.is_compiled_with_cuda():
self.place = core.CUDAPlace(0)
else:
self.place = core.CPUPlace()
self.data = np.random.rand(2, 3, 4).astype("float32")
self.data = np.random.rand(*self.input_shape)
def test_api_0(self):
def test_api(self):
with fluid.program_guard(fluid.Program()):
input = fluid.data(name="input", shape=[2, 3, 4], dtype="float32")
output = paddle.argsort(x=input)
exe = fluid.Executor(self.place)
result, = exe.run(feed={'input': self.data}, fetch_list=[output])
np_result = np.argsort(self.data)
self.assertEqual((result == np_result).all(), True)
input = fluid.data(
name="input", shape=self.input_shape, dtype="float64")
output = paddle.argsort(input, axis=self.axis)
output2 = paddle.argsort(input, axis=self.axis, descending=True)
def test_api_1(self):
with fluid.program_guard(fluid.Program()):
input = fluid.data(name="input", shape=[2, 3, 4], dtype="float32")
output = paddle.argsort(x=input, axis=1)
exe = fluid.Executor(self.place)
result, = exe.run(feed={'input': self.data}, fetch_list=[output])
np_result = np.argsort(self.data, axis=1)
result, result2 = exe.run(feed={'input': self.data},
fetch_list=[output, output2])
np_result = np.argsort(self.data, axis=self.axis)
self.assertEqual((result == np_result).all(), True)
np_result2 = np.argsort(-self.data, axis=self.axis)
self.assertEqual((result2 == np_result2).all(), True)
class TestArgsort2(TestArgsort):
def init(self):
self.input_shape = [10000, 1]
self.axis = 0
class TestArgsort3(TestArgsort):
def init(self):
self.input_shape = [1, 10000]
self.axis = 1
class TestArgsort4(TestArgsort):
def init(self):
self.input_shape = [2, 3, 4]
self.axis = 1
class TestArgsortImperative(unittest.TestCase):
def init(self):
self.input_shape = [10000, ]
self.axis = 0
class TestArgsortDygraph(unittest.TestCase):
def setUp(self):
self.input_data = np.random.rand(10, 10)
self.init()
self.input_data = np.random.rand(*self.input_shape)
if core.is_compiled_with_cuda():
self.place = core.CUDAPlace(0)
else:
self.place = core.CPUPlace()
def test_api_0(self):
def test_api(self):
paddle.disable_static(self.place)
var_x = paddle.to_variable(self.input_data)
out = paddle.argsort(var_x)
self.assertEqual((np.argsort(self.input_data) == out.numpy()).all(),
True)
paddle.enable_static()
var_x = paddle.to_tensor(self.input_data)
out = paddle.argsort(var_x, axis=self.axis)
expect = np.argsort(self.input_data, axis=self.axis)
self.assertEqual((expect == out.numpy()).all(), True)
out2 = paddle.argsort(var_x, axis=self.axis, descending=True)
expect2 = np.argsort(-self.input_data, axis=self.axis)
self.assertEqual((expect2 == out2.numpy()).all(), True)
def test_api_1(self):
paddle.disable_static(self.place)
var_x = paddle.to_variable(self.input_data)
out = paddle.argsort(var_x, axis=-1)
self.assertEqual(
(np.argsort(
self.input_data, axis=-1) == out.numpy()).all(), True)
paddle.enable_static()
class TestArgsortImperative2(TestArgsortImperative):
def init(self):
self.input_shape = [10000, 1]
self.axis = 0
class TestArgsortImperative3(TestArgsortImperative):
def init(self):
self.input_shape = [1, 10000]
self.axis = 1
class TestArgsortImperative2(TestArgsortImperative):
def init(self):
self.input_shape = [2, 3, 4]
self.axis = 1
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册