未验证 提交 841efcd4 编写于 作者: C co63oc 提交者: GitHub

【Hackathon No.59】addmm 算子FP16/BF16单测完善 (#53111)

* Add addmm tests

* Fix code
上级 74074a8d
...@@ -1316,6 +1316,74 @@ inline void Blas<phi::GPUContext>::GEMM(bool transA, ...@@ -1316,6 +1316,74 @@ inline void Blas<phi::GPUContext>::GEMM(bool transA,
}); });
} }
template <>
template <>
inline void Blas<phi::GPUContext>::GEMM(bool transA,
bool transB,
int M,
int N,
int K,
phi::dtype::bfloat16 alpha,
const phi::dtype::bfloat16 *A,
int lda,
const phi::dtype::bfloat16 *B,
int ldb,
phi::dtype::bfloat16 beta,
phi::dtype::bfloat16 *C,
int ldc) const {
#if CUDA_VERSION >= 11000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
PADDLE_ENFORCE_GE(
context_.GetComputeCapability(),
80,
phi::errors::InvalidArgument(
"cublas bf16 gemm requires GPU compute capability >= 80,"
"but received %d",
context_.GetComputeCapability()));
float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
bool use_tensor_op_math = context_.tensor_core_available();
if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
}
VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False");
context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle,
cuTransB,
cuTransA,
N,
M,
K,
&h_alpha,
B,
CUDA_R_16BF,
ldb,
A,
CUDA_R_16BF,
lda,
&h_beta,
C,
CUDA_R_16BF,
ldc,
CUDA_R_32F,
algo));
});
#else
// raise error
PADDLE_THROW(phi::errors::Unimplemented(
"cublasGemmEx with bfloat16 is not supported on cuda <= 11"));
#endif // CUDA_VERSION >= 11000
}
template <> template <>
template <typename T> template <typename T>
void Blas<phi::GPUContext>::AXPY(int n, T alpha, const T *x, T *y) const { void Blas<phi::GPUContext>::AXPY(int n, T alpha, const T *x, T *y) const {
......
...@@ -751,7 +751,7 @@ inline void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA, ...@@ -751,7 +751,7 @@ inline void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
context_.GetComputeCapability(), context_.GetComputeCapability(),
80, 80,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"rocblas fp16 gemm requires GPU compute capability >= 80," "rocblas bf16 gemm requires GPU compute capability >= 80,"
"but received %d", "but received %d",
context_.GetComputeCapability())); context_.GetComputeCapability()));
...@@ -982,6 +982,70 @@ inline void Blas<phi::GPUContext>::GEMM(bool transA, ...@@ -982,6 +982,70 @@ inline void Blas<phi::GPUContext>::GEMM(bool transA,
}); });
} }
template <>
template <>
inline void Blas<phi::GPUContext>::GEMM(bool transA,
bool transB,
int M,
int N,
int K,
phi::dtype::bfloat16 alpha,
const phi::dtype::bfloat16 *A,
int lda,
const phi::dtype::bfloat16 *B,
int ldb,
phi::dtype::bfloat16 beta,
phi::dtype::bfloat16 *C,
int ldc) const {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
rocblas_operation cuTransA = (transA == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
rocblas_operation cuTransB = (transB == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
PADDLE_ENFORCE_GE(
context_.GetComputeCapability(),
80,
phi::errors::InvalidArgument(
"rocblas bf16 gemm requires GPU compute capability >= 80,"
"but received %d",
context_.GetComputeCapability()));
float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);
rocblas_gemm_algo algo = rocblas_gemm_algo_standard;
context_.TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) {
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::rocblas_gemm_ex(handle,
cuTransB,
cuTransA,
N,
M,
K,
&h_alpha,
B,
rocblas_datatype_bf16_r,
ldb,
A,
rocblas_datatype_bf16_r,
lda,
&h_beta,
C,
rocblas_datatype_bf16_r,
ldc,
C,
rocblas_datatype_bf16_r,
ldc,
rocblas_datatype_f32_r,
algo,
0,
0));
});
}
template <> template <>
template <typename T> template <typename T>
void Blas<phi::GPUContext>::AXPY(int n, T alpha, const T *x, T *y) const { void Blas<phi::GPUContext>::AXPY(int n, T alpha, const T *x, T *y) const {
......
...@@ -18,5 +18,11 @@ limitations under the License. */ ...@@ -18,5 +18,11 @@ limitations under the License. */
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/addmm_grad_kernel_impl.h" #include "paddle/phi/kernels/impl/addmm_grad_kernel_impl.h"
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(addmm_grad,
addmm_grad, GPU, ALL_LAYOUT, phi::AddmmGradKernel, float, double) {} GPU,
ALL_LAYOUT,
phi::AddmmGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -18,4 +18,11 @@ limitations under the License. */ ...@@ -18,4 +18,11 @@ limitations under the License. */
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/addmm_kernel_impl.h" #include "paddle/phi/kernels/impl/addmm_kernel_impl.h"
PD_REGISTER_KERNEL(addmm, GPU, ALL_LAYOUT, phi::AddmmKernel, float, double) {} PD_REGISTER_KERNEL(addmm,
GPU,
ALL_LAYOUT,
phi::AddmmKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -18,13 +18,34 @@ limitations under the License. */ ...@@ -18,13 +18,34 @@ limitations under the License. */
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/addmm_grad_kernel.h" #include "paddle/phi/kernels/addmm_grad_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/for_range.h"
namespace phi { namespace phi {
template <typename T>
struct CopyOrScaleFunctor {
CopyOrScaleFunctor(const float scale, const T* x, T* output, int64_t numel)
: scale_(scale), x_(x), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
const MPType mp_scale = static_cast<MPType>(scale_);
const MPType mp_x = static_cast<MPType>(x_[idx]);
output_[idx] = static_cast<T>(mp_scale * mp_x);
}
private:
const float scale_;
const T* x_;
T* output_;
int64_t numel_;
};
template <typename T, template <typename T,
size_t D, size_t D,
int MajorType = Eigen::RowMajor, int MajorType = Eigen::RowMajor,
...@@ -45,6 +66,13 @@ void AddmmGradKernel(const Context& dev_ctx, ...@@ -45,6 +66,13 @@ void AddmmGradKernel(const Context& dev_ctx,
DenseTensor* input_grad, DenseTensor* input_grad,
DenseTensor* x_grad, DenseTensor* x_grad,
DenseTensor* y_grad) { DenseTensor* y_grad) {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
bool is_float16_or_bfloat16 = false;
if (std::is_same<T, phi::dtype::float16>::value ||
std::is_same<T, phi::dtype::bfloat16>::value) {
is_float16_or_bfloat16 = true;
}
auto in_dims = input.dims(); auto in_dims = input.dims();
if (input.dims().size() == 1) { if (input.dims().size() == 1) {
in_dims = {1, input.dims()[0]}; in_dims = {1, input.dims()[0]};
...@@ -65,6 +93,7 @@ void AddmmGradKernel(const Context& dev_ctx, ...@@ -65,6 +93,7 @@ void AddmmGradKernel(const Context& dev_ctx,
} }
auto blas = funcs::GetBlas<Context, T>(dev_ctx); auto blas = funcs::GetBlas<Context, T>(dev_ctx);
auto mt_blas = funcs::GetBlas<Context, MPType>(dev_ctx);
if (input_grad) { if (input_grad) {
dev_ctx.template Alloc<T>(input_grad); dev_ctx.template Alloc<T>(input_grad);
total_elems = in_dims[0] * in_dims[1]; total_elems = in_dims[0] * in_dims[1];
...@@ -78,19 +107,60 @@ void AddmmGradKernel(const Context& dev_ctx, ...@@ -78,19 +107,60 @@ void AddmmGradKernel(const Context& dev_ctx,
Array2(input_grad->dims()[0], input_grad->dims()[1]); Array2(input_grad->dims()[0], input_grad->dims()[1]);
if (row_compress && col_compress) { if (row_compress && col_compress) {
eigen_dinput.device(place) = if (!is_float16_or_bfloat16) {
eigen_dout.sum().eval().reshape(eigen_dinput_shape); eigen_dinput.device(place) =
eigen_dout.sum().eval().reshape(eigen_dinput_shape);
} else {
eigen_dinput.device(place) = eigen_dout.template cast<MPType>()
.sum()
.eval()
.reshape(eigen_dinput_shape)
.template cast<T>();
}
} else if (row_compress) { } else if (row_compress) {
eigen_dinput.device(place) = if (!is_float16_or_bfloat16) {
eigen_dout.sum(Array1(0)).eval().reshape(eigen_dinput_shape); eigen_dinput.device(place) =
eigen_dout.sum(Array1(0)).eval().reshape(eigen_dinput_shape);
} else {
eigen_dinput.device(place) = eigen_dout.template cast<MPType>()
.sum(Array1(0))
.eval()
.reshape(eigen_dinput_shape)
.template cast<T>();
}
} else if (col_compress) { } else if (col_compress) {
eigen_dinput.device(place) = if (!is_float16_or_bfloat16) {
eigen_dout.sum(Array1(1)).eval().reshape(eigen_dinput_shape); eigen_dinput.device(place) =
eigen_dout.sum(Array1(1)).eval().reshape(eigen_dinput_shape);
} else {
eigen_dinput.device(place) = eigen_dout.template cast<MPType>()
.sum(Array1(1))
.eval()
.reshape(eigen_dinput_shape)
.template cast<T>();
}
} else { } else {
blas.VCOPY(total_elems, out_grad.data<T>(), input_grad->data<T>()); // The VCOPY does not support the float16, bfloat16
if (!is_float16_or_bfloat16) {
mt_blas.VCOPY(
total_elems, out_grad.data<MPType>(), input_grad->data<MPType>());
} else {
phi::funcs::ForRange<Context> for_range(dev_ctx, total_elems);
CopyOrScaleFunctor<T> functor(
1, out_grad.data<T>(), input_grad->data<T>(), total_elems);
for_range(functor);
}
} }
blas.SCAL(total_elems, beta, input_grad->data<T>()); // The SCAL does not support the float16, bfloat16
if (!is_float16_or_bfloat16) {
mt_blas.SCAL(total_elems, beta, input_grad->data<MPType>());
} else {
phi::funcs::ForRange<Context> for_range(dev_ctx, total_elems);
CopyOrScaleFunctor<T> functor(
beta, input_grad->data<T>(), input_grad->data<T>(), total_elems);
for_range(functor);
}
if (input.dims().size() == 1) { if (input.dims().size() == 1) {
input_grad->Resize(input.dims()); input_grad->Resize(input.dims());
...@@ -101,14 +171,28 @@ void AddmmGradKernel(const Context& dev_ctx, ...@@ -101,14 +171,28 @@ void AddmmGradKernel(const Context& dev_ctx,
total_elems = x.dims()[0] * x.dims()[1]; total_elems = x.dims()[0] * x.dims()[1];
// x_grad = out_grad * y'. x_grad: M x K, out_grad : M x N, y : K x N // x_grad = out_grad * y'. x_grad: M x K, out_grad : M x N, y : K x N
blas.MatMul(out_grad, false, y, true, x_grad); blas.MatMul(out_grad, false, y, true, x_grad);
blas.SCAL(total_elems, alpha, x_grad->data<T>()); if (!is_float16_or_bfloat16) {
mt_blas.SCAL(total_elems, alpha, x_grad->data<MPType>());
} else {
phi::funcs::ForRange<Context> for_range(dev_ctx, total_elems);
CopyOrScaleFunctor<T> functor(
alpha, x_grad->data<T>(), x_grad->data<T>(), total_elems);
for_range(functor);
}
} }
if (y_grad) { if (y_grad) {
dev_ctx.template Alloc<T>(y_grad); dev_ctx.template Alloc<T>(y_grad);
total_elems = x.dims()[1] * y.dims()[1]; total_elems = x.dims()[1] * y.dims()[1];
// y_grad = x' * out_grad. y_grad K x N, out_grad : M x N, x : M x K // y_grad = x' * out_grad. y_grad K x N, out_grad : M x N, x : M x K
blas.MatMul(x, true, out_grad, false, y_grad); blas.MatMul(x, true, out_grad, false, y_grad);
blas.SCAL(total_elems, alpha, y_grad->data<T>()); if (!is_float16_or_bfloat16) {
mt_blas.SCAL(total_elems, alpha, y_grad->data<MPType>());
} else {
phi::funcs::ForRange<Context> for_range(dev_ctx, total_elems);
CopyOrScaleFunctor<T> functor(
alpha, y_grad->data<T>(), y_grad->data<T>(), total_elems);
for_range(functor);
}
} }
} }
......
...@@ -112,17 +112,19 @@ void AddmmKernel(const Context& dev_ctx, ...@@ -112,17 +112,19 @@ void AddmmKernel(const Context& dev_ctx,
funcs::EigenBroadcast<std::decay_t<decltype(place)>, T, 2>::Eval( funcs::EigenBroadcast<std::decay_t<decltype(place)>, T, 2>::Eval(
place, eigen_out, eigen_input, bcast_dims); place, eigen_out, eigen_input, bcast_dims);
T t_alpha = static_cast<T>(alpha);
T t_beta = static_cast<T>(beta);
blas.GEMM(false, blas.GEMM(false,
false, false,
x_dims[0], x_dims[0],
y_dims[1], y_dims[1],
x_dims[1], x_dims[1],
alpha, t_alpha,
x.data<T>(), x.data<T>(),
x_dims[1], x_dims[1],
y.data<T>(), y.data<T>(),
y_dims[1], y_dims[1],
beta, t_beta,
out->data<T>(), out->data<T>(),
y_dims[1]); y_dims[1]);
} }
......
...@@ -15,11 +15,11 @@ ...@@ -15,11 +15,11 @@
import unittest import unittest
import numpy as np import numpy as np
from eager_op_test import OpTest from eager_op_test import OpTest, convert_float_to_uint16
import paddle import paddle
from paddle import fluid from paddle import fluid
from paddle.fluid import Program, program_guard from paddle.fluid import Program, core, program_guard
class TestAddMMOp(OpTest): class TestAddMMOp(OpTest):
...@@ -27,7 +27,6 @@ class TestAddMMOp(OpTest): ...@@ -27,7 +27,6 @@ class TestAddMMOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "addmm" self.op_type = "addmm"
self.python_api = paddle.addmm self.python_api = paddle.addmm
self.dtype = np.float64
self.init_dtype_type() self.init_dtype_type()
self.inputs = { self.inputs = {
'Input': np.random.random((100, 1)).astype(self.dtype), 'Input': np.random.random((100, 1)).astype(self.dtype),
...@@ -40,7 +39,7 @@ class TestAddMMOp(OpTest): ...@@ -40,7 +39,7 @@ class TestAddMMOp(OpTest):
} }
def init_dtype_type(self): def init_dtype_type(self):
pass self.dtype = np.float64
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -58,6 +57,62 @@ class TestAddMMOp(OpTest): ...@@ -58,6 +57,62 @@ class TestAddMMOp(OpTest):
self.check_grad(['Input'], 'Out', no_grad_set=None) self.check_grad(['Input'], 'Out', no_grad_set=None)
class TestAddMMFP16Op(TestAddMMOp):
def init_dtype_type(self):
self.dtype = np.float16
def test_check_output(self):
self.check_output(atol=1e-2)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestAddMMBF16Op(OpTest):
def setUp(self):
self.op_type = "addmm"
self.python_api = paddle.addmm
self.init_dtype_type()
self.inputs = {
'Input': np.random.random((100, 1)).astype(self.np_dtype),
'X': np.random.random((100, 10)).astype(self.np_dtype),
'Y': np.random.random((10, 20)).astype(self.np_dtype),
}
self.outputs = {
'Out': self.inputs['Input']
+ np.dot(self.inputs['X'], self.inputs['Y'])
}
self.inputs['Input'] = convert_float_to_uint16(self.inputs['Input'])
self.inputs['X'] = convert_float_to_uint16(self.inputs['X'])
self.inputs['Y'] = convert_float_to_uint16(self.inputs['Y'])
self.outputs['Out'] = convert_float_to_uint16(self.outputs['Out'])
self.place = core.CUDAPlace(0)
def init_dtype_type(self):
self.dtype = np.uint16
self.np_dtype = np.float32
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad_normal(self):
self.check_grad_with_place(self.place, ['Input', 'X', 'Y'], 'Out')
def test_check_grad_x(self):
self.check_grad_with_place(self.place, ['X'], 'Out', no_grad_set=None)
def test_check_grad_y(self):
self.check_grad_with_place(self.place, ['Y'], 'Out', no_grad_set=None)
def test_check_grad_input(self):
self.check_grad_with_place(
self.place, ['Input'], 'Out', no_grad_set=None
)
class TestAddMMOpError(unittest.TestCase): class TestAddMMOpError(unittest.TestCase):
# test error # test error
def test_errors(self): def test_errors(self):
......
...@@ -1959,10 +1959,14 @@ def addmm(input, x, y, beta=1.0, alpha=1.0, name=None): ...@@ -1959,10 +1959,14 @@ def addmm(input, x, y, beta=1.0, alpha=1.0, name=None):
helper = LayerHelper("addmm", **locals()) helper = LayerHelper("addmm", **locals())
check_variable_and_dtype( check_variable_and_dtype(
input, 'Input', ['float32', 'float64'], 'addmm' input, 'Input', ['float16', 'float32', 'float64', 'uint16'], 'addmm'
)
check_variable_and_dtype(
x, 'X', ['float16', 'float32', 'float64', 'uint16'], 'addmm'
)
check_variable_and_dtype(
y, 'Y', ['float16', 'float32', 'float64', 'uint16'], 'addmm'
) )
check_variable_and_dtype(x, 'X', ['float32', 'float64'], 'addmm')
check_variable_and_dtype(y, 'Y', ['float32', 'float64'], 'addmm')
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op( helper.append_op(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册