提交 243cd332 编写于 作者: Z zhaozhenlong

adapt SparseApplyAdagradD adding output accum

上级 e70b162f
graphengine @ d345a800
Subproject commit 43a715bc461fd70b7837051a2f47f0a1b19c5859
Subproject commit d345a800a4f7c32eb768ea48667d1ce00b841748
......@@ -1133,7 +1133,7 @@ INPUT_MAP(SparseApplyAdagradD) = {
{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(grad)}, {4, INPUT_DESC(indices)}};
ATTR_MAP(SparseApplyAdagradD) = {{"lr", ATTR_DESC(lr, AnyTraits<float>())},
{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(SparseApplyAdagradD) = {{0, OUTPUT_DESC(var)}};
OUTPUT_MAP(SparseApplyAdagradD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}};
// SparseApplyFtrlD
INPUT_MAP(SparseApplyFtrlD) = {{1, INPUT_DESC(var)},
......
......@@ -2433,7 +2433,10 @@ class SparseApplyAdagrad(PrimitiveWithInfer):
The shape of `indices` must be the same as `grad` in first dimension, the type must be int32.
Outputs:
Tensor, has the same shape and type as `var`.
Tuple of 2 Tensor, the updated parameters.
- **var** (Tensor) - The same shape and data type as `var`.
- **accum** (Tensor) - The same shape and data type as `accum`.
"""
@prim_attr_register
......@@ -2448,13 +2451,13 @@ class SparseApplyAdagrad(PrimitiveWithInfer):
validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name)
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
return var_shape
return var_shape, accum_shape
def infer_dtype(self, var_type, accum_type, grad_type, indices_type):
args = {'var': var_type, 'accum': accum_type, 'grad': grad_type}
validator.check_tensor_type_same(args, (mstype.float32,), self.name)
validator.check_tensor_type_same({'indices': indices_type}, [mstype.int32], self.name)
return var_type
return var_type, accum_type
class LARSUpdate(PrimitiveWithInfer):
......
......@@ -814,7 +814,7 @@ test_case_nn_ops = [
('SparseApplyAdagrad', {
'block': P.SparseApplyAdagrad(0.5),
'desc_inputs': [[3, 3], [3, 3], [3, 3], Tensor(np.ones((3,), np.int32))],
'desc_bprop': [3, 3],
'desc_bprop': [[3, 3], [3, 3]],
'skip': ['backward']}),
('Flatten_1', {
'block': NetForFlatten(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册