提交 300a21bf 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

[AutoSharding] Make sure that the 1D device mesh in cluster environment...

[AutoSharding] Make sure that the 1D device mesh in cluster environment matches the assumptions made by `ReshardingCostMixedMeshShape` in auto_sharding_utils.

PiperOrigin-RevId: 549428578
上级 038d2555
......@@ -1142,18 +1142,15 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding(
resharding_costs, input_shardings}));
}
CHECK_EQ(strategies->leaf_vector.size(), 1);
{
// If there is only one option for resharding, and the cost
// computed for that option is kInfinityCost, set the cost to
// zero. This is okay because there is only one option anyway, and
// having the costs set to kInfinityCost is problematic for the
// solver.
for (auto& operand_resharding_costs :
strategies->leaf_vector[0].resharding_costs) {
if (operand_resharding_costs.size() == 1 &&
operand_resharding_costs[0] >= kInfinityCost) {
operand_resharding_costs[0] = 0;
}
// If there is only one option for resharding, and the cost computed for
// that option is kInfinityCost, set the cost to zero. This is okay
// because there is only one option anyway, and having the costs set to
// kInfinityCost is problematic for the solver.
for (auto& operand_resharding_costs :
strategies->leaf_vector[0].resharding_costs) {
if (operand_resharding_costs.size() == 1 &&
operand_resharding_costs[0] >= kInfinityCost) {
operand_resharding_costs[0] = 0;
}
}
} else if (!strategies->following) {
......
......@@ -18,7 +18,7 @@ limitations under the License.
#include <algorithm>
#include <cstdint>
#include <iostream>
#include <iterator>
#include <memory>
#include <optional>
#include <ostream>
......@@ -59,9 +59,23 @@ class ClusterEnvironment {
non_zero_mesh_dims_ =
VectorGreaterThanOneElementIndices(device_mesh.dimensions());
GenerateCachedReplicaGroups();
// TODO(yuemmawang) Find the largest dimension in original_device_mesh and
// create 1d mesh on that dimension.
device_mesh_1d_.Reshape({original_device_mesh.num_elements(), 1});
// Essentially, we want to create a 1D mesh here such that the resharding
// costs between the original mesh and this 1D mesh are the least. This
// essentially means we create a 1D shape which stretches along the largest
// dimension of the original mesh. This will not however for asymmetric
// values of alpha and beta, I think.
// TODO(pratikf) Fix this for asymmetric alpha and beta values.
auto original_device_mesh_shape = original_device_mesh.dimensions();
auto max_dim_iterator = std::max_element(original_device_mesh_shape.begin(),
original_device_mesh_shape.end());
size_t largest_dim_idx =
std::distance(original_device_mesh_shape.begin(), max_dim_iterator);
std::vector<int64_t> device_mesh_1d_shape(
original_device_mesh.num_dimensions(), 1);
device_mesh_1d_shape[largest_dim_idx] = original_device_mesh.num_elements();
device_mesh_1d_.Reshape(device_mesh_1d_shape);
}
size_t NumDevices() const { return total_devices_; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册