Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
qq_38905368
tensorflow
提交
1f8936bb
T
tensorflow
项目概览
qq_38905368
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
5
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
1f8936bb
编写于
11月 01, 2016
作者:
A
A. Unique TensorFlower
提交者:
TensorFlower Gardener
11月 01, 2016
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Basic slow versions of resource-based variable ops which have the right semantics.
Change: 137863207
上级
b5c790e5
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
275 addition
and
10 deletion
+275
-10
tensorflow/core/framework/tensor.h
tensorflow/core/framework/tensor.h
+1
-0
tensorflow/core/kernels/resource_variable_ops.cc
tensorflow/core/kernels/resource_variable_ops.cc
+146
-8
tensorflow/core/ops/resource_variable_ops.cc
tensorflow/core/ops/resource_variable_ops.cc
+88
-2
tensorflow/python/kernel_tests/resource_variable_ops_test.py
tensorflow/python/kernel_tests/resource_variable_ops_test.py
+37
-0
tensorflow/python/ops/resource_variable_ops.py
tensorflow/python/ops/resource_variable_ops.py
+3
-0
未找到文件。
tensorflow/core/framework/tensor.h
浏览文件 @
1f8936bb
...
...
@@ -435,6 +435,7 @@ class Tensor {
friend
class
VariableOp
;
// For access to set_shape
friend
class
AutoReloadVariableOp
;
// For access to set_shape
friend
class
TensorTestHelper
;
// For access to set_shape
template
<
typename
Device
,
typename
T
>
friend
class
CreateVariableOp
;
// Creates a tensor with the input datatype, shape and buf.
...
...
tensorflow/core/kernels/resource_variable_ops.cc
浏览文件 @
1f8936bb
...
...
@@ -13,9 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/variable_ops.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/mutex.h"
...
...
@@ -25,25 +28,160 @@ namespace tensorflow {
REGISTER_RESOURCE_HANDLE_KERNEL
(
Var
);
template
<
typename
Device
,
typename
T
>
class
CreateVariableOp
:
public
OpKernel
{
public:
CreateVariableOp
(
OpKernelConstruction
*
c
)
:
OpKernel
(
c
)
{
OP_REQUIRES_OK
(
c
,
c
->
GetAttr
(
"dtype"
,
&
dtype_
));
OP_REQUIRES
(
c
,
DataTypeToEnum
<
T
>::
value
==
dtype_
,
errors
::
InvalidArgument
(
"Dtypes don't match; expected "
,
DataTypeString
(
dtype_
),
" got "
,
DataTypeString
(
DataTypeToEnum
<
T
>::
value
)));
}
void
Compute
(
OpKernelContext
*
c
)
override
{
void
Compute
(
OpKernelContext
*
c
ontext
)
override
{
Var
*
var
=
new
Var
(
dtype_
);
var
->
Ref
();
core
::
ScopedUnref
ur
(
var
);
OP_REQUIRES_OK
(
c
,
CreateResource
<
Var
>
(
c
,
HandleFromInput
(
c
,
0
),
var
));
// TODO(apassos): this currently does not initialize the tensor, so it's
// pointless, other than checking construction in tests. Fix this.
AllocatorAttributes
attr
;
attr
.
set_gpu_compatible
(
true
);
attr
.
set_nic_compatible
(
true
);
PersistentTensor
copy
;
Tensor
value
=
context
->
input
(
1
);
// TODO(apassos): allocating and copying is unnecessary if we are the last
// user of the value tensor. This should essentially always be the case, yet
// the refcount is usually 2 instead of 1. Figure out what needs to change
// in the code to make this not be the case, so we can safely take
// ownership.
Tensor
*
tmp_copy
=
nullptr
;
OP_REQUIRES_OK
(
context
,
context
->
allocate_persistent
(
dtype_
,
value
.
shape
(),
&
copy
,
&
tmp_copy
,
attr
));
*
var
->
tensor
()
=
*
tmp_copy
;
var
->
tensor
()
->
flat
<
T
>
().
device
(
context
->
eigen_device
<
Device
>
())
=
value
.
flat
<
T
>
();
OP_REQUIRES_OK
(
context
,
CreateResource
<
Var
>
(
context
,
HandleFromInput
(
context
,
0
),
var
));
}
private:
DataType
dtype_
;
};
REGISTER_KERNEL_BUILDER
(
Name
(
"CreateVariableOp"
).
Device
(
DEVICE_CPU
),
CreateVariableOp
);
// TODO(apassos) register for the GPU as well.
#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("CreateVariableOp") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("dtype"), \
CreateVariableOp<Eigen::ThreadPoolDevice, type>);
TF_CALL_ALL_TYPES
(
REGISTER_KERNELS
);
TF_CALL_QUANTIZED_TYPES
(
REGISTER_KERNELS
);
#undef REGISTER_KERNELS
template
<
typename
Device
,
typename
T
>
class
ReadVariableOp
:
public
OpKernel
{
public:
ReadVariableOp
(
OpKernelConstruction
*
c
)
:
OpKernel
(
c
)
{}
void
Compute
(
OpKernelContext
*
ctx
)
{
Var
*
variable
=
nullptr
;
OP_REQUIRES_OK
(
ctx
,
LookupResource
(
ctx
,
HandleFromInput
(
ctx
,
0
),
&
variable
));
core
::
ScopedUnref
s
(
variable
);
// TODO(apassos): It's possible to do copy-on-write here instead of always
// copying by coordinating with the writing code. Do this. This will also
// obviate the need to hold a lock here.
mutex_lock
ml
(
*
variable
->
mu
());
Tensor
*
out
=
nullptr
;
OP_REQUIRES_OK
(
ctx
,
ctx
->
allocate_output
(
0
,
variable
->
tensor
()
->
shape
(),
&
out
));
out
->
flat
<
T
>
().
device
(
ctx
->
eigen_device
<
Device
>
())
=
variable
->
tensor
()
->
flat
<
T
>
();
}
};
// TODO(apassos) register for the GPU as well.
#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("ReadVariableOp").Device(DEVICE_CPU).TypeConstraint<type>("dtype"), \
ReadVariableOp<Eigen::ThreadPoolDevice, type>);
TF_CALL_ALL_TYPES
(
REGISTER_KERNELS
);
TF_CALL_QUANTIZED_TYPES
(
REGISTER_KERNELS
);
#undef REGISTER_KERNELS
template
<
typename
Device
,
typename
T
>
class
AssignVariableOp
:
public
OpKernel
{
public:
AssignVariableOp
(
OpKernelConstruction
*
c
)
:
OpKernel
(
c
)
{}
void
Compute
(
OpKernelContext
*
context
)
override
{
Var
*
variable
=
nullptr
;
OP_REQUIRES_OK
(
context
,
LookupResource
(
context
,
HandleFromInput
(
context
,
0
),
&
variable
));
core
::
ScopedUnref
s
(
variable
);
// TODO(apassos): holding a lock and copying is unnecessary if we are the
// last user of the value tensor. This should essentially always be the
// case, yet the refcount is usually 2 instead of 1. Figure out what needs
// to change in the code to make this not be the case, so we can safely take
// ownership.
mutex_lock
ml
(
*
variable
->
mu
());
Tensor
value
=
context
->
input
(
1
);
variable
->
tensor
()
->
flat
<
T
>
().
device
(
context
->
eigen_device
<
Device
>
())
=
value
.
flat
<
T
>
();
}
};
// TODO(apassos) register for the GPU as well.
#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("AssignVariableOp") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("dtype"), \
AssignVariableOp<Eigen::ThreadPoolDevice, type>);
TF_CALL_ALL_TYPES
(
REGISTER_KERNELS
);
TF_CALL_QUANTIZED_TYPES
(
REGISTER_KERNELS
);
#undef REGISTER_KERNELS
template
<
typename
Device
,
typename
T
>
class
AssignAddVariableOp
:
public
OpKernel
{
public:
AssignAddVariableOp
(
OpKernelConstruction
*
c
)
:
OpKernel
(
c
)
{}
void
Compute
(
OpKernelContext
*
context
)
override
{
Var
*
variable
=
nullptr
;
OP_REQUIRES_OK
(
context
,
LookupResource
(
context
,
HandleFromInput
(
context
,
0
),
&
variable
));
core
::
ScopedUnref
s
(
variable
);
// TODO(apassos): holding a lock and copying is unnecessary if we are the
// last user of the value tensor. This should essentially always be the
// case, yet the refcount is usually 2 instead of 1. Figure out what needs
// to change in the code to make this not be the case, so we can safely take
// ownership.
mutex_lock
ml
(
*
variable
->
mu
());
Tensor
value
=
context
->
input
(
1
);
variable
->
tensor
()
->
flat
<
T
>
().
device
(
context
->
eigen_device
<
Device
>
())
+=
value
.
flat
<
T
>
();
// TODO(apassos): this read can also be implemented efficiently so it is
// free if no one uses the resulting tensor.
Tensor
*
out
=
nullptr
;
OP_REQUIRES_OK
(
context
,
context
->
allocate_output
(
0
,
variable
->
tensor
()
->
shape
(),
&
out
));
out
->
flat
<
T
>
().
device
(
context
->
eigen_device
<
Device
>
())
=
variable
->
tensor
()
->
flat
<
T
>
();
}
};
// TODO(apassos) register for the GPU as well.
#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("AssignAddVariableOp") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("dtype"), \
AssignAddVariableOp<Eigen::ThreadPoolDevice, type>);
TF_CALL_NUMBER_TYPES
(
REGISTER_KERNELS
);
#undef REGISTER_KERNELS
}
// namespace tensorflow
tensorflow/core/ops/resource_variable_ops.cc
浏览文件 @
1f8936bb
...
...
@@ -49,10 +49,88 @@ dtype: the type of this variable. Must agree with the dtypes
shape: The (possibly partially specified) shape of this variable.
)"
);
Status
CreateAssignShapeFn
(
shape_inference
::
InferenceContext
*
c
)
{
DataType
handle_dtype
=
c
->
input_handle_dtype
(
0
);
DataType
value_dtype
;
c
->
GetAttr
(
"dtype"
,
&
value_dtype
);
if
(
handle_dtype
!=
value_dtype
)
{
return
errors
::
InvalidArgument
(
"Trying to initialize handle for variable with wrong dtype. "
"Expected "
,
handle_dtype
,
" got "
,
value_dtype
);
}
shape_inference
::
ShapeHandle
s
=
c
->
input_handle_shape
(
0
);
shape_inference
::
ShapeHandle
value_shape
=
c
->
input
(
1
);
shape_inference
::
ShapeHandle
unused
;
TF_RETURN_IF_ERROR
(
c
->
Merge
(
s
,
value_shape
,
&
unused
));
return
Status
::
OK
();
}
REGISTER_OP
(
"CreateVariableOp"
)
.
Input
(
"resource: resource"
)
.
Input
(
"value: dtype"
)
.
Attr
(
"dtype: type"
)
.
SetShapeFn
(
CreateAssignShapeFn
)
.
Doc
(
R"(
Creates a variable resource.
resource: handle to the resource in which to store the variable.
value: the value to set the new tensor to use.
dtype: the dtype of the value.
)"
);
REGISTER_OP
(
"ReadVariableOp"
)
.
Input
(
"resource: resource"
)
.
Output
(
"value: dtype"
)
.
Attr
(
"dtype: type"
)
.
SetShapeFn
([](
shape_inference
::
InferenceContext
*
c
)
{
DataType
handle_dtype
=
c
->
input_handle_dtype
(
0
);
DataType
value_dtype
;
c
->
GetAttr
(
"dtype"
,
&
value_dtype
);
if
(
handle_dtype
!=
value_dtype
)
{
return
errors
::
InvalidArgument
(
"Trying to read variable with wrong dtype. "
"Expected "
,
handle_dtype
,
" got "
,
value_dtype
);
}
c
->
set_output
(
0
,
c
->
input_handle_shape
(
0
));
return
Status
::
OK
();
})
.
Doc
(
R"(
Reads the value of a variable.
The tensor returned by this operation is immutable.
The value returned by this operation is guaranteed to be influenced by all the
writes on which this operation depends directly or indirectly, and to not be
influenced by any of the writes which depend directly or indirectly on this
operation.
resource: handle to the resource in which to store the variable.
dtype: the dtype of the value.
)"
);
REGISTER_OP
(
"AssignVariableOp"
)
.
Input
(
"resource: resource"
)
.
Input
(
"value: dtype"
)
.
Attr
(
"dtype: type"
)
.
SetShapeFn
(
CreateAssignShapeFn
)
.
Doc
(
R"(
Assigns a new value to a variable.
Any ReadVariableOp with a control dependency on this op is guaranteed to return
this value or a subsequent newer value of the variable.
resource: handle to the resource in which to store the variable.
value: the value to set the new tensor to use.
dtype: the dtype of the value.
)"
);
REGISTER_OP
(
"AssignAddVariableOp"
)
.
Input
(
"resource: resource"
)
.
Input
(
"value: dtype"
)
.
Output
(
"new_value: dtype"
)
.
Attr
(
"dtype: type"
)
.
SetShapeFn
([](
shape_inference
::
InferenceContext
*
c
)
{
DataType
handle_dtype
=
c
->
input_handle_dtype
(
0
);
DataType
value_dtype
;
...
...
@@ -67,13 +145,21 @@ REGISTER_OP("CreateVariableOp")
shape_inference
::
ShapeHandle
value_shape
=
c
->
input
(
1
);
shape_inference
::
ShapeHandle
unused
;
TF_RETURN_IF_ERROR
(
c
->
Merge
(
s
,
value_shape
,
&
unused
));
c
->
set_output
(
0
,
value_shape
);
return
Status
::
OK
();
})
.
Doc
(
R"(
Creates a variable resource.
Adds a value to the current value of a variable.
Any ReadVariableOp which depends directly or indirectly on this assign is
guaranteed to see the incremented value or a subsequent newer one.
Outputs the incremented value, which can be used to totally order the
increments to this variable.
resource: handle to the resource in which to store the variable.
value: the value to set the new tensor to use.
value: the value by which the variable will be incremented.
new_value: the new value of the variable.
dtype: the dtype of the value.
)"
);
...
...
tensorflow/python/kernel_tests/resource_variable_ops_test.py
浏览文件 @
1f8936bb
...
...
@@ -19,6 +19,7 @@ from __future__ import print_function
from
tensorflow.python.framework
import
constant_op
from
tensorflow.python.framework
import
dtypes
from
tensorflow.python.framework
import
ops
from
tensorflow.python.framework
import
test_util
from
tensorflow.python.ops
import
array_ops
from
tensorflow.python.ops
import
resource_variable_ops
...
...
@@ -46,6 +47,42 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
resource_variable_ops
.
create_variable_op
(
id_handle
,
constant_op
.
constant
(
0
,
dtype
=
dtypes
.
int32
)).
run
()
def
testCreateRead
(
self
):
with
self
.
test_session
():
handle
=
resource_variable_ops
.
var_handle_op
(
dtype
=
dtypes
.
int32
,
shape
=
[])
resource_variable_ops
.
create_variable_op
(
handle
,
constant_op
.
constant
(
1
,
dtype
=
dtypes
.
int32
)).
run
()
value
=
resource_variable_ops
.
read_variable_op
(
handle
,
dtype
=
dtypes
.
int32
).
eval
()
self
.
assertAllEqual
(
1
,
value
)
def
testManyAssigns
(
self
):
with
self
.
test_session
()
as
session
:
handle
=
resource_variable_ops
.
var_handle_op
(
dtype
=
dtypes
.
int32
,
shape
=
[])
create
=
resource_variable_ops
.
create_variable_op
(
handle
,
constant_op
.
constant
(
1
,
dtype
=
dtypes
.
int32
))
with
ops
.
control_dependencies
([
create
]):
first_read
=
resource_variable_ops
.
read_variable_op
(
handle
,
dtype
=
dtypes
.
int32
)
with
ops
.
control_dependencies
([
first_read
]):
write
=
resource_variable_ops
.
assign_variable_op
(
handle
,
constant_op
.
constant
(
2
,
dtype
=
dtypes
.
int32
))
with
ops
.
control_dependencies
([
write
]):
second_read
=
resource_variable_ops
.
read_variable_op
(
handle
,
dtype
=
dtypes
.
int32
)
f
,
s
=
session
.
run
([
first_read
,
second_read
])
self
.
assertEqual
(
f
,
1
)
self
.
assertEqual
(
s
,
2
)
def
testAssignAdd
(
self
):
with
self
.
test_session
():
handle
=
resource_variable_ops
.
var_handle_op
(
dtype
=
dtypes
.
int32
,
shape
=
[])
resource_variable_ops
.
create_variable_op
(
handle
,
constant_op
.
constant
(
1
,
dtype
=
dtypes
.
int32
)).
run
()
assign_add
=
resource_variable_ops
.
assign_add_variable_op
(
handle
,
constant_op
.
constant
(
1
,
dtype
=
dtypes
.
int32
))
self
.
assertEqual
(
assign_add
.
eval
(),
2
)
if
__name__
==
"__main__"
:
test
.
main
()
tensorflow/python/ops/resource_variable_ops.py
浏览文件 @
1f8936bb
...
...
@@ -28,3 +28,6 @@ from tensorflow.python.ops.gen_resource_variable_ops import *
ops
.
RegisterShape
(
"VarHandleOp"
)(
common_shapes
.
call_cpp_shape_fn
)
ops
.
RegisterShape
(
"CreateVariableOp"
)(
common_shapes
.
call_cpp_shape_fn
)
ops
.
RegisterShape
(
"ReadVariableOp"
)(
common_shapes
.
call_cpp_shape_fn
)
ops
.
RegisterShape
(
"AssignVariableOp"
)(
common_shapes
.
call_cpp_shape_fn
)
ops
.
RegisterShape
(
"AssignAddVariableOp"
)(
common_shapes
.
call_cpp_shape_fn
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录