提交 5381a6ee 编写于 作者: Y Yi Wang

Update

上级 717fe549
...@@ -19,8 +19,6 @@ permissions and limitations under the License. */ ...@@ -19,8 +19,6 @@ permissions and limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
using VarIndexMap = std::unordered_map<std::string, int>;
typedef std::vector<int> Ints; typedef std::vector<int> Ints;
enum class OpArgType { IN, OUT }; enum class OpArgType { IN, OUT };
...@@ -91,21 +89,27 @@ OperatorBase* BuildGradOp(const OperatorBase* op) { ...@@ -91,21 +89,27 @@ OperatorBase* BuildGradOp(const OperatorBase* op) {
} }
std::vector<std::string> grad_inputs, grad_outputs; std::vector<std::string> grad_inputs, grad_outputs;
std::unordered_map<std::string, int> grad_idxs;
using VarIndexMap = std::unordered_map<std::string, int>;
VarIndexMap* grad_idxs = new VarIndexMap;
int in_idx = 0; int in_idx = 0;
int out_idx = 0; int out_idx = 0;
TransOpArg(op, grad_inputs, grad_outputs, grad_attrs, grad_idxs, TransOpArg(op, grad_inputs, grad_outputs, grad_attrs, *grad_idxs,
"input_format", "input_format", in_idx, false); // I "input_format", "input_format", in_idx, false); // I
TransOpArg(op, grad_inputs, grad_outputs, grad_attrs, grad_idxs, TransOpArg(op, grad_inputs, grad_outputs, grad_attrs, *grad_idxs,
"output_format", "input_format", in_idx, false); // G "output_format", "input_format", in_idx, false); // G
TransOpArg(op, grad_inputs, grad_outputs, grad_attrs, grad_idxs, TransOpArg(op, grad_inputs, grad_outputs, grad_attrs, *grad_idxs,
"output_format", "input_format", in_idx, true); // OG "output_format", "input_format", in_idx, true); // OG
TransOpArg(op, grad_inputs, grad_outputs, grad_attrs, grad_idxs, TransOpArg(op, grad_inputs, grad_outputs, grad_attrs, *grad_idxs,
"input_format", "output_format", out_idx, true); // IG "input_format", "output_format", out_idx, true); // IG
OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)(); OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)();
// TODO(yi): Set data member of grad_op. grad_op->type_ = grad_op_type;
grad_op->inputs_ = grad_inputs;
grad_op->outputs_ = grad_outputs;
grad_op->attrs_ = grad_attrs;
grad_op->in_out_idxs_.reset(grad_idxs);
return grad_op; return grad_op;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册