提交 1f8936bb 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Basic slow versions of resource-based variable ops which have the right semantics.

Change: 137863207
上级 b5c790e5
...@@ -435,6 +435,7 @@ class Tensor { ...@@ -435,6 +435,7 @@ class Tensor {
friend class VariableOp; // For access to set_shape friend class VariableOp; // For access to set_shape
friend class AutoReloadVariableOp; // For access to set_shape friend class AutoReloadVariableOp; // For access to set_shape
friend class TensorTestHelper; // For access to set_shape friend class TensorTestHelper; // For access to set_shape
template <typename Device, typename T>
friend class CreateVariableOp; friend class CreateVariableOp;
// Creates a tensor with the input datatype, shape and buf. // Creates a tensor with the input datatype, shape and buf.
......
...@@ -13,9 +13,12 @@ See the License for the specific language governing permissions and ...@@ -13,9 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#define EIGEN_USE_THREADS
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/variable_ops.h" #include "tensorflow/core/kernels/variable_ops.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/mutex.h"
...@@ -25,25 +28,160 @@ namespace tensorflow { ...@@ -25,25 +28,160 @@ namespace tensorflow {
REGISTER_RESOURCE_HANDLE_KERNEL(Var); REGISTER_RESOURCE_HANDLE_KERNEL(Var);
template <typename Device, typename T>
class CreateVariableOp : public OpKernel { class CreateVariableOp : public OpKernel {
public: public:
CreateVariableOp(OpKernelConstruction* c) : OpKernel(c) { CreateVariableOp(OpKernelConstruction* c) : OpKernel(c) {
OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_)); OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
OP_REQUIRES(c, DataTypeToEnum<T>::value == dtype_,
errors::InvalidArgument(
"Dtypes don't match; expected ", DataTypeString(dtype_),
" got ", DataTypeString(DataTypeToEnum<T>::value)));
} }
void Compute(OpKernelContext* c) override { void Compute(OpKernelContext* context) override {
Var* var = new Var(dtype_); Var* var = new Var(dtype_);
var->Ref(); AllocatorAttributes attr;
core::ScopedUnref ur(var); attr.set_gpu_compatible(true);
OP_REQUIRES_OK(c, CreateResource<Var>(c, HandleFromInput(c, 0), var)); attr.set_nic_compatible(true);
// TODO(apassos): this currently does not initialize the tensor, so it's PersistentTensor copy;
// pointless, other than checking construction in tests. Fix this. Tensor value = context->input(1);
// TODO(apassos): allocating and copying is unnecessary if we are the last
// user of the value tensor. This should essentially always be the case, yet
// the refcount is usually 2 instead of 1. Figure out what needs to change
// in the code to make this not be the case, so we can safely take
// ownership.
Tensor* tmp_copy = nullptr;
OP_REQUIRES_OK(context, context->allocate_persistent(
dtype_, value.shape(), &copy, &tmp_copy, attr));
*var->tensor() = *tmp_copy;
var->tensor()->flat<T>().device(context->eigen_device<Device>()) =
value.flat<T>();
OP_REQUIRES_OK(context, CreateResource<Var>(
context, HandleFromInput(context, 0), var));
} }
private: private:
DataType dtype_; DataType dtype_;
}; };
REGISTER_KERNEL_BUILDER(Name("CreateVariableOp").Device(DEVICE_CPU),
CreateVariableOp); // TODO(apassos) register for the GPU as well.
#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("CreateVariableOp") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("dtype"), \
CreateVariableOp<Eigen::ThreadPoolDevice, type>);
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
template <typename Device, typename T>
class ReadVariableOp : public OpKernel {
public:
ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) {}
void Compute(OpKernelContext* ctx) {
Var* variable = nullptr;
OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 0), &variable));
core::ScopedUnref s(variable);
// TODO(apassos): It's possible to do copy-on-write here instead of always
// copying by coordinating with the writing code. Do this. This will also
// obviate the need to hold a lock here.
mutex_lock ml(*variable->mu());
Tensor* out = nullptr;
OP_REQUIRES_OK(ctx,
ctx->allocate_output(0, variable->tensor()->shape(), &out));
out->flat<T>().device(ctx->eigen_device<Device>()) =
variable->tensor()->flat<T>();
}
};
// TODO(apassos) register for the GPU as well.
#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("ReadVariableOp").Device(DEVICE_CPU).TypeConstraint<type>("dtype"), \
ReadVariableOp<Eigen::ThreadPoolDevice, type>);
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
template <typename Device, typename T>
class AssignVariableOp : public OpKernel {
public:
AssignVariableOp(OpKernelConstruction* c) : OpKernel(c) {}
void Compute(OpKernelContext* context) override {
Var* variable = nullptr;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&variable));
core::ScopedUnref s(variable);
// TODO(apassos): holding a lock and copying is unnecessary if we are the
// last user of the value tensor. This should essentially always be the
// case, yet the refcount is usually 2 instead of 1. Figure out what needs
// to change in the code to make this not be the case, so we can safely take
// ownership.
mutex_lock ml(*variable->mu());
Tensor value = context->input(1);
variable->tensor()->flat<T>().device(context->eigen_device<Device>()) =
value.flat<T>();
}
};
// TODO(apassos) register for the GPU as well.
#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("AssignVariableOp") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("dtype"), \
AssignVariableOp<Eigen::ThreadPoolDevice, type>);
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
template <typename Device, typename T>
class AssignAddVariableOp : public OpKernel {
public:
AssignAddVariableOp(OpKernelConstruction* c) : OpKernel(c) {}
void Compute(OpKernelContext* context) override {
Var* variable = nullptr;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&variable));
core::ScopedUnref s(variable);
// TODO(apassos): holding a lock and copying is unnecessary if we are the
// last user of the value tensor. This should essentially always be the
// case, yet the refcount is usually 2 instead of 1. Figure out what needs
// to change in the code to make this not be the case, so we can safely take
// ownership.
mutex_lock ml(*variable->mu());
Tensor value = context->input(1);
variable->tensor()->flat<T>().device(context->eigen_device<Device>()) +=
value.flat<T>();
// TODO(apassos): this read can also be implemented efficiently so it is
// free if no one uses the resulting tensor.
Tensor* out = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(
0, variable->tensor()->shape(), &out));
out->flat<T>().device(context->eigen_device<Device>()) =
variable->tensor()->flat<T>();
}
};
// TODO(apassos) register for the GPU as well.
#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("AssignAddVariableOp") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("dtype"), \
AssignAddVariableOp<Eigen::ThreadPoolDevice, type>);
TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
} // namespace tensorflow } // namespace tensorflow
...@@ -49,10 +49,88 @@ dtype: the type of this variable. Must agree with the dtypes ...@@ -49,10 +49,88 @@ dtype: the type of this variable. Must agree with the dtypes
shape: The (possibly partially specified) shape of this variable. shape: The (possibly partially specified) shape of this variable.
)"); )");
Status CreateAssignShapeFn(shape_inference::InferenceContext* c) {
DataType handle_dtype = c->input_handle_dtype(0);
DataType value_dtype;
c->GetAttr("dtype", &value_dtype);
if (handle_dtype != value_dtype) {
return errors::InvalidArgument(
"Trying to initialize handle for variable with wrong dtype. "
"Expected ",
handle_dtype, " got ", value_dtype);
}
shape_inference::ShapeHandle s = c->input_handle_shape(0);
shape_inference::ShapeHandle value_shape = c->input(1);
shape_inference::ShapeHandle unused;
TF_RETURN_IF_ERROR(c->Merge(s, value_shape, &unused));
return Status::OK();
}
REGISTER_OP("CreateVariableOp") REGISTER_OP("CreateVariableOp")
.Input("resource: resource") .Input("resource: resource")
.Input("value: dtype") .Input("value: dtype")
.Attr("dtype: type") .Attr("dtype: type")
.SetShapeFn(CreateAssignShapeFn)
.Doc(R"(
Creates a variable resource.
resource: handle to the resource in which to store the variable.
value: the value to set the new tensor to use.
dtype: the dtype of the value.
)");
REGISTER_OP("ReadVariableOp")
.Input("resource: resource")
.Output("value: dtype")
.Attr("dtype: type")
.SetShapeFn([](shape_inference::InferenceContext* c) {
DataType handle_dtype = c->input_handle_dtype(0);
DataType value_dtype;
c->GetAttr("dtype", &value_dtype);
if (handle_dtype != value_dtype) {
return errors::InvalidArgument(
"Trying to read variable with wrong dtype. "
"Expected ",
handle_dtype, " got ", value_dtype);
}
c->set_output(0, c->input_handle_shape(0));
return Status::OK();
})
.Doc(R"(
Reads the value of a variable.
The tensor returned by this operation is immutable.
The value returned by this operation is guaranteed to be influenced by all the
writes on which this operation depends directly or indirectly, and to not be
influenced by any of the writes which depend directly or indirectly on this
operation.
resource: handle to the resource in which to store the variable.
dtype: the dtype of the value.
)");
REGISTER_OP("AssignVariableOp")
.Input("resource: resource")
.Input("value: dtype")
.Attr("dtype: type")
.SetShapeFn(CreateAssignShapeFn)
.Doc(R"(
Assigns a new value to a variable.
Any ReadVariableOp with a control dependency on this op is guaranteed to return
this value or a subsequent newer value of the variable.
resource: handle to the resource in which to store the variable.
value: the value to set the new tensor to use.
dtype: the dtype of the value.
)");
REGISTER_OP("AssignAddVariableOp")
.Input("resource: resource")
.Input("value: dtype")
.Output("new_value: dtype")
.Attr("dtype: type")
.SetShapeFn([](shape_inference::InferenceContext* c) { .SetShapeFn([](shape_inference::InferenceContext* c) {
DataType handle_dtype = c->input_handle_dtype(0); DataType handle_dtype = c->input_handle_dtype(0);
DataType value_dtype; DataType value_dtype;
...@@ -67,13 +145,21 @@ REGISTER_OP("CreateVariableOp") ...@@ -67,13 +145,21 @@ REGISTER_OP("CreateVariableOp")
shape_inference::ShapeHandle value_shape = c->input(1); shape_inference::ShapeHandle value_shape = c->input(1);
shape_inference::ShapeHandle unused; shape_inference::ShapeHandle unused;
TF_RETURN_IF_ERROR(c->Merge(s, value_shape, &unused)); TF_RETURN_IF_ERROR(c->Merge(s, value_shape, &unused));
c->set_output(0, value_shape);
return Status::OK(); return Status::OK();
}) })
.Doc(R"( .Doc(R"(
Creates a variable resource. Adds a value to the current value of a variable.
Any ReadVariableOp which depends directly or indirectly on this assign is
guaranteed to see the incremented value or a subsequent newer one.
Outputs the incremented value, which can be used to totally order the
increments to this variable.
resource: handle to the resource in which to store the variable. resource: handle to the resource in which to store the variable.
value: the value to set the new tensor to use. value: the value by which the variable will be incremented.
new_value: the new value of the variable.
dtype: the dtype of the value. dtype: the dtype of the value.
)"); )");
......
...@@ -19,6 +19,7 @@ from __future__ import print_function ...@@ -19,6 +19,7 @@ from __future__ import print_function
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import resource_variable_ops
...@@ -46,6 +47,42 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): ...@@ -46,6 +47,42 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
resource_variable_ops.create_variable_op( resource_variable_ops.create_variable_op(
id_handle, constant_op.constant(0, dtype=dtypes.int32)).run() id_handle, constant_op.constant(0, dtype=dtypes.int32)).run()
def testCreateRead(self):
with self.test_session():
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
resource_variable_ops.create_variable_op(
handle, constant_op.constant(1, dtype=dtypes.int32)).run()
value = resource_variable_ops.read_variable_op(
handle, dtype=dtypes.int32).eval()
self.assertAllEqual(1, value)
def testManyAssigns(self):
with self.test_session() as session:
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
create = resource_variable_ops.create_variable_op(
handle, constant_op.constant(1, dtype=dtypes.int32))
with ops.control_dependencies([create]):
first_read = resource_variable_ops.read_variable_op(
handle, dtype=dtypes.int32)
with ops.control_dependencies([first_read]):
write = resource_variable_ops.assign_variable_op(
handle, constant_op.constant(2, dtype=dtypes.int32))
with ops.control_dependencies([write]):
second_read = resource_variable_ops.read_variable_op(
handle, dtype=dtypes.int32)
f, s = session.run([first_read, second_read])
self.assertEqual(f, 1)
self.assertEqual(s, 2)
def testAssignAdd(self):
with self.test_session():
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
resource_variable_ops.create_variable_op(
handle, constant_op.constant(1, dtype=dtypes.int32)).run()
assign_add = resource_variable_ops.assign_add_variable_op(
handle, constant_op.constant(1, dtype=dtypes.int32))
self.assertEqual(assign_add.eval(), 2)
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()
...@@ -28,3 +28,6 @@ from tensorflow.python.ops.gen_resource_variable_ops import * ...@@ -28,3 +28,6 @@ from tensorflow.python.ops.gen_resource_variable_ops import *
ops.RegisterShape("VarHandleOp")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("VarHandleOp")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("CreateVariableOp")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("CreateVariableOp")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("ReadVariableOp")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("AssignVariableOp")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("AssignAddVariableOp")(common_shapes.call_cpp_shape_fn)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册