未验证 提交 39b704c1 编写于 作者: Z Zhang Zheng 提交者: GitHub

[Cherry-Pick] AMP OP&Test support from Hackathon (#53522)

低精度算子支持和单测补充,合并 cherry pick 17个Hackathon PR,共覆盖25个OP的低精度支持及完善
上级 584d6105
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/prelu.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
......@@ -135,14 +136,17 @@ void PreluScalarDirectCUDAFunctor<T>::operator()(gpuStream_t stream,
template class PreluChannelWiseDirectCUDAFunctor<float>;
template class PreluChannelWiseDirectCUDAFunctor<platform::float16>;
template class PreluChannelWiseDirectCUDAFunctor<platform::bfloat16>;
template class PreluChannelWiseDirectCUDAFunctor<double>;
template class PreluElementWiseDirectCUDAFunctor<float>;
template class PreluElementWiseDirectCUDAFunctor<platform::float16>;
template class PreluElementWiseDirectCUDAFunctor<platform::bfloat16>;
template class PreluElementWiseDirectCUDAFunctor<double>;
template class PreluScalarDirectCUDAFunctor<float>;
template class PreluScalarDirectCUDAFunctor<platform::float16>;
template class PreluScalarDirectCUDAFunctor<platform::bfloat16>;
template class PreluScalarDirectCUDAFunctor<double>;
} // namespace math
......
......@@ -1316,6 +1316,74 @@ inline void Blas<phi::GPUContext>::GEMM(bool transA,
});
}
template <>
template <>
inline void Blas<phi::GPUContext>::GEMM(bool transA,
bool transB,
int M,
int N,
int K,
phi::dtype::bfloat16 alpha,
const phi::dtype::bfloat16 *A,
int lda,
const phi::dtype::bfloat16 *B,
int ldb,
phi::dtype::bfloat16 beta,
phi::dtype::bfloat16 *C,
int ldc) const {
#if CUDA_VERSION >= 11000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
PADDLE_ENFORCE_GE(
context_.GetComputeCapability(),
80,
phi::errors::InvalidArgument(
"cublas bf16 gemm requires GPU compute capability >= 80,"
"but received %d",
context_.GetComputeCapability()));
float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
bool use_tensor_op_math = context_.tensor_core_available();
if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
}
VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False");
context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle,
cuTransB,
cuTransA,
N,
M,
K,
&h_alpha,
B,
CUDA_R_16BF,
ldb,
A,
CUDA_R_16BF,
lda,
&h_beta,
C,
CUDA_R_16BF,
ldc,
CUDA_R_32F,
algo));
});
#else
// raise error
PADDLE_THROW(phi::errors::Unimplemented(
"cublasGemmEx with bfloat16 is not supported on cuda <= 11"));
#endif // CUDA_VERSION >= 11000
}
template <>
template <typename T>
void Blas<phi::GPUContext>::AXPY(int n, T alpha, const T *x, T *y) const {
......
......@@ -751,7 +751,7 @@ inline void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
context_.GetComputeCapability(),
80,
phi::errors::InvalidArgument(
"rocblas fp16 gemm requires GPU compute capability >= 80,"
"rocblas bf16 gemm requires GPU compute capability >= 80,"
"but received %d",
context_.GetComputeCapability()));
......@@ -982,6 +982,70 @@ inline void Blas<phi::GPUContext>::GEMM(bool transA,
});
}
template <>
template <>
inline void Blas<phi::GPUContext>::GEMM(bool transA,
bool transB,
int M,
int N,
int K,
phi::dtype::bfloat16 alpha,
const phi::dtype::bfloat16 *A,
int lda,
const phi::dtype::bfloat16 *B,
int ldb,
phi::dtype::bfloat16 beta,
phi::dtype::bfloat16 *C,
int ldc) const {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
rocblas_operation cuTransA = (transA == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
rocblas_operation cuTransB = (transB == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
PADDLE_ENFORCE_GE(
context_.GetComputeCapability(),
80,
phi::errors::InvalidArgument(
"rocblas bf16 gemm requires GPU compute capability >= 80,"
"but received %d",
context_.GetComputeCapability()));
float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);
rocblas_gemm_algo algo = rocblas_gemm_algo_standard;
context_.TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) {
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::rocblas_gemm_ex(handle,
cuTransB,
cuTransA,
N,
M,
K,
&h_alpha,
B,
rocblas_datatype_bf16_r,
ldb,
A,
rocblas_datatype_bf16_r,
lda,
&h_beta,
C,
rocblas_datatype_bf16_r,
ldc,
C,
rocblas_datatype_bf16_r,
ldc,
rocblas_datatype_f32_r,
algo,
0,
0));
});
}
template <>
template <typename T>
void Blas<phi::GPUContext>::AXPY(int n, T alpha, const T *x, T *y) const {
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/enforce.h"
......@@ -189,6 +190,17 @@ struct FMinFunctor<dtype::float16> {
}
};
template <>
struct FMinFunctor<dtype::bfloat16> {
inline HOSTDEVICE dtype::bfloat16 operator()(const dtype::bfloat16 a,
const dtype::bfloat16 b) const {
float float_a = static_cast<float>(a);
float float_b = static_cast<float>(b);
auto result = std::fmin(float_a, float_b);
return static_cast<dtype::bfloat16>(result);
}
};
template <>
struct FMinFunctor<int> {
inline HOSTDEVICE int operator()(const int a, const int b) const {
......@@ -228,6 +240,17 @@ struct FMaxFunctor<dtype::float16> {
}
};
template <>
struct FMaxFunctor<dtype::bfloat16> {
inline HOSTDEVICE dtype::bfloat16 operator()(const dtype::bfloat16 a,
const dtype::bfloat16 b) const {
float float_a = static_cast<float>(a);
float float_b = static_cast<float>(b);
auto result = std::fmax(float_a, float_b);
return static_cast<dtype::bfloat16>(result);
}
};
template <>
struct FMaxFunctor<int> {
inline HOSTDEVICE int operator()(const int a, const int b) const {
......
......@@ -24,9 +24,11 @@ namespace funcs {
#define Instantiate_Template_Function(func) \
Instantiate_Template_Function_index_t( \
func, int) Instantiate_Template_Function_index_t(func, float) \
Instantiate_Template_Function_index_t(func, double) \
Instantiate_Template_Function_index_t(func, int64_t) \
Instantiate_Template_Function_index_t( \
func, double) Instantiate_Template_Function_index_t(func, int64_t) \
Instantiate_Template_Function_index_t(func, phi::dtype::float16) \
Instantiate_Template_Function_index_t(func, \
phi::dtype::bfloat16) \
Instantiate_Template_Function_index_t(func, unsigned char)
#define Instantiate_Template_Function_index_t(func, tensor_t) \
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/funcs/im2col.h"
namespace phi {
......@@ -71,7 +72,7 @@ __global__ void im2col(const T* data_im,
}
*data_col =
(rIdx >= im_height || rIdx < 0 || cIdx >= im_width || cIdx < 0)
? 0
? T(0)
: data_im[im_idx];
data_col += col_height * col_width;
}
......@@ -173,7 +174,7 @@ __global__ void col2im(int n,
int input_channels = n / im_height / im_width;
if (index < n) {
T val = 0;
T val = static_cast<T>(0);
int w = (data_layout != DataLayout::kNHWC
? index % im_width + padding_width
: (index / input_channels) % im_width + padding_width);
......@@ -309,12 +310,24 @@ template class Im2ColFunctor<phi::funcs::ColFormat::kCFO,
template class Im2ColFunctor<phi::funcs::ColFormat::kCFO,
phi::GPUContext,
double>;
template class Im2ColFunctor<phi::funcs::ColFormat::kCFO,
phi::GPUContext,
phi::dtype::float16>;
template class Im2ColFunctor<phi::funcs::ColFormat::kCFO,
phi::GPUContext,
phi::dtype::bfloat16>;
template class Col2ImFunctor<phi::funcs::ColFormat::kCFO,
phi::GPUContext,
float>;
template class Col2ImFunctor<phi::funcs::ColFormat::kCFO,
phi::GPUContext,
double>;
template class Col2ImFunctor<phi::funcs::ColFormat::kCFO,
phi::GPUContext,
phi::dtype::float16>;
template class Col2ImFunctor<phi::funcs::ColFormat::kCFO,
phi::GPUContext,
phi::dtype::bfloat16>;
template <class T>
__global__ void im2colOCF(const T* im_data,
......@@ -560,13 +573,24 @@ template class Im2ColFunctor<phi::funcs::ColFormat::kOCF,
template class Im2ColFunctor<phi::funcs::ColFormat::kOCF,
phi::GPUContext,
double>;
template class Im2ColFunctor<phi::funcs::ColFormat::kOCF,
phi::GPUContext,
phi::dtype::float16>;
template class Im2ColFunctor<phi::funcs::ColFormat::kOCF,
phi::GPUContext,
phi::dtype::bfloat16>;
template class Col2ImFunctor<phi::funcs::ColFormat::kOCF,
phi::GPUContext,
float>;
template class Col2ImFunctor<phi::funcs::ColFormat::kOCF,
phi::GPUContext,
double>;
template class Col2ImFunctor<phi::funcs::ColFormat::kOCF,
phi::GPUContext,
phi::dtype::float16>;
template class Col2ImFunctor<phi::funcs::ColFormat::kOCF,
phi::GPUContext,
phi::dtype::bfloat16>;
} // namespace funcs
} // namespace phi
......@@ -1963,7 +1963,7 @@ __global__ void KernelMaxPool2dWithIdx(const int nthreads,
wstart = max(wstart, 0);
}
T1 ele = -FLT_MAX;
T1 ele = static_cast<T1>(-FLT_MAX);
int max_index = -1;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
......@@ -2015,7 +2015,7 @@ __global__ void AdaptiveKernelMaxPool2dWithIdx(const int nthreads,
wstart = AdaptStartIndex(w_offset, input_width, output_width);
wend = AdaptEndIndex(w_offset, input_width, output_width);
T1 ele = -FLT_MAX;
T1 ele = static_cast<T1>(-FLT_MAX);
int max_index = -1;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
......@@ -2089,7 +2089,7 @@ __global__ void KernelMaxPool2DWithIdxGrad(const int nthreads,
pwend = min((w_offset + padding_width) / stride_width + 1, output_width);
}
T1 input_grad_data = 0;
T1 input_grad_data = static_cast<T1>(0);
int input_current_featuremap_idx = h_offset * input_width + w_offset;
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
......@@ -2259,6 +2259,14 @@ template class MaxPool2dWithIndexFunctor<phi::GPUContext, float, int>;
template class MaxPool2dWithIndexGradFunctor<phi::GPUContext, float, int>;
template class MaxPool2dWithIndexFunctor<phi::GPUContext, double, int>;
template class MaxPool2dWithIndexGradFunctor<phi::GPUContext, double, int>;
template class MaxPool2dWithIndexFunctor<phi::GPUContext, dtype::float16, int>;
template class MaxPool2dWithIndexGradFunctor<phi::GPUContext,
dtype::float16,
int>;
template class MaxPool2dWithIndexFunctor<phi::GPUContext, dtype::bfloat16, int>;
template class MaxPool2dWithIndexGradFunctor<phi::GPUContext,
dtype::bfloat16,
int>;
template <typename T1, typename T2>
__global__ void KernelMaxPool3DWithIdx(const int ncd,
......@@ -2324,7 +2332,7 @@ __global__ void KernelMaxPool3DWithIdx(const int ncd,
wstart = max(wstart, 0);
}
T1 ele = -FLT_MAX;
T1 ele = static_cast<T1>(-FLT_MAX);
int max_index = -1;
for (int d = dstart; d < dend; ++d) {
for (int h = hstart; h < hend; ++h) {
......@@ -2560,6 +2568,14 @@ template class MaxPool3dWithIndexFunctor<phi::GPUContext, float, int>;
template class MaxPool3dWithIndexGradFunctor<phi::GPUContext, float, int>;
template class MaxPool3dWithIndexFunctor<phi::GPUContext, double, int>;
template class MaxPool3dWithIndexGradFunctor<phi::GPUContext, double, int>;
template class MaxPool3dWithIndexFunctor<phi::GPUContext, dtype::float16, int>;
template class MaxPool3dWithIndexGradFunctor<phi::GPUContext,
dtype::float16,
int>;
template class MaxPool3dWithIndexFunctor<phi::GPUContext, dtype::bfloat16, int>;
template class MaxPool3dWithIndexGradFunctor<phi::GPUContext,
dtype::bfloat16,
int>;
} // namespace funcs
} // namespace phi
......@@ -18,5 +18,11 @@ limitations under the License. */
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/addmm_grad_kernel_impl.h"
PD_REGISTER_KERNEL(
addmm_grad, GPU, ALL_LAYOUT, phi::AddmmGradKernel, float, double) {}
PD_REGISTER_KERNEL(addmm_grad,
GPU,
ALL_LAYOUT,
phi::AddmmGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -18,4 +18,11 @@ limitations under the License. */
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/addmm_kernel_impl.h"
PD_REGISTER_KERNEL(addmm, GPU, ALL_LAYOUT, phi::AddmmKernel, float, double) {}
PD_REGISTER_KERNEL(addmm,
GPU,
ALL_LAYOUT,
phi::AddmmKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -16,7 +16,7 @@
#include <vector>
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
......@@ -104,8 +104,10 @@ PD_REGISTER_KERNEL(broadcast_tensors_grad,
GPU,
ALL_LAYOUT,
phi::BroadcastTensorsGradKernel,
bool,
int,
int64_t,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -14,7 +14,7 @@
#include "paddle/phi/kernels/broadcast_tensors_kernel.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/broadcast_tensors_kernel_impl.h"
......@@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(broadcast_tensors,
int64_t,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -17,7 +17,7 @@
#include <typeinfo>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
......@@ -34,7 +34,7 @@ void ClipByNormKernel(const Context& dev_ctx,
return ClipByNormFunctor<float, Context>(dev_ctx, in, max_norm, output);
}
auto input = &in;
dev_ctx.template Alloc<dtype::float16>(output);
dev_ctx.template Alloc<T>(output);
PADDLE_ENFORCE_NOT_NULL(input,
phi::errors::InvalidArgument(
......@@ -49,20 +49,14 @@ void ClipByNormKernel(const Context& dev_ctx,
auto* tmp = &tmp_tensor;
tmp->Resize({1});
dev_ctx.template Alloc<float>(tmp);
phi::funcs::ReduceKernel<dtype::float16,
float,
kps::AddFunctor,
kps::SquareFunctor<dtype::float16, float>>(
dev_ctx,
*input,
tmp,
kps::SquareFunctor<dtype::float16, float>(),
reduce_dims);
phi::funcs::
ReduceKernel<T, float, kps::AddFunctor, kps::SquareFunctor<T, float>>(
dev_ctx, *input, tmp, kps::SquareFunctor<T, float>(), reduce_dims);
auto tmp_eigen = phi::EigenVector<float>::Flatten(*tmp);
auto x_norm = tmp_eigen.sqrt();
auto x = phi::EigenVector<dtype::float16>::Flatten(*input);
auto out = phi::EigenVector<dtype::float16>::Flatten(*output);
auto x = phi::EigenVector<T>::Flatten(*input);
auto out = phi::EigenVector<T>::Flatten(*output);
auto* place = dev_ctx.eigen_device();
auto temp = (x_norm <= max_norm).template cast<float>();
......@@ -72,7 +66,7 @@ void ClipByNormKernel(const Context& dev_ctx,
auto scaling =
(temp + (static_cast<float>(1) - temp) * max_norm / (x_norm + epsilon))
.template cast<dtype::float16>();
.template cast<T>();
Eigen::array<int, 1> one_dim{{1}};
Eigen::DSizes<int, 1> m_dsize(input->numel());
......@@ -86,4 +80,5 @@ PD_REGISTER_KERNEL(clip_by_norm,
ALL_LAYOUT,
phi::ClipByNormKernel,
float,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -98,6 +98,7 @@ PD_REGISTER_KERNEL(fmax_grad,
double,
int,
phi::dtype::float16,
phi::dtype::bfloat16,
int64_t) {}
PD_REGISTER_KERNEL(fmin_grad,
......@@ -108,6 +109,7 @@ PD_REGISTER_KERNEL(fmin_grad,
double,
int,
phi::dtype::float16,
phi::dtype::bfloat16,
int64_t) {}
PD_REGISTER_KERNEL(maximum_grad,
......
......@@ -105,5 +105,6 @@ PD_REGISTER_KERNEL(index_add_grad,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
int,
int64_t) {}
......@@ -123,5 +123,6 @@ PD_REGISTER_KERNEL(index_add,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
int,
int64_t) {}
......@@ -15,6 +15,7 @@
#include "paddle/phi/kernels/logspace_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/data_type_transform.h"
......@@ -25,25 +26,34 @@ namespace phi {
template <typename T>
__global__ void LogspaceKernelInner(
T start, T stop, double step, T base, int64_t size, T* out) {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType mt_start = static_cast<MPType>(start);
MPType mt_stop = static_cast<MPType>(stop);
MPType mt_base = static_cast<MPType>(base);
int64_t index = blockIdx.x * blockDim.x + threadIdx.x;
for (; index < size; index += blockDim.x * gridDim.x) {
if (index < size / 2) {
out[index] =
static_cast<T>(pow(static_cast<double>(base),
static_cast<double>(start + step * index)));
static_cast<T>(pow(static_cast<double>(mt_base),
static_cast<double>(mt_start + step * index)));
} else {
out[index] = static_cast<T>(
pow(static_cast<double>(base),
static_cast<double>(stop - step * (size - index - 1))));
pow(static_cast<double>(mt_base),
static_cast<double>(mt_stop - step * (size - index - 1))));
}
}
}
template <typename T>
__global__ void LogspaceSpecialKernel(T start, T base, T* out) {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType mt_start = static_cast<MPType>(start);
MPType mt_base = static_cast<MPType>(base);
out[0] = static_cast<T>(
pow(static_cast<double>(base), static_cast<double>(start)));
pow(static_cast<double>(mt_base), static_cast<double>(mt_start)));
}
template <typename T, typename Context>
......@@ -54,6 +64,8 @@ void LogspaceKernel(const Context& ctx,
const DenseTensor& base,
DataType dtype,
DenseTensor* out) {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
auto start_t = phi::funcs::TransDataType(ctx, start, dtype);
auto stop_t = phi::funcs::TransDataType(ctx, stop, dtype);
auto base_t = phi::funcs::TransDataType(ctx, base, dtype);
......@@ -71,6 +83,9 @@ void LogspaceKernel(const Context& ctx,
phi::Copy(ctx, base_t, phi::CPUPlace(), false, &n_base);
T base_data = n_base.data<T>()[0];
MPType mt_start_data = static_cast<MPType>(start_data);
MPType mt_stop_data = static_cast<MPType>(stop_data);
PADDLE_ENFORCE_GT(
num,
0,
......@@ -86,7 +101,7 @@ void LogspaceKernel(const Context& ctx,
int block = 512;
int grid = (num + block - 1) / block;
if (num != 1) {
step = (static_cast<double>(stop_data - start_data)) / (num - 1);
step = (static_cast<double>(mt_stop_data - mt_start_data)) / (num - 1);
LogspaceKernelInner<T><<<grid, block, 0, stream>>>(
start_data, stop_data, step, base_data, num, out_data);
} else {
......@@ -104,4 +119,6 @@ PD_REGISTER_KERNEL(logspace,
float,
int32_t,
int64_t,
double) {}
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -64,4 +64,5 @@ PD_REGISTER_KERNEL(matmul_with_flatten_double_grad,
phi::MatmulWithFlattenDoubleGradKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::float16) {}
......@@ -15,16 +15,15 @@ limitations under the License. */
#include "paddle/phi/kernels/multi_dot_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/multi_dot_kernel_impl.h"
using float16 = phi::dtype::float16;
PD_REGISTER_KERNEL(multi_dot_grad,
GPU,
ALL_LAYOUT,
phi::MultiDotGradKernel,
float,
double,
float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -15,11 +15,15 @@ limitations under the License. */
#include "paddle/phi/kernels/multi_dot_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/multi_dot_kernel_impl.h"
using float16 = phi::dtype::float16;
PD_REGISTER_KERNEL(
multi_dot, GPU, ALL_LAYOUT, phi::MultiDotKernel, float, double, float16) {}
PD_REGISTER_KERNEL(multi_dot,
GPU,
ALL_LAYOUT,
phi::MultiDotKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -38,7 +38,9 @@ PD_REGISTER_KERNEL(max_pool2d_with_index_grad,
ALL_LAYOUT,
phi::MaxPool2dWithIndexGradKernel,
float,
double) {
double,
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->InputAt(1).SetDataType(phi::CppTypeToDataType<int>::Type());
}
......@@ -55,6 +57,8 @@ PD_REGISTER_KERNEL(max_pool3d_with_index_grad,
ALL_LAYOUT,
phi::MaxPool3dWithIndexGradKernel,
float,
double) {
double,
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->InputAt(1).SetDataType(phi::CppTypeToDataType<int>::Type());
}
......@@ -32,7 +32,9 @@ PD_REGISTER_KERNEL(max_pool2d_with_index,
ALL_LAYOUT,
phi::MaxPool2dWithIndexKernel,
float,
double) {
double,
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->OutputAt(1).SetDataType(phi::CppTypeToDataType<int>::Type());
}
......@@ -49,6 +51,8 @@ PD_REGISTER_KERNEL(max_pool3d_with_index,
ALL_LAYOUT,
phi::MaxPool3dWithIndexKernel,
float,
double) {
double,
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->OutputAt(1).SetDataType(phi::CppTypeToDataType<int>::Type());
}
......@@ -189,4 +189,5 @@ PD_REGISTER_KERNEL(prelu_grad,
phi::PReluGradKernel,
float,
phi::dtype::float16,
phi::dtype::bfloat16,
double) {}
......@@ -79,4 +79,5 @@ PD_REGISTER_KERNEL(prelu,
phi::PReluKernel,
float,
phi::dtype::float16,
phi::dtype::bfloat16,
double) {}
......@@ -75,4 +75,5 @@ PD_REGISTER_KERNEL(put_along_axis_grad,
double,
int64_t,
int,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -82,4 +82,5 @@ PD_REGISTER_KERNEL(put_along_axis,
double,
int64_t,
int,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -28,6 +28,7 @@ namespace cub = hipcub;
#include "gflags/gflags.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
......@@ -165,4 +166,6 @@ PD_REGISTER_KERNEL(randperm,
float,
double,
int,
int64_t) {}
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -16,8 +16,63 @@
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/reduce_min_grad_kernel_impl.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/compare_functors.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
namespace phi {
template <typename T, typename Context>
void ReduceMinGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& out_grad,
const IntArray& dims,
bool keep_dim,
bool reduce_all,
DenseTensor* x_grad) {
dev_ctx.Alloc(x_grad, x.dtype());
reduce_all = recompute_reduce_all(x, dims, reduce_all);
// get reduce_dim
int dim_size = x.dims().size();
auto reduce_dims =
funcs::details::GetReduceDim(dims.GetData(), dim_size, reduce_all);
auto update_dims = vectorize(x.dims());
for (auto i : reduce_dims) {
update_dims[i] = 1;
}
// make new tensor of out and out_grad
phi::DenseTensor new_out(out.type());
new_out.ShareDataWith(out);
new_out.Resize(phi::make_ddim(update_dims));
phi::DenseTensor new_out_grad(out_grad.type());
new_out_grad.ShareDataWith(out_grad);
new_out_grad.Resize(phi::make_ddim(update_dims));
// make equal_out
phi::DenseTensor* equal_out = new phi::DenseTensor();
equal_out->Resize(x.dims());
dev_ctx.template Alloc<T>(equal_out);
// compute
// 1. equal_out = Equal(x, y)
std::vector<const phi::DenseTensor*> equal_inputs = {&new_out, &x};
std::vector<phi::DenseTensor*> equal_outputs = {equal_out};
funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx, equal_inputs, &equal_outputs, 0, funcs::EqualFunctor<T>());
// 2. dx = dout * 1
std::vector<const phi::DenseTensor*> mul_inputs = {&new_out_grad, equal_out};
std::vector<phi::DenseTensor*> mul_outputs = {x_grad};
funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx, mul_inputs, &mul_outputs, 0, funcs::MultiplyFunctor<T>());
delete equal_out;
}
} // namespace phi
PD_REGISTER_KERNEL(min_grad,
GPU,
ALL_LAYOUT,
......@@ -25,4 +80,6 @@ PD_REGISTER_KERNEL(min_grad,
float,
double,
int,
int64_t) {}
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -68,4 +68,5 @@ PD_REGISTER_KERNEL(take_along_axis_grad,
double,
int64_t,
int,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -54,4 +54,5 @@ PD_REGISTER_KERNEL(take_along_axis,
double,
int64_t,
int,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -27,5 +27,6 @@ PD_REGISTER_KERNEL(trace_grad,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -52,5 +52,6 @@ PD_REGISTER_KERNEL(trace,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -18,5 +18,11 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/unfold_grad_kernel_impl.h"
PD_REGISTER_KERNEL(
unfold_grad, GPU, ALL_LAYOUT, phi::UnfoldGradKernel, float, double) {}
PD_REGISTER_KERNEL(unfold_grad,
GPU,
ALL_LAYOUT,
phi::UnfoldGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -15,7 +15,15 @@
#include "paddle/phi/kernels/unfold_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/unfold_kernel_impl.h"
PD_REGISTER_KERNEL(unfold, GPU, ALL_LAYOUT, phi::UnfoldKernel, float, double) {}
PD_REGISTER_KERNEL(unfold,
GPU,
ALL_LAYOUT,
phi::UnfoldKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/phi/kernels/uniform_inplace_grad_kernel.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/full_kernel.h"
......@@ -41,4 +42,6 @@ PD_REGISTER_KERNEL(uniform_inplace_grad,
ALL_LAYOUT,
phi::UniformInplaceGradKernel,
float,
double) {}
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <thrust/random.h>
#include "gflags/gflags.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
#include "paddle/phi/kernels/funcs/index_impl.cu.h"
......@@ -72,8 +73,12 @@ void UniformInplaceKernel(const Context& ctx,
funcs::distribution_and_transform<T>(ctx, out, dist, trans);
} else {
// Use OP seed
auto func =
UniformGenerator<T>(min, max, seed, diag_num, diag_step, diag_val);
auto func = UniformGenerator<T>(static_cast<T>(min),
static_cast<T>(max),
seed,
diag_num,
diag_step,
static_cast<T>(diag_val));
IndexKernel<T, UniformGenerator<T>>(ctx, out, func);
}
}
......@@ -85,4 +90,6 @@ PD_REGISTER_KERNEL(uniform_inplace,
ALL_LAYOUT,
phi::UniformInplaceKernel,
float,
double) {}
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -18,13 +18,34 @@ limitations under the License. */
#include "glog/logging.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/addmm_grad_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/for_range.h"
namespace phi {
template <typename T>
struct CopyOrScaleFunctor {
CopyOrScaleFunctor(const float scale, const T* x, T* output, int64_t numel)
: scale_(scale), x_(x), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
const MPType mp_scale = static_cast<MPType>(scale_);
const MPType mp_x = static_cast<MPType>(x_[idx]);
output_[idx] = static_cast<T>(mp_scale * mp_x);
}
private:
const float scale_;
const T* x_;
T* output_;
int64_t numel_;
};
template <typename T,
size_t D,
int MajorType = Eigen::RowMajor,
......@@ -45,6 +66,13 @@ void AddmmGradKernel(const Context& dev_ctx,
DenseTensor* input_grad,
DenseTensor* x_grad,
DenseTensor* y_grad) {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
bool is_float16_or_bfloat16 = false;
if (std::is_same<T, phi::dtype::float16>::value ||
std::is_same<T, phi::dtype::bfloat16>::value) {
is_float16_or_bfloat16 = true;
}
auto in_dims = input.dims();
if (input.dims().size() == 1) {
in_dims = {1, input.dims()[0]};
......@@ -65,6 +93,7 @@ void AddmmGradKernel(const Context& dev_ctx,
}
auto blas = funcs::GetBlas<Context, T>(dev_ctx);
auto mt_blas = funcs::GetBlas<Context, MPType>(dev_ctx);
if (input_grad) {
dev_ctx.template Alloc<T>(input_grad);
total_elems = in_dims[0] * in_dims[1];
......@@ -78,19 +107,60 @@ void AddmmGradKernel(const Context& dev_ctx,
Array2(input_grad->dims()[0], input_grad->dims()[1]);
if (row_compress && col_compress) {
if (!is_float16_or_bfloat16) {
eigen_dinput.device(place) =
eigen_dout.sum().eval().reshape(eigen_dinput_shape);
} else {
eigen_dinput.device(place) = eigen_dout.template cast<MPType>()
.sum()
.eval()
.reshape(eigen_dinput_shape)
.template cast<T>();
}
} else if (row_compress) {
if (!is_float16_or_bfloat16) {
eigen_dinput.device(place) =
eigen_dout.sum(Array1(0)).eval().reshape(eigen_dinput_shape);
} else {
eigen_dinput.device(place) = eigen_dout.template cast<MPType>()
.sum(Array1(0))
.eval()
.reshape(eigen_dinput_shape)
.template cast<T>();
}
} else if (col_compress) {
if (!is_float16_or_bfloat16) {
eigen_dinput.device(place) =
eigen_dout.sum(Array1(1)).eval().reshape(eigen_dinput_shape);
} else {
blas.VCOPY(total_elems, out_grad.data<T>(), input_grad->data<T>());
eigen_dinput.device(place) = eigen_dout.template cast<MPType>()
.sum(Array1(1))
.eval()
.reshape(eigen_dinput_shape)
.template cast<T>();
}
} else {
// The VCOPY does not support the float16, bfloat16
if (!is_float16_or_bfloat16) {
mt_blas.VCOPY(
total_elems, out_grad.data<MPType>(), input_grad->data<MPType>());
} else {
phi::funcs::ForRange<Context> for_range(dev_ctx, total_elems);
CopyOrScaleFunctor<T> functor(
1, out_grad.data<T>(), input_grad->data<T>(), total_elems);
for_range(functor);
}
}
blas.SCAL(total_elems, beta, input_grad->data<T>());
// The SCAL does not support the float16, bfloat16
if (!is_float16_or_bfloat16) {
mt_blas.SCAL(total_elems, beta, input_grad->data<MPType>());
} else {
phi::funcs::ForRange<Context> for_range(dev_ctx, total_elems);
CopyOrScaleFunctor<T> functor(
beta, input_grad->data<T>(), input_grad->data<T>(), total_elems);
for_range(functor);
}
if (input.dims().size() == 1) {
input_grad->Resize(input.dims());
......@@ -101,14 +171,28 @@ void AddmmGradKernel(const Context& dev_ctx,
total_elems = x.dims()[0] * x.dims()[1];
// x_grad = out_grad * y'. x_grad: M x K, out_grad : M x N, y : K x N
blas.MatMul(out_grad, false, y, true, x_grad);
blas.SCAL(total_elems, alpha, x_grad->data<T>());
if (!is_float16_or_bfloat16) {
mt_blas.SCAL(total_elems, alpha, x_grad->data<MPType>());
} else {
phi::funcs::ForRange<Context> for_range(dev_ctx, total_elems);
CopyOrScaleFunctor<T> functor(
alpha, x_grad->data<T>(), x_grad->data<T>(), total_elems);
for_range(functor);
}
}
if (y_grad) {
dev_ctx.template Alloc<T>(y_grad);
total_elems = x.dims()[1] * y.dims()[1];
// y_grad = x' * out_grad. y_grad K x N, out_grad : M x N, x : M x K
blas.MatMul(x, true, out_grad, false, y_grad);
blas.SCAL(total_elems, alpha, y_grad->data<T>());
if (!is_float16_or_bfloat16) {
mt_blas.SCAL(total_elems, alpha, y_grad->data<MPType>());
} else {
phi::funcs::ForRange<Context> for_range(dev_ctx, total_elems);
CopyOrScaleFunctor<T> functor(
alpha, y_grad->data<T>(), y_grad->data<T>(), total_elems);
for_range(functor);
}
}
}
......
......@@ -112,17 +112,19 @@ void AddmmKernel(const Context& dev_ctx,
funcs::EigenBroadcast<std::decay_t<decltype(place)>, T, 2>::Eval(
place, eigen_out, eigen_input, bcast_dims);
T t_alpha = static_cast<T>(alpha);
T t_beta = static_cast<T>(beta);
blas.GEMM(false,
false,
x_dims[0],
y_dims[1],
x_dims[1],
alpha,
t_alpha,
x.data<T>(),
x_dims[1],
y.data<T>(),
y_dims[1],
beta,
t_beta,
out->data<T>(),
y_dims[1]);
}
......
......@@ -117,6 +117,7 @@ PD_REGISTER_KERNEL(fmax,
double,
int,
float16,
bfloat16,
int64_t) {}
PD_REGISTER_KERNEL(fmin,
......@@ -127,6 +128,7 @@ PD_REGISTER_KERNEL(fmin,
double,
int,
float16,
bfloat16,
int64_t) {}
PD_REGISTER_KERNEL(maximum_raw,
......
......@@ -36,6 +36,14 @@ void MinRawKernel(const Context& dev_ctx,
#ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL(min_raw, KPS, ALL_LAYOUT, phi::MinRawKernel, float) {}
#else
PD_REGISTER_KERNEL(
min_raw, KPS, ALL_LAYOUT, phi::MinRawKernel, float, double, int, int64_t) {}
PD_REGISTER_KERNEL(min_raw,
KPS,
ALL_LAYOUT,
phi::MinRawKernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#endif
......@@ -39,7 +39,20 @@ void MinKernel(const Context& dev_ctx,
PD_REGISTER_KERNEL(
min, CPU, ALL_LAYOUT, phi::MinKernel, float, double, int, int64_t) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#if defined(PADDLE_WITH_CUDA)
PD_REGISTER_KERNEL(min,
GPU,
ALL_LAYOUT,
phi::MinKernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#endif
#if defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(
min, GPU, ALL_LAYOUT, phi::MinKernel, float, double, int, int64_t) {}
#endif
......
......@@ -2689,6 +2689,26 @@ class OpTest(unittest.TestCase):
def np_value_to_fluid_value(input):
return input
def cast_bf16_output(self, block, cast_inputs):
output_names = []
for i in range(0, len(cast_inputs)):
cast_output = block.create_var(
dtype="float32", shape=cast_inputs[i].shape
)
cast_op = block.append_op(
inputs={"X": cast_inputs[i]},
outputs={"Out": cast_output},
type="cast",
attrs={
"in_dtype": core.VarDesc.VarType.BF16,
"out_dtype": core.VarDesc.VarType.FP32,
},
)
cast_op.desc.infer_var_type(block.desc)
cast_op.desc.infer_shape(block.desc)
output_names.append(cast_output.name)
return output_names
def _get_gradient(
self,
input_to_check,
......@@ -2712,6 +2732,9 @@ class OpTest(unittest.TestCase):
if user_defined_grad_outputs is None:
if self.dtype == np.uint16:
cast_inputs = list(map(block.var, output_names))
if self.op_type == "broadcast_tensors":
output_names = self.cast_bf16_output(block, cast_inputs)
else:
cast_outputs = block.create_var(
dtype="float32", shape=cast_inputs[0].shape
)
......
......@@ -15,11 +15,11 @@
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
from paddle import fluid
from paddle.fluid import Program, program_guard
from paddle.fluid import Program, core, program_guard
class TestAddMMOp(OpTest):
......@@ -27,7 +27,6 @@ class TestAddMMOp(OpTest):
def setUp(self):
self.op_type = "addmm"
self.python_api = paddle.addmm
self.dtype = np.float64
self.init_dtype_type()
self.inputs = {
'Input': np.random.random((100, 1)).astype(self.dtype),
......@@ -40,7 +39,7 @@ class TestAddMMOp(OpTest):
}
def init_dtype_type(self):
pass
self.dtype = np.float64
def test_check_output(self):
self.check_output()
......@@ -58,6 +57,62 @@ class TestAddMMOp(OpTest):
self.check_grad(['Input'], 'Out', no_grad_set=None)
class TestAddMMFP16Op(TestAddMMOp):
def init_dtype_type(self):
self.dtype = np.float16
def test_check_output(self):
self.check_output(atol=1e-2)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestAddMMBF16Op(OpTest):
def setUp(self):
self.op_type = "addmm"
self.python_api = paddle.addmm
self.init_dtype_type()
self.inputs = {
'Input': np.random.random((100, 1)).astype(self.np_dtype),
'X': np.random.random((100, 10)).astype(self.np_dtype),
'Y': np.random.random((10, 20)).astype(self.np_dtype),
}
self.outputs = {
'Out': self.inputs['Input']
+ np.dot(self.inputs['X'], self.inputs['Y'])
}
self.inputs['Input'] = convert_float_to_uint16(self.inputs['Input'])
self.inputs['X'] = convert_float_to_uint16(self.inputs['X'])
self.inputs['Y'] = convert_float_to_uint16(self.inputs['Y'])
self.outputs['Out'] = convert_float_to_uint16(self.outputs['Out'])
self.place = core.CUDAPlace(0)
def init_dtype_type(self):
self.dtype = np.uint16
self.np_dtype = np.float32
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad_normal(self):
self.check_grad_with_place(self.place, ['Input', 'X', 'Y'], 'Out')
def test_check_grad_x(self):
self.check_grad_with_place(self.place, ['X'], 'Out', no_grad_set=None)
def test_check_grad_y(self):
self.check_grad_with_place(self.place, ['Y'], 'Out', no_grad_set=None)
def test_check_grad_input(self):
self.check_grad_with_place(
self.place, ['Input'], 'Out', no_grad_set=None
)
class TestAddMMOpError(unittest.TestCase):
# test error
def test_errors(self):
......
......@@ -16,7 +16,7 @@ import random
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
from paddle.fluid import core
......@@ -43,7 +43,7 @@ def find_output_shape(input_list):
return list(reversed(output_shape))
def make_inputs_outputs(input_shapes, dtype):
def make_inputs_outputs(input_shapes, dtype, is_bfloat16=False):
"""Automatically generate formatted inputs and outputs from input_shapes"""
input_list = [
np.random.random(shape).astype(dtype) for shape in input_shapes
......@@ -53,6 +53,16 @@ def make_inputs_outputs(input_shapes, dtype):
x + np.zeros(output_shape).astype(x.dtype) for x in input_list
]
if is_bfloat16:
input_list = [
convert_float_to_uint16(input_list[i])
for i in range(len(input_list))
]
output_list = [
convert_float_to_uint16(output_list[i])
for i in range(len(output_list))
]
output_formatted = {
"Out": [(f"out{i}", output_list[i]) for i in range(len(output_list))]
}
......@@ -63,24 +73,24 @@ def make_inputs_outputs(input_shapes, dtype):
return input_formatted, output_formatted
def gen_rank_diff_test(dtype):
def gen_rank_diff_test(dtype, is_bfloat16=False):
input_shapes = [(2, 60, 1), (6, 2, 1, 10)]
return make_inputs_outputs(input_shapes, dtype)
return make_inputs_outputs(input_shapes, dtype, is_bfloat16)
def gen_no_broadcast_test(dtype):
def gen_no_broadcast_test(dtype, is_bfloat16=False):
input_shapes = [(12, 1, 10, 1), (12, 1, 10, 1)]
return make_inputs_outputs(input_shapes, dtype)
return make_inputs_outputs(input_shapes, dtype, is_bfloat16)
def gen_mixed_tensors_test(dtype):
def gen_mixed_tensors_test(dtype, is_bfloat16=False):
input_shapes = [(2, 60, 1), (2, 2, 1, 30), (1, 2, 60, 1)]
return make_inputs_outputs(input_shapes, dtype)
return make_inputs_outputs(input_shapes, dtype, is_bfloat16)
def gen_empty_tensors_test(dtype):
def gen_empty_tensors_test(dtype, is_bfloat16=False):
input_shapes = [(0), (0), (0)]
return make_inputs_outputs(input_shapes, dtype)
return make_inputs_outputs(input_shapes, dtype, is_bfloat16)
class TestCPUBroadcastTensorsOp(OpTest):
......@@ -125,7 +135,7 @@ class TestCPUBroadcastTensorsOp(OpTest):
def test_check_output(self):
self.run_dual_test(
self.check_output_with_place,
{"place": self.place, "atol": 1e-1},
{"place": self.place},
)
def test_check_grad_normal(self):
......@@ -135,7 +145,6 @@ class TestCPUBroadcastTensorsOp(OpTest):
"place": self.place,
"inputs_to_check": ['x0', 'x1'],
"output_names": ['out0', 'out1'],
"max_relative_error": 0.05,
},
)
self.run_triple_in_test(
......@@ -144,7 +153,6 @@ class TestCPUBroadcastTensorsOp(OpTest):
"place": self.place,
"inputs_to_check": ['x0', 'x1', 'x2'],
"output_names": ['out0', 'out1', "out2"],
"max_relative_error": 0.05,
},
)
......@@ -152,14 +160,77 @@ class TestCPUBroadcastTensorsOp(OpTest):
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestCUDABroadcastTensorsOp(TestCPUBroadcastTensorsOp):
class TestBroadcastTensorsFP16Op(TestCPUBroadcastTensorsOp):
def set_place(self):
self.place = core.CUDAPlace(0)
def set_dtypes(self):
self.dtypes = ['float64']
if core.is_float16_supported(self.place):
self.dtypes.append('float16')
self.dtypes = ['float16']
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestBroadcastTensorsBF16Op(OpTest):
def setUp(self):
self.op_type = "broadcast_tensors"
self.dtype = np.uint16
self.np_dtype = "float32"
self.use_mkldnn = False
self.attrs = {'use_mkldnn': self.use_mkldnn}
self.test_gen_func_list = [
gen_rank_diff_test,
gen_no_broadcast_test,
gen_mixed_tensors_test,
]
self.python_api = paddle.broadcast_tensors
self.place = core.CUDAPlace(0)
def run_dual_test(self, test_func, args):
for gen_func in self.test_gen_func_list:
self.inputs, self.outputs = gen_func(self.np_dtype, True)
if len(self.outputs["Out"]) < 3:
self.python_out_sig = [
f"out{i}" for i in range(len(self.outputs["Out"]))
]
test_func(**args)
def run_triple_in_test(self, test_func, args):
self.inputs, self.outputs = self.test_gen_func_list[2](
self.np_dtype, True
)
self.python_out_sig = [
f"out{i}" for i in range(len(self.outputs["Out"]))
]
test_func(**args)
def test_check_output(self):
self.run_dual_test(
self.check_output_with_place,
{"place": self.place},
)
def test_check_grad_normal(self):
self.run_dual_test(
self.check_grad_with_place,
{
"place": self.place,
"inputs_to_check": ['x0', 'x1'],
"output_names": ['out0', 'out1'],
"check_dygraph": False,
},
)
self.run_triple_in_test(
self.check_grad_with_place,
{
"place": self.place,
"inputs_to_check": ['x0', 'x1', 'x2'],
"output_names": ['out0', 'out1', 'out2'],
"check_dygraph": False,
},
)
class TestBroadcastTensorsAPI(unittest.TestCase):
......
......@@ -15,7 +15,7 @@
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
from op import Operator
import paddle
......@@ -102,6 +102,48 @@ class TestClipByNormOpFp16Case3(TestClipByNormOpFp16):
self.max_norm = 1.0
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestClipByNormBF16Op(OpTest):
def setUp(self):
self.max_relative_error = 0.006
self.python_api = clip.clip_by_norm
self.init_dtype()
self.initTestCase()
input = np.random.random(self.shape).astype(self.np_dtype)
input[np.abs(input) < self.max_relative_error] = 0.5
self.op_type = "clip_by_norm"
self.inputs = {
'X': input,
}
self.attrs = {}
self.attrs['max_norm'] = self.max_norm
norm = np.sqrt(np.sum(np.square(input)))
if norm > self.max_norm:
output = self.max_norm * input / norm
else:
output = input
self.outputs = {'Out': output}
self.inputs['X'] = convert_float_to_uint16(self.inputs['X'])
self.outputs['Out'] = convert_float_to_uint16(self.outputs['Out'])
self.place = core.CUDAPlace(0)
def test_check_output(self):
self.check_output_with_place(self.place)
def initTestCase(self):
self.shape = (100,)
self.max_norm = 1.0
def init_dtype(self):
self.dtype = np.uint16
self.np_dtype = np.float32
class TestClipByNormOpWithSelectedRows(unittest.TestCase):
def check_with_place(self, place):
self.config_test_case()
......
......@@ -15,10 +15,16 @@
import unittest
import numpy as np
from eager_op_test import OpTest, paddle_static_guard
from eager_op_test import (
OpTest,
convert_float_to_uint16,
get_numeric_gradient,
paddle_static_guard,
)
import paddle
from paddle.fluid import core
from paddle.fluid.tests.unittests.testsuite import create_op
def conv3d_forward_naive(
......@@ -179,6 +185,77 @@ def create_test_cudnn_class(parent):
globals()[cls_name] = TestCUDNNCase
def create_test_cudnn_bf16_class(parent):
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and do not support bfloat16",
)
class TestConv3DCUDNNBF16(parent):
def get_numeric_grad(self, place, check_name):
scope = core.Scope()
self._check_grad_helper()
op = create_op(
scope, self.op_type, self.inputs, self.outputs, self.attrs
)
return get_numeric_gradient(
place, scope, op, self.inputs_fp32, check_name, ['Output']
)
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.uint16
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(
place, check_dygraph=(not self.use_mkldnn)
)
def test_check_grad_no_filter(self):
place = core.CUDAPlace(0)
numeric_grads = self.get_numeric_grad(place, 'Input')
self.check_grad_with_place(
place,
['Input'],
'Output',
no_grad_set={'Filter'},
check_dygraph=(not self.use_mkldnn),
user_defined_grads=[numeric_grads],
)
def test_check_grad_no_input(self):
place = core.CUDAPlace(0)
numeric_grads = self.get_numeric_grad(place, 'Filter')
self.check_grad_with_place(
place,
['Filter'],
'Output',
no_grad_set={'Input'},
check_dygraph=(not self.use_mkldnn),
user_defined_grads=[numeric_grads],
)
def test_check_grad(self):
place = core.CUDAPlace(0)
numeric_input_grads = self.get_numeric_grad(place, 'Input')
numeric_fliter_grads = self.get_numeric_grad(place, 'Filter')
self.check_grad_with_place(
place,
{'Input', 'Filter'},
'Output',
user_defined_grads=[numeric_input_grads, numeric_fliter_grads],
check_dygraph=(not self.use_mkldnn),
)
cls_name = "{}_{}".format(parent.__name__, "CUDNNBF16OP")
TestConv3DCUDNNBF16.__name__ = cls_name
globals()[cls_name] = TestConv3DCUDNNBF16
def create_test_padding_SAME_class(parent):
class TestPaddingSMAECase(parent):
def init_paddings(self):
......@@ -323,19 +400,37 @@ class TestConv3DOp(OpTest):
'dilations': self.dilations,
}
if self.is_bfloat16_op():
input = np.random.random(self.input_size).astype(np.float32)
filter = np.random.random(self.filter_size).astype(np.float32)
else:
input = np.random.random(self.input_size).astype(self.dtype)
filter = np.random.random(self.filter_size).astype(self.dtype)
output = conv3d_forward_naive(
input,
filter,
self.groups,
conv3d_param,
).astype(self.dtype)
)
if self.is_bfloat16_op():
output = convert_float_to_uint16(output)
self.inputs = {
'Input': convert_float_to_uint16(input),
'Filter': convert_float_to_uint16(filter),
}
self.inputs_fp32 = {
'Input': OpTest.np_dtype_to_fluid_dtype(input),
'Filter': OpTest.np_dtype_to_fluid_dtype(filter),
}
else:
output = output.astype(self.dtype)
self.inputs = {
'Input': OpTest.np_dtype_to_fluid_dtype(input),
'Filter': OpTest.np_dtype_to_fluid_dtype(filter),
}
self.attrs = {
'strides': self.stride,
'paddings': self.pad,
......@@ -358,8 +453,6 @@ class TestConv3DOp(OpTest):
)
def test_check_grad(self):
if self.dtype == np.float16:
return
place = core.CUDAPlace(0) if self.has_cudnn() else core.CPUPlace()
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_grad_with_place(
......@@ -371,8 +464,7 @@ class TestConv3DOp(OpTest):
)
def test_check_grad_no_filter(self):
if self.dtype == np.float16:
return
place = core.CUDAPlace(0) if self.has_cudnn() else core.CPUPlace()
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_grad_with_place(
......@@ -385,8 +477,7 @@ class TestConv3DOp(OpTest):
)
def test_check_grad_no_input(self):
if self.dtype == np.float16:
return
place = core.CUDAPlace(0) if self.has_cudnn() else core.CPUPlace()
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_grad_with_place(
......@@ -617,6 +708,14 @@ class TestCUDNNExhaustiveSearch(TestCUDNN):
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
# ----------------Conv3DCUDNN bf16----------------
create_test_cudnn_bf16_class(TestConv3DOp)
create_test_cudnn_bf16_class(TestWithGroup1)
create_test_cudnn_bf16_class(TestWithGroup2)
create_test_cudnn_bf16_class(TestWith1x1)
create_test_cudnn_bf16_class(TestWithInput1x1Filter1x1)
# ---- test asymmetric padding ----
......@@ -1114,4 +1213,5 @@ class TestConv3DAPI_Error(unittest.TestCase):
if __name__ == '__main__':
unittest.main()
......@@ -19,11 +19,25 @@ import numpy as np
import paddle
paddle.enable_static()
from eager_op_test import OpTest
from eager_op_test import OpTest, copy_bits_from_float_to_uint16
from paddle.fluid import core
def convert_float_to_uint16(float_list, data_format="NCHW"):
if data_format == "NHWC":
float_list = np.transpose(float_list, [0, 4, 1, 2, 3])
new_output = []
for x in np.nditer(float_list):
new_output.append(np.uint16(copy_bits_from_float_to_uint16(x)))
new_output = np.reshape(new_output, float_list.shape).view(np.uint16)
if data_format == "NHWC":
new_output = np.transpose(new_output, [0, 2, 3, 4, 1])
return new_output
def conv3dtranspose_forward_naive(input_, filter_, attrs):
padding_algorithm = attrs['padding_algorithm']
if padding_algorithm not in ["SAME", "VALID", "EXPLICIT"]:
......@@ -134,6 +148,86 @@ def conv3dtranspose_forward_naive(input_, filter_, attrs):
return out
def create_test_cudnn_fp16_class(parent, grad_check=True):
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestConv3DTransposeCUDNNFP16(parent):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=2e-2)
def test_check_grad_no_filter(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place) and grad_check:
self.check_grad_with_place(
place, ['Input'], 'Output', no_grad_set={'Filter'}
)
def test_check_grad_no_input(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place) and grad_check:
self.check_grad_with_place(
place, ['Filter'], 'Output', no_grad_set={'Input'}
)
cls_name = "{}_{}".format(parent.__name__, "CUDNNFP16OP")
TestConv3DTransposeCUDNNFP16.__name__ = cls_name
globals()[cls_name] = TestConv3DTransposeCUDNNFP16
def create_test_cudnn_bf16_class(parent):
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and do not support bfloat16",
)
class TestConv3DTransposeCUDNNBF16(parent):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.uint16
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
{'Input', 'Filter'},
'Output',
)
def test_check_grad_no_filter(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
['Input'],
'Output',
no_grad_set={'Filter'},
)
def test_check_grad_no_input(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
['Filter'],
'Output',
no_grad_set={'Input'},
)
cls_name = "{}_{}".format(parent.__name__, "CUDNNBF16OP")
TestConv3DTransposeCUDNNBF16.__name__ = cls_name
globals()[cls_name] = TestConv3DTransposeCUDNNBF16
def conv3d_transpose_wrapper(
x,
weight,
......@@ -172,12 +266,16 @@ class TestConv3DTransposeOp(OpTest):
self.pad = [0, 0, 0]
self.padding_algorithm = "EXPLICIT"
self.init_op_type()
self.init_kernel_type()
self.init_test_case()
input_ = np.random.random(self.input_size).astype("float32")
filter_ = np.random.random(self.filter_size).astype("float32")
if self.is_bfloat16_op():
input = np.random.random(self.input_size).astype(np.float32)
filter = np.random.random(self.filter_size).astype(np.float32)
else:
input = np.random.random(self.input_size).astype(self.dtype)
filter = np.random.random(self.filter_size).astype(self.dtype)
self.inputs = {'Input': input_, 'Filter': filter_}
self.attrs = {
'strides': self.stride,
'paddings': self.pad,
......@@ -189,9 +287,21 @@ class TestConv3DTransposeOp(OpTest):
}
output = conv3dtranspose_forward_naive(
input_, filter_, self.attrs
input, filter, self.attrs
).astype("float32")
if self.is_bfloat16_op():
self.inputs = {
'Input': convert_float_to_uint16(input),
'Filter': convert_float_to_uint16(filter),
}
else:
self.inputs = {
'Input': input,
'Filter': filter,
}
output = output.astype(self.dtype)
self.outputs = {'Output': output}
def test_check_output(self):
......@@ -264,6 +374,9 @@ class TestConv3DTransposeOp(OpTest):
self.op_type = "conv3d_transpose"
self.python_api = conv3d_transpose_wrapper
def init_kernel_type(self):
self.dtype = np.float32
class TestWithSymmetricPad(TestConv3DTransposeOp):
def init_test_case(self):
......@@ -596,6 +709,30 @@ class TestCUDNNWithGroups_NHWC(TestWithGroups):
self.python_api = conv3d_transpose_wrapper
# ----------------Conv3DTransposeCUDNN fp16----------------
create_test_cudnn_fp16_class(TestConv3DTransposeOp)
create_test_cudnn_fp16_class(TestWithSymmetricPad)
create_test_cudnn_fp16_class(TestWithAsymmetricPad)
create_test_cudnn_fp16_class(TestWithSAMEPad)
create_test_cudnn_fp16_class(TestWithVALIDPad)
create_test_cudnn_fp16_class(TestWithStride)
create_test_cudnn_fp16_class(TestWithGroups)
create_test_cudnn_fp16_class(TestWithDilation)
create_test_cudnn_fp16_class(Test_NHWC)
# ----------------Conv3DTransposeCUDNN bf16----------------
create_test_cudnn_bf16_class(TestConv3DTransposeOp)
create_test_cudnn_bf16_class(TestWithSymmetricPad)
create_test_cudnn_bf16_class(TestWithAsymmetricPad)
create_test_cudnn_bf16_class(TestWithSAMEPad)
create_test_cudnn_bf16_class(TestWithVALIDPad)
create_test_cudnn_bf16_class(TestWithStride)
create_test_cudnn_bf16_class(TestWithGroups)
create_test_cudnn_bf16_class(TestWithDilation)
create_test_cudnn_bf16_class(Test_NHWC)
class TestConv3dTranspose(unittest.TestCase):
def error_weight_input(self):
array = np.array([1], dtype=np.float32)
......
......@@ -15,7 +15,11 @@
import unittest
import numpy as np
from test_conv3d_transpose_op import TestConv3DTransposeOp
from test_conv3d_transpose_op import (
TestConv3DTransposeOp,
create_test_cudnn_bf16_class,
create_test_cudnn_fp16_class,
)
import paddle
from paddle import fluid
......@@ -84,6 +88,22 @@ class TestWithDilation_NHWC(TestConv3DTransposeOp):
self.data_format = 'NHWC'
# ----------------Conv3DTransposeCUDNN fp16----------------
create_test_cudnn_fp16_class(TestWithSymmetricPad_NHWC)
create_test_cudnn_fp16_class(TestWithAsymmetricPad_NHWC)
create_test_cudnn_fp16_class(TestWithGroups_NHWC)
create_test_cudnn_fp16_class(TestWithStride_NHWC)
create_test_cudnn_fp16_class(TestWithDilation_NHWC)
# ----------------Conv3DTransposeCUDNN bf16----------------
create_test_cudnn_bf16_class(TestWithSymmetricPad_NHWC)
create_test_cudnn_bf16_class(TestWithAsymmetricPad_NHWC)
create_test_cudnn_bf16_class(TestWithGroups_NHWC)
create_test_cudnn_bf16_class(TestWithStride_NHWC)
create_test_cudnn_bf16_class(TestWithDilation_NHWC)
class TestConv3DTransposeAPI(unittest.TestCase):
def test_case1(self):
data1 = paddle.static.data(
......
......@@ -15,7 +15,7 @@
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
from paddle.fluid import core
......@@ -241,5 +241,34 @@ class TestElementwiseFmax3Op(OpTest):
self.check_grad(['X', 'Y'], 'Out')
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and not support the bfloat16",
)
class TestFmaxBF16OP(OpTest):
def setUp(self):
self.op_type = "elementwise_fmax"
self.python_api = paddle.fmax
self.dtype = np.uint16
x = np.random.uniform(0.1, 1, [13, 17]).astype("float32")
sgn = np.random.choice([-1, 1], [13, 17]).astype("float32")
y = x + sgn * np.random.uniform(0.1, 1, [13, 17]).astype("float32")
out = np.fmax(x, y)
self.inputs = {
'X': convert_float_to_uint16(x),
'Y': convert_float_to_uint16(y),
}
self.outputs = {'Out': convert_float_to_uint16(out)}
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X', 'Y'], 'Out')
if __name__ == "__main__":
unittest.main()
......@@ -15,7 +15,7 @@
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
from paddle.fluid import core
......@@ -243,6 +243,35 @@ class TestElementwiseFmin3Op(OpTest):
self.check_grad(['X', 'Y'], 'Out')
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and not support the bfloat16",
)
class TestFminBF16OP(OpTest):
def setUp(self):
self.op_type = "elementwise_fmin"
self.python_api = paddle.fmin
self.dtype = np.uint16
x = np.random.uniform(1, 1, [13, 17]).astype("float32")
sgn = np.random.choice([-1, 1], [13, 17]).astype("float32")
y = x + sgn * np.random.uniform(1, 1, [13, 17]).astype("float32")
out = np.fmin(x, y)
self.inputs = {
'X': convert_float_to_uint16(x),
'Y': convert_float_to_uint16(y),
}
self.outputs = {'Out': convert_float_to_uint16(out)}
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X', 'Y'], 'Out')
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
......@@ -15,10 +15,10 @@
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
from paddle.fluid import Program
from paddle.fluid import Program, core
def compute_index_add_ref(
......@@ -99,6 +99,69 @@ class TestIndexAddOp(OpTest):
self.check_grad(['X', 'AddValue'], 'Out')
class TestIndexAddFP16Op(TestIndexAddOp):
def init_dtype_type(self):
self.axis = 0
self.x_type = np.float16
self.index_type = np.int64
self.x_shape = (101, 3)
self.index_size = 3
self.add_value_shape = (3, 3)
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestIndexAddBF16Op(OpTest):
def setUp(self):
self.python_api = raw_index_add
self.op_type = "index_add"
self.init_dtype_type()
index_np = np.random.randint(
low=0, high=self.x_shape[self.axis], size=self.index_size
)
x_np = np.random.random(self.x_shape).astype(self.x_type)
add_value_np = np.random.random(self.add_value_shape).astype(
self.x_type
)
self.inputs = {
'X': convert_float_to_uint16(x_np),
'Index': index_np,
'AddValue': convert_float_to_uint16(add_value_np),
}
self.attrs = {'axis': self.axis}
out = compute_index_add_ref(
self.axis,
self.x_shape,
x_np,
self.add_value_shape,
add_value_np,
self.index_size,
index_np,
)
self.outputs = {'Out': convert_float_to_uint16(out)}
self.place = core.CUDAPlace(0)
def init_dtype_type(self):
self.axis = 0
self.x_type = np.float32
self.index_type = np.int64
self.x_shape = (101, 3)
self.index_size = 3
self.add_value_shape = (3, 3)
self.dtype = np.uint16
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad_normal(self):
self.check_grad_with_place(self.place, ['X', 'AddValue'], 'Out')
class TestIndexAddAPI(unittest.TestCase):
def setUp(self):
self.setType()
......
......@@ -15,10 +15,11 @@
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
from paddle import fluid
from paddle.fluid import core
class TestIndexSampleOp(OpTest):
......@@ -121,6 +122,49 @@ class TestCase6(TestIndexSampleOp):
self.index_type = "int64"
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestIndexSampleBF16Op(OpTest):
def setUp(self):
self.op_type = "index_sample"
self.python_api = paddle.index_sample
self.config()
xnp = np.random.random(self.x_shape).astype(self.x_type)
indexnp = np.random.randint(
low=0, high=self.x_shape[1], size=self.index_shape
).astype(self.index_type)
self.inputs = {'X': xnp, 'Index': indexnp}
index_array = []
for i in range(self.index_shape[0]):
for j in indexnp[i]:
index_array.append(xnp[i, j])
index_array = np.array(index_array).astype(self.x_type)
out = np.reshape(index_array, self.index_shape)
self.outputs = {'Out': out}
self.inputs['X'] = convert_float_to_uint16(self.inputs['X'])
self.outputs['Out'] = convert_float_to_uint16(self.outputs['Out'])
self.place = core.CUDAPlace(0)
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Out')
def config(self):
"""
For multi-dimension input
"""
self.x_shape = (10, 20)
self.x_type = "float32"
self.dtype = np.uint16
self.index_shape = (10, 10)
self.index_type = "int32"
class TestIndexSampleShape(unittest.TestCase):
def test_shape(self):
paddle.enable_static()
......
......@@ -15,9 +15,10 @@
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
from paddle.fluid import core
class TestLogspaceOpCommonCase(OpTest):
......@@ -41,6 +42,54 @@ class TestLogspaceOpCommonCase(OpTest):
self.check_output()
class TestLogspaceFP16Op(TestLogspaceOpCommonCase):
def init_data(self):
self.dtype = np.float16
self.inputs = {
'Start': np.array([0]).astype(self.dtype),
'Stop': np.array([10]).astype(self.dtype),
'Num': np.array([11]).astype('int32'),
'Base': np.array([2]).astype(self.dtype),
}
self.attrs = {'dtype': int(paddle.float16)}
self.outputs = {'Out': np.power(2, np.arange(0, 11)).astype(self.dtype)}
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestLogspaceBF16Op(OpTest):
def setUp(self):
self.op_type = "logspace"
self.python_api = paddle.logspace
self.init_data()
def init_data(self):
self.dtype = np.uint16
self.np_dtype = np.float32
self.inputs = {
'Start': np.array([0]).astype(self.np_dtype),
'Stop': np.array([10]).astype(self.np_dtype),
'Num': np.array([11]).astype('int32'),
'Base': np.array([2]).astype(self.np_dtype),
}
self.attrs = {'dtype': int(paddle.bfloat16)}
self.outputs = {
'Out': np.power(2, np.arange(0, 11)).astype(self.np_dtype)
}
self.inputs["Start"] = convert_float_to_uint16(self.inputs["Start"])
self.inputs["Stop"] = convert_float_to_uint16(self.inputs["Stop"])
self.inputs["Base"] = convert_float_to_uint16(self.inputs["Base"])
self.outputs["Out"] = convert_float_to_uint16(self.outputs["Out"])
self.place = core.CUDAPlace(0)
def test_check_output(self):
self.check_output_with_place(self.place)
class TestLogspaceOpReverseCase(TestLogspaceOpCommonCase):
def init_data(self):
dtype = 'float32'
......
......@@ -20,7 +20,7 @@ import numpy as np
from paddle.fluid import core
sys.path.append("..")
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
class TestMulOp(OpTest):
......@@ -114,14 +114,14 @@ class TestMulOp2(OpTest):
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestFP16MulOp1(TestMulOp):
class TestMulFP16Op1(TestMulOp):
def init_dtype_type(self):
self.dtype = np.float16
def test_check_output(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-1, check_dygraph=False)
self.check_output_with_place(place, check_dygraph=False)
def test_check_grad_normal(self):
place = core.CUDAPlace(0)
......@@ -130,7 +130,6 @@ class TestFP16MulOp1(TestMulOp):
place,
['X', 'Y'],
'Out',
max_relative_error=0.5,
check_dygraph=False,
)
......@@ -141,7 +140,6 @@ class TestFP16MulOp1(TestMulOp):
place,
['Y'],
'Out',
max_relative_error=0.5,
no_grad_set=set("X"),
check_dygraph=False,
)
......@@ -153,7 +151,6 @@ class TestFP16MulOp1(TestMulOp):
place,
['X'],
'Out',
max_relative_error=0.5,
no_grad_set=set('Y'),
check_dygraph=False,
)
......@@ -162,14 +159,14 @@ class TestFP16MulOp1(TestMulOp):
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestFP16MulOp2(TestMulOp2):
class TestMulFP16Op2(TestMulOp2):
def init_dtype_type(self):
self.dtype = np.float16
def test_check_output(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=2e-1, check_dygraph=False)
self.check_output_with_place(place, check_dygraph=False)
def test_check_grad_normal(self):
place = core.CUDAPlace(0)
......@@ -178,7 +175,6 @@ class TestFP16MulOp2(TestMulOp2):
place,
['X', 'Y'],
'Out',
max_relative_error=0.9,
check_dygraph=False,
)
......@@ -189,7 +185,6 @@ class TestFP16MulOp2(TestMulOp2):
place,
['Y'],
'Out',
max_relative_error=0.5,
no_grad_set=set("X"),
check_dygraph=False,
)
......@@ -201,7 +196,116 @@ class TestFP16MulOp2(TestMulOp2):
place,
['X'],
'Out',
max_relative_error=0.9,
no_grad_set=set('Y'),
check_dygraph=False,
)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestMulBF16Op1(OpTest):
def setUp(self):
self.op_type = "mul"
self.init_dtype_type()
self.inputs = {
'X': np.random.random((20, 5)).astype(self.np_dtype),
'Y': np.random.random((5, 21)).astype(self.np_dtype),
}
self.outputs = {'Out': np.dot(self.inputs['X'], self.inputs['Y'])}
self.inputs['X'] = convert_float_to_uint16(self.inputs['X'])
self.inputs['Y'] = convert_float_to_uint16(self.inputs['Y'])
self.outputs['Out'] = convert_float_to_uint16(self.outputs['Out'])
self.place = core.CUDAPlace(0)
def init_dtype_type(self):
self.dtype = np.uint16
self.np_dtype = np.float32
def test_check_output(self):
self.check_output_with_place(self.place, check_dygraph=False)
def test_check_grad_normal(self):
self.check_grad_with_place(
self.place, ['X', 'Y'], 'Out', check_dygraph=False
)
def test_check_grad_ingore_x(self):
self.check_grad_with_place(
self.place,
['Y'],
'Out',
no_grad_set=set("X"),
check_dygraph=False,
)
def test_check_grad_ingore_y(self):
self.check_grad_with_place(
self.place,
['X'],
'Out',
no_grad_set=set('Y'),
check_dygraph=False,
)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestMulBF16Op2(TestMulBF16Op1):
def setUp(self):
self.op_type = "mul"
self.init_dtype_type()
self.inputs = {
'X': np.random.random((3, 4, 2, 9)).astype(self.np_dtype),
'Y': np.random.random((3, 6, 1, 2, 3)).astype(self.np_dtype),
}
self.attrs = {
'x_num_col_dims': 2,
'y_num_col_dims': 2,
}
result = np.dot(
self.inputs['X'].reshape(3 * 4, 2 * 9),
self.inputs['Y'].reshape(3 * 6, 1 * 2 * 3),
)
result = result.reshape(3, 4, 1, 2, 3)
self.outputs = {'Out': result}
self.inputs['X'] = convert_float_to_uint16(self.inputs['X'])
self.inputs['Y'] = convert_float_to_uint16(self.inputs['Y'])
self.outputs['Out'] = convert_float_to_uint16(self.outputs['Out'])
self.place = core.CUDAPlace(0)
def test_check_grad_normal(self):
self.check_grad_with_place(
self.place,
['X', 'Y'],
'Out',
numeric_grad_delta=0.02,
check_dygraph=False,
)
def test_check_grad_ingore_x(self):
self.check_grad_with_place(
self.place,
['Y'],
'Out',
numeric_grad_delta=0.02,
no_grad_set=set("X"),
check_dygraph=False,
)
def test_check_grad_ingore_y(self):
self.check_grad_with_place(
self.place,
['X'],
'Out',
numeric_grad_delta=0.02,
no_grad_set=set('Y'),
check_dygraph=False,
)
......
......@@ -15,10 +15,11 @@
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
from numpy.linalg import multi_dot
import paddle
from paddle.fluid import core
paddle.enable_static()
......@@ -49,6 +50,53 @@ class TestMultiDotOp(OpTest):
self.check_grad(['x1'], 'Out')
class TestMultiDotFP16Op(TestMultiDotOp):
def get_dtype(self):
return "float16"
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestMultiDotBF16Op(OpTest):
def setUp(self):
self.op_type = "multi_dot"
self.python_api = paddle.linalg.multi_dot
self.dtype = self.get_dtype()
self.get_inputs_and_outputs()
self.place = core.CUDAPlace(0)
def get_dtype(self):
self.np_dtype = "float32"
return np.uint16
def get_inputs_and_outputs(self):
self.A = np.random.random((2, 8)).astype(self.np_dtype)
self.B = np.random.random((8, 4)).astype(self.np_dtype)
self.inputs = {
'X': [
('x0', convert_float_to_uint16(self.A)),
('x1', convert_float_to_uint16(self.B)),
]
}
self.outputs = {
'Out': convert_float_to_uint16(multi_dot([self.A, self.B]))
}
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(
self.place, ['x0'], 'Out', numeric_grad_delta=0.01
)
self.check_grad_with_place(
self.place, ['x1'], 'Out', numeric_grad_delta=0.01
)
# (A*B)*C
class TestMultiDotOp3Mat(TestMultiDotOp):
def get_inputs_and_outputs(self):
......
......@@ -15,9 +15,16 @@
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import (
OpTest,
convert_float_to_uint16,
convert_uint16_to_float,
get_numeric_gradient,
)
import paddle
from paddle.fluid import core
from paddle.fluid.tests.unittests.testsuite import create_op
def adaptive_start_index(index, input_size, output_size):
......@@ -149,9 +156,18 @@ class TestMaxPoolWithIndex_Op(OpTest):
self.init_test_case()
self.init_global()
self.init_adaptive()
self.init_dtype()
input = np.random.random(self.shape).astype("float64")
if self.is_bfloat16_op():
input = np.random.random(self.shape).astype(np.float32)
input = convert_uint16_to_float(
convert_float_to_uint16(np.round(input * 100.0, 2))
)
else:
input = np.random.random(self.shape).astype(self.dtype)
input = np.round(input * 100.0, 2)
output, mask = self.pool_forward_naive(
input,
self.ksize,
......@@ -160,8 +176,11 @@ class TestMaxPoolWithIndex_Op(OpTest):
self.global_pool,
self.adaptive,
)
output = output.astype("float64")
mask = mask.astype("int32")
if self.is_bfloat16_op():
output = output.astype(np.float32)
else:
output = output.astype(self.dtype)
self.attrs = {
'strides': self.strides,
......@@ -171,9 +190,21 @@ class TestMaxPoolWithIndex_Op(OpTest):
'adaptive': self.adaptive,
}
if self.is_bfloat16_op():
self.inputs = {'X': convert_float_to_uint16(input)}
self.outputs = {
'Out': convert_float_to_uint16(output),
"Mask": mask,
}
self.inputs_fp32 = {'X': input}
else:
self.inputs = {'X': input}
self.outputs = {'Out': output, "Mask": mask}
def init_dtype(self):
self.dtype = np.float64
def test_check_output(self):
self.check_output()
......@@ -220,9 +251,90 @@ class TestCase3(TestCase2):
self.global_pool = False
# ----------------max_pool2d_with_index----------------
class TestCastAdaptive3d(TestMaxPoolWithIndex_Op):
def init_adaptive(self):
self.adaptive = True
# ----------------max_pool3d_with_index_fp16----------------
def create_test_fp16_class(parent):
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestMaxPool3dFP16(parent):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place)
def test_check_grad(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_grad_with_place(place, {'X'}, ['Out'])
cls_name = "{}_{}".format(parent.__name__, "FP16OP")
TestMaxPool3dFP16.__name__ = cls_name
globals()[cls_name] = TestMaxPool3dFP16
create_test_fp16_class(TestMaxPoolWithIndex_Op)
create_test_fp16_class(TestCase1)
create_test_fp16_class(TestCase2)
create_test_fp16_class(TestCase3)
create_test_fp16_class(TestCastAdaptive3d)
# ----------------max_pool3d_with_index_bf16----------------
def create_test_bf16_class(parent):
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and do not support bfloat16",
)
class TestMaxPool3dBF16(parent):
def init_dtype(self):
self.dtype = np.uint16
def get_numeric_grad(self, place, check_name):
scope = core.Scope()
self._check_grad_helper()
op = create_op(
scope, self.op_type, self.inputs, self.outputs, self.attrs
)
return get_numeric_gradient(
place, scope, op, self.inputs_fp32, check_name, ['Out']
)
def test_check_output(self):
place = core.CUDAPlace(0)
if core.is_bfloat16_supported(place):
self.check_output_with_place(place)
def test_check_grad(self):
place = core.CUDAPlace(0)
numeric_grads = self.get_numeric_grad(place, 'X')
if core.is_bfloat16_supported(place):
self.check_grad_with_place(
place, {'X'}, ['Out'], user_defined_grads=[numeric_grads]
)
cls_name = "{}_{}".format(parent.__name__, "BF16OP")
TestMaxPool3dBF16.__name__ = cls_name
globals()[cls_name] = TestMaxPool3dBF16
create_test_bf16_class(TestMaxPoolWithIndex_Op)
create_test_bf16_class(TestCase1)
create_test_bf16_class(TestCase2)
create_test_bf16_class(TestCase3)
create_test_bf16_class(TestCastAdaptive3d)
# ----------------max_pool2d_with_index----------------
def max_pool2d_with_index_wapper(
x,
kernel_size=[],
......@@ -279,9 +391,82 @@ class TestCastAdaptive2d(TestCase6):
self.adaptive = True
class TestCastAdaptive3d(TestMaxPoolWithIndex_Op):
def init_adaptive(self):
self.adaptive = True
# ----------------max_pool2d_with_index_fp16----------------
def create_test_fp16_class(parent):
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestMaxPool2dFP16(parent):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place)
def test_check_grad(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_grad_with_place(place, {'X'}, ['Out'])
cls_name = "{}_{}".format(parent.__name__, "FP16OP")
TestMaxPool2dFP16.__name__ = cls_name
globals()[cls_name] = TestMaxPool2dFP16
create_test_fp16_class(TestCase4)
create_test_fp16_class(TestCase5)
create_test_fp16_class(TestCase6)
create_test_fp16_class(TestCase7)
create_test_fp16_class(TestCastAdaptive2d)
# ----------------max_pool2d_with_index_bf16----------------
def create_test_bf16_class(parent):
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and do not support bfloat16",
)
class TestMaxPool2dBF16(parent):
def init_dtype(self):
self.dtype = np.uint16
def get_numeric_grad(self, place, check_name):
scope = core.Scope()
self._check_grad_helper()
op = create_op(
scope, self.op_type, self.inputs, self.outputs, self.attrs
)
return get_numeric_gradient(
place, scope, op, self.inputs_fp32, check_name, ['Out']
)
def test_check_output(self):
place = core.CUDAPlace(0)
if core.is_bfloat16_supported(place):
self.check_output_with_place(place)
def test_check_grad(self):
place = core.CUDAPlace(0)
numeric_grads = self.get_numeric_grad(place, 'X')
if core.is_bfloat16_supported(place):
self.check_grad_with_place(
place, {'X'}, ['Out'], user_defined_grads=[numeric_grads]
)
cls_name = "{}_{}".format(parent.__name__, "BF16OP")
TestMaxPool2dBF16.__name__ = cls_name
globals()[cls_name] = TestMaxPool2dBF16
create_test_bf16_class(TestCase4)
create_test_bf16_class(TestCase5)
create_test_bf16_class(TestCase6)
create_test_bf16_class(TestCase7)
create_test_bf16_class(TestCastAdaptive2d)
if __name__ == '__main__':
......
......@@ -15,7 +15,7 @@
import unittest
import numpy as np
from eager_op_test import OpTest, skip_check_grad_ci
from eager_op_test import OpTest, convert_float_to_uint16, skip_check_grad_ci
import paddle
import paddle.nn.functional as F
......@@ -174,7 +174,11 @@ class PReluTest(OpTest):
self.op_type = "prelu"
self.python_api = prelu_api_wrapper
x_np = np.random.uniform(-1, 1, self.x_shape).astype(self.dtype)
if self.dtype == np.uint16:
as_type = self.np_dtype
else:
as_type = self.dtype
x_np = np.random.uniform(-1, 1, self.x_shape).astype(as_type)
# Since zero point in prelu is not differentiable, avoid randomize
# zero.
x_np[np.abs(x_np) < 0.005] = 0.02
......@@ -190,7 +194,7 @@ class PReluTest(OpTest):
alpha_np = np.random.uniform(-1, -0.5, [1, 1, 1, self.x_shape[-1]])
else:
alpha_np = np.random.uniform(-1, -0.5, [1] + self.x_shape[1:])
alpha_np = alpha_np.astype(self.dtype)
alpha_np = alpha_np.astype(as_type)
self.inputs = {'X': x_np, 'Alpha': alpha_np}
......@@ -393,18 +397,48 @@ def create_test_fp16_class(
def test_check_grad(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place) and check_grad:
self.check_grad_with_place(
place,
['X', 'Alpha'],
'Out',
max_relative_error=max_relative_error,
)
# Use the default max_relative_error, not use max_relative_error
self.check_grad_with_place(place, ['X', 'Alpha'], 'Out')
cls_name = "{}_{}".format(parent.__name__, "Fp16Op")
TestPReluFp16Case.__name__ = cls_name
globals()[cls_name] = TestPReluFp16Case
def create_test_bf16_class(
parent, check_grad=True, atol=1e-3, max_relative_error=0.05
):
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestPReluBF16Op(parent):
def setUp(self):
super().setUp()
self.inputs['X'] = convert_float_to_uint16(self.inputs['X'])
self.inputs['Alpha'] = convert_float_to_uint16(self.inputs['Alpha'])
self.outputs['Out'] = convert_float_to_uint16(self.outputs['Out'])
def init_dtype(self):
self.dtype = np.uint16
self.np_dtype = np.float32
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=atol)
def test_check_grad(self):
place = core.CUDAPlace(0)
if check_grad:
# Use the default max_relative_error, not use max_relative_error
self.check_grad_with_place(place, ['X', 'Alpha'], 'Out')
cls_name = "{}_{}".format(parent.__name__, "BF16Op")
TestPReluBF16Op.__name__ = cls_name
globals()[cls_name] = TestPReluBF16Op
create_test_fp16_class(TestModeElt)
create_test_fp16_class(TestModeAllRank3)
create_test_fp16_class(TestModeAllRank6)
......@@ -420,6 +454,21 @@ create_test_fp16_class(TestModeChannelRank6NHWC)
create_test_fp16_class(TestModeElementRank3NHWC)
create_test_fp16_class(TestModeElementRank6NHWC)
create_test_bf16_class(TestModeElt)
create_test_bf16_class(TestModeAllRank3)
create_test_bf16_class(TestModeAllRank6)
create_test_bf16_class(TestModeChannelRank3)
create_test_bf16_class(TestModeChannelRank6)
create_test_bf16_class(TestModeElementRank3)
create_test_bf16_class(TestModeElementRank6)
create_test_bf16_class(TestModeEltNHWC)
create_test_bf16_class(TestModeAllRank3NHWC)
create_test_bf16_class(TestModeAllRank6NHWC)
create_test_bf16_class(TestModeChannelRank3NHWC)
create_test_bf16_class(TestModeChannelRank6NHWC)
create_test_bf16_class(TestModeElementRank3NHWC)
create_test_bf16_class(TestModeElementRank6NHWC)
def prelu_t(x, mode, param_attr=None, name=None, data_format='NCHW'):
helper = fluid.layer_helper.LayerHelper('prelu', **locals())
......
......@@ -16,7 +16,7 @@ import copy
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
from paddle.framework import core
......@@ -28,19 +28,18 @@ class TestPutAlongAxisOp(OpTest):
def setUp(self):
self.init_data()
self.reduce_op = "assign"
self.dtype = 'float64'
self.op_type = "put_along_axis"
self.python_api = paddle.tensor.put_along_axis
self.xnp = np.random.random(self.x_shape).astype(self.x_type)
# numpy put_along_axis is an inplace opearion.
# numpy put_along_axis is an inplace operation.
self.xnp_result = copy.deepcopy(self.xnp)
np.put_along_axis(self.xnp_result, self.index, self.value, self.axis)
self.target = self.xnp_result
broadcast_shape_list = list(self.x_shape)
broadcast_shape_list[self.axis] = 1
self.braodcast_shape = tuple(broadcast_shape_list)
self.index_broadcast = np.broadcast_to(self.index, self.braodcast_shape)
self.value_broadcast = np.broadcast_to(self.value, self.braodcast_shape)
self.broadcast_shape = tuple(broadcast_shape_list)
self.index_broadcast = np.broadcast_to(self.index, self.broadcast_shape)
self.value_broadcast = np.broadcast_to(self.value, self.broadcast_shape)
self.inputs = {
'Input': self.xnp,
'Index': self.index_broadcast,
......@@ -56,6 +55,7 @@ class TestPutAlongAxisOp(OpTest):
self.check_grad(["Input", "Value"], "Result")
def init_data(self):
self.dtype = 'float64'
self.x_type = "float64"
self.x_shape = (10, 10, 10)
self.value_type = "float64"
......@@ -66,6 +66,71 @@ class TestPutAlongAxisOp(OpTest):
self.axis_type = "int64"
class TestPutAlongAxisFP16Op(TestPutAlongAxisOp):
def init_data(self):
self.dtype = np.float16
self.x_type = "float16"
self.x_shape = (10, 10, 10)
self.value_type = "float16"
self.value = np.array([99]).astype(self.value_type)
self.index_type = "int32"
self.index = np.array([[[0]]]).astype(self.index_type)
self.axis = 1
self.axis_type = "int64"
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestPutAlongAxisBF16Op(OpTest):
def setUp(self):
self.init_data()
self.reduce_op = "assign"
self.op_type = "put_along_axis"
self.python_api = paddle.tensor.put_along_axis
self.xnp = np.random.random(self.x_shape).astype(self.x_type)
# numpy put_along_axis is an inplace operation.
self.xnp_result = copy.deepcopy(self.xnp)
np.put_along_axis(self.xnp_result, self.index, self.value, self.axis)
self.target = self.xnp_result
broadcast_shape_list = list(self.x_shape)
broadcast_shape_list[self.axis] = 1
self.broadcast_shape = tuple(broadcast_shape_list)
self.index_broadcast = np.broadcast_to(self.index, self.broadcast_shape)
self.value_broadcast = np.broadcast_to(self.value, self.broadcast_shape)
self.inputs = {
'Input': self.xnp,
'Index': self.index_broadcast,
'Value': self.value_broadcast,
}
self.attrs = {'Axis': self.axis, 'Reduce': self.reduce_op}
self.outputs = {'Result': self.target}
self.inputs['Input'] = convert_float_to_uint16(self.inputs['Input'])
self.inputs['Value'] = convert_float_to_uint16(self.inputs['Value'])
self.outputs['Result'] = convert_float_to_uint16(self.outputs['Result'])
self.place = core.CUDAPlace(0)
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ["Input", "Value"], "Result")
def init_data(self):
self.dtype = np.uint16
self.x_type = "float32"
self.x_shape = (10, 10, 10)
self.value_type = "float32"
self.value = np.array([99]).astype(self.value_type)
self.index_type = "int32"
self.index = np.array([[[0]]]).astype(self.index_type)
self.axis = 1
self.axis_type = "int64"
class TestPutAlongAxisAPI(unittest.TestCase):
def setUp(self):
np.random.seed(0)
......
......@@ -15,7 +15,11 @@
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import (
OpTest,
convert_float_to_uint16,
convert_uint16_to_float,
)
import paddle
from paddle.fluid import core
......@@ -40,12 +44,21 @@ def error_msg(data_np):
def convert_dtype(dtype_str):
dtype_str_list = ["int32", "int64", "float32", "float64"]
dtype_str_list = [
"int32",
"int64",
"float16",
"float32",
"float64",
"uint16",
]
dtype_num_list = [
core.VarDesc.VarType.INT32,
core.VarDesc.VarType.INT64,
core.VarDesc.VarType.FP16,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP64,
core.VarDesc.VarType.BF16,
]
assert dtype_str in dtype_str_list, (
dtype_str + " should in " + str(dtype_str_list)
......@@ -62,9 +75,9 @@ class TestRandpermOp(OpTest):
self.n = 200
self.dtype = "int64"
self.init_attrs()
self.inputs = {}
self.outputs = {"Out": np.zeros(self.n).astype(self.dtype)}
self.init_attrs()
self.attrs = {
"n": self.n,
"dtype": convert_dtype(self.dtype),
......@@ -103,6 +116,47 @@ class TestRandpermOpFloat64(TestRandpermOp):
self.dtype = "float64"
class TestRandpermFP16Op(TestRandpermOp):
def init_attrs(self):
self.dtype = "float16"
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestRandpermBF16Op(OpTest):
def setUp(self):
self.op_type = "randperm"
self.python_api = paddle.randperm
self.n = 200
self.init_attrs()
self.inputs = {}
self.outputs = {"Out": np.zeros(self.n).astype(self.np_dtype)}
self.attrs = {
"n": self.n,
"dtype": convert_dtype(self.dtype),
}
self.outputs['Out'] = convert_float_to_uint16(self.outputs['Out'])
self.place = core.CUDAPlace(0)
def init_attrs(self):
self.dtype = "uint16"
self.np_dtype = np.float32
def test_check_output(self):
self.check_output_with_place_customized(self.verify_output, self.place)
def verify_output(self, outs):
out_np = convert_uint16_to_float(np.array(outs[0]))
self.assertTrue(
check_randperm_out(self.n, out_np), msg=error_msg(out_np)
)
class TestRandpermOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
......
......@@ -418,6 +418,51 @@ class TestMin8DOp(OpTest):
self.check_output()
@skip_check_grad_ci(
reason="reduce_min is discontinuous non-derivable function,"
" its gradient check is not supported by unittest framework."
)
class TestMinFP16Op(OpTest):
"""Remove Min with subgradient from gradient check to confirm the success of CI."""
def setUp(self):
self.op_type = "reduce_min"
self.python_api = paddle.min
self.public_python_api = paddle.min
self.init_dtype()
if self.dtype == np.uint16:
x = np.random.random((5, 6, 10)).astype(np.float32)
self.inputs = {'X': convert_float_to_uint16(x)}
else:
x = np.random.random((5, 6, 10)).astype(self.dtype)
self.inputs = {'X': x}
self.attrs = {'dim': [2], 'keep_dim': True}
out = x.min(axis=tuple(self.attrs['dim']), keepdims=True)
if self.dtype == np.uint16:
self.outputs = {'Out': convert_float_to_uint16(out)}
else:
self.outputs = {'Out': out}
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
self.check_output()
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the bfloat16",
)
class TestMinBF16Op(TestMinFP16Op):
def init_dtype(self):
self.dtype = np.uint16
def test_check_output(self):
self.check_output_with_place(core.CUDAPlace(0))
def raw_reduce_prod(x, dim=[0], keep_dim=False):
return paddle.prod(x, dim, keep_dim)
......
......@@ -65,7 +65,7 @@ class TestSplitOp(OpTest):
# test with attr(num)
class TestSplitOp_2(OpTest):
class TestSplitWithNumOp(OpTest):
def setUp(self):
self.python_api = paddle.split
self.public_python_api = paddle.split
......@@ -74,17 +74,31 @@ class TestSplitOp_2(OpTest):
self.prim_op_type = "prim"
self.dtype = self.get_dtype()
self.init_data()
self.inputs = {'X': self.x}
self.attrs = {
'axis': self.axis,
'sections': self.sections,
'num': self.num,
}
if self.dtype == np.uint16:
self.inputs = {'X': convert_float_to_uint16(self.x)}
out = np.split(self.x, self.indices_or_sections, self.axis)
self.outputs = {'Out': [('out%d' % i, out[i]) for i in range(len(out))]}
self.outputs = {
'Out': [
('out%d' % i, convert_float_to_uint16(out[i]))
for i in range(len(out))
]
}
else:
self.inputs = {'X': self.x}
out = np.split(self.x, self.indices_or_sections, self.axis)
self.outputs = {
'Out': [('out%d' % i, out[i]) for i in range(len(out))]
}
def init_data(self):
if self.dtype == np.uint16:
self.x = np.random.random((4, 5, 6)).astype(np.float32)
else:
self.x = np.random.random((4, 5, 6)).astype(self.dtype)
self.axis = 2
self.sections = []
......@@ -240,28 +254,28 @@ def create_test_fp16(parent):
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestSplitFp16(parent):
class TestSplitFP16Op(parent):
def get_dtype(self):
return np.float16
def test_check_grad(self):
pass
cls_name = "{}_{}".format(parent.__name__, "Fp16")
TestSplitFp16.__name__ = cls_name
globals()[cls_name] = TestSplitFp16
cls_name = "{}_{}".format(parent.__name__, "FP16Op")
TestSplitFP16Op.__name__ = cls_name
globals()[cls_name] = TestSplitFP16Op
create_test_fp16(TestSplitOp)
create_test_fp16(TestSplitWithNumOp)
# ----------------Split Bf16----------------
def create_test_bf16(parent):
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestSplitBf16(parent):
class TestSplitBF16Op(parent):
def get_dtype(self):
return np.uint16
......@@ -270,14 +284,16 @@ def create_test_bf16(parent):
self.check_output_with_place(place)
def test_check_grad(self):
pass
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'out2')
cls_name = "{}_{}".format(parent.__name__, "Bf16")
TestSplitBf16.__name__ = cls_name
globals()[cls_name] = TestSplitBf16
cls_name = "{}_{}".format(parent.__name__, "BF16Op")
TestSplitBF16Op.__name__ = cls_name
globals()[cls_name] = TestSplitBF16Op
create_test_bf16(TestSplitOp)
create_test_bf16(TestSplitWithNumOp)
class TestSplitAPI(unittest.TestCase):
......
......@@ -15,7 +15,7 @@
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
from paddle.framework import core
......@@ -32,8 +32,8 @@ class TestTakeAlongAxisOp(OpTest):
self.target = np.take_along_axis(self.xnp, self.index, self.axis)
broadcast_shape_list = list(self.x_shape)
broadcast_shape_list[self.axis] = 1
self.braodcast_shape = tuple(broadcast_shape_list)
self.index_broadcast = np.broadcast_to(self.index, self.braodcast_shape)
self.broadcast_shape = tuple(broadcast_shape_list)
self.index_broadcast = np.broadcast_to(self.index, self.broadcast_shape)
self.inputs = {
'Input': self.xnp,
'Index': self.index_broadcast,
......@@ -58,6 +58,64 @@ class TestTakeAlongAxisOp(OpTest):
self.axis_type = "int64"
class TestTakeAlongAxisFP16Op(TestTakeAlongAxisOp):
def init_data(self):
self.dtype = np.float16
self.x_type = "float16"
self.x_shape = (5, 5, 5)
self.index_type = "int32"
self.index = np.array([[[1]], [[1]], [[2]], [[4]], [[3]]]).astype(
self.index_type
)
self.axis = 2
self.axis_type = "int64"
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestTakeAlongAxisBF16Op(OpTest):
def setUp(self):
self.init_data()
self.op_type = "take_along_axis"
self.python_api = paddle.tensor.take_along_axis
self.xnp = np.random.random(self.x_shape).astype(self.x_type)
self.target = np.take_along_axis(self.xnp, self.index, self.axis)
broadcast_shape_list = list(self.x_shape)
broadcast_shape_list[self.axis] = 1
self.broadcast_shape = tuple(broadcast_shape_list)
self.index_broadcast = np.broadcast_to(self.index, self.broadcast_shape)
self.inputs = {
'Input': self.xnp,
'Index': self.index_broadcast,
}
self.attrs = {'Axis': self.axis}
self.outputs = {'Result': self.target}
self.inputs['Input'] = convert_float_to_uint16(self.inputs['Input'])
self.outputs['Result'] = convert_float_to_uint16(self.outputs['Result'])
self.place = core.CUDAPlace(0)
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['Input'], 'Result')
def init_data(self):
self.dtype = np.uint16
self.x_type = "float32"
self.x_shape = (5, 5, 5)
self.index_type = "int32"
self.index = np.array([[[1]], [[1]], [[2]], [[4]], [[3]]]).astype(
self.index_type
)
self.axis = 2
self.axis_type = "int64"
class TestCase1(TestTakeAlongAxisOp):
def init_data(self):
self.x_type = "float64"
......
......@@ -15,7 +15,7 @@
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
from paddle import fluid, tensor
......@@ -68,6 +68,82 @@ class TestTraceOpCase2(TestTraceOp):
)
class TestTraceFP16Op1(TestTraceOp):
def init_config(self):
self.dtype = np.float16
self.case = np.random.randn(20, 6).astype(self.dtype)
self.inputs = {'Input': self.case}
self.attrs = {'offset': 0, 'axis1': 0, 'axis2': 1}
self.target = np.trace(self.inputs['Input'])
class TestTraceFP16Op2(TestTraceOp):
def init_config(self):
self.dtype = np.float16
self.case = np.random.randn(2, 20, 2, 3).astype(self.dtype)
self.inputs = {'Input': self.case}
self.attrs = {'offset': -5, 'axis1': 1, 'axis2': -1}
self.target = np.trace(
self.inputs['Input'],
offset=self.attrs['offset'],
axis1=self.attrs['axis1'],
axis2=self.attrs['axis2'],
)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestTraceBF16Op1(OpTest):
def setUp(self):
self.op_type = "trace"
self.python_api = paddle.trace
self.init_config()
self.outputs = {'Out': self.target}
self.inputs['Input'] = convert_float_to_uint16(self.inputs['Input'])
self.outputs['Out'] = convert_float_to_uint16(self.outputs['Out'])
self.place = core.CUDAPlace(0)
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(
self.place, ['Input'], 'Out', numeric_grad_delta=0.02
)
def init_config(self):
self.dtype = np.uint16
self.np_dtype = np.float32
self.case = np.random.randn(20, 6).astype(self.np_dtype)
self.inputs = {'Input': self.case}
self.attrs = {'offset': 0, 'axis1': 0, 'axis2': 1}
self.target = np.trace(self.inputs['Input'])
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestTraceBF16Op2(TestTraceBF16Op1):
def init_config(self):
self.dtype = np.uint16
self.np_dtype = np.float32
self.case = np.random.randn(2, 20, 2, 3).astype(self.np_dtype)
self.inputs = {'Input': self.case}
self.attrs = {'offset': -5, 'axis1': 1, 'axis2': -1}
self.target = np.trace(
self.inputs['Input'],
offset=self.attrs['offset'],
axis1=self.attrs['axis1'],
axis2=self.attrs['axis2'],
)
class TestTraceAPICase(unittest.TestCase):
def test_case1(self):
case = np.random.randn(2, 20, 2, 3).astype('float32')
......
......@@ -15,7 +15,7 @@
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
from paddle import fluid
......@@ -42,7 +42,11 @@ class TestUnfoldOp(OpTest):
self.input_height,
self.input_width,
]
self.x = np.random.rand(*input_shape).astype(np.float64)
if self.dtype == np.uint16:
as_type = self.np_dtype
else:
as_type = self.dtype
self.x = np.random.rand(*input_shape).astype(as_type)
def calc_unfold(self):
output_shape = [0] * 3
......@@ -77,7 +81,11 @@ class TestUnfoldOp(OpTest):
+ 1
)
output_shape[2] = out_height * out_width
output = np.zeros(output_shape).astype(np.float64)
if self.dtype == np.uint16:
as_type = self.np_dtype
else:
as_type = self.dtype
output = np.zeros(output_shape).astype(as_type)
# ------------ calculate output -------------- #
for i in range(output_shape[0]):
for j in range(output_shape[1]):
......@@ -123,9 +131,13 @@ class TestUnfoldOp(OpTest):
def setUp(self):
self.op_type = 'unfold'
self.init_dtype()
self.python_api = paddle.nn.functional.unfold
self.set_data()
def init_dtype(self):
self.dtype = np.float64
def test_check_output(self):
self.check_output()
......@@ -133,6 +145,55 @@ class TestUnfoldOp(OpTest):
self.check_grad(['X'], 'Y')
class TestUnfoldFP16Op(TestUnfoldOp):
def init_dtype(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestUnfoldBF16Op(TestUnfoldOp):
# Notice: The test is time consuming, may cause timeout, modify the parameters to reduce the time
def init_data(self):
self.batch_size = 3
self.input_channels = 3
self.input_height = 5
self.input_width = 5
self.kernel_sizes = [3, 3]
self.strides = [1, 1]
self.paddings = [1, 1, 1, 1]
self.dilations = [1, 1]
input_shape = [
self.batch_size,
self.input_channels,
self.input_height,
self.input_width,
]
self.x = np.random.rand(*input_shape).astype(self.np_dtype)
def init_dtype(self):
self.dtype = np.uint16
self.np_dtype = np.float32
def setUp(self):
self.op_type = 'unfold'
self.init_dtype()
self.python_api = paddle.nn.functional.unfold
self.set_data()
self.inputs['X'] = convert_float_to_uint16(self.inputs['X'])
self.outputs['Y'] = convert_float_to_uint16(self.outputs['Y'])
self.place = core.CUDAPlace(0)
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Y')
class TestUnfoldAPI(TestUnfoldOp):
"""
This is for test on paddle.nn.Unfold
......
......@@ -15,9 +15,19 @@
import unittest
import numpy as np
from eager_op_test import OpTest, convert_uint16_to_float
import paddle
from paddle import fluid
from paddle.fluid import core
def output_hist(out):
hist, _ = np.histogram(out, range=(-1, 1))
hist = hist.astype("float32")
hist /= float(out.size)
prob = 0.1 * np.ones(10)
return hist, prob
class TestUniformRandomInplaceOpDtype(unittest.TestCase):
......@@ -44,6 +54,72 @@ class TestUniformRandomInplaceOpDtype(unittest.TestCase):
test_fp64()
class TestUniformRandomInplaceFP16Op(OpTest):
def setUp(self):
self.op_type = "uniform_random_inplace"
self.dtype = np.float16
self.shape = (1000, 784)
x = np.random.random(self.shape).astype(self.dtype)
y = np.random.random(self.shape).astype(self.dtype)
self.inputs = {"X": x}
self.outputs = {"Out": y}
self.init_attrs()
def init_attrs(self):
self.output_hist = output_hist
def test_check_output(self):
self.check_output_customized(self.verify_output)
def verify_output(self, outs):
hist, prob = self.output_hist(np.array(outs[0]))
np.testing.assert_allclose(hist, prob, rtol=0, atol=0.001)
# TODO: Due to the lack of the self.python_api=paddle.uniform_random_inplace setting, the dynamic graph is temporarily turned off, set check_dygraph=False
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_dygraph=False)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestUniformRandomInplaceBF16Op(OpTest):
def setUp(self):
self.op_type = "uniform_random_inplace"
self.dtype = np.uint16
self.shape = (1000, 784)
x = np.random.random(self.shape).astype(self.dtype)
y = np.random.random(self.shape).astype(self.dtype)
self.inputs = {'X': x}
self.outputs = {'Out': y}
self.init_attrs()
self.place = core.CUDAPlace(0)
def init_attrs(self):
self.output_hist = output_hist
def test_check_output(self):
self.check_output_with_place_customized(self.verify_output, self.place)
def verify_output(self, outs):
result = convert_uint16_to_float(np.array(outs[0]))
hist, prob = self.output_hist(result)
np.testing.assert_allclose(hist, prob, rtol=0, atol=0.002)
# TODO: Due to the lack of the self.python_api=paddle.uniform_random_inplace setting, the dynamic graph is temporarily turned off, set check_dygraph=False
def test_check_grad(self):
grads = [paddle.zeros(self.shape, dtype=self.dtype)]
self.check_grad_with_place(
self.place,
['X'],
'Out',
check_dygraph=False,
user_defined_grads=grads,
)
class TestUniformRandomInplaceOpIsInplace(unittest.TestCase):
def setUp(self):
self.shape = (1000, 784)
......
......@@ -15,10 +15,11 @@
import unittest
import numpy as np
from eager_op_test import OpTest, paddle_static_guard
from eager_op_test import OpTest, convert_float_to_uint16, paddle_static_guard
import paddle
from paddle import fluid
from paddle.fluid import core
from paddle.static.amp import amp_nn
......@@ -81,8 +82,10 @@ class TestUpdateLossScalingOp(OpTest):
def init(self):
self.incr_ratio = 2.0
self.decr_ratio = 0.8
self.dtype = np.float32
self.prev_loss_scaling = np.array([2048]).astype(self.dtype)
self.init_dtype()
self.prev_loss_scaling = np.array([2048]).astype(
self.loss_scaling_dtype
)
self.num_good_steps = np.array([999], dtype=np.int32)
self.num_bad_steps = np.array([1], dtype=np.int32)
self.zero_steps = np.array([0], dtype=np.int32)
......@@ -94,10 +97,77 @@ class TestUpdateLossScalingOp(OpTest):
'decr_ratio': self.decr_ratio,
}
def init_dtype(self):
self.dtype = np.float32
self.loss_scaling_dtype = np.float32
def test_check_output(self):
self.check_output(no_check_set=['Out'])
class TestUpdateLossScalingFP16Op(TestUpdateLossScalingOp):
def init_dtype(self):
self.dtype = np.float16
self.loss_scaling_dtype = np.float32
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestUpdateLossScalingBF16Op(OpTest):
def init(self):
self.incr_ratio = 2.0
self.decr_ratio = 0.8
self.dtype = np.uint16
self.np_dtype = "float32"
self.prev_loss_scaling = np.array([2048]).astype(self.np_dtype)
self.num_good_steps = np.array([999], dtype=np.int32)
self.num_bad_steps = np.array([1], dtype=np.int32)
self.zero_steps = np.array([0], dtype=np.int32)
self.stop_update = np.array([False], dtype=np.bool_)
self.attrs = {
'incr_every_n_steps': 1000,
'decr_every_n_nan_or_inf': 2,
'incr_ratio': self.incr_ratio,
'decr_ratio': self.decr_ratio,
}
def setUp(self):
self.op_type = "update_loss_scaling"
self.init()
self.python_api = update_loss_scaling_wrapper
self.python_out_sig = [
"out0",
"LossScaling",
"OutGoodSteps",
"OutBadSteps",
]
found_inf = np.array([False], dtype=np.bool_)
x = np.random.random((1024, 1024)).astype(self.np_dtype)
self.inputs = {
'X': [('x0', convert_float_to_uint16(x))],
'FoundInfinite': found_inf,
# do not convert
'PrevLossScaling': self.prev_loss_scaling,
'InGoodSteps': self.num_good_steps,
'InBadSteps': self.num_bad_steps,
}
self.outputs = {
'Out': [('out0', convert_float_to_uint16(x))],
# do not convert
'LossScaling': self.prev_loss_scaling * self.incr_ratio,
'OutGoodSteps': self.zero_steps,
'OutBadSteps': self.zero_steps,
}
def test_check_output(self):
self.check_output_with_place(core.CUDAPlace(0), no_check_set=['Out'])
class TestUpdateLossScalingOpBad(TestUpdateLossScalingOp):
def setUp(self):
self.op_type = "update_loss_scaling"
......
......@@ -63,7 +63,9 @@ def clip_by_norm(x, max_norm, name=None):
return _legacy_C_ops.clip_by_norm(x, 'max_norm', max_norm)
helper = LayerHelper("clip_by_norm", **locals())
check_variable_and_dtype(x, 'X', ['float32', 'float16'], 'clip_by_norm')
check_variable_and_dtype(
x, 'X', ['float16', 'float32', 'uint16'], 'clip_by_norm'
)
check_type(max_norm, 'max_norm', (float), 'clip_by_norm')
if name is None:
......
......@@ -538,10 +538,13 @@ def prelu(x, weight, data_format="NCHW", name=None):
return _C_ops.prelu(x, weight, data_format, mode)
else:
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'prelu'
x, 'x', ['float16', 'float32', 'float64', 'uint16'], 'prelu'
)
check_variable_and_dtype(
weight, 'weight', ['float16', 'float32', 'float64'], 'prelu'
weight,
'weight',
['float16', 'float32', 'float64', 'uint16'],
'prelu',
)
helper = LayerHelper('prelu', **locals())
out = helper.create_variable_for_type_inference(x.dtype)
......
......@@ -2498,7 +2498,7 @@ def multi_dot(x, name=None):
check_variable_and_dtype(
item,
'x[' + str(id) + ']',
['float16', 'float32', 'float64'],
['float16', 'float32', 'float64', 'uint16'],
'multi_dot',
)
if item.dtype != x[0].dtype:
......
......@@ -1240,7 +1240,15 @@ def broadcast_tensors(input, name=None):
check_variable_and_dtype(
x,
'input[' + str(id) + ']',
['bool', 'float32', 'float64', 'int32', 'int64'],
[
'bool',
'float16',
'float32',
'float64',
'int32',
'int64',
'uint16',
],
'broadcast_tensors',
)
if x.dtype != input[0].dtype:
......@@ -1977,6 +1985,7 @@ def split(x, num_or_sections, axis=0, name=None):
'int32',
'int64',
'uint8',
'uint16',
'int8',
],
'split',
......@@ -4559,7 +4568,15 @@ def take_along_axis(arr, indices, axis):
check_variable_and_dtype(
arr,
'x',
['float16', 'float32', 'float64', 'int32', 'int64', 'uint8'],
[
'float16',
'float32',
'float64',
'int32',
'int64',
'uint8',
'uint16',
],
'take_along_axis',
)
check_variable_and_dtype(
......@@ -4631,7 +4648,15 @@ def put_along_axis(arr, indices, values, axis, reduce='assign'):
check_variable_and_dtype(
arr,
'x',
['float16', 'float32', 'float64', 'int32', 'int64', 'uint8'],
[
'float16',
'float32',
'float64',
'int32',
'int64',
'uint8',
'uint16',
],
'put_along_axis',
)
check_variable_and_dtype(
......@@ -4713,7 +4738,7 @@ def index_add(x, index, axis, value, name=None):
check_variable_and_dtype(
x,
'x',
['float16', 'float32', 'float64', 'int32', 'int64'],
['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'paddle.tensor.manipulation.index_add',
)
check_variable_and_dtype(
......@@ -4725,7 +4750,7 @@ def index_add(x, index, axis, value, name=None):
check_variable_and_dtype(
value,
'add_value',
['float16', 'float32', 'float64', 'int32', 'int64'],
['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'paddle.tensor.manipulation.index_add',
)
......
......@@ -1958,10 +1958,14 @@ def addmm(input, x, y, beta=1.0, alpha=1.0, name=None):
helper = LayerHelper("addmm", **locals())
check_variable_and_dtype(
input, 'Input', ['float32', 'float64'], 'addmm'
input, 'Input', ['float16', 'float32', 'float64', 'uint16'], 'addmm'
)
check_variable_and_dtype(
x, 'X', ['float16', 'float32', 'float64', 'uint16'], 'addmm'
)
check_variable_and_dtype(
y, 'Y', ['float16', 'float32', 'float64', 'uint16'], 'addmm'
)
check_variable_and_dtype(x, 'X', ['float32', 'float64'], 'addmm')
check_variable_and_dtype(y, 'Y', ['float32', 'float64'], 'addmm')
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
......@@ -2456,7 +2460,10 @@ def min(x, axis=None, keepdim=False, name=None):
reduce_all, axis = _get_reduce_axis_with_tensor(axis, x)
helper = LayerHelper('min', **locals())
check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'int32', 'int64'], 'min'
x,
'x',
['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'],
'min',
)
out = helper.create_variable_for_type_inference(dtype=x.dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册