提交 b635af71 编写于 作者: F fengjiayi

Fix some compile error

上级 8a5ee462
...@@ -12,7 +12,8 @@ OperatorBase* GradOpCreator::Create() { ...@@ -12,7 +12,8 @@ OperatorBase* GradOpCreator::Create() {
OpInOutArg* GradOpCreator::BuildArg(const VarProto& var, OpInOutArg* GradOpCreator::BuildArg(const VarProto& var,
const VarIndexMap& var_map, const VarIndexMap& var_map,
const vector<int>& format, InOutType type) { const std::vector<int>& format,
InOutType type) {
int idx = var_map.at(var.name()); int idx = var_map.at(var.name());
int begin_idx = format.empty() ? idx : format.at(idx); int begin_idx = format.empty() ? idx : format.at(idx);
int end_idx = format.empty() ? idx + 1 : format.at(idx + 1); int end_idx = format.empty() ? idx + 1 : format.at(idx + 1);
...@@ -23,11 +24,11 @@ OpInOutArg* GradOpCreator::BuildArg(const VarProto& var, ...@@ -23,11 +24,11 @@ OpInOutArg* GradOpCreator::BuildArg(const VarProto& var,
void GradOpCreator::BuildOpInOutArgList() { void GradOpCreator::BuildOpInOutArgList() {
const OpProto& op_proto = OpRegistry::protos().at(op_->type); const OpProto& op_proto = OpRegistry::protos().at(op_->type);
const auto& var_map = *(OpRegistry::VarIndexMaps().at(op->type_)); const auto& var_map = *(OpRegistry::VarIndexMaps().at(op->type_));
const vector<int>& in_format = const std::vector<int>& in_format =
op_->attrs_.count("input_format") op_->attrs_.count("input_format")
? op->GetAttr<std::vector<int>>("input_format") ? op->GetAttr<std::vector<int>>("input_format")
: std::vector<int>(); : std::vector<int>();
const vector<int>& out_format = const std::vector<int>& out_format =
op_->attrs_.count("output_format") op_->attrs_.count("output_format")
? op->GetAttr<std::vector<int>>("output_format") ? op->GetAttr<std::vector<int>>("output_format")
: std::vector<int>(); : std::vector<int>();
...@@ -41,10 +42,11 @@ void GradOpCreator::BuildOpInOutArgList() { ...@@ -41,10 +42,11 @@ void GradOpCreator::BuildOpInOutArgList() {
} }
} }
void GradOpCreator::PushArgIntoGradOp(const OpInOutArg* arg, void GradOpCreator::AddArgIntoGradOp(const OpInOutArg* arg,
vector<std::string>& in_out, std::vector<std::string>& in_out,
vector<int>& format, VarIndexMap* varmap, std::vector<int>& format,
int& idx, bool is_grad) { VarIndexMap* varmap, int& idx,
bool is_grad) {
std::string var_name = arg->proto_name_; std::string var_name = arg->proto_name_;
if (is_grad) { if (is_grad) {
var_name += OperatorBase::GRAD_VAR_SUFFIX(); var_name += OperatorBase::GRAD_VAR_SUFFIX();
...@@ -70,22 +72,22 @@ void GradOpCreator::CompleteGradOp(OperatorBase* grad_op) const { ...@@ -70,22 +72,22 @@ void GradOpCreator::CompleteGradOp(OperatorBase* grad_op) const {
VarIndexMap* grad_varmap = new VarIndexMap(); VarIndexMap* grad_varmap = new VarIndexMap();
int in_idx = 0; int in_idx = 0;
int out_idx = 0; int out_idx = 0;
vector<int> in_format({0}); std::vector<int> in_format({0});
vector<int> out_format({0}); std::vector<int> out_format({0});
for (const auto& arg : arg_list_) { for (const auto& arg : arg_list_) {
// op_'s inputs_ and outputs_ // op_'s inputs_ and outputs_
if (arg->needed_in_grad_) { if (arg->needed_in_grad_) {
PushArgIntoGradOp(arg.get(), grad_op->inputs_, in_format, grad_varmap, AddArgIntoGradOp(arg.get(), grad_op->inputs_, in_format, grad_varmap,
in_idx, false); in_idx, false);
} }
if (arg->type_ == IN) { if (arg->type_ == IN) {
// gradients of op_'s inputs_ // gradients of op_'s inputs_
PushArgIntoGradOp(arg.get(), grad_op->outputs_, out_format, grad_varmap, AddArgIntoGradOp(arg.get(), grad_op->outputs_, out_format, grad_varmap,
out_idx, true); out_idx, true);
} else { } else {
// gradients of op_'s outputs_ // gradients of op_'s outputs_
PushArgIntoGradOp(arg.get(), grad_op->inputs_, in_format, grad_varmap, AddArgIntoGradOp(arg.get(), grad_op->inputs_, in_format, grad_varmap,
in_idx, true); in_idx, true);
} }
} }
grad_op->attrs_["input_format"] = in_format; grad_op->attrs_["input_format"] = in_format;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册