提交 f9473e22 编写于 作者: C chengtbf 提交者: Jinhui Yuan

fix bug of op infer data id (#1054)

上级 098db839
......@@ -32,9 +32,8 @@ void AccuracyOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> Get
BlobDesc* accuracy_blob_desc = GetBlobDesc4BnInOp("accuracy");
accuracy_blob_desc->mut_shape() = Shape({1});
accuracy_blob_desc->set_data_type(pred_blob_desc->data_type());
accuracy_blob_desc->set_has_data_id_field(pred_blob_desc->has_data_id_field());
}
REGISTER_OP(OperatorConf::kAccuracyConf, AccuracyOp);
} // namespace oneflow
\ No newline at end of file
} // namespace oneflow
......@@ -44,7 +44,6 @@ void LossOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlob
BlobDesc* reduction_blob_desc = GetBlobDesc4BnInOp("reduction_coefficient");
reduction_blob_desc->mut_shape() = Shape({1});
reduction_blob_desc->set_data_type(pred_blob_desc->data_type());
reduction_blob_desc->set_has_data_id_field(pred_blob_desc->has_data_id_field());
}
VirtualInferBlobDescs(GetBlobDesc4BnInOp, parallel_ctx, buf_size);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册