提交 7cdd5581 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2914 make AbstractRef can join with AbstractTensor

Merge pull request !2914 from xychow/fix-abstractref-join
...@@ -838,7 +838,8 @@ bool AbstractRef::operator==(const AbstractBase &other) const { ...@@ -838,7 +838,8 @@ bool AbstractRef::operator==(const AbstractBase &other) const {
AbstractBasePtr AbstractRef::Join(const AbstractBasePtr &other) { AbstractBasePtr AbstractRef::Join(const AbstractBasePtr &other) {
auto other_ref = other->cast<AbstractRefPtr>(); auto other_ref = other->cast<AbstractRefPtr>();
if (other_ref == nullptr) { if (other_ref == nullptr) {
MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString(); auto new_ref = ref_->Join(other);
return std::make_shared<AbstractRef>(ref_key_, new_ref, ref_origin_);
} }
if (*this == *other) { if (*this == *other) {
return shared_from_base<AbstractBase>(); return shared_from_base<AbstractBase>();
......
...@@ -22,7 +22,7 @@ from mindspore.common.parameter import ParameterTuple ...@@ -22,7 +22,7 @@ from mindspore.common.parameter import ParameterTuple
from mindspore.nn import Cell from mindspore.nn import Cell
from mindspore.ops import operations as P from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
def test_net_vargs_expand(): def test_net_vargs_expand():
...@@ -184,6 +184,27 @@ def test_grad_var_args_with_sens(): ...@@ -184,6 +184,27 @@ def test_grad_var_args_with_sens():
_ = grad_net(x, y, sens) _ = grad_net(x, y, sens)
def test_grad_with_param_sens():
""""test grad_with_sens parameter"""
class GradNet(Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.weights = ParameterTuple(net.trainable_params())
self.net = net
self.sens = Parameter(Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), name='sens', requires_grad=False)
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
def construct(self, x, y):
return self.grad(self.net, self.weights)(x, y, self.sens)
x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
net = SecondNet()
grad_net = GradNet(net)
_ = grad_net(x, y)
def test_var_args_grad(): def test_var_args_grad():
class VarNet(Cell): class VarNet(Cell):
def __init__(self, net): def __init__(self, net):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册