提交 e11c2c28 编写于 作者: Y Yuefeng Zhou 提交者: TensorFlower Gardener

Record allocated sizes for tensors instead of actual tensor sizes.

Substract back temp memory for reduction op because its temp memory becomes output memory.
Change: 150130275
上级 fafd5b24
......@@ -207,7 +207,7 @@ class PersistentTensor {
int64 NumElements() const { return tensor_.NumElements(); }
int64 TotalBytes() const { return tensor_.TotalBytes(); }
int64 AllocatedBytes() const { return tensor_.AllocatedBytes(); }
private:
Tensor tensor_;
......
......@@ -730,6 +730,18 @@ size_t Tensor::TotalBytes() const {
return 0; // Makes compiler happy.
}
size_t Tensor::AllocatedBytes() const {
TensorDescription tensor_description;
FillDescription(&tensor_description);
if (tensor_description.has_allocation_description() &&
tensor_description.allocation_description().allocated_bytes() > 0) {
return tensor_description.allocation_description().allocated_bytes();
} else {
// Fall back to TotalBytes() if the allocator doesn't have its size.
return TotalBytes();
}
}
bool Tensor::CanUseDMA() const {
CASES(dtype(), return is_simple_type<T>::value);
return false; // Makes compiler happy.
......
......@@ -144,6 +144,9 @@ class Tensor {
/// Returns the estimated memory usage of this tensor.
size_t TotalBytes() const;
// Returns the size of sallocated memory for this tensor.
size_t AllocatedBytes() const;
/// Returns true iff this tensor is aligned.
bool IsAligned() const {
#if EIGEN_MAX_ALIGN_BYTES == 0
......
......@@ -235,6 +235,14 @@ class ReductionOp : public OpKernel {
if (!out.CopyFrom(tmp_out, helper.out_shape())) {
ctx->SetStatus(errors::Internal("Error during reduction copy."));
}
if (ctx->track_allocations()) {
// The temporary memory becomes the output memory.
if (ctx->allocate_on_host(alloc_attr)) {
ctx->record_host_temp_memory_size(-out.AllocatedBytes());
} else {
ctx->record_device_temp_memory_size(-out.AllocatedBytes());
}
}
ctx->set_output(0, out);
}
......
......@@ -84,7 +84,7 @@ int64 SizeOf(const std::deque<PersistentTensor>& sq) {
if (sq.empty()) {
return 0;
}
return sq.size() * sq.front().TotalBytes();
return sq.size() * sq.front().AllocatedBytes();
}
template <>
......@@ -92,7 +92,7 @@ int64 SizeOf(const std::vector<PersistentTensor>& sq) {
if (sq.empty()) {
return 0;
}
return sq.size() * sq.front().TotalBytes();
return sq.size() * sq.front().AllocatedBytes();
}
using TensorPair = std::pair<int64, PersistentTensor>;
......@@ -102,7 +102,7 @@ int64 SizeOf(const std::priority_queue<TensorPair, U, V>& sq) {
if (sq.empty()) {
return 0;
}
return sq.size() * (sizeof(TensorPair) + sq.top().second.TotalBytes());
return sq.size() * (sizeof(TensorPair) + sq.top().second.AllocatedBytes());
}
} // namespace
......
......@@ -157,10 +157,11 @@ class DestroyTemporaryVariableOp : public OpKernel {
context->step_container()->name(), var_name_));
if (context->track_allocations()) {
if (context->allocate_on_host(AllocatorAttributes())) {
context->record_host_persistent_memory_allocation(-tmpvar.TotalBytes());
context->record_host_persistent_memory_allocation(
-tmpvar.AllocatedBytes());
} else {
context->record_device_persistent_memory_allocation(
-tmpvar.TotalBytes());
-tmpvar.AllocatedBytes());
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册