Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
2367cca6
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2367cca6
编写于
9月 15, 2021
作者:
Y
Yiqun Liu
提交者:
GitHub
9月 15, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Unify the functor definition of elementwise add, sub, mul, div, floordiv, max, min. (#35684)
上级
39dcfc6c
变更
19
隐藏空白更改
内联
并排
Showing
19 changed file
with
149 addition
and
368 deletion
+149
-368
paddle/fluid/operators/elementwise/elementwise_add_op.cu
paddle/fluid/operators/elementwise/elementwise_add_op.cu
+3
-3
paddle/fluid/operators/elementwise/elementwise_add_op.h
paddle/fluid/operators/elementwise/elementwise_add_op.h
+1
-16
paddle/fluid/operators/elementwise/elementwise_div_op.cu
paddle/fluid/operators/elementwise/elementwise_div_op.cu
+4
-21
paddle/fluid/operators/elementwise/elementwise_div_op.h
paddle/fluid/operators/elementwise/elementwise_div_op.h
+0
-1
paddle/fluid/operators/elementwise/elementwise_floordiv_op.cu
...le/fluid/operators/elementwise/elementwise_floordiv_op.cu
+5
-14
paddle/fluid/operators/elementwise/elementwise_floordiv_op.h
paddle/fluid/operators/elementwise/elementwise_floordiv_op.h
+1
-42
paddle/fluid/operators/elementwise/elementwise_functor.h
paddle/fluid/operators/elementwise/elementwise_functor.h
+117
-0
paddle/fluid/operators/elementwise/elementwise_max_op.cu
paddle/fluid/operators/elementwise/elementwise_max_op.cu
+4
-10
paddle/fluid/operators/elementwise/elementwise_max_op.h
paddle/fluid/operators/elementwise/elementwise_max_op.h
+1
-5
paddle/fluid/operators/elementwise/elementwise_min_op.cu
paddle/fluid/operators/elementwise/elementwise_min_op.cu
+4
-10
paddle/fluid/operators/elementwise/elementwise_min_op.h
paddle/fluid/operators/elementwise/elementwise_min_op.h
+1
-5
paddle/fluid/operators/elementwise/elementwise_mul_op.cu
paddle/fluid/operators/elementwise/elementwise_mul_op.cu
+2
-3
paddle/fluid/operators/elementwise/elementwise_mul_op.h
paddle/fluid/operators/elementwise/elementwise_mul_op.h
+1
-1
paddle/fluid/operators/elementwise/elementwise_op_function.cu.h
.../fluid/operators/elementwise/elementwise_op_function.cu.h
+0
-231
paddle/fluid/operators/elementwise/elementwise_op_function.h
paddle/fluid/operators/elementwise/elementwise_op_function.h
+1
-1
paddle/fluid/operators/elementwise/elementwise_sub_op.cu
paddle/fluid/operators/elementwise/elementwise_sub_op.cu
+3
-2
paddle/fluid/operators/elementwise/elementwise_sub_op.h
paddle/fluid/operators/elementwise/elementwise_sub_op.h
+1
-1
paddle/fluid/operators/layer_norm_op.h
paddle/fluid/operators/layer_norm_op.h
+0
-1
paddle/fluid/operators/svd_helper.h
paddle/fluid/operators/svd_helper.h
+0
-1
未找到文件。
paddle/fluid/operators/elementwise/elementwise_add_op.cu
浏览文件 @
2367cca6
...
@@ -147,10 +147,10 @@ elementwise_add_grad(const framework::ExecutionContext& ctx,
...
@@ -147,10 +147,10 @@ elementwise_add_grad(const framework::ExecutionContext& ctx,
}
else
if
(
dx_data
!=
dout_data
&&
dy_data
!=
dout_data
)
{
}
else
if
(
dx_data
!=
dout_data
&&
dy_data
!=
dout_data
)
{
auto
size
=
x
->
numel
();
auto
size
=
x
->
numel
();
int
vec_size
=
max
(
static_cast
<
int
>
(
sizeof
(
float4
)
/
sizeof
(
T
)),
1
);
int
vec_size
=
max
(
static_cast
<
int
>
(
sizeof
(
float4
)
/
sizeof
(
T
)),
1
);
dim3
block_size
=
dim3
(
PADDLE_CUDA_THREAD
_SIZE
,
1
);
dim3
block_size
=
dim3
(
ELEMENTWISE_BLOCK
_SIZE
,
1
);
dim3
grid_size
=
dim3
grid_size
=
dim3
(((
size
+
vec_size
-
1
)
/
vec_size
+
PADDLE_CUDA_THREAD
_SIZE
-
1
)
/
dim3
(((
size
+
vec_size
-
1
)
/
vec_size
+
ELEMENTWISE_BLOCK
_SIZE
-
1
)
/
PADDLE_CUDA_THREAD
_SIZE
,
ELEMENTWISE_BLOCK
_SIZE
,
1
);
1
);
SimpleElemwiseAddGradCUDAKernel
<
SimpleElemwiseAddGradCUDAKernel
<
T
><<<
grid_size
,
block_size
,
0
,
T
><<<
grid_size
,
block_size
,
0
,
...
...
paddle/fluid/operators/elementwise/elementwise_add_op.h
浏览文件 @
2367cca6
...
@@ -11,30 +11,15 @@ distributed under the License is distributed on an "AS IS" BASIS,
...
@@ -11,30 +11,15 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#pragma once
#pragma once
#include <algorithm>
#include <algorithm>
#include <utility>
#include <utility>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function.h"
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#ifdef __NVCC__
#include <cuda.h>
#include <cuda_fp16.h>
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#include <hipcub/hipcub.hpp>
namespace
cub
=
hipcub
;
#endif
#endif
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
...
paddle/fluid/operators/elementwise/elementwise_div_op.cu
浏览文件 @
2367cca6
...
@@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
...
@@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_div_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_div_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex.h"
...
@@ -22,24 +23,6 @@ namespace plat = paddle::platform;
...
@@ -22,24 +23,6 @@ namespace plat = paddle::platform;
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
template
<
typename
T
,
typename
Enable
=
void
>
struct
CudaDivFunctor
{
inline
HOSTDEVICE
T
operator
()(
const
T
*
args
)
const
{
return
args
[
0
]
/
args
[
1
];
}
};
template
<
typename
T
>
struct
CudaDivFunctor
<
T
,
typename
std
::
enable_if_t
<
std
::
is_integral
<
T
>::
value
>>
{
inline
HOSTDEVICE
T
operator
()(
const
T
*
args
)
const
{
PADDLE_ENFORCE
(
args
[
1
]
!=
0
,
"Invalid Argument Error: Integer division by zero "
"encountered in divide. Please check the input value."
);
return
args
[
0
]
/
args
[
1
];
}
};
template
<
typename
T
>
template
<
typename
T
>
class
ElementwiseDivKernel
<
platform
::
CUDADeviceContext
,
T
>
class
ElementwiseDivKernel
<
platform
::
CUDADeviceContext
,
T
>
:
public
framework
::
OpKernel
<
T
>
{
:
public
framework
::
OpKernel
<
T
>
{
...
@@ -52,7 +35,7 @@ class ElementwiseDivKernel<platform::CUDADeviceContext, T>
...
@@ -52,7 +35,7 @@ class ElementwiseDivKernel<platform::CUDADeviceContext, T>
int
axis
=
PackTensorsIntoVector
<
T
>
(
ctx
,
&
ins
,
&
outs
);
int
axis
=
PackTensorsIntoVector
<
T
>
(
ctx
,
&
ins
,
&
outs
);
LaunchElementwiseCudaKernel
<
ElementwiseType
::
kBinary
,
T
,
T
>
(
LaunchElementwiseCudaKernel
<
ElementwiseType
::
kBinary
,
T
,
T
>
(
cuda_ctx
,
ins
,
&
outs
,
axis
,
Cuda
DivFunctor
<
T
>
());
cuda_ctx
,
ins
,
&
outs
,
axis
,
DivFunctor
<
T
>
());
}
}
};
};
...
@@ -124,10 +107,10 @@ elementwise_div_grad(const framework::ExecutionContext& ctx,
...
@@ -124,10 +107,10 @@ elementwise_div_grad(const framework::ExecutionContext& ctx,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
)
{
framework
::
Tensor
*
dy
)
{
dim3
block_size
=
dim3
(
PADDLE_CUDA_THREAD
_SIZE
,
1
);
dim3
block_size
=
dim3
(
ELEMENTWISE_BLOCK
_SIZE
,
1
);
auto
size
=
x
->
numel
();
auto
size
=
x
->
numel
();
dim3
grid_size
=
dim3
grid_size
=
dim3
((
size
+
PADDLE_CUDA_THREAD_SIZE
-
1
)
/
PADDLE_CUDA_THREAD
_SIZE
,
1
);
dim3
((
size
+
ELEMENTWISE_BLOCK_SIZE
-
1
)
/
ELEMENTWISE_BLOCK
_SIZE
,
1
);
SimpleElemwiseDivGradCUDAKernel
<
SimpleElemwiseDivGradCUDAKernel
<
T
><<<
grid_size
,
block_size
,
0
,
T
><<<
grid_size
,
block_size
,
0
,
ctx
.
template
device_context
<
plat
::
CUDADeviceContext
>().
stream
()
>>>
(
ctx
.
template
device_context
<
plat
::
CUDADeviceContext
>().
stream
()
>>>
(
...
...
paddle/fluid/operators/elementwise/elementwise_div_op.h
浏览文件 @
2367cca6
...
@@ -18,7 +18,6 @@ limitations under the License. */
...
@@ -18,7 +18,6 @@ limitations under the License. */
#include <vector>
#include <vector>
#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/blas.h"
...
...
paddle/fluid/operators/elementwise/elementwise_floordiv_op.cu
浏览文件 @
2367cca6
...
@@ -11,25 +11,13 @@ distributed under the License is distributed on an "AS IS" BASIS,
...
@@ -11,25 +11,13 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_floordiv_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_floordiv_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
template
<
typename
T
>
struct
CudaFloorDivFunctor
{
inline
HOSTDEVICE
T
operator
()(
const
T
argv
[])
const
{
PADDLE_ENFORCE
(
argv
[
1
]
!=
0
,
"InvalidArgument: divide by zero "
"encountered in floor-divide ops, please check.
\n
"
);
return
static_cast
<
T
>
(
std
::
trunc
(
argv
[
0
]
/
argv
[
1
]));
}
};
template
<
typename
T
>
template
<
typename
T
>
class
ElementwiseFloorDivKernel
<
platform
::
CUDADeviceContext
,
T
>
class
ElementwiseFloorDivKernel
<
platform
::
CUDADeviceContext
,
T
>
:
public
framework
::
OpKernel
<
T
>
{
:
public
framework
::
OpKernel
<
T
>
{
...
@@ -42,13 +30,16 @@ class ElementwiseFloorDivKernel<platform::CUDADeviceContext, T>
...
@@ -42,13 +30,16 @@ class ElementwiseFloorDivKernel<platform::CUDADeviceContext, T>
int
axis
=
PackTensorsIntoVector
<
T
>
(
ctx
,
&
ins
,
&
outs
);
int
axis
=
PackTensorsIntoVector
<
T
>
(
ctx
,
&
ins
,
&
outs
);
LaunchElementwiseCudaKernel
<
ElementwiseType
::
kBinary
,
T
,
T
>
(
LaunchElementwiseCudaKernel
<
ElementwiseType
::
kBinary
,
T
,
T
>
(
cuda_ctx
,
ins
,
&
outs
,
axis
,
Cuda
FloorDivFunctor
<
T
>
());
cuda_ctx
,
ins
,
&
outs
,
axis
,
FloorDivFunctor
<
T
>
());
}
}
};
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
REGISTER_OP_CUDA_KERNEL
(
elementwise_floordiv
,
elementwise_floordiv
,
ops
::
ElementwiseFloorDivKernel
<
plat
::
CUDADeviceContext
,
int
>
,
ops
::
ElementwiseFloorDivKernel
<
plat
::
CUDADeviceContext
,
int
>
,
...
...
paddle/fluid/operators/elementwise/elementwise_floordiv_op.h
浏览文件 @
2367cca6
...
@@ -15,54 +15,13 @@ limitations under the License. */
...
@@ -15,54 +15,13 @@ limitations under the License. */
#pragma once
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/elementwise/elementwise_functor.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/blas.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
template
<
typename
T
>
struct
FloorDivFunctor
{
inline
HOSTDEVICE
T
operator
()(
T
a
,
T
b
)
const
{
#if defined(__HIPCC__) || defined(__CUDA_ARCH__)
if
(
b
==
0
)
{
printf
(
"Error: Divide by zero encounter in floor_divide
\n
"
);
#ifdef __HIPCC__
abort
();
#else
asm
(
"trap;"
);
#endif
}
#else
if
(
b
==
0
)
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Divide by zero encounter in floor_divide"
));
#endif
return
static_cast
<
T
>
(
std
::
trunc
(
a
/
b
));
}
};
template
<
typename
T
>
struct
InverseFloorDivFunctor
{
inline
HOSTDEVICE
T
operator
()(
T
a
,
T
b
)
const
{
#if defined(__HIPCC__) || defined(__CUDA_ARCH__)
if
(
a
==
0
)
{
printf
(
"Error: Divide by zero encounter in floor_divide
\n
"
);
#ifdef __HIPCC__
abort
();
#else
asm
(
"trap;"
);
#endif
}
#else
if
(
a
==
0
)
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Divide by zero encounter in floor_divide"
));
#endif
return
static_cast
<
T
>
(
std
::
trunc
(
b
/
a
));
}
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
void
elementwise_floor_div
(
const
framework
::
ExecutionContext
&
ctx
,
void
elementwise_floor_div
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
x
,
...
...
paddle/fluid/operators/elementwise/elementwise_functor.h
0 → 100644
浏览文件 @
2367cca6
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/hostdevice.h"
namespace
paddle
{
namespace
operators
{
// Define the binary functors used in elementwise ops.
// Add
template
<
typename
T
>
struct
AddFunctor
{
inline
HOSTDEVICE
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
a
+
b
;
}
};
template
<
typename
T
>
struct
InverseAddFunctor
{
inline
HOSTDEVICE
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
b
+
a
;
}
};
// Subtract
template
<
typename
T
>
struct
SubFunctor
{
inline
HOSTDEVICE
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
a
-
b
;
}
};
template
<
typename
T
>
struct
InverseSubFunctor
{
inline
HOSTDEVICE
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
b
-
a
;
}
};
// Multiply
template
<
typename
T
>
struct
MulFunctor
{
inline
HOSTDEVICE
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
a
*
b
;
}
};
template
<
typename
T
>
struct
InverseMulFunctor
{
inline
HOSTDEVICE
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
b
*
a
;
}
};
// Divide
#define DIV_ERROR_INFO \
"InvalidArgumentError: Integer division by zero encountered in " \
"(floor) divide. Please check the input value."
template
<
typename
T
,
typename
Enable
=
void
>
struct
DivFunctor
{
inline
HOSTDEVICE
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
a
/
b
;
}
};
template
<
typename
T
>
struct
DivFunctor
<
T
,
typename
std
::
enable_if
<
std
::
is_integral
<
T
>::
value
>::
type
>
{
inline
HOSTDEVICE
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
// For int32/int64, need to check whether the divison is zero.
PADDLE_ENFORCE
(
b
!=
0
,
DIV_ERROR_INFO
);
return
a
/
b
;
}
};
template
<
typename
T
,
typename
Enable
=
void
>
struct
InverseDivFunctor
{
inline
HOSTDEVICE
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
b
/
a
;
}
};
// Floor Divide
template
<
typename
T
>
struct
FloorDivFunctor
{
inline
HOSTDEVICE
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
PADDLE_ENFORCE
(
b
!=
0
,
DIV_ERROR_INFO
);
return
static_cast
<
T
>
(
std
::
trunc
(
a
/
b
));
}
};
template
<
typename
T
>
struct
InverseFloorDivFunctor
{
inline
HOSTDEVICE
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
PADDLE_ENFORCE
(
a
!=
0
,
DIV_ERROR_INFO
);
return
static_cast
<
T
>
(
std
::
trunc
(
b
/
a
));
}
};
#undef DIV_ERROR_INFO
// Maximum
template
<
typename
T
>
struct
MaxFunctor
{
inline
HOSTDEVICE
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
a
>
b
?
a
:
b
;
}
};
// Minmum
template
<
typename
T
>
struct
MinFunctor
{
inline
HOSTDEVICE
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
a
<
b
?
a
:
b
;
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/elementwise/elementwise_max_op.cu
浏览文件 @
2367cca6
...
@@ -11,21 +11,13 @@ distributed under the License is distributed on an "AS IS" BASIS,
...
@@ -11,21 +11,13 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_max_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_max_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
namespace
ops
=
paddle
::
operators
;
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
template
<
typename
T
>
struct
CudaMaxFunctor
{
inline
HOSTDEVICE
T
operator
()(
const
T
*
args
)
const
{
return
(
args
[
0
]
>
args
[
1
]
?
args
[
0
]
:
args
[
1
]);
}
};
template
<
typename
T
>
template
<
typename
T
>
class
ElementwiseMaxKernel
<
platform
::
CUDADeviceContext
,
T
>
class
ElementwiseMaxKernel
<
platform
::
CUDADeviceContext
,
T
>
:
public
framework
::
OpKernel
<
T
>
{
:
public
framework
::
OpKernel
<
T
>
{
...
@@ -38,13 +30,15 @@ class ElementwiseMaxKernel<platform::CUDADeviceContext, T>
...
@@ -38,13 +30,15 @@ class ElementwiseMaxKernel<platform::CUDADeviceContext, T>
int
axis
=
PackTensorsIntoVector
<
T
>
(
ctx
,
&
ins
,
&
outs
);
int
axis
=
PackTensorsIntoVector
<
T
>
(
ctx
,
&
ins
,
&
outs
);
LaunchElementwiseCudaKernel
<
ElementwiseType
::
kBinary
,
T
,
T
>
(
LaunchElementwiseCudaKernel
<
ElementwiseType
::
kBinary
,
T
,
T
>
(
cuda_ctx
,
ins
,
&
outs
,
axis
,
Cuda
MaxFunctor
<
T
>
());
cuda_ctx
,
ins
,
&
outs
,
axis
,
MaxFunctor
<
T
>
());
}
}
};
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
REGISTER_OP_CUDA_KERNEL
(
elementwise_max
,
elementwise_max
,
ops
::
ElementwiseMaxKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
ElementwiseMaxKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
...
...
paddle/fluid/operators/elementwise/elementwise_max_op.h
浏览文件 @
2367cca6
...
@@ -14,17 +14,13 @@ limitations under the License. */
...
@@ -14,17 +14,13 @@ limitations under the License. */
#pragma once
#pragma once
#include "paddle/fluid/operators/elementwise/elementwise_functor.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
template
<
typename
T
>
struct
MaxFunctor
{
inline
HOSTDEVICE
T
operator
()(
T
a
,
T
b
)
const
{
return
a
>
b
?
a
:
b
;
}
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
ElementwiseMaxKernel
:
public
framework
::
OpKernel
<
T
>
{
class
ElementwiseMaxKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
...
...
paddle/fluid/operators/elementwise/elementwise_min_op.cu
浏览文件 @
2367cca6
...
@@ -11,21 +11,13 @@ distributed under the License is distributed on an "AS IS" BASIS,
...
@@ -11,21 +11,13 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_min_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_min_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
namespace
ops
=
paddle
::
operators
;
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
template
<
typename
T
>
struct
CudaMinFunctor
{
inline
HOSTDEVICE
T
operator
()(
const
T
*
args
)
const
{
return
(
args
[
0
]
>
args
[
1
]
?
args
[
1
]
:
args
[
0
]);
}
};
template
<
typename
T
>
template
<
typename
T
>
class
ElementwiseMinKernel
<
platform
::
CUDADeviceContext
,
T
>
class
ElementwiseMinKernel
<
platform
::
CUDADeviceContext
,
T
>
:
public
framework
::
OpKernel
<
T
>
{
:
public
framework
::
OpKernel
<
T
>
{
...
@@ -38,13 +30,15 @@ class ElementwiseMinKernel<platform::CUDADeviceContext, T>
...
@@ -38,13 +30,15 @@ class ElementwiseMinKernel<platform::CUDADeviceContext, T>
int
axis
=
PackTensorsIntoVector
<
T
>
(
ctx
,
&
ins
,
&
outs
);
int
axis
=
PackTensorsIntoVector
<
T
>
(
ctx
,
&
ins
,
&
outs
);
LaunchElementwiseCudaKernel
<
ElementwiseType
::
kBinary
,
T
,
T
>
(
LaunchElementwiseCudaKernel
<
ElementwiseType
::
kBinary
,
T
,
T
>
(
cuda_ctx
,
ins
,
&
outs
,
axis
,
Cuda
MinFunctor
<
T
>
());
cuda_ctx
,
ins
,
&
outs
,
axis
,
MinFunctor
<
T
>
());
}
}
};
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
REGISTER_OP_CUDA_KERNEL
(
elementwise_min
,
elementwise_min
,
ops
::
ElementwiseMinKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
ElementwiseMinKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
...
...
paddle/fluid/operators/elementwise/elementwise_min_op.h
浏览文件 @
2367cca6
...
@@ -14,17 +14,13 @@ limitations under the License. */
...
@@ -14,17 +14,13 @@ limitations under the License. */
#pragma once
#pragma once
#include "paddle/fluid/operators/elementwise/elementwise_functor.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
template
<
typename
T
>
struct
MinFunctor
{
inline
HOSTDEVICE
T
operator
()(
T
a
,
T
b
)
const
{
return
a
<
b
?
a
:
b
;
}
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
ElementwiseMinKernel
:
public
framework
::
OpKernel
<
T
>
{
class
ElementwiseMinKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
...
...
paddle/fluid/operators/elementwise/elementwise_mul_op.cu
浏览文件 @
2367cca6
...
@@ -14,7 +14,6 @@ limitations under the License. */
...
@@ -14,7 +14,6 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/float16.h"
...
@@ -95,10 +94,10 @@ elementwise_mul_grad(const framework::ExecutionContext& ctx,
...
@@ -95,10 +94,10 @@ elementwise_mul_grad(const framework::ExecutionContext& ctx,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
)
{
framework
::
Tensor
*
dy
)
{
dim3
block_size
=
dim3
(
PADDLE_CUDA_THREAD
_SIZE
,
1
);
dim3
block_size
=
dim3
(
ELEMENTWISE_BLOCK
_SIZE
,
1
);
auto
size
=
x
->
numel
();
auto
size
=
x
->
numel
();
dim3
grid_size
=
dim3
grid_size
=
dim3
((
size
+
PADDLE_CUDA_THREAD_SIZE
-
1
)
/
PADDLE_CUDA_THREAD
_SIZE
,
1
);
dim3
((
size
+
ELEMENTWISE_BLOCK_SIZE
-
1
)
/
ELEMENTWISE_BLOCK
_SIZE
,
1
);
SimpleElemwiseMulGradCUDAKernel
<
SimpleElemwiseMulGradCUDAKernel
<
T
><<<
grid_size
,
block_size
,
0
,
T
><<<
grid_size
,
block_size
,
0
,
ctx
.
template
device_context
<
plat
::
CUDADeviceContext
>().
stream
()
>>>
(
ctx
.
template
device_context
<
plat
::
CUDADeviceContext
>().
stream
()
>>>
(
...
...
paddle/fluid/operators/elementwise/elementwise_mul_op.h
浏览文件 @
2367cca6
...
@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
...
@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#pragma once
#pragma once
#include <string>
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/cpu_info.h"
...
...
paddle/fluid/operators/elementwise/elementwise_op_function.cu.h
已删除
100644 → 0
浏览文件 @
39dcfc6c
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <glog/logging.h>
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/hostdevice.h"
#ifdef __HIPCC__
#define PADDLE_CUDA_THREAD_SIZE 256
#else
#define PADDLE_CUDA_THREAD_SIZE 512
#endif
#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
#ifdef PADDLE_CUDA_FP16
#include <cuda_fp16.h>
#endif
#endif // PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#ifdef PADDLE_CUDA_FP16
#include <hip/hip_fp16.h>
#endif
#endif // PADDLE_WITH_HIP
#define DIV_ERROR_INFO \
"InvalidArgumentError: Integer division by zero encountered in divide. " \
"Please check.\n"
namespace
paddle
{
namespace
operators
{
#define DEFINE_SIMPLE_BINARY_FUNCTOR(Func, expr) \
template <typename T, class Enable = void> \
struct Func##Functor { \
inline HOSTDEVICE T operator()(const T& a, const T& b) const { \
return a expr b; \
} \
}; \
template <typename T, class Enable = void> \
struct Inverse##Func##Functor { \
inline HOSTDEVICE T operator()(const T& a, const T& b) const { \
return b expr a; \
} \
};
DEFINE_SIMPLE_BINARY_FUNCTOR
(
Add
,
+
)
DEFINE_SIMPLE_BINARY_FUNCTOR
(
Sub
,
-
)
DEFINE_SIMPLE_BINARY_FUNCTOR
(
Mul
,
*
)
DEFINE_SIMPLE_BINARY_FUNCTOR
(
Div
,
/
)
#undef DEFINE_SIMPLE_BINARY_FUNCTOR
// special div functor for int32/int64. check divison has a zero
template
<
typename
T
>
struct
DivFunctor
<
T
,
typename
std
::
enable_if
<
std
::
is_integral
<
T
>::
value
>::
type
>
{
inline
HOSTDEVICE
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
PADDLE_ENFORCE
(
b
!=
0
,
DIV_ERROR_INFO
);
return
a
/
b
;
}
};
#define DEFINE_SIMPLE_CUDA_BINARY_FUNCTOR(Func, expr) \
template <typename T, class Enable = void> \
struct Func##RangeFunctor { \
Func##RangeFunctor(const T* x, const T* y, T* z) : x_(x), y_(y), z_(z) {} \
inline HOSTDEVICE void operator()(size_t id) const { \
z_[id] = x_[id] expr y_[id]; \
} \
const T* x_; \
const T* y_; \
T* z_; \
};
DEFINE_SIMPLE_CUDA_BINARY_FUNCTOR
(
Add
,
+
)
DEFINE_SIMPLE_CUDA_BINARY_FUNCTOR
(
Sub
,
-
)
DEFINE_SIMPLE_CUDA_BINARY_FUNCTOR
(
Mul
,
*
)
DEFINE_SIMPLE_CUDA_BINARY_FUNCTOR
(
Div
,
/
)
#undef DEFINE_SIMPLE_CUDA_BINARY_FUNCTOR
// special div functor for int32/int64. check divison has a zero
template
<
typename
T
>
struct
DivRangeFunctor
<
T
,
typename
std
::
enable_if
<
std
::
is_integral
<
T
>::
value
>::
type
>
{
DivRangeFunctor
(
const
T
*
x
,
const
T
*
y
,
T
*
z
)
:
x_
(
x
),
y_
(
y
),
z_
(
z
)
{}
inline
HOSTDEVICE
void
operator
()(
size_t
id
)
const
{
PADDLE_ENFORCE
(
y_
[
id
]
!=
0
,
DIV_ERROR_INFO
);
z_
[
id
]
=
x_
[
id
]
/
y_
[
id
];
}
const
T
*
x_
;
const
T
*
y_
;
T
*
z_
;
};
#ifdef PADDLE_CUDA_FP16
inline
DEVICE
half2
half2_add
(
const
half2
&
a
,
const
half2
&
b
)
{
#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
return
__hadd2
(
a
,
b
);
#else
float
a1
=
__low2float
(
a
);
float
a2
=
__high2float
(
a
);
float
b1
=
__low2float
(
b
);
float
b2
=
__high2float
(
b
);
float
r1
=
a1
+
b1
;
float
r2
=
a2
+
b2
;
return
__floats2half2_rn
(
r1
,
r2
);
#endif
}
inline
DEVICE
half2
half2_sub
(
const
half2
&
a
,
const
half2
&
b
)
{
#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
return
__hsub2
(
a
,
b
);
#else
float
a1
=
__low2float
(
a
);
float
a2
=
__high2float
(
a
);
float
b1
=
__low2float
(
b
);
float
b2
=
__high2float
(
b
);
float
r1
=
a1
-
b1
;
float
r2
=
a2
-
b2
;
return
__floats2half2_rn
(
r1
,
r2
);
#endif
}
inline
DEVICE
half2
half2_mul
(
const
half2
&
a
,
const
half2
&
b
)
{
#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
return
__hmul2
(
a
,
b
);
#else
float
a1
=
__low2float
(
a
);
float
a2
=
__high2float
(
a
);
float
b1
=
__low2float
(
b
);
float
b2
=
__high2float
(
b
);
float
r1
=
a1
*
b1
;
float
r2
=
a2
*
b2
;
return
__floats2half2_rn
(
r1
,
r2
);
#endif
}
inline
DEVICE
half2
half2_div
(
const
half2
&
a
,
const
half2
&
b
)
{
#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
return
__h2div
(
a
,
b
);
#else
float
a1
=
__low2float
(
a
);
float
a2
=
__high2float
(
a
);
float
b1
=
__low2float
(
b
);
float
b2
=
__high2float
(
b
);
float
r1
=
a1
/
b1
;
float
r2
=
a2
/
b2
;
return
__floats2half2_rn
(
r1
,
r2
);
#endif
}
#define DEFINE_SIMPLE_CUDA_BINARY_KERNEL(Func, expr, FP16Function) \
inline __global__ void SameDimsElemwise##Func##CUDAKernel( \
const float* __restrict__ x, const float* __restrict__ y, float* z, \
int64_t size) { \
int tid = blockIdx.x * blockDim.x + threadIdx.x; \
int stride = gridDim.x * blockDim.x; \
int loop = size / 4; \
int remainder = size % 4; \
const float4* x_vec = reinterpret_cast<const float4*>(x); \
const float4* y_vec = reinterpret_cast<const float4*>(y); \
float4* z_vec = reinterpret_cast<float4*>(z); \
float4 x_f4, y_f4; \
for (int i = tid; i < loop; i += stride) { \
x_f4 = x_vec[i]; \
y_f4 = y_vec[i]; \
z_vec[i] = make_float4(x_f4.x expr y_f4.x, x_f4.y expr y_f4.y, \
x_f4.z expr y_f4.z, x_f4.w expr y_f4.w); \
} \
if (tid == loop && remainder != 0) { \
while (remainder) { \
int idx = size - remainder; \
remainder--; \
z[idx] = x[idx] expr y[idx]; \
} \
} \
} \
inline __global__ void SameDimsElemwise##Func##CUDAKernel( \
const half* __restrict__ x, const half* __restrict__ y, half* z, \
int64_t size) { \
int tid = blockIdx.x * blockDim.x + threadIdx.x; \
int stride = gridDim.x * blockDim.x; \
int loop = size / 8; \
int remainder = size % 8; \
const float4* x_vec = reinterpret_cast<const float4*>(x); \
const float4* y_vec = reinterpret_cast<const float4*>(y); \
float4* z_vec = reinterpret_cast<float4*>(z); \
float4 x_h8, y_h8, z_h8; \
for (int i = tid; i < loop; i += stride) { \
x_h8 = x_vec[i]; \
y_h8 = y_vec[i]; \
half2* x_h2 = reinterpret_cast<half2*>(&x_h8); \
half2* y_h2 = reinterpret_cast<half2*>(&y_h8); \
half2* z_h2 = reinterpret_cast<half2*>(&z_h8); \
z_h2[0] = FP16Function(x_h2[0], y_h2[0]); \
z_h2[1] = FP16Function(x_h2[1], y_h2[1]); \
z_h2[2] = FP16Function(x_h2[2], y_h2[2]); \
z_h2[3] = FP16Function(x_h2[3], y_h2[3]); \
z_vec[i] = z_h8; \
} \
if (tid == loop && remainder != 0) { \
while (remainder) { \
int idx = size - remainder; \
remainder--; \
z[idx] = __float2half(__half2float(x[idx]) expr __half2float(y[idx])); \
} \
} \
}
DEFINE_SIMPLE_CUDA_BINARY_KERNEL
(
Add
,
+
,
half2_add
)
DEFINE_SIMPLE_CUDA_BINARY_KERNEL
(
Sub
,
-
,
half2_sub
)
DEFINE_SIMPLE_CUDA_BINARY_KERNEL
(
Mul
,
*
,
half2_mul
)
DEFINE_SIMPLE_CUDA_BINARY_KERNEL
(
Div
,
/
,
half2_div
)
#undef DEFINE_SIMPLE_CUDA_BINARY_KERNEL
#endif // PADDLE_CUDA_FP16
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/elementwise/elementwise_op_function.h
浏览文件 @
2367cca6
...
@@ -25,7 +25,7 @@ limitations under the License. */
...
@@ -25,7 +25,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/elementwise/elementwise_
op_function.cu
.h"
#include "paddle/fluid/operators/elementwise/elementwise_
functor
.h"
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/transform.h"
#include "paddle/fluid/platform/transform.h"
...
...
paddle/fluid/operators/elementwise/elementwise_sub_op.cu
浏览文件 @
2367cca6
...
@@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
...
@@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex.h"
...
@@ -59,10 +60,10 @@ elementwise_sub_grad(const framework::ExecutionContext& ctx,
...
@@ -59,10 +60,10 @@ elementwise_sub_grad(const framework::ExecutionContext& ctx,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
)
{
framework
::
Tensor
*
dy
)
{
dim3
block_size
=
dim3
(
PADDLE_CUDA_THREAD
_SIZE
,
1
);
dim3
block_size
=
dim3
(
ELEMENTWISE_BLOCK
_SIZE
,
1
);
auto
size
=
x
->
numel
();
auto
size
=
x
->
numel
();
dim3
grid_size
=
dim3
grid_size
=
dim3
((
size
+
PADDLE_CUDA_THREAD_SIZE
-
1
)
/
PADDLE_CUDA_THREAD
_SIZE
,
1
);
dim3
((
size
+
ELEMENTWISE_BLOCK_SIZE
-
1
)
/
ELEMENTWISE_BLOCK
_SIZE
,
1
);
SimpleElemwiseSubGradCUDAKernel
<
SimpleElemwiseSubGradCUDAKernel
<
T
><<<
grid_size
,
block_size
,
0
,
T
><<<
grid_size
,
block_size
,
0
,
ctx
.
template
device_context
<
plat
::
CUDADeviceContext
>().
stream
()
>>>
(
ctx
.
template
device_context
<
plat
::
CUDADeviceContext
>().
stream
()
>>>
(
...
...
paddle/fluid/operators/elementwise/elementwise_sub_op.h
浏览文件 @
2367cca6
...
@@ -11,10 +11,10 @@ distributed under the License is distributed on an "AS IS" BASIS,
...
@@ -11,10 +11,10 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#pragma once
#pragma once
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/blas.h"
...
...
paddle/fluid/operators/layer_norm_op.h
浏览文件 @
2367cca6
...
@@ -19,7 +19,6 @@ limitations under the License. */
...
@@ -19,7 +19,6 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/blas.h"
#if !defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__) && \
#if !defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__) && \
...
...
paddle/fluid/operators/svd_helper.h
浏览文件 @
2367cca6
...
@@ -25,7 +25,6 @@
...
@@ -25,7 +25,6 @@
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/functors.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/for_range.h"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录