未验证 提交 f59c5d8b 编写于 作者: S sneaxiy 提交者: GitHub

Add fused_linear_param_grad_add_kernel (#51805)

* add fused_linear_param_grad_add_kernel

* fix compile error

* remove flag

* fix ci compile error

* fix ci compile error

* revert pylayer revision

* fix ci ut

* improve performance
上级 a765eb26
......@@ -587,7 +587,7 @@ static cublasLtEpilogue_t GetEpilogueGradType(
}
}
template <typename T, bool TransX, bool TransY>
template <typename T, typename DXT, typename DYT, bool TransX, bool TransY>
void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
const phi::DenseTensor* dout,
const phi::DenseTensor* x,
......@@ -600,8 +600,12 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
phi::DenseTensor* dx,
phi::DenseTensor* dy,
phi::DenseTensor* dbias,
bool use_addto) {
bool use_addto_dx,
bool use_addto_dy) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
static_assert(std::is_same<DXT, T>::value || std::is_same<DXT, MT>::value);
static_assert(std::is_same<DYT, T>::value || std::is_same<DYT, MT>::value);
using Trait = FusedGEMMGradTrait<TransX, TransY>;
cudaDataType_t mat_type = phi::backends::gpu::ToCudaDataType<T>();
......@@ -619,8 +623,8 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
cudaStream_t stream = dev_ctx.stream();
MT alpha = static_cast<MT>(1.0);
MT beta_dx = use_addto ? static_cast<MT>(1.0) : static_cast<MT>(0.0);
MT beta_dy = static_cast<MT>(0.0);
MT beta_dx = use_addto_dx ? static_cast<MT>(1.0) : static_cast<MT>(0.0);
MT beta_dy = use_addto_dy ? static_cast<MT>(1.0) : static_cast<MT>(0.0);
cublasLtMatrixLayout_t dout_desc = nullptr, dout_trans_desc = nullptr;
cublasLtMatrixLayout_t x_desc = nullptr, x_trans_desc = nullptr;
......@@ -687,7 +691,11 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
auto b_trans = BoolToCuBlasEnum(Trait::kXGradBTrans);
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&dx_desc, mat_type, x_col, x_row, x_col));
&dx_desc,
phi::backends::gpu::ToCudaDataType<DXT>(),
x_col,
x_row,
x_col));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate(
&dx_operation_desc, compute_type, scale_type));
......@@ -735,7 +743,7 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
workspace_size,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
auto* dx_data = dev_ctx.Alloc<T>(dx, dx->numel() * sizeof(T));
auto* dx_data = dev_ctx.Alloc<DXT>(dx, dx->numel() * sizeof(DXT));
const auto* y_data = y->data<T>();
const auto* dout_data = dout->data<T>();
const auto* a_data = kXGradAIsDZ ? dout_data : y_data;
......@@ -806,7 +814,11 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
auto b_trans = BoolToCuBlasEnum(Trait::kYGradBTrans);
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&dy_desc, mat_type, y_col, y_row, y_col));
&dy_desc,
phi::backends::gpu::ToCudaDataType<DYT>(),
y_col,
y_row,
y_col));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate(
&dy_operation_desc, compute_type, scale_type));
......@@ -843,7 +855,8 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
sizeof(epiloque_func_for_dy)));
if (dbias) {
auto* dbias_data = dev_ctx.Alloc<T>(dbias, dbias->numel() * sizeof(T));
auto* dbias_data =
dev_ctx.Alloc<DYT>(dbias, dbias->numel() * sizeof(DYT));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dy_operation_desc,
......@@ -856,7 +869,7 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
dev_ctx.GetPlace(),
workspace_size,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
auto* dy_data = dev_ctx.Alloc<T>(dy, dy->numel() * sizeof(T));
auto* dy_data = dev_ctx.Alloc<DYT>(dy, dy->numel() * sizeof(DYT));
const auto* dout_data = dout->data<T>();
const auto* x_data = x->data<T>();
const auto* a_data = kYGradAIsDZ ? dout_data : x_data;
......@@ -897,7 +910,7 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
}
}
template <typename T>
template <typename T, typename DXT = T, typename DYT = T>
void ComputeFusedGemmEpilogueBackward(const phi::GPUContext& dev_ctx,
const phi::DenseTensor* dout,
const phi::DenseTensor* x,
......@@ -912,70 +925,79 @@ void ComputeFusedGemmEpilogueBackward(const phi::GPUContext& dev_ctx,
phi::DenseTensor* dx,
phi::DenseTensor* dy,
phi::DenseTensor* dbias,
bool use_addto = false) {
bool use_addto_dx = false,
bool use_addto_dy = false) {
VLOG(10) << "M=" << M << ", K=" << K << ", N=" << N << ", trans_x=" << trans_x
<< ", trans_y=" << trans_y
<< ", activation_grad=" << activation_grad;
if (trans_x) {
if (trans_y) {
ComputeFusedGemmEpilogueBackwardImpl<T, true, true>(dev_ctx,
dout,
x,
y,
reserve_space,
M,
N,
K,
activation_grad,
dx,
dy,
dbias,
use_addto);
ComputeFusedGemmEpilogueBackwardImpl<T, DXT, DYT, true, true>(
dev_ctx,
dout,
x,
y,
reserve_space,
M,
N,
K,
activation_grad,
dx,
dy,
dbias,
use_addto_dx,
use_addto_dy);
} else {
ComputeFusedGemmEpilogueBackwardImpl<T, true, false>(dev_ctx,
dout,
x,
y,
reserve_space,
M,
N,
K,
activation_grad,
dx,
dy,
dbias,
use_addto);
ComputeFusedGemmEpilogueBackwardImpl<T, DXT, DYT, true, false>(
dev_ctx,
dout,
x,
y,
reserve_space,
M,
N,
K,
activation_grad,
dx,
dy,
dbias,
use_addto_dx,
use_addto_dy);
}
} else {
if (trans_y) {
ComputeFusedGemmEpilogueBackwardImpl<T, false, true>(dev_ctx,
dout,
x,
y,
reserve_space,
M,
N,
K,
activation_grad,
dx,
dy,
dbias,
use_addto);
ComputeFusedGemmEpilogueBackwardImpl<T, DXT, DYT, false, true>(
dev_ctx,
dout,
x,
y,
reserve_space,
M,
N,
K,
activation_grad,
dx,
dy,
dbias,
use_addto_dx,
use_addto_dy);
} else {
ComputeFusedGemmEpilogueBackwardImpl<T, false, false>(dev_ctx,
dout,
x,
y,
reserve_space,
M,
N,
K,
activation_grad,
dx,
dy,
dbias,
use_addto);
ComputeFusedGemmEpilogueBackwardImpl<T, DXT, DYT, false, false>(
dev_ctx,
dout,
x,
y,
reserve_space,
M,
N,
K,
activation_grad,
dx,
dy,
dbias,
use_addto_dx,
use_addto_dy);
}
}
}
......
......@@ -596,6 +596,16 @@
func : frame
backward : frame_grad
- op : fused_linear_param_grad_add
args : (Tensor x, Tensor dout, Tensor dweight, Tensor dbias, bool multi_precision = true)
output : Tensor(dweight_out), Tensor(dbias_out)
infer_meta:
func : FusedLinearParamGradAddInferMeta
optional : dweight, dbias
kernel:
func : fused_linear_param_grad_add
data_type : dout
- op : gather_nd
args : (Tensor x, Tensor index)
output : Tensor
......
......@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/phi/core/meta_tensor.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/concat_funcs.h"
namespace phi {
std::vector<DDim> GetMetaTensorsDim(
......@@ -1229,6 +1230,65 @@ void EditDistanceInferMeta(const MetaTensor& hyps,
sequencenum->set_dtype(DataType::FLOAT32);
}
void FusedLinearParamGradAddInferMeta(const MetaTensor& x,
const MetaTensor& dout,
const MetaTensor& dweight,
const MetaTensor& dbias,
bool multi_precision,
MetaTensor* dweight_out,
MetaTensor* dbias_out) {
const auto dtype = dout.dtype();
PADDLE_ENFORCE_EQ(
x.dtype(),
dtype,
phi::errors::InvalidArgument(
"The data type of Input(x) and Input(dout) must be the same."));
const auto& x_dims = x.dims();
const auto& dout_dims = dout.dims();
int rank = dout_dims.size();
PADDLE_ENFORCE_EQ(
x_dims.size(),
rank,
phi::errors::InvalidArgument(
"The shape of Input(x) and Input(dout) do not match: %s vs %s.",
x_dims,
dout_dims));
for (int i = 0; i + 1 < x_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(
x_dims[i],
dout_dims[i],
phi::errors::InvalidArgument(
"The shape of Input(x) and Input(dout) do not match: %s vs %s.",
x_dims,
dout_dims));
}
const phi::DDim& weight_dims = {x_dims[rank - 1], dout_dims[rank - 1]};
if (dweight) {
PADDLE_ENFORCE_EQ(
weight_dims,
dweight.dims(),
phi::errors::InvalidArgument(
"The shape of input(dweight) does not match the other inputs."));
}
const auto mp_dtype =
(dtype == DataType::FLOAT16 || dtype == DataType::BFLOAT16)
? DataType::FLOAT32
: dtype;
if (dbias_out) {
dbias_out->set_dims({weight_dims[1]});
dbias_out->set_dtype(multi_precision ? mp_dtype : dtype);
}
if (dweight_out) {
dweight_out->set_dims(weight_dims);
dweight_out->set_dtype(multi_precision ? mp_dtype : dtype);
}
}
void GenerateProposalsV2InferMeta(const MetaTensor& scores,
const MetaTensor& bbox_deltas,
const MetaTensor& im_shape,
......
......@@ -259,6 +259,14 @@ void EditDistanceInferMeta(const MetaTensor& hyps,
MetaTensor* sequencenum,
MetaTensor* out);
void FusedLinearParamGradAddInferMeta(const MetaTensor& x,
const MetaTensor& dout,
const MetaTensor& dweight,
const MetaTensor& dbias,
bool multi_precision,
MetaTensor* dweight_out,
MetaTensor* dbias_out);
void GenerateProposalsV2InferMeta(const MetaTensor& scores,
const MetaTensor& bbox_deltas,
const MetaTensor& im_shape,
......
......@@ -103,6 +103,7 @@ file(GLOB kernel_h "*.h" "selected_rows/*.h" "sparse/*.h" "strings/*.h")
file(GLOB kernel_impl_h "impl/*.h" "selected_rows/impl/*.h")
file(GLOB kernel_primitive_h "primitive/*.h")
# fusion ops would be included here
file(
GLOB
kernel_cu
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void FusedLinearParamGradAdd(const Context &ctx,
const DenseTensor &x,
const DenseTensor &dout,
const paddle::optional<DenseTensor> &dweight,
const paddle::optional<DenseTensor> &dbias,
bool multi_precision,
DenseTensor *dweight_out,
DenseTensor *dbias_out);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/fusion/fused_linear_param_grad_add_kernel.h"
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
#include "paddle/fluid/operators/fused/fused_gemm_epilogue_op.h"
#endif
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/kernels/elementwise_add_kernel.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"
namespace phi {
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
template <typename T, typename MT, typename Context>
void FusedLinearParamGradAddImpl(const Context &ctx,
const DenseTensor &x,
const DenseTensor &dout,
const paddle::optional<DenseTensor> &dbias,
int64_t M,
int64_t K,
int64_t N,
bool use_addto,
DenseTensor *dweight_out,
DenseTensor *dbias_out) {
constexpr bool kIsMultiPrecision = !std::is_same<T, MT>::value;
const bool fuse_bias_grad = kIsMultiPrecision && dweight_out;
if (dweight_out) {
paddle::operators::ComputeFusedGemmEpilogueBackward<T, T, MT>(
ctx,
&dout,
&x,
nullptr,
nullptr,
M,
N,
K,
false,
false,
"none",
nullptr,
dweight_out,
fuse_bias_grad ? dbias_out : nullptr,
false,
use_addto);
}
if (dbias_out == nullptr) return;
if (!fuse_bias_grad) {
auto dout_copy = dout;
dout_copy.Resize({M, N});
if (kIsMultiPrecision) {
*dbias_out = phi::Sum<T, Context>(
ctx,
dout_copy,
{0},
paddle::experimental::CppTypeToDataType<MT>::Type(),
false);
} else {
*dbias_out = phi::Sum<T, Context>(
ctx,
dout_copy,
{0},
paddle::experimental::CppTypeToDataType<T>::Type(),
false);
}
}
if (dbias) {
if (kIsMultiPrecision) {
phi::AddKernel<MT, Context>(ctx, *dbias_out, dbias.get(), dbias_out);
} else {
phi::AddKernel<T, Context>(ctx, *dbias_out, dbias.get(), dbias_out);
}
}
}
template <int LogLevel = 10>
static void PrintMeta(const DenseTensor &t, const char *name) {
PADDLE_ENFORCE_EQ(
t.initialized(),
true,
phi::errors::InvalidArgument("Tensor(%s) is not initialized.", name));
std::stringstream ss;
ss << "Tensor(" << name << "): ";
ss << "dtype(" << t.dtype() << "), ";
ss << "shape(" << t.dims() << "), ";
ss << "place(" << t.place() << "), ";
ss << "ptr(" << t.data() << ")";
VLOG(LogLevel) << ss.str();
}
template <int LogLevel = 10>
static void PrintMeta(const DenseTensor *t, const char *name) {
if (t == nullptr) {
VLOG(LogLevel) << "Tensor(" << name << "): None";
} else {
PrintMeta<LogLevel>(*t, name);
}
}
template <int LogLevel = 10>
static void PrintMeta(const paddle::optional<DenseTensor> &t,
const char *name) {
const auto *t_ptr = t ? &(t.get()) : nullptr;
PrintMeta<LogLevel>(t_ptr, name);
}
template <typename T, typename Context>
void FusedLinearParamGradAdd(const Context &ctx,
const DenseTensor &x,
const DenseTensor &dout,
const paddle::optional<DenseTensor> &dweight,
const paddle::optional<DenseTensor> &dbias,
bool multi_precision,
DenseTensor *dweight_out,
DenseTensor *dbias_out) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
bool use_addto = false;
if (dweight_out) {
if (dweight) {
use_addto = true;
*dweight_out = dweight.get();
if (multi_precision) {
PADDLE_ENFORCE_EQ(
dweight_out->dtype(),
paddle::experimental::CppTypeToDataType<MT>::Type(),
phi::errors::InvalidArgument("Invaid data type error."));
} else {
PADDLE_ENFORCE_EQ(
dweight_out->dtype(),
paddle::experimental::CppTypeToDataType<T>::Type(),
phi::errors::InvalidArgument("Invaid data type error."));
}
} else {
if (multi_precision) {
ctx.template Alloc<MT>(dweight_out);
} else {
ctx.template Alloc<T>(dweight_out);
}
}
}
if (std::is_same<T, MT>::value) {
multi_precision = false;
}
if (dbias_out) {
ctx.template Alloc<T>(dbias_out);
}
int64_t K = x.dims()[x.dims().size() - 1];
int64_t M = x.numel() / K;
int64_t N = dout.dims()[dout.dims().size() - 1];
constexpr int kLogLevel = 10;
if (VLOG_IS_ON(kLogLevel)) {
PrintMeta<kLogLevel>(x, "x");
PrintMeta<kLogLevel>(dout, "dout");
PrintMeta<kLogLevel>(dweight, "dweight");
PrintMeta<kLogLevel>(dbias, "dbias");
PrintMeta<kLogLevel>(dweight_out, "dweight_out");
PrintMeta<kLogLevel>(dbias_out, "dbias_out");
VLOG(kLogLevel) << "multi_precision = " << multi_precision;
VLOG(kLogLevel) << "use_addto = " << use_addto;
VLOG(kLogLevel) << "M = " << M;
VLOG(kLogLevel) << "N = " << N;
VLOG(kLogLevel) << "K = " << K;
}
if (multi_precision) {
FusedLinearParamGradAddImpl<T, MT, Context>(
ctx, x, dout, dbias, M, K, N, use_addto, dweight_out, dbias_out);
} else {
FusedLinearParamGradAddImpl<T, T, Context>(
ctx, x, dout, dbias, M, K, N, use_addto, dweight_out, dbias_out);
}
if (VLOG_IS_ON(kLogLevel)) {
ctx.Wait();
}
}
#else
template <typename T, typename Context>
void FusedLinearParamGradAdd(const Context &ctx,
const DenseTensor &x,
const DenseTensor &dout,
const paddle::optional<DenseTensor> &dweight,
const paddle::optional<DenseTensor> &dbias,
bool multi_precision,
DenseTensor *dweight_out,
DenseTensor *dbias_out) {
PADDLE_THROW(phi::errors::Unimplemented(
"FusedLinearParamGradAdd is only supported when CUDA_VERSION >= 11.6."));
}
#endif
} // namespace phi
PD_REGISTER_KERNEL(fused_linear_param_grad_add,
GPU,
ALL_LAYOUT,
phi::FusedLinearParamGradAdd,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import unittest
import numpy as np
import paddle
from paddle import _C_ops
def get_cuda_version():
result = os.popen("nvcc --version").read()
regex = r'release (\S+),'
match = re.search(regex, result)
if match:
num = str(match.group(1))
integer, decimal = num.split('.')
return int(integer) * 1000 + int(float(decimal) * 10)
else:
return -1
def promote_dtype(x):
if x.dtype in [paddle.float16, paddle.bfloat16]:
return x.astype(paddle.float32)
else:
return x
def recreate(x, multi_precision):
if isinstance(x, (list, tuple)):
return [recreate(item, multi_precision) for item in x]
if x is None:
return None
if multi_precision:
x = promote_dtype(x)
return paddle.to_tensor(x.numpy())
def run_ground_truth(x, dy, dweight, dbias, multi_precision):
x, dy, dweight, dbias = recreate([x, dy, dweight, dbias], multi_precision)
dweight_tmp = paddle.matmul(
x.reshape([-1, x.shape[-1]]),
dy.reshape([-1, dy.shape[-1]]),
transpose_x=True,
)
if dweight is None:
dweight = dweight_tmp
else:
assert dweight.shape == dweight_tmp.shape
assert dweight.dtype == dweight.dtype
dweight += dweight_tmp
dbias_tmp = dy.reshape([-1, dy.shape[-1]]).sum(axis=0)
if dbias is None:
dbias = dbias_tmp
else:
assert dbias.shape == dbias_tmp.shape
assert dbias.dtype == dbias_tmp.dtype
dbias += dbias_tmp
return promote_dtype(dweight).numpy(), promote_dtype(dbias).numpy()
def run_fused_linear_param_grad_add(x, dy, dweight, dbias, multi_precision):
dweight_new, dbias_new = _C_ops.fused_linear_param_grad_add(
x, dy, dweight, dbias, multi_precision
)
if dweight is not None:
assert dweight_new.data_ptr() == dweight.data_ptr()
return promote_dtype(dweight_new).numpy(), promote_dtype(dbias_new).numpy()
class TestMainClassBase(unittest.TestCase):
def setUp(self):
self.shape = [3, 4, 32]
self.output_size = 128
self.dtype = paddle.float16
def config(self):
pass
def rand(self, shape, dtype=None):
x = np.random.randint(low=-5, high=5, size=shape)
x = paddle.to_tensor(x)
return x.astype(dtype or self.dtype)
def generate_rand_inputs(self, has_dweight, has_dbias, multi_precision):
x_shape = self.shape
dy_shape = self.shape[:-1] + [self.output_size]
dweight_shape = [self.shape[-1], self.output_size]
dbias_shape = [self.output_size]
x = self.rand(x_shape)
dy = self.rand(dy_shape)
if has_dweight:
dweight = self.rand(dweight_shape)
if multi_precision:
dweight = promote_dtype(dweight)
else:
dweight = None
if has_dbias:
dbias = self.rand(dbias_shape)
if multi_precision:
dbias = promote_dtype(dbias)
else:
dbias = None
return x, dy, dweight, dbias
def check_main(self, has_dweight, has_dbias, multi_precision):
print(has_dweight, has_dbias, multi_precision)
x, dy, dweight, dbias = self.generate_rand_inputs(
has_dweight, has_dbias, multi_precision
)
res1 = run_ground_truth(x, dy, dweight, dbias, multi_precision)
res2 = run_fused_linear_param_grad_add(
x, dy, dweight, dbias, multi_precision
)
self.assertEqual(len(res1), len(res2))
for r1, r2 in zip(res1, res2):
max_diff = np.max(np.abs(r1 - r2))
self.assertLess(max_diff, 1e-10)
def test_main(self):
if not paddle.is_compiled_with_cuda() or paddle.is_compiled_with_rocm():
return
prop = paddle.device.cuda.get_device_properties()
cap = prop.major * 10 + prop.minor
if self.dtype == paddle.bfloat16 and cap < 80:
return
if get_cuda_version() < 11060:
return
for has_dweight in [False, True]:
for has_dbias in [False, True]:
for multi_precision in [False, True]:
self.check_main(has_dweight, has_dbias, multi_precision)
class TestMainClassBF16(TestMainClassBase):
def config(self):
self.dtype = paddle.bfloat16
class TestMainClassFP32(TestMainClassBase):
def config(self):
self.dtype = paddle.float32
class TestMainClassFP64(TestMainClassBase):
def config(self):
self.dtype = paddle.float64
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册