提交 e203503b 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4283 add InsertGradientOf operator support in pynative mode

Merge pull request !4283 from wangqiuliang/add-insert-gradient-of-operator-support
......@@ -58,7 +58,8 @@ using mindspore::tensor::TensorPy;
const char SINGLE_OP_GRAPH[] = "single_op_graph";
// primitive unable to infer value for constant input in PyNative mode
const std::set<std::string> vm_operators = {"make_ref", "HookBackward", "stop_gradient", "mixed_precision_cast"};
const std::set<std::string> vm_operators = {"make_ref", "HookBackward", "InsertGradientOf", "stop_gradient",
"mixed_precision_cast"};
namespace mindspore {
namespace pynative {
......@@ -346,7 +347,7 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
MS_EXCEPTION_IF_NULL(op_exec_info->py_primitive);
auto &op_inputs = op_exec_info->op_inputs;
if (op_exec_info->op_name == "HookBackward") {
if (op_exec_info->op_name == "HookBackward" || op_exec_info->op_name == "InsertGradientOf") {
py::tuple result(op_inputs.size());
for (size_t i = 0; i < op_inputs.size(); i++) {
py::object input = op_inputs[i];
......
......@@ -238,10 +238,6 @@ class InsertGradientOf(PrimitiveWithInfer):
def __init__(self, f):
self.f = f
def __call__(self, x):
"""run in PyNative mode."""
return x
def infer_shape(self, x_shape):
return x_shape
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册