未验证 提交 6baeb2d1 编写于 作者: Z zyfncg 提交者: GitHub

Generate static graph code for some activation ops by Yaml (#47382)

* generate static graph code for some activation op

* fix example code of cosh
上级 533f6cbd
......@@ -46,7 +46,7 @@ repos:
- id: detect-private-key
- id: end-of-file-fixer
- id: sort-simple-yaml
files: (op|backward|op_[a-z_]+)\.yaml$
files: (ops|backward|op_[a-z_]+)\.yaml$
- id: trailing-whitespace
files: (.*\.(py|bzl|md|rst|c|cc|cxx|cpp|cu|h|hpp|hxx|xpu|kps|cmake|yaml|yml|hook)|BUILD|.*\.BUILD|WORKSPACE|CMakeLists\.txt)$
- repo: local
......
......@@ -221,70 +221,6 @@ $$out = \\lfloor x \\rfloor$$
)DOC";
UNUSED constexpr char CosDoc[] = R"DOC(
Cosine Operator. Computes cosine of x element-wise.
Input range is `(-inf, inf)` and output range is `[-1,1]`.
.. math::
out = cos(x)
)DOC";
UNUSED constexpr char TanDoc[] = R"DOC(
Tangent Operator. Computes tangent of x element-wise.
Input range is `(k*pi-pi/2, k*pi+pi/2)` and output range is `(-inf, inf)`.
$$out = tan(x)$$
)DOC";
UNUSED constexpr char SinDoc[] = R"DOC(
Sine Activation Operator.
$$out = sin(x)$$
)DOC";
UNUSED constexpr char SinhDoc[] = R"DOC(
Sinh Activation Operator.
$$out = sinh(x)$$
)DOC";
UNUSED constexpr char CoshDoc[] = R"DOC(
Cosh Activation Operator.
Input range `(-inf, inf)`, output range `(1, inf)`.
.. math::
out = \frac{exp(x)+exp(-x)}{2}
)DOC";
UNUSED constexpr char AsinhDoc[] = R"DOC(
Asinh Activation Operator.
$$out = asinh(x)$$
)DOC";
UNUSED constexpr char AcoshDoc[] = R"DOC(
Acosh Activation Operator.
$$out = acosh(x)$$
)DOC";
UNUSED constexpr char AtanhDoc[] = R"DOC(
Atanh Activation Operator.
$$out = atanh(x)$$
)DOC";
UNUSED constexpr char RoundDoc[] = R"DOC(
The OP rounds the values in the input to the nearest integer value.
......@@ -357,55 +293,6 @@ $$out = \\frac{x}{1 + \|x\|}$$
)DOC";
class AcosOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input of acos operator");
AddOutput("Out", "Tensor, same shape and dtype as input");
AddComment(R"DOC(
Arccosine Operator.
.. math::
out = \cos^{-1}(x)
)DOC");
}
};
class AsinOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"Input of asin operator, an N-D Tensor, with data type float32, "
"float64 or float16.");
AddOutput("Out", "Tensor, same shape and dtype as input.");
AddComment(R"DOC(
Arcsine Operator.
.. math::
out = \sin^{-1}(x)
)DOC");
}
};
class AtanOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"Input of atan operator, an N-D Tensor, with data type float32, "
"float64 or float16.");
AddOutput("Out", "Tensor, same shape and dtype as input x");
AddComment(R"DOC(
Arctangent Operator.
.. math::
out = \tan^{-1}(x)
)DOC");
}
};
class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
......@@ -807,14 +694,6 @@ REGISTER_ACTIVATION_OP_MAKER(Sqrt, SqrtDoc);
REGISTER_ACTIVATION_OP_MAKER(Rsqrt, RsqrtDoc);
REGISTER_ACTIVATION_OP_MAKER(Ceil, CeilDoc);
REGISTER_ACTIVATION_OP_MAKER(Floor, FloorDoc);
REGISTER_ACTIVATION_OP_MAKER(Cos, CosDoc);
REGISTER_ACTIVATION_OP_MAKER(Tan, TanDoc);
REGISTER_ACTIVATION_OP_MAKER(Sin, SinDoc);
REGISTER_ACTIVATION_OP_MAKER(Sinh, SinhDoc);
REGISTER_ACTIVATION_OP_MAKER(Cosh, CoshDoc);
REGISTER_ACTIVATION_OP_MAKER(Acosh, AcoshDoc);
REGISTER_ACTIVATION_OP_MAKER(Asinh, AsinhDoc);
REGISTER_ACTIVATION_OP_MAKER(Atanh, AtanhDoc);
REGISTER_ACTIVATION_OP_MAKER(Round, RoundDoc);
REGISTER_ACTIVATION_OP_MAKER(Reciprocal, ReciprocalDoc);
REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc);
......@@ -1388,17 +1267,6 @@ namespace plat = paddle::platform;
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP);
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CPU_KERNEL);
REGISTER_ACTIVATION_OP(cos, Cos, CosFunctor, CosGradFunctor)
REGISTER_ACTIVATION_OP(tan, Tan, TanFunctor, TanGradFunctor);
REGISTER_ACTIVATION_OP(acos, Acos, AcosFunctor, AcosGradFunctor);
REGISTER_ACTIVATION_OP(sin, Sin, SinFunctor, SinGradFunctor);
REGISTER_ACTIVATION_OP(asin, Asin, AsinFunctor, AsinGradFunctor);
REGISTER_ACTIVATION_OP(atan, Atan, AtanFunctor, AtanGradFunctor);
REGISTER_ACTIVATION_OP(sinh, Sinh, SinhFunctor, SinhGradFunctor);
REGISTER_ACTIVATION_OP(cosh, Cosh, CoshFunctor, CoshGradFunctor);
REGISTER_ACTIVATION_OP(asinh, Asinh, AsinhFunctor, AsinhGradFunctor);
REGISTER_ACTIVATION_OP(acosh, Acosh, AcoshFunctor, AcoshGradFunctor);
REGISTER_ACTIVATION_OP(atanh, Atanh, AtanhFunctor, AtanhGradFunctor);
REGISTER_ACTIVATION_OP(brelu, BRelu, BReluFunctor, BReluGradFunctor);
REGISTER_ACTIVATION_OP(thresholded_relu,
ThresholdedRelu,
......
- backward_op : acos_grad
forward : acos (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : acos_grad
inplace : (out_grad -> x_grad)
- backward_op : acosh_grad
forward : acosh (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : acosh_grad
inplace : (out_grad -> x_grad)
- backward_op : asin_grad
forward : asin (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : asin_grad
inplace : (out_grad -> x_grad)
- backward_op : asinh_grad
forward : asinh (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : asinh_grad
inplace : (out_grad -> x_grad)
- backward_op : atan2_grad
forward : atan2 (Tensor x, Tensor y) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor out_grad)
......@@ -8,6 +52,28 @@
kernel :
func : atan2_grad
- backward_op : atan_grad
forward : atan (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : atan_grad
inplace : (out_grad -> x_grad)
- backward_op : atanh_grad
forward : atanh (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : atanh_grad
inplace : (out_grad -> x_grad)
- backward_op : cholesky_grad
forward : cholesky (Tensor x, bool upper) -> Tensor(out)
args : (Tensor out, Tensor out_grad, bool upper)
......@@ -28,6 +94,28 @@
kernel :
func : cholesky_solve_grad
- backward_op : cos_grad
forward : cos (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : cos_grad
inplace : (out_grad -> x_grad)
- backward_op : cosh_grad
forward : cosh (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : cosh_grad
inplace : (out_grad -> x_grad)
- backward_op : cross_grad
forward : cross (Tensor x, Tensor y, int axis = 9) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor out_grad, int axis)
......@@ -205,6 +293,28 @@
kernel :
func : poisson_grad
- backward_op : sin_grad
forward : sin (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : sin_grad
inplace : (out_grad -> x_grad)
- backward_op : sinh_grad
forward : sinh (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : sinh_grad
inplace : (out_grad -> x_grad)
- backward_op : solve_grad
forward : solve (Tensor x, Tensor y) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor out, Tensor out_grad)
......@@ -215,6 +325,17 @@
kernel :
func : solve_grad
- backward_op : tan_grad
forward : tan (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : tan_grad
inplace : (out_grad -> x_grad)
- backward_op : trace_grad
forward : trace (Tensor x, int offset, int axis1, int axis2) -> Tensor(out)
args : (Tensor x, Tensor out_grad, int offset, int axis1, int axis2)
......
......@@ -21,28 +21,6 @@
func : abs_grad
backward : abs_double_grad
- backward_op : acos_grad
forward : acos (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : acos_grad
inplace : (out_grad -> x_grad)
- backward_op : acosh_grad
forward : acosh (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : acosh_grad
inplace : (out_grad -> x_grad)
- backward_op : add_double_grad
forward : add_grad (Tensor x, Tensor y, Tensor grad_out, int axis = -1) -> Tensor(grad_x), Tensor(grad_y)
args : (Tensor y, Tensor grad_out, Tensor grad_x_grad, Tensor grad_y_grad, int axis = -1)
......@@ -158,28 +136,6 @@
output : Tensor(x_grad)
invoke : as_complex(out_grad)
- backward_op : asin_grad
forward : asin (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : asin_grad
inplace : (out_grad -> x_grad)
- backward_op : asinh_grad
forward : asinh (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : asinh_grad
inplace : (out_grad -> x_grad)
- backward_op : assign_grad
forward : assign (Tensor x) -> Tensor(out)
args : (Tensor out_grad)
......@@ -196,28 +152,6 @@
func : assign
inplace : (out_grad -> x_grad)
- backward_op : atan_grad
forward : atan (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : atan_grad
inplace : (out_grad -> x_grad)
- backward_op : atanh_grad
forward : atanh (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : atanh_grad
inplace : (out_grad -> x_grad)
- backward_op : batch_norm_double_grad
forward : batch_norm_grad (Tensor x, Tensor scale, Tensor bias, Tensor out_mean, Tensor out_variance, Tensor saved_mean, Tensor saved_variance, Tensor reserve_space, Tensor grad_out, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) -> Tensor(grad_x), Tensor(grad_scale), Tensor(grad_bias)
args : (Tensor x, Tensor scale, Tensor out_mean, Tensor out_variance, Tensor saved_mean, Tensor saved_variance, Tensor grad_out, Tensor grad_x_grad, Tensor grad_scale_grad, Tensor grad_bias_grad, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu)
......@@ -500,28 +434,6 @@
func : conv3d_transpose_grad
use_gpudnn : true
- backward_op : cos_grad
forward : cos (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : cos_grad
inplace : (out_grad -> x_grad)
- backward_op : cosh_grad
forward : cosh (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : cosh_grad
inplace : (out_grad -> x_grad)
- backward_op : crop_tensor_grad
forward : crop_tensor (Tensor x, IntArray shape, IntArray offsets) -> Tensor(out)
args : (Tensor x, Tensor out_grad, IntArray offsets)
......@@ -2106,28 +2018,6 @@
func : silu_grad
inplace : (out_grad -> x_grad)
- backward_op : sin_grad
forward : sin (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : sin_grad
inplace : (out_grad -> x_grad)
- backward_op : sinh_grad
forward : sinh (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : sinh_grad
inplace : (out_grad -> x_grad)
- backward_op : slice_double_grad
forward : slice_grad (Tensor input, Tensor grad_out, int64_t[] axes, IntArray starts, IntArray ends, int64_t[] infer_flags, int64_t[] decrease_axis) -> Tensor(grad_input)
args : (Tensor grad_input_grad, int64_t[] axes, IntArray starts, IntArray ends, int64_t[] infer_flags, int64_t[] decrease_axis)
......@@ -2408,17 +2298,6 @@
kernel :
func : take_along_axis_grad
- backward_op : tan_grad
forward : tan (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : tan_grad
inplace : (out_grad -> x_grad)
- backward_op : tanh_double_grad
forward : tanh_grad (Tensor out, Tensor grad_out) -> Tensor(grad_x)
args : (Tensor out, Tensor grad_out, Tensor grad_x_grad)
......
......@@ -19,24 +19,6 @@
func : accuracy
dtype : x
- op : acos
args : (Tensor x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
kernel :
func : acos
backward : acos_grad
- op : acosh
args : (Tensor x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
kernel :
func : acosh
backward : acosh_grad
- op : adadelta_
args : (Tensor param, Tensor grad, Tensor avg_squared_grad, Tensor avg_squared_update, float rho, float epsilon)
output : Tensor(param_out), Tensor(moment_out), Tensor(inf_norm_out)
......@@ -236,24 +218,6 @@
func : as_real
backward : as_real_grad
- op : asin
args : (Tensor x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
kernel :
func : asin
backward : asin_grad
- op : asinh
args : (Tensor x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
kernel :
func : asinh
backward : asinh_grad
- op : assign
args : (Tensor x)
output : Tensor
......@@ -288,24 +252,6 @@
data_type : dtype
backend : place > output
- op : atan
args : (Tensor x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
kernel :
func : atan
backward : atan_grad
- op : atanh
args : (Tensor x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
kernel :
func : atanh
backward : atanh_grad
- op : auc
args : (Tensor x, Tensor label, Tensor stat_pos, Tensor stat_neg, Tensor ins_tag_weight, str curve, int num_thresholds, int slide_steps)
output : Tensor(auc), Tensor(stat_pos_out), Tensor(stat_neg_out)
......@@ -589,24 +535,6 @@
output : Tensor(out)
invoke : copy_to_impl(x, place, blocking)
- op : cos
args : (Tensor x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
kernel :
func : cos
backward : cos_grad
- op : cosh
args : (Tensor x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
kernel :
func : cosh
backward : cosh_grad
- op : crop_tensor
args : (Tensor x, IntArray shape, IntArray offsets)
output : Tensor(out)
......@@ -1939,6 +1867,16 @@
kernel :
func : not_equal
- op : numel
args : (Tensor x)
output : Tensor(size)
infer_meta :
func : SizeInferMeta
kernel :
func : size
data_transform:
skip_transform : x
- op : one_hot
args : (Tensor x, Scalar(int) num_classes)
output : Tensor(out)
......@@ -2402,34 +2340,6 @@
func : silu
backward : silu_grad
- op : sin
args : (Tensor x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
kernel :
func : sin
backward : sin_grad
- op : sinh
args : (Tensor x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
kernel :
func : sinh
backward : sinh_grad
- op : numel
args : (Tensor x)
output : Tensor(size)
infer_meta :
func : SizeInferMeta
kernel :
func : size
data_transform:
skip_transform : x
- op : slice
args : (Tensor input, int64_t[] axes, IntArray starts, IntArray ends, int64_t[] infer_flags, int64_t[] decrease_axis)
output : Tensor
......@@ -2448,16 +2358,6 @@
func : slogdeterminant
backward : slogdet_grad
- op : softshrink
args : (Tensor x, float threshold)
output : Tensor
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : soft_shrink
backward : softshrink_grad
- op : softmax
args : (Tensor x, int axis)
output : Tensor(out)
......@@ -2479,6 +2379,16 @@
func : softplus
backward : softplus_grad
- op : softshrink
args : (Tensor x, float threshold)
output : Tensor
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : soft_shrink
backward : softshrink_grad
- op : softsign
args : (Tensor x)
output : Tensor
......@@ -2637,15 +2547,6 @@
data_type : arr
backward : take_along_axis_grad
- op : tan
args : (Tensor x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
kernel :
func : tan
backward : tan_grad
- op : tanh
args : (Tensor x)
output : Tensor(out)
......@@ -2777,17 +2678,6 @@
backend : place
data_type : dtype
- op : update_loss_scaling_
args : (Tensor[] x, Tensor found_infinite, Tensor prev_loss_scaling, Tensor in_good_steps, Tensor in_bad_steps, int incr_every_n_steps, int decr_every_n_nan_or_inf, float incr_ratio, float decr_ratio, Scalar stop_update)
output : Tensor[](out){x.size()}, Tensor(loss_scaling), Tensor(out_good_steps), Tensor(out_bad_steps)
infer_meta :
func : UpdateLossScalingInferMeta
param : [x, found_infinite, prev_loss_scaling, in_good_steps, in_bad_steps]
kernel :
func : update_loss_scaling
data_type : x
inplace : (x -> out), (prev_loss_scaling -> loss_scaling), (in_good_steps -> out_good_steps), (in_bad_steps -> out_bad_steps)
- op : unbind
args : (Tensor input, int axis)
output : Tensor[] {axis<0 ? input.dims()[input.dims().size()+axis]:input.dims()[axis]}
......@@ -2858,6 +2748,17 @@
func : unstack
backward : unstack_grad
- op : update_loss_scaling_
args : (Tensor[] x, Tensor found_infinite, Tensor prev_loss_scaling, Tensor in_good_steps, Tensor in_bad_steps, int incr_every_n_steps, int decr_every_n_nan_or_inf, float incr_ratio, float decr_ratio, Scalar stop_update)
output : Tensor[](out){x.size()}, Tensor(loss_scaling), Tensor(out_good_steps), Tensor(out_bad_steps)
infer_meta :
func : UpdateLossScalingInferMeta
param : [x, found_infinite, prev_loss_scaling, in_good_steps, in_bad_steps]
kernel :
func : update_loss_scaling
data_type : x
inplace : (x -> out), (prev_loss_scaling -> loss_scaling), (in_good_steps -> out_good_steps), (in_bad_steps -> out_bad_steps)
- op : viterbi_decode
args : (Tensor potentials, Tensor transition_params, Tensor lengths, bool include_bos_eos_tag)
output : Tensor(scores), Tensor(path)
......@@ -2926,6 +2827,15 @@
output : Tensor(out)
invoke : full_like(x, 0, dtype, place)
- op: bincount
args: (Tensor x, Tensor weights, Scalar minlength)
output: Tensor(out)
infer_meta:
func: BincountInferMeta
kernel:
func: bincount
optional: weights
- op: broadcast_tensors
args: (Tensor[] input)
output: Tensor[]{input.size()}
......@@ -3015,12 +2925,3 @@
func: unpool3d
data_type: x
backward: unpool3d_grad
- op: bincount
args: (Tensor x, Tensor weights, Scalar minlength)
output: Tensor(out)
infer_meta:
func: BincountInferMeta
kernel:
func: bincount
optional: weights
......@@ -8,7 +8,17 @@
extra :
attrs : [bool use_mkldnn = false]
- op : acos
inputs :
x : X
outputs :
out : Out
- op : acosh
inputs :
x : X
outputs :
out : Out
backward : acosh_grad
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
......@@ -34,11 +44,27 @@
extra :
attrs : [bool use_mkldnn = false]
- op : asin
inputs :
x : X
outputs :
out : Out
- op : asinh
backward : asinh_grad
inputs :
x : X
outputs :
out : Out
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- op : atan
inputs :
x : X
outputs :
out : Out
- op : atan2
inputs :
{x : X1, y : X2}
......@@ -47,6 +73,10 @@
- op : atanh
backward : atanh_grad
inputs :
x : X
outputs :
out : Out
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
......@@ -145,11 +175,19 @@
- op : cos
backward : cos_grad
inputs :
x : X
outputs :
out : Out
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- op : cosh
backward : cosh_grad
inputs :
x : X
outputs :
out : Out
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
......@@ -271,14 +309,12 @@
- op : exp
backward : exp_grad
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- op : exp
inputs :
x : X
outputs :
out : Out
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- op : expand (expand_v2)
backward : expand_grad (expand_v2_grad)
......@@ -670,11 +706,19 @@
- op : sin
backward : sin_grad
inputs :
x : X
outputs :
out : Out
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- op : sinh
backward : sinh_grad
inputs :
x : X
outputs :
out : Out
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
......@@ -748,6 +792,10 @@
- op : tan
backward : tan_grad
inputs :
x : X
outputs :
out : Out
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
......
- op : acos
args : (Tensor x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
kernel :
func : acos
backward : acos_grad
- op : acosh
args : (Tensor x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
kernel :
func : acosh
backward : acosh_grad
- op : asin
args : (Tensor x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
kernel :
func : asin
backward : asin_grad
- op : asinh
args : (Tensor x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
kernel :
func : asinh
backward : asinh_grad
- op : atan
args : (Tensor x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
kernel :
func : atan
backward : atan_grad
- op : atan2
args : (Tensor x, Tensor y)
output : Tensor
......@@ -7,6 +52,15 @@
func : atan2
backward : atan2_grad
- op : atanh
args : (Tensor x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
kernel :
func : atanh
backward : atanh_grad
- op : bernoulli
args : (Tensor x)
output : Tensor(out)
......@@ -33,15 +87,23 @@
func : cholesky_solve
backward : cholesky_solve_grad
- op : exp
- op : cos
args : (Tensor x)
output : Tensor(out)
output : Tensor
infer_meta :
func : UnchangedInferMeta
kernel :
func : exp
inplace : (x -> out)
backward : exp_grad
func : cos
backward : cos_grad
- op : cosh
args : (Tensor x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
kernel :
func : cosh
backward : cosh_grad
- op : cross
args : (Tensor x, Tensor y, int axis = 9)
......@@ -118,6 +180,16 @@
inplace : (x -> out)
backward : erfinv_grad
- op : exp
args : (Tensor x)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : exp
inplace : (x -> out)
backward : exp_grad
- op : fft_c2c
args : (Tensor x, int64_t[] axes, str normalization, bool forward)
output : Tensor
......@@ -145,6 +217,15 @@
func : fft_r2c
backward : fft_r2c_grad
- op : flip
args : (Tensor x, int[] axis)
output : Tensor (out)
infer_meta :
func : FlipInferMeta
kernel :
func : flip
backward : flip_grad
- op : graph_send_uv
args : (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, str message_op = "ADD")
output : Tensor(out)
......@@ -182,6 +263,24 @@
func : poisson
backward : poisson_grad
- op : sin
args : (Tensor x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
kernel :
func : sin
backward : sin_grad
- op : sinh
args : (Tensor x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
kernel :
func : sinh
backward : sinh_grad
- op : solve
args : (Tensor x, Tensor y)
output : Tensor
......@@ -192,6 +291,15 @@
data_type : x
backward : solve_grad
- op : tan
args : (Tensor x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
kernel :
func : tan
backward : tan_grad
- op : trace
args : (Tensor x, int offset = 0, int axis1 = 0, int axis2 = 1)
output : Tensor
......@@ -209,12 +317,3 @@
kernel :
func : trunc
backward : trunc_grad
- op : flip
args : (Tensor x, int[] axis)
output : Tensor (out)
infer_meta :
func : FlipInferMeta
kernel :
func : flip
backward : flip_grad
......@@ -39,17 +39,6 @@ namespace phi {
#define comma ,
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Cos, "cos", ); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Tan, "tan", ); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Acos, "acos", ); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Sin, "sin", ); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Asin, "asin", ); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Atan, "atan", ); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Sinh, "sinh", ); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Cosh, "cosh", ); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Asinh, "asinh", ); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Acosh, "acosh", ); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Atanh, "atanh", ); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Square, "square", ); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(BRelu, "brelu", "t_min" comma "t_max");
......@@ -240,17 +229,6 @@ PD_REGISTER_BASE_KERNEL_NAME(rsqrt_grad_grad, rsqrt_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(celu_grad_grad, celu_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(square_grad_grad, square_double_grad);
PD_REGISTER_ARG_MAPPING_FN(cos_grad, phi::CosGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(tan_grad, phi::TanGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(acos_grad, phi::AcosGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(sin_grad, phi::SinGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(asin_grad, phi::AsinGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(atan_grad, phi::AtanGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(sinh_grad, phi::SinhGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(cosh_grad, phi::CoshGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(asinh_grad, phi::AsinhGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(acosh_grad, phi::AcoshGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(atanh_grad, phi::AtanhGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(relu_grad, phi::ReluGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(expm1_grad, phi::Expm1GradOpArgumentMapping);
......
......@@ -40,25 +40,14 @@ __activations_noattr__ = [
__unary_func__ = [
'expm1',
'atan',
'sqrt',
'rsqrt',
'abs',
'ceil',
'floor',
'cos',
'tan',
'acos',
'sin',
'sinh',
'asin',
'cosh',
'round',
'reciprocal',
'square',
'acosh',
'asinh',
'atanh',
]
__inplace_unary_func__ = [
......@@ -191,22 +180,6 @@ Examples:
""",
)
add_sample_code(
globals()["atan"],
r"""
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.atan(x)
print(out)
# [-0.38050638 -0.19739556 0.09966865 0.29145679]
""",
)
add_sample_code(
globals()["tanh_shrink"],
r"""
......@@ -305,23 +278,23 @@ Examples:
)
add_sample_code(
globals()["cos"],
globals()["round"],
r"""
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.cos(x)
x = paddle.to_tensor([-0.5, -0.2, 0.6, 1.5])
out = paddle.round(x)
print(out)
# [0.92106099 0.98006658 0.99500417 0.95533649]
# [-1. -0. 1. 2.]
""",
)
add_sample_code(
globals()["tan"],
globals()["reciprocal"],
r"""
Examples:
.. code-block:: python
......@@ -329,15 +302,15 @@ Examples:
import paddle
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.tan(x)
out = paddle.reciprocal(x)
print(out)
# [-0.42279324, -0.20271005, 0.10033467, 0.30933627]
# [-2.5 -5. 10. 3.33333333]
""",
)
add_sample_code(
globals()["acos"],
globals()["square"],
r"""
Examples:
.. code-block:: python
......@@ -345,206 +318,346 @@ Examples:
import paddle
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.acos(x)
out = paddle.square(x)
print(out)
# [1.98231317 1.77215425 1.47062891 1.26610367]
# [0.16 0.04 0.01 0.09]
""",
)
add_sample_code(
globals()["sin"],
globals()["softplus"],
r"""
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.sin(x)
out = F.softplus(x)
print(out)
# [-0.38941834 -0.19866933 0.09983342 0.29552021]
# [0.513015, 0.598139, 0.744397, 0.854355]
""",
)
add_sample_code(
globals()["asin"],
globals()["softsign"],
r"""
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.asin(x)
out = F.softsign(x)
print(out)
# [-0.41151685 -0.20135792 0.10016742 0.30469265]
# [-0.285714, -0.166667, 0.0909091, 0.230769]
""",
)
add_sample_code(
globals()["cosh"],
r"""
Examples:
.. code-block:: python
import paddle
def acos(x, name=None):
"""
Acos Activation Operator.
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.cosh(x)
print(out)
# [1.08107237 1.02006676 1.00500417 1.04533851]
.. math::
out = cos^{-1}(x)
""",
)
Args:
x (Tensor): Input of Acos operator, an N-D Tensor, with data type float32, float64 or float16.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
add_sample_code(
globals()["sinh"],
r"""
Examples:
.. code-block:: python
Returns:
Tensor. Output of Acos operator, a Tensor with shape same as input.
import paddle
Examples:
.. code-block:: python
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.sinh(x)
print(out)
# [-0.41075233 -0.201336 0.10016675 0.30452029]
import paddle
""",
)
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.acos(x)
print(out)
# [1.98231317 1.77215425 1.47062891 1.26610367]
add_sample_code(
globals()["asinh"],
r"""
Examples:
.. code-block:: python
"""
if in_dygraph_mode():
return _C_ops.acos(x)
if _in_legacy_dygraph():
return _legacy_C_ops.acos(x)
import paddle
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'acos')
helper = LayerHelper('acos', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type='acos', inputs={"X": x}, outputs={"Out": out})
return out
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.asinh(x)
print(out)
# [-0.39003533, -0.19869010, 0.09983408, 0.29567307]
""",
)
def acosh(x, name=None):
"""
Acosh Activation Operator.
add_sample_code(
globals()["acosh"],
r"""
Examples:
.. code-block:: python
.. math::
out = acosh(x)
import paddle
Args:
x (Tensor): Input of Acosh operator, an N-D Tensor, with data type float32, float64 or float16.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
x = paddle.to_tensor([1., 3., 4., 5.])
out = paddle.acosh(x)
print(out)
# [0. , 1.76274729, 2.06343699, 2.29243159]
Returns:
Tensor. Output of Acosh operator, a Tensor with shape same as input.
""",
)
Examples:
.. code-block:: python
add_sample_code(
globals()["atanh"],
r"""
Examples:
.. code-block:: python
import paddle
import paddle
x = paddle.to_tensor([1., 3., 4., 5.])
out = paddle.acosh(x)
print(out)
# [0. , 1.76274729, 2.06343699, 2.29243159]
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.atanh(x)
print(out)
# [-0.42364895, -0.20273256, 0.10033535, 0.30951962]
"""
if in_dygraph_mode():
return _C_ops.acosh(x)
if _in_legacy_dygraph():
return _legacy_C_ops.acosh(x)
""",
)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'acosh')
helper = LayerHelper('acosh', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type='acosh', inputs={"X": x}, outputs={"Out": out})
return out
add_sample_code(
globals()["round"],
r"""
Examples:
.. code-block:: python
import paddle
def asin(x, name=None):
"""
Arcsine Operator.
x = paddle.to_tensor([-0.5, -0.2, 0.6, 1.5])
out = paddle.round(x)
print(out)
# [-1. -0. 1. 2.]
.. math::
out = sin^{-1}(x)
""",
)
Args:
x (Tensor): Input of Asin operator, an N-D Tensor, with data type float32, float64 or float16.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
add_sample_code(
globals()["reciprocal"],
r"""
Examples:
.. code-block:: python
Returns:
Tensor. Same shape and dtype as input.
import paddle
Examples:
.. code-block:: python
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.reciprocal(x)
print(out)
# [-2.5 -5. 10. 3.33333333]
import paddle
""",
)
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.asin(x)
print(out)
# [-0.41151685 -0.20135792 0.10016742 0.30469265]
add_sample_code(
globals()["square"],
r"""
Examples:
.. code-block:: python
"""
if in_dygraph_mode():
return _C_ops.asin(x)
if _in_legacy_dygraph():
return _legacy_C_ops.asin(x)
import paddle
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'asin')
helper = LayerHelper('asin', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type='asin', inputs={"X": x}, outputs={"Out": out})
return out
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.square(x)
print(out)
# [0.16 0.04 0.01 0.09]
""",
)
def asinh(x, name=None):
"""
Asinh Activation Operator.
add_sample_code(
globals()["softplus"],
r"""
Examples:
.. code-block:: python
.. math::
out = asinh(x)
import paddle
import paddle.nn.functional as F
Args:
x (Tensor): Input of Asinh operator, an N-D Tensor, with data type float32, float64 or float16.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = F.softplus(x)
print(out)
# [0.513015, 0.598139, 0.744397, 0.854355]
Returns:
Tensor. Output of Asinh operator, a Tensor with shape same as input.
""",
)
Examples:
.. code-block:: python
add_sample_code(
globals()["softsign"],
r"""
Examples:
.. code-block:: python
import paddle
import paddle
import paddle.nn.functional as F
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.asinh(x)
print(out)
# [-0.39003533, -0.19869010, 0.09983408, 0.29567307]
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = F.softsign(x)
print(out)
# [-0.285714, -0.166667, 0.0909091, 0.230769]
"""
if in_dygraph_mode():
return _C_ops.asinh(x)
if _in_legacy_dygraph():
return _legacy_C_ops.asinh(x)
""",
)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'asinh')
helper = LayerHelper('asinh', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type='asinh', inputs={"X": x}, outputs={"Out": out})
return out
def atan(x, name=None):
"""
Arctangent Operator.
.. math::
out = tan^{-1}(x)
Args:
x (Tensor): Input of Atan operator, an N-D Tensor, with data type float32, float64 or float16.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor. Same shape and dtype as input x.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.atan(x)
print(out)
# [-0.38050638 -0.19739556 0.09966865 0.29145679]
"""
if in_dygraph_mode():
return _C_ops.atan(x)
if _in_legacy_dygraph():
return _legacy_C_ops.atan(x)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'atan')
helper = LayerHelper('atan', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type='atan', inputs={"X": x}, outputs={"Out": out})
return out
def atanh(x, name=None):
"""
Atanh Activation Operator.
.. math::
out = atanh(x)
Args:
x (Tensor): Input of Atan operator, an N-D Tensor, with data type float32, float64 or float16.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor. Output of Atanh operator, a Tensor with shape same as input.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.atanh(x)
print(out)
# [-0.42364895, -0.20273256, 0.10033535, 0.30951962]
"""
if in_dygraph_mode():
return _C_ops.atanh(x)
if _in_legacy_dygraph():
return _legacy_C_ops.atanh(x)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'atanh')
helper = LayerHelper('atanh', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type='atanh', inputs={"X": x}, outputs={"Out": out})
return out
def cos(x, name=None):
"""
Cosine Operator. Computes cosine of x element-wise.
Input range is `(-inf, inf)` and output range is `[-1,1]`.
.. math::
out = cos(x)
Args:
x (Tensor): Input of Cos operator, an N-D Tensor, with data type float32, float64 or float16.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor. Output of Cos operator, a Tensor with shape same as input.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.cos(x)
print(out)
# [0.92106099 0.98006658 0.99500417 0.95533649]
"""
if in_dygraph_mode():
return _C_ops.cos(x)
if _in_legacy_dygraph():
return _legacy_C_ops.cos(x)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'cos')
helper = LayerHelper('cos', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type='cos', inputs={"X": x}, outputs={"Out": out})
return out
def cosh(x, name=None):
"""
Cosh Activation Operator.
Input range `(-inf, inf)`, output range `(1, inf)`.
.. math::
out = \frac{exp(x)+exp(-x)}{2}
Args:
x (Tensor): Input of Cosh operator, an N-D Tensor, with data type float32, float64 or float16.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor. Output of Cosh operator, a Tensor with shape same as input.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.cosh(x)
print(out)
# [1.08107237 1.02006676 1.00500417 1.04533851]
"""
if in_dygraph_mode():
return _C_ops.cosh(x)
if _in_legacy_dygraph():
return _legacy_C_ops.cosh(x)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'cosh')
helper = LayerHelper('cosh', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type='cosh', inputs={"X": x}, outputs={"Out": out})
return out
def exp(x, name=None):
......@@ -598,6 +711,119 @@ def exp(x, name=None):
return out
def sin(x, name=None):
"""
Sine Activation Operator.
.. math::
out = sin(x)
Args:
x (Tensor): Input of Sin operator, an N-D Tensor, with data type float32, float64 or float16.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor. Output of Sin operator, a Tensor with shape same as input.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.sin(x)
print(out)
# [-0.38941834 -0.19866933 0.09983342 0.29552021]
"""
if in_dygraph_mode():
return _C_ops.sin(x)
if _in_legacy_dygraph():
return _legacy_C_ops.sin(x)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'sin')
helper = LayerHelper('sin', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type='sin', inputs={"X": x}, outputs={"Out": out})
return out
def sinh(x, name=None):
"""
Sinh Activation Operator.
.. math::
out = sinh(x)
Args:
x (Tensor): Input of Sinh operator, an N-D Tensor, with data type float32, float64 or float16.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor. Output of Sinh operator, a Tensor with shape same as input.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.sinh(x)
print(out)
# [-0.41075233 -0.201336 0.10016675 0.30452029]
"""
if in_dygraph_mode():
return _C_ops.sinh(x)
if _in_legacy_dygraph():
return _legacy_C_ops.sinh(x)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'sinh')
helper = LayerHelper('sinh', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type='sinh', inputs={"X": x}, outputs={"Out": out})
return out
def tan(x, name=None):
"""
Tangent Operator. Computes tangent of x element-wise.
Input range is `(k*pi-pi/2, k*pi+pi/2)` and output range is `(-inf, inf)`.
.. math::
out = tan(x)
Args:
x (Tensor): Input of Tan operator, an N-D Tensor, with data type float32, float64 or float16.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor. Output of Tan operator, a Tensor with shape same as input.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.tan(x)
print(out)
# [-0.42279324, -0.20271005, 0.10033467, 0.30933627]
"""
if in_dygraph_mode():
return _C_ops.tan(x)
if _in_legacy_dygraph():
return _legacy_C_ops.tan(x)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'tan')
helper = LayerHelper('tan', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type='tan', inputs={"X": x}, outputs={"Out": out})
return out
__all__ += ['erf']
_erf_ = generate_layer_fn('erf')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册