Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
f188e22b
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f188e22b
编写于
8月 23, 2017
作者:
D
dangqingqing
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove set functor and add comapre_grad test
上级
a8d072c7
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
33 addition
and
148 deletion
+33
-148
paddle/operators/CMakeLists.txt
paddle/operators/CMakeLists.txt
+1
-2
paddle/operators/fill_zeros_like_op.h
paddle/operators/fill_zeros_like_op.h
+1
-1
paddle/operators/functor/CMakeLists.txt
paddle/operators/functor/CMakeLists.txt
+0
-5
paddle/operators/functor/math_functor.cc
paddle/operators/functor/math_functor.cc
+0
-42
paddle/operators/functor/math_functor.cu
paddle/operators/functor/math_functor.cu
+0
-42
paddle/operators/functor/math_functor.h
paddle/operators/functor/math_functor.h
+0
-32
paddle/operators/lookup_table_op.cu
paddle/operators/lookup_table_op.cu
+13
-13
paddle/operators/lookup_table_op.h
paddle/operators/lookup_table_op.h
+5
-5
paddle/platform/cuda_helper.h
paddle/platform/cuda_helper.h
+0
-4
python/paddle/v2/framework/tests/gradient_checker.py
python/paddle/v2/framework/tests/gradient_checker.py
+11
-2
python/paddle/v2/framework/tests/test_lookup_table.py
python/paddle/v2/framework/tests/test_lookup_table.py
+2
-0
未找到文件。
paddle/operators/CMakeLists.txt
浏览文件 @
f188e22b
...
@@ -42,7 +42,6 @@ function(op_library TARGET)
...
@@ -42,7 +42,6 @@ function(op_library TARGET)
endfunction
()
endfunction
()
add_subdirectory
(
math
)
add_subdirectory
(
math
)
add_subdirectory
(
functor
)
cc_test
(
gather_test SRCS gather_test.cc DEPS tensor
)
cc_test
(
gather_test SRCS gather_test.cc DEPS tensor
)
...
@@ -69,4 +68,4 @@ op_library(sgd_op SRCS sgd_op.cc sgd_op.cu)
...
@@ -69,4 +68,4 @@ op_library(sgd_op SRCS sgd_op.cc sgd_op.cu)
op_library
(
recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
op_library
(
recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS framework_proto tensor op_registry operator net_op
)
DEPS framework_proto tensor op_registry operator net_op
)
op_library
(
uniform_random_op SRCS uniform_random_op.cc uniform_random_op.cu
)
op_library
(
uniform_random_op SRCS uniform_random_op.cc uniform_random_op.cu
)
op_library
(
lookup_table_op SRCS lookup_table_op.cc lookup_table_op.cu
DEPS math_functor
)
op_library
(
lookup_table_op SRCS lookup_table_op.cc lookup_table_op.cu
)
paddle/operators/fill_zeros_like_op.h
浏览文件 @
f188e22b
...
@@ -26,7 +26,7 @@ class FillZerosLikeKernel : public framework::OpKernel {
...
@@ -26,7 +26,7 @@ class FillZerosLikeKernel : public framework::OpKernel {
auto
*
output
=
context
.
Output
<
framework
::
Tensor
>
(
"Dst"
);
auto
*
output
=
context
.
Output
<
framework
::
Tensor
>
(
"Dst"
);
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
t
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
output
);
auto
t
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
output
);
t
.
device
(
context
.
GetEigenDevice
<
Place
>
())
=
t
.
constant
(
T
(
0
));
t
.
device
(
context
.
GetEigenDevice
<
Place
>
())
=
t
.
constant
(
static_cast
<
T
>
(
0
));
}
}
};
};
...
...
paddle/operators/functor/CMakeLists.txt
已删除
100644 → 0
浏览文件 @
a8d072c7
if
(
WITH_GPU
)
nv_library
(
math_functor SRCS math_functor.cc math_functor.cu DEPS device_context
)
else
()
cc_library
(
math_functor SRCS math_functor.cc DEPS device_context
)
endif
()
paddle/operators/functor/math_functor.cc
已删除
100644 → 0
浏览文件 @
a8d072c7
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/functor/math_functor.h"
#include "paddle/framework/eigen.h"
namespace
paddle
{
namespace
operators
{
namespace
functor
{
template
<
typename
T
>
struct
Set
<
platform
::
CPUPlace
,
T
>
{
void
operator
()(
const
T
alpha
,
framework
::
Tensor
*
Y
,
platform
::
DeviceContext
*
context
)
{
int
N
=
product
(
Y
->
dims
());
T
*
YData
=
Y
->
mutable_data
<
T
>
(
context
->
GetPlace
());
if
(
alpha
==
static_cast
<
T
>
(
0
))
{
memset
(
YData
,
0
,
N
*
sizeof
(
T
));
}
else
{
framework
::
EigenVector
<
T
,
Eigen
::
RowMajor
,
Eigen
::
DenseIndex
>::
Flatten
(
*
Y
)
.
setConstant
(
alpha
);
}
}
};
template
struct
Set
<
platform
::
CPUPlace
,
float
>;
template
struct
Set
<
platform
::
CPUPlace
,
double
>;
}
// namespace functor
}
// namespace operators
}
// namespace paddle
paddle/operators/functor/math_functor.cu
已删除
100644 → 0
浏览文件 @
a8d072c7
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/functor/math_functor.h"
#include "paddle/platform/cuda_helper.h"
namespace
paddle
{
namespace
operators
{
namespace
functor
{
template
<
typename
T
>
__global__
void
SetKernel
(
const
int
N
,
const
T
alpha
,
T
*
Y
)
{
CUDA_1D_KERNEL_LOOP
(
i
,
N
)
{
Y
[
i
]
=
alpha
;
}
}
template
<
typename
T
>
struct
Set
<
platform
::
GPUPlace
,
T
>
{
void
operator
()(
const
T
alpha
,
framework
::
Tensor
*
Y
,
platform
::
DeviceContext
*
context
)
{
int
N
=
product
(
Y
->
dims
());
T
*
YData
=
Y
->
mutable_data
<
T
>
(
context
->
GetPlace
());
SetKernel
<<<
(
N
+
512
-
1
)
/
512
,
512
>>>
(
N
,
alpha
,
YData
);
}
};
template
struct
Set
<
platform
::
GPUPlace
,
float
>;
template
struct
Set
<
platform
::
GPUPlace
,
double
>;
}
// namespace functor
}
// namespace operators
}
// namespace paddle
paddle/operators/functor/math_functor.h
已删除
100644 → 0
浏览文件 @
a8d072c7
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
namespace
paddle
{
namespace
operators
{
namespace
functor
{
template
<
typename
Place
,
typename
T
>
struct
Set
{
void
operator
()(
const
T
alpha
,
paddle
::
framework
::
Tensor
*
Y
,
paddle
::
platform
::
DeviceContext
*
context
);
};
}
// namespace functor
}
// namespace operators
}
// namespace paddle
paddle/operators/lookup_table_op.cu
浏览文件 @
f188e22b
...
@@ -12,8 +12,8 @@
...
@@ -12,8 +12,8 @@
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/functor/math_functor.h"
#include "paddle/platform/assert.h"
#include "paddle/platform/assert.h"
#include "paddle/platform/cuda_helper.h"
#include "paddle/platform/cuda_helper.h"
...
@@ -22,11 +22,11 @@ namespace operators {
...
@@ -22,11 +22,11 @@ namespace operators {
using
Tensor
=
framework
::
Tensor
;
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
,
int
blockDimX
,
int
blockDimY
,
int
g
ridDimX
>
template
<
typename
T
,
int
BlockDimX
,
int
BlockDimY
,
int
G
ridDimX
>
__global__
void
LookupTable
(
T
*
output
,
const
T
*
table
,
const
int32_t
*
ids
,
__global__
void
LookupTable
(
T
*
output
,
const
T
*
table
,
const
int32_t
*
ids
,
const
int
N
,
const
int
K
,
const
int
D
)
{
const
int
N
,
const
int
K
,
const
int
D
)
{
int
idx
=
threadIdx
.
x
;
int
idx
=
threadIdx
.
x
;
int
idy
=
blockIdx
.
x
+
threadIdx
.
y
*
g
ridDimX
;
int
idy
=
blockIdx
.
x
+
threadIdx
.
y
*
G
ridDimX
;
while
(
idy
<
K
)
{
while
(
idy
<
K
)
{
int
id
=
ids
[
idy
];
int
id
=
ids
[
idy
];
...
@@ -34,18 +34,18 @@ __global__ void LookupTable(T* output, const T* table, const int32_t* ids,
...
@@ -34,18 +34,18 @@ __global__ void LookupTable(T* output, const T* table, const int32_t* ids,
PADDLE_ASSERT
(
id
<
N
);
PADDLE_ASSERT
(
id
<
N
);
T
*
out
=
output
+
idy
*
D
;
T
*
out
=
output
+
idy
*
D
;
const
T
*
tab
=
table
+
id
*
D
;
const
T
*
tab
=
table
+
id
*
D
;
for
(
int
i
=
idx
;
i
<
D
;
i
+=
b
lockDimX
)
{
for
(
int
i
=
idx
;
i
<
D
;
i
+=
B
lockDimX
)
{
out
[
i
]
=
tab
[
i
];
out
[
i
]
=
tab
[
i
];
}
}
idy
+=
blockDimY
*
g
ridDimX
;
idy
+=
BlockDimY
*
G
ridDimX
;
}
}
}
}
template
<
typename
T
,
int
blockDimX
,
int
blockDimY
,
int
g
ridDimX
>
template
<
typename
T
,
int
BlockDimX
,
int
BlockDimY
,
int
G
ridDimX
>
__global__
void
LookupTableGrad
(
T
*
table
,
const
T
*
output
,
const
int32_t
*
ids
,
__global__
void
LookupTableGrad
(
T
*
table
,
const
T
*
output
,
const
int32_t
*
ids
,
const
int
N
,
const
int
K
,
const
int
D
)
{
const
int
N
,
const
int
K
,
const
int
D
)
{
int
idx
=
threadIdx
.
x
;
int
idx
=
threadIdx
.
x
;
int
idy
=
blockIdx
.
x
+
threadIdx
.
y
*
g
ridDimX
;
int
idy
=
blockIdx
.
x
+
threadIdx
.
y
*
G
ridDimX
;
while
(
idy
<
K
)
{
while
(
idy
<
K
)
{
int
id
=
ids
[
idy
];
int
id
=
ids
[
idy
];
...
@@ -53,10 +53,10 @@ __global__ void LookupTableGrad(T* table, const T* output, const int32_t* ids,
...
@@ -53,10 +53,10 @@ __global__ void LookupTableGrad(T* table, const T* output, const int32_t* ids,
PADDLE_ASSERT
(
id
<
N
);
PADDLE_ASSERT
(
id
<
N
);
const
T
*
out
=
output
+
idy
*
D
;
const
T
*
out
=
output
+
idy
*
D
;
T
*
tab
=
table
+
id
*
D
;
T
*
tab
=
table
+
id
*
D
;
for
(
int
i
=
idx
;
i
<
D
;
i
+=
b
lockDimX
)
{
for
(
int
i
=
idx
;
i
<
D
;
i
+=
B
lockDimX
)
{
paddle
::
platform
::
CudaAtomicAdd
(
&
tab
[
i
],
out
[
i
]);
paddle
::
platform
::
CudaAtomicAdd
(
&
tab
[
i
],
out
[
i
]);
}
}
idy
+=
blockDimY
*
g
ridDimX
;
idy
+=
BlockDimY
*
G
ridDimX
;
}
}
}
}
...
@@ -96,10 +96,10 @@ class LookupTableGradCUDAKernel : public framework::OpKernel {
...
@@ -96,10 +96,10 @@ class LookupTableGradCUDAKernel : public framework::OpKernel {
const
T
*
d_output
=
d_output_t
->
data
<
T
>
();
const
T
*
d_output
=
d_output_t
->
data
<
T
>
();
T
*
d_table
=
d_table_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
d_table
=
d_table_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
device_context
=
auto
t
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
d_table_t
);
const_cast
<
platform
::
DeviceContext
*>
(
context
.
device_context_
);
t
.
device
(
context
.
GetEigenDevice
<
platform
::
GPUPlace
>
())
=
functor
::
Set
<
paddle
::
platform
::
GPUPlace
,
T
>
()(
static_cast
<
T
>
(
0
),
d_table_t
,
t
.
constant
(
static_cast
<
T
>
(
0
));
device_context
);
dim3
threads
(
128
,
8
);
dim3
threads
(
128
,
8
);
dim3
grids
(
8
,
1
);
dim3
grids
(
8
,
1
);
LookupTableGrad
<
T
,
128
,
8
,
8
><<<
grids
,
threads
>>>
(
d_table
,
d_output
,
ids
,
N
,
LookupTableGrad
<
T
,
128
,
8
,
8
><<<
grids
,
threads
>>>
(
d_table
,
d_output
,
ids
,
N
,
...
...
paddle/operators/lookup_table_op.h
浏览文件 @
f188e22b
...
@@ -14,8 +14,8 @@
...
@@ -14,8 +14,8 @@
#pragma once
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/functor/math_functor.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -57,10 +57,10 @@ class LookupTableGradKernel : public framework::OpKernel {
...
@@ -57,10 +57,10 @@ class LookupTableGradKernel : public framework::OpKernel {
const
T
*
d_output
=
d_output_t
->
data
<
T
>
();
const
T
*
d_output
=
d_output_t
->
data
<
T
>
();
T
*
d_table
=
d_table_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
d_table
=
d_table_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
device_context
=
auto
t
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
d_table_t
);
const_cast
<
platform
::
DeviceContext
*>
(
context
.
device_context_
);
t
.
device
(
context
.
GetEigenDevice
<
platform
::
CPUPlace
>
())
=
functor
::
Set
<
paddle
::
platform
::
CPUPlace
,
T
>
()(
static_cast
<
T
>
(
0
),
d_table_t
,
t
.
constant
(
static_cast
<
T
>
(
0
));
device_context
);
for
(
size_t
i
=
0
;
i
<
product
(
ids_t
->
dims
());
++
i
)
{
for
(
size_t
i
=
0
;
i
<
product
(
ids_t
->
dims
());
++
i
)
{
PADDLE_ENFORCE_LT
(
ids
[
i
],
N
);
PADDLE_ENFORCE_LT
(
ids
[
i
],
N
);
PADDLE_ENFORCE_GE
(
ids
[
i
],
0
);
PADDLE_ENFORCE_GE
(
ids
[
i
],
0
);
...
...
paddle/platform/cuda_helper.h
浏览文件 @
f188e22b
...
@@ -18,10 +18,6 @@ limitations under the License. */
...
@@ -18,10 +18,6 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
platform
{
namespace
platform
{
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
i += blockDim.x * gridDim.x)
#define CUDA_ATOMIC_WRAPPER(op, T) \
#define CUDA_ATOMIC_WRAPPER(op, T) \
__device__ __forceinline__ T CudaAtomic##op(T* address, const T val)
__device__ __forceinline__ T CudaAtomic##op(T* address, const T val)
...
...
python/paddle/v2/framework/tests/gradient_checker.py
浏览文件 @
f188e22b
...
@@ -23,6 +23,10 @@ def grad_var_name(var_name):
...
@@ -23,6 +23,10 @@ def grad_var_name(var_name):
return
var_name
+
"@GRAD"
return
var_name
+
"@GRAD"
def
empty_var_name
():
return
"@EMPTY@"
def
get_numeric_gradient
(
op
,
def
get_numeric_gradient
(
op
,
input_values
,
input_values
,
output_name
,
output_name
,
...
@@ -171,7 +175,7 @@ class GradientChecker(unittest.TestCase):
...
@@ -171,7 +175,7 @@ class GradientChecker(unittest.TestCase):
]
]
return
outs
return
outs
def
compare_grad
(
self
,
forward_op
,
input_value
):
def
compare_grad
(
self
,
forward_op
,
input_value
,
no_grad_set
=
None
):
""" Compare the input gradients between CPU and GPU for the given forward
""" Compare the input gradients between CPU and GPU for the given forward
operator.
operator.
...
@@ -179,15 +183,20 @@ class GradientChecker(unittest.TestCase):
...
@@ -179,15 +183,20 @@ class GradientChecker(unittest.TestCase):
:type forward_op: Operator
:type forward_op: Operator
:param input_value: input values.
:param input_value: input values.
:type input_value: dict{string:numpy.array}
:type input_value: dict{string:numpy.array}
:param no_grad_set: the set of variables names without gradients.
:type no_grad_set: a set of string
:raises: AssertionError, there is different gradient value.
:raises: AssertionError, there is different gradient value.
"""
"""
backward_op
=
core
.
Operator
.
backward
(
forward_op
,
set
())
if
no_grad_set
is
None
:
no_grad_set
=
set
()
backward_op
=
core
.
Operator
.
backward
(
forward_op
,
no_grad_set
)
# return if not compile with GPU or not implementing GPU kernel
# return if not compile with GPU or not implementing GPU kernel
if
not
(
core
.
is_compile_gpu
()
and
backward_op
.
support_gpu
()):
if
not
(
core
.
is_compile_gpu
()
and
backward_op
.
support_gpu
()):
return
return
outputs
=
backward_op
.
outputs
()
outputs
=
backward_op
.
outputs
()
out_names
=
[
item
for
k
in
outputs
for
item
in
outputs
[
k
]]
out_names
=
[
item
for
k
in
outputs
for
item
in
outputs
[
k
]]
out_names
=
filter
(
lambda
x
:
x
!=
empty_var_name
(),
out_names
)
cpu_grads
=
self
.
__get_gradient
(
forward_op
,
backward_op
,
input_value
,
cpu_grads
=
self
.
__get_gradient
(
forward_op
,
backward_op
,
input_value
,
out_names
,
core
.
CPUPlace
())
out_names
,
core
.
CPUPlace
())
gpu_grads
=
self
.
__get_gradient
(
forward_op
,
backward_op
,
input_value
,
gpu_grads
=
self
.
__get_gradient
(
forward_op
,
backward_op
,
input_value
,
...
...
python/paddle/v2/framework/tests/test_lookup_table.py
浏览文件 @
f188e22b
...
@@ -21,6 +21,8 @@ class TestSigmoidGradOp(GradientChecker):
...
@@ -21,6 +21,8 @@ class TestSigmoidGradOp(GradientChecker):
table
=
np
.
random
.
random
((
17
,
31
)).
astype
(
'float32'
)
table
=
np
.
random
.
random
((
17
,
31
)).
astype
(
'float32'
)
ids
=
np
.
random
.
randint
(
0
,
17
,
4
).
astype
(
'int32'
)
ids
=
np
.
random
.
randint
(
0
,
17
,
4
).
astype
(
'int32'
)
inputs
=
{
'W'
:
table
,
'Ids'
:
ids
}
inputs
=
{
'W'
:
table
,
'Ids'
:
ids
}
# comapre gradients
self
.
compare_grad
(
op
,
inputs
,
set
([
'Ids'
]))
# check gradients
# check gradients
self
.
check_grad
(
op
,
inputs
,
set
(
'W'
),
'Out'
)
self
.
check_grad
(
op
,
inputs
,
set
(
'W'
),
'Out'
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录