未验证 提交 32b90b1c 编写于 作者: J joejiong 提交者: GitHub

add log10 (#28576)

Add new operator log10
上级 f0806bda
......@@ -310,6 +310,15 @@ logarithm of x base to 2.
)DOC";
UNUSED constexpr char Log10Doc[] = R"DOC(
Log10 Activation Operator.
$$out = \log_10_x$$
logarithm of x base to 10.
)DOC";
UNUSED constexpr char Log1pDoc[] = R"DOC(
Log Activation Operator.
......@@ -707,6 +716,7 @@ REGISTER_ACTIVATION_OP_MAKER(Round, RoundDoc);
REGISTER_ACTIVATION_OP_MAKER(Reciprocal, ReciprocalDoc);
REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc);
REGISTER_ACTIVATION_OP_MAKER(Log2, Log2Doc);
REGISTER_ACTIVATION_OP_MAKER(Log10, Log10Doc);
REGISTER_ACTIVATION_OP_MAKER(Log1p, Log1pDoc);
REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc);
REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc);
......
......@@ -841,6 +841,27 @@ struct Log2GradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// log10(x) = logarithm to the base 10 of the elements of x
template <typename T>
struct Log10Functor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.log() / static_cast<T>(log(10));
}
};
// the gradient of log10(x) is 1/(x*ln(10))
template <typename T>
struct Log10GradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * static_cast<T>(1) / (x * static_cast<T>(log(10)));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// log1p(x) = natural logarithm of x+1
template <typename T>
struct Log1pFunctor : public BaseActivationFunctor<T> {
......@@ -1930,6 +1951,7 @@ struct LogGradGradFunctor : public BaseActivationFunctor<T> {
__macro(reciprocal, Reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \
__macro(log1p, Log1p, Log1pFunctor, Log1pGradFunctor); \
__macro(log2, Log2, Log2Functor, Log2GradFunctor); \
__macro(log10, Log10, Log10Functor, Log10GradFunctor); \
__macro(brelu, BRelu, BReluFunctor, BReluGradFunctor); \
__macro(soft_relu, SoftRelu, SoftReluFunctor, SoftReluGradFunctor); \
__macro(stanh, STanh, STanhFunctor, STanhGradFunctor); \
......
......@@ -152,6 +152,7 @@ from .tensor.math import floor #DEFINE_ALIAS
from .tensor.math import increment #DEFINE_ALIAS
from .tensor.math import log #DEFINE_ALIAS
from .tensor.math import log2 #DEFINE_ALIAS
from .tensor.math import log10 #DEFINE_ALIAS
from .tensor.math import multiplex #DEFINE_ALIAS
from .tensor.math import pow #DEFINE_ALIAS
from .tensor.math import reciprocal #DEFINE_ALIAS
......
......@@ -1698,6 +1698,55 @@ class TestLog2(TestActivation):
self.assertTrue(np.allclose(np_z, z_expected))
class TestLog10(TestActivation):
def setUp(self):
self.op_type = "log10"
self.init_dtype()
x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype)
out = np.log10(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out')
def test_error(self):
in1 = paddle.static.data(name="in1", shape=[11, 17], dtype="int32")
in2 = paddle.static.data(name="in2", shape=[11, 17], dtype="int64")
self.assertRaises(TypeError, paddle.log10, in1)
self.assertRaises(TypeError, paddle.log10, in2)
def test_api(self):
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
input_x = np.random.uniform(0.1, 1, [11, 17]).astype("float64")
data_x = paddle.static.data(
name="data_x", shape=[11, 17], dtype="float64")
out1 = paddle.log10(data_x)
exe = paddle.static.Executor(place=paddle.CPUPlace())
exe.run(paddle.static.default_startup_program())
res1 = exe.run(paddle.static.default_main_program(),
feed={"data_x": input_x},
fetch_list=[out1])
expected_res = np.log10(input_x)
self.assertTrue(np.allclose(res1, expected_res))
# dygraph
with fluid.dygraph.guard():
np_x = np.random.uniform(0.1, 1, [11, 17]).astype("float64")
data_x = paddle.to_tensor(np_x)
z = paddle.log10(data_x)
np_z = z.numpy()
z_expected = np.array(np.log10(np_x))
self.assertTrue(np.allclose(np_z, z_expected))
class TestLog1p(TestActivation):
def setUp(self):
self.op_type = "log1p"
......@@ -2432,6 +2481,7 @@ create_test_act_fp16_class(TestELU)
create_test_act_fp16_class(TestReciprocal)
create_test_act_fp16_class(TestLog)
create_test_act_fp16_class(TestLog2, atol=5e-2)
create_test_act_fp16_class(TestLog10, atol=5e-2)
create_test_act_fp16_class(TestLog1p, grad_atol=0.9)
create_test_act_fp16_class(TestSquare)
create_test_act_fp16_class(TestPow, atol=5e-2)
......
......@@ -152,6 +152,7 @@ from .math import atan #DEFINE_ALIAS
from .math import logsumexp #DEFINE_ALIAS
from .math import inverse #DEFINE_ALIAS
from .math import log2 #DEFINE_ALIAS
from .math import log10 #DEFINE_ALIAS
from .math import log1p #DEFINE_ALIAS
from .math import erf #DEFINE_ALIAS
# from .math import addcmul #DEFINE_ALIAS
......
......@@ -80,6 +80,7 @@ __all__ = [
'increment',
'log',
'log2',
'log10',
'logsumexp',
'mul',
'multiplex',
......@@ -1362,6 +1363,57 @@ def log2(x, name=None):
helper.append_op(type="log2", inputs={"X": x}, outputs={"Out": out})
return out
def log10(x, name=None):
"""
Calculates the log to the base 10 of the given input tensor, element-wise.
.. math::
Out = \\log_10_x
Args:
x (Tensor): Input tensor must be one of the following types: float32, float64.
name (str|None): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`
Returns:
Tensor: The log to the base 10 of the input Tensor computed element-wise.
Examples:
.. code-block:: python
import paddle
# example 1: x is a float
x_i = paddle.to_tensor([[1.0], [10.0]])
res = paddle.log10(x_i) # [[0.], [1.0]]
# example 2: x is float32
x_i = paddle.full(shape=[1], fill_value=10, dtype='float32')
paddle.to_tensor(x_i)
res = paddle.log10(x_i)
print(res) # [1.0]
# example 3: x is float64
x_i = paddle.full(shape=[1], fill_value=10, dtype='float64')
paddle.to_tensor(x_i)
res = paddle.log10(x_i)
print(res) # [1.0]
"""
if in_dygraph_mode():
return core.ops.log10(x)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], "log10")
inputs = {'X': [x]}
helper = LayerHelper('log10', **locals())
dtype = helper.input_dtype(input_param_name='x')
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(type="log10", inputs={"X": x}, outputs={"Out": out})
return out
def addcmul(input, tensor1, tensor2, value=1.0, name=None):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册