提交 ec702337 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Modify Hash() function of HloComputation and HloInstruction to prevent...

Modify Hash() function of HloComputation and HloInstruction to prevent non-termination from infinite recursive calls.

PiperOrigin-RevId: 225412890
上级 d501a62a
...@@ -711,8 +711,6 @@ bool HloComputation::operator==(const HloComputation& other) const { ...@@ -711,8 +711,6 @@ bool HloComputation::operator==(const HloComputation& other) const {
return eq(root_instruction(), other.root_instruction()); return eq(root_instruction(), other.root_instruction());
} }
uint64 HloComputation::Hash() const { return root_instruction()->Hash(); }
Status HloComputation::ReplaceWithNewInstruction( Status HloComputation::ReplaceWithNewInstruction(
HloInstruction* old_instruction, HloInstruction* old_instruction,
std::unique_ptr<HloInstruction> new_instruction) { std::unique_ptr<HloInstruction> new_instruction) {
......
...@@ -264,12 +264,6 @@ class HloComputation { ...@@ -264,12 +264,6 @@ class HloComputation {
// Return whether `*this` and `other` are functionally equivalent. // Return whether `*this` and `other` are functionally equivalent.
bool operator==(const HloComputation& other) const; bool operator==(const HloComputation& other) const;
// Generates a hash value of an HLO computation. Hash considers
// information on opcode, shape, operands, and typically a root instruction.
// This function returns the same hash value for equivalent HLO computations,
// with respect to HloInstruction::Identical() method.
uint64 Hash() const;
// Replaces old instruction with newly created instruction. Removes old // Replaces old instruction with newly created instruction. Removes old
// instruction from computation. Updates uses and root instruction. // instruction from computation. Updates uses and root instruction.
Status ReplaceWithNewInstruction( Status ReplaceWithNewInstruction(
......
...@@ -1761,7 +1761,12 @@ bool HloInstruction::IdenticalSlowPath( ...@@ -1761,7 +1761,12 @@ bool HloInstruction::IdenticalSlowPath(
return false; return false;
} }
uint64 HloInstruction::Hash() const { static uint64 HashOperand(const HloInstruction* hlo) {
return ShapeUtil::Hash(hlo->shape());
}
uint64 HloInstruction::Hash(
const std::function<uint64(const HloInstruction*)>& hash_operand) const {
using tensorflow::Hash64Combine; using tensorflow::Hash64Combine;
uint64 hash_value = Hash64Combine(0, static_cast<uint64>(opcode())); uint64 hash_value = Hash64Combine(0, static_cast<uint64>(opcode()));
...@@ -1770,7 +1775,7 @@ uint64 HloInstruction::Hash() const { ...@@ -1770,7 +1775,7 @@ uint64 HloInstruction::Hash() const {
if (!IsCrossModuleAllReduce()) { if (!IsCrossModuleAllReduce()) {
if (!operands().empty()) { if (!operands().empty()) {
for (size_t i = 0; i < operands().size(); ++i) { for (size_t i = 0; i < operands().size(); ++i) {
hash_value = Hash64Combine(hash_value, operand(i)->Hash()); hash_value = Hash64Combine(hash_value, hash_operand(operand(i)));
} }
} }
} }
...@@ -1779,6 +1784,11 @@ uint64 HloInstruction::Hash() const { ...@@ -1779,6 +1784,11 @@ uint64 HloInstruction::Hash() const {
return hash_value; return hash_value;
} }
uint64 HloInstruction::Hash() const {
// Use HashOperand as an argument to prevent non-termination.
return Hash(HashOperand);
}
uint64 HloInstruction::InnerHash() const { return 13; } uint64 HloInstruction::InnerHash() const { return 13; }
void HloInstruction::RemoveUser(HloInstruction* user) { void HloInstruction::RemoveUser(HloInstruction* user) {
......
...@@ -909,6 +909,14 @@ class HloInstruction { ...@@ -909,6 +909,14 @@ class HloInstruction {
// information on opcode, shape, operands, and typically a root instruction. // information on opcode, shape, operands, and typically a root instruction.
// This function returns the same hash value for equivalent HLO instructions, // This function returns the same hash value for equivalent HLO instructions,
// with respect to HloInstruction::Identical() method. // with respect to HloInstruction::Identical() method.
//
// Uses hash_operand function to compute hash values of its operands.
// At the very top level, hash_operand should be non-recursive to prevent
// non-termination.
uint64 Hash(
const std::function<uint64(const HloInstruction*)>& hash_operand) const;
// Calls the above method with non-recursive hash_operand function.
uint64 Hash() const; uint64 Hash() const;
// Returns whether the instruction has a constant operand. // Returns whether the instruction has a constant operand.
......
...@@ -1372,8 +1372,14 @@ bool HloFusionInstruction::IdenticalSlowPath( ...@@ -1372,8 +1372,14 @@ bool HloFusionInstruction::IdenticalSlowPath(
other.fused_instructions_computation()); other.fused_instructions_computation());
} }
static uint64 HashOperandRecursive(const HloInstruction* hlo) {
return hlo->Hash(HashOperandRecursive);
}
uint64 HloFusionInstruction::InnerHash() const { uint64 HloFusionInstruction::InnerHash() const {
return fused_instructions_computation()->Hash(); // Use HashOperandRecursive to recursively compute hash on inner operands.
return fused_instructions_computation()->root_instruction()->Hash(
HashOperandRecursive);
} }
std::unique_ptr<HloInstruction> HloFusionInstruction::CloneWithNewOperandsImpl( std::unique_ptr<HloInstruction> HloFusionInstruction::CloneWithNewOperandsImpl(
......
...@@ -136,7 +136,9 @@ class HloModule { ...@@ -136,7 +136,9 @@ class HloModule {
// information on opcode, shape, operands, and typically a root instruction. // information on opcode, shape, operands, and typically a root instruction.
// This function returns the same hash value for equivalent HLO modules, // This function returns the same hash value for equivalent HLO modules,
// with respect to HloInstruction::Identical() method. // with respect to HloInstruction::Identical() method.
uint64 Hash() const { return entry_computation()->Hash(); } uint64 Hash() const {
return entry_computation()->root_instruction()->Hash();
}
// Gets the computations in this module. // Gets the computations in this module.
// //
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册