提交 95358d2d 编写于 作者: A Alexandre Passos 提交者: TensorFlower Gardener

Changing the copy-on-write semantics of resource variables.

A variable now has a bit which can be turned on which, when turned on,
makes that variable act as copy-on-read instead of copy-on-write. This
allows sparse writes to happen concurrently while only holding a shared
lock, mimicking the use_locking behavior of ref variables.

PiperOrigin-RevId: 224855851
上级 4e7564ef
......@@ -79,6 +79,13 @@ XlaDeviceContext::XlaDeviceContext(
}
}
void XlaDeviceContext::CopyTensorInSameDevice(const Tensor* input_tensor,
Device* device,
Tensor* output_tensor,
StatusCallback done) const {
done(errors::Unimplemented("XLA->XLA same-device copies not implemented."));
}
void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
Device* device,
Tensor* device_tensor,
......
......@@ -62,6 +62,9 @@ class XlaDeviceContext : public DeviceContext {
void CopyDeviceTensorToCPU(const Tensor* device_tensor,
absl::string_view tensor_name, Device* device,
Tensor* cpu_tensor, StatusCallback done) override;
void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device,
Tensor* output_tensor,
StatusCallback done) const override;
xla::LocalClient* client() const { return client_; }
se::Stream* stream() const { return stream_.get(); }
......
......@@ -37,6 +37,14 @@ void GPUDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
GPUUtil::CopyGPUTensorToCPU(device, this, device_tensor, cpu_tensor, done);
}
void GPUDeviceContext::CopyTensorInSameDevice(const Tensor* input_tensor,
Device* device,
Tensor* output_tensor,
StatusCallback done) const {
GPUUtil::CopyGPUTensorToSameGPU(device, this, input_tensor, output_tensor,
done);
}
Status GPUDeviceContext::ThenExecute(Device* device, se::Stream* stream,
std::function<void()> func) {
const DeviceBase::GpuDeviceInfo* gpu_info =
......
......@@ -57,6 +57,10 @@ class GPUDeviceContext : public DeviceContext {
Device* device, Tensor* cpu_tensor,
StatusCallback done) override;
void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device,
Tensor* output_tensor,
StatusCallback done) const override;
void MaintainLifetimeOnStream(const Tensor* t,
se::Stream* stream) const override {}
......
......@@ -82,6 +82,13 @@ class DeviceContext : public core::RefCounted {
done(errors::Internal("Unrecognized device type in CPU-to-device Copy"));
}
// Copies a tensor in this device.
virtual void CopyTensorInSameDevice(const Tensor* input_tensor,
Device* device, Tensor* output_tensor,
StatusCallback done) const {
done(errors::Unimplemented("Copy in same device not implemented."));
}
// "device_tensor" is a tensor on a non-CPU device. Copies
// device_tensor into "cpu_tensor". "cpu_tensor" must be allocated
// to be of the same size as "device_tensor".
......
......@@ -278,6 +278,12 @@ class DummyDeviceContext : public DeviceContext {
~DummyDeviceContext() override {}
int stream_id() const { return stream_id_; }
void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device,
Tensor* output_tensor,
StatusCallback done) const override {
done(Status::OK());
}
private:
const int stream_id_;
};
......
......@@ -20,14 +20,46 @@ limitations under the License.
namespace tensorflow {
// Resource stored by variables in the resource manager
// (new, resource-style version).
// Resource stored by variables in the resource manager (new, resource-style
// version).
//
// These variables have a mixed access mode: they can operate on copy-on-write
// mode (the default) or copy-on-read mode (used only for sparse access).
//
// When copy-on-write mode is enabled reading the value of the variable involves
// grabbing its mutex in shared mode and aliasing the internal tensor as the
// output of the read operation, increasing its reference count. Writing,
// conversely, works by, under an exclusive lock, detecting whether there are
// outstanding aliases of the tensor, using the reference count, copying the
// tensor if they exist, and writing to either the original or a copy with no
// outstanding aliases. Sparse operations are not supported in copy-on-write
// mode.
//
// When a variable is accessed sparsely it switches to copy-on-read mode. To
// switch we need to grab an exclusive lock and might (if there are aliases)
// need to copy the entire tensor. Once copy-on-read mode is enabled, no tensor
// is allowed to alias the variable's internal tensor. This means dense reads
// must return a copy of the variable, done while holding a shared lock. Dense
// writes do not need to check whether aliases exist, and can always write
// directly to the buffer without making a copy, while holding an exclusive
// lock. Sparse reads and sparse writes, on the other hand, can be done under a
// shared or exclusive mutex (the damage from writes under a shared mutex is
// limited since no other buffer is allowed to alias the variable's
// buffer). Using an exclusive mutex disallows concurrent writes and concurrent
// sparse reads, providing some extra safety at the expense of performance,
// while shared mutex allow for "hogwild" behavior. Doing sparse writes under a
// shared mutex prevents them from overlapping with dense writes, which is
// necessary as dense writes can change the shape the of the tensor.
//
// Transitioning a variable from copy-on-read mode to copy-on-write mode is
// currently not supported. To upgrade a variable from copy-on-write to
// copy-on-read use `EnsureSparseVariableAccess()`, and then grab the variable's
// mutex as desired. To access the variable in dense mode grab the mutex either
// directly or via `MaybeLockVariableInputMutexesInOrder` on all variables being
// modified and then call `PrepareToUpdateVariable` on them in any order.
class Var : public ResourceBase {
public:
explicit Var(DataType dtype) : tensor_(dtype) {}
// Not copyable or movable.
Var(const Var&) = delete;
Var& operator=(const Var&) = delete;
// When locking multiple variables, the locks must be acquired in order of
// increasing mu() address.
......@@ -48,11 +80,19 @@ class Var : public ResourceBase {
bool is_initialized = false; // GUARDED_BY(mu_) but annotalysis doesn't like
// it.
// Also fake-guarded by mu_. Should be set to True whenever any sparse
// operation uses the variable. Once this is true no tensor is allowed to
// alias the memory of the variable, and we always copy the variable on
// reads. This allows sparse operations to happen with only a shared lock if
// so desired.
std::atomic<bool> copy_on_read_mode{false};
private:
mutex mu_;
Tensor tensor_;
~Var() override {}
TF_DISALLOW_COPY_AND_ASSIGN(Var);
};
} // end namespace tensorflow
......
......@@ -45,6 +45,7 @@ class TensorBuffer;
class TensorCApi;
class TensorDescription;
class TensorProto;
class Var;
namespace batch_util {
Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index);
......@@ -581,11 +582,16 @@ class Tensor {
friend class XlaTensor; // For access to RefCountIsOne().
friend class XlaTensorBuffer; // For access to the private constructor taking
// the buffer
friend class Var;
template <typename Device, typename T>
friend class AssignVariableOp; // For access to RefCountIsOne().
template <typename Device, typename T>
friend Status PrepareToUpdateVariable(
OpKernelContext* ctx, Tensor* tensor); // For access to RefCountIsOne().
OpKernelContext* ctx, Tensor* tensor,
bool copy_on_read_mode); // For access to RefCountIsOne().
template <typename Device, typename T>
friend Status EnsureSparseVariableAccess(
OpKernelContext* ctx, Var* var); // For access to RefCountIsOne().
friend Status batch_util::CopyElementToSlice(
Tensor element, Tensor* parent,
int64 index); // For access to RefCountIsOne().
......
......@@ -2196,6 +2196,7 @@ tf_kernel_library(
":state",
":training_op_helpers",
":variable_ops",
"//tensorflow/core:core_cpu_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
......
......@@ -55,6 +55,7 @@ limitations under the License.
#include <vector>
#include "absl/strings/str_join.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/resource_mgr.h"
......@@ -84,6 +85,47 @@ ReadVariableOp::ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) {
OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
}
namespace {
Status CopyVariable(int output_idx, OpKernelContext* ctx, const Tensor* t) {
Tensor* output;
Notification n;
Status status;
AllocatorAttributes attr;
if (t->dtype() == DT_VARIANT) {
attr.set_on_host(true);
}
TF_RETURN_IF_ERROR(
ctx->allocate_output(output_idx, t->shape(), &output, attr));
if (t->dtype() == DT_VARIANT) {
output->flat<Variant>() = t->flat<Variant>();
} else if (ctx->op_device_context() != nullptr) {
// TODO(apassos): remove the down_cast by just returning Device* from
// OpKernelContext
Device* device = static_cast<Device*>(ctx->device());
ctx->op_device_context()->CopyTensorInSameDevice(
t, device, output, [&n, &status](const Status& s) {
status = s;
n.Notify();
});
n.WaitForNotification();
return status;
} else {
switch (t->dtype()) {
#define HANDLER(type) \
case DataTypeToEnum<type>::value: \
output->flat<type>() = t->flat<type>(); \
break;
TF_CALL_ALL_TYPES(HANDLER);
#undef HANDLER
default:
return errors::Internal("Unsupported dtype", t->dtype());
}
}
return Status::OK();
}
} // namespace
void ReadVariableOp::Compute(OpKernelContext* ctx) {
Var* variable = nullptr;
const ResourceHandle& handle = HandleFromInput(ctx, 0);
......@@ -100,12 +142,16 @@ void ReadVariableOp::Compute(OpKernelContext* ctx) {
// holding a shared lock to guarantee ordering of reads and
// writes.
tf_shared_lock ml(*variable->mu());
const Tensor& t = *variable->tensor();
OP_REQUIRES(ctx, dtype_ == t.dtype(),
const Tensor* t = variable->tensor();
OP_REQUIRES(ctx, dtype_ == t->dtype(),
errors::InvalidArgument(
"Trying to read variable with wrong dtype. Expected ",
DataTypeString(dtype_), " got ", DataTypeString(t.dtype())));
ctx->set_output(0, t);
DataTypeString(dtype_), " got ", DataTypeString(t->dtype())));
if (variable->copy_on_read_mode.load()) {
OP_REQUIRES_OK(ctx, CopyVariable(0, ctx, t));
} else {
ctx->set_output(0, *t);
}
}
ReadVariablesOp::ReadVariablesOp(OpKernelConstruction* c) : OpKernel(c) {
......@@ -146,14 +192,18 @@ void ReadVariablesOp::Compute(OpKernelContext* ctx) {
// holding a shared lock to guarantee ordering of reads and
// writes.
tf_shared_lock ml(*variables[i]->mu());
const Tensor& t = *variables[i]->tensor();
OP_REQUIRES(ctx, dtypes_[i] == t.dtype(),
OP_REQUIRES(ctx, dtypes_[i] == variables[i]->tensor()->dtype(),
errors::InvalidArgument(
"Trying to read variable ", handles[i]->name(),
" from Container: ", handles[i]->container(),
" with wrong dtype. Expected ", DataTypeString(dtypes_[i]),
" got ", DataTypeString(t.dtype())));
ctx->set_output(i, t);
" got ", DataTypeString(variables[i]->tensor()->dtype())));
if (variables[i]->copy_on_read_mode.load()) {
OP_REQUIRES_OK(ctx, CopyVariable(i, ctx, variables[i]->tensor()));
} else {
const Tensor& t = *variables[i]->tensor();
ctx->set_output(i, t);
}
}
}
......@@ -308,8 +358,23 @@ class AssignVariableOp : public OpKernel {
"Trying to assign variable with wrong dtype. Expected ",
DataTypeString(variable->tensor()->dtype()), " got ",
DataTypeString(dtype_)));
if (variable->copy_on_read_mode.load()) {
PersistentTensor unused;
Tensor* tmp;
AllocatorAttributes attr;
attr.set_gpu_compatible(true);
attr.set_nic_compatible(true);
OP_REQUIRES_OK(context,
context->allocate_persistent(value.dtype(), value.shape(),
&unused, &tmp, attr));
functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
copy_functor(context->eigen_device<Device>(), tmp->flat<T>(),
value.flat<T>());
*variable->tensor() = *tmp;
} else {
*variable->tensor() = value;
}
variable->is_initialized = true;
*variable->tensor() = value;
}
private:
......@@ -442,8 +507,9 @@ class AssignUpdateVariableOp : public OpKernel {
" using a Tensor with shape ",
value.shape().DebugString(),
", shapes must be equal."));
OP_REQUIRES_OK(context,
PrepareToUpdateVariable<Device, T>(context, var_tensor));
OP_REQUIRES_OK(
context, PrepareToUpdateVariable<Device, T>(
context, var_tensor, variable->copy_on_read_mode.load()));
functor::DenseUpdate<Device, T, Op> update_functor;
update_functor(context->eigen_device<Device>(), var_tensor->flat<T>(),
value.flat<T>());
......@@ -524,6 +590,7 @@ class ResourceGatherOp : public OpKernel {
Var* v = nullptr;
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
core::ScopedUnref su(v);
OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v));
// NOTE: We hold the lock for the whole gather operation instead
// of increasing the reference count of v->tensor() to avoid a
// situation where a write to the same variable will see a
......@@ -639,9 +706,9 @@ class ResourceScatterUpdateOp : public OpKernel {
Var* v = nullptr;
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
core::ScopedUnref unref_v(v);
mutex_lock ml(*v->mu());
OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v));
tf_shared_lock ml(*v->mu());
Tensor* params = v->tensor();
OP_REQUIRES_OK(c, PrepareToUpdateVariable<Device, T>(c, params));
const Tensor& indices = c->input(1);
const Tensor& updates = c->input(2);
......
......@@ -231,6 +231,7 @@ class ScatterNdUpdateOp : public OpKernel {
Var* v;
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
core::ScopedUnref scoped_unref(v);
OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v));
mutex_lock m(*v->mu());
DoCompute(c);
} else if (use_exclusive_lock_) {
......@@ -258,7 +259,6 @@ class ScatterNdUpdateOp : public OpKernel {
Var* v;
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
Tensor* t = v->tensor();
OP_REQUIRES_OK(c, PrepareToUpdateVariable<Device, T>(c, t));
params = *t;
params_shape = params.shape();
} else if (IsRefType(c->input_dtype(0))) {
......
......@@ -307,9 +307,9 @@ class StridedSliceAssignOp : public OpKernel {
OP_REQUIRES_OK(context,
LookupResource(context, HandleFromInput(context, 0), &v));
core::ScopedUnref scoped_unref(v);
mutex_lock ml(*v->mu());
OP_REQUIRES_OK(context,
PrepareToUpdateVariable<Device, T>(context, v->tensor()));
EnsureSparseVariableAccess<Device, T>(context, v));
mutex_lock ml(*v->mu());
old_lhs = v->tensor();
OP_REQUIRES(context, old_lhs->dtype() == DataTypeToEnum<T>::value,
errors::InvalidArgument(
......
......@@ -19,70 +19,6 @@ limitations under the License.
namespace tensorflow {
mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input,
Var** maybe_resource) {
*maybe_resource = nullptr;
if (ctx->input_dtype(input) == DT_RESOURCE) {
if (LookupResource(ctx, HandleFromInput(ctx, input), maybe_resource).ok()) {
return (*maybe_resource)->mu();
} else {
ctx->CtxFailureWithWarning(
errors::Internal("Invalid variable reference."));
return nullptr;
}
}
return ctx->input_ref_mutex(input);
}
// MaybeLockVariableInputMutexesInOrder is a helper function to acquire mutexes
// in address order to mitigate deadlock. Returns a structure that, when
// deleted, will release the acquired mutexes. Safe to pass duplicates - will
// only lock each distinct mutex once. If do_lock is false, returns
// immediately. Note that this silently doesn't lock mutexes for invalid
// variable references; in all usages this is followed by GetInputTensor which
// will signal a failure.
VariableInputLockHolder MaybeLockVariableInputMutexesInOrder(
OpKernelContext* ctx, bool do_lock, const std::vector<int>& input_ids) {
bool any_resource = false;
for (auto i : input_ids) {
if (ctx->input_dtype(i) == DT_RESOURCE) {
any_resource = true;
break;
}
}
if (!do_lock && !any_resource) {
return VariableInputLockHolder({}, {});
}
std::vector<Var*> vars;
std::vector<mutex*> mutexes;
std::vector<int> acquire_order;
for (auto input : input_ids) {
Var* var;
mutex* mutex = GetTrainingVariableMutex(ctx, input, &var);
if (var) vars.push_back(var);
// Only lock each mutex once if duplicates exist (n^2 but n is 2 or 3).
if (std::find(mutexes.begin(), mutexes.end(), mutex) == mutexes.end()) {
acquire_order.push_back(mutexes.size());
mutexes.push_back(mutex);
}
}
std::sort(acquire_order.begin(), acquire_order.end(),
[&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; });
std::unique_ptr<std::vector<mutex_lock>> locks =
MakeUnique<std::vector<mutex_lock>>();
locks->reserve(acquire_order.size());
for (auto input : acquire_order) {
Var* var;
mutex* mu = GetTrainingVariableMutex(ctx, input, &var);
core::ScopedUnref scoped_unref(var);
if (mu != nullptr) {
locks->emplace_back(*mu);
}
}
return VariableInputLockHolder(std::move(vars), std::move(locks));
}
void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input,
int output) {
......
......@@ -17,30 +17,72 @@ limitations under the License.
#define TENSORFLOW_CORE_KERNELS_TRAINING_OP_HELPERS_H_
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/kernels/dense_update_functor.h"
#include "tensorflow/core/kernels/variable_ops.h"
namespace tensorflow {
// Returns a borrowed pointer to the mutex for the variable `input` in `ctx`.
//
// If `input` corresponds to a `DT_RESOURCE`-type variable input,
// `*maybe_resource` will be updated to contain the underlying resource, and the
// caller will be responsible for calling `Unref()` on that resource.
mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input,
Var** maybe_resource);
// Must be called before performing a sparse operation on a variable. Ensures
// that no concurrent dense operations can happen while holding the variable's
// lock.
template <typename Device, typename T>
Status EnsureSparseVariableAccess(OpKernelContext* ctx, Var* var) {
if (var->copy_on_read_mode.load()) {
return Status::OK();
}
mutex_lock ml(*var->mu());
// Once copy-on-read mode is True the refcount is guaranteed to be 1. This can
// also happen if there are no concurrent reads of the variable and
// copy-on-read mode is false.
if (var->tensor()->RefCountIsOne()) {
var->copy_on_read_mode.store(true);
return Status::OK();
}
PersistentTensor unused;
Tensor* tmp;
if (std::is_same<T, Variant>::value) {
AllocatorAttributes attr;
attr.set_on_host(true);
TF_RETURN_IF_ERROR(ctx->allocate_persistent(
var->tensor()->dtype(), var->tensor()->shape(), &unused, &tmp, attr));
const auto elements_in = var->tensor()->flat<Variant>();
auto elements_out = tmp->flat<Variant>();
for (int64 i = 0; i < elements_in.size(); ++i) {
elements_out(i) = elements_in(i);
}
} else {
AllocatorAttributes attr;
attr.set_gpu_compatible(true);
attr.set_nic_compatible(true);
TF_RETURN_IF_ERROR(ctx->allocate_persistent(
var->tensor()->dtype(), var->tensor()->shape(), &unused, &tmp, attr));
functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
copy_functor(ctx->eigen_device<Device>(), tmp->flat<T>(),
const_cast<const Tensor*>(var->tensor())->flat<T>());
}
*var->tensor() = *tmp;
var->copy_on_read_mode.store(true);
return Status::OK();
}
// Utility structure that releases a sequence of borrowed mutexes when it is
// deleted.
struct VariableInputLockHolder {
public:
VariableInputLockHolder(std::vector<Var*> vars,
std::unique_ptr<std::vector<mutex_lock>> locks)
: vars_(std::move(vars)), locks_(std::move(locks)) {}
VariableInputLockHolder(
std::vector<Var*> vars, std::unique_ptr<std::vector<mutex_lock>> locks,
std::unique_ptr<std::vector<tf_shared_lock>> shared_locks)
: vars_(std::move(vars)),
locks_(std::move(locks)),
shared_locks_(std::move(shared_locks)) {}
VariableInputLockHolder(VariableInputLockHolder&& other)
: vars_(std::move(other.vars_)), locks_(std::move(other.locks_)) {}
: vars_(std::move(other.vars_)),
locks_(std::move(other.locks_)),
shared_locks_(std::move(other.shared_locks_)) {}
~VariableInputLockHolder() {
// Release the locks before unreffing the Vars, because each lock
......@@ -56,10 +98,95 @@ struct VariableInputLockHolder {
// NOTE: Use a `std::unique_ptr` instead of moving in a vector directly,
// because a `std::vector<mutex_lock>` is not movable on all platforms.
std::unique_ptr<std::vector<mutex_lock>> locks_;
std::unique_ptr<std::vector<tf_shared_lock>> shared_locks_;
};
// Returns a borrowed pointer to the mutex for the variable `input` in `ctx`.
//
// If `input` corresponds to a `DT_RESOURCE`-type variable input,
// `*maybe_resource` will be updated to contain the underlying resource, and the
// caller will be responsible for calling `Unref()` on that resource.
template <typename Device, typename T>
mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input, bool sparse,
Var** maybe_resource) {
*maybe_resource = nullptr;
if (ctx->input_dtype(input) == DT_RESOURCE) {
if (LookupResource(ctx, HandleFromInput(ctx, input), maybe_resource).ok()) {
if (sparse) {
EnsureSparseVariableAccess<Device, T>(ctx, *maybe_resource);
}
return (*maybe_resource)->mu();
} else {
ctx->CtxFailureWithWarning(
errors::Internal("Invalid variable reference."));
return nullptr;
}
}
return ctx->input_ref_mutex(input);
}
// MaybeLockVariableInputMutexesInOrder is a helper function to acquire mutexes
// in address order to mitigate deadlock. Returns a structure that, when
// deleted, will release the acquired mutexes. Safe to pass duplicates - will
// only lock each distinct mutex once. If sparse is true will ensure the
// variable gets switched to copy-on-read mode before trying to acquire the
// locks. If do_lock is false, returns immediately for reference variables. For
// resource variables in copy-on-read-mode it will grab a shared lock if do_lock
// is false, exclusive lock otherwise. Note that this silently doesn't lock
// mutexes for invalid variable references; in all usages this is followed by
// GetInputTensor which will signal a failure.
template <typename Device, typename T>
VariableInputLockHolder MaybeLockVariableInputMutexesInOrder(
OpKernelContext* ctx, bool do_lock, const std::vector<int>& input_ids);
OpKernelContext* ctx, bool do_lock, bool sparse,
const std::vector<int>& input_ids) {
bool any_resource = false;
for (auto i : input_ids) {
if (ctx->input_dtype(i) == DT_RESOURCE) {
any_resource = true;
break;
}
}
if (!do_lock && !any_resource) {
return VariableInputLockHolder({}, {}, {});
}
std::vector<Var*> vars;
std::vector<mutex*> mutexes;
std::vector<int> acquire_order;
for (auto input : input_ids) {
Var* var;
mutex* mutex =
GetTrainingVariableMutex<Device, T>(ctx, input, sparse, &var);
if (var) vars.push_back(var);
// Only lock each mutex once if duplicates exist (n^2 but n is 2 or 3).
if (std::find(mutexes.begin(), mutexes.end(), mutex) == mutexes.end()) {
acquire_order.push_back(mutexes.size());
mutexes.push_back(mutex);
}
}
std::sort(acquire_order.begin(), acquire_order.end(),
[&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; });
std::unique_ptr<std::vector<mutex_lock>> locks =
absl::make_unique<std::vector<mutex_lock>>();
std::unique_ptr<std::vector<tf_shared_lock>> shared_locks =
absl::make_unique<std::vector<tf_shared_lock>>();
locks->reserve(acquire_order.size());
for (auto input : acquire_order) {
Var* var;
mutex* mu = GetTrainingVariableMutex<Device, T>(ctx, input, sparse, &var);
core::ScopedUnref scoped_unref(var);
if (mu != nullptr) {
if (do_lock) {
locks->emplace_back(*mu);
} else {
shared_locks->emplace_back(*mu);
}
}
}
return VariableInputLockHolder(std::move(vars), std::move(locks),
std::move(shared_locks));
}
void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input,
int output);
......@@ -68,8 +195,9 @@ void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input,
// reference count of 1 before you update it.
// REQUIRES: If you pass in variable->tensor(), *variable->mu() must be held.
template <typename Device, typename T>
Status PrepareToUpdateVariable(OpKernelContext* ctx, Tensor* tensor) {
if (!tensor->RefCountIsOne()) {
Status PrepareToUpdateVariable(OpKernelContext* ctx, Tensor* tensor,
bool copy_on_read_mode) {
if (copy_on_read_mode || !tensor->RefCountIsOne()) {
// Tensor's buffer is in use by some read, so we need to copy before
// updating.
PersistentTensor unused;
......@@ -100,12 +228,14 @@ Status PrepareToUpdateVariable(OpKernelContext* ctx, Tensor* tensor) {
return Status::OK();
}
// This gives you `*out`, a tensor you can update, corresponding to a
// variable passed as input index `input`. This handles the
// differences between reference and resource variables. For resource
// variables, we ensure `*out` has a reference count of 1 (using
// PrepareToUpdateVariable() to copy if necessary) unless
// sparse && !lock_held, in which case it never copies.
// This gives you `*out`, a tensor you can update, corresponding to a variable
// passed as input index `input`. This handles the differences between
// reference and resource variables. For reference variables we can just grab
// the tensor, grabbing the lock if lock_held is False.
//
// For resource variables we, if sparse is true, ensure it's in copy-on-read
// mode, and then, regardless of the value of sparse, ensure its refcount is 1
// (by potentially copying its contents). In this case lock_held is ignored.
template <typename Device, typename T>
Status GetInputTensorFromVariable(OpKernelContext* ctx, int input,
bool lock_held, bool sparse, Tensor* out) {
......@@ -113,7 +243,13 @@ Status GetInputTensorFromVariable(OpKernelContext* ctx, int input,
Var* var;
TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, input), &var));
core::ScopedUnref unref_var(var);
TF_RETURN_IF_ERROR(PrepareToUpdateVariable<Device, T>(ctx, var->tensor()));
if (sparse) {
TF_RETURN_IF_ERROR(EnsureSparseVariableAccess<Device, T>(ctx, var));
*out = *var->tensor();
return Status::OK();
}
TF_RETURN_IF_ERROR(PrepareToUpdateVariable<Device, T>(
ctx, var->tensor(), var->copy_on_read_mode.load()));
*out = *var->tensor();
return Status::OK();
}
......
......@@ -36,6 +36,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import list_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
......@@ -953,6 +954,19 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
state_ops.scatter_sub(v, [1], [3])
self.assertAllEqual([1.0, -1.0], v.numpy())
def testScatterUpdateVariant(self):
with context.eager_mode():
v = resource_variable_ops.ResourceVariable([
list_ops.empty_tensor_list(
element_dtype=dtypes.float32, element_shape=[])
])
v.scatter_update(
ops.IndexedSlices(
list_ops.tensor_list_from_tensor([1., 2.], element_shape=[]), 0))
self.assertAllEqual(
list_ops.tensor_list_get_item(v[0], 0, element_dtype=dtypes.float32),
1.)
def testScatterNdAddStateOps(self):
with context.eager_mode():
v = resource_variable_ops.ResourceVariable(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册