Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
39b0bee9
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
39b0bee9
编写于
8月 31, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 31, 2020
浏览文件
操作
浏览文件
下载
差异文件
!5314 Add implicit conversion description for some ops in API.
Merge pull request !5314 from liuxiao93/implicit-conversion-api
上级
8599a84a
cb67d7d5
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
162 addition
and
1 deletion
+162
-1
mindspore/ops/operations/array_ops.py
mindspore/ops/operations/array_ops.py
+55
-0
mindspore/ops/operations/nn_ops.py
mindspore/ops/operations/nn_ops.py
+107
-1
未找到文件。
mindspore/ops/operations/array_ops.py
浏览文件 @
39b0bee9
...
@@ -2528,6 +2528,11 @@ class ScatterUpdate(_ScatterOp):
...
@@ -2528,6 +2528,11 @@ class ScatterUpdate(_ScatterOp):
Using given values to update tensor value, along with the input indices.
Using given values to update tensor value, along with the input indices.
Inputs of `input_x` and `updates` comply with the implicit type conversion rules to make the data types consistent.
If they have different data types, lower priority data type will be converted to
relatively highest priority data type.
RuntimeError exception will be thrown when the data type conversion of Parameter is required.
Args:
Args:
use_locking (bool): Whether protect the assignment by a lock. Default: True.
use_locking (bool): Whether protect the assignment by a lock. Default: True.
...
@@ -2569,6 +2574,11 @@ class ScatterNdUpdate(_ScatterNdOp):
...
@@ -2569,6 +2574,11 @@ class ScatterNdUpdate(_ScatterNdOp):
Using given values to update tensor value, along with the input indices.
Using given values to update tensor value, along with the input indices.
Inputs of `input_x` and `updates` comply with the implicit type conversion rules to make the data types consistent.
If they have different data types, lower priority data type will be converted to
relatively highest priority data type.
RuntimeError exception will be thrown when the data type conversion of Parameter is required.
Args:
Args:
use_locking (bool): Whether protect the assignment by a lock. Default: True.
use_locking (bool): Whether protect the assignment by a lock. Default: True.
...
@@ -2610,6 +2620,11 @@ class ScatterMax(_ScatterOp):
...
@@ -2610,6 +2620,11 @@ class ScatterMax(_ScatterOp):
Using given values to update tensor value through the max operation, along with the input indices.
Using given values to update tensor value through the max operation, along with the input indices.
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
Inputs of `input_x` and `updates` comply with the implicit type conversion rules to make the data types consistent.
If they have different data types, lower priority data type will be converted to
relatively highest priority data type.
RuntimeError exception will be thrown when the data type conversion of Parameter is required.
Args:
Args:
use_locking (bool): Whether protect the assignment by a lock. Default: True.
use_locking (bool): Whether protect the assignment by a lock. Default: True.
...
@@ -2645,6 +2660,11 @@ class ScatterMin(_ScatterOp):
...
@@ -2645,6 +2660,11 @@ class ScatterMin(_ScatterOp):
Using given values to update tensor value through the min operation, along with the input indices.
Using given values to update tensor value through the min operation, along with the input indices.
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
Inputs of `input_x` and `updates` comply with the implicit type conversion rules to make the data types consistent.
If they have different data types, lower priority data type will be converted to
relatively highest priority data type.
RuntimeError exception will be thrown when the data type conversion of Parameter is required.
Args:
Args:
use_locking (bool): Whether protect the assignment by a lock. Default: False.
use_locking (bool): Whether protect the assignment by a lock. Default: False.
...
@@ -2674,6 +2694,11 @@ class ScatterAdd(_ScatterOp):
...
@@ -2674,6 +2694,11 @@ class ScatterAdd(_ScatterOp):
Using given values to update tensor value through the add operation, along with the input indices.
Using given values to update tensor value through the add operation, along with the input indices.
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
Inputs of `input_x` and `updates` comply with the implicit type conversion rules to make the data types consistent.
If they have different data types, lower priority data type will be converted to
relatively highest priority data type.
RuntimeError exception will be thrown when the data type conversion of Parameter is required.
Args:
Args:
use_locking (bool): Whether protect the assignment by a lock. Default: False.
use_locking (bool): Whether protect the assignment by a lock. Default: False.
...
@@ -2703,6 +2728,11 @@ class ScatterSub(_ScatterOp):
...
@@ -2703,6 +2728,11 @@ class ScatterSub(_ScatterOp):
Using given values to update tensor value through the sub operation, along with the input indices.
Using given values to update tensor value through the sub operation, along with the input indices.
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
Inputs of `input_x` and `updates` comply with the implicit type conversion rules to make the data types consistent.
If they have different data types, lower priority data type will be converted to
relatively highest priority data type.
RuntimeError exception will be thrown when the data type conversion of Parameter is required.
Args:
Args:
use_locking (bool): Whether protect the assignment by a lock. Default: False.
use_locking (bool): Whether protect the assignment by a lock. Default: False.
...
@@ -2732,6 +2762,11 @@ class ScatterMul(_ScatterOp):
...
@@ -2732,6 +2762,11 @@ class ScatterMul(_ScatterOp):
Using given values to update tensor value through the mul operation, along with the input indices.
Using given values to update tensor value through the mul operation, along with the input indices.
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
Inputs of `input_x` and `updates` comply with the implicit type conversion rules to make the data types consistent.
If they have different data types, lower priority data type will be converted to
relatively highest priority data type.
RuntimeError exception will be thrown when the data type conversion of Parameter is required.
Args:
Args:
use_locking (bool): Whether protect the assignment by a lock. Default: False.
use_locking (bool): Whether protect the assignment by a lock. Default: False.
...
@@ -2761,6 +2796,11 @@ class ScatterDiv(_ScatterOp):
...
@@ -2761,6 +2796,11 @@ class ScatterDiv(_ScatterOp):
Using given values to update tensor value through the div operation, along with the input indices.
Using given values to update tensor value through the div operation, along with the input indices.
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
Inputs of `input_x` and `updates` comply with the implicit type conversion rules to make the data types consistent.
If they have different data types, lower priority data type will be converted to
relatively highest priority data type.
RuntimeError exception will be thrown when the data type conversion of Parameter is required.
Args:
Args:
use_locking (bool): Whether protect the assignment by a lock. Default: False.
use_locking (bool): Whether protect the assignment by a lock. Default: False.
...
@@ -2790,6 +2830,11 @@ class ScatterNdAdd(_ScatterNdOp):
...
@@ -2790,6 +2830,11 @@ class ScatterNdAdd(_ScatterNdOp):
Using given values to update tensor value through the add operation, along with the input indices.
Using given values to update tensor value through the add operation, along with the input indices.
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
Inputs of `input_x` and `updates` comply with the implicit type conversion rules to make the data types consistent.
If they have different data types, lower priority data type will be converted to
relatively highest priority data type.
RuntimeError exception will be thrown when the data type conversion of Parameter is required.
Args:
Args:
use_locking (bool): Whether protect the assignment by a lock. Default: False.
use_locking (bool): Whether protect the assignment by a lock. Default: False.
...
@@ -2819,6 +2864,11 @@ class ScatterNdSub(_ScatterNdOp):
...
@@ -2819,6 +2864,11 @@ class ScatterNdSub(_ScatterNdOp):
Using given values to update tensor value through the sub operation, along with the input indices.
Using given values to update tensor value through the sub operation, along with the input indices.
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
Inputs of `input_x` and `updates` comply with the implicit type conversion rules to make the data types consistent.
If they have different data types, lower priority data type will be converted to
relatively highest priority data type.
RuntimeError exception will be thrown when the data type conversion of Parameter is required.
Args:
Args:
use_locking (bool): Whether protect the assignment by a lock. Default: False.
use_locking (bool): Whether protect the assignment by a lock. Default: False.
...
@@ -2848,6 +2898,11 @@ class ScatterNonAliasingAdd(_ScatterNdOp):
...
@@ -2848,6 +2898,11 @@ class ScatterNonAliasingAdd(_ScatterNdOp):
Using given values to update tensor value through the add operation, along with the input indices.
Using given values to update tensor value through the add operation, along with the input indices.
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
Inputs of `input_x` and `updates` comply with the implicit type conversion rules to make the data types consistent.
If they have different data types, lower priority data type will be converted to
relatively highest priority data type.
RuntimeError exception will be thrown when the data type conversion of Parameter is required.
Inputs:
Inputs:
- **input_x** (Parameter) - The target parameter. The data type should be float16, float32 or int32.
- **input_x** (Parameter) - The target parameter. The data type should be float16, float32 or int32.
- **indices** (Tensor) - The index to do add operation whose data type should be mindspore.int32.
- **indices** (Tensor) - The index to do add operation whose data type should be mindspore.int32.
...
...
mindspore/ops/operations/nn_ops.py
浏览文件 @
39b0bee9
...
@@ -1694,6 +1694,12 @@ class ApplyMomentum(PrimitiveWithInfer):
...
@@ -1694,6 +1694,12 @@ class ApplyMomentum(PrimitiveWithInfer):
Refer to the paper `On the importance of initialization and momentum in deep
Refer to the paper `On the importance of initialization and momentum in deep
learning <https://dl.acm.org/doi/10.5555/3042817.3043064>`_ for more details.
learning <https://dl.acm.org/doi/10.5555/3042817.3043064>`_ for more details.
Inputs of `variable`, `accumulation` and `gradient` comply with the implicit type conversion rules
to make the data types consistent.
If they have different data types, lower priority data type will be converted to
relatively highest priority data type.
Data type conversion of Parameter is not supported. RuntimeError exception will be thrown.
Args:
Args:
use_locking (bool): Enable a lock to protect the update of variable and accumlation tensors. Default: False.
use_locking (bool): Enable a lock to protect the update of variable and accumlation tensors. Default: False.
use_nesterov (bool): Enable Nesterov momentum. Default: False.
use_nesterov (bool): Enable Nesterov momentum. Default: False.
...
@@ -3076,6 +3082,11 @@ class FusedSparseAdam(PrimitiveWithInfer):
...
@@ -3076,6 +3082,11 @@ class FusedSparseAdam(PrimitiveWithInfer):
`beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents `var`, :math:`\epsilon` represents
`beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents `var`, :math:`\epsilon` represents
`epsilon`.
`epsilon`.
All of inputs except `indices` comply with the implicit type conversion rules to make the data types consistent.
If they have different data types, lower priority data type will be converted to
relatively highest priority data type.
RuntimeError exception will be thrown when the data type conversion of Parameter is required.
Args:
Args:
use_locking (bool): Whether to enable a lock to protect updating variable tensors.
use_locking (bool): Whether to enable a lock to protect updating variable tensors.
If true, updates of the var, m, and v tensors will be protected by a lock.
If true, updates of the var, m, and v tensors will be protected by a lock.
...
@@ -3210,6 +3221,11 @@ class FusedSparseLazyAdam(PrimitiveWithInfer):
...
@@ -3210,6 +3221,11 @@ class FusedSparseLazyAdam(PrimitiveWithInfer):
`beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents `var`, :math:`\epsilon` represents
`beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents `var`, :math:`\epsilon` represents
`epsilon`.
`epsilon`.
All of inputs except `indices` comply with the implicit type conversion rules to make the data types consistent.
If they have different data types, lower priority data type will be converted to
relatively highest priority data type.
RuntimeError exception will be thrown when the data type conversion of Parameter is required.
Args:
Args:
use_locking (bool): Whether to enable a lock to protect updating variable tensors.
use_locking (bool): Whether to enable a lock to protect updating variable tensors.
If true, updates of the var, m, and v tensors will be protected by a lock.
If true, updates of the var, m, and v tensors will be protected by a lock.
...
@@ -3325,6 +3341,11 @@ class FusedSparseFtrl(PrimitiveWithInfer):
...
@@ -3325,6 +3341,11 @@ class FusedSparseFtrl(PrimitiveWithInfer):
"""
"""
Merge the duplicate value of the gradient and then update relevant entries according to the FTRL-proximal scheme.
Merge the duplicate value of the gradient and then update relevant entries according to the FTRL-proximal scheme.
All of inputs except `indices` comply with the implicit type conversion rules to make the data types consistent.
If they have different data types, lower priority data type will be converted to
relatively highest priority data type.
RuntimeError exception will be thrown when the data type conversion of Parameter is required.
Args:
Args:
lr (float): The learning rate value, must be positive.
lr (float): The learning rate value, must be positive.
l1 (float): l1 regularization strength, must be greater than or equal to zero.
l1 (float): l1 regularization strength, must be greater than or equal to zero.
...
@@ -3423,6 +3444,11 @@ class FusedSparseProximalAdagrad(PrimitiveWithInfer):
...
@@ -3423,6 +3444,11 @@ class FusedSparseProximalAdagrad(PrimitiveWithInfer):
.. math::
.. math::
var = \frac{sign(\text{prox_v})}{1 + lr * l2} * \max(\left| \text{prox_v} \right| - lr * l1, 0)
var = \frac{sign(\text{prox_v})}{1 + lr * l2} * \max(\left| \text{prox_v} \right| - lr * l1, 0)
All of inputs except `indices` comply with the implicit type conversion rules to make the data types consistent.
If they have different data types, lower priority data type will be converted to
relatively highest priority data type.
RuntimeError exception will be thrown when the data type conversion of Parameter is required.
Args:
Args:
use_locking (bool): If true, updates of the var and accum tensors will be protected. Default: False.
use_locking (bool): If true, updates of the var and accum tensors will be protected. Default: False.
...
@@ -3669,6 +3695,12 @@ class ApplyAdaMax(PrimitiveWithInfer):
...
@@ -3669,6 +3695,12 @@ class ApplyAdaMax(PrimitiveWithInfer):
:math:`beta_1^t` represent `beta1_power`, :math:`var` represents Variable to be updated,
:math:`beta_1^t` represent `beta1_power`, :math:`var` represents Variable to be updated,
:math:`\epsilon` represents `epsilon`.
:math:`\epsilon` represents `epsilon`.
Inputs of `var`, `m`, `v` and `grad` comply with the implicit type conversion rules
to make the data types consistent.
If they have different data types, lower priority data type will be converted to
relatively highest priority data type.
RuntimeError exception will be thrown when the data type conversion of Parameter is required.
Inputs:
Inputs:
- **var** (Parameter) - Variable to be updated. With float32 or float16 data type.
- **var** (Parameter) - Variable to be updated. With float32 or float16 data type.
- **m** (Parameter) - The 1st moment vector in the updating formula. Has the same shape and type as `var`.
- **m** (Parameter) - The 1st moment vector in the updating formula. Has the same shape and type as `var`.
...
@@ -3791,6 +3823,12 @@ class ApplyAdadelta(PrimitiveWithInfer):
...
@@ -3791,6 +3823,12 @@ class ApplyAdadelta(PrimitiveWithInfer):
.. math::
.. math::
var -= lr * update
var -= lr * update
Inputs of `var`, `accum`, `accum_update` and `grad` comply with the implicit type conversion rules
to make the data types consistent.
If they have different data types, lower priority data type will be converted to
relatively highest priority data type.
RuntimeError exception will be thrown when the data type conversion of Parameter is required.
Inputs:
Inputs:
- **var** (Parameter) - Weights to be updated. With float32 or float16 data type.
- **var** (Parameter) - Weights to be updated. With float32 or float16 data type.
- **accum** (Parameter) - Accum to be updated, has the same shape and type as `var`.
- **accum** (Parameter) - Accum to be updated, has the same shape and type as `var`.
...
@@ -3888,6 +3926,12 @@ class ApplyAdagrad(PrimitiveWithInfer):
...
@@ -3888,6 +3926,12 @@ class ApplyAdagrad(PrimitiveWithInfer):
.. math::
.. math::
var -= lr * grad * \frac{1}{\sqrt{accum}}
var -= lr * grad * \frac{1}{\sqrt{accum}}
Inputs of `var`, `accum` and `grad` comply with the implicit type conversion rules
to make the data types consistent..
If they have different data types, lower priority data type will be converted to
relatively highest priority data type.
RuntimeError exception will be thrown when the data type conversion of Parameter is required.
Args:
Args:
update_slots (bool): If `True`, `accum` will be updated. Default: True.
update_slots (bool): If `True`, `accum` will be updated. Default: True.
...
@@ -3963,6 +4007,12 @@ class ApplyAdagradV2(PrimitiveWithInfer):
...
@@ -3963,6 +4007,12 @@ class ApplyAdagradV2(PrimitiveWithInfer):
.. math::
.. math::
var -= lr * grad * \frac{1}{\sqrt{accum} + \epsilon}
var -= lr * grad * \frac{1}{\sqrt{accum} + \epsilon}
Inputs of `var`, `accum` and `grad` comply with the implicit type conversion rules
to make the data types consistent.
If they have different data types, lower priority data type will be converted to
relatively highest priority data type.
RuntimeError exception will be thrown when the data type conversion of Parameter is required.
Args:
Args:
epsilon (float): A small value added for numerical stability.
epsilon (float): A small value added for numerical stability.
update_slots (bool): If `True`, `accum` will be updated. Default: True.
update_slots (bool): If `True`, `accum` will be updated. Default: True.
...
@@ -4040,6 +4090,12 @@ class SparseApplyAdagrad(PrimitiveWithInfer):
...
@@ -4040,6 +4090,12 @@ class SparseApplyAdagrad(PrimitiveWithInfer):
.. math::
.. math::
var -= lr * grad * (1 / sqrt(accum))
var -= lr * grad * (1 / sqrt(accum))
Inputs of `var`, `accum` and `grad` comply with the implicit type conversion rules
to make the data types consistent.
If they have different data types, lower priority data type will be converted to
relatively highest priority data type.
RuntimeError exception will be thrown when the data type conversion of Parameter is required.
Args:
Args:
lr (float): Learning rate.
lr (float): Learning rate.
update_slots (bool): If `True`, `accum` will be updated. Default: True.
update_slots (bool): If `True`, `accum` will be updated. Default: True.
...
@@ -4119,6 +4175,12 @@ class SparseApplyAdagradV2(PrimitiveWithInfer):
...
@@ -4119,6 +4175,12 @@ class SparseApplyAdagradV2(PrimitiveWithInfer):
.. math::
.. math::
var -= lr * grad * \frac{1}{\sqrt{accum} + \epsilon}
var -= lr * grad * \frac{1}{\sqrt{accum} + \epsilon}
Inputs of `var`, `accum` and `grad` comply with the implicit type conversion rules
to make the data types consistent.
If they have different data types, lower priority data type will be converted to
relatively highest priority data type.
RuntimeError exception will be thrown when the data type conversion of Parameter is required.
Args:
Args:
lr (float): Learning rate.
lr (float): Learning rate.
epsilon (float): A small value added for numerical stability.
epsilon (float): A small value added for numerical stability.
...
@@ -4202,6 +4264,12 @@ class ApplyProximalAdagrad(PrimitiveWithInfer):
...
@@ -4202,6 +4264,12 @@ class ApplyProximalAdagrad(PrimitiveWithInfer):
.. math::
.. math::
var = \frac{sign(\text{prox_v})}{1 + lr * l2} * \max(\left| \text{prox_v} \right| - lr * l1, 0)
var = \frac{sign(\text{prox_v})}{1 + lr * l2} * \max(\left| \text{prox_v} \right| - lr * l1, 0)
Inputs of `var`, `accum` and `grad` comply with the implicit type conversion rules
to make the data types consistent.
If they have different data types, lower priority data type will be converted to
relatively highest priority data type.
RuntimeError exception will be thrown when the data type conversion of Parameter is required.
Args:
Args:
use_locking (bool): If true, updates of the var and accum tensors will be protected. Default: False.
use_locking (bool): If true, updates of the var and accum tensors will be protected. Default: False.
...
@@ -4298,6 +4366,12 @@ class SparseApplyProximalAdagrad(PrimitiveWithInfer):
...
@@ -4298,6 +4366,12 @@ class SparseApplyProximalAdagrad(PrimitiveWithInfer):
.. math::
.. math::
var = \frac{sign(\text{prox_v})}{1 + lr * l2} * \max(\left| \text{prox_v} \right| - lr * l1, 0)
var = \frac{sign(\text{prox_v})}{1 + lr * l2} * \max(\left| \text{prox_v} \right| - lr * l1, 0)
Inputs of `var`, `accum` and `grad` comply with the implicit type conversion rules
to make the data types consistent.
If they have different data types, lower priority data type will be converted to
relatively highest priority data type.
RuntimeError exception will be thrown when the data type conversion of Parameter is required.
Args:
Args:
use_locking (bool): If true, updates of the var and accum tensors will be protected. Default: False.
use_locking (bool): If true, updates of the var and accum tensors will be protected. Default: False.
...
@@ -4311,7 +4385,6 @@ class SparseApplyProximalAdagrad(PrimitiveWithInfer):
...
@@ -4311,7 +4385,6 @@ class SparseApplyProximalAdagrad(PrimitiveWithInfer):
- **l2** (Union[Number, Tensor]) - l2 regularization strength. should be a float number or
- **l2** (Union[Number, Tensor]) - l2 regularization strength. should be a float number or
a scalar tensor with float16 or float32 data type..
a scalar tensor with float16 or float32 data type..
- **grad** (Tensor) - A tensor of the same type as `var`, for the gradient.
- **grad** (Tensor) - A tensor of the same type as `var`, for the gradient.
The data type must be float16 or float32.
- **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`.
- **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`.
Outputs:
Outputs:
...
@@ -4390,6 +4463,12 @@ class ApplyAddSign(PrimitiveWithInfer):
...
@@ -4390,6 +4463,12 @@ class ApplyAddSign(PrimitiveWithInfer):
:math:`t` represents updating step while, :math:`m` represents the 1st moment vector, :math:`m_{t-1}`
:math:`t` represents updating step while, :math:`m` represents the 1st moment vector, :math:`m_{t-1}`
is the last momentent of :math:`m_{t}`, :math:`lr` represents scaling factor `lr`, :math:`g` represents `grad`.
is the last momentent of :math:`m_{t}`, :math:`lr` represents scaling factor `lr`, :math:`g` represents `grad`.
Inputs of `var`, `accum` and `grad` comply with the implicit type conversion rules
to make the data types consistent.
If they have different data types, lower priority data type will be converted to
relatively highest priority data type.
RuntimeError exception will be thrown when the data type conversion of Parameter is required.
Inputs:
Inputs:
- **var** (Parameter) - Variable tensor to be updated. With float32 or float16 data type.
- **var** (Parameter) - Variable tensor to be updated. With float32 or float16 data type.
- **m** (Parameter) - Variable tensor to be updated. Has the same dtype as `var`.
- **m** (Parameter) - Variable tensor to be updated. Has the same dtype as `var`.
...
@@ -4491,6 +4570,13 @@ class ApplyPowerSign(PrimitiveWithInfer):
...
@@ -4491,6 +4570,13 @@ class ApplyPowerSign(PrimitiveWithInfer):
:math:`t` represents updating step while, :math:`m` represents the 1st moment vector, :math:`m_{t-1}`
:math:`t` represents updating step while, :math:`m` represents the 1st moment vector, :math:`m_{t-1}`
is the last momentent of :math:`m_{t}`, :math:`lr` represents scaling factor `lr`, :math:`g` represents `grad`.
is the last momentent of :math:`m_{t}`, :math:`lr` represents scaling factor `lr`, :math:`g` represents `grad`.
All of inputs comply with the implicit type conversion rules to make the data types consistent.
If `lr`, `logbase`, `sign_decay` or `beta` is a number, the number is automatically converted to Tensor,
and the data type is consistent with the Tensor data type involved in the operation.
If inputs are tensors and have different data types, lower priority data type will be converted to
relatively highest priority data type.
RuntimeError exception will be thrown when the data type conversion of Parameter is required.
Inputs:
Inputs:
- **var** (Parameter) - Variable tensor to be updated. With float32 or float16 data type.
- **var** (Parameter) - Variable tensor to be updated. With float32 or float16 data type.
If data type of `var` is float16, all inputs must have the same data type as `var`.
If data type of `var` is float16, all inputs must have the same data type as `var`.
...
@@ -4587,6 +4673,11 @@ class ApplyGradientDescent(PrimitiveWithInfer):
...
@@ -4587,6 +4673,11 @@ class ApplyGradientDescent(PrimitiveWithInfer):
.. math::
.. math::
var = var - \alpha * \delta
var = var - \alpha * \delta
Inputs of `var` and `delta` comply with the implicit type conversion rules to make the data types consistent.
If they have different data types, lower priority data type will be converted to
relatively highest priority data type.
RuntimeError exception will be thrown when the data type conversion of Parameter is required.
Inputs:
Inputs:
- **var** (Parameter) - Variable tensor to be updated. With float32 or float16 data type.
- **var** (Parameter) - Variable tensor to be updated. With float32 or float16 data type.
- **alpha** (Union[Number, Tensor]) - Scaling factor, should be a scalar. With float32 or float16 data type.
- **alpha** (Union[Number, Tensor]) - Scaling factor, should be a scalar. With float32 or float16 data type.
...
@@ -4649,6 +4740,11 @@ class ApplyProximalGradientDescent(PrimitiveWithInfer):
...
@@ -4649,6 +4740,11 @@ class ApplyProximalGradientDescent(PrimitiveWithInfer):
.. math::
.. math::
var = \frac{sign(\text{prox_v})}{1 + \alpha * l2} * \max(\left| \text{prox_v} \right| - alpha * l1, 0)
var = \frac{sign(\text{prox_v})}{1 + \alpha * l2} * \max(\left| \text{prox_v} \right| - alpha * l1, 0)
Inputs of `var` and `delta` comply with the implicit type conversion rules to make the data types consistent.
If they have different data types, lower priority data type will be converted to
relatively highest priority data type.
RuntimeError exception will be thrown when the data type conversion of Parameter is required.
Inputs:
Inputs:
- **var** (Parameter) - Variable tensor to be updated. With float32 or float16 data type.
- **var** (Parameter) - Variable tensor to be updated. With float32 or float16 data type.
- **alpha** (Union[Number, Tensor]) - Saling factor, should be a scalar. With float32 or float16 data type.
- **alpha** (Union[Number, Tensor]) - Saling factor, should be a scalar. With float32 or float16 data type.
...
@@ -4886,6 +4982,11 @@ class SparseApplyFtrl(PrimitiveWithInfer):
...
@@ -4886,6 +4982,11 @@ class SparseApplyFtrl(PrimitiveWithInfer):
"""
"""
Update relevant entries according to the FTRL-proximal scheme.
Update relevant entries according to the FTRL-proximal scheme.
All of inputs except `indices` comply with the implicit type conversion rules to make the data types consistent.
If they have different data types, lower priority data type will be converted to
relatively highest priority data type.
RuntimeError exception will be thrown when the data type conversion of Parameter is required.
Args:
Args:
lr (float): The learning rate value, must be positive.
lr (float): The learning rate value, must be positive.
l1 (float): l1 regularization strength, must be greater than or equal to zero.
l1 (float): l1 regularization strength, must be greater than or equal to zero.
...
@@ -4973,6 +5074,11 @@ class SparseApplyFtrlV2(PrimitiveWithInfer):
...
@@ -4973,6 +5074,11 @@ class SparseApplyFtrlV2(PrimitiveWithInfer):
"""
"""
Update relevant entries according to the FTRL-proximal scheme.
Update relevant entries according to the FTRL-proximal scheme.
All of inputs except `indices` comply with the implicit type conversion rules to make the data types consistent.
If they have different data types, lower priority data type will be converted to
relatively highest priority data type.
RuntimeError exception will be thrown when the data type conversion of Parameter is required.
Args:
Args:
lr (float): The learning rate value, must be positive.
lr (float): The learning rate value, must be positive.
l1 (float): l1 regularization strength, must be greater than or equal to zero.
l1 (float): l1 regularization strength, must be greater than or equal to zero.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录