diff --git a/paddle/framework/grad_op_builder.cc b/paddle/framework/grad_op_builder.cc index cbfc1bfab0cc24045dc51626530079f9c142e708..8bd2bc590272256fed79f4ab38ad52b470e87012 100644 --- a/paddle/framework/grad_op_builder.cc +++ b/paddle/framework/grad_op_builder.cc @@ -19,8 +19,6 @@ permissions and limitations under the License. */ namespace paddle { namespace framework { -using VarIndexMap = std::unordered_map; - typedef std::vector Ints; enum class OpArgType { IN, OUT }; @@ -91,21 +89,27 @@ OperatorBase* BuildGradOp(const OperatorBase* op) { } std::vector grad_inputs, grad_outputs; - std::unordered_map grad_idxs; + + using VarIndexMap = std::unordered_map; + VarIndexMap* grad_idxs = new VarIndexMap; int in_idx = 0; int out_idx = 0; - TransOpArg(op, grad_inputs, grad_outputs, grad_attrs, grad_idxs, + TransOpArg(op, grad_inputs, grad_outputs, grad_attrs, *grad_idxs, "input_format", "input_format", in_idx, false); // I - TransOpArg(op, grad_inputs, grad_outputs, grad_attrs, grad_idxs, + TransOpArg(op, grad_inputs, grad_outputs, grad_attrs, *grad_idxs, "output_format", "input_format", in_idx, false); // G - TransOpArg(op, grad_inputs, grad_outputs, grad_attrs, grad_idxs, + TransOpArg(op, grad_inputs, grad_outputs, grad_attrs, *grad_idxs, "output_format", "input_format", in_idx, true); // OG - TransOpArg(op, grad_inputs, grad_outputs, grad_attrs, grad_idxs, + TransOpArg(op, grad_inputs, grad_outputs, grad_attrs, *grad_idxs, "input_format", "output_format", out_idx, true); // IG OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)(); - // TODO(yi): Set data member of grad_op. + grad_op->type_ = grad_op_type; + grad_op->inputs_ = grad_inputs; + grad_op->outputs_ = grad_outputs; + grad_op->attrs_ = grad_attrs; + grad_op->in_out_idxs_.reset(grad_idxs); return grad_op; }