Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
fa54fb33
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
fa54fb33
编写于
8月 14, 2017
作者:
Q
QI JUN
提交者:
GitHub
8月 14, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #3209 from QiJune/port_blas
port gemm to new framework
上级
81f5f861
c41862d2
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
429 addition
and
8 deletion
+429
-8
paddle/framework/operator.h
paddle/framework/operator.h
+4
-0
paddle/framework/tensor.h
paddle/framework/tensor.h
+2
-0
paddle/operators/CMakeLists.txt
paddle/operators/CMakeLists.txt
+2
-1
paddle/operators/math/CMakeLists.txt
paddle/operators/math/CMakeLists.txt
+13
-0
paddle/operators/math/math_function.cc
paddle/operators/math/math_function.cc
+114
-0
paddle/operators/math/math_function.cu
paddle/operators/math/math_function.cu
+127
-0
paddle/operators/math/math_function.h
paddle/operators/math/math_function.h
+82
-0
paddle/operators/math/math_function_test.cc
paddle/operators/math/math_function_test.cc
+75
-0
paddle/operators/mul_op.cc
paddle/operators/mul_op.cc
+1
-0
paddle/operators/mul_op.cu
paddle/operators/mul_op.cu
+0
-1
paddle/operators/mul_op.h
paddle/operators/mul_op.h
+3
-0
paddle/platform/dynload/cublas.h
paddle/platform/dynload/cublas.h
+6
-6
未找到文件。
paddle/framework/operator.h
浏览文件 @
fa54fb33
...
...
@@ -264,6 +264,10 @@ class ExecutionContext : public InferShapeContext {
platform
::
Place
GetPlace
()
const
{
return
device_context_
->
GetPlace
();
}
const
platform
::
DeviceContext
*
device_context
()
const
{
return
device_context_
;
}
const
platform
::
DeviceContext
*
device_context_
;
};
...
...
paddle/framework/tensor.h
浏览文件 @
fa54fb33
...
...
@@ -105,6 +105,8 @@ class Tensor {
template
<
typename
T
>
inline
Tensor
Slice
(
const
int
&
begin_idx
,
const
int
&
end_idx
)
const
;
platform
::
Place
place
()
const
{
return
holder_
->
place
();
}
private:
template
<
typename
T
>
inline
void
check_memory_size
()
const
;
...
...
paddle/operators/CMakeLists.txt
浏览文件 @
fa54fb33
...
...
@@ -41,6 +41,7 @@ function(op_library TARGET)
endif
()
endfunction
()
add_subdirectory
(
math
)
cc_test
(
gather_test SRCS gather_test.cc DEPS tensor
)
cc_library
(
net_op SRCS net_op.cc DEPS op_registry
)
...
...
@@ -50,7 +51,7 @@ op_library(add_op SRCS add_op.cc add_op.cu)
op_library
(
mean_op SRCS mean_op.cc mean_op.cu
)
op_library
(
mul_op SRCS mul_op.cc mul_op.cu
)
op_library
(
mul_op SRCS mul_op.cc mul_op.cu
DEPS math_function
)
op_library
(
rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc
)
op_library
(
sigmoid_op SRCS sigmoid_op.cc sigmoid_op.cu
)
...
...
paddle/operators/math/CMakeLists.txt
0 → 100644
浏览文件 @
fa54fb33
if
(
WITH_MKLML
)
set
(
BLAS_LIB mklml
)
else
()
set
(
BLAS_LIB cblas
)
endif
()
if
(
WITH_GPU
)
nv_library
(
math_function SRCS math_function.cc math_function.cu DEPS
${
BLAS_LIB
}
device_context
)
else
()
cc_library
(
math_function SRCS math_function.cc DEPS
${
BLAS_LIB
}
device_context
)
endif
()
nv_test
(
math_function_test SRCS math_function_test.cc DEPS math_function tensor
)
paddle/operators/math/math_function.cc
0 → 100644
浏览文件 @
fa54fb33
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/math_function.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
template
<
>
void
gemm
<
platform
::
CPUPlace
,
float
>
(
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
float
alpha
,
const
float
*
A
,
const
float
*
B
,
const
float
beta
,
float
*
C
,
platform
::
DeviceContext
*
context
)
{
int
lda
=
K
;
int
ldb
=
N
;
int
ldc
=
N
;
cblas_sgemm
(
CblasRowMajor
,
transA
,
transB
,
M
,
N
,
K
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
template
<
>
void
gemm
<
platform
::
CPUPlace
,
double
>
(
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
double
alpha
,
const
double
*
A
,
const
double
*
B
,
const
double
beta
,
double
*
C
,
platform
::
DeviceContext
*
context
)
{
int
lda
=
K
;
int
ldb
=
N
;
int
ldc
=
N
;
cblas_dgemm
(
CblasRowMajor
,
transA
,
transB
,
M
,
N
,
K
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
template
<
>
void
matmul
<
platform
::
CPUPlace
,
float
>
(
const
framework
::
Tensor
&
matrix_a
,
bool
trans_a
,
const
framework
::
Tensor
&
matrix_b
,
bool
trans_b
,
float
alpha
,
framework
::
Tensor
*
matrix_out
,
float
beta
,
platform
::
DeviceContext
*
context
)
{
auto
dim_a
=
matrix_a
.
dims
();
auto
dim_b
=
matrix_b
.
dims
();
auto
dim_out
=
matrix_out
->
dims
();
PADDLE_ENFORCE
(
dim_a
.
size
()
==
2
&&
dim_b
.
size
()
==
2
&&
dim_out
.
size
()
==
2
,
"The input and output of matmul be matrix"
);
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
matrix_a
.
place
())
&&
platform
::
is_cpu_place
(
matrix_b
.
place
())
&&
platform
::
is_cpu_place
(
matrix_out
->
place
()),
"Matrix must all be in CPUPlace"
);
int
M
=
dim_out
[
0
];
int
N
=
dim_out
[
1
];
int
K
=
(
trans_a
==
false
)
?
dim_a
[
1
]
:
dim_a
[
0
];
CBLAS_TRANSPOSE
transA
=
(
trans_a
==
false
)
?
CblasNoTrans
:
CblasTrans
;
CBLAS_TRANSPOSE
transB
=
(
trans_b
==
false
)
?
CblasNoTrans
:
CblasTrans
;
gemm
<
platform
::
CPUPlace
,
float
>
(
transA
,
transB
,
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
float
>
(),
matrix_b
.
data
<
float
>
(),
beta
,
matrix_out
->
data
<
float
>
(),
context
);
}
template
<
>
void
matmul
<
platform
::
CPUPlace
,
double
>
(
const
framework
::
Tensor
&
matrix_a
,
bool
trans_a
,
const
framework
::
Tensor
&
matrix_b
,
bool
trans_b
,
double
alpha
,
framework
::
Tensor
*
matrix_out
,
double
beta
,
platform
::
DeviceContext
*
context
)
{
auto
dim_a
=
matrix_a
.
dims
();
auto
dim_b
=
matrix_b
.
dims
();
auto
dim_out
=
matrix_out
->
dims
();
PADDLE_ENFORCE
(
dim_a
.
size
()
==
2
&&
dim_b
.
size
()
==
2
&&
dim_out
.
size
()
==
2
,
"The input and output of matmul be matrix"
);
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
matrix_a
.
place
())
&&
platform
::
is_cpu_place
(
matrix_b
.
place
())
&&
platform
::
is_cpu_place
(
matrix_out
->
place
()),
"Matrix must all be in CPUPlace"
);
int
M
=
dim_out
[
0
];
int
N
=
dim_out
[
1
];
int
K
=
(
trans_a
==
false
)
?
dim_a
[
1
]
:
dim_a
[
0
];
CBLAS_TRANSPOSE
transA
=
(
trans_a
==
false
)
?
CblasNoTrans
:
CblasTrans
;
CBLAS_TRANSPOSE
transB
=
(
trans_b
==
false
)
?
CblasNoTrans
:
CblasTrans
;
gemm
<
platform
::
CPUPlace
,
double
>
(
transA
,
transB
,
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
double
>
(),
matrix_b
.
data
<
double
>
(),
beta
,
matrix_out
->
data
<
double
>
(),
context
);
}
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/math/math_function.cu
0 → 100644
浏览文件 @
fa54fb33
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/math_function.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
template
<
>
void
gemm
<
platform
::
GPUPlace
,
float
>
(
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
float
alpha
,
const
float
*
A
,
const
float
*
B
,
const
float
beta
,
float
*
C
,
platform
::
DeviceContext
*
context
)
{
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int
lda
=
(
transA
==
CblasNoTrans
)
?
K
:
M
;
int
ldb
=
(
transB
==
CblasNoTrans
)
?
N
:
K
;
cublasOperation_t
cuTransA
=
(
transA
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
cublasOperation_t
cuTransB
=
(
transB
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSgemm
(
reinterpret_cast
<
platform
::
CUDADeviceContext
*>
(
context
)
->
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
N
));
}
template
<
>
void
gemm
<
platform
::
GPUPlace
,
double
>
(
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
double
alpha
,
const
double
*
A
,
const
double
*
B
,
const
double
beta
,
double
*
C
,
platform
::
DeviceContext
*
context
)
{
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int
lda
=
(
transA
==
CblasNoTrans
)
?
K
:
M
;
int
ldb
=
(
transB
==
CblasNoTrans
)
?
N
:
K
;
cublasOperation_t
cuTransA
=
(
transA
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
cublasOperation_t
cuTransB
=
(
transB
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasDgemm
(
reinterpret_cast
<
platform
::
CUDADeviceContext
*>
(
context
)
->
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
N
));
}
template
<
>
void
matmul
<
platform
::
GPUPlace
,
float
>
(
const
framework
::
Tensor
&
matrix_a
,
bool
trans_a
,
const
framework
::
Tensor
&
matrix_b
,
bool
trans_b
,
float
alpha
,
framework
::
Tensor
*
matrix_out
,
float
beta
,
platform
::
DeviceContext
*
context
)
{
auto
dim_a
=
matrix_a
.
dims
();
auto
dim_b
=
matrix_b
.
dims
();
auto
dim_out
=
matrix_out
->
dims
();
PADDLE_ENFORCE
(
dim_a
.
size
()
==
2
&&
dim_b
.
size
()
==
2
&&
dim_out
.
size
()
==
2
,
"The input and output of matmul be matrix"
);
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
matrix_a
.
place
())
&&
platform
::
is_gpu_place
(
matrix_b
.
place
())
&&
platform
::
is_gpu_place
(
matrix_out
->
place
()),
"Matrix must all be in GPUPlace"
);
int
M
=
dim_out
[
0
];
int
N
=
dim_out
[
1
];
int
K
=
(
trans_a
==
false
)
?
dim_a
[
1
]
:
dim_a
[
0
];
CBLAS_TRANSPOSE
transA
=
(
trans_a
==
false
)
?
CblasNoTrans
:
CblasTrans
;
CBLAS_TRANSPOSE
transB
=
(
trans_b
==
false
)
?
CblasNoTrans
:
CblasTrans
;
gemm
<
platform
::
GPUPlace
,
float
>
(
transA
,
transB
,
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
float
>
(),
matrix_b
.
data
<
float
>
(),
beta
,
matrix_out
->
data
<
float
>
(),
context
);
}
template
<
>
void
matmul
<
platform
::
GPUPlace
,
double
>
(
const
framework
::
Tensor
&
matrix_a
,
bool
trans_a
,
const
framework
::
Tensor
&
matrix_b
,
bool
trans_b
,
double
alpha
,
framework
::
Tensor
*
matrix_out
,
double
beta
,
platform
::
DeviceContext
*
context
)
{
auto
dim_a
=
matrix_a
.
dims
();
auto
dim_b
=
matrix_b
.
dims
();
auto
dim_out
=
matrix_out
->
dims
();
PADDLE_ENFORCE
(
dim_a
.
size
()
==
2
&&
dim_b
.
size
()
==
2
&&
dim_out
.
size
()
==
2
,
"The input and output of matmul be matrix"
);
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
matrix_a
.
place
())
&&
platform
::
is_gpu_place
(
matrix_b
.
place
())
&&
platform
::
is_gpu_place
(
matrix_out
->
place
()),
"Matrix must all be in GPUPlace"
);
int
M
=
dim_out
[
0
];
int
N
=
dim_out
[
1
];
int
K
=
(
trans_a
==
false
)
?
dim_a
[
1
]
:
dim_a
[
0
];
CBLAS_TRANSPOSE
transA
=
(
trans_a
==
false
)
?
CblasNoTrans
:
CblasTrans
;
CBLAS_TRANSPOSE
transB
=
(
trans_b
==
false
)
?
CblasNoTrans
:
CblasTrans
;
gemm
<
platform
::
GPUPlace
,
double
>
(
transA
,
transB
,
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
double
>
(),
matrix_b
.
data
<
double
>
(),
beta
,
matrix_out
->
data
<
double
>
(),
context
);
}
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/math/math_function.h
0 → 100644
浏览文件 @
fa54fb33
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_USE_MKLML
#include <mkl_cblas.h>
#include <mkl_lapacke.h>
#include <mkl_vml_functions.h>
#endif
#ifdef PADDLE_USE_MKL
#include <mkl.h>
#include <mkl_lapacke.h>
#endif
#ifdef PADDLE_USE_ATLAS
extern
"C"
{
#include <cblas.h>
#include <clapack.h>
}
#endif
#ifdef PADDLE_USE_OPENBLAS
#include <cblas.h>
#include <lapacke.h>
#endif
#ifndef LAPACK_FOUND
extern
"C"
{
#include <cblas.h>
int
LAPACKE_sgetrf
(
int
matrix_layout
,
int
m
,
int
n
,
float
*
a
,
int
lda
,
int
*
ipiv
);
int
LAPACKE_dgetrf
(
int
matrix_layout
,
int
m
,
int
n
,
double
*
a
,
int
lda
,
int
*
ipiv
);
int
LAPACKE_sgetri
(
int
matrix_layout
,
int
n
,
float
*
a
,
int
lda
,
const
int
*
ipiv
);
int
LAPACKE_dgetri
(
int
matrix_layout
,
int
n
,
double
*
a
,
int
lda
,
const
int
*
ipiv
);
}
#endif
#include <cmath>
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
// Support continuous memory now
// If transA = N, and transB = N
// Then matrixA: M * K, matrixB: K * N matrixC : M * N
// For more detailed info, please refer to
// http://www.netlib.org/lapack/explore-html/d4/de2/sgemm_8f.html
template
<
typename
Place
,
typename
T
>
void
gemm
(
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
T
alpha
,
const
T
*
A
,
const
T
*
B
,
const
T
beta
,
T
*
C
,
platform
::
DeviceContext
*
context
);
// matrix multiply with continuous memory
template
<
typename
Place
,
typename
T
>
void
matmul
(
const
framework
::
Tensor
&
matrix_a
,
bool
trans_a
,
const
framework
::
Tensor
&
matrix_b
,
bool
trans_b
,
T
alpha
,
framework
::
Tensor
*
matrix_out
,
T
beta
,
platform
::
DeviceContext
*
context
);
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/math/math_function_test.cc
0 → 100644
浏览文件 @
fa54fb33
#include "paddle/operators/math/math_function.h"
#include "gtest/gtest.h"
#ifndef PADDLE_ONLY_CPU
TEST
(
math_function
,
notrans_mul_trans
)
{
paddle
::
framework
::
Tensor
input1
;
paddle
::
framework
::
Tensor
input1_gpu
;
paddle
::
framework
::
Tensor
input2_gpu
;
paddle
::
framework
::
Tensor
out_gpu
;
paddle
::
framework
::
Tensor
out
;
auto
*
cpu_place
=
new
paddle
::
platform
::
CPUPlace
();
float
*
input1_ptr
=
input1
.
mutable_data
<
float
>
({
2
,
3
},
*
cpu_place
);
float
arr
[
6
]
=
{
0
,
1
,
2
,
3
,
4
,
5
};
memcpy
(
input1_ptr
,
arr
,
6
*
sizeof
(
float
));
auto
*
gpu_place
=
new
paddle
::
platform
::
GPUPlace
(
0
);
paddle
::
platform
::
DeviceContext
*
context
=
new
paddle
::
platform
::
CUDADeviceContext
(
*
gpu_place
);
input1_gpu
.
CopyFrom
<
float
>
(
input1
,
*
gpu_place
);
input2_gpu
.
CopyFrom
<
float
>
(
input1
,
*
gpu_place
);
out_gpu
.
mutable_data
<
float
>
({
2
,
2
},
*
gpu_place
);
paddle
::
operators
::
math
::
matmul
<
paddle
::
platform
::
GPUPlace
,
float
>
(
input1_gpu
,
false
,
input2_gpu
,
true
,
1
,
&
out_gpu
,
0
,
context
);
out
.
CopyFrom
<
float
>
(
out_gpu
,
*
cpu_place
);
float
*
out_ptr
=
out
.
data
<
float
>
();
EXPECT_EQ
(
out_ptr
[
0
],
5
);
EXPECT_EQ
(
out_ptr
[
1
],
14
);
EXPECT_EQ
(
out_ptr
[
2
],
14
);
EXPECT_EQ
(
out_ptr
[
3
],
50
);
}
TEST
(
math_function
,
trans_mul_notrans
)
{
paddle
::
framework
::
Tensor
input1
;
paddle
::
framework
::
Tensor
input1_gpu
;
paddle
::
framework
::
Tensor
input2_gpu
;
paddle
::
framework
::
Tensor
out_gpu
;
paddle
::
framework
::
Tensor
out
;
auto
*
cpu_place
=
new
paddle
::
platform
::
CPUPlace
();
float
*
input1_ptr
=
input1
.
mutable_data
<
float
>
({
2
,
3
},
*
cpu_place
);
float
arr
[
6
]
=
{
0
,
1
,
2
,
3
,
4
,
5
};
memcpy
(
input1_ptr
,
arr
,
6
*
sizeof
(
float
));
auto
*
gpu_place
=
new
paddle
::
platform
::
GPUPlace
(
0
);
paddle
::
platform
::
DeviceContext
*
context
=
new
paddle
::
platform
::
CUDADeviceContext
(
*
gpu_place
);
input1_gpu
.
CopyFrom
<
float
>
(
input1
,
*
gpu_place
);
input2_gpu
.
CopyFrom
<
float
>
(
input1
,
*
gpu_place
);
out_gpu
.
mutable_data
<
float
>
({
3
,
3
},
*
gpu_place
);
paddle
::
operators
::
math
::
matmul
<
paddle
::
platform
::
GPUPlace
,
float
>
(
input1_gpu
,
true
,
input2_gpu
,
false
,
1
,
&
out_gpu
,
0
,
context
);
out
.
CopyFrom
<
float
>
(
out_gpu
,
*
cpu_place
);
float
*
out_ptr
=
out
.
data
<
float
>
();
EXPECT_EQ
(
out_ptr
[
0
],
9
);
EXPECT_EQ
(
out_ptr
[
1
],
12
);
EXPECT_EQ
(
out_ptr
[
2
],
15
);
EXPECT_EQ
(
out_ptr
[
3
],
12
);
EXPECT_EQ
(
out_ptr
[
4
],
17
);
EXPECT_EQ
(
out_ptr
[
5
],
22
);
EXPECT_EQ
(
out_ptr
[
6
],
15
);
EXPECT_EQ
(
out_ptr
[
7
],
22
);
EXPECT_EQ
(
out_ptr
[
8
],
29
);
}
#endif
paddle/operators/mul_op.cc
浏览文件 @
fa54fb33
...
...
@@ -13,6 +13,7 @@
limitations under the License. */
#include "paddle/operators/mul_op.h"
#include "paddle/operators/math/math_function.h"
namespace
paddle
{
namespace
operators
{
...
...
paddle/operators/mul_op.cu
浏览文件 @
fa54fb33
...
...
@@ -16,5 +16,4 @@
#include "paddle/operators/mul_op.h"
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_GPU_KERNEL
(
mul
,
ops
::
MulKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
paddle/operators/mul_op.h
浏览文件 @
fa54fb33
...
...
@@ -13,6 +13,9 @@
limitations under the License. */
#pragma once
#include "paddle/operators/math/math_function.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
...
...
paddle/platform/dynload/cublas.h
浏览文件 @
fa54fb33
...
...
@@ -62,12 +62,12 @@ extern void *cublas_dso_handle;
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(__name)
#define CUBLAS_BLAS_ROUTINE_EACH(__macro) \
__macro(cublasSgemv
);
\
__macro(cublasDgemv
);
\
__macro(cublasSgemm
);
\
__macro(cublasDgemm
);
\
__macro(cublasSgeam
);
\
__macro(cublasDgeam
);
\
__macro(cublasSgemv
_v2);
\
__macro(cublasDgemv
_v2);
\
__macro(cublasSgemm
_v2);
\
__macro(cublasDgemm
_v2);
\
__macro(cublasSgeam
_v2);
\
__macro(cublasDgeam
_v2);
\
__macro(cublasCreate_v2); \
__macro(cublasDestroy_v2); \
__macro(cublasSetStream_v2); \
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录