未验证 提交 9c98ee3e 编写于 作者: L Leo Chen 提交者: GitHub

fix proto consistency bug (#45017)

* fix proto bug

* add ut

* reset need_update for var_desc

* refine code

* fix var desc order issue
上级 81d6fa6c
...@@ -176,19 +176,48 @@ std::vector<OpDesc *> BlockDesc::AllOps() const { ...@@ -176,19 +176,48 @@ std::vector<OpDesc *> BlockDesc::AllOps() const {
} }
void BlockDesc::Flush() { void BlockDesc::Flush() {
auto need_update = NeedUpdate(true);
for (auto &op_desc : ops_) { for (auto &op_desc : ops_) {
op_desc->Flush(); op_desc->Flush();
} }
// no flush for var_desc? or is op_desc flush really needed?
if (need_update_) { VLOG(10) << "Flush " << NeedUpdate(true) << " " << need_update << std::endl;
if (need_update) {
this->desc_->mutable_ops()->Clear(); this->desc_->mutable_ops()->Clear();
for (auto &op_desc : ops_) { for (auto &op_desc : ops_) {
this->desc_->mutable_ops()->Add()->CopyFrom(*op_desc->Proto()); this->desc_->mutable_ops()->Add()->CopyFrom(*op_desc->Proto());
// op_desc's need_update is set to false in op_desc->Flush();
} }
std::vector<std::string> var_names;
std::set<std::string> var_names_set;
// keep order
for (const auto &var : this->desc_->vars()) {
var_names.emplace_back(var.name());
var_names_set.insert(var.name());
}
this->desc_->mutable_vars()->Clear(); this->desc_->mutable_vars()->Clear();
for (const auto &name : var_names) {
if (vars_.count(name)) {
this->desc_->mutable_vars()->Add()->CopyFrom(*vars_[name]->Proto());
vars_[name]->SetNeedUpdate(false);
}
}
for (auto &var_desc : vars_) { for (auto &var_desc : vars_) {
this->desc_->mutable_vars()->Add()->CopyFrom(*var_desc.second->Proto()); if (var_names_set.count(var_desc.first) != 1) {
this->desc_->mutable_vars()->Add()->CopyFrom(*var_desc.second->Proto());
var_desc.second->SetNeedUpdate(false);
}
} }
// this->desc_->mutable_vars()->Clear();
// for (auto &var_desc : vars_) {
// this->desc_->mutable_vars()->Add()->CopyFrom(*var_desc.second->Proto());
// var_desc.second->SetNeedUpdate(false);
// }
need_update_ = false; need_update_ = false;
} }
} }
...@@ -207,6 +236,7 @@ BlockDesc::BlockDesc(ProgramDesc *prog, proto::BlockDesc *desc) ...@@ -207,6 +236,7 @@ BlockDesc::BlockDesc(ProgramDesc *prog, proto::BlockDesc *desc)
for (const proto::VarDesc &var_desc : desc_->vars()) { for (const proto::VarDesc &var_desc : desc_->vars()) {
vars_[var_desc.name()].reset(new VarDesc(var_desc)); vars_[var_desc.name()].reset(new VarDesc(var_desc));
} }
for (const proto::OpDesc &op_desc : desc_->ops()) { for (const proto::OpDesc &op_desc : desc_->ops()) {
ops_.emplace_back(new OpDesc(op_desc, this)); ops_.emplace_back(new OpDesc(op_desc, this));
} }
...@@ -304,5 +334,24 @@ void BlockDesc::MoveFrom(BlockDesc *block) { ...@@ -304,5 +334,24 @@ void BlockDesc::MoveFrom(BlockDesc *block) {
block->Flush(); block->Flush();
} }
bool BlockDesc::NeedUpdate(bool include_subs) {
bool need = need_update_;
if (include_subs) {
for (const auto &op : ops_) {
if (op->NeedUpdate()) {
need = true;
break;
}
}
for (const auto &pair : vars_) {
if (pair.second->NeedUpdate()) {
need = true;
break;
}
}
}
return need;
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -113,10 +113,13 @@ class BlockDesc { ...@@ -113,10 +113,13 @@ class BlockDesc {
void MoveFrom(BlockDesc *block); void MoveFrom(BlockDesc *block);
bool NeedUpdate(bool include_subs = true);
private: private:
ProgramDesc *prog_; // not_own ProgramDesc *prog_; // not_own
proto::BlockDesc *desc_; // not_own proto::BlockDesc *desc_; // not_own
bool need_update_; bool need_update_; // block itself need_update, not aware of its ops_ and
// vars_
std::deque<std::unique_ptr<OpDesc>> ops_; std::deque<std::unique_ptr<OpDesc>> ops_;
std::unordered_map<std::string, std::unique_ptr<VarDesc>> vars_; std::unordered_map<std::string, std::unique_ptr<VarDesc>> vars_;
......
...@@ -883,6 +883,8 @@ struct SetAttrDescVisitor { ...@@ -883,6 +883,8 @@ struct SetAttrDescVisitor {
}; };
void OpDesc::Flush() { void OpDesc::Flush() {
VLOG(4) << "Flush "
<< " " << Type() << " " << need_update_;
if (need_update_) { if (need_update_) {
this->desc_.mutable_inputs()->Clear(); this->desc_.mutable_inputs()->Clear();
for (auto &ipt : inputs_) { for (auto &ipt : inputs_) {
......
...@@ -174,6 +174,8 @@ class OpDesc { ...@@ -174,6 +174,8 @@ class OpDesc {
uint64_t OriginalId() const { return original_id_; } uint64_t OriginalId() const { return original_id_; }
void SetOriginalId(uint64_t original_id) { original_id_ = original_id; } void SetOriginalId(uint64_t original_id) { original_id_ = original_id; }
bool NeedUpdate() const { return need_update_; }
private: private:
friend class ProgramDesc; friend class ProgramDesc;
// Find VarDesc from OpDesc located Block into global Block // Find VarDesc from OpDesc located Block into global Block
...@@ -198,7 +200,7 @@ class OpDesc { ...@@ -198,7 +200,7 @@ class OpDesc {
// Must start from one // Must start from one
return ++uid; return ++uid;
} }
// it it really needed? or just mantain a ptr from block?
proto::OpDesc desc_; proto::OpDesc desc_;
BlockDesc *block_{nullptr}; // not_own BlockDesc *block_{nullptr}; // not_own
// input arg name => input variable names // input arg name => input variable names
......
...@@ -249,5 +249,16 @@ void ProgramDesc::SetFetchHolderName(const std::string &fetch_holder_name) { ...@@ -249,5 +249,16 @@ void ProgramDesc::SetFetchHolderName(const std::string &fetch_holder_name) {
fetch_holder->SetPersistable(true); fetch_holder->SetPersistable(true);
} }
bool ProgramDesc::NeedUpdate() const {
bool need = false;
for (auto &block : blocks_) {
if (block->NeedUpdate()) {
need = true;
break;
}
}
return need;
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -85,6 +85,8 @@ class ProgramDesc { ...@@ -85,6 +85,8 @@ class ProgramDesc {
// This function is used to change or unify the fetch_holder variables' name. // This function is used to change or unify the fetch_holder variables' name.
void SetFetchHolderName(const std::string &fetch_holder_name); void SetFetchHolderName(const std::string &fetch_holder_name);
bool NeedUpdate() const;
private: private:
void InitFromProto(); void InitFromProto();
......
...@@ -25,10 +25,12 @@ proto::VarType::Type VarDesc::GetType() const { return desc_.type().type(); } ...@@ -25,10 +25,12 @@ proto::VarType::Type VarDesc::GetType() const { return desc_.type().type(); }
void VarDesc::SetType(proto::VarType::Type type) { void VarDesc::SetType(proto::VarType::Type type) {
desc_.mutable_type()->set_type(type); desc_.mutable_type()->set_type(type);
need_updated_ = true;
} }
void VarDesc::SetShape(const std::vector<int64_t> &dims) { void VarDesc::SetShape(const std::vector<int64_t> &dims) {
VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims()); VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims());
need_updated_ = true;
} }
void VarDesc::SetTensorDescNum(size_t num) { void VarDesc::SetTensorDescNum(size_t num) {
...@@ -48,6 +50,7 @@ void VarDesc::SetTensorDescNum(size_t num) { ...@@ -48,6 +50,7 @@ void VarDesc::SetTensorDescNum(size_t num) {
"supported by the %s type variable.", "supported by the %s type variable.",
this->Name())); this->Name()));
} }
need_updated_ = true;
} }
size_t VarDesc::GetTensorDescNum() const { size_t VarDesc::GetTensorDescNum() const {
...@@ -76,6 +79,7 @@ void VarDesc::SetShapes( ...@@ -76,6 +79,7 @@ void VarDesc::SetShapes(
for (size_t i = 0; i < multiple_dims.size(); ++i) { for (size_t i = 0; i < multiple_dims.size(); ++i) {
VectorToRepeated(multiple_dims[i], tensors[i]->mutable_dims()); VectorToRepeated(multiple_dims[i], tensors[i]->mutable_dims());
} }
need_updated_ = true;
} }
std::vector<int64_t> VarDesc::GetShape() const { std::vector<int64_t> VarDesc::GetShape() const {
...@@ -94,6 +98,7 @@ std::vector<std::vector<int64_t>> VarDesc::GetShapes() const { ...@@ -94,6 +98,7 @@ std::vector<std::vector<int64_t>> VarDesc::GetShapes() const {
void VarDesc::SetDataType(proto::VarType::Type data_type) { void VarDesc::SetDataType(proto::VarType::Type data_type) {
mutable_tensor_desc()->set_data_type(data_type); mutable_tensor_desc()->set_data_type(data_type);
need_updated_ = true;
} }
void VarDesc::SetDataTypes( void VarDesc::SetDataTypes(
...@@ -111,6 +116,7 @@ void VarDesc::SetDataTypes( ...@@ -111,6 +116,7 @@ void VarDesc::SetDataTypes(
for (size_t i = 0; i < multiple_data_type.size(); ++i) { for (size_t i = 0; i < multiple_data_type.size(); ++i) {
tensor_descs[i]->set_data_type(multiple_data_type[i]); tensor_descs[i]->set_data_type(multiple_data_type[i]);
} }
need_updated_ = true;
} }
proto::VarType::Type VarDesc::GetDataType() const { proto::VarType::Type VarDesc::GetDataType() const {
...@@ -144,6 +150,7 @@ void VarDesc::SetLoDLevel(int32_t lod_level) { ...@@ -144,6 +150,7 @@ void VarDesc::SetLoDLevel(int32_t lod_level) {
"Setting 'lod_level' is not supported by the %s type variable.", "Setting 'lod_level' is not supported by the %s type variable.",
this->Name())); this->Name()));
} }
need_updated_ = true;
} }
void VarDesc::SetLoDLevels(const std::vector<int32_t> &multiple_lod_level) { void VarDesc::SetLoDLevels(const std::vector<int32_t> &multiple_lod_level) {
...@@ -168,6 +175,7 @@ void VarDesc::SetLoDLevels(const std::vector<int32_t> &multiple_lod_level) { ...@@ -168,6 +175,7 @@ void VarDesc::SetLoDLevels(const std::vector<int32_t> &multiple_lod_level) {
"Setting 'lod_levels' is not supported by the %s type variable", "Setting 'lod_levels' is not supported by the %s type variable",
this->Name())); this->Name()));
} }
need_updated_ = true;
} }
int32_t VarDesc::GetLoDLevel() const { int32_t VarDesc::GetLoDLevel() const {
...@@ -273,6 +281,7 @@ proto::VarType::TensorDesc *VarDesc::mutable_tensor_desc() { ...@@ -273,6 +281,7 @@ proto::VarType::TensorDesc *VarDesc::mutable_tensor_desc() {
"supported by the %s type variable.", "supported by the %s type variable.",
this->Name())); this->Name()));
} }
need_updated_ = true;
} }
std::vector<proto::VarType::TensorDesc *> VarDesc::mutable_tensor_descs() { std::vector<proto::VarType::TensorDesc *> VarDesc::mutable_tensor_descs() {
...@@ -298,6 +307,7 @@ std::vector<proto::VarType::TensorDesc *> VarDesc::mutable_tensor_descs() { ...@@ -298,6 +307,7 @@ std::vector<proto::VarType::TensorDesc *> VarDesc::mutable_tensor_descs() {
"Getting 'tensor_descs' is not supported by the %s type variable.", "Getting 'tensor_descs' is not supported by the %s type variable.",
this->Name())); this->Name()));
} }
need_updated_ = true;
} }
std::vector<std::string> VarDesc::AttrNames() const { std::vector<std::string> VarDesc::AttrNames() const {
......
...@@ -65,9 +65,12 @@ class VarDesc { ...@@ -65,9 +65,12 @@ class VarDesc {
desc_.set_name(name); desc_.set_name(name);
// TODO(paddle-dev): Why default to lodtensor. // TODO(paddle-dev): Why default to lodtensor.
desc_.mutable_type()->set_type(proto::VarType::LOD_TENSOR); desc_.mutable_type()->set_type(proto::VarType::LOD_TENSOR);
need_updated_ = true;
} }
explicit VarDesc(const proto::VarDesc &desc) : desc_(desc) {} explicit VarDesc(const proto::VarDesc &desc) : desc_(desc) {
// need_updated_ = true;
}
// Explicitly implement the copy constructor for auto parallel // Explicitly implement the copy constructor for auto parallel
VarDesc(const VarDesc &other) VarDesc(const VarDesc &other)
...@@ -78,16 +81,23 @@ class VarDesc { ...@@ -78,16 +81,23 @@ class VarDesc {
desc_ = other.desc_; desc_ = other.desc_;
attrs_ = other.attrs_; attrs_ = other.attrs_;
original_id_ = other.original_id_; original_id_ = other.original_id_;
need_updated_ = true;
return *this; return *this;
} }
proto::VarDesc *Proto() { return &desc_; } proto::VarDesc *Proto() {
return &desc_;
need_updated_ = true;
}
const proto::VarDesc *Proto() const { return &desc_; } const proto::VarDesc *Proto() const { return &desc_; }
std::string Name() const { return desc_.name(); } std::string Name() const { return desc_.name(); }
void SetName(std::string name) { desc_.set_name(name); } void SetName(std::string name) {
desc_.set_name(name);
need_updated_ = true;
}
void SetTensorDescNum(size_t num); void SetTensorDescNum(size_t num);
...@@ -126,15 +136,22 @@ class VarDesc { ...@@ -126,15 +136,22 @@ class VarDesc {
bool Persistable() const { return desc_.persistable(); } bool Persistable() const { return desc_.persistable(); }
void SetPersistable(bool persistable) { desc_.set_persistable(persistable); } void SetPersistable(bool persistable) {
desc_.set_persistable(persistable);
need_updated_ = true;
}
bool IsParameter() const { return desc_.is_parameter(); } bool IsParameter() const { return desc_.is_parameter(); }
void SetIsParameter(bool is_parameter) { void SetIsParameter(bool is_parameter) {
desc_.set_is_parameter(is_parameter); desc_.set_is_parameter(is_parameter);
need_updated_ = true;
} }
void ClearIsParameter() { desc_.clear_is_parameter(); } void ClearIsParameter() {
desc_.clear_is_parameter();
need_updated_ = true;
}
bool HasIsParameter() const { return desc_.has_is_parameter(); } bool HasIsParameter() const { return desc_.has_is_parameter(); }
...@@ -142,9 +159,13 @@ class VarDesc { ...@@ -142,9 +159,13 @@ class VarDesc {
void SetStopGradient(bool stop_gradient) { void SetStopGradient(bool stop_gradient) {
desc_.set_stop_gradient(stop_gradient); desc_.set_stop_gradient(stop_gradient);
need_updated_ = true;
} }
void ClearStopGradient() { desc_.clear_stop_gradient(); } void ClearStopGradient() {
desc_.clear_stop_gradient();
need_updated_ = true;
}
bool HasStopGradient() const { return desc_.has_stop_gradient(); } bool HasStopGradient() const { return desc_.has_stop_gradient(); }
...@@ -152,6 +173,7 @@ class VarDesc { ...@@ -152,6 +173,7 @@ class VarDesc {
void SetNeedCheckFeed(bool need_check_feed) { void SetNeedCheckFeed(bool need_check_feed) {
desc_.set_need_check_feed(need_check_feed); desc_.set_need_check_feed(need_check_feed);
need_updated_ = true;
} }
bool HasAttr(const std::string &name) const { bool HasAttr(const std::string &name) const {
...@@ -168,7 +190,13 @@ class VarDesc { ...@@ -168,7 +190,13 @@ class VarDesc {
// The Id() and OriginalId() are only used for auto parallel. // The Id() and OriginalId() are only used for auto parallel.
uint64_t Id() const { return id_; } uint64_t Id() const { return id_; }
uint64_t OriginalId() const { return original_id_; } uint64_t OriginalId() const { return original_id_; }
void SetOriginalId(uint64_t original_id) { original_id_ = original_id; } void SetOriginalId(uint64_t original_id) {
original_id_ = original_id;
need_updated_ = true;
}
bool NeedUpdate() const { return need_updated_; }
void SetNeedUpdate(bool need) { need_updated_ = need; }
private: private:
const proto::VarType::TensorDesc &tensor_desc() const; const proto::VarType::TensorDesc &tensor_desc() const;
...@@ -183,9 +211,12 @@ class VarDesc { ...@@ -183,9 +211,12 @@ class VarDesc {
return ++uid; return ++uid;
} }
// it it really needed? or just mantain a ptr from block?
proto::VarDesc desc_; proto::VarDesc desc_;
AttributeMap attrs_; AttributeMap attrs_;
bool need_updated_{false};
// Note: the id_ is unique for all VarDesc (only for auto parallel). // Note: the id_ is unique for all VarDesc (only for auto parallel).
uint64_t id_ = GenerateId(); uint64_t id_ = GenerateId();
// Note: the orignal_id_ is used for referring to the original VarDesc // Note: the orignal_id_ is used for referring to the original VarDesc
......
...@@ -84,6 +84,7 @@ void BindProgramDesc(pybind11::module *m) { ...@@ -84,6 +84,7 @@ void BindProgramDesc(pybind11::module *m) {
.def("get_feed_target_names", &pd::ProgramDesc::GetFeedTargetNames) .def("get_feed_target_names", &pd::ProgramDesc::GetFeedTargetNames)
.def("get_fetch_target_names", &pd::ProgramDesc::GetFetchTargetNames) .def("get_fetch_target_names", &pd::ProgramDesc::GetFetchTargetNames)
.def("serialize_to_string", SerializeMessage<pd::ProgramDesc>) .def("serialize_to_string", SerializeMessage<pd::ProgramDesc>)
.def("need_update", &pd::ProgramDesc::NeedUpdate)
.def("parse_from_string", .def("parse_from_string",
[](pd::ProgramDesc &program_desc, const std::string &data) { [](pd::ProgramDesc &program_desc, const std::string &data) {
pd::proto::ProgramDesc *desc = program_desc.Proto(); pd::proto::ProgramDesc *desc = program_desc.Proto();
......
...@@ -202,5 +202,44 @@ class TestProgram(unittest.TestCase): ...@@ -202,5 +202,44 @@ class TestProgram(unittest.TestCase):
self.assertFalse(var.has_stop_gradient()) self.assertFalse(var.has_stop_gradient())
def build_program():
main_program = paddle.static.Program()
startuo_program = paddle.static.Program()
with paddle.utils.unique_name.guard():
with paddle.static.program_guard(main_program, startuo_program):
x = paddle.static.data(name='x', shape=[3, 2, 1])
out = paddle.static.nn.fc(x=x, size=1, num_flatten_dims=2)
return main_program
class TestProgramProto(unittest.TestCase):
def test_update_op(self):
program = build_program()
a = program.desc.serialize_to_string()
program.current_block().ops[0]._set_attr('use_mkldnn', True)
self.assertTrue(program.desc.need_update())
b = program.desc.serialize_to_string()
self.assertFalse(a == b)
def test_update_var(self):
program = build_program()
a = program.desc.serialize_to_string()
program.current_block().var("x").desc.set_stop_gradient(False)
self.assertTrue(program.desc.need_update())
b = program.desc.serialize_to_string()
self.assertFalse(a == b)
# it seems the attrs of framework::VarDesc is not write to proto,
# except for persistable/need_check_feed/is_parameter/stop_gradient
def test_update_var_attr(self):
program = build_program()
a = program.desc.serialize_to_string()
program.current_block().var("x").desc._set_attr("a", 1)
self.assertFalse(program.desc.need_update())
b = program.desc.serialize_to_string()
self.assertTrue(a == b) # not affected
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册