提交 0e783982 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1945 [bug]fix bug in '=', use signature to support auto cast in assign.

Merge pull request !1945 from vlne-v1/I1JXUP-resnet50-thor-assign-fail
...@@ -27,7 +27,8 @@ namespace mindspore { ...@@ -27,7 +27,8 @@ namespace mindspore {
// namespace to support primitive operators // namespace to support primitive operators
namespace prim { namespace prim {
ValuePtr GetPythonOps(const std::string &op_name, ValuePtr GetPythonOps(const std::string &op_name,
const std::string &module_name = "mindspore._extends.parse.standard_method"); const std::string &module_name = "mindspore._extends.parse.standard_method",
bool use_signature = false);
// Arithmetic // Arithmetic
extern const PrimitivePtr kPrimScalarAdd; extern const PrimitivePtr kPrimScalarAdd;
......
...@@ -23,10 +23,10 @@ ...@@ -23,10 +23,10 @@
namespace mindspore { namespace mindspore {
// namespace to support primitive operators // namespace to support primitive operators
namespace prim { namespace prim {
ValuePtr GetPythonOps(const std::string &op_name, const std::string &module_name) { ValuePtr GetPythonOps(const std::string &op_name, const std::string &module_name, bool use_signature) {
py::object obj = parse::python_adapter::GetPyFn(module_name, op_name); py::object obj = parse::python_adapter::GetPyFn(module_name, op_name);
ValuePtr node = nullptr; ValuePtr node = nullptr;
bool succ = parse::ConvertData(obj, &node); bool succ = parse::ConvertData(obj, &node, use_signature);
if (!succ) { if (!succ) {
MS_LOG(EXCEPTION) << "get Python op " << op_name << " from " << module_name << " fail"; MS_LOG(EXCEPTION) << "get Python op " << op_name << " from " << module_name << " fail";
} }
......
...@@ -322,12 +322,10 @@ void FunctionBlock::InsertDependItemsBeforeReturn() { ...@@ -322,12 +322,10 @@ void FunctionBlock::InsertDependItemsBeforeReturn() {
ValueNodePtr make_tuple_op = NewValueNode(prim::kPrimMakeTuple); ValueNodePtr make_tuple_op = NewValueNode(prim::kPrimMakeTuple);
ValueNodePtr depend_op = NewValueNode(prim::kPrimDepend); ValueNodePtr depend_op = NewValueNode(prim::kPrimDepend);
ValueNodePtr get_ref_origin_op = NewValueNode(prim::kPrimGetRefOrigin);
ValueNodePtr stop_gradient_op = NewValueNode(prim::kPrimStopGradient); ValueNodePtr stop_gradient_op = NewValueNode(prim::kPrimStopGradient);
const std::string primitive_name("assign"); const std::string primitive_name("assign");
const std::string module_name("mindspore.ops.functional"); const std::string module_name("mindspore.ops.functional");
ValueNodePtr assign_op = NewValueNode(prim::GetPythonOps(primitive_name, module_name)); ValueNodePtr assign_op = NewValueNode(prim::GetPythonOps(primitive_name, module_name, true));
if (state_assign_.size() == 0 && auto_depends_.size() == 0) { if (state_assign_.size() == 0 && auto_depends_.size() == 0) {
return; return;
} }
...@@ -336,8 +334,7 @@ void FunctionBlock::InsertDependItemsBeforeReturn() { ...@@ -336,8 +334,7 @@ void FunctionBlock::InsertDependItemsBeforeReturn() {
vec_states.emplace_back(make_tuple_op); vec_states.emplace_back(make_tuple_op);
for (auto &item : state_assign_) { for (auto &item : state_assign_) {
auto source = ReadVariable(item.second); auto source = ReadVariable(item.second);
auto origin = func_graph()->NewCNode({get_ref_origin_op, item.first}); auto assign = func_graph()->NewCNode({assign_op, item.first, source});
auto assign = func_graph()->NewCNode({assign_op, origin, source});
MS_LOG(INFO) << "SetState read " << item.first->ToString() << ", " << item.second; MS_LOG(INFO) << "SetState read " << item.first->ToString() << ", " << item.second;
vec_states.emplace_back(assign); vec_states.emplace_back(assign);
} }
......
...@@ -47,7 +47,7 @@ class Net(nn.Cell): ...@@ -47,7 +47,7 @@ class Net(nn.Cell):
def test_assign_through_cell(): def test_assign_through_cell():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True) context.set_context(mode=context.GRAPH_MODE)
net = Net() net = Net()
net.to_float(ms.float16) net.to_float(ms.float16)
net.add_flags_recursive(fp16=False) net.add_flags_recursive(fp16=False)
...@@ -57,6 +57,25 @@ def test_assign_through_cell(): ...@@ -57,6 +57,25 @@ def test_assign_through_cell():
net(None) net(None)
class AssignOp(nn.Cell):
def __init__(self):
super(AssignOp, self).__init__()
self.b = Parameter(initializer('ones', [5]), name='b')
def construct(self, w):
self.b = w
return w
def test_assign_by_operator():
context.set_context(mode=context.GRAPH_MODE)
net = AssignOp()
net.to_float(ms.float16)
input_data = Tensor(np.ones([5]).astype(np.float32))
net(input_data)
class NetScatterNdUpdate(nn.Cell): class NetScatterNdUpdate(nn.Cell):
def __init__(self): def __init__(self):
super(NetScatterNdUpdate, self).__init__() super(NetScatterNdUpdate, self).__init__()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册