Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
2f47f35b
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
2f47f35b
编写于
8月 21, 2017
作者:
Q
qijun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix gpu build error
上级
7c274dc0
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
28 addition
and
33 deletion
+28
-33
paddle/operators/math/CMakeLists.txt
paddle/operators/math/CMakeLists.txt
+2
-2
paddle/operators/math/math_function.cc
paddle/operators/math/math_function.cc
+5
-5
paddle/operators/math/math_function.cu
paddle/operators/math/math_function.cu
+8
-7
paddle/operators/math/math_function.h
paddle/operators/math/math_function.h
+2
-5
paddle/operators/uniform_random_op.cu
paddle/operators/uniform_random_op.cu
+3
-6
paddle/platform/device_context.cc
paddle/platform/device_context.cc
+5
-5
paddle/platform/device_context.h
paddle/platform/device_context.h
+3
-3
未找到文件。
paddle/operators/math/CMakeLists.txt
浏览文件 @
2f47f35b
if
(
WITH_GPU
)
nv_library
(
math_function SRCS math_function.cc math_function.cu DEPS cblas device_context
)
nv_library
(
math_function SRCS math_function.cc math_function.cu DEPS cblas device_context
eigen3
)
else
()
cc_library
(
math_function SRCS math_function.cc DEPS cblas device_context
)
cc_library
(
math_function SRCS math_function.cc DEPS cblas device_context
eigen3
)
endif
()
nv_test
(
math_function_test SRCS math_function_test.cc DEPS math_function tensor
)
paddle/operators/math/math_function.cc
浏览文件 @
2f47f35b
...
...
@@ -110,12 +110,12 @@ void matmul<platform::CPUPlace, double>(const framework::Tensor& matrix_a,
}
template
<
>
void
Set
<
typename
CPUPlace
,
typename
float
>
(
const
int
n
,
const
float
alpha
,
float
*
output
,
platform
::
DeviceContext
*
context
)
{
void
Set
<
platform
::
CPUPlace
,
float
>
(
const
int
n
,
const
float
alpha
,
float
*
output
,
platform
::
DeviceContext
*
context
)
{
auto
*
cpu_context
=
reinterpret_cast
<
platform
::
CPUDeviceContext
*>
(
context
);
framework
::
EigenVector
::
Type
<
T
>
out
(
output
,
n
);
out
.
device
(
*
(
cpu_context
->
eigen_device
()))
=
t
.
constant
(
T
(
alpha
));
framework
::
EigenVector
<
float
>::
Type
out
(
output
,
n
);
out
.
device
(
*
(
cpu_context
->
eigen_device
()))
=
out
.
constant
(
float
(
alpha
));
}
template
<
>
...
...
paddle/operators/math/math_function.cu
浏览文件 @
2f47f35b
...
...
@@ -127,12 +127,12 @@ void matmul<platform::GPUPlace, double>(const framework::Tensor& matrix_a,
}
template
<
>
void
Set
<
typename
GPUPlace
,
typename
float
>
(
const
int
n
,
const
float
alpha
,
float
*
output
,
platform
::
DeviceContext
*
context
)
{
void
Set
<
platform
::
GPUPlace
,
float
>
(
const
int
n
,
const
float
alpha
,
float
*
output
,
platform
::
DeviceContext
*
context
)
{
auto
*
cuda_context
=
reinterpret_cast
<
platform
::
CUDADeviceContext
*>
(
context
);
framework
::
EigenVector
::
Type
<
T
>
out
(
output
,
n
);
out
.
device
(
*
(
cuda_context
->
eigen_device
()))
=
t
.
constant
(
T
(
alpha
));
framework
::
EigenVector
<
float
>::
Type
out
(
output
,
n
);
out
.
device
(
*
(
cuda_context
->
eigen_device
()))
=
out
.
constant
(
float
(
alpha
));
}
template
<
typename
T
>
...
...
@@ -159,12 +159,13 @@ void RandUniform<platform::GPUPlace, float>(const int n, const float min,
template
<
typename
T
>
int
HandleOddLengthRandGaussian
(
const
int
n
,
const
T
mean
,
const
T
std
,
T
*
output
,
CUDADeviceContext
*
context
)
{
T
*
output
,
platform
::
CUDADeviceContext
*
context
)
{
if
(
n
%
2
==
1
)
{
std
::
default_random_engine
generator
;
std
::
normal_distribution
<
T
>
distribution
(
mean
,
std
);
const
T
random_value
=
distribution
(
generator
);
Set
<
T
,
platform
::
GPUPlace
>
(
1
,
random_value
,
output
+
(
n
-
1
),
context
);
Set
<
platform
::
GPUPlace
,
T
>
(
1
,
random_value
,
output
+
(
n
-
1
),
context
);
return
n
-
1
;
}
return
n
;
...
...
paddle/operators/math/math_function.h
浏览文件 @
2f47f35b
...
...
@@ -52,9 +52,9 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda,
#include <cmath>
#include "paddle/framework/eigen.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/eigen.h"
#include "paddle/platform/enforce.h"
namespace
paddle
{
...
...
@@ -80,10 +80,7 @@ void matmul(const framework::Tensor& matrix_a, bool trans_a,
template
<
typename
Place
,
typename
T
>
void
Set
(
const
int
n
,
const
T
alpha
,
T
*
output
,
platform
::
DeviceContext
*
context
)
{
framework
::
EigenVector
::
Type
<
T
>
out
(
output
,
n
);
out
.
device
(
*
(
context
->
eigen_device
()))
=
t
.
constant
(
T
(
alpha
));
}
platform
::
DeviceContext
*
context
);
template
<
typename
Place
,
typename
T
>
void
RandUniform
(
const
int
n
,
const
T
min
,
const
T
max
,
T
*
output
,
...
...
paddle/operators/uniform_random_op.cu
浏览文件 @
2f47f35b
...
...
@@ -14,9 +14,6 @@
#include "paddle/operators/uniform_random_op.h"
namespace
paddle
{
namespace
operators
{
REGISTER_OP_GPU_KERNEL
(
uniform_random
,
paddle
::
operators
::
GPUUniformRandomKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
REGISTER_OP_GPU_KERNEL
(
uniform_random
,
paddle
::
operators
::
UniformRandomKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
paddle/platform/device_context.cc
浏览文件 @
2f47f35b
...
...
@@ -25,9 +25,9 @@ CPUDeviceContext::CPUDeviceContext() {
eigen_device_
.
reset
(
new
Eigen
::
DefaultDevice
());
}
CPUDeviceContext
::
CPUDeviceContext
(
CPUPlace
place
,
int
rand_
seed
)
{
CPUDeviceContext
::
CPUDeviceContext
(
CPUPlace
place
,
int
seed
)
{
eigen_device_
.
reset
(
new
Eigen
::
DefaultDevice
());
rand_seed_
=
rand_
seed
;
rand_seed_
=
seed
;
}
std
::
minstd_rand
&
CPUDeviceContext
::
rand_engine
()
{
...
...
@@ -105,7 +105,7 @@ Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() const {
}
CUDADeviceContext
::
CUDADeviceContext
(
GPUPlace
place
,
uint64_t
seed
)
:
place_
(
place
),
seed_
(
seed
)
{
:
place_
(
place
),
rand_
seed_
(
seed
)
{
SetDeviceId
(
place_
.
device
);
PADDLE_ENFORCE
(
cudaStreamCreate
(
&
stream_
));
eigen_stream_
.
reset
(
new
EigenCudaStreamDevice
());
...
...
@@ -162,8 +162,8 @@ curandGenerator_t CUDADeviceContext::curand_generator() {
SetDeviceId
(
place_
.
device
);
PADDLE_ENFORCE
(
dynload
::
curandCreateGenerator
(
&
curand_generator_
,
CURAND_RNG_PSEUDO_DEFAULT
));
PADDLE_ENFORCE
(
dynload
::
curandSetPseudoRandomGeneratorSeed
(
curand_generator_
,
seed_
));
PADDLE_ENFORCE
(
dynload
::
curandSetPseudoRandomGeneratorSeed
(
curand_generator_
,
rand_
seed_
));
PADDLE_ENFORCE
(
dynload
::
curandSetStream
(
curand_generator_
,
stream_
));
}
...
...
paddle/platform/device_context.h
浏览文件 @
2f47f35b
...
...
@@ -40,7 +40,7 @@ class DeviceContext {
class
CPUDeviceContext
:
public
DeviceContext
{
public:
CPUDeviceContext
();
explicit
CPUDeviceContext
(
CPUPlace
place
,
int
rand_
seed
=
0
);
explicit
CPUDeviceContext
(
CPUPlace
place
,
int
seed
=
0
);
virtual
~
CPUDeviceContext
()
{}
Eigen
::
DefaultDevice
*
eigen_device
()
const
;
...
...
@@ -60,7 +60,7 @@ class EigenCudaStreamDevice;
class
CUDADeviceContext
:
public
DeviceContext
{
public:
explicit
CUDADeviceContext
(
GPUPlace
place
,
uint64_t
rand_
seed
=
0
);
explicit
CUDADeviceContext
(
GPUPlace
place
,
uint64_t
seed
=
0
);
virtual
~
CUDADeviceContext
();
/*! \brief Wait for all operations completion in the stream. */
...
...
@@ -93,12 +93,12 @@ class CUDADeviceContext : public DeviceContext {
std
::
unique_ptr
<
EigenCudaStreamDevice
>
eigen_stream_
;
uint64_t
rand_seed_
;
std
::
unique_ptr
<
thrust
::
minstd_rand
>
rand_engine_
;
// clang-format off
cudaStream_t
stream_
{
nullptr
};
cudnnHandle_t
cudnn_handle_
{
nullptr
};
cublasHandle_t
cublas_handle_
{
nullptr
};
curandGenerator_t
curand_generator_
{
nullptr
};
// clang-format on
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录