未验证 提交 9c98ee3e 编写于 作者: L Leo Chen 提交者: GitHub

fix proto consistency bug (#45017)

* fix proto bug

* add ut

* reset need_update for var_desc

* refine code

* fix var desc order issue
上级 81d6fa6c
......@@ -176,19 +176,48 @@ std::vector<OpDesc *> BlockDesc::AllOps() const {
}
void BlockDesc::Flush() {
auto need_update = NeedUpdate(true);
for (auto &op_desc : ops_) {
op_desc->Flush();
}
if (need_update_) {
// no flush for var_desc? or is op_desc flush really needed?
VLOG(10) << "Flush " << NeedUpdate(true) << " " << need_update << std::endl;
if (need_update) {
this->desc_->mutable_ops()->Clear();
for (auto &op_desc : ops_) {
this->desc_->mutable_ops()->Add()->CopyFrom(*op_desc->Proto());
// op_desc's need_update is set to false in op_desc->Flush();
}
std::vector<std::string> var_names;
std::set<std::string> var_names_set;
// keep order
for (const auto &var : this->desc_->vars()) {
var_names.emplace_back(var.name());
var_names_set.insert(var.name());
}
this->desc_->mutable_vars()->Clear();
for (const auto &name : var_names) {
if (vars_.count(name)) {
this->desc_->mutable_vars()->Add()->CopyFrom(*vars_[name]->Proto());
vars_[name]->SetNeedUpdate(false);
}
}
for (auto &var_desc : vars_) {
if (var_names_set.count(var_desc.first) != 1) {
this->desc_->mutable_vars()->Add()->CopyFrom(*var_desc.second->Proto());
var_desc.second->SetNeedUpdate(false);
}
}
// this->desc_->mutable_vars()->Clear();
// for (auto &var_desc : vars_) {
// this->desc_->mutable_vars()->Add()->CopyFrom(*var_desc.second->Proto());
// var_desc.second->SetNeedUpdate(false);
// }
need_update_ = false;
}
}
......@@ -207,6 +236,7 @@ BlockDesc::BlockDesc(ProgramDesc *prog, proto::BlockDesc *desc)
for (const proto::VarDesc &var_desc : desc_->vars()) {
vars_[var_desc.name()].reset(new VarDesc(var_desc));
}
for (const proto::OpDesc &op_desc : desc_->ops()) {
ops_.emplace_back(new OpDesc(op_desc, this));
}
......@@ -304,5 +334,24 @@ void BlockDesc::MoveFrom(BlockDesc *block) {
block->Flush();
}
bool BlockDesc::NeedUpdate(bool include_subs) {
bool need = need_update_;
if (include_subs) {
for (const auto &op : ops_) {
if (op->NeedUpdate()) {
need = true;
break;
}
}
for (const auto &pair : vars_) {
if (pair.second->NeedUpdate()) {
need = true;
break;
}
}
}
return need;
}
} // namespace framework
} // namespace paddle
......@@ -113,10 +113,13 @@ class BlockDesc {
void MoveFrom(BlockDesc *block);
bool NeedUpdate(bool include_subs = true);
private:
ProgramDesc *prog_; // not_own
proto::BlockDesc *desc_; // not_own
bool need_update_;
bool need_update_; // block itself need_update, not aware of its ops_ and
// vars_
std::deque<std::unique_ptr<OpDesc>> ops_;
std::unordered_map<std::string, std::unique_ptr<VarDesc>> vars_;
......
......@@ -883,6 +883,8 @@ struct SetAttrDescVisitor {
};
void OpDesc::Flush() {
VLOG(4) << "Flush "
<< " " << Type() << " " << need_update_;
if (need_update_) {
this->desc_.mutable_inputs()->Clear();
for (auto &ipt : inputs_) {
......
......@@ -174,6 +174,8 @@ class OpDesc {
uint64_t OriginalId() const { return original_id_; }
void SetOriginalId(uint64_t original_id) { original_id_ = original_id; }
bool NeedUpdate() const { return need_update_; }
private:
friend class ProgramDesc;
// Find VarDesc from OpDesc located Block into global Block
......@@ -198,7 +200,7 @@ class OpDesc {
// Must start from one
return ++uid;
}
// it it really needed? or just mantain a ptr from block?
proto::OpDesc desc_;
BlockDesc *block_{nullptr}; // not_own
// input arg name => input variable names
......
......@@ -249,5 +249,16 @@ void ProgramDesc::SetFetchHolderName(const std::string &fetch_holder_name) {
fetch_holder->SetPersistable(true);
}
bool ProgramDesc::NeedUpdate() const {
bool need = false;
for (auto &block : blocks_) {
if (block->NeedUpdate()) {
need = true;
break;
}
}
return need;
}
} // namespace framework
} // namespace paddle
......@@ -85,6 +85,8 @@ class ProgramDesc {
// This function is used to change or unify the fetch_holder variables' name.
void SetFetchHolderName(const std::string &fetch_holder_name);
bool NeedUpdate() const;
private:
void InitFromProto();
......
......@@ -25,10 +25,12 @@ proto::VarType::Type VarDesc::GetType() const { return desc_.type().type(); }
void VarDesc::SetType(proto::VarType::Type type) {
desc_.mutable_type()->set_type(type);
need_updated_ = true;
}
void VarDesc::SetShape(const std::vector<int64_t> &dims) {
VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims());
need_updated_ = true;
}
void VarDesc::SetTensorDescNum(size_t num) {
......@@ -48,6 +50,7 @@ void VarDesc::SetTensorDescNum(size_t num) {
"supported by the %s type variable.",
this->Name()));
}
need_updated_ = true;
}
size_t VarDesc::GetTensorDescNum() const {
......@@ -76,6 +79,7 @@ void VarDesc::SetShapes(
for (size_t i = 0; i < multiple_dims.size(); ++i) {
VectorToRepeated(multiple_dims[i], tensors[i]->mutable_dims());
}
need_updated_ = true;
}
std::vector<int64_t> VarDesc::GetShape() const {
......@@ -94,6 +98,7 @@ std::vector<std::vector<int64_t>> VarDesc::GetShapes() const {
void VarDesc::SetDataType(proto::VarType::Type data_type) {
mutable_tensor_desc()->set_data_type(data_type);
need_updated_ = true;
}
void VarDesc::SetDataTypes(
......@@ -111,6 +116,7 @@ void VarDesc::SetDataTypes(
for (size_t i = 0; i < multiple_data_type.size(); ++i) {
tensor_descs[i]->set_data_type(multiple_data_type[i]);
}
need_updated_ = true;
}
proto::VarType::Type VarDesc::GetDataType() const {
......@@ -144,6 +150,7 @@ void VarDesc::SetLoDLevel(int32_t lod_level) {
"Setting 'lod_level' is not supported by the %s type variable.",
this->Name()));
}
need_updated_ = true;
}
void VarDesc::SetLoDLevels(const std::vector<int32_t> &multiple_lod_level) {
......@@ -168,6 +175,7 @@ void VarDesc::SetLoDLevels(const std::vector<int32_t> &multiple_lod_level) {
"Setting 'lod_levels' is not supported by the %s type variable",
this->Name()));
}
need_updated_ = true;
}
int32_t VarDesc::GetLoDLevel() const {
......@@ -273,6 +281,7 @@ proto::VarType::TensorDesc *VarDesc::mutable_tensor_desc() {
"supported by the %s type variable.",
this->Name()));
}
need_updated_ = true;
}
std::vector<proto::VarType::TensorDesc *> VarDesc::mutable_tensor_descs() {
......@@ -298,6 +307,7 @@ std::vector<proto::VarType::TensorDesc *> VarDesc::mutable_tensor_descs() {
"Getting 'tensor_descs' is not supported by the %s type variable.",
this->Name()));
}
need_updated_ = true;
}
std::vector<std::string> VarDesc::AttrNames() const {
......
......@@ -65,9 +65,12 @@ class VarDesc {
desc_.set_name(name);
// TODO(paddle-dev): Why default to lodtensor.
desc_.mutable_type()->set_type(proto::VarType::LOD_TENSOR);
need_updated_ = true;
}
explicit VarDesc(const proto::VarDesc &desc) : desc_(desc) {}
explicit VarDesc(const proto::VarDesc &desc) : desc_(desc) {
// need_updated_ = true;
}
// Explicitly implement the copy constructor for auto parallel
VarDesc(const VarDesc &other)
......@@ -78,16 +81,23 @@ class VarDesc {
desc_ = other.desc_;
attrs_ = other.attrs_;
original_id_ = other.original_id_;
need_updated_ = true;
return *this;
}
proto::VarDesc *Proto() { return &desc_; }
proto::VarDesc *Proto() {
return &desc_;
need_updated_ = true;
}
const proto::VarDesc *Proto() const { return &desc_; }
std::string Name() const { return desc_.name(); }
void SetName(std::string name) { desc_.set_name(name); }
void SetName(std::string name) {
desc_.set_name(name);
need_updated_ = true;
}
void SetTensorDescNum(size_t num);
......@@ -126,15 +136,22 @@ class VarDesc {
bool Persistable() const { return desc_.persistable(); }
void SetPersistable(bool persistable) { desc_.set_persistable(persistable); }
void SetPersistable(bool persistable) {
desc_.set_persistable(persistable);
need_updated_ = true;
}
bool IsParameter() const { return desc_.is_parameter(); }
void SetIsParameter(bool is_parameter) {
desc_.set_is_parameter(is_parameter);
need_updated_ = true;
}
void ClearIsParameter() { desc_.clear_is_parameter(); }
void ClearIsParameter() {
desc_.clear_is_parameter();
need_updated_ = true;
}
bool HasIsParameter() const { return desc_.has_is_parameter(); }
......@@ -142,9 +159,13 @@ class VarDesc {
void SetStopGradient(bool stop_gradient) {
desc_.set_stop_gradient(stop_gradient);
need_updated_ = true;
}
void ClearStopGradient() { desc_.clear_stop_gradient(); }
void ClearStopGradient() {
desc_.clear_stop_gradient();
need_updated_ = true;
}
bool HasStopGradient() const { return desc_.has_stop_gradient(); }
......@@ -152,6 +173,7 @@ class VarDesc {
void SetNeedCheckFeed(bool need_check_feed) {
desc_.set_need_check_feed(need_check_feed);
need_updated_ = true;
}
bool HasAttr(const std::string &name) const {
......@@ -168,7 +190,13 @@ class VarDesc {
// The Id() and OriginalId() are only used for auto parallel.
uint64_t Id() const { return id_; }
uint64_t OriginalId() const { return original_id_; }
void SetOriginalId(uint64_t original_id) { original_id_ = original_id; }
void SetOriginalId(uint64_t original_id) {
original_id_ = original_id;
need_updated_ = true;
}
bool NeedUpdate() const { return need_updated_; }
void SetNeedUpdate(bool need) { need_updated_ = need; }
private:
const proto::VarType::TensorDesc &tensor_desc() const;
......@@ -183,9 +211,12 @@ class VarDesc {
return ++uid;
}
// it it really needed? or just mantain a ptr from block?
proto::VarDesc desc_;
AttributeMap attrs_;
bool need_updated_{false};
// Note: the id_ is unique for all VarDesc (only for auto parallel).
uint64_t id_ = GenerateId();
// Note: the orignal_id_ is used for referring to the original VarDesc
......
......@@ -84,6 +84,7 @@ void BindProgramDesc(pybind11::module *m) {
.def("get_feed_target_names", &pd::ProgramDesc::GetFeedTargetNames)
.def("get_fetch_target_names", &pd::ProgramDesc::GetFetchTargetNames)
.def("serialize_to_string", SerializeMessage<pd::ProgramDesc>)
.def("need_update", &pd::ProgramDesc::NeedUpdate)
.def("parse_from_string",
[](pd::ProgramDesc &program_desc, const std::string &data) {
pd::proto::ProgramDesc *desc = program_desc.Proto();
......
......@@ -202,5 +202,44 @@ class TestProgram(unittest.TestCase):
self.assertFalse(var.has_stop_gradient())
def build_program():
main_program = paddle.static.Program()
startuo_program = paddle.static.Program()
with paddle.utils.unique_name.guard():
with paddle.static.program_guard(main_program, startuo_program):
x = paddle.static.data(name='x', shape=[3, 2, 1])
out = paddle.static.nn.fc(x=x, size=1, num_flatten_dims=2)
return main_program
class TestProgramProto(unittest.TestCase):
def test_update_op(self):
program = build_program()
a = program.desc.serialize_to_string()
program.current_block().ops[0]._set_attr('use_mkldnn', True)
self.assertTrue(program.desc.need_update())
b = program.desc.serialize_to_string()
self.assertFalse(a == b)
def test_update_var(self):
program = build_program()
a = program.desc.serialize_to_string()
program.current_block().var("x").desc.set_stop_gradient(False)
self.assertTrue(program.desc.need_update())
b = program.desc.serialize_to_string()
self.assertFalse(a == b)
# it seems the attrs of framework::VarDesc is not write to proto,
# except for persistable/need_check_feed/is_parameter/stop_gradient
def test_update_var_attr(self):
program = build_program()
a = program.desc.serialize_to_string()
program.current_block().var("x").desc._set_attr("a", 1)
self.assertFalse(program.desc.need_update())
b = program.desc.serialize_to_string()
self.assertTrue(a == b) # not affected
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册