Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
3663bd88
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3663bd88
编写于
8月 24, 2017
作者:
Q
qingqing01
提交者:
GitHub
8月 24, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #3620 from qingqing01/lookup_table
Add a lookup table op and a CUDA helper.
上级
e5cbeb02
aafeff0f
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
363 addition
and
5 deletion
+363
-5
paddle/operators/CMakeLists.txt
paddle/operators/CMakeLists.txt
+3
-2
paddle/operators/fill_zeros_like_op.h
paddle/operators/fill_zeros_like_op.h
+1
-1
paddle/operators/lookup_table_op.cc
paddle/operators/lookup_table_op.cc
+72
-0
paddle/operators/lookup_table_op.cu
paddle/operators/lookup_table_op.cu
+116
-0
paddle/operators/lookup_table_op.h
paddle/operators/lookup_table_op.h
+75
-0
paddle/platform/cuda_helper.h
paddle/platform/cuda_helper.h
+51
-0
paddle/pybind/CMakeLists.txt
paddle/pybind/CMakeLists.txt
+1
-0
paddle/pybind/pybind.cc
paddle/pybind/pybind.cc
+1
-0
python/paddle/v2/framework/tests/CMakeLists.txt
python/paddle/v2/framework/tests/CMakeLists.txt
+1
-0
python/paddle/v2/framework/tests/gradient_checker.py
python/paddle/v2/framework/tests/gradient_checker.py
+11
-2
python/paddle/v2/framework/tests/test_lookup_table.py
python/paddle/v2/framework/tests/test_lookup_table.py
+31
-0
未找到文件。
paddle/operators/CMakeLists.txt
浏览文件 @
3663bd88
...
@@ -42,6 +42,7 @@ function(op_library TARGET)
...
@@ -42,6 +42,7 @@ function(op_library TARGET)
endfunction
()
endfunction
()
add_subdirectory
(
math
)
add_subdirectory
(
math
)
cc_test
(
gather_test SRCS gather_test.cc DEPS tensor
)
cc_test
(
gather_test SRCS gather_test.cc DEPS tensor
)
op_library
(
gather_op SRCS gather_op.cc gather_op.cu
)
op_library
(
gather_op SRCS gather_op.cc gather_op.cu
)
...
@@ -67,7 +68,7 @@ op_library(sgd_op SRCS sgd_op.cc sgd_op.cu)
...
@@ -67,7 +68,7 @@ op_library(sgd_op SRCS sgd_op.cc sgd_op.cu)
op_library
(
recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
op_library
(
recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS framework_proto tensor op_registry operator net_op
)
DEPS framework_proto tensor op_registry operator net_op
)
op_library
(
uniform_random_op
op_library
(
uniform_random_op
SRCS uniform_random_op.cc uniform_random_op.cu
)
SRCS uniform_random_op.cc uniform_random
_op.cu
)
op_library
(
lookup_table_op SRCS lookup_table_op.cc lookup_table
_op.cu
)
op_library
(
scale_op SRCS scale_op.cc scale_op.cu DEPS net_op
)
op_library
(
scale_op SRCS scale_op.cc scale_op.cu DEPS net_op
)
op_library
(
minus_op SRCS minus_op.cc minus_op.cu DEPS scale_op
)
op_library
(
minus_op SRCS minus_op.cc minus_op.cu DEPS scale_op
)
paddle/operators/fill_zeros_like_op.h
浏览文件 @
3663bd88
...
@@ -26,7 +26,7 @@ class FillZerosLikeKernel : public framework::OpKernel {
...
@@ -26,7 +26,7 @@ class FillZerosLikeKernel : public framework::OpKernel {
auto
*
output
=
context
.
Output
<
framework
::
Tensor
>
(
"Dst"
);
auto
*
output
=
context
.
Output
<
framework
::
Tensor
>
(
"Dst"
);
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
t
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
output
);
auto
t
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
output
);
t
.
device
(
context
.
GetEigenDevice
<
Place
>
())
=
t
.
constant
(
T
(
0
));
t
.
device
(
context
.
GetEigenDevice
<
Place
>
())
=
t
.
constant
(
static_cast
<
T
>
(
0
));
}
}
};
};
...
...
paddle/operators/lookup_table_op.cc
0 → 100644
浏览文件 @
3663bd88
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/lookup_table_op.h"
namespace
paddle
{
namespace
operators
{
class
LookupTableOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
context
)
const
override
{
auto
table_t
=
context
.
Input
<
Tensor
>
(
"W"
);
auto
ids_t
=
context
.
Input
<
Tensor
>
(
"Ids"
);
auto
output_t
=
context
.
Output
<
Tensor
>
(
"Out"
);
output_t
->
Resize
({
ids_t
->
dims
()[
0
],
table_t
->
dims
()[
1
]});
}
};
class
LookupTableOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
LookupTableOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"W"
,
"An input represents embedding tensors,"
" which is a learnable parameter."
);
AddInput
(
"Ids"
,
"An input with type int32 or int64"
"contains the ids to be looked up in W."
);
AddOutput
(
"Out"
,
"The lookup results, which have the same type with W."
);
AddComment
(
"This operator is used to perform lookups on the parameter W,"
"then concatenated into a dense tensor."
);
}
};
class
LookupTableOpGrad
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
context
)
const
override
{
auto
table
=
context
.
Input
<
Tensor
>
(
"W"
);
auto
d_table
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"W"
));
d_table
->
Resize
(
table
->
dims
());
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
lookup_table
,
ops
::
LookupTableOp
,
ops
::
LookupTableOpMaker
,
lookup_table_grad
,
ops
::
LookupTableOpGrad
);
REGISTER_OP_CPU_KERNEL
(
lookup_table
,
ops
::
LookupTableKernel
<
float
>
);
REGISTER_OP_CPU_KERNEL
(
lookup_table_grad
,
ops
::
LookupTableGradKernel
<
float
>
);
paddle/operators/lookup_table_op.cu
0 → 100644
浏览文件 @
3663bd88
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/platform/assert.h"
#include "paddle/platform/cuda_helper.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
,
int
BlockDimX
,
int
BlockDimY
,
int
GridDimX
>
__global__
void
LookupTable
(
T
*
output
,
const
T
*
table
,
const
int32_t
*
ids
,
const
int
N
,
const
int
K
,
const
int
D
)
{
int
idx
=
threadIdx
.
x
;
int
idy
=
blockIdx
.
x
+
threadIdx
.
y
*
GridDimX
;
while
(
idy
<
K
)
{
int
id
=
ids
[
idy
];
PADDLE_ASSERT
(
id
>=
0
);
PADDLE_ASSERT
(
id
<
N
);
T
*
out
=
output
+
idy
*
D
;
const
T
*
tab
=
table
+
id
*
D
;
for
(
int
i
=
idx
;
i
<
D
;
i
+=
BlockDimX
)
{
out
[
i
]
=
tab
[
i
];
}
idy
+=
BlockDimY
*
GridDimX
;
}
}
template
<
typename
T
,
int
BlockDimX
,
int
BlockDimY
,
int
GridDimX
>
__global__
void
LookupTableGrad
(
T
*
table
,
const
T
*
output
,
const
int32_t
*
ids
,
const
int
N
,
const
int
K
,
const
int
D
)
{
int
idx
=
threadIdx
.
x
;
int
idy
=
blockIdx
.
x
+
threadIdx
.
y
*
GridDimX
;
while
(
idy
<
K
)
{
int
id
=
ids
[
idy
];
PADDLE_ASSERT
(
id
>=
0
);
PADDLE_ASSERT
(
id
<
N
);
const
T
*
out
=
output
+
idy
*
D
;
T
*
tab
=
table
+
id
*
D
;
for
(
int
i
=
idx
;
i
<
D
;
i
+=
BlockDimX
)
{
paddle
::
platform
::
CudaAtomicAdd
(
&
tab
[
i
],
out
[
i
]);
}
idy
+=
BlockDimY
*
GridDimX
;
}
}
template
<
typename
T
>
class
LookupTableCUDAKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
table_t
=
context
.
Input
<
Tensor
>
(
"W"
);
auto
ids_t
=
context
.
Input
<
Tensor
>
(
"Ids"
);
auto
output_t
=
context
.
Output
<
Tensor
>
(
"Out"
);
size_t
N
=
table_t
->
dims
()[
0
];
size_t
D
=
table_t
->
dims
()[
1
];
size_t
K
=
product
(
ids_t
->
dims
());
auto
ids
=
ids_t
->
data
<
int32_t
>
();
auto
table
=
table_t
->
data
<
T
>
();
auto
output
=
output_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
dim3
threads
(
128
,
8
);
dim3
grids
(
8
,
1
);
LookupTable
<
T
,
128
,
8
,
8
><<<
grids
,
threads
>>>
(
output
,
table
,
ids
,
N
,
K
,
D
);
}
};
template
<
typename
T
>
class
LookupTableGradCUDAKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
ids_t
=
context
.
Input
<
Tensor
>
(
"Ids"
);
auto
d_output_t
=
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
d_table_t
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"W"
));
int
N
=
d_table_t
->
dims
()[
0
];
int
D
=
d_table_t
->
dims
()[
1
];
int
K
=
product
(
ids_t
->
dims
());
const
int32_t
*
ids
=
ids_t
->
data
<
int32_t
>
();
const
T
*
d_output
=
d_output_t
->
data
<
T
>
();
T
*
d_table
=
d_table_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
t
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
d_table_t
);
t
.
device
(
context
.
GetEigenDevice
<
platform
::
GPUPlace
>
())
=
t
.
constant
(
static_cast
<
T
>
(
0
));
dim3
threads
(
128
,
8
);
dim3
grids
(
8
,
1
);
LookupTableGrad
<
T
,
128
,
8
,
8
><<<
grids
,
threads
>>>
(
d_table
,
d_output
,
ids
,
N
,
K
,
D
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_GPU_KERNEL
(
lookup_table
,
ops
::
LookupTableCUDAKernel
<
float
>
);
REGISTER_OP_GPU_KERNEL
(
lookup_table_grad
,
ops
::
LookupTableGradCUDAKernel
<
float
>
);
paddle/operators/lookup_table_op.h
0 → 100644
浏览文件 @
3663bd88
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
class
LookupTableKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
table_t
=
context
.
Input
<
Tensor
>
(
"W"
);
// float tensor
auto
ids_t
=
context
.
Input
<
Tensor
>
(
"Ids"
);
// int tensor
auto
output_t
=
context
.
Output
<
Tensor
>
(
"Out"
);
// float tensor
size_t
N
=
table_t
->
dims
()[
0
];
size_t
D
=
table_t
->
dims
()[
1
];
auto
ids
=
ids_t
->
data
<
int32_t
>
();
auto
table
=
table_t
->
data
<
T
>
();
auto
output
=
output_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
for
(
size_t
i
=
0
;
i
<
product
(
ids_t
->
dims
());
++
i
)
{
PADDLE_ENFORCE_LT
(
ids
[
i
],
N
);
PADDLE_ENFORCE_GE
(
ids
[
i
],
0
);
memcpy
(
output
+
i
*
D
,
table
+
ids
[
i
]
*
D
,
D
*
sizeof
(
T
));
}
}
};
template
<
typename
T
>
class
LookupTableGradKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
ids_t
=
context
.
Input
<
Tensor
>
(
"Ids"
);
auto
d_output_t
=
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
d_table_t
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"W"
));
size_t
N
=
d_table_t
->
dims
()[
0
];
size_t
D
=
d_table_t
->
dims
()[
1
];
auto
ids
=
ids_t
->
data
<
int32_t
>
();
const
T
*
d_output
=
d_output_t
->
data
<
T
>
();
T
*
d_table
=
d_table_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
t
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
d_table_t
);
t
.
device
(
context
.
GetEigenDevice
<
platform
::
CPUPlace
>
())
=
t
.
constant
(
static_cast
<
T
>
(
0
));
for
(
size_t
i
=
0
;
i
<
product
(
ids_t
->
dims
());
++
i
)
{
PADDLE_ENFORCE_LT
(
ids
[
i
],
N
);
PADDLE_ENFORCE_GE
(
ids
[
i
],
0
);
for
(
size_t
j
=
0
;
j
<
D
;
++
j
)
{
d_table
[
ids
[
i
]
*
D
+
j
]
+=
d_output
[
i
*
D
+
j
];
}
}
}
};
}
// namespace operators
}
// namespace paddle
paddle/platform/cuda_helper.h
0 → 100644
浏览文件 @
3663bd88
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cuda.h>
namespace
paddle
{
namespace
platform
{
#define CUDA_ATOMIC_WRAPPER(op, T) \
__device__ __forceinline__ T CudaAtomic##op(T* address, const T val)
#define USE_CUDA_ATOMIC(op, T) \
CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); }
// For atomicAdd.
USE_CUDA_ATOMIC
(
Add
,
float
);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
USE_CUDA_ATOMIC
(
Add
,
double
);
#else
CUDA_ATOMIC_WRAPPER
(
Add
,
double
)
{
unsigned
long
long
int
*
address_as_ull
=
reinterpret_cast
<
unsigned
long
long
int
*>
(
address
);
unsigned
long
long
int
old
=
*
address_as_ull
,
assumed
;
do
{
assumed
=
old
;
old
=
atomicCAS
(
address_as_ull
,
assumed
,
__double_as_longlong
(
val
+
__longlong_as_double
(
assumed
)));
// Note: uses integer comparison to avoid hang in case of NaN
}
while
(
assumed
!=
old
);
return
__longlong_as_double
(
old
);
}
#endif
}
// namespace platform
}
// namespace paddle
paddle/pybind/CMakeLists.txt
浏览文件 @
3663bd88
...
@@ -15,6 +15,7 @@ cc_library(paddle_pybind SHARED
...
@@ -15,6 +15,7 @@ cc_library(paddle_pybind SHARED
uniform_random_op
uniform_random_op
gaussian_random_op
gaussian_random_op
fill_zeros_like_op
fill_zeros_like_op
lookup_table_op
scale_op
scale_op
minus_op
)
minus_op
)
endif
(
WITH_PYTHON
)
endif
(
WITH_PYTHON
)
paddle/pybind/pybind.cc
浏览文件 @
3663bd88
...
@@ -42,6 +42,7 @@ USE_OP(fill_zeros_like);
...
@@ -42,6 +42,7 @@ USE_OP(fill_zeros_like);
USE_OP_ITSELF
(
recurrent_op
);
USE_OP_ITSELF
(
recurrent_op
);
USE_OP
(
gaussian_random
);
USE_OP
(
gaussian_random
);
USE_OP
(
uniform_random
);
USE_OP
(
uniform_random
);
USE_OP
(
lookup_table
);
USE_OP
(
scale
);
USE_OP
(
scale
);
USE_OP_ITSELF
(
identity
);
USE_OP_ITSELF
(
identity
);
USE_OP
(
minus
);
USE_OP
(
minus
);
...
...
python/paddle/v2/framework/tests/CMakeLists.txt
浏览文件 @
3663bd88
...
@@ -28,5 +28,6 @@ py_test(test_uniform_random_op SRCS test_uniform_random_op.py)
...
@@ -28,5 +28,6 @@ py_test(test_uniform_random_op SRCS test_uniform_random_op.py)
py_test
(
test_recurrent_op SRCS test_recurrent_op.py
)
py_test
(
test_recurrent_op SRCS test_recurrent_op.py
)
py_test
(
test_sgd_op SRCS test_sgd_op.py
)
py_test
(
test_sgd_op SRCS test_sgd_op.py
)
py_test
(
test_gradient_checker SRCS test_gradient_checker.py
)
py_test
(
test_gradient_checker SRCS test_gradient_checker.py
)
py_test
(
test_lookup_table SRCS test_lookup_table.py
)
py_test
(
test_scale_and_identity_op SRCS test_scale_and_identity_op.py
)
py_test
(
test_scale_and_identity_op SRCS test_scale_and_identity_op.py
)
py_test
(
mnist SRCS mnist.py
)
py_test
(
mnist SRCS mnist.py
)
python/paddle/v2/framework/tests/gradient_checker.py
浏览文件 @
3663bd88
...
@@ -23,6 +23,10 @@ def grad_var_name(var_name):
...
@@ -23,6 +23,10 @@ def grad_var_name(var_name):
return
var_name
+
"@GRAD"
return
var_name
+
"@GRAD"
def
empty_var_name
():
return
"@EMPTY@"
def
get_numeric_gradient
(
op
,
def
get_numeric_gradient
(
op
,
input_values
,
input_values
,
output_name
,
output_name
,
...
@@ -176,7 +180,7 @@ class GradientChecker(unittest.TestCase):
...
@@ -176,7 +180,7 @@ class GradientChecker(unittest.TestCase):
]
]
return
outs
return
outs
def
compare_grad
(
self
,
forward_op
,
input_value
):
def
compare_grad
(
self
,
forward_op
,
input_value
,
no_grad_set
=
None
):
""" Compare the input gradients between CPU and GPU for the given forward
""" Compare the input gradients between CPU and GPU for the given forward
operator.
operator.
...
@@ -184,15 +188,20 @@ class GradientChecker(unittest.TestCase):
...
@@ -184,15 +188,20 @@ class GradientChecker(unittest.TestCase):
:type forward_op: Operator
:type forward_op: Operator
:param input_value: input values.
:param input_value: input values.
:type input_value: dict{string:numpy.array}
:type input_value: dict{string:numpy.array}
:param no_grad_set: the set of variables names without gradients.
:type no_grad_set: a set of string
:raises: AssertionError, there is different gradient value.
:raises: AssertionError, there is different gradient value.
"""
"""
backward_op
=
core
.
Operator
.
backward
(
forward_op
,
set
())
if
no_grad_set
is
None
:
no_grad_set
=
set
()
backward_op
=
core
.
Operator
.
backward
(
forward_op
,
no_grad_set
)
# return if not compile with GPU or not implementing GPU kernel
# return if not compile with GPU or not implementing GPU kernel
if
not
(
core
.
is_compile_gpu
()
and
backward_op
.
support_gpu
()):
if
not
(
core
.
is_compile_gpu
()
and
backward_op
.
support_gpu
()):
return
return
outputs
=
backward_op
.
outputs
()
outputs
=
backward_op
.
outputs
()
out_names
=
[
item
for
k
in
outputs
for
item
in
outputs
[
k
]]
out_names
=
[
item
for
k
in
outputs
for
item
in
outputs
[
k
]]
out_names
=
filter
(
lambda
x
:
x
!=
empty_var_name
(),
out_names
)
cpu_grads
=
self
.
__get_gradient
(
forward_op
,
backward_op
,
input_value
,
cpu_grads
=
self
.
__get_gradient
(
forward_op
,
backward_op
,
input_value
,
out_names
,
core
.
CPUPlace
())
out_names
,
core
.
CPUPlace
())
gpu_grads
=
self
.
__get_gradient
(
forward_op
,
backward_op
,
input_value
,
gpu_grads
=
self
.
__get_gradient
(
forward_op
,
backward_op
,
input_value
,
...
...
python/paddle/v2/framework/tests/test_lookup_table.py
0 → 100644
浏览文件 @
3663bd88
import
unittest
import
numpy
as
np
from
op_test_util
import
OpTestMeta
from
gradient_checker
import
GradientChecker
,
create_op
class
TestSigmoidOp
(
unittest
.
TestCase
):
__metaclass__
=
OpTestMeta
def
setUp
(
self
):
self
.
type
=
'lookup_table'
table
=
np
.
random
.
random
((
17
,
31
)).
astype
(
'float32'
)
ids
=
np
.
random
.
randint
(
0
,
17
,
4
).
astype
(
'int32'
)
self
.
inputs
=
{
'W'
:
table
,
'Ids'
:
ids
}
self
.
outputs
=
{
'Out'
:
table
[
ids
]}
class
TestSigmoidGradOp
(
GradientChecker
):
def
test_grad
(
self
):
op
=
create_op
(
'lookup_table'
)
table
=
np
.
random
.
random
((
17
,
31
)).
astype
(
'float32'
)
ids
=
np
.
random
.
randint
(
0
,
17
,
4
).
astype
(
'int32'
)
inputs
=
{
'W'
:
table
,
'Ids'
:
ids
}
# comapre gradients
self
.
compare_grad
(
op
,
inputs
,
set
([
'Ids'
]))
# check gradients
self
.
check_grad
(
op
,
inputs
,
set
(
'W'
),
'Out'
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录