提交 f5b00de1 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3042 fix sens tensor check issue

Merge pull request !3042 from wangqiuliang/fix-sens-tensor-check-issue
...@@ -351,13 +351,13 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat ...@@ -351,13 +351,13 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
for (size_t i = 0; i < op_inputs.size(); i++) { for (size_t i = 0; i < op_inputs.size(); i++) {
py::object input = op_inputs[i]; py::object input = op_inputs[i];
if (py::hasattr(input, "__parameter__")) { if (py::hasattr(input, "__parameter__")) {
result[i] = py::getattr(input, "data"); input = py::getattr(input, "data");
} else { }
auto tensor = py::cast<tensor::TensorPtr>(input); auto tensor = py::cast<tensor::TensorPtr>(input);
auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data_type(), tensor->shape(), tensor->data_ptr()); auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data_type(), tensor->shape(), tensor->data_ptr());
new_tensor->set_device_address(tensor->device_address()); new_tensor->set_device_address(tensor->device_address());
new_tensor->set_dirty(tensor->is_dirty()); new_tensor->set_dirty(tensor->is_dirty());
result[i] = new_tensor; result[i] = new_tensor;
} }
} }
*status = PYNATIVE_SUCCESS; *status = PYNATIVE_SUCCESS;
......
...@@ -120,6 +120,9 @@ class GradOperation(GradOperation_): ...@@ -120,6 +120,9 @@ class GradOperation(GradOperation_):
""" Pynative forward run to build grad graph. """ """ Pynative forward run to build grad graph. """
if self.sens_param: if self.sens_param:
args = args[:-1] args = args[:-1]
for arg in args:
if not isinstance(arg, Tensor):
raise TypeError("grad inputs should be tensor in pynative mode")
if isinstance(fn, FunctionType): if isinstance(fn, FunctionType):
_pynative_exec.set_grad_flag(True) _pynative_exec.set_grad_flag(True)
_pynative_exec.new_graph(fn, *args) _pynative_exec.new_graph(fn, *args)
...@@ -150,9 +153,6 @@ class GradOperation(GradOperation_): ...@@ -150,9 +153,6 @@ class GradOperation(GradOperation_):
else: else:
@_wrap_func @_wrap_func
def after_grad(*args): def after_grad(*args):
for arg in args:
if not isinstance(arg, Tensor):
raise TypeError("grad inputs should be tensor in pynative mode")
self._pynative_forward_run(args, fn) self._pynative_forward_run(args, fn)
_pynative_exec.grad(grad_, fn, weights, *args) _pynative_exec.grad(grad_, fn, weights, *args)
out = _pynative_exec(*args) out = _pynative_exec(*args)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册