未验证 提交 42754088 编写于 作者: J Jiabin Yang 提交者: GitHub

[Eager] Fix custom op error (#43463)

* fix custom op error

* fix code error
上级 90cf2299
...@@ -263,9 +263,9 @@ RunCustomOpNode::operator()( ...@@ -263,9 +263,9 @@ RunCustomOpNode::operator()(
trace_backward, &(ins_auto_grad_metas[i])); trace_backward, &(ins_auto_grad_metas[i]));
} }
if (require_any_grad) {
auto meta_info_map = egr::Controller::Instance().GetOpMetaInfoMap(); auto meta_info_map = egr::Controller::Instance().GetOpMetaInfoMap();
const auto& vec_map = meta_info_map.at(op_type_); const auto& vec_map = meta_info_map.at(op_type_);
if (require_any_grad && (vec_map.size() > 2)) {
paddle::platform::RecordEvent node_creation_record_event( paddle::platform::RecordEvent node_creation_record_event(
"Custom Op " + op_type_ + " double_grad node_creation", "Custom Op " + op_type_ + " double_grad node_creation",
paddle::platform::TracerEventType::OperatorInner, 1); paddle::platform::TracerEventType::OperatorInner, 1);
......
...@@ -384,7 +384,7 @@ static PyObject* eager_api_run_costum_op(PyObject* self, PyObject* args, ...@@ -384,7 +384,7 @@ static PyObject* eager_api_run_costum_op(PyObject* self, PyObject* args,
require_any_grad || egr::EagerUtils::ComputeRequireGrad( require_any_grad || egr::EagerUtils::ComputeRequireGrad(
trace_backward, &(ins_auto_grad_metas[i])); trace_backward, &(ins_auto_grad_metas[i]));
} }
if (require_any_grad) { if (require_any_grad && (vec_map.size() > 1)) {
VLOG(6) << " Construct Grad for Custom Op: " << op_type; VLOG(6) << " Construct Grad for Custom Op: " << op_type;
ConstructFwdAndBwdMap(vec_map, op_type); ConstructFwdAndBwdMap(vec_map, op_type);
for (size_t i = 0; i < outs_auto_grad_metas.size(); i++) { for (size_t i = 0; i < outs_auto_grad_metas.size(); i++) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册