未验证 提交 0c1d68e1 编写于 作者: H HongyuJia 提交者: GitHub

fix custom operator backward=None (#48656)

上级 2cb07a1f
...@@ -217,18 +217,20 @@ RunCustomOpNode::operator()( ...@@ -217,18 +217,20 @@ RunCustomOpNode::operator()(
VLOG(6) << "Prepare Grad outputs for size: " << grad_outputs_names.size(); VLOG(6) << "Prepare Grad outputs for size: " << grad_outputs_names.size();
for (size_t i = 0; i < OutputMeta().size(); i++) { for (size_t i = 0; i < OutputMeta().size(); i++) {
if (map[0][0].find(i) != map[0][0].end()) { if (map[0][0].find(i) != map[0][0].end()) {
int grad_output_idx = map[0][0][i];
VLOG(7) << "Insert grad outputs: " << i VLOG(7) << "Insert grad outputs: " << i
<< " with size: " << OutputMeta()[i].size() << " with size: " << OutputMeta()[grad_output_idx].size()
<< " to tmp_outputs: " << map[0][0][i]; << " to tmp_outputs: " << grad_output_idx;
for (size_t j = 0; j < OutputMeta()[i].size(); j++) { for (size_t j = 0; j < OutputMeta()[grad_output_idx].size(); j++) {
outs[i].emplace_back(/* init it incase of copy nullptr of shared_ptr */ outs[grad_output_idx]
std::make_shared<phi::DenseTensor>( .emplace_back(/* init it incase of copy nullptr of shared_ptr */
phi::DataType::UNDEFINED), std::make_shared<phi::DenseTensor>(
egr::Controller::Instance().GenerateUniqueName( phi::DataType::UNDEFINED),
"custom_tmp_grad")); egr::Controller::Instance().GenerateUniqueName(
egr::EagerUtils::autograd_meta(&(outs[i][j])); "custom_tmp_grad"));
egr::EagerUtils::autograd_meta(&(outs[grad_output_idx][j]));
} }
tmp_outs[map[0][0][i]] = outs[i]; tmp_outs[grad_output_idx] = outs[grad_output_idx];
} }
} }
for (size_t i = 0; i < tmp_outs.size(); i++) { for (size_t i = 0; i < tmp_outs.size(); i++) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册