未验证 提交 2ddd0473 编写于 作者: H Hui Zhang 提交者: GitHub

log/Log10/log2/log1p support int32/int64/float16/bfloat16 forward (#54089)

* fix for log xxx

* add int32/int64 for cpu/gpu; add float16/bfloat16 for cpu forward

* fix docstring

* fix bug

* fix bugs

* fix bugs

* fix bugs

* fix bugs

* fix bug

* using cast

* fix test

* fix api

* fix other bugs

* fix ci bug for not using dygraph guard

* add bfloat16 test

* fix ut

* bf16
上级 685c0a49
......@@ -26,10 +26,22 @@ namespace phi {
void name##Kernel( \
const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) { \
funcs::functor_class<T> functor; \
ActivationImpl<T, Context, funcs::functor_class<T>>( \
ActivationImpl<T, T, Context, funcs::functor_class<T>>( \
dev_ctx, x, out, functor); \
}
#define DEFINE_CPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(name, \
functor_class) \
template <typename T, typename Context> \
void name##Kernel( \
const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) { \
funcs::functor_class<T> functor; \
using U = \
typename std::conditional_t<std::is_integral<T>::value, float, T>; \
ActivationImpl<T, U, Context, funcs::functor_class<T>>( \
dev_ctx, x, out, functor); \
}
#define DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(name, functor_class, attr) \
template <typename T, typename Context> \
void name##Kernel(const Context& dev_ctx, \
......@@ -39,24 +51,24 @@ namespace phi {
funcs::functor_class<T> functor; \
auto attrs = functor.GetAttrs(); \
*(attrs[0].second) = attr; \
ActivationImpl<T, Context, funcs::functor_class<T>>( \
ActivationImpl<T, T, Context, funcs::functor_class<T>>( \
dev_ctx, x, out, functor); \
}
#define DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS( \
name, functor_class, attr1, attr2) \
template <typename T, typename Context> \
void name##Kernel(const Context& dev_ctx, \
const DenseTensor& x, \
float attr1, \
float attr2, \
DenseTensor* out) { \
funcs::functor_class<T> functor; \
auto attrs = functor.GetAttrs(); \
*(attrs[0].second) = attr1; \
*(attrs[1].second) = attr2; \
ActivationImpl<T, Context, funcs::functor_class<T>>( \
dev_ctx, x, out, functor); \
#define DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS( \
name, functor_class, attr1, attr2) \
template <typename T, typename Context> \
void name##Kernel(const Context& dev_ctx, \
const DenseTensor& x, \
float attr1, \
float attr2, \
DenseTensor* out) { \
funcs::functor_class<T> functor; \
auto attrs = functor.GetAttrs(); \
*(attrs[0].second) = attr1; \
*(attrs[1].second) = attr2; \
ActivationImpl<T, T, Context, funcs::functor_class<T>>( \
dev_ctx, x, out, functor); \
}
DEFINE_CPU_ACTIVATION_KERNEL(Sin, SinFunctor)
......@@ -83,15 +95,16 @@ DEFINE_CPU_ACTIVATION_KERNEL(Rsqrt, RsqrtFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Softsign, SoftsignFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Sigmoid, SigmoidFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(LogSigmoid, LogSigmoidFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Log, LogFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Log2, Log2Functor)
DEFINE_CPU_ACTIVATION_KERNEL(Log10, Log10Functor)
DEFINE_CPU_ACTIVATION_KERNEL(Log1p, Log1pFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Round, RoundFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Floor, FloorFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Ceil, CeilFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Negative, NegativeFunctor)
DEFINE_CPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log, LogFunctor)
DEFINE_CPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log2, Log2Functor)
DEFINE_CPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log10, Log10Functor)
DEFINE_CPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log1p, Log1pFunctor)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, LeakyReluFunctor, alpha)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu,
ThresholdedReluFunctor,
......@@ -124,7 +137,7 @@ void HardSwishKernel(const Context& dev_ctx,
*(attrs[0].second) = threshold;
*(attrs[1].second) = scale;
*(attrs[2].second) = offset;
ActivationImpl<T, Context, funcs::HardSwishFunctor<T>>(
ActivationImpl<T, T, Context, funcs::HardSwishFunctor<T>>(
dev_ctx, x, out, functor);
}
......@@ -178,10 +191,48 @@ PD_REGISTER_ACTIVATION_KERNEL(softsign, SoftsignKernel)
PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(logsigmoid, LogSigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_sigmoid, HardSigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(log, LogKernel)
PD_REGISTER_ACTIVATION_KERNEL(log2, Log2Kernel)
PD_REGISTER_ACTIVATION_KERNEL(log10, Log10Kernel)
PD_REGISTER_ACTIVATION_KERNEL(log1p, Log1pKernel)
PD_REGISTER_KERNEL(log,
CPU,
ALL_LAYOUT,
phi::LogKernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(log2,
CPU,
ALL_LAYOUT,
phi::Log2Kernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(log10,
CPU,
ALL_LAYOUT,
phi::Log10Kernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(log1p,
CPU,
ALL_LAYOUT,
phi::Log1pKernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_ACTIVATION_KERNEL(swish_raw, SwishRawKernel)
PD_REGISTER_ACTIVATION_KERNEL(hardswish, HardSwishKernel)
PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel)
......
......@@ -1996,12 +1996,33 @@ struct HardSigmoidGradFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct Log {
HOSTDEVICE T operator()(const T& val) const { return std::log(val); }
};
template <>
struct Log<dtype::float16> {
HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
return dtype::float16(std::log(static_cast<float>(val)));
}
};
template <>
struct Log<dtype::bfloat16> {
HOSTDEVICE dtype::bfloat16 operator()(const dtype::bfloat16& val) const {
return dtype::bfloat16(std::log(static_cast<float>(val)));
}
};
// log(x) = natural logarithm of x
template <typename T>
struct LogFunctor : public BaseActivationFunctor<T> {
using U = typename std::conditional_t<std::is_integral<T>::value, float, T>;
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.log();
out.device(d) = x.template cast<U>().unaryExpr(Log<U>());
}
};
......@@ -2019,12 +2040,33 @@ struct LogGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct Log2 {
HOSTDEVICE T operator()(const T& val) const { return std::log2(val); }
};
template <>
struct Log2<dtype::float16> {
HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
return dtype::float16(std::log2(static_cast<float>(val)));
}
};
template <>
struct Log2<dtype::bfloat16> {
HOSTDEVICE dtype::bfloat16 operator()(const dtype::bfloat16& val) const {
return dtype::bfloat16(std::log2(static_cast<float>(val)));
}
};
// log2(x) = logarithm to the base 2 of the elements of x
template <typename T>
struct Log2Functor : public BaseActivationFunctor<T> {
using U = typename std::conditional_t<std::is_integral<T>::value, float, T>;
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.log() / static_cast<T>(log(2));
out.device(d) = x.template cast<U>().unaryExpr(Log2<U>());
}
};
......@@ -2043,12 +2085,33 @@ struct Log2GradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct Log10 {
HOSTDEVICE T operator()(const T& val) const { return std::log10(val); }
};
template <>
struct Log10<dtype::float16> {
HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
return dtype::float16(std::log10(static_cast<float>(val)));
}
};
template <>
struct Log10<dtype::bfloat16> {
HOSTDEVICE dtype::bfloat16 operator()(const dtype::bfloat16& val) const {
return dtype::bfloat16(std::log10(static_cast<float>(val)));
}
};
// log10(x) = logarithm to the base 10 of the elements of x
template <typename T>
struct Log10Functor : public BaseActivationFunctor<T> {
using U = typename std::conditional_t<std::is_integral<T>::value, float, T>;
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.log() / static_cast<T>(log(10));
out.device(d) = x.template cast<U>().unaryExpr(Log10<U>());
}
};
......@@ -2067,12 +2130,33 @@ struct Log10GradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct Log1p {
HOSTDEVICE T operator()(const T& val) const { return std::log1p(val); }
};
template <>
struct Log1p<dtype::float16> {
HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
return dtype::float16(std::log1p(static_cast<float>(val)));
}
};
template <>
struct Log1p<dtype::bfloat16> {
HOSTDEVICE dtype::bfloat16 operator()(const dtype::bfloat16& val) const {
return dtype::bfloat16(std::log1p(static_cast<float>(val)));
}
};
// log1p(x) = natural logarithm of x+1
template <typename T>
struct Log1pFunctor : public BaseActivationFunctor<T> {
using U = typename std::conditional_t<std::is_integral<T>::value, float, T>;
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = (static_cast<T>(1) + x).log();
out.device(d) = x.template cast<U>().unaryExpr(Log1p<U>());
}
};
......@@ -3665,14 +3749,35 @@ struct CudaHardSigmoidGradFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
__device__ __forceinline__
std::conditional_t<std::is_integral<T>::value, float, T>
log_local(T x) {
static_assert(!std::is_same<T, double>::value,
"this template must be used with float or less precise type");
#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
// use __logf fast approximation for peak bandwidth
return __logf(x);
#else
return ::log(x);
#endif
}
template <>
__device__ __forceinline__ double log_local<double>(double x) {
return ::log(x);
}
template <typename T>
struct CudaLogFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
using U = typename std::conditional_t<std::is_integral<T>::value, float, T>;
// log(x) = log(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
__device__ __forceinline__ U operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(log(x));
return static_cast<U>(log_local(x));
}
};
......@@ -3690,11 +3795,12 @@ template <typename T>
struct CudaLog1pFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
using U = typename std::conditional_t<std::is_integral<T>::value, float, T>;
// log1p(x) = log(1 + x)
__device__ __forceinline__ T operator()(const T arg_x) const {
__device__ __forceinline__ U operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(log(one + x));
return static_cast<U>(log_local(one + x));
}
};
......@@ -3710,14 +3816,35 @@ struct CudaLog1pGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
__device__ __forceinline__
std::conditional_t<std::is_integral<T>::value, float, T>
log2_local(T x) {
static_assert(!std::is_same<T, double>::value,
"this template must be used with float or less precise type");
#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
// use __logf fast approximation for peak bandwidth
return __log2f(x);
#else
return ::log2(x);
#endif
}
template <>
__device__ __forceinline__ double log2_local<double>(double x) {
return ::log2(x);
}
template <typename T>
struct CudaLog2Functor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
using U = typename std::conditional_t<std::is_integral<T>::value, float, T>;
// log2(x) = log2(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
__device__ __forceinline__ U operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(log2(x));
return static_cast<U>(log2_local(x));
}
};
......@@ -3734,14 +3861,35 @@ struct CudaLog2GradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
__device__ __forceinline__
std::conditional_t<std::is_integral<T>::value, float, T>
log10_local(T x) {
static_assert(!std::is_same<T, double>::value,
"this template must be used with float or less precise type");
#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
// use __logf fast approximation for peak bandwidth
return __log10f(x);
#else
return ::log10(x);
#endif
}
template <>
__device__ __forceinline__ double log10_local(double x) {
return ::log10(x);
}
template <typename T>
struct CudaLog10Functor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
using U = typename std::conditional_t<std::is_integral<T>::value, float, T>;
// log10(x) = log10(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
__device__ __forceinline__ U operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(log10(x));
return static_cast<U>(log10_local(x));
}
};
......
......@@ -47,6 +47,18 @@ void ActivationGPUImpl(const Context& dev_ctx,
dev_ctx, x, out, functor); \
}
#define DEFINE_GPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(name, \
functor_class) \
template <typename T, typename Context> \
void name##Kernel( \
const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) { \
funcs::functor_class<T> functor; \
using U = \
typename std::conditional_t<std::is_integral<T>::value, float, T>; \
ActivationGPUImpl<U, Context, funcs::functor_class<T>>( \
dev_ctx, x, out, functor); \
}
#define DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(name, functor_class, attr) \
template <typename T, typename Context> \
void name##Kernel(const Context& dev_ctx, \
......@@ -100,14 +112,15 @@ DEFINE_GPU_ACTIVATION_KERNEL(Rsqrt, CudaRsqrtFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Softsign, CudaSoftsignFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Sigmoid, CudaSigmoidFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(LogSigmoid, CudaLogSigmoidFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Log, CudaLogFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Log2, CudaLog2Functor)
DEFINE_GPU_ACTIVATION_KERNEL(Log10, CudaLog10Functor)
DEFINE_GPU_ACTIVATION_KERNEL(Log1p, CudaLog1pFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Round, CudaRoundFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Floor, CudaFloorFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Ceil, CudaCeilFunctor)
DEFINE_GPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log, CudaLogFunctor)
DEFINE_GPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log2, CudaLog2Functor)
DEFINE_GPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log10, CudaLog10Functor)
DEFINE_GPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Log1p, CudaLog1pFunctor)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, CudaLeakyReluFunctor, alpha)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(LogitCUDA, CudaLogitFunctor, eps)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu,
......@@ -246,10 +259,6 @@ PD_REGISTER_ACTIVATION_KERNEL(softsign, SoftsignKernel)
PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(logsigmoid, LogSigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_sigmoid, HardSigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(log, LogKernel)
PD_REGISTER_ACTIVATION_KERNEL(log2, Log2Kernel)
PD_REGISTER_ACTIVATION_KERNEL(log10, Log10Kernel)
PD_REGISTER_ACTIVATION_KERNEL(log1p, Log1pKernel)
PD_REGISTER_ACTIVATION_KERNEL(hardswish, HardSwishKernel)
PD_REGISTER_ACTIVATION_KERNEL(swish_raw, SwishRawKernel)
PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel)
......@@ -258,6 +267,46 @@ PD_REGISTER_ACTIVATION_KERNEL(ceil, CeilKernel)
PD_REGISTER_ACTIVATION_KERNEL(celu, CeluKernel)
PD_REGISTER_ACTIVATION_KERNEL(logit, LogitCUDAKernel)
PD_REGISTER_KERNEL(log,
GPU,
ALL_LAYOUT,
phi::LogKernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(log2,
GPU,
ALL_LAYOUT,
phi::Log2Kernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(log10,
GPU,
ALL_LAYOUT,
phi::Log10Kernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(log1p,
GPU,
ALL_LAYOUT,
phi::Log1pKernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(pow,
GPU,
ALL_LAYOUT,
......
......@@ -23,17 +23,17 @@ namespace phi {
#define ToString(x) #x
template <typename T, typename Context, typename Functor>
template <typename T, typename U, typename Context, typename Functor>
void ActivationImpl(const Context& dev_ctx,
const DenseTensor& X,
DenseTensor* Out,
const Functor& functor) {
PADDLE_ENFORCE_NOT_NULL(Out,
errors::NotFound("Output Out should not be nullptr"));
dev_ctx.template Alloc<T>(Out);
dev_ctx.template Alloc<U>(Out);
auto x = phi::EigenVector<T>::Flatten(
GET_DATA_SAFELY(&X, "Input", "X", "Activation"));
auto out = phi::EigenVector<T>::Flatten(
auto out = phi::EigenVector<U>::Flatten(
GET_DATA_SAFELY(Out, "Output", "Out", "Activation"));
auto* place = dev_ctx.eigen_device();
// use 32bit index to speed up computation
......
......@@ -137,7 +137,7 @@ def log(x, name=None):
Out = \ln(x)
Args:
x (Tensor): Input Tensor. Must be one of the following types: float16, float32, float64.
x (Tensor): Input Tensor. Must be one of the following types: int32, int64, float16, bfloat16, float32, float64.
name (str|None): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`
......@@ -159,7 +159,10 @@ def log(x, name=None):
return _C_ops.log(x)
else:
check_variable_and_dtype(
x, 'x', ['uint16', 'float16', 'float32', 'float64'], "log"
x,
'x',
['int32', 'int64', 'uint16', 'float16', 'float32', 'float64'],
"log",
)
inputs = {'X': [x]}
helper = LayerHelper('log', **locals())
......@@ -2763,7 +2766,7 @@ def log1p(x, name=None):
Out = \ln(x+1)
Args:
x (Tensor): Input Tensor. Must be one of the following types: float16, float32, float64.
x (Tensor): Input Tensor. Must be one of the following types: int32, int64, float16, bfloat16, float32, float64.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
......@@ -2783,7 +2786,10 @@ def log1p(x, name=None):
return _C_ops.log1p(x)
else:
check_variable_and_dtype(
x, 'x', ['float16', 'uint16', 'float32', 'float64'], "log1p"
x,
'x',
['int32', 'int64', 'float16', 'uint16', 'float32', 'float64'],
"log1p",
)
inputs = {'X': [x]}
helper = LayerHelper('log1p', **locals())
......@@ -2802,7 +2808,7 @@ def log2(x, name=None):
Out = \log_2x
Args:
x (Tensor): Input tensor must be one of the following types: float32, float64.
x (Tensor): Input tensor must be one of the following types: int32, int64, float16, bfloat16, float32, float64.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
......@@ -2835,7 +2841,10 @@ def log2(x, name=None):
return _C_ops.log2(x)
else:
check_variable_and_dtype(
x, 'x', ['float16', 'uint16', 'float32', 'float64'], "log2"
x,
'x',
['int32', 'int64', 'float16', 'uint16', 'float32', 'float64'],
"log2",
)
inputs = {'X': [x]}
helper = LayerHelper('log2', **locals())
......@@ -2854,7 +2863,7 @@ def log10(x, name=None):
Out = \log_10_x
Args:
x (Tensor): Input tensor must be one of the following types: float32, float64.
x (Tensor): Input tensor must be one of the following types: int32, int64, float16, bfloat16, float32, float64.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
......@@ -2887,7 +2896,10 @@ def log10(x, name=None):
return _C_ops.log10(x)
else:
check_variable_and_dtype(
x, 'x', ['float16', 'uint16', 'float32', 'float64'], "log10"
x,
'x',
['int32', 'int64', 'float16', 'uint16', 'float32', 'float64'],
"log10",
)
inputs = {'X': [x]}
helper = LayerHelper('log10', **locals())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册