Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
425279a5
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
425279a5
编写于
9月 30, 2019
作者:
D
danleifeng
提交者:
gongweibao
9月 30, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Improve elementwise operators performance in same dimensions. (#19763)
Improve elementwise operators performance in same dimensions
上级
292aae43
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
794 addition
and
135 deletion
+794
-135
paddle/fluid/operators/elementwise/elementwise_add_op.cc
paddle/fluid/operators/elementwise/elementwise_add_op.cc
+28
-0
paddle/fluid/operators/elementwise/elementwise_add_op.cu
paddle/fluid/operators/elementwise/elementwise_add_op.cu
+72
-1
paddle/fluid/operators/elementwise/elementwise_add_op.h
paddle/fluid/operators/elementwise/elementwise_add_op.h
+24
-41
paddle/fluid/operators/elementwise/elementwise_div_op.cc
paddle/fluid/operators/elementwise/elementwise_div_op.cc
+28
-0
paddle/fluid/operators/elementwise/elementwise_div_op.cu
paddle/fluid/operators/elementwise/elementwise_div_op.cu
+78
-0
paddle/fluid/operators/elementwise/elementwise_div_op.h
paddle/fluid/operators/elementwise/elementwise_div_op.h
+56
-10
paddle/fluid/operators/elementwise/elementwise_mul_op.cc
paddle/fluid/operators/elementwise/elementwise_mul_op.cc
+28
-0
paddle/fluid/operators/elementwise/elementwise_mul_op.cu
paddle/fluid/operators/elementwise/elementwise_mul_op.cu
+53
-35
paddle/fluid/operators/elementwise/elementwise_mul_op.h
paddle/fluid/operators/elementwise/elementwise_mul_op.h
+41
-34
paddle/fluid/operators/elementwise/elementwise_op_function.cu.h
.../fluid/operators/elementwise/elementwise_op_function.cu.h
+159
-0
paddle/fluid/operators/elementwise/elementwise_sub_op.cc
paddle/fluid/operators/elementwise/elementwise_sub_op.cc
+27
-0
paddle/fluid/operators/elementwise/elementwise_sub_op.cu
paddle/fluid/operators/elementwise/elementwise_sub_op.cu
+75
-0
paddle/fluid/operators/elementwise/elementwise_sub_op.h
paddle/fluid/operators/elementwise/elementwise_sub_op.h
+57
-11
paddle/fluid/operators/math/blas.h
paddle/fluid/operators/math/blas.h
+16
-0
paddle/fluid/operators/math/blas_impl.h
paddle/fluid/operators/math/blas_impl.h
+48
-0
paddle/fluid/platform/dynload/mklml.h
paddle/fluid/platform/dynload/mklml.h
+4
-3
未找到文件。
paddle/fluid/operators/elementwise/elementwise_add_op.cc
浏览文件 @
425279a5
...
...
@@ -20,6 +20,34 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
struct
SameDimsElemwiseAdd
<
platform
::
CPUDeviceContext
,
T
,
typename
std
::
enable_if
<
std
::
is_floating_point
<
T
>::
value
>::
type
>
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
auto
blas
=
math
::
GetBlas
<
platform
::
CPUDeviceContext
,
T
>
(
ctx
);
blas
.
VADD
(
x
->
numel
(),
x
->
data
<
T
>
(),
y
->
data
<
T
>
(),
z
->
data
<
T
>
());
}
};
template
<
typename
T
>
struct
SameDimsElemwiseAdd
<
platform
::
CPUDeviceContext
,
T
,
typename
std
::
enable_if
<!
std
::
is_floating_point
<
T
>::
value
>::
type
>
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
auto
eigen_x
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
x
);
auto
eigen_y
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
y
);
auto
eigen_z
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
z
);
auto
&
place
=
*
ctx
.
template
device_context
<
platform
::
CPUDeviceContext
>()
.
eigen_device
();
eigen_z
.
device
(
place
)
=
eigen_x
+
eigen_y
;
}
};
class
ElementwiseAddOpMaker
:
public
ElementwiseOpMaker
{
protected:
std
::
string
GetName
()
const
override
{
return
"Add"
;
}
...
...
paddle/fluid/operators/elementwise/elementwise_add_op.cu
浏览文件 @
425279a5
...
...
@@ -11,13 +11,84 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/platform/float16.h"
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
struct
SameDimsElemwiseAdd
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
AddRangeFunctor
<
T
>
functor
(
x
->
data
<
T
>
(),
y
->
data
<
T
>
(),
z
->
data
<
T
>
());
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
platform
::
ForRange
<
platform
::
CUDADeviceContext
>
for_range
(
dev_ctx
,
x
->
numel
());
for_range
(
functor
);
}
};
template
<
>
struct
SameDimsElemwiseAdd
<
platform
::
CUDADeviceContext
,
platform
::
float16
>
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
auto
size
=
x
->
numel
();
dim3
gird_size
=
dim3
(
(
size
/
2
+
PADDLE_CUDA_THREAD_SIZE
-
1
)
/
PADDLE_CUDA_THREAD_SIZE
,
1
);
dim3
block_size
=
dim3
(
PADDLE_CUDA_THREAD_SIZE
,
1
);
const
half
*
x2
=
reinterpret_cast
<
const
half
*>
(
x
->
data
<
platform
::
float16
>
());
const
half
*
y2
=
reinterpret_cast
<
const
half
*>
(
y
->
data
<
platform
::
float16
>
());
half
*
z2
=
reinterpret_cast
<
half
*>
(
z
->
data
<
platform
::
float16
>
());
SameDimsElemwiseAddCUDAKernel
<<<
gird_size
,
block_size
,
0
,
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>().
stream
()
>>>
(
x2
,
y2
,
z2
,
size
);
}
};
template
<
typename
T
>
static
__global__
void
SimpleElemwiseAddGradCUDAKernel
(
const
T
*
dout
,
int64_t
size
,
T
*
dx
,
T
*
dy
)
{
int
col
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
while
(
col
<
size
)
{
dx
[
col
]
=
dout
[
col
];
dy
[
col
]
=
dout
[
col
];
col
+=
blockDim
.
x
*
gridDim
.
x
;
}
}
template
<
typename
DeviceContext
,
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
DeviceContext
,
plat
::
CUDADeviceContext
>::
value
>::
type
elementwise_add_grad
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
)
{
dim3
block_size
=
dim3
(
PADDLE_CUDA_THREAD_SIZE
,
1
);
auto
size
=
x
->
numel
();
dim3
gird_size
=
dim3
((
size
+
PADDLE_CUDA_THREAD_SIZE
-
1
)
/
PADDLE_CUDA_THREAD_SIZE
,
1
);
SimpleElemwiseAddGradCUDAKernel
<
T
><<<
gird_size
,
block_size
,
0
,
ctx
.
template
device_context
<
plat
::
CUDADeviceContext
>().
stream
()
>>>
(
dout
->
data
<
T
>
(),
size
,
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
}
}
// namespace operators
}
// namespace paddle
REGISTER_OP_CUDA_KERNEL
(
elementwise_add
,
ops
::
ElementwiseAddKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
ElementwiseAddKernel
<
plat
::
CUDADeviceContext
,
double
>
,
...
...
paddle/fluid/operators/elementwise/elementwise_add_op.h
浏览文件 @
425279a5
...
...
@@ -11,22 +11,15 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
struct
AddFunctor
{
inline
HOSTDEVICE
T
operator
()(
T
a
,
T
b
)
const
{
return
a
+
b
;
}
};
template
<
typename
DeviceContext
,
typename
T
>
void
default_elementwise_add
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
...
...
@@ -36,31 +29,12 @@ void default_elementwise_add(const framework::ExecutionContext &ctx,
AddFunctor
<
T
>
(),
z
);
}
template
<
typename
DeviceContext
,
typename
T
>
typename
std
::
enable_if
<
std
::
is_floating_point
<
T
>::
value
&&
std
::
is_same
<
DeviceContext
,
platform
::
CPUDeviceContext
>::
value
>::
type
elementwise_add_same_dims
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
ctx
);
blas
.
VADD
(
x
->
numel
(),
x
->
data
<
T
>
(),
y
->
data
<
T
>
(),
z
->
data
<
T
>
());
}
template
<
typename
DeviceContext
,
typename
T
>
typename
std
::
enable_if
<
!
std
::
is_floating_point
<
T
>::
value
||
!
std
::
is_same
<
DeviceContext
,
platform
::
CPUDeviceContext
>::
value
>::
type
elementwise_add_same_dims
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
auto
eigen_x
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
x
);
auto
eigen_y
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
y
);
auto
eigen_z
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
z
);
auto
&
place
=
*
ctx
.
template
device_context
<
DeviceContext
>().
eigen_device
();
eigen_z
.
device
(
place
)
=
eigen_x
+
eigen_y
;
}
template
<
typename
DeviceContext
,
typename
T
,
class
Enable
=
void
>
struct
SameDimsElemwiseAdd
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
);
};
template
<
typename
DeviceContext
,
typename
T
>
class
ElementwiseAddKernel
:
public
framework
::
OpKernel
<
T
>
{
...
...
@@ -69,12 +43,11 @@ class ElementwiseAddKernel : public framework::OpKernel<T> {
auto
*
x
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
auto
*
y
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Y"
);
auto
*
z
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
z
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
dims_equal
=
x
->
dims
()
==
y
->
dims
();
if
(
dims_equal
)
{
elementwise_add_same_dims
<
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
z
);
SameDimsElemwiseAdd
<
DeviceContext
,
T
>
same_dims_add
;
same_dims_add
(
ctx
,
x
,
y
,
z
);
}
else
{
default_elementwise_add
<
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
z
);
}
...
...
@@ -112,7 +85,6 @@ elementwise_add_grad(const framework::ExecutionContext &ctx,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
)
{
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
ctx
);
if
(
dx
)
{
blas
.
VCOPY
(
dout
->
numel
(),
dout
->
data
<
T
>
(),
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
...
...
@@ -126,8 +98,8 @@ elementwise_add_grad(const framework::ExecutionContext &ctx,
template
<
typename
DeviceContext
,
typename
T
>
typename
std
::
enable_if
<
!
std
::
is_floating_point
<
T
>::
value
||
!
std
::
is_same
<
DeviceContext
,
platform
::
CPUDeviceContext
>::
value
>::
type
!
std
::
is_floating_point
<
T
>::
value
&&
std
::
is_same
<
DeviceContext
,
platform
::
CPUDeviceContext
>::
value
>::
type
elementwise_add_grad
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
const
framework
::
Tensor
*
out
,
...
...
@@ -136,6 +108,18 @@ elementwise_add_grad(const framework::ExecutionContext &ctx,
default_elementwise_add_grad
<
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
out
,
dout
,
dx
,
dy
);
}
#ifdef PADDLE_WITH_CUDA
// cuda definition
template
<
typename
DeviceContext
,
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
DeviceContext
,
platform
::
CUDADeviceContext
>::
value
>::
type
elementwise_add_grad
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
);
#endif
template
<
typename
DeviceContext
,
typename
T
>
class
ElementwiseAddGradKernel
:
public
ElemwiseGradKernel
<
T
>
{
public:
...
...
@@ -151,8 +135,7 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> {
auto
*
out
=
dout
;
auto
*
x
=
dout
,
*
y
=
dout
;
if
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
())
&&
dx
!=
nullptr
&&
dy
!=
nullptr
&&
(
dx
->
dims
()
==
dy
->
dims
()))
{
if
(
dx
!=
nullptr
&&
dy
!=
nullptr
&&
(
dx
->
dims
()
==
dy
->
dims
()))
{
elementwise_add_grad
<
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
out
,
dout
,
dx
,
dy
);
}
else
{
default_elementwise_add_grad
<
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
out
,
dout
,
dx
,
...
...
paddle/fluid/operators/elementwise/elementwise_div_op.cc
浏览文件 @
425279a5
...
...
@@ -20,6 +20,34 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
struct
SameDimsElemwiseDiv
<
platform
::
CPUDeviceContext
,
T
,
typename
std
::
enable_if
<
std
::
is_floating_point
<
T
>::
value
>::
type
>
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
auto
blas
=
math
::
GetBlas
<
platform
::
CPUDeviceContext
,
T
>
(
ctx
);
blas
.
VDIV
(
x
->
numel
(),
x
->
data
<
T
>
(),
y
->
data
<
T
>
(),
z
->
data
<
T
>
());
}
};
template
<
typename
T
>
struct
SameDimsElemwiseDiv
<
platform
::
CPUDeviceContext
,
T
,
typename
std
::
enable_if
<!
std
::
is_floating_point
<
T
>::
value
>::
type
>
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
auto
eigen_x
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
x
);
auto
eigen_y
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
y
);
auto
eigen_z
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
z
);
auto
&
place
=
*
ctx
.
template
device_context
<
platform
::
CPUDeviceContext
>()
.
eigen_device
();
eigen_z
.
device
(
place
)
=
eigen_x
/
eigen_y
;
}
};
class
ElementwiseDivOpMaker
:
public
ElementwiseOpMaker
{
protected:
std
::
string
GetName
()
const
override
{
return
"Div"
;
}
...
...
paddle/fluid/operators/elementwise/elementwise_div_op.cu
浏览文件 @
425279a5
...
...
@@ -12,9 +12,87 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_div_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/platform/float16.h"
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
struct
SameDimsElemwiseDiv
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
DivRangeFunctor
<
T
>
functor
(
x
->
data
<
T
>
(),
y
->
data
<
T
>
(),
z
->
data
<
T
>
());
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
platform
::
ForRange
<
platform
::
CUDADeviceContext
>
for_range
(
dev_ctx
,
x
->
numel
());
for_range
(
functor
);
}
};
template
<
>
struct
SameDimsElemwiseDiv
<
platform
::
CUDADeviceContext
,
platform
::
float16
>
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
auto
size
=
x
->
numel
();
dim3
gird_size
=
dim3
(
(
size
/
2
+
PADDLE_CUDA_THREAD_SIZE
-
1
)
/
PADDLE_CUDA_THREAD_SIZE
,
1
);
dim3
block_size
=
dim3
(
PADDLE_CUDA_THREAD_SIZE
,
1
);
const
half
*
x2
=
reinterpret_cast
<
const
half
*>
(
x
->
data
<
platform
::
float16
>
());
const
half
*
y2
=
reinterpret_cast
<
const
half
*>
(
y
->
data
<
platform
::
float16
>
());
half
*
z2
=
reinterpret_cast
<
half
*>
(
z
->
data
<
platform
::
float16
>
());
SameDimsElemwiseDivCUDAKernel
<<<
gird_size
,
block_size
,
0
,
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>().
stream
()
>>>
(
x2
,
y2
,
z2
,
size
);
}
};
template
<
typename
T
>
static
__global__
void
SimpleElemwiseDivGradCUDAKernel
(
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int64_t
size
,
T
*
dx
,
T
*
dy
)
{
int
col
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
while
(
col
<
size
)
{
T
o
=
dout
[
col
];
dx
[
col
]
=
o
/
y
[
col
];
dy
[
col
]
=
-
o
*
out
[
col
]
/
y
[
col
];
col
+=
blockDim
.
x
*
gridDim
.
x
;
}
}
template
<
typename
DeviceContext
,
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
DeviceContext
,
plat
::
CUDADeviceContext
>::
value
>::
type
elementwise_div_grad
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
)
{
dim3
block_size
=
dim3
(
PADDLE_CUDA_THREAD_SIZE
,
1
);
auto
size
=
x
->
numel
();
dim3
gird_size
=
dim3
((
size
+
PADDLE_CUDA_THREAD_SIZE
-
1
)
/
PADDLE_CUDA_THREAD_SIZE
,
1
);
SimpleElemwiseDivGradCUDAKernel
<
T
><<<
gird_size
,
block_size
,
0
,
ctx
.
template
device_context
<
plat
::
CUDADeviceContext
>().
stream
()
>>>
(
x
->
data
<
T
>
(),
y
->
data
<
T
>
(),
out
->
data
<
T
>
(),
dout
->
data
<
T
>
(),
size
,
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
}
}
// namespace operators
}
// namespace paddle
REGISTER_OP_CUDA_KERNEL
(
elementwise_div
,
...
...
paddle/fluid/operators/elementwise/elementwise_div_op.h
浏览文件 @
425279a5
...
...
@@ -17,16 +17,29 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
struct
DivFunctor
{
inline
HOSTDEVICE
T
operator
()(
T
a
,
T
b
)
const
{
return
a
/
b
;
}
template
<
typename
DeviceContext
,
typename
T
>
void
default_elementwise_div
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
ElementwiseComputeEx
<
DivFunctor
<
T
>
,
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
axis
,
DivFunctor
<
T
>
(),
z
);
}
template
<
typename
DeviceContext
,
typename
T
,
class
Enable
=
void
>
struct
SameDimsElemwiseDiv
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
);
};
template
<
typename
DeviceContext
,
typename
T
>
...
...
@@ -36,11 +49,15 @@ class ElementwiseDivKernel : public framework::OpKernel<T> {
auto
*
x
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
auto
*
y
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Y"
);
auto
*
z
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
z
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
ElementwiseComputeEx
<
DivFunctor
<
T
>
,
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
axis
,
DivFunctor
<
T
>
(),
z
);
auto
dims_equal
=
x
->
dims
()
==
y
->
dims
();
if
(
dims_equal
)
{
SameDimsElemwiseDiv
<
DeviceContext
,
T
>
same_dims_div
;
same_dims_div
(
ctx
,
x
,
y
,
z
);
}
else
{
default_elementwise_div
<
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
z
);
}
}
};
...
...
@@ -63,6 +80,31 @@ struct DivDoubleDY {
}
};
template
<
typename
DeviceContext
,
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
DeviceContext
,
platform
::
CPUDeviceContext
>::
value
>::
type
elementwise_div_grad
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
)
{
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
ElemwiseGradCompute
<
DeviceContext
,
T
,
DivGradDX
<
T
>
,
DivGradDY
<
T
>>
(
ctx
,
*
x
,
*
y
,
*
out
,
*
dout
,
axis
,
dx
,
dy
,
DivGradDX
<
T
>
(),
DivGradDY
<
T
>
());
}
#ifdef PADDLE_WITH_CUDA
// cuda definition
template
<
typename
DeviceContext
,
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
DeviceContext
,
platform
::
CUDADeviceContext
>::
value
>::
type
elementwise_div_grad
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
);
#endif
template
<
typename
DeviceContext
,
typename
T
>
class
ElementwiseDivGradKernel
:
public
ElemwiseGradKernel
<
T
>
{
public:
...
...
@@ -76,11 +118,15 @@ class ElementwiseDivGradKernel : public ElemwiseGradKernel<T> {
auto
*
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dy
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
auto
*
x
=
dout
;
// Fake x, not used
ElemwiseGradCompute
<
DeviceContext
,
T
,
DivGradDX
<
T
>
,
DivGradDY
<
T
>>
(
ctx
,
*
x
,
*
y
,
*
out
,
*
dout
,
axis
,
dx
,
dy
,
DivGradDX
<
T
>
(),
DivGradDY
<
T
>
());
if
(
dx
!=
nullptr
&&
dy
!=
nullptr
&&
(
dx
->
dims
()
==
dy
->
dims
()))
{
elementwise_div_grad
<
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
out
,
dout
,
dx
,
dy
);
}
else
{
ElemwiseGradCompute
<
DeviceContext
,
T
,
DivGradDX
<
T
>
,
DivGradDY
<
T
>>
(
ctx
,
*
x
,
*
y
,
*
out
,
*
dout
,
axis
,
dx
,
dy
,
DivGradDX
<
T
>
(),
DivGradDY
<
T
>
());
}
}
};
...
...
paddle/fluid/operators/elementwise/elementwise_mul_op.cc
浏览文件 @
425279a5
...
...
@@ -20,6 +20,34 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
struct
SameDimsElemwiseMul
<
platform
::
CPUDeviceContext
,
T
,
typename
std
::
enable_if
<
std
::
is_floating_point
<
T
>::
value
>::
type
>
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
auto
blas
=
math
::
GetBlas
<
platform
::
CPUDeviceContext
,
T
>
(
ctx
);
blas
.
VMUL
(
x
->
numel
(),
x
->
data
<
T
>
(),
y
->
data
<
T
>
(),
z
->
data
<
T
>
());
}
};
template
<
typename
T
>
struct
SameDimsElemwiseMul
<
platform
::
CPUDeviceContext
,
T
,
typename
std
::
enable_if
<!
std
::
is_floating_point
<
T
>::
value
>::
type
>
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
auto
eigen_x
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
x
);
auto
eigen_y
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
y
);
auto
eigen_z
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
z
);
auto
&
place
=
*
ctx
.
template
device_context
<
platform
::
CPUDeviceContext
>()
.
eigen_device
();
eigen_z
.
device
(
place
)
=
eigen_x
*
eigen_y
;
}
};
class
ElementwiseMulOpMaker
:
public
ElementwiseOpMaker
{
protected:
std
::
string
GetName
()
const
override
{
return
"Mul"
;
}
...
...
paddle/fluid/operators/elementwise/elementwise_mul_op.cu
浏览文件 @
425279a5
...
...
@@ -13,15 +13,49 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/platform/float16.h"
#define TILE_SIZE 512
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
struct
SameDimsElemwiseMul
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
MulRangeFunctor
<
T
>
functor
(
x
->
data
<
T
>
(),
y
->
data
<
T
>
(),
z
->
data
<
T
>
());
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
platform
::
ForRange
<
platform
::
CUDADeviceContext
>
for_range
(
dev_ctx
,
x
->
numel
());
for_range
(
functor
);
}
};
template
<
>
struct
SameDimsElemwiseMul
<
platform
::
CUDADeviceContext
,
platform
::
float16
>
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
auto
size
=
x
->
numel
();
dim3
gird_size
=
dim3
(
(
size
/
2
+
PADDLE_CUDA_THREAD_SIZE
-
1
)
/
PADDLE_CUDA_THREAD_SIZE
,
1
);
dim3
block_size
=
dim3
(
PADDLE_CUDA_THREAD_SIZE
,
1
);
const
half
*
x2
=
reinterpret_cast
<
const
half
*>
(
x
->
data
<
platform
::
float16
>
());
const
half
*
y2
=
reinterpret_cast
<
const
half
*>
(
y
->
data
<
platform
::
float16
>
());
half
*
z2
=
reinterpret_cast
<
half
*>
(
z
->
data
<
platform
::
float16
>
());
SameDimsElemwiseMulCUDAKernel
<<<
gird_size
,
block_size
,
0
,
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>().
stream
()
>>>
(
x2
,
y2
,
z2
,
size
);
}
};
template
<
typename
T
>
static
__global__
void
SimpleElemwiseMulGradCUDAKernel
(
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
...
...
@@ -38,40 +72,24 @@ static __global__ void SimpleElemwiseMulGradCUDAKernel(const T* x, const T* y,
}
}
template
<
typename
T
>
class
ElementwiseMulGradKernel
<
plat
::
CUDADeviceContext
,
T
>
:
public
ElemwiseGradKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
ElemwiseGradKernel
<
T
>::
Compute
(
ctx
);
using
Tensor
=
framework
::
Tensor
;
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
y
=
ctx
.
Input
<
Tensor
>
(
"Y"
);
auto
*
dout
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
out
=
dout
;
// out is not necessary
auto
*
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dy
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
if
(
x
->
dims
()
==
y
->
dims
()
&&
dx
&&
dy
)
{
dim3
block_size
=
dim3
(
TILE_SIZE
,
1
);
auto
size
=
x
->
numel
();
dim3
gird_size
=
dim3
((
size
+
TILE_SIZE
-
1
)
/
TILE_SIZE
,
1
);
SimpleElemwiseMulGradCUDAKernel
<
T
><<<
gird_size
,
block_size
,
0
,
ctx
.
template
device_context
<
plat
::
CUDADeviceContext
>().
stream
()
>>>
(
x
->
data
<
T
>
(),
y
->
data
<
T
>
(),
out
->
data
<
T
>
(),
dout
->
data
<
T
>
(),
size
,
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
return
;
}
else
{
ElemwiseGradCompute
<
plat
::
CUDADeviceContext
,
T
,
MulGradDX
<
T
>
,
MulGradDY
<
T
>>
(
ctx
,
*
x
,
*
y
,
*
out
,
*
dout
,
axis
,
dx
,
dy
,
MulGradDX
<
T
>
(),
MulGradDY
<
T
>
());
}
}
};
template
<
typename
DeviceContext
,
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
DeviceContext
,
plat
::
CUDADeviceContext
>::
value
>::
type
elementwise_mul_grad
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
)
{
dim3
block_size
=
dim3
(
PADDLE_CUDA_THREAD_SIZE
,
1
);
auto
size
=
x
->
numel
();
dim3
gird_size
=
dim3
((
size
+
PADDLE_CUDA_THREAD_SIZE
-
1
)
/
PADDLE_CUDA_THREAD_SIZE
,
1
);
SimpleElemwiseMulGradCUDAKernel
<
T
><<<
gird_size
,
block_size
,
0
,
ctx
.
template
device_context
<
plat
::
CUDADeviceContext
>().
stream
()
>>>
(
x
->
data
<
T
>
(),
y
->
data
<
T
>
(),
out
->
data
<
T
>
(),
dout
->
data
<
T
>
(),
size
,
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
}
}
// namespace operators
}
// namespace paddle
...
...
paddle/fluid/operators/elementwise/elementwise_mul_op.h
浏览文件 @
425279a5
...
...
@@ -14,17 +14,13 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
struct
MulFunctor
{
inline
HOSTDEVICE
T
operator
()(
T
a
,
T
b
)
const
{
return
a
*
b
;
}
};
template
<
typename
DeviceContext
,
typename
T
>
void
default_elementwise_mul
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
...
...
@@ -33,32 +29,12 @@ void default_elementwise_mul(const framework::ExecutionContext& ctx,
ElementwiseComputeEx
<
MulFunctor
<
T
>
,
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
axis
,
MulFunctor
<
T
>
(),
z
);
}
template
<
typename
DeviceContext
,
typename
T
>
typename
std
::
enable_if
<
std
::
is_floating_point
<
T
>::
value
&&
std
::
is_same
<
DeviceContext
,
platform
::
CPUDeviceContext
>::
value
>::
type
elementwise_mul_same_dims
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
ctx
);
blas
.
VMUL
(
x
->
numel
(),
x
->
data
<
T
>
(),
y
->
data
<
T
>
(),
z
->
data
<
T
>
());
}
template
<
typename
DeviceContext
,
typename
T
>
typename
std
::
enable_if
<
!
std
::
is_floating_point
<
T
>::
value
||
!
std
::
is_same
<
DeviceContext
,
platform
::
CPUDeviceContext
>::
value
>::
type
elementwise_mul_same_dims
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
auto
eigen_x
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
x
);
auto
eigen_y
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
y
);
auto
eigen_z
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
z
);
auto
&
place
=
*
ctx
.
template
device_context
<
DeviceContext
>().
eigen_device
();
eigen_z
.
device
(
place
)
=
eigen_x
*
eigen_y
;
}
template
<
typename
DeviceContext
,
typename
T
,
class
Enable
=
void
>
struct
SameDimsElemwiseMul
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
);
};
template
<
typename
DeviceContext
,
typename
T
>
class
ElementwiseMulKernel
:
public
framework
::
OpKernel
<
T
>
{
...
...
@@ -92,7 +68,8 @@ class ElementwiseMulKernel : public framework::OpKernel<T> {
z
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
if
(
x
.
numel
()
==
y
->
numel
())
{
elementwise_mul_same_dims
<
DeviceContext
,
T
>
(
ctx
,
&
x
,
y
,
z
);
SameDimsElemwiseMul
<
DeviceContext
,
T
>
same_dims_mul
;
same_dims_mul
(
ctx
,
&
x
,
y
,
z
);
}
else
{
default_elementwise_mul
<
DeviceContext
,
T
>
(
ctx
,
&
x
,
y
,
z
);
}
...
...
@@ -109,6 +86,31 @@ struct MulGradDY {
HOSTDEVICE
T
operator
()(
T
x
,
T
y
,
T
out
,
T
dout
)
const
{
return
dout
*
x
;
}
};
template
<
typename
DeviceContext
,
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
DeviceContext
,
platform
::
CPUDeviceContext
>::
value
>::
type
elementwise_mul_grad
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
)
{
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
ElemwiseGradCompute
<
DeviceContext
,
T
,
MulGradDX
<
T
>
,
MulGradDY
<
T
>>
(
ctx
,
*
x
,
*
y
,
*
out
,
*
dout
,
axis
,
dx
,
dy
,
MulGradDX
<
T
>
(),
MulGradDY
<
T
>
());
}
#ifdef PADDLE_WITH_CUDA
// cuda definition
template
<
typename
DeviceContext
,
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
DeviceContext
,
platform
::
CUDADeviceContext
>::
value
>::
type
elementwise_mul_grad
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
);
#endif
template
<
typename
DeviceContext
,
typename
T
>
class
ElementwiseMulGradKernel
:
public
ElemwiseGradKernel
<
T
>
{
public:
...
...
@@ -123,8 +125,13 @@ class ElementwiseMulGradKernel : public ElemwiseGradKernel<T> {
auto
*
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dy
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
ElemwiseGradCompute
<
DeviceContext
,
T
,
MulGradDX
<
T
>
,
MulGradDY
<
T
>>
(
ctx
,
*
x
,
*
y
,
*
out
,
*
dout
,
axis
,
dx
,
dy
,
MulGradDX
<
T
>
(),
MulGradDY
<
T
>
());
if
(
dx
!=
nullptr
&&
dy
!=
nullptr
&&
(
dx
->
dims
()
==
dy
->
dims
()))
{
elementwise_mul_grad
<
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
out
,
dout
,
dx
,
dy
);
}
else
{
ElemwiseGradCompute
<
DeviceContext
,
T
,
MulGradDX
<
T
>
,
MulGradDY
<
T
>>
(
ctx
,
*
x
,
*
y
,
*
out
,
*
dout
,
axis
,
dx
,
dy
,
MulGradDX
<
T
>
(),
MulGradDY
<
T
>
());
}
}
};
...
...
paddle/fluid/operators/elementwise/elementwise_op_function.cu.h
0 → 100644
浏览文件 @
425279a5
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <glog/logging.h>
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/hostdevice.h"
#define PADDLE_CUDA_THREAD_SIZE 512
#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
#endif // PADDLE_WITH_CUDA
#ifdef PADDLE_CUDA_FP16
#include <cuda_fp16.h>
#endif
#if CUDA_VERSION < 9000
#define __h2div h2div
#endif
namespace
paddle
{
namespace
operators
{
#define DEFINE_SIMPLE_BINARY_FUNCTOR(Func, expr) \
template <typename T> \
struct Func##Functor { \
inline HOSTDEVICE T operator()(const T& a, const T& b) const { \
return a expr b; \
} \
};
DEFINE_SIMPLE_BINARY_FUNCTOR
(
Add
,
+
)
DEFINE_SIMPLE_BINARY_FUNCTOR
(
Sub
,
-
)
DEFINE_SIMPLE_BINARY_FUNCTOR
(
Mul
,
*
)
DEFINE_SIMPLE_BINARY_FUNCTOR
(
Div
,
/
)
#undef DEFINE_SIMPLE_BINARY_FUNCTOR
#define DEFINE_SIMPLE_CUDA_BINARY_FUNCTOR(Func, expr) \
template <typename T> \
struct Func##RangeFunctor { \
Func##RangeFunctor(const T* x, const T* y, T* z) : x_(x), y_(y), z_(z) {} \
inline HOSTDEVICE void operator()(size_t id) const { \
z_[id] = x_[id] expr y_[id]; \
} \
const T* x_; \
const T* y_; \
T* z_; \
};
DEFINE_SIMPLE_CUDA_BINARY_FUNCTOR
(
Add
,
+
)
DEFINE_SIMPLE_CUDA_BINARY_FUNCTOR
(
Sub
,
-
)
DEFINE_SIMPLE_CUDA_BINARY_FUNCTOR
(
Mul
,
*
)
DEFINE_SIMPLE_CUDA_BINARY_FUNCTOR
(
Div
,
/
)
#undef DEFINE_SIMPLE_CUDA_BINARY_FUNCTOR
#ifdef PADDLE_CUDA_FP16
inline
DEVICE
half2
half2_add
(
const
half2
&
a
,
const
half2
&
b
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return
__hadd2
(
a
,
b
);
#else
float
a1
=
__low2float
(
a
);
float
a2
=
__high2float
(
a
);
float
b1
=
__low2float
(
b
);
float
b2
=
__high2float
(
b
);
float
r1
=
a1
+
b1
;
float
r2
=
a2
+
b2
;
return
__floats2half2_rn
(
r1
,
r2
);
#endif
}
inline
DEVICE
half2
half2_sub
(
const
half2
&
a
,
const
half2
&
b
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return
__hsub2
(
a
,
b
);
#else
float
a1
=
__low2float
(
a
);
float
a2
=
__high2float
(
a
);
float
b1
=
__low2float
(
b
);
float
b2
=
__high2float
(
b
);
float
r1
=
a1
-
b1
;
float
r2
=
a2
-
b2
;
return
__floats2half2_rn
(
r1
,
r2
);
#endif
}
inline
DEVICE
half2
half2_mul
(
const
half2
&
a
,
const
half2
&
b
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return
__hmul2
(
a
,
b
);
#else
float
a1
=
__low2float
(
a
);
float
a2
=
__high2float
(
a
);
float
b1
=
__low2float
(
b
);
float
b2
=
__high2float
(
b
);
float
r1
=
a1
*
b1
;
float
r2
=
a2
*
b2
;
return
__floats2half2_rn
(
r1
,
r2
);
#endif
}
inline
DEVICE
half2
half2_div
(
const
half2
&
a
,
const
half2
&
b
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return
__h2div
(
a
,
b
);
#else
float
a1
=
__low2float
(
a
);
float
a2
=
__high2float
(
a
);
float
b1
=
__low2float
(
b
);
float
b2
=
__high2float
(
b
);
float
r1
=
a1
/
b1
;
float
r2
=
a2
/
b2
;
return
__floats2half2_rn
(
r1
,
r2
);
#endif
}
#define DEFINE_SIMPLE_CUDA_BINARY_KERNEL(Func, expr, FP16Function) \
template <typename T> \
__global__ void SameDimsElemwise##Func##CUDAKernel(const T* x, const T* y, \
T* z, int64_t size) { \
int col = blockIdx.x * blockDim.x + threadIdx.x; \
while (col < size) { \
z[col] = x[col] expr y[col]; \
col += blockDim.x * gridDim.x; \
} \
} \
template <> \
inline __global__ void SameDimsElemwise##Func##CUDAKernel<half>( \
const half* x, const half* y, half* z, int64_t size) { \
int start = threadIdx.x + blockDim.x * blockIdx.x; \
int stride = blockDim.x * gridDim.x; \
int n2 = size / 2; \
const half2* x2 = reinterpret_cast<const half2*>(x); \
const half2* y2 = reinterpret_cast<const half2*>(y); \
half2* z2 = reinterpret_cast<half2*>(z); \
for (int i = start; i < n2; i += stride) { \
z2[i] = FP16Function(x2[i], y2[i]); \
} \
if (start == 0 && (size % 2)) { \
z[size - 1] = __float2half(__half2float(x[size - 1]) \
expr __half2float(y[size - 1])); \
} \
}
DEFINE_SIMPLE_CUDA_BINARY_KERNEL
(
Add
,
+
,
half2_add
)
DEFINE_SIMPLE_CUDA_BINARY_KERNEL
(
Sub
,
-
,
half2_sub
)
DEFINE_SIMPLE_CUDA_BINARY_KERNEL
(
Mul
,
*
,
half2_mul
)
DEFINE_SIMPLE_CUDA_BINARY_KERNEL
(
Div
,
/
,
half2_div
)
#undef DEFINE_SIMPLE_CUDA_BINARY_KERNEL
#endif // PADDLE_CUDA_FP16
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/elementwise/elementwise_sub_op.cc
浏览文件 @
425279a5
...
...
@@ -20,6 +20,33 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
struct
SameDimsElemwiseSub
<
platform
::
CPUDeviceContext
,
T
,
typename
std
::
enable_if
<
std
::
is_floating_point
<
T
>::
value
>::
type
>
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
auto
blas
=
math
::
GetBlas
<
platform
::
CPUDeviceContext
,
T
>
(
ctx
);
blas
.
VSUB
(
x
->
numel
(),
x
->
data
<
T
>
(),
y
->
data
<
T
>
(),
z
->
data
<
T
>
());
}
};
template
<
typename
T
>
struct
SameDimsElemwiseSub
<
platform
::
CPUDeviceContext
,
T
,
typename
std
::
enable_if
<!
std
::
is_floating_point
<
T
>::
value
>::
type
>
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
auto
eigen_x
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
x
);
auto
eigen_y
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
y
);
auto
eigen_z
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
z
);
auto
&
place
=
*
ctx
.
template
device_context
<
platform
::
CPUDeviceContext
>()
.
eigen_device
();
eigen_z
.
device
(
place
)
=
eigen_x
-
eigen_y
;
}
};
class
ElementwiseSubOpMaker
:
public
ElementwiseOpMaker
{
protected:
std
::
string
GetName
()
const
override
{
return
"Sub"
;
}
...
...
paddle/fluid/operators/elementwise/elementwise_sub_op.cu
浏览文件 @
425279a5
...
...
@@ -11,10 +11,85 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h"
#include "paddle/fluid/platform/float16.h"
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
struct
SameDimsElemwiseSub
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
SubRangeFunctor
<
T
>
functor
(
x
->
data
<
T
>
(),
y
->
data
<
T
>
(),
z
->
data
<
T
>
());
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
platform
::
ForRange
<
platform
::
CUDADeviceContext
>
for_range
(
dev_ctx
,
x
->
numel
());
for_range
(
functor
);
}
};
template
<
>
struct
SameDimsElemwiseSub
<
platform
::
CUDADeviceContext
,
platform
::
float16
>
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
auto
size
=
x
->
numel
();
dim3
gird_size
=
dim3
(
(
size
/
2
+
PADDLE_CUDA_THREAD_SIZE
-
1
)
/
PADDLE_CUDA_THREAD_SIZE
,
1
);
dim3
block_size
=
dim3
(
PADDLE_CUDA_THREAD_SIZE
,
1
);
const
half
*
x2
=
reinterpret_cast
<
const
half
*>
(
x
->
data
<
platform
::
float16
>
());
const
half
*
y2
=
reinterpret_cast
<
const
half
*>
(
y
->
data
<
platform
::
float16
>
());
half
*
z2
=
reinterpret_cast
<
half
*>
(
z
->
data
<
platform
::
float16
>
());
SameDimsElemwiseSubCUDAKernel
<<<
gird_size
,
block_size
,
0
,
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>().
stream
()
>>>
(
x2
,
y2
,
z2
,
size
);
}
};
template
<
typename
T
>
static
__global__
void
SimpleElemwiseSubGradCUDAKernel
(
const
T
*
dout
,
int64_t
size
,
T
*
dx
,
T
*
dy
)
{
int
col
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
while
(
col
<
size
)
{
dx
[
col
]
=
dout
[
col
];
dy
[
col
]
=
-
dout
[
col
];
col
+=
blockDim
.
x
*
gridDim
.
x
;
}
}
template
<
typename
DeviceContext
,
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
DeviceContext
,
plat
::
CUDADeviceContext
>::
value
>::
type
elementwise_sub_grad
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
)
{
dim3
block_size
=
dim3
(
PADDLE_CUDA_THREAD_SIZE
,
1
);
auto
size
=
x
->
numel
();
dim3
gird_size
=
dim3
((
size
+
PADDLE_CUDA_THREAD_SIZE
-
1
)
/
PADDLE_CUDA_THREAD_SIZE
,
1
);
SimpleElemwiseSubGradCUDAKernel
<
T
><<<
gird_size
,
block_size
,
0
,
ctx
.
template
device_context
<
plat
::
CUDADeviceContext
>().
stream
()
>>>
(
dout
->
data
<
T
>
(),
size
,
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
}
}
// namespace operators
}
// namespace paddle
REGISTER_OP_CUDA_KERNEL
(
elementwise_sub
,
...
...
paddle/fluid/operators/elementwise/elementwise_sub_op.h
浏览文件 @
425279a5
...
...
@@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
...
...
@@ -14,14 +14,27 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
struct
SubFunctor
{
inline
HOSTDEVICE
T
operator
()(
T
a
,
T
b
)
const
{
return
a
-
b
;
}
template
<
typename
DeviceContext
,
typename
T
>
void
default_elementwise_sub
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
ElementwiseComputeEx
<
SubFunctor
<
T
>
,
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
axis
,
SubFunctor
<
T
>
(),
z
);
}
template
<
typename
DeviceContext
,
typename
T
,
class
Enable
=
void
>
struct
SameDimsElemwiseSub
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
);
};
template
<
typename
DeviceContext
,
typename
T
>
...
...
@@ -31,11 +44,15 @@ class ElementwiseSubKernel : public framework::OpKernel<T> {
auto
*
x
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
auto
*
y
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Y"
);
auto
*
z
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
z
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
ElementwiseComputeEx
<
SubFunctor
<
T
>
,
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
axis
,
SubFunctor
<
T
>
(),
z
);
auto
dims_equal
=
x
->
dims
()
==
y
->
dims
();
if
(
dims_equal
)
{
SameDimsElemwiseSub
<
DeviceContext
,
T
>
same_dims_sub
;
same_dims_sub
(
ctx
,
x
,
y
,
z
);
}
else
{
default_elementwise_sub
<
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
z
);
}
}
};
...
...
@@ -49,6 +66,31 @@ struct SubGradDY {
HOSTDEVICE
T
operator
()(
T
x
,
T
y
,
T
out
,
T
dout
)
const
{
return
-
dout
;
}
};
template
<
typename
DeviceContext
,
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
DeviceContext
,
platform
::
CPUDeviceContext
>::
value
>::
type
elementwise_sub_grad
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
)
{
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
ElemwiseExplicitGradCompute
<
DeviceContext
,
T
,
SubGradDX
<
T
>
,
SubGradDY
<
T
>>
(
ctx
,
*
x
,
*
y
,
*
out
,
*
dout
,
axis
,
dx
,
dy
,
SubGradDX
<
T
>
(),
SubGradDY
<
T
>
());
}
#ifdef PADDLE_WITH_CUDA
// cuda definition
template
<
typename
DeviceContext
,
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
DeviceContext
,
platform
::
CUDADeviceContext
>::
value
>::
type
elementwise_sub_grad
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
);
#endif
template
<
typename
DeviceContext
,
typename
T
>
class
ElementwiseSubGradKernel
:
public
ElemwiseGradKernel
<
T
>
{
public:
...
...
@@ -63,9 +105,13 @@ class ElementwiseSubGradKernel : public ElemwiseGradKernel<T> {
// skip out, x, y
auto
*
out
=
dout
;
auto
*
x
=
dout
,
*
y
=
dout
;
ElemwiseExplicitGradCompute
<
DeviceContext
,
T
,
SubGradDX
<
T
>
,
SubGradDY
<
T
>>
(
ctx
,
*
x
,
*
y
,
*
out
,
*
dout
,
axis
,
dx
,
dy
,
SubGradDX
<
T
>
(),
SubGradDY
<
T
>
());
if
(
dx
!=
nullptr
&&
dy
!=
nullptr
&&
(
dx
->
dims
()
==
dy
->
dims
()))
{
elementwise_sub_grad
<
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
out
,
dout
,
dx
,
dy
);
}
else
{
ElemwiseExplicitGradCompute
<
DeviceContext
,
T
,
SubGradDX
<
T
>
,
SubGradDY
<
T
>>
(
ctx
,
*
x
,
*
y
,
*
out
,
*
dout
,
axis
,
dx
,
dy
,
SubGradDX
<
T
>
(),
SubGradDY
<
T
>
());
}
}
};
...
...
paddle/fluid/operators/math/blas.h
浏览文件 @
425279a5
...
...
@@ -159,9 +159,15 @@ class Blas {
template
<
typename
T
>
void
VADD
(
int
n
,
const
T
*
x
,
const
T
*
y
,
T
*
z
)
const
;
template
<
typename
T
>
void
VSUB
(
int
n
,
const
T
*
x
,
const
T
*
y
,
T
*
z
)
const
;
template
<
typename
T
>
void
VMUL
(
int
n
,
const
T
*
x
,
const
T
*
y
,
T
*
z
)
const
;
template
<
typename
T
>
void
VDIV
(
int
n
,
const
T
*
x
,
const
T
*
y
,
T
*
z
)
const
;
template
<
typename
T
>
void
VCOPY
(
int
n
,
const
T
*
x
,
T
*
y
)
const
;
...
...
@@ -275,11 +281,21 @@ class BlasT : private Blas<DeviceContext> {
Base
()
->
template
VADD
<
T
>(
args
...);
}
template
<
typename
...
ARGS
>
void
VSUB
(
ARGS
...
args
)
const
{
Base
()
->
template
VSUB
<
T
>(
args
...);
}
template
<
typename
...
ARGS
>
void
VMUL
(
ARGS
...
args
)
const
{
Base
()
->
template
VMUL
<
T
>(
args
...);
}
template
<
typename
...
ARGS
>
void
VDIV
(
ARGS
...
args
)
const
{
Base
()
->
template
VDIV
<
T
>(
args
...);
}
template
<
typename
...
ARGS
>
void
VCOPY
(
ARGS
...
args
)
const
{
Base
()
->
template
VCOPY
<
T
>(
args
...);
...
...
paddle/fluid/operators/math/blas_impl.h
浏览文件 @
425279a5
...
...
@@ -99,11 +99,21 @@ struct CBlas<float> {
platform
::
dynload
::
vsAdd
(
args
...);
}
template
<
typename
...
ARGS
>
static
void
VSUB
(
ARGS
...
args
)
{
platform
::
dynload
::
vsSub
(
args
...);
}
template
<
typename
...
ARGS
>
static
void
VMUL
(
ARGS
...
args
)
{
platform
::
dynload
::
vsMul
(
args
...);
}
template
<
typename
...
ARGS
>
static
void
VDIV
(
ARGS
...
args
)
{
platform
::
dynload
::
vsDiv
(
args
...);
}
template
<
typename
...
ARGS
>
static
void
VEXP
(
ARGS
...
args
)
{
platform
::
dynload
::
vsExp
(
args
...);
...
...
@@ -210,11 +220,21 @@ struct CBlas<double> {
platform
::
dynload
::
vdAdd
(
args
...);
}
template
<
typename
...
ARGS
>
static
void
VSUB
(
ARGS
...
args
)
{
platform
::
dynload
::
vdSub
(
args
...);
}
template
<
typename
...
ARGS
>
static
void
VMUL
(
ARGS
...
args
)
{
platform
::
dynload
::
vdMul
(
args
...);
}
template
<
typename
...
ARGS
>
static
void
VDIV
(
ARGS
...
args
)
{
platform
::
dynload
::
vdDiv
(
args
...);
}
template
<
typename
...
ARGS
>
static
void
VEXP
(
ARGS
...
args
)
{
platform
::
dynload
::
vdExp
(
args
...);
...
...
@@ -443,6 +463,20 @@ void Blas<platform::CPUDeviceContext>::VADD(int n, const T *x, const T *y,
#endif
}
template
<
>
template
<
typename
T
>
void
Blas
<
platform
::
CPUDeviceContext
>::
VSUB
(
int
n
,
const
T
*
x
,
const
T
*
y
,
T
*
z
)
const
{
#ifdef PADDLE_WITH_MKLML
CBlas
<
T
>::
VSUB
(
n
,
x
,
y
,
z
);
#else
// try to find if openblas support vsub
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
z
[
i
]
=
x
[
i
]
-
y
[
i
];
}
#endif
}
template
<
>
template
<
typename
T
>
void
Blas
<
platform
::
CPUDeviceContext
>::
VMUL
(
int
n
,
const
T
*
x
,
const
T
*
y
,
...
...
@@ -457,6 +491,20 @@ void Blas<platform::CPUDeviceContext>::VMUL(int n, const T *x, const T *y,
#endif
}
template
<
>
template
<
typename
T
>
void
Blas
<
platform
::
CPUDeviceContext
>::
VDIV
(
int
n
,
const
T
*
x
,
const
T
*
y
,
T
*
z
)
const
{
#ifdef PADDLE_WITH_MKLML
CBlas
<
T
>::
VDIV
(
n
,
x
,
y
,
z
);
#else
// try to find if openblas support vdiv
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
z
[
i
]
=
x
[
i
]
/
y
[
i
];
}
#endif
}
template
<
>
template
<
typename
T
>
void
Blas
<
platform
::
CPUDeviceContext
>::
VEXP
(
int
n
,
const
T
*
x
,
T
*
y
)
const
{
...
...
paddle/fluid/platform/dynload/mklml.h
浏览文件 @
425279a5
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
@@ -76,8 +73,12 @@ extern void* mklml_dso_handle;
__macro(cblas_dscal); \
__macro(vsAdd); \
__macro(vdAdd); \
__macro(vsSub); \
__macro(vdSub); \
__macro(vsMul); \
__macro(vdMul); \
__macro(vsDiv); \
__macro(vdDiv); \
__macro(vsExp); \
__macro(vdExp); \
__macro(vsSqr); \
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录