From 2ddd047346cf6fb99a13f2cdd218a0b8764df646 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Mon, 12 Jun 2023 17:48:36 +0800 Subject: [PATCH] log/Log10/log2/log1p support int32/int64/float16/bfloat16 forward (#54089) * fix for log xxx * add int32/int64 for cpu/gpu; add float16/bfloat16 for cpu forward * fix docstring * fix bug * fix bugs * fix bugs * fix bugs * fix bugs * fix bug * using cast * fix test * fix api * fix other bugs * fix ci bug for not using dygraph guard * add bfloat16 test * fix ut * bf16 --- paddle/phi/kernels/cpu/activation_kernel.cc | 101 +++- paddle/phi/kernels/funcs/activation_functor.h | 172 +++++- paddle/phi/kernels/gpu/activation_kernel.cu | 65 ++- paddle/phi/kernels/impl/activation_impl.h | 6 +- python/paddle/tensor/math.py | 28 +- test/legacy_test/test_activation_op.py | 540 +++++++++++------- 6 files changed, 641 insertions(+), 271 deletions(-) diff --git a/paddle/phi/kernels/cpu/activation_kernel.cc b/paddle/phi/kernels/cpu/activation_kernel.cc index 355dc3547f8..6d5e08cae1d 100644 --- a/paddle/phi/kernels/cpu/activation_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_kernel.cc @@ -26,10 +26,22 @@ namespace phi { void name##Kernel( \ const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) { \ funcs::functor_class functor; \ - ActivationImpl>( \ + ActivationImpl>( \ dev_ctx, x, out, functor); \ } +#define DEFINE_CPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(name, \ + functor_class) \ + template \ + void name##Kernel( \ + const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) { \ + funcs::functor_class functor; \ + using U = \ + typename std::conditional_t::value, float, T>; \ + ActivationImpl>( \ + dev_ctx, x, out, functor); \ + } + #define DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(name, functor_class, attr) \ template \ void name##Kernel(const Context& dev_ctx, \ @@ -39,24 +51,24 @@ namespace phi { funcs::functor_class functor; \ auto attrs = functor.GetAttrs(); \ *(attrs[0].second) = attr; \ - ActivationImpl>( \ + ActivationImpl>( \ dev_ctx, x, out, functor); \ } -#define DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS( \ - name, functor_class, attr1, attr2) \ - template \ - void name##Kernel(const Context& dev_ctx, \ - const DenseTensor& x, \ - float attr1, \ - float attr2, \ - DenseTensor* out) { \ - funcs::functor_class functor; \ - auto attrs = functor.GetAttrs(); \ - *(attrs[0].second) = attr1; \ - *(attrs[1].second) = attr2; \ - ActivationImpl>( \ - dev_ctx, x, out, functor); \ +#define DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS( \ + name, functor_class, attr1, attr2) \ + template \ + void name##Kernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + float attr1, \ + float attr2, \ + DenseTensor* out) { \ + funcs::functor_class functor; \ + auto attrs = functor.GetAttrs(); \ + *(attrs[0].second) = attr1; \ + *(attrs[1].second) = attr2; \ + ActivationImpl>( \ + dev_ctx, x, out, functor); \ } DEFINE_CPU_ACTIVATION_KERNEL(Sin, SinFunctor) @@ -83,15 +95,16 @@ DEFINE_CPU_ACTIVATION_KERNEL(Rsqrt, RsqrtFunctor) DEFINE_CPU_ACTIVATION_KERNEL(Softsign, SoftsignFunctor) DEFINE_CPU_ACTIVATION_KERNEL(Sigmoid, SigmoidFunctor) DEFINE_CPU_ACTIVATION_KERNEL(LogSigmoid, LogSigmoidFunctor) -DEFINE_CPU_ACTIVATION_KERNEL(Log, LogFunctor) -DEFINE_CPU_ACTIVATION_KERNEL(Log2, Log2Functor) -DEFINE_CPU_ACTIVATION_KERNEL(Log10, Log10Functor) -DEFINE_CPU_ACTIVATION_KERNEL(Log1p, Log1pFunctor) DEFINE_CPU_ACTIVATION_KERNEL(Round, RoundFunctor) DEFINE_CPU_ACTIVATION_KERNEL(Floor, FloorFunctor) DEFINE_CPU_ACTIVATION_KERNEL(Ceil, CeilFunctor) DEFINE_CPU_ACTIVATION_KERNEL(Negative, NegativeFunctor) +DEFINE_CPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log, LogFunctor) +DEFINE_CPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log2, Log2Functor) +DEFINE_CPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log10, Log10Functor) +DEFINE_CPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log1p, Log1pFunctor) + DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, LeakyReluFunctor, alpha) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, ThresholdedReluFunctor, @@ -124,7 +137,7 @@ void HardSwishKernel(const Context& dev_ctx, *(attrs[0].second) = threshold; *(attrs[1].second) = scale; *(attrs[2].second) = offset; - ActivationImpl>( + ActivationImpl>( dev_ctx, x, out, functor); } @@ -178,10 +191,48 @@ PD_REGISTER_ACTIVATION_KERNEL(softsign, SoftsignKernel) PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL(logsigmoid, LogSigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL(hard_sigmoid, HardSigmoidKernel) -PD_REGISTER_ACTIVATION_KERNEL(log, LogKernel) -PD_REGISTER_ACTIVATION_KERNEL(log2, Log2Kernel) -PD_REGISTER_ACTIVATION_KERNEL(log10, Log10Kernel) -PD_REGISTER_ACTIVATION_KERNEL(log1p, Log1pKernel) + +PD_REGISTER_KERNEL(log, + CPU, + ALL_LAYOUT, + phi::LogKernel, + float, + double, + int, + int64_t, + phi::dtype::float16, + phi::dtype::bfloat16) {} +PD_REGISTER_KERNEL(log2, + CPU, + ALL_LAYOUT, + phi::Log2Kernel, + float, + double, + int, + int64_t, + phi::dtype::float16, + phi::dtype::bfloat16) {} +PD_REGISTER_KERNEL(log10, + CPU, + ALL_LAYOUT, + phi::Log10Kernel, + float, + double, + int, + int64_t, + phi::dtype::float16, + phi::dtype::bfloat16) {} +PD_REGISTER_KERNEL(log1p, + CPU, + ALL_LAYOUT, + phi::Log1pKernel, + float, + double, + int, + int64_t, + phi::dtype::float16, + phi::dtype::bfloat16) {} + PD_REGISTER_ACTIVATION_KERNEL(swish_raw, SwishRawKernel) PD_REGISTER_ACTIVATION_KERNEL(hardswish, HardSwishKernel) PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel) diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index 8cdf2dc1281..d18d8dd302e 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -1996,12 +1996,33 @@ struct HardSigmoidGradFunctor : public BaseActivationFunctor { } }; +template +struct Log { + HOSTDEVICE T operator()(const T& val) const { return std::log(val); } +}; + +template <> +struct Log { + HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { + return dtype::float16(std::log(static_cast(val))); + } +}; + +template <> +struct Log { + HOSTDEVICE dtype::bfloat16 operator()(const dtype::bfloat16& val) const { + return dtype::bfloat16(std::log(static_cast(val))); + } +}; + // log(x) = natural logarithm of x template struct LogFunctor : public BaseActivationFunctor { + using U = typename std::conditional_t::value, float, T>; + template void operator()(Device d, X x, Out out) const { - out.device(d) = x.log(); + out.device(d) = x.template cast().unaryExpr(Log()); } }; @@ -2019,12 +2040,33 @@ struct LogGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct Log2 { + HOSTDEVICE T operator()(const T& val) const { return std::log2(val); } +}; + +template <> +struct Log2 { + HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { + return dtype::float16(std::log2(static_cast(val))); + } +}; + +template <> +struct Log2 { + HOSTDEVICE dtype::bfloat16 operator()(const dtype::bfloat16& val) const { + return dtype::bfloat16(std::log2(static_cast(val))); + } +}; + // log2(x) = logarithm to the base 2 of the elements of x template struct Log2Functor : public BaseActivationFunctor { + using U = typename std::conditional_t::value, float, T>; + template void operator()(Device d, X x, Out out) const { - out.device(d) = x.log() / static_cast(log(2)); + out.device(d) = x.template cast().unaryExpr(Log2()); } }; @@ -2043,12 +2085,33 @@ struct Log2GradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct Log10 { + HOSTDEVICE T operator()(const T& val) const { return std::log10(val); } +}; + +template <> +struct Log10 { + HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { + return dtype::float16(std::log10(static_cast(val))); + } +}; + +template <> +struct Log10 { + HOSTDEVICE dtype::bfloat16 operator()(const dtype::bfloat16& val) const { + return dtype::bfloat16(std::log10(static_cast(val))); + } +}; + // log10(x) = logarithm to the base 10 of the elements of x template struct Log10Functor : public BaseActivationFunctor { + using U = typename std::conditional_t::value, float, T>; + template void operator()(Device d, X x, Out out) const { - out.device(d) = x.log() / static_cast(log(10)); + out.device(d) = x.template cast().unaryExpr(Log10()); } }; @@ -2067,12 +2130,33 @@ struct Log10GradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct Log1p { + HOSTDEVICE T operator()(const T& val) const { return std::log1p(val); } +}; + +template <> +struct Log1p { + HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { + return dtype::float16(std::log1p(static_cast(val))); + } +}; + +template <> +struct Log1p { + HOSTDEVICE dtype::bfloat16 operator()(const dtype::bfloat16& val) const { + return dtype::bfloat16(std::log1p(static_cast(val))); + } +}; + // log1p(x) = natural logarithm of x+1 template struct Log1pFunctor : public BaseActivationFunctor { + using U = typename std::conditional_t::value, float, T>; + template void operator()(Device d, X x, Out out) const { - out.device(d) = (static_cast(1) + x).log(); + out.device(d) = x.template cast().unaryExpr(Log1p()); } }; @@ -3665,14 +3749,35 @@ struct CudaHardSigmoidGradFunctor : public BaseActivationFunctor { } }; +template +__device__ __forceinline__ + std::conditional_t::value, float, T> + log_local(T x) { + static_assert(!std::is_same::value, + "this template must be used with float or less precise type"); + +#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__) + // use __logf fast approximation for peak bandwidth + return __logf(x); +#else + return ::log(x); +#endif +} + +template <> +__device__ __forceinline__ double log_local(double x) { + return ::log(x); +} + template struct CudaLogFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; + using U = typename std::conditional_t::value, float, T>; // log(x) = log(x) - __device__ __forceinline__ T operator()(const T arg_x) const { + __device__ __forceinline__ U operator()(const T arg_x) const { MPType x = static_cast(arg_x); - return static_cast(log(x)); + return static_cast(log_local(x)); } }; @@ -3690,11 +3795,12 @@ template struct CudaLog1pFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; MPType one = static_cast(1.0f); + using U = typename std::conditional_t::value, float, T>; // log1p(x) = log(1 + x) - __device__ __forceinline__ T operator()(const T arg_x) const { + __device__ __forceinline__ U operator()(const T arg_x) const { MPType x = static_cast(arg_x); - return static_cast(log(one + x)); + return static_cast(log_local(one + x)); } }; @@ -3710,14 +3816,35 @@ struct CudaLog1pGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +__device__ __forceinline__ + std::conditional_t::value, float, T> + log2_local(T x) { + static_assert(!std::is_same::value, + "this template must be used with float or less precise type"); + +#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__) + // use __logf fast approximation for peak bandwidth + return __log2f(x); +#else + return ::log2(x); +#endif +} + +template <> +__device__ __forceinline__ double log2_local(double x) { + return ::log2(x); +} + template struct CudaLog2Functor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; + using U = typename std::conditional_t::value, float, T>; // log2(x) = log2(x) - __device__ __forceinline__ T operator()(const T arg_x) const { + __device__ __forceinline__ U operator()(const T arg_x) const { MPType x = static_cast(arg_x); - return static_cast(log2(x)); + return static_cast(log2_local(x)); } }; @@ -3734,14 +3861,35 @@ struct CudaLog2GradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +__device__ __forceinline__ + std::conditional_t::value, float, T> + log10_local(T x) { + static_assert(!std::is_same::value, + "this template must be used with float or less precise type"); + +#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__) + // use __logf fast approximation for peak bandwidth + return __log10f(x); +#else + return ::log10(x); +#endif +} + +template <> +__device__ __forceinline__ double log10_local(double x) { + return ::log10(x); +} + template struct CudaLog10Functor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; + using U = typename std::conditional_t::value, float, T>; // log10(x) = log10(x) - __device__ __forceinline__ T operator()(const T arg_x) const { + __device__ __forceinline__ U operator()(const T arg_x) const { MPType x = static_cast(arg_x); - return static_cast(log10(x)); + return static_cast(log10_local(x)); } }; diff --git a/paddle/phi/kernels/gpu/activation_kernel.cu b/paddle/phi/kernels/gpu/activation_kernel.cu index 794d442ce2a..adf917bb878 100644 --- a/paddle/phi/kernels/gpu/activation_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_kernel.cu @@ -47,6 +47,18 @@ void ActivationGPUImpl(const Context& dev_ctx, dev_ctx, x, out, functor); \ } +#define DEFINE_GPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(name, \ + functor_class) \ + template \ + void name##Kernel( \ + const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) { \ + funcs::functor_class functor; \ + using U = \ + typename std::conditional_t::value, float, T>; \ + ActivationGPUImpl>( \ + dev_ctx, x, out, functor); \ + } + #define DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(name, functor_class, attr) \ template \ void name##Kernel(const Context& dev_ctx, \ @@ -100,14 +112,15 @@ DEFINE_GPU_ACTIVATION_KERNEL(Rsqrt, CudaRsqrtFunctor) DEFINE_GPU_ACTIVATION_KERNEL(Softsign, CudaSoftsignFunctor) DEFINE_GPU_ACTIVATION_KERNEL(Sigmoid, CudaSigmoidFunctor) DEFINE_GPU_ACTIVATION_KERNEL(LogSigmoid, CudaLogSigmoidFunctor) -DEFINE_GPU_ACTIVATION_KERNEL(Log, CudaLogFunctor) -DEFINE_GPU_ACTIVATION_KERNEL(Log2, CudaLog2Functor) -DEFINE_GPU_ACTIVATION_KERNEL(Log10, CudaLog10Functor) -DEFINE_GPU_ACTIVATION_KERNEL(Log1p, CudaLog1pFunctor) DEFINE_GPU_ACTIVATION_KERNEL(Round, CudaRoundFunctor) DEFINE_GPU_ACTIVATION_KERNEL(Floor, CudaFloorFunctor) DEFINE_GPU_ACTIVATION_KERNEL(Ceil, CudaCeilFunctor) +DEFINE_GPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log, CudaLogFunctor) +DEFINE_GPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log2, CudaLog2Functor) +DEFINE_GPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log10, CudaLog10Functor) +DEFINE_GPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log1p, CudaLog1pFunctor) + DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, CudaLeakyReluFunctor, alpha) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(LogitCUDA, CudaLogitFunctor, eps) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, @@ -246,10 +259,6 @@ PD_REGISTER_ACTIVATION_KERNEL(softsign, SoftsignKernel) PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL(logsigmoid, LogSigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL(hard_sigmoid, HardSigmoidKernel) -PD_REGISTER_ACTIVATION_KERNEL(log, LogKernel) -PD_REGISTER_ACTIVATION_KERNEL(log2, Log2Kernel) -PD_REGISTER_ACTIVATION_KERNEL(log10, Log10Kernel) -PD_REGISTER_ACTIVATION_KERNEL(log1p, Log1pKernel) PD_REGISTER_ACTIVATION_KERNEL(hardswish, HardSwishKernel) PD_REGISTER_ACTIVATION_KERNEL(swish_raw, SwishRawKernel) PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel) @@ -258,6 +267,46 @@ PD_REGISTER_ACTIVATION_KERNEL(ceil, CeilKernel) PD_REGISTER_ACTIVATION_KERNEL(celu, CeluKernel) PD_REGISTER_ACTIVATION_KERNEL(logit, LogitCUDAKernel) +PD_REGISTER_KERNEL(log, + GPU, + ALL_LAYOUT, + phi::LogKernel, + float, + double, + int, + int64_t, + phi::dtype::float16, + phi::dtype::bfloat16) {} +PD_REGISTER_KERNEL(log2, + GPU, + ALL_LAYOUT, + phi::Log2Kernel, + float, + double, + int, + int64_t, + phi::dtype::float16, + phi::dtype::bfloat16) {} +PD_REGISTER_KERNEL(log10, + GPU, + ALL_LAYOUT, + phi::Log10Kernel, + float, + double, + int, + int64_t, + phi::dtype::float16, + phi::dtype::bfloat16) {} +PD_REGISTER_KERNEL(log1p, + GPU, + ALL_LAYOUT, + phi::Log1pKernel, + float, + double, + int, + int64_t, + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(pow, GPU, ALL_LAYOUT, diff --git a/paddle/phi/kernels/impl/activation_impl.h b/paddle/phi/kernels/impl/activation_impl.h index 92f3a9e4328..e539d6d5d56 100644 --- a/paddle/phi/kernels/impl/activation_impl.h +++ b/paddle/phi/kernels/impl/activation_impl.h @@ -23,17 +23,17 @@ namespace phi { #define ToString(x) #x -template +template void ActivationImpl(const Context& dev_ctx, const DenseTensor& X, DenseTensor* Out, const Functor& functor) { PADDLE_ENFORCE_NOT_NULL(Out, errors::NotFound("Output Out should not be nullptr")); - dev_ctx.template Alloc(Out); + dev_ctx.template Alloc(Out); auto x = phi::EigenVector::Flatten( GET_DATA_SAFELY(&X, "Input", "X", "Activation")); - auto out = phi::EigenVector::Flatten( + auto out = phi::EigenVector::Flatten( GET_DATA_SAFELY(Out, "Output", "Out", "Activation")); auto* place = dev_ctx.eigen_device(); // use 32bit index to speed up computation diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 91d6aae9dac..cd548be5a51 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -137,7 +137,7 @@ def log(x, name=None): Out = \ln(x) Args: - x (Tensor): Input Tensor. Must be one of the following types: float16, float32, float64. + x (Tensor): Input Tensor. Must be one of the following types: int32, int64, float16, bfloat16, float32, float64. name (str|None): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` @@ -159,7 +159,10 @@ def log(x, name=None): return _C_ops.log(x) else: check_variable_and_dtype( - x, 'x', ['uint16', 'float16', 'float32', 'float64'], "log" + x, + 'x', + ['int32', 'int64', 'uint16', 'float16', 'float32', 'float64'], + "log", ) inputs = {'X': [x]} helper = LayerHelper('log', **locals()) @@ -2763,7 +2766,7 @@ def log1p(x, name=None): Out = \ln(x+1) Args: - x (Tensor): Input Tensor. Must be one of the following types: float16, float32, float64. + x (Tensor): Input Tensor. Must be one of the following types: int32, int64, float16, bfloat16, float32, float64. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: @@ -2783,7 +2786,10 @@ def log1p(x, name=None): return _C_ops.log1p(x) else: check_variable_and_dtype( - x, 'x', ['float16', 'uint16', 'float32', 'float64'], "log1p" + x, + 'x', + ['int32', 'int64', 'float16', 'uint16', 'float32', 'float64'], + "log1p", ) inputs = {'X': [x]} helper = LayerHelper('log1p', **locals()) @@ -2802,7 +2808,7 @@ def log2(x, name=None): Out = \log_2x Args: - x (Tensor): Input tensor must be one of the following types: float32, float64. + x (Tensor): Input tensor must be one of the following types: int32, int64, float16, bfloat16, float32, float64. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -2835,7 +2841,10 @@ def log2(x, name=None): return _C_ops.log2(x) else: check_variable_and_dtype( - x, 'x', ['float16', 'uint16', 'float32', 'float64'], "log2" + x, + 'x', + ['int32', 'int64', 'float16', 'uint16', 'float32', 'float64'], + "log2", ) inputs = {'X': [x]} helper = LayerHelper('log2', **locals()) @@ -2854,7 +2863,7 @@ def log10(x, name=None): Out = \log_10_x Args: - x (Tensor): Input tensor must be one of the following types: float32, float64. + x (Tensor): Input tensor must be one of the following types: int32, int64, float16, bfloat16, float32, float64. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -2887,7 +2896,10 @@ def log10(x, name=None): return _C_ops.log10(x) else: check_variable_and_dtype( - x, 'x', ['float16', 'uint16', 'float32', 'float64'], "log10" + x, + 'x', + ['int32', 'int64', 'float16', 'uint16', 'float32', 'float64'], + "log10", ) inputs = {'X': [x]} helper = LayerHelper('log10', **locals()) diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index 0ecb3d1c1a0..2843a8db138 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -15,6 +15,7 @@ import os import unittest import warnings +from contextlib import contextmanager import numpy as np from eager_op_test import OpTest, convert_float_to_uint16 @@ -27,6 +28,15 @@ from paddle.fluid import Program, core, program_guard from paddle.fluid.layer_helper import LayerHelper +@contextmanager +def dynamic_guad(): + paddle.disable_static() + try: + yield + finally: + paddle.enable_static() + + class TestSqrtOpError(unittest.TestCase): def test_errors(self): with paddle.fluid.framework._static_guard(): @@ -200,13 +210,17 @@ class TestExpm1API(unittest.TestCase): run(place) def test_dygraph_api(self): - def run(place): - X = paddle.to_tensor(self.x) - out = paddle.expm1(X) - np.testing.assert_allclose(self.out_ref, out.numpy(), rtol=1e-05) + with dynamic_guad(): - for place in self.place: - run(place) + def run(place): + X = paddle.to_tensor(self.x) + out = paddle.expm1(X) + np.testing.assert_allclose( + self.out_ref, out.numpy(), rtol=1e-05 + ) + + for place in self.place: + run(place) def test_errors(self): with paddle.fluid.framework._static_guard(): @@ -381,6 +395,7 @@ class TestSiluAPI(unittest.TestCase): np.testing.assert_allclose(out_ref, r, rtol=1e-05) def test_dygraph_api(self): + paddle.disable_static() x = paddle.to_tensor(self.x_np) out1 = F.silu(x) m = paddle.nn.Silu() @@ -388,6 +403,7 @@ class TestSiluAPI(unittest.TestCase): out_ref = self.x_np / (1 + np.exp(-self.x_np)) for r in [out1, out2]: np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) + paddle.enable_static() def test_errors(self): with paddle.fluid.framework._static_guard(): @@ -457,6 +473,7 @@ class TestLogSigmoidAPI(unittest.TestCase): np.testing.assert_allclose(out_ref, r, rtol=1e-05) def test_dygraph_api(self): + paddle.disable_static() x = paddle.to_tensor(self.x_np) out1 = F.log_sigmoid(x) m = paddle.nn.LogSigmoid() @@ -464,6 +481,7 @@ class TestLogSigmoidAPI(unittest.TestCase): out_ref = np.log(1 / (1 + np.exp(-self.x_np))) for r in [out1, out2]: np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) + paddle.enable_static() def test_errors(self): with paddle.fluid.framework._static_guard(): @@ -549,14 +567,15 @@ class TestTanhAPI(unittest.TestCase): np.testing.assert_allclose(out_ref, r, rtol=1e-05) def test_dygraph_api(self): - x = paddle.to_tensor(self.x_np) - out1 = F.tanh(x) - out2 = paddle.tanh(x) - th = paddle.nn.Tanh() - out3 = th(x) - out_ref = np.tanh(self.x_np) - for r in [out1, out2, out3]: - np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) + with dynamic_guad(): + x = paddle.to_tensor(self.x_np) + out1 = F.tanh(x) + out2 = paddle.tanh(x) + th = paddle.nn.Tanh() + out3 = th(x) + out_ref = np.tanh(self.x_np) + for r in [out1, out2, out3]: + np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) def test_errors(self): with paddle.fluid.framework._static_guard(): @@ -869,13 +888,14 @@ class TestTanhshrinkAPI(unittest.TestCase): np.testing.assert_allclose(out_ref, r, rtol=1e-05) def test_dygraph_api(self): - x = paddle.to_tensor(self.x_np) - out1 = F.tanhshrink(x) - tanhshrink = paddle.nn.Tanhshrink() - out2 = tanhshrink(x) - out_ref = ref_tanhshrink(self.x_np) - for r in [out1, out2]: - np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) + with dynamic_guad(): + x = paddle.to_tensor(self.x_np) + out1 = F.tanhshrink(x) + tanhshrink = paddle.nn.Tanhshrink() + out2 = tanhshrink(x) + out_ref = ref_tanhshrink(self.x_np) + for r in [out1, out2]: + np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) def test_errors(self): with paddle.fluid.framework._static_guard(): @@ -969,20 +989,21 @@ class TestHardShrinkAPI(unittest.TestCase): np.testing.assert_allclose(out_ref, r, rtol=1e-05) def test_dygraph_api(self): - x = paddle.to_tensor(self.x_np) - out1 = F.hardshrink(x) - hd = paddle.nn.Hardshrink() - out2 = hd(x) - out_ref = ref_hardshrink(self.x_np, 0.5) - for r in [out1, out2]: - np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) + with dynamic_guad(): + x = paddle.to_tensor(self.x_np) + out1 = F.hardshrink(x) + hd = paddle.nn.Hardshrink() + out2 = hd(x) + out_ref = ref_hardshrink(self.x_np, 0.5) + for r in [out1, out2]: + np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) - out1 = F.hardshrink(x, 0.6) - hd = paddle.nn.Hardshrink(0.6) - out2 = hd(x) - out_ref = ref_hardshrink(self.x_np, 0.6) - for r in [out1, out2]: - np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) + out1 = F.hardshrink(x, 0.6) + hd = paddle.nn.Hardshrink(0.6) + out2 = hd(x) + out_ref = ref_hardshrink(self.x_np, 0.6) + for r in [out1, out2]: + np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) def test_errors(self): with paddle.fluid.framework._static_guard(): @@ -1034,20 +1055,21 @@ class TestHardtanhAPI(unittest.TestCase): np.testing.assert_allclose(out_ref, r, rtol=1e-05) def test_dygraph_api(self): - x = paddle.to_tensor(self.x_np) - out1 = F.hardtanh(x) - m = paddle.nn.Hardtanh() - out2 = m(x) - out_ref = ref_hardtanh(self.x_np) - for r in [out1, out2]: - np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) + with dynamic_guad(): + x = paddle.to_tensor(self.x_np) + out1 = F.hardtanh(x) + m = paddle.nn.Hardtanh() + out2 = m(x) + out_ref = ref_hardtanh(self.x_np) + for r in [out1, out2]: + np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) - out1 = F.hardtanh(x, -2.0, 2.0) - m = paddle.nn.Hardtanh(-2.0, 2.0) - out2 = m(x) - out_ref = ref_hardtanh(self.x_np, -2.0, 2.0) - for r in [out1, out2]: - np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) + out1 = F.hardtanh(x, -2.0, 2.0) + m = paddle.nn.Hardtanh(-2.0, 2.0) + out2 = m(x) + out_ref = ref_hardtanh(self.x_np, -2.0, 2.0) + for r in [out1, out2]: + np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) def test_errors(self): with paddle.fluid.framework._static_guard(): @@ -1129,13 +1151,14 @@ class TestSoftshrinkAPI(unittest.TestCase): np.testing.assert_allclose(out_ref, r, rtol=1e-05) def test_dygraph_api(self): - x = paddle.to_tensor(self.x_np) - out1 = F.softshrink(x, self.threshold) - softshrink = paddle.nn.Softshrink(self.threshold) - out2 = softshrink(x) - out_ref = ref_softshrink(self.x_np, self.threshold) - for r in [out1, out2]: - np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) + with dynamic_guad(): + x = paddle.to_tensor(self.x_np) + out1 = F.softshrink(x, self.threshold) + softshrink = paddle.nn.Softshrink(self.threshold) + out2 = softshrink(x) + out_ref = ref_softshrink(self.x_np, self.threshold) + for r in [out1, out2]: + np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) def test_errors(self): with paddle.fluid.framework._static_guard(): @@ -1574,10 +1597,11 @@ class TestTanAPI(unittest.TestCase): ) def test_dygraph_api(self): - x = paddle.to_tensor(self.x_np) - out_test = paddle.tan(x) - out_ref = np.tan(self.x_np) - np.testing.assert_allclose(out_ref, out_test.numpy(), rtol=1e-05) + with dynamic_guad(): + x = paddle.to_tensor(self.x_np) + out_test = paddle.tan(x) + out_ref = np.tan(self.x_np) + np.testing.assert_allclose(out_ref, out_test.numpy(), rtol=1e-05) def test_static_api(self): with paddle.fluid.framework._static_guard(): @@ -1875,13 +1899,14 @@ class TestReluAPI(unittest.TestCase): np.testing.assert_allclose(out_ref, r, rtol=1e-05) def test_dygraph_api(self): - x = paddle.to_tensor(self.x_np) - m = paddle.nn.ReLU() - out1 = m(x) - out2 = self.relu(x) - out_ref = np.maximum(self.x_np, 0) - for r in [out1, out2]: - np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) + with dynamic_guad(): + x = paddle.to_tensor(self.x_np) + m = paddle.nn.ReLU() + out1 = m(x) + out2 = self.relu(x) + out_ref = np.maximum(self.x_np, 0) + for r in [out1, out2]: + np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) def test_errors(self): with paddle.fluid.framework._static_guard(): @@ -1998,20 +2023,21 @@ class TestLeakyReluAPI(unittest.TestCase): np.testing.assert_allclose(out_ref, r, rtol=1e-05) def test_dygraph_api(self): - x = paddle.to_tensor(self.x_np) - out1 = F.leaky_relu(x) - m = paddle.nn.LeakyReLU() - out2 = m(x) - out_ref = ref_leaky_relu(self.x_np) - for r in [out1, out2]: - np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) + with dynamic_guad(): + x = paddle.to_tensor(self.x_np) + out1 = F.leaky_relu(x) + m = paddle.nn.LeakyReLU() + out2 = m(x) + out_ref = ref_leaky_relu(self.x_np) + for r in [out1, out2]: + np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) - out1 = F.leaky_relu(x, 0.6) - m = paddle.nn.LeakyReLU(0.6) - out2 = m(x) - out_ref = ref_leaky_relu(self.x_np, 0.6) - for r in [out1, out2]: - np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) + out1 = F.leaky_relu(x, 0.6) + m = paddle.nn.LeakyReLU(0.6) + out2 = m(x) + out_ref = ref_leaky_relu(self.x_np, 0.6) + for r in [out1, out2]: + np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) def test_errors(self): with paddle.fluid.framework._static_guard(): @@ -2153,20 +2179,21 @@ class TestGELUAPI(unittest.TestCase): np.testing.assert_allclose(out_ref, r, rtol=1e-05) def test_dygraph_api(self): - x = paddle.to_tensor(self.x_np) - out1 = F.gelu(x) - m = paddle.nn.GELU() - out2 = m(x) - out_ref = gelu(self.x_np, False) - for r in [out1, out2]: - np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) + with dynamic_guad(): + x = paddle.to_tensor(self.x_np) + out1 = F.gelu(x) + m = paddle.nn.GELU() + out2 = m(x) + out_ref = gelu(self.x_np, False) + for r in [out1, out2]: + np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) - out1 = F.gelu(x, True) - m = paddle.nn.GELU(True) - out2 = m(x) - out_ref = gelu(self.x_np, True) - for r in [out1, out2]: - np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) + out1 = F.gelu(x, True) + m = paddle.nn.GELU(True) + out2 = m(x) + out_ref = gelu(self.x_np, True) + for r in [out1, out2]: + np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) def test_errors(self): with paddle.fluid.framework._static_guard(): @@ -2278,13 +2305,14 @@ class TestRelu6API(unittest.TestCase): np.testing.assert_allclose(out_ref, r, rtol=1e-05) def test_dygraph_api(self): - x = paddle.to_tensor(self.x_np) - out1 = F.relu6(x) - relu6 = paddle.nn.ReLU6() - out2 = relu6(x) - out_ref = ref_relu6(self.x_np) - for r in [out1, out2]: - np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) + with dynamic_guad(): + x = paddle.to_tensor(self.x_np) + out1 = F.relu6(x) + relu6 = paddle.nn.ReLU6() + out2 = relu6(x) + out_ref = ref_relu6(self.x_np) + for r in [out1, out2]: + np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) def test_fluid_api(self): with paddle.fluid.framework._static_guard(): @@ -2421,13 +2449,14 @@ class TestHardswishAPI(unittest.TestCase): np.testing.assert_allclose(out_ref, r, rtol=1e-05) def test_dygraph_api(self): - x = paddle.to_tensor([11648.0, 11448.0]) - out1 = F.hardswish(x) - m = paddle.nn.Hardswish() - out2 = m(x) - out_ref = [11648.0, 11448.0] - for r in [out1, out2]: - np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) + with dynamic_guad(): + x = paddle.to_tensor([11648.0, 11448.0]) + out1 = F.hardswish(x) + m = paddle.nn.Hardswish() + out2 = m(x) + out_ref = [11648.0, 11448.0] + for r in [out1, out2]: + np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) def test_fluid_api(self): with paddle.fluid.framework._static_guard(): @@ -2439,9 +2468,10 @@ class TestHardswishAPI(unittest.TestCase): out_ref = ref_hardswish(self.x_np) np.testing.assert_allclose(out_ref, res[0], rtol=1e-05) - x = paddle.to_tensor(self.x_np) - out = paddle.nn.functional.hardswish(x) - np.testing.assert_allclose(out_ref, out.numpy(), rtol=1e-05) + with dynamic_guad(): + x = paddle.to_tensor(self.x_np) + out = paddle.nn.functional.hardswish(x) + np.testing.assert_allclose(out_ref, out.numpy(), rtol=1e-05) def test_errors(self): with paddle.fluid.framework._static_guard(): @@ -2567,22 +2597,23 @@ class TestELUAPI(unittest.TestCase): np.testing.assert_allclose(out_ref, r, rtol=1e-05) def test_dygraph_api(self): - x = paddle.to_tensor(self.x_np) - out1 = self.elu(x) - x = paddle.to_tensor(self.x_np) - m = paddle.nn.ELU() - out2 = m(x) - out_ref = elu(self.x_np, 1.0) - for r in [out1, out2]: - np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) + with dynamic_guad(): + x = paddle.to_tensor(self.x_np) + out1 = self.elu(x) + x = paddle.to_tensor(self.x_np) + m = paddle.nn.ELU() + out2 = m(x) + out_ref = elu(self.x_np, 1.0) + for r in [out1, out2]: + np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) - out1 = self.elu(x, 0.2) - x = paddle.to_tensor(self.x_np) - m = paddle.nn.ELU(0.2) - out2 = m(x) - out_ref = elu(self.x_np, 0.2) - for r in [out1, out2]: - np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) + out1 = self.elu(x, 0.2) + x = paddle.to_tensor(self.x_np) + m = paddle.nn.ELU(0.2) + out2 = m(x) + out_ref = elu(self.x_np, 0.2) + for r in [out1, out2]: + np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) def test_errors(self): with paddle.fluid.framework._static_guard(): @@ -2607,8 +2638,9 @@ class TestELUInplaceAPI(TestELUAPI): self.elu = F.elu_ def test_alpha_error(self): - x = paddle.to_tensor(self.x_np) - self.assertRaises(Exception, F.elu_, x, -0.2) + with dynamic_guad(): + x = paddle.to_tensor(self.x_np) + self.assertRaises(Exception, F.elu_, x, -0.2) def celu(x, alpha): @@ -2676,22 +2708,23 @@ class TestCELUAPI(unittest.TestCase): np.testing.assert_allclose(out_ref, r, rtol=1e-05) def test_dygraph_api(self): - x = paddle.to_tensor(self.x_np) - out1 = self.celu(x, 1.5) - x = paddle.to_tensor(self.x_np) - m = paddle.nn.CELU(1.5) - out2 = m(x) - out_ref = celu(self.x_np, 1.5) - for r in [out1, out2]: - np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) + with dynamic_guad(): + x = paddle.to_tensor(self.x_np) + out1 = self.celu(x, 1.5) + x = paddle.to_tensor(self.x_np) + m = paddle.nn.CELU(1.5) + out2 = m(x) + out_ref = celu(self.x_np, 1.5) + for r in [out1, out2]: + np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) - out1 = self.celu(x, 0.2) - x = paddle.to_tensor(self.x_np) - m = paddle.nn.CELU(0.2) - out2 = m(x) - out_ref = celu(self.x_np, 0.2) - for r in [out1, out2]: - np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) + out1 = self.celu(x, 0.2) + x = paddle.to_tensor(self.x_np) + m = paddle.nn.CELU(0.2) + out2 = m(x) + out_ref = celu(self.x_np, 0.2) + for r in [out1, out2]: + np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) def test_errors(self): with paddle.fluid.framework._static_guard(): @@ -2770,19 +2803,6 @@ class TestLog(TestActivation): return self.check_grad(['X'], 'Out', check_prim=True) - def test_error(self): - with paddle.fluid.framework._static_guard(): - with paddle.fluid.framework._static_guard(): - in1 = paddle.static.data( - name="in1", shape=[11, 17], dtype="int32" - ) - in2 = paddle.static.data( - name="in2", shape=[11, 17], dtype="int64" - ) - - self.assertRaises(TypeError, paddle.log, in1) - self.assertRaises(TypeError, paddle.log, in2) - class Test_Log_Op_Fp16(unittest.TestCase): def test_api_fp16(self): @@ -2798,6 +2818,31 @@ class Test_Log_Op_Fp16(unittest.TestCase): exe = paddle.static.Executor(place) (res,) = exe.run(fetch_list=[out]) + def test_api_bf16(self): + with paddle.fluid.framework._static_guard(): + with static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = [[2, 3, 4], [7, 8, 9]] + x = paddle.to_tensor(x, dtype='bfloat16') + out = paddle.log(x) + if core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + (res,) = exe.run(fetch_list=[out]) + + +class Test_Log_Op_Int(unittest.TestCase): + def test_api_int(self): + paddle.disable_static() + for dtype in ('int32', 'int64', 'float16'): + np_x = np.array([[2, 3, 4], [7, 8, 9]], dtype=dtype) + x = paddle.to_tensor(np_x, dtype=dtype) + y = paddle.log(x) + x_expect = np.log(np_x) + np.testing.assert_allclose(y.numpy(), x_expect, rtol=1e-3) + paddle.enable_static() + class TestLog_ZeroDim(TestLog): def init_shape(self): @@ -2823,14 +2868,6 @@ class TestLog2(TestActivation): return self.check_grad(['X'], 'Out') - def test_error(self): - with paddle.fluid.framework._static_guard(): - in1 = paddle.static.data(name="in1", shape=[11, 17], dtype="int32") - in2 = paddle.static.data(name="in2", shape=[11, 17], dtype="int64") - - self.assertRaises(TypeError, paddle.log2, in1) - self.assertRaises(TypeError, paddle.log2, in2) - def test_api(self): with paddle.fluid.framework._static_guard(): with paddle.static.program_guard( @@ -2867,6 +2904,31 @@ class TestLog2_ZeroDim(TestLog2): self.shape = [] +class TestLog2_Op_Int(unittest.TestCase): + def test_api_int(self): + paddle.disable_static() + for dtype in ['int32', 'int64', 'float16']: + np_x = np.array([[2, 3, 4], [7, 8, 9]], dtype=dtype) + x = paddle.to_tensor(np_x, dtype=dtype) + y = paddle.log2(x) + x_expect = np.log2(np_x) + np.testing.assert_allclose(y.numpy(), x_expect, rtol=1e-3) + paddle.enable_static() + + def test_api_bf16(self): + with paddle.fluid.framework._static_guard(): + with static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = [[2, 3, 4], [7, 8, 9]] + x = paddle.to_tensor(x, dtype='bfloat16') + out = paddle.log2(x) + if core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + (res,) = exe.run(fetch_list=[out]) + + class TestLog10(TestActivation): def setUp(self): self.op_type = "log10" @@ -2892,15 +2954,32 @@ class TestLog10_ZeroDim(TestLog10): self.shape = [] -class TestLog10API(unittest.TestCase): - def test_error(self): +class TestLog10_Op_Int(unittest.TestCase): + def test_api_int(self): + paddle.disable_static() + for dtype in ['int32', 'int64', 'float16']: + np_x = np.array([[2, 3, 4], [7, 8, 9]], dtype=dtype) + x = paddle.to_tensor(np_x, dtype=dtype) + y = paddle.log10(x) + x_expect = np.log10(np_x) + np.testing.assert_allclose(y.numpy(), x_expect, rtol=1e-3) + paddle.enable_static() + + def test_api_bf16(self): with paddle.fluid.framework._static_guard(): - in1 = paddle.static.data(name="in1", shape=[11, 17], dtype="int32") - in2 = paddle.static.data(name="in2", shape=[11, 17], dtype="int64") + with static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = [[2, 3, 4], [7, 8, 9]] + x = paddle.to_tensor(x, dtype='bfloat16') + out = paddle.log10(x) + if core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + (res,) = exe.run(fetch_list=[out]) - self.assertRaises(TypeError, paddle.log10, in1) - self.assertRaises(TypeError, paddle.log10, in2) +class TestLog10API(unittest.TestCase): def test_api(self): with paddle.fluid.framework._static_guard(): with paddle.static.program_guard( @@ -2968,6 +3047,31 @@ class Test_Log1p_Op_Fp16(unittest.TestCase): (res,) = exe.run(fetch_list=[out]) +class TestLog1p_Op_Int(unittest.TestCase): + def test_api_int(self): + paddle.disable_static() + for dtype in ['int32', 'int64', 'float16']: + np_x = np.array([[2, 3, 4], [7, 8, 9]], dtype=dtype) + x = paddle.to_tensor(np_x, dtype=dtype) + y = paddle.log1p(x) + x_expect = np.log1p(np_x) + np.testing.assert_allclose(y.numpy(), x_expect, rtol=1e-3) + paddle.enable_static() + + def test_api_bf16(self): + with paddle.fluid.framework._static_guard(): + with static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = [[2, 3, 4], [7, 8, 9]] + x = paddle.to_tensor(x, dtype='bfloat16') + out = paddle.log1p(x) + if core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + (res,) = exe.run(fetch_list=[out]) + + class TestLog1p_ZeroDim(TestLog1p): def init_shape(self): self.shape = [] @@ -3239,11 +3343,12 @@ class TestSTanhAPI(unittest.TestCase): np.testing.assert_allclose(out_ref, r, rtol=1e-05) def test_dygraph_api(self): - x = paddle.to_tensor(self.x_np) - out = paddle.stanh(x, self.scale_a, self.scale_b) - out_ref = ref_stanh(self.x_np, self.scale_a, self.scale_b) - for r in [out]: - np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) + with dynamic_guad(): + x = paddle.to_tensor(self.x_np) + out = paddle.stanh(x, self.scale_a, self.scale_b) + out_ref = ref_stanh(self.x_np, self.scale_a, self.scale_b) + for r in [out]: + np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) def test_fluid_api(self): with paddle.fluid.framework._static_guard(): @@ -3381,13 +3486,14 @@ class TestSoftplusAPI(unittest.TestCase): np.testing.assert_allclose(out_ref, r, rtol=1e-05) def test_dygraph_api(self): - x = paddle.to_tensor(self.x_np) - out1 = F.softplus(x, self.beta, self.threshold) - softplus = paddle.nn.Softplus(self.beta, self.threshold) - out2 = softplus(x) - out_ref = ref_softplus(self.x_np, self.beta, self.threshold) - for r in [out1, out2]: - np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) + with dynamic_guad(): + x = paddle.to_tensor(self.x_np) + out1 = F.softplus(x, self.beta, self.threshold) + softplus = paddle.nn.Softplus(self.beta, self.threshold) + out2 = softplus(x) + out_ref = ref_softplus(self.x_np, self.beta, self.threshold) + for r in [out1, out2]: + np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) def test_errors(self): with paddle.fluid.framework._static_guard(): @@ -3466,13 +3572,14 @@ class TestSoftsignAPI(unittest.TestCase): np.testing.assert_allclose(out_ref, r, rtol=1e-05) def test_dygraph_api(self): - x = paddle.to_tensor(self.x_np) - out1 = F.softsign(x) - softsign = paddle.nn.Softsign() - out2 = softsign(x) - out_ref = ref_softsign(self.x_np) - for r in [out1, out2]: - np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) + with dynamic_guad(): + x = paddle.to_tensor(self.x_np) + out1 = F.softsign(x) + softsign = paddle.nn.Softsign() + out2 = softsign(x) + out_ref = ref_softsign(self.x_np) + for r in [out1, out2]: + np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) def test_errors(self): with paddle.fluid.framework._static_guard(): @@ -3556,14 +3663,14 @@ class TestThresholdedReluAPI(unittest.TestCase): np.testing.assert_allclose(out_ref, r, rtol=1e-05) def test_dygraph_api(self): - paddle.disable_static() - x = paddle.to_tensor(self.x_np) - out1 = F.thresholded_relu(x, self.threshold) - thresholded_relu = paddle.nn.ThresholdedReLU(self.threshold) - out2 = thresholded_relu(x) - out_ref = ref_thresholded_relu(self.x_np, self.threshold) - for r in [out1, out2]: - np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) + with dynamic_guad(): + x = paddle.to_tensor(self.x_np) + out1 = F.thresholded_relu(x, self.threshold) + thresholded_relu = paddle.nn.ThresholdedReLU(self.threshold) + out2 = thresholded_relu(x) + out_ref = ref_thresholded_relu(self.x_np, self.threshold) + for r in [out1, out2]: + np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) def test_errors(self): with paddle.fluid.framework._static_guard(): @@ -3660,13 +3767,14 @@ class TestHardsigmoidAPI(unittest.TestCase): np.testing.assert_allclose(out_ref, r, rtol=1e-05) def test_dygraph_api(self): - x = paddle.to_tensor(self.x_np) - out1 = F.hardsigmoid(x) - m = paddle.nn.Hardsigmoid() - out2 = m(x) - out_ref = ref_hardsigmoid(self.x_np) - for r in [out1, out2]: - np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) + with dynamic_guad(): + x = paddle.to_tensor(self.x_np) + out1 = F.hardsigmoid(x) + m = paddle.nn.Hardsigmoid() + out2 = m(x) + out_ref = ref_hardsigmoid(self.x_np) + for r in [out1, out2]: + np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) def test_fluid_api(self): with paddle.fluid.framework._static_guard(): @@ -3763,13 +3871,14 @@ class TestSwishAPI(unittest.TestCase): np.testing.assert_allclose(out_ref, r, rtol=1e-05) def test_dygraph_api(self): - x = paddle.to_tensor(self.x_np) - out1 = F.swish(x) - swish = paddle.nn.Swish() - out2 = swish(x) - out_ref = ref_swish(self.x_np) - for r in [out1, out2]: - np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) + with dynamic_guad(): + x = paddle.to_tensor(self.x_np) + out1 = F.swish(x) + swish = paddle.nn.Swish() + out2 = swish(x) + out_ref = ref_swish(self.x_np) + for r in [out1, out2]: + np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) def test_fluid_api(self): with paddle.fluid.framework._static_guard(): @@ -3862,13 +3971,14 @@ class TestMishAPI(unittest.TestCase): np.testing.assert_allclose(out_ref, r, rtol=1e-05) def test_dygraph_api(self): - x = paddle.to_tensor(self.x_np) - out1 = F.mish(x) - mish = paddle.nn.Mish() - out2 = mish(x) - out_ref = ref_mish(self.x_np) - for r in [out1, out2]: - np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) + with dynamic_guad(): + x = paddle.to_tensor(self.x_np) + out1 = F.mish(x) + mish = paddle.nn.Mish() + out2 = mish(x) + out_ref = ref_mish(self.x_np) + for r in [out1, out2]: + np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) def test_fluid_api(self): with paddle.fluid.framework._static_guard(): -- GitLab