未验证 提交 2ddd0473 编写于 作者: H Hui Zhang 提交者: GitHub

log/Log10/log2/log1p support int32/int64/float16/bfloat16 forward (#54089)

* fix for log xxx

* add int32/int64 for cpu/gpu; add float16/bfloat16 for cpu forward

* fix docstring

* fix bug

* fix bugs

* fix bugs

* fix bugs

* fix bugs

* fix bug

* using cast

* fix test

* fix api

* fix other bugs

* fix ci bug for not using dygraph guard

* add bfloat16 test

* fix ut

* bf16
上级 685c0a49
......@@ -26,7 +26,19 @@ namespace phi {
void name##Kernel( \
const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) { \
funcs::functor_class<T> functor; \
ActivationImpl<T, Context, funcs::functor_class<T>>( \
ActivationImpl<T, T, Context, funcs::functor_class<T>>( \
dev_ctx, x, out, functor); \
}
#define DEFINE_CPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(name, \
functor_class) \
template <typename T, typename Context> \
void name##Kernel( \
const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) { \
funcs::functor_class<T> functor; \
using U = \
typename std::conditional_t<std::is_integral<T>::value, float, T>; \
ActivationImpl<T, U, Context, funcs::functor_class<T>>( \
dev_ctx, x, out, functor); \
}
......@@ -39,7 +51,7 @@ namespace phi {
funcs::functor_class<T> functor; \
auto attrs = functor.GetAttrs(); \
*(attrs[0].second) = attr; \
ActivationImpl<T, Context, funcs::functor_class<T>>( \
ActivationImpl<T, T, Context, funcs::functor_class<T>>( \
dev_ctx, x, out, functor); \
}
......@@ -55,7 +67,7 @@ namespace phi {
auto attrs = functor.GetAttrs(); \
*(attrs[0].second) = attr1; \
*(attrs[1].second) = attr2; \
ActivationImpl<T, Context, funcs::functor_class<T>>( \
ActivationImpl<T, T, Context, funcs::functor_class<T>>( \
dev_ctx, x, out, functor); \
}
......@@ -83,15 +95,16 @@ DEFINE_CPU_ACTIVATION_KERNEL(Rsqrt, RsqrtFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Softsign, SoftsignFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Sigmoid, SigmoidFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(LogSigmoid, LogSigmoidFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Log, LogFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Log2, Log2Functor)
DEFINE_CPU_ACTIVATION_KERNEL(Log10, Log10Functor)
DEFINE_CPU_ACTIVATION_KERNEL(Log1p, Log1pFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Round, RoundFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Floor, FloorFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Ceil, CeilFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Negative, NegativeFunctor)
DEFINE_CPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log, LogFunctor)
DEFINE_CPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log2, Log2Functor)
DEFINE_CPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log10, Log10Functor)
DEFINE_CPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log1p, Log1pFunctor)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, LeakyReluFunctor, alpha)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu,
ThresholdedReluFunctor,
......@@ -124,7 +137,7 @@ void HardSwishKernel(const Context& dev_ctx,
*(attrs[0].second) = threshold;
*(attrs[1].second) = scale;
*(attrs[2].second) = offset;
ActivationImpl<T, Context, funcs::HardSwishFunctor<T>>(
ActivationImpl<T, T, Context, funcs::HardSwishFunctor<T>>(
dev_ctx, x, out, functor);
}
......@@ -178,10 +191,48 @@ PD_REGISTER_ACTIVATION_KERNEL(softsign, SoftsignKernel)
PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(logsigmoid, LogSigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_sigmoid, HardSigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(log, LogKernel)
PD_REGISTER_ACTIVATION_KERNEL(log2, Log2Kernel)
PD_REGISTER_ACTIVATION_KERNEL(log10, Log10Kernel)
PD_REGISTER_ACTIVATION_KERNEL(log1p, Log1pKernel)
PD_REGISTER_KERNEL(log,
CPU,
ALL_LAYOUT,
phi::LogKernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(log2,
CPU,
ALL_LAYOUT,
phi::Log2Kernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(log10,
CPU,
ALL_LAYOUT,
phi::Log10Kernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(log1p,
CPU,
ALL_LAYOUT,
phi::Log1pKernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_ACTIVATION_KERNEL(swish_raw, SwishRawKernel)
PD_REGISTER_ACTIVATION_KERNEL(hardswish, HardSwishKernel)
PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel)
......
......@@ -1996,12 +1996,33 @@ struct HardSigmoidGradFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct Log {
HOSTDEVICE T operator()(const T& val) const { return std::log(val); }
};
template <>
struct Log<dtype::float16> {
HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
return dtype::float16(std::log(static_cast<float>(val)));
}
};
template <>
struct Log<dtype::bfloat16> {
HOSTDEVICE dtype::bfloat16 operator()(const dtype::bfloat16& val) const {
return dtype::bfloat16(std::log(static_cast<float>(val)));
}
};
// log(x) = natural logarithm of x
template <typename T>
struct LogFunctor : public BaseActivationFunctor<T> {
using U = typename std::conditional_t<std::is_integral<T>::value, float, T>;
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.log();
out.device(d) = x.template cast<U>().unaryExpr(Log<U>());
}
};
......@@ -2019,12 +2040,33 @@ struct LogGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct Log2 {
HOSTDEVICE T operator()(const T& val) const { return std::log2(val); }
};
template <>
struct Log2<dtype::float16> {
HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
return dtype::float16(std::log2(static_cast<float>(val)));
}
};
template <>
struct Log2<dtype::bfloat16> {
HOSTDEVICE dtype::bfloat16 operator()(const dtype::bfloat16& val) const {
return dtype::bfloat16(std::log2(static_cast<float>(val)));
}
};
// log2(x) = logarithm to the base 2 of the elements of x
template <typename T>
struct Log2Functor : public BaseActivationFunctor<T> {
using U = typename std::conditional_t<std::is_integral<T>::value, float, T>;
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.log() / static_cast<T>(log(2));
out.device(d) = x.template cast<U>().unaryExpr(Log2<U>());
}
};
......@@ -2043,12 +2085,33 @@ struct Log2GradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct Log10 {
HOSTDEVICE T operator()(const T& val) const { return std::log10(val); }
};
template <>
struct Log10<dtype::float16> {
HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
return dtype::float16(std::log10(static_cast<float>(val)));
}
};
template <>
struct Log10<dtype::bfloat16> {
HOSTDEVICE dtype::bfloat16 operator()(const dtype::bfloat16& val) const {
return dtype::bfloat16(std::log10(static_cast<float>(val)));
}
};
// log10(x) = logarithm to the base 10 of the elements of x
template <typename T>
struct Log10Functor : public BaseActivationFunctor<T> {
using U = typename std::conditional_t<std::is_integral<T>::value, float, T>;
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.log() / static_cast<T>(log(10));
out.device(d) = x.template cast<U>().unaryExpr(Log10<U>());
}
};
......@@ -2067,12 +2130,33 @@ struct Log10GradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct Log1p {
HOSTDEVICE T operator()(const T& val) const { return std::log1p(val); }
};
template <>
struct Log1p<dtype::float16> {
HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
return dtype::float16(std::log1p(static_cast<float>(val)));
}
};
template <>
struct Log1p<dtype::bfloat16> {
HOSTDEVICE dtype::bfloat16 operator()(const dtype::bfloat16& val) const {
return dtype::bfloat16(std::log1p(static_cast<float>(val)));
}
};
// log1p(x) = natural logarithm of x+1
template <typename T>
struct Log1pFunctor : public BaseActivationFunctor<T> {
using U = typename std::conditional_t<std::is_integral<T>::value, float, T>;
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = (static_cast<T>(1) + x).log();
out.device(d) = x.template cast<U>().unaryExpr(Log1p<U>());
}
};
......@@ -3665,14 +3749,35 @@ struct CudaHardSigmoidGradFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
__device__ __forceinline__
std::conditional_t<std::is_integral<T>::value, float, T>
log_local(T x) {
static_assert(!std::is_same<T, double>::value,
"this template must be used with float or less precise type");
#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
// use __logf fast approximation for peak bandwidth
return __logf(x);
#else
return ::log(x);
#endif
}
template <>
__device__ __forceinline__ double log_local<double>(double x) {
return ::log(x);
}
template <typename T>
struct CudaLogFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
using U = typename std::conditional_t<std::is_integral<T>::value, float, T>;
// log(x) = log(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
__device__ __forceinline__ U operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(log(x));
return static_cast<U>(log_local(x));
}
};
......@@ -3690,11 +3795,12 @@ template <typename T>
struct CudaLog1pFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
using U = typename std::conditional_t<std::is_integral<T>::value, float, T>;
// log1p(x) = log(1 + x)
__device__ __forceinline__ T operator()(const T arg_x) const {
__device__ __forceinline__ U operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(log(one + x));
return static_cast<U>(log_local(one + x));
}
};
......@@ -3710,14 +3816,35 @@ struct CudaLog1pGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
__device__ __forceinline__
std::conditional_t<std::is_integral<T>::value, float, T>
log2_local(T x) {
static_assert(!std::is_same<T, double>::value,
"this template must be used with float or less precise type");
#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
// use __logf fast approximation for peak bandwidth
return __log2f(x);
#else
return ::log2(x);
#endif
}
template <>
__device__ __forceinline__ double log2_local<double>(double x) {
return ::log2(x);
}
template <typename T>
struct CudaLog2Functor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
using U = typename std::conditional_t<std::is_integral<T>::value, float, T>;
// log2(x) = log2(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
__device__ __forceinline__ U operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(log2(x));
return static_cast<U>(log2_local(x));
}
};
......@@ -3734,14 +3861,35 @@ struct CudaLog2GradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
__device__ __forceinline__
std::conditional_t<std::is_integral<T>::value, float, T>
log10_local(T x) {
static_assert(!std::is_same<T, double>::value,
"this template must be used with float or less precise type");
#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
// use __logf fast approximation for peak bandwidth
return __log10f(x);
#else
return ::log10(x);
#endif
}
template <>
__device__ __forceinline__ double log10_local(double x) {
return ::log10(x);
}
template <typename T>
struct CudaLog10Functor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
using U = typename std::conditional_t<std::is_integral<T>::value, float, T>;
// log10(x) = log10(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
__device__ __forceinline__ U operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(log10(x));
return static_cast<U>(log10_local(x));
}
};
......
......@@ -47,6 +47,18 @@ void ActivationGPUImpl(const Context& dev_ctx,
dev_ctx, x, out, functor); \
}
#define DEFINE_GPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(name, \
functor_class) \
template <typename T, typename Context> \
void name##Kernel( \
const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) { \
funcs::functor_class<T> functor; \
using U = \
typename std::conditional_t<std::is_integral<T>::value, float, T>; \
ActivationGPUImpl<U, Context, funcs::functor_class<T>>( \
dev_ctx, x, out, functor); \
}
#define DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(name, functor_class, attr) \
template <typename T, typename Context> \
void name##Kernel(const Context& dev_ctx, \
......@@ -100,14 +112,15 @@ DEFINE_GPU_ACTIVATION_KERNEL(Rsqrt, CudaRsqrtFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Softsign, CudaSoftsignFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Sigmoid, CudaSigmoidFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(LogSigmoid, CudaLogSigmoidFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Log, CudaLogFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Log2, CudaLog2Functor)
DEFINE_GPU_ACTIVATION_KERNEL(Log10, CudaLog10Functor)
DEFINE_GPU_ACTIVATION_KERNEL(Log1p, CudaLog1pFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Round, CudaRoundFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Floor, CudaFloorFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Ceil, CudaCeilFunctor)
DEFINE_GPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log, CudaLogFunctor)
DEFINE_GPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log2, CudaLog2Functor)
DEFINE_GPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log10, CudaLog10Functor)
DEFINE_GPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log1p, CudaLog1pFunctor)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, CudaLeakyReluFunctor, alpha)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(LogitCUDA, CudaLogitFunctor, eps)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu,
......@@ -246,10 +259,6 @@ PD_REGISTER_ACTIVATION_KERNEL(softsign, SoftsignKernel)
PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(logsigmoid, LogSigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_sigmoid, HardSigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(log, LogKernel)
PD_REGISTER_ACTIVATION_KERNEL(log2, Log2Kernel)
PD_REGISTER_ACTIVATION_KERNEL(log10, Log10Kernel)
PD_REGISTER_ACTIVATION_KERNEL(log1p, Log1pKernel)
PD_REGISTER_ACTIVATION_KERNEL(hardswish, HardSwishKernel)
PD_REGISTER_ACTIVATION_KERNEL(swish_raw, SwishRawKernel)
PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel)
......@@ -258,6 +267,46 @@ PD_REGISTER_ACTIVATION_KERNEL(ceil, CeilKernel)
PD_REGISTER_ACTIVATION_KERNEL(celu, CeluKernel)
PD_REGISTER_ACTIVATION_KERNEL(logit, LogitCUDAKernel)
PD_REGISTER_KERNEL(log,
GPU,
ALL_LAYOUT,
phi::LogKernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(log2,
GPU,
ALL_LAYOUT,
phi::Log2Kernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(log10,
GPU,
ALL_LAYOUT,
phi::Log10Kernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(log1p,
GPU,
ALL_LAYOUT,
phi::Log1pKernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(pow,
GPU,
ALL_LAYOUT,
......
......@@ -23,17 +23,17 @@ namespace phi {
#define ToString(x) #x
template <typename T, typename Context, typename Functor>
template <typename T, typename U, typename Context, typename Functor>
void ActivationImpl(const Context& dev_ctx,
const DenseTensor& X,
DenseTensor* Out,
const Functor& functor) {
PADDLE_ENFORCE_NOT_NULL(Out,
errors::NotFound("Output Out should not be nullptr"));
dev_ctx.template Alloc<T>(Out);
dev_ctx.template Alloc<U>(Out);
auto x = phi::EigenVector<T>::Flatten(
GET_DATA_SAFELY(&X, "Input", "X", "Activation"));
auto out = phi::EigenVector<T>::Flatten(
auto out = phi::EigenVector<U>::Flatten(
GET_DATA_SAFELY(Out, "Output", "Out", "Activation"));
auto* place = dev_ctx.eigen_device();
// use 32bit index to speed up computation
......
......@@ -137,7 +137,7 @@ def log(x, name=None):
Out = \ln(x)
Args:
x (Tensor): Input Tensor. Must be one of the following types: float16, float32, float64.
x (Tensor): Input Tensor. Must be one of the following types: int32, int64, float16, bfloat16, float32, float64.
name (str|None): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`
......@@ -159,7 +159,10 @@ def log(x, name=None):
return _C_ops.log(x)
else:
check_variable_and_dtype(
x, 'x', ['uint16', 'float16', 'float32', 'float64'], "log"
x,
'x',
['int32', 'int64', 'uint16', 'float16', 'float32', 'float64'],
"log",
)
inputs = {'X': [x]}
helper = LayerHelper('log', **locals())
......@@ -2763,7 +2766,7 @@ def log1p(x, name=None):
Out = \ln(x+1)
Args:
x (Tensor): Input Tensor. Must be one of the following types: float16, float32, float64.
x (Tensor): Input Tensor. Must be one of the following types: int32, int64, float16, bfloat16, float32, float64.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
......@@ -2783,7 +2786,10 @@ def log1p(x, name=None):
return _C_ops.log1p(x)
else:
check_variable_and_dtype(
x, 'x', ['float16', 'uint16', 'float32', 'float64'], "log1p"
x,
'x',
['int32', 'int64', 'float16', 'uint16', 'float32', 'float64'],
"log1p",
)
inputs = {'X': [x]}
helper = LayerHelper('log1p', **locals())
......@@ -2802,7 +2808,7 @@ def log2(x, name=None):
Out = \log_2x
Args:
x (Tensor): Input tensor must be one of the following types: float32, float64.
x (Tensor): Input tensor must be one of the following types: int32, int64, float16, bfloat16, float32, float64.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
......@@ -2835,7 +2841,10 @@ def log2(x, name=None):
return _C_ops.log2(x)
else:
check_variable_and_dtype(
x, 'x', ['float16', 'uint16', 'float32', 'float64'], "log2"
x,
'x',
['int32', 'int64', 'float16', 'uint16', 'float32', 'float64'],
"log2",
)
inputs = {'X': [x]}
helper = LayerHelper('log2', **locals())
......@@ -2854,7 +2863,7 @@ def log10(x, name=None):
Out = \log_10_x
Args:
x (Tensor): Input tensor must be one of the following types: float32, float64.
x (Tensor): Input tensor must be one of the following types: int32, int64, float16, bfloat16, float32, float64.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
......@@ -2887,7 +2896,10 @@ def log10(x, name=None):
return _C_ops.log10(x)
else:
check_variable_and_dtype(
x, 'x', ['float16', 'uint16', 'float32', 'float64'], "log10"
x,
'x',
['int32', 'int64', 'float16', 'uint16', 'float32', 'float64'],
"log10",
)
inputs = {'X': [x]}
helper = LayerHelper('log10', **locals())
......
......@@ -15,6 +15,7 @@
import os
import unittest
import warnings
from contextlib import contextmanager
import numpy as np
from eager_op_test import OpTest, convert_float_to_uint16
......@@ -27,6 +28,15 @@ from paddle.fluid import Program, core, program_guard
from paddle.fluid.layer_helper import LayerHelper
@contextmanager
def dynamic_guad():
paddle.disable_static()
try:
yield
finally:
paddle.enable_static()
class TestSqrtOpError(unittest.TestCase):
def test_errors(self):
with paddle.fluid.framework._static_guard():
......@@ -200,10 +210,14 @@ class TestExpm1API(unittest.TestCase):
run(place)
def test_dygraph_api(self):
with dynamic_guad():
def run(place):
X = paddle.to_tensor(self.x)
out = paddle.expm1(X)
np.testing.assert_allclose(self.out_ref, out.numpy(), rtol=1e-05)
np.testing.assert_allclose(
self.out_ref, out.numpy(), rtol=1e-05
)
for place in self.place:
run(place)
......@@ -381,6 +395,7 @@ class TestSiluAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r, rtol=1e-05)
def test_dygraph_api(self):
paddle.disable_static()
x = paddle.to_tensor(self.x_np)
out1 = F.silu(x)
m = paddle.nn.Silu()
......@@ -388,6 +403,7 @@ class TestSiluAPI(unittest.TestCase):
out_ref = self.x_np / (1 + np.exp(-self.x_np))
for r in [out1, out2]:
np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05)
paddle.enable_static()
def test_errors(self):
with paddle.fluid.framework._static_guard():
......@@ -457,6 +473,7 @@ class TestLogSigmoidAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r, rtol=1e-05)
def test_dygraph_api(self):
paddle.disable_static()
x = paddle.to_tensor(self.x_np)
out1 = F.log_sigmoid(x)
m = paddle.nn.LogSigmoid()
......@@ -464,6 +481,7 @@ class TestLogSigmoidAPI(unittest.TestCase):
out_ref = np.log(1 / (1 + np.exp(-self.x_np)))
for r in [out1, out2]:
np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05)
paddle.enable_static()
def test_errors(self):
with paddle.fluid.framework._static_guard():
......@@ -549,6 +567,7 @@ class TestTanhAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r, rtol=1e-05)
def test_dygraph_api(self):
with dynamic_guad():
x = paddle.to_tensor(self.x_np)
out1 = F.tanh(x)
out2 = paddle.tanh(x)
......@@ -869,6 +888,7 @@ class TestTanhshrinkAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r, rtol=1e-05)
def test_dygraph_api(self):
with dynamic_guad():
x = paddle.to_tensor(self.x_np)
out1 = F.tanhshrink(x)
tanhshrink = paddle.nn.Tanhshrink()
......@@ -969,6 +989,7 @@ class TestHardShrinkAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r, rtol=1e-05)
def test_dygraph_api(self):
with dynamic_guad():
x = paddle.to_tensor(self.x_np)
out1 = F.hardshrink(x)
hd = paddle.nn.Hardshrink()
......@@ -1034,6 +1055,7 @@ class TestHardtanhAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r, rtol=1e-05)
def test_dygraph_api(self):
with dynamic_guad():
x = paddle.to_tensor(self.x_np)
out1 = F.hardtanh(x)
m = paddle.nn.Hardtanh()
......@@ -1129,6 +1151,7 @@ class TestSoftshrinkAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r, rtol=1e-05)
def test_dygraph_api(self):
with dynamic_guad():
x = paddle.to_tensor(self.x_np)
out1 = F.softshrink(x, self.threshold)
softshrink = paddle.nn.Softshrink(self.threshold)
......@@ -1574,6 +1597,7 @@ class TestTanAPI(unittest.TestCase):
)
def test_dygraph_api(self):
with dynamic_guad():
x = paddle.to_tensor(self.x_np)
out_test = paddle.tan(x)
out_ref = np.tan(self.x_np)
......@@ -1875,6 +1899,7 @@ class TestReluAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r, rtol=1e-05)
def test_dygraph_api(self):
with dynamic_guad():
x = paddle.to_tensor(self.x_np)
m = paddle.nn.ReLU()
out1 = m(x)
......@@ -1998,6 +2023,7 @@ class TestLeakyReluAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r, rtol=1e-05)
def test_dygraph_api(self):
with dynamic_guad():
x = paddle.to_tensor(self.x_np)
out1 = F.leaky_relu(x)
m = paddle.nn.LeakyReLU()
......@@ -2153,6 +2179,7 @@ class TestGELUAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r, rtol=1e-05)
def test_dygraph_api(self):
with dynamic_guad():
x = paddle.to_tensor(self.x_np)
out1 = F.gelu(x)
m = paddle.nn.GELU()
......@@ -2278,6 +2305,7 @@ class TestRelu6API(unittest.TestCase):
np.testing.assert_allclose(out_ref, r, rtol=1e-05)
def test_dygraph_api(self):
with dynamic_guad():
x = paddle.to_tensor(self.x_np)
out1 = F.relu6(x)
relu6 = paddle.nn.ReLU6()
......@@ -2421,6 +2449,7 @@ class TestHardswishAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r, rtol=1e-05)
def test_dygraph_api(self):
with dynamic_guad():
x = paddle.to_tensor([11648.0, 11448.0])
out1 = F.hardswish(x)
m = paddle.nn.Hardswish()
......@@ -2439,6 +2468,7 @@ class TestHardswishAPI(unittest.TestCase):
out_ref = ref_hardswish(self.x_np)
np.testing.assert_allclose(out_ref, res[0], rtol=1e-05)
with dynamic_guad():
x = paddle.to_tensor(self.x_np)
out = paddle.nn.functional.hardswish(x)
np.testing.assert_allclose(out_ref, out.numpy(), rtol=1e-05)
......@@ -2567,6 +2597,7 @@ class TestELUAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r, rtol=1e-05)
def test_dygraph_api(self):
with dynamic_guad():
x = paddle.to_tensor(self.x_np)
out1 = self.elu(x)
x = paddle.to_tensor(self.x_np)
......@@ -2607,6 +2638,7 @@ class TestELUInplaceAPI(TestELUAPI):
self.elu = F.elu_
def test_alpha_error(self):
with dynamic_guad():
x = paddle.to_tensor(self.x_np)
self.assertRaises(Exception, F.elu_, x, -0.2)
......@@ -2676,6 +2708,7 @@ class TestCELUAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r, rtol=1e-05)
def test_dygraph_api(self):
with dynamic_guad():
x = paddle.to_tensor(self.x_np)
out1 = self.celu(x, 1.5)
x = paddle.to_tensor(self.x_np)
......@@ -2770,19 +2803,6 @@ class TestLog(TestActivation):
return
self.check_grad(['X'], 'Out', check_prim=True)
def test_error(self):
with paddle.fluid.framework._static_guard():
with paddle.fluid.framework._static_guard():
in1 = paddle.static.data(
name="in1", shape=[11, 17], dtype="int32"
)
in2 = paddle.static.data(
name="in2", shape=[11, 17], dtype="int64"
)
self.assertRaises(TypeError, paddle.log, in1)
self.assertRaises(TypeError, paddle.log, in2)
class Test_Log_Op_Fp16(unittest.TestCase):
def test_api_fp16(self):
......@@ -2798,6 +2818,31 @@ class Test_Log_Op_Fp16(unittest.TestCase):
exe = paddle.static.Executor(place)
(res,) = exe.run(fetch_list=[out])
def test_api_bf16(self):
with paddle.fluid.framework._static_guard():
with static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = [[2, 3, 4], [7, 8, 9]]
x = paddle.to_tensor(x, dtype='bfloat16')
out = paddle.log(x)
if core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
(res,) = exe.run(fetch_list=[out])
class Test_Log_Op_Int(unittest.TestCase):
def test_api_int(self):
paddle.disable_static()
for dtype in ('int32', 'int64', 'float16'):
np_x = np.array([[2, 3, 4], [7, 8, 9]], dtype=dtype)
x = paddle.to_tensor(np_x, dtype=dtype)
y = paddle.log(x)
x_expect = np.log(np_x)
np.testing.assert_allclose(y.numpy(), x_expect, rtol=1e-3)
paddle.enable_static()
class TestLog_ZeroDim(TestLog):
def init_shape(self):
......@@ -2823,14 +2868,6 @@ class TestLog2(TestActivation):
return
self.check_grad(['X'], 'Out')
def test_error(self):
with paddle.fluid.framework._static_guard():
in1 = paddle.static.data(name="in1", shape=[11, 17], dtype="int32")
in2 = paddle.static.data(name="in2", shape=[11, 17], dtype="int64")
self.assertRaises(TypeError, paddle.log2, in1)
self.assertRaises(TypeError, paddle.log2, in2)
def test_api(self):
with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(
......@@ -2867,6 +2904,31 @@ class TestLog2_ZeroDim(TestLog2):
self.shape = []
class TestLog2_Op_Int(unittest.TestCase):
def test_api_int(self):
paddle.disable_static()
for dtype in ['int32', 'int64', 'float16']:
np_x = np.array([[2, 3, 4], [7, 8, 9]], dtype=dtype)
x = paddle.to_tensor(np_x, dtype=dtype)
y = paddle.log2(x)
x_expect = np.log2(np_x)
np.testing.assert_allclose(y.numpy(), x_expect, rtol=1e-3)
paddle.enable_static()
def test_api_bf16(self):
with paddle.fluid.framework._static_guard():
with static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = [[2, 3, 4], [7, 8, 9]]
x = paddle.to_tensor(x, dtype='bfloat16')
out = paddle.log2(x)
if core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
(res,) = exe.run(fetch_list=[out])
class TestLog10(TestActivation):
def setUp(self):
self.op_type = "log10"
......@@ -2892,15 +2954,32 @@ class TestLog10_ZeroDim(TestLog10):
self.shape = []
class TestLog10API(unittest.TestCase):
def test_error(self):
class TestLog10_Op_Int(unittest.TestCase):
def test_api_int(self):
paddle.disable_static()
for dtype in ['int32', 'int64', 'float16']:
np_x = np.array([[2, 3, 4], [7, 8, 9]], dtype=dtype)
x = paddle.to_tensor(np_x, dtype=dtype)
y = paddle.log10(x)
x_expect = np.log10(np_x)
np.testing.assert_allclose(y.numpy(), x_expect, rtol=1e-3)
paddle.enable_static()
def test_api_bf16(self):
with paddle.fluid.framework._static_guard():
in1 = paddle.static.data(name="in1", shape=[11, 17], dtype="int32")
in2 = paddle.static.data(name="in2", shape=[11, 17], dtype="int64")
with static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = [[2, 3, 4], [7, 8, 9]]
x = paddle.to_tensor(x, dtype='bfloat16')
out = paddle.log10(x)
if core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
(res,) = exe.run(fetch_list=[out])
self.assertRaises(TypeError, paddle.log10, in1)
self.assertRaises(TypeError, paddle.log10, in2)
class TestLog10API(unittest.TestCase):
def test_api(self):
with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(
......@@ -2968,6 +3047,31 @@ class Test_Log1p_Op_Fp16(unittest.TestCase):
(res,) = exe.run(fetch_list=[out])
class TestLog1p_Op_Int(unittest.TestCase):
def test_api_int(self):
paddle.disable_static()
for dtype in ['int32', 'int64', 'float16']:
np_x = np.array([[2, 3, 4], [7, 8, 9]], dtype=dtype)
x = paddle.to_tensor(np_x, dtype=dtype)
y = paddle.log1p(x)
x_expect = np.log1p(np_x)
np.testing.assert_allclose(y.numpy(), x_expect, rtol=1e-3)
paddle.enable_static()
def test_api_bf16(self):
with paddle.fluid.framework._static_guard():
with static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = [[2, 3, 4], [7, 8, 9]]
x = paddle.to_tensor(x, dtype='bfloat16')
out = paddle.log1p(x)
if core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
(res,) = exe.run(fetch_list=[out])
class TestLog1p_ZeroDim(TestLog1p):
def init_shape(self):
self.shape = []
......@@ -3239,6 +3343,7 @@ class TestSTanhAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r, rtol=1e-05)
def test_dygraph_api(self):
with dynamic_guad():
x = paddle.to_tensor(self.x_np)
out = paddle.stanh(x, self.scale_a, self.scale_b)
out_ref = ref_stanh(self.x_np, self.scale_a, self.scale_b)
......@@ -3381,6 +3486,7 @@ class TestSoftplusAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r, rtol=1e-05)
def test_dygraph_api(self):
with dynamic_guad():
x = paddle.to_tensor(self.x_np)
out1 = F.softplus(x, self.beta, self.threshold)
softplus = paddle.nn.Softplus(self.beta, self.threshold)
......@@ -3466,6 +3572,7 @@ class TestSoftsignAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r, rtol=1e-05)
def test_dygraph_api(self):
with dynamic_guad():
x = paddle.to_tensor(self.x_np)
out1 = F.softsign(x)
softsign = paddle.nn.Softsign()
......@@ -3556,7 +3663,7 @@ class TestThresholdedReluAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r, rtol=1e-05)
def test_dygraph_api(self):
paddle.disable_static()
with dynamic_guad():
x = paddle.to_tensor(self.x_np)
out1 = F.thresholded_relu(x, self.threshold)
thresholded_relu = paddle.nn.ThresholdedReLU(self.threshold)
......@@ -3660,6 +3767,7 @@ class TestHardsigmoidAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r, rtol=1e-05)
def test_dygraph_api(self):
with dynamic_guad():
x = paddle.to_tensor(self.x_np)
out1 = F.hardsigmoid(x)
m = paddle.nn.Hardsigmoid()
......@@ -3763,6 +3871,7 @@ class TestSwishAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r, rtol=1e-05)
def test_dygraph_api(self):
with dynamic_guad():
x = paddle.to_tensor(self.x_np)
out1 = F.swish(x)
swish = paddle.nn.Swish()
......@@ -3862,6 +3971,7 @@ class TestMishAPI(unittest.TestCase):
np.testing.assert_allclose(out_ref, r, rtol=1e-05)
def test_dygraph_api(self):
with dynamic_guad():
x = paddle.to_tensor(self.x_np)
out1 = F.mish(x)
mish = paddle.nn.Mish()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册