diff --git a/paddle/fluid/operators/cast_op.cu b/paddle/fluid/operators/cast_op.cu index 55cc5a675b46b7ecc6b36743f83cacf9f9ba3791..13759633d0168a4d38796a88fe8db215cfcfe380 100644 --- a/paddle/fluid/operators/cast_op.cu +++ b/paddle/fluid/operators/cast_op.cu @@ -19,6 +19,43 @@ limitations under the License. */ namespace paddle { namespace operators { +// aligned vector generates vectorized load/store on CUDA +template +struct alignas(sizeof(T) * Size) AlignedVector { + T val[Size]; +}; + +template +inline int VectorizedSize(const T* pointer) { + uint64_t address = reinterpret_cast(pointer); + constexpr int vec4 = std::alignment_of>::value; // NOLINT + if (address % vec4 == 0) { + return 4; + } + return 1; +} + +template +__global__ void VecCastCUDAKernel(const InT* in, const int64_t N, OutT* out) { + int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; + using LoadT = AlignedVector; + using StoreT = AlignedVector; + for (int i = idx * VecSize; i < N; i += blockDim.x * gridDim.x * VecSize) { + InT in_vec[VecSize]; + LoadT* in_value = reinterpret_cast(&in_vec); + *in_value = *reinterpret_cast(&in[i]); + + OutT out_vec[VecSize]; +#pragma unroll + for (int ii = 0; ii < VecSize; ii++) { + out_vec[ii] = static_cast(in_vec[ii]); + } + + *(reinterpret_cast(&out[i])) = + *reinterpret_cast(&out_vec[0]); + } +} + template __global__ void CastCUDAKernel(const InT* in, const int64_t N, OutT* out) { CUDA_KERNEL_LOOP(index, N) { out[index] = static_cast(in[index]); } @@ -40,8 +77,16 @@ struct CastOpFunctor { auto* out = out_->mutable_data(ctx_.GetPlace()); platform::GpuLaunchConfig config = platform::GetGpuLaunchConfig1D(ctx_, size); - CastCUDAKernel<<>>(in, size, out); + int vec_size = VectorizedSize(out); + if (!std::is_same::value && vec_size == 4 && size % 4 == 0) { + VecCastCUDAKernel<<< + config.block_per_grid, config.thread_per_block, 0, ctx_.stream()>>>( + in, size, out); + } else { + CastCUDAKernel<<>>( + in, size, out); + } } }; diff --git a/paddle/fluid/operators/tril_triu_op.cc b/paddle/fluid/operators/tril_triu_op.cc index 445163f03f6627a14feb50d792edf066d6db8816..8fb0b3809503ecc86e33796a4bc7f7cb2d21f8bb 100644 --- a/paddle/fluid/operators/tril_triu_op.cc +++ b/paddle/fluid/operators/tril_triu_op.cc @@ -99,6 +99,7 @@ class TrilTriuGradOpMaker : public framework::SingleGradOpMaker { } // namespace paddle namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OPERATOR(tril_triu, ops::TrilTriuOp, ops::TrilTriuOpMaker, ops::TrilTriuGradOpMaker, ops::TrilTriuGradOpMaker); @@ -107,10 +108,13 @@ REGISTER_OP_CPU_KERNEL( tril_triu, ops::TrilTriuOpKernel, ops::TrilTriuOpKernel, ops::TrilTriuOpKernel, - ops::TrilTriuOpKernel); + ops::TrilTriuOpKernel, + ops::TrilTriuOpKernel); REGISTER_OP_CPU_KERNEL( tril_triu_grad, ops::TrilTriuGradOpKernel, ops::TrilTriuGradOpKernel, ops::TrilTriuGradOpKernel, - ops::TrilTriuGradOpKernel); + ops::TrilTriuGradOpKernel, + ops::TrilTriuGradOpKernel); diff --git a/paddle/fluid/operators/tril_triu_op.cu b/paddle/fluid/operators/tril_triu_op.cu index b81939053181f011f04c250b588a6e8a8b411a53..d04acd340597928ba0fbbbebf2dfc7eda1d698ac 100644 --- a/paddle/fluid/operators/tril_triu_op.cu +++ b/paddle/fluid/operators/tril_triu_op.cu @@ -15,16 +15,20 @@ limitations under the License. */ #include "paddle/fluid/operators/tril_triu_op.h" namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( tril_triu, ops::TrilTriuOpKernel, ops::TrilTriuOpKernel, ops::TrilTriuOpKernel, - ops::TrilTriuOpKernel); + ops::TrilTriuOpKernel, + ops::TrilTriuOpKernel); REGISTER_OP_CUDA_KERNEL( tril_triu_grad, ops::TrilTriuGradOpKernel, ops::TrilTriuGradOpKernel, ops::TrilTriuGradOpKernel, - ops::TrilTriuGradOpKernel); + ops::TrilTriuGradOpKernel, + ops::TrilTriuGradOpKernel); diff --git a/paddle/fluid/operators/tril_triu_op.h b/paddle/fluid/operators/tril_triu_op.h index ed9b244d346356466bd22638db6768828a36184b..3150b7617d10a8f9c2f60dd2e74ab2cbbb2d655e 100644 --- a/paddle/fluid/operators/tril_triu_op.h +++ b/paddle/fluid/operators/tril_triu_op.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/for_range.h" namespace paddle { diff --git a/python/paddle/fluid/tests/unittests/test_tril_triu_op.py b/python/paddle/fluid/tests/unittests/test_tril_triu_op.py index 2cd2599f2ea2f4fb26b2d2730ca45384a3b664a7..cdb5f66f578924f1050993fb99dea7e44ac0efe4 100644 --- a/python/paddle/fluid/tests/unittests/test_tril_triu_op.py +++ b/python/paddle/fluid/tests/unittests/test_tril_triu_op.py @@ -16,8 +16,10 @@ from __future__ import print_function import unittest import numpy as np from op_test import OpTest +import paddle import paddle.fluid as fluid import paddle.tensor as tensor +from paddle.fluid.framework import Program, program_guard class TrilTriuOpDefaultTest(OpTest): @@ -68,6 +70,8 @@ def case_generator(op_type, Xshape, diagonal, expected): class FailureCase(unittest.TestCase): def test_failure(self): + paddle.enable_static() + data = fluid.data(shape=Xshape, dtype='float64', name=cls_name) with self.assertRaisesRegexp( eval(expected.split(':')[-1]), errmsg[expected]): @@ -75,6 +79,8 @@ def case_generator(op_type, Xshape, diagonal, expected): class SuccessCase(TrilTriuOpDefaultTest): def initTestCase(self): + paddle.enable_static() + self.real_op_type = op_type self.diagonal = diagonal self.X = np.random.random(Xshape).astype("float64") @@ -120,39 +126,58 @@ class TestTrilTriuOpAPI(unittest.TestCase): """ def test_api(self): - data = np.random.random([1, 9, 9, 4]).astype('float32') - x = fluid.data(shape=[1, 9, -1, 4], dtype='float32', name='x') - tril_out, triu_out = tensor.tril(x), tensor.triu(x) - - place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( - ) else fluid.CPUPlace() - exe = fluid.Executor(place) - tril_out, triu_out = exe.run( - fluid.default_main_program(), - feed={"x": data}, - fetch_list=[tril_out, triu_out], ) - self.assertTrue(np.allclose(tril_out, np.tril(data))) - self.assertTrue(np.allclose(triu_out, np.triu(data))) + paddle.enable_static() + + dtypes = ['float16', 'float32'] + for dtype in dtypes: + prog = Program() + startup_prog = Program() + with program_guard(prog, startup_prog): + data = np.random.random([1, 9, 9, 4]).astype(dtype) + x = fluid.data(shape=[1, 9, -1, 4], dtype=dtype, name='x') + tril_out, triu_out = tensor.tril(x), tensor.triu(x) + + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + exe = fluid.Executor(place) + tril_out, triu_out = exe.run( + fluid.default_main_program(), + feed={"x": data}, + fetch_list=[tril_out, triu_out], ) + self.assertTrue(np.allclose(tril_out, np.tril(data))) + self.assertTrue(np.allclose(triu_out, np.triu(data))) def test_api_with_dygraph(self): - with fluid.dygraph.guard(): - data = np.random.random([1, 9, 9, 4]).astype('float32') - x = fluid.dygraph.to_variable(data) - tril_out, triu_out = tensor.tril(x).numpy(), tensor.triu(x).numpy() - self.assertTrue(np.allclose(tril_out, np.tril(data))) - self.assertTrue(np.allclose(triu_out, np.triu(data))) + paddle.disable_static() + + dtypes = ['float16', 'float32'] + for dtype in dtypes: + with fluid.dygraph.guard(): + data = np.random.random([1, 9, 9, 4]).astype(dtype) + x = fluid.dygraph.to_variable(data) + tril_out, triu_out = tensor.tril(x).numpy(), tensor.triu( + x).numpy() + self.assertTrue(np.allclose(tril_out, np.tril(data))) + self.assertTrue(np.allclose(triu_out, np.triu(data))) def test_fluid_api(self): - data = np.random.random([1, 9, 9, 4]).astype('float32') - x = fluid.data(shape=[1, 9, -1, 4], dtype='float32', name='x') - triu_out = fluid.layers.triu(x) - - place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( - ) else fluid.CPUPlace() - exe = fluid.Executor(place) - triu_out = exe.run(fluid.default_main_program(), - feed={"x": data}, - fetch_list=[triu_out]) + paddle.enable_static() + + dtypes = ['float16', 'float32'] + for dtype in dtypes: + prog = Program() + startup_prog = Program() + with program_guard(prog, startup_prog): + data = np.random.random([1, 9, 9, 4]).astype(dtype) + x = fluid.data(shape=[1, 9, -1, 4], dtype=dtype, name='x') + triu_out = fluid.layers.triu(x) + + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + exe = fluid.Executor(place) + triu_out = exe.run(fluid.default_main_program(), + feed={"x": data}, + fetch_list=[triu_out]) if __name__ == '__main__': diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index fd5ca15840076022870af033a5d23a37fd49995d..056a0226723ca1797e1ed8bff99733bba61d84a8 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -558,8 +558,8 @@ def _tril_triu_op(helper): x = helper.kwargs.get('x', None) assert x is not None, 'x cannot be None in {}'.format(op_type) - check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'], - op_type) + check_variable_and_dtype( + x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], op_type) if len(x.shape) < 2: raise ValueError("x shape in {} must be at least 2-D".format(op_type)) diagonal = helper.kwargs.get('diagonal', 0)