提交 52f0eeac 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

[AutoSharding] Ensure that strategies are generated for custom call ops with...

[AutoSharding] Ensure that strategies are generated for custom call ops with user shardings. Previously, no shardings strategies were being generated for such ops.

PiperOrigin-RevId: 549414995
上级 7d00bce4
......@@ -1963,6 +1963,46 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence,
break;
}
case HloOpcode::kCustomCall: {
auto generate_non_following_strategies = [&](bool only_replicated) {
if (ins->shape().IsTuple()) {
if (only_replicated) {
strategies = CreateTupleStrategyVector(instruction_id);
strategies->childs.reserve(ins->shape().tuple_shapes_size());
for (size_t i = 0; i < ins->shape().tuple_shapes_size(); ++i) {
std::unique_ptr<StrategyVector> child_strategies =
CreateLeafStrategyVector(instruction_id, ins, strategy_map,
leaf_strategies);
AddReplicatedStrategy(ins, ins->shape().tuple_shapes(i),
cluster_env, strategy_map,
child_strategies, replicated_penalty);
strategies->childs.push_back(std::move(child_strategies));
}
} else {
strategies = CreateAllStrategiesVector(
ins, ins->shape(), instruction_id,
leaf_strategies, cluster_env, strategy_map,
solver_option, replicated_penalty, batch_dim_map,
call_graph, only_allow_divisible, true)
.value();
}
} else {
if (only_replicated) {
strategies = CreateLeafStrategyVector(
instruction_id, ins, strategy_map, leaf_strategies);
AddReplicatedStrategy(ins, ins->shape(), cluster_env,
strategy_map, strategies,
replicated_penalty);
} else {
strategies = CreateAllStrategiesVector(
ins, ins->shape(), instruction_id,
leaf_strategies, cluster_env, strategy_map,
solver_option, replicated_penalty, batch_dim_map,
call_graph, only_allow_divisible, true)
.value();
}
}
};
if (IsCustomCallMarker(ins)) {
const HloInstruction* operand = ins->operand(0);
const StrategyVector* src_strategies = strategy_map.at(operand).get();
......@@ -1972,12 +2012,7 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence,
/* have_memory_cost= */ true, leaf_strategies, cluster_env,
pretrimmed_strategy_map);
} else if (ins->has_sharding()) {
if (ins->shape().IsTuple()) {
strategies = CreateTupleStrategyVector(instruction_id);
} else {
strategies = CreateLeafStrategyVector(
instruction_id, ins, strategy_map, leaf_strategies);
}
generate_non_following_strategies(false);
} else if (OutputInputSameShapes(ins)) {
auto* partitioner =
GetCustomCallPartitioner(ins->custom_call_target());
......@@ -1994,24 +2029,7 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence,
}
} else {
// TODO (b/258723035) Handle CustomCall ops for GPUs in a better way.
if (ins->shape().IsTuple()) {
strategies = CreateTupleStrategyVector(instruction_id);
strategies->childs.reserve(ins->shape().tuple_shapes_size());
for (size_t i = 0; i < ins->shape().tuple_shapes_size(); ++i) {
std::unique_ptr<StrategyVector> child_strategies =
CreateLeafStrategyVector(instruction_id, ins, strategy_map,
leaf_strategies);
AddReplicatedStrategy(ins, ins->shape().tuple_shapes(i),
cluster_env, strategy_map, child_strategies,
replicated_penalty);
strategies->childs.push_back(std::move(child_strategies));
}
} else {
strategies = CreateLeafStrategyVector(
instruction_id, ins, strategy_map, leaf_strategies);
AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map,
strategies, replicated_penalty);
}
generate_non_following_strategies(true);
}
break;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册