提交 f9fab14c 编写于 作者: Y Yu Yang

Fix compile error

上级 9475972b
......@@ -49,11 +49,9 @@ static std::shared_ptr<OperatorBase> EmptyOp() {
return net_op;
}
static void DeDuplicate(NetOp* net, std::unordered_se)
static std::shared_ptr<OperatorBase> BackwardImpl(
const OperatorBase& forwardOp,
std::unordered_set<std::string>& no_grad_names, unsigned& uniq_id) {
static std::shared_ptr<OperatorBase> BackwardImpl(
const OperatorBase& forwardOp,
std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) {
if (AllInSet(forwardOp.inputs_, OperatorBase::GRAD_VAR_SUFFIX(),
no_grad_names)) {
return EmptyOp();
......@@ -73,13 +71,16 @@ static void DeDuplicate(NetOp* net, std::unordered_se)
if (forwardOp.IsNetOp()) {
//! TODO(dzh)
std::unordered_map<std::string, int> dup_output;
std::unordered_map<std::string std::vector<int>> dup_output_ops;
const unsigned uniq_id_local = uniq_id;
unsigned op_id_offset = 0;
for (auto& fwd : forwardOp) {
auto bwd = Backward(fwd, no_grad_names);
std::unordered_map<std::string, std::vector<int>> dup_output_ops;
// const unsigned uniq_id_local = uniq_id;
int op_id_offset = 0;
// Because it is a net op, it can static_cast.
auto& forwardNet = static_cast<const NetOp&>(forwardOp);
for (auto& fwd : forwardNet.ops_) {
auto bwd = Backward(*fwd, no_grad_names);
net->AddOp(bwd);
for (size_t i = 0; i < bwd.outputs_; ++i) {
for (size_t i = 0; i < bwd->outputs_.size(); ++i) {
bwd->outputs_[i] += OperatorBase::EMPTY_VAR_NAME();
if (dup_output.find(bwd->inputs_[i]) == dup_output.end()) {
dup_output[bwd->inputs_[i]] = 1;
......@@ -138,7 +139,7 @@ extern std::shared_ptr<OperatorBase> Backward(
for (auto& name : no_grad_vars) {
no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX());
}
int uid = 0;
size_t uid = 0;
return BackwardImpl(forwardOp, no_grad_names, uid);
}
} // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册