提交 9c805809 编写于 作者: J Justin Szaday 提交者: TensorFlower Gardener

Use inferred vs. actual shapes for multi-device output layouts.

PiperOrigin-RevId: 549417829
上级 52f0eeac
......@@ -1840,6 +1840,10 @@ void DTensorDevice::ExecuteMultiDeviceOperation(
for (int i = 0; i < num_output_layouts; i++) {
const Layout& output_layout = function.output_layouts[i];
const int num_devices = function.num_local_outputs[i];
ASSIGN_OR_RETURN_C_STATUS(
const std::vector<int64_t> local_output_shape,
GetTensorShapeAsVector(function.local_output_shapes[output_offset]),
status);
std::vector<TensorHandlePtr> layout_outputs;
for (int j = 0; j < num_devices; j++) {
const int output_idx = output_offset + j;
......@@ -1848,7 +1852,8 @@ void DTensorDevice::ExecuteMultiDeviceOperation(
output_offset += num_devices;
ASSIGN_OR_RETURN_C_STATUS(
auto local_output,
CreateTensorWithLayout(std::move(layout_outputs), output_layout),
CreateTensorWithLayout(std::move(layout_outputs), output_layout,
local_output_shape),
status);
outputs[i] = std::move(local_output);
}
......
......@@ -288,21 +288,26 @@ StatusOr<Layout> GetLayoutThroughIdentityOps(Node* op, int output_index) {
char TensorWithLayoutTf::ID = 0;
StatusOr<std::vector<int64_t>> GetTensorShapeAsVector(
const tensorflow::PartialTensorShape& shape) {
const int dims = shape.dims();
if (dims < 0) {
return absl::InvalidArgumentError("Unavailable tensor shape!");
}
std::vector<int64_t> result;
result.reserve(dims);
for (const TensorShapeDim& dim : shape) {
result.emplace_back(dim.size);
}
return result;
}
StatusOr<std::vector<int64_t>> GetTensorShapeAsVector(
TFE_TensorHandle* tensor) {
tensorflow::PartialTensorShape shape;
const Status status = tensorflow::unwrap(tensor)->Shape(&shape);
if (status.ok()) {
const int dims = shape.dims();
if (dims < 0) {
return absl::InvalidArgumentError("Unavailable tensor shape!");
}
std::vector<int64_t> result;
result.reserve(dims);
for (const TensorShapeDim& dim : shape) {
result.emplace_back(dim.size);
}
return result;
return GetTensorShapeAsVector(shape);
} else {
return status;
}
......
......@@ -431,6 +431,11 @@ std::unique_ptr<TensorWithLayoutTf> CreateDummyTensorWithLayout(
const std::vector<int64_t>& local_shape, TF_DataType dtype,
const Layout& layout);
// Creates a DTensor from one or more tensor handles and a compatible
// layout. Optionally accepts a `shape` argument that overrides the
// actual shape of the underlying tensors; this argument should be
// provided when there's a possibility of the inferred shape from
// differing from the actual shape (like when it is dynamic).
StatusOr<std::unique_ptr<TensorWithLayoutTf>> CreateTensorWithLayout(
std::vector<TensorHandlePtr>&& tensor, const Layout& layout,
std::optional<std::vector<int64_t>>&& shape = std::nullopt);
......@@ -568,6 +573,10 @@ class ExecutableManager : public tsl::core::WeakRefCounted {
} stats_;
};
// Returns the shape of a given tensor.
StatusOr<std::vector<int64_t>> GetTensorShapeAsVector(
const tensorflow::PartialTensorShape& shape);
// Returns the shape of a given tensor.
StatusOr<std::vector<int64_t>> GetTensorShapeAsVector(TFE_TensorHandle* tensor);
......
......@@ -170,7 +170,7 @@ bool IsDynamicSize(int64_t size) {
return mlir::ShapedType::isDynamic(size) || size == -1;
}
bool IsDynamicShape(const std::vector<int64_t>& shape) {
bool IsDynamicShape(absl::Span<const int64_t> shape) {
for (int64_t size : shape) {
if (IsDynamicSize(size)) return true;
}
......@@ -1129,7 +1129,11 @@ std::vector<int64_t> Layout::GlobalShapeFromLocalShape(
absl::Span<const int64_t> local_shape,
const std::vector<std::vector<int64_t>>* local_shapes) const {
if (IsSingleDevice() || IsFullyReplicated()) {
return std::vector<int64_t>(local_shape.begin(), local_shape.end());
if (IsDynamicShape(local_shape) && local_shapes) {
return local_shapes->at(0);
} else {
return std::vector<int64_t>(local_shape.begin(), local_shape.end());
}
}
std::vector<int64_t> stride_for_dim;
......
......@@ -50,7 +50,7 @@ bool IsDynamicSize(int64_t size);
// Returns true if `shape` is a dynamic shape based on either MLIR and TF
// standards.
bool IsDynamicShape(const std::vector<int64_t>& shape);
bool IsDynamicShape(absl::Span<const int64_t> shape);
// The location of a device in a mesh.
//
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册