Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
c6b39a00
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
c6b39a00
编写于
12月 06, 2018
作者:
H
Houjiang Chen
提交者:
GitHub
12月 06, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #14714 from NHZlX/add_prelu_gpu
add prelu cuda kernel for inference.
上级
8a111ac6
e7abe6b6
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
284 addition
and
87 deletion
+284
-87
paddle/fluid/inference/tensorrt/convert/test_prelu_op.cc
paddle/fluid/inference/tensorrt/convert/test_prelu_op.cc
+1
-2
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
+1
-1
paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu
paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu
+18
-82
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+1
-1
paddle/fluid/operators/math/CMakeLists.txt
paddle/fluid/operators/math/CMakeLists.txt
+1
-0
paddle/fluid/operators/math/prelu.cu
paddle/fluid/operators/math/prelu.cu
+148
-0
paddle/fluid/operators/math/prelu.h
paddle/fluid/operators/math/prelu.h
+49
-0
paddle/fluid/operators/prelu_op.cc
paddle/fluid/operators/prelu_op.cc
+1
-1
paddle/fluid/operators/prelu_op.cu
paddle/fluid/operators/prelu_op.cu
+64
-0
未找到文件。
paddle/fluid/inference/tensorrt/convert/test_prelu_op.cc
浏览文件 @
c6b39a00
...
...
@@ -90,5 +90,4 @@ TEST(prelu_op, test_scalar) {
}
// namespace inference
}
// namespace paddle
// USE_OP(prelu);
USE_CPU_ONLY_OP
(
prelu
);
USE_OP
(
prelu
);
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
浏览文件 @
c6b39a00
nv_library
(
tensorrt_plugin
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu prelu_op_plugin.cu
avg_pool_op_plugin.cu
DEPS enforce tensorrt_engine
)
DEPS enforce tensorrt_engine
prelu
)
paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu
浏览文件 @
c6b39a00
...
...
@@ -14,92 +14,16 @@
#include <stdio.h>
#include <cassert>
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h"
#include "paddle/fluid/operators/math/prelu.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
static
const
int
CUDA_NUM_THREADS
=
1024
;
static
const
int
CUDA_MAX_NUM_BLOCKS
=
65535
;
inline
static
int
GET_NUM_BLOCKS
(
const
int
N
)
{
return
(
N
+
CUDA_NUM_THREADS
-
1
)
/
CUDA_NUM_THREADS
;
}
__global__
void
PReluChannelWiseKernel
(
const
float
*
input
,
const
float
*
alpha
,
float
*
output
,
int
channel
,
size_t
spatial_size
)
{
size_t
offset
=
blockIdx
.
x
*
spatial_size
;
const
float
*
in
=
input
+
offset
;
float
*
out
=
output
+
offset
;
float
scale
=
alpha
[
blockIdx
.
x
%
channel
];
for
(
size_t
i
=
threadIdx
.
x
;
i
<
spatial_size
;
i
+=
blockDim
.
x
)
{
float
x
=
in
[
i
];
out
[
i
]
=
(
x
>
0
)
?
x
:
scale
*
x
;
}
}
__global__
void
PReluElementWiseKernel
(
const
float
*
input
,
const
float
*
alpha
,
float
*
output
,
size_t
spatial_size
)
{
size_t
offset
=
blockIdx
.
x
*
spatial_size
;
const
float
*
in
=
input
+
offset
;
const
float
*
scale
=
alpha
+
offset
;
float
*
out
=
output
+
offset
;
for
(
size_t
i
=
threadIdx
.
x
;
i
<
spatial_size
;
i
+=
blockDim
.
x
)
{
float
x
=
in
[
i
];
out
[
i
]
=
(
x
>
0
)
?
x
:
scale
[
i
]
*
x
;
}
}
__global__
void
PReluScalarKernel
(
const
float
*
input
,
const
float
*
alpha
,
float
*
output
,
size_t
spatial_size
)
{
size_t
offset
=
blockIdx
.
x
*
spatial_size
;
const
float
*
in
=
input
+
offset
;
float
scale
=
*
alpha
;
float
*
out
=
output
+
offset
;
for
(
size_t
i
=
threadIdx
.
x
;
i
<
spatial_size
;
i
+=
blockDim
.
x
)
{
float
x
=
in
[
i
];
out
[
i
]
=
(
x
>
0
)
?
x
:
scale
*
x
;
}
}
static
inline
void
PReluChannelWise
(
cudaStream_t
stream
,
const
float
*
input
,
const
float
*
alpha
,
float
*
output
,
int
batch_size
,
const
nvinfer1
::
Dims
&
dims
)
{
size_t
unroll
=
batch_size
*
dims
.
d
[
0
];
size_t
spatial_size
=
dims
.
d
[
1
]
*
dims
.
d
[
2
];
CHECK_LT
(
unroll
,
CUDA_MAX_NUM_BLOCKS
);
PReluChannelWiseKernel
<<<
unroll
,
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
input
,
alpha
,
output
,
dims
.
d
[
0
],
spatial_size
);
}
static
inline
void
PReluElementWise
(
cudaStream_t
stream
,
const
float
*
input
,
const
float
*
alpha
,
float
*
output
,
int
batch_size
,
const
nvinfer1
::
Dims
&
dims
)
{
size_t
unroll
=
batch_size
*
dims
.
d
[
0
];
size_t
spatial_size
=
dims
.
d
[
1
]
*
dims
.
d
[
2
];
CHECK_LT
(
unroll
,
CUDA_MAX_NUM_BLOCKS
);
PReluElementWiseKernel
<<<
unroll
,
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
input
,
alpha
,
output
,
spatial_size
);
}
static
inline
void
PReluScalar
(
cudaStream_t
stream
,
const
float
*
input
,
const
float
*
alpha
,
float
*
output
,
int
batch_size
,
const
nvinfer1
::
Dims
&
dims
)
{
size_t
unroll
=
batch_size
*
dims
.
d
[
0
];
size_t
spatial_size
=
dims
.
d
[
1
]
*
dims
.
d
[
2
];
CHECK_LT
(
unroll
,
CUDA_MAX_NUM_BLOCKS
);
PReluScalarKernel
<<<
unroll
,
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
input
,
alpha
,
output
,
spatial_size
);
}
nvinfer1
::
Dims
PReluPlugin
::
getOutputDimensions
(
int
index
,
const
nvinfer1
::
Dims
*
inputDims
,
int
nbInputs
)
{
...
...
@@ -110,19 +34,31 @@ nvinfer1::Dims PReluPlugin::getOutputDimensions(int index,
return
output_dims
;
}
int
PReluPlugin
::
enqueue
(
int
batch
S
ize
,
const
void
*
const
*
inputs
,
int
PReluPlugin
::
enqueue
(
int
batch
_s
ize
,
const
void
*
const
*
inputs
,
void
**
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
{
// input dims is CHW.
const
auto
&
input_dims
=
this
->
getInputDims
(
0
);
const
float
*
input
=
reinterpret_cast
<
const
float
*>
(
inputs
[
0
]);
const
float
*
alpha
=
reinterpret_cast
<
const
float
*>
(
alpha_
.
get
().
values
);
float
*
output
=
reinterpret_cast
<
float
**>
(
outputs
)[
0
];
std
::
vector
<
int
>
input_shape
;
input_shape
.
push_back
(
batch_size
);
for
(
int
i
=
0
;
i
<
input_dims
.
nbDims
;
i
++
)
{
input_shape
.
push_back
(
input_dims
.
d
[
i
]);
}
if
(
mode_
==
"channel"
)
{
PReluChannelWise
(
stream
,
input
,
alpha
,
output
,
batchSize
,
input_dims
);
operators
::
math
::
PreluChannelWiseDirectCUDAFunctor
<
float
>
prelu_channel_wise
;
prelu_channel_wise
(
stream
,
input
,
alpha
,
output
,
input_shape
);
}
else
if
(
mode_
==
"element"
)
{
PReluElementWise
(
stream
,
input
,
alpha
,
output
,
batchSize
,
input_dims
);
operators
::
math
::
PreluElementWiseDirectCUDAFunctor
<
float
>
prelu_element_wise
;
prelu_element_wise
(
stream
,
input
,
alpha
,
output
,
input_shape
);
}
else
{
PReluScalar
(
stream
,
input
,
alpha
,
output
,
batchSize
,
input_dims
);
operators
::
math
::
PreluScalarDirectCUDAFunctor
<
float
>
prelu_scalar
;
prelu_scalar
(
stream
,
input
,
alpha
,
output
,
input_shape
);
}
return
cudaGetLastError
()
!=
cudaSuccess
;
}
...
...
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
c6b39a00
...
...
@@ -70,7 +70,7 @@ endif()
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
sequence_padding sequence_scale cos_sim_functor memory jit_kernel concat_and_split cross_entropy softmax vol2col im2col sampler
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions
)
if
(
WITH_GPU
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
depthwise_conv
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
depthwise_conv
prelu
)
endif
()
# FIXME(typhoonzero): operator deps may not needed.
...
...
paddle/fluid/operators/math/CMakeLists.txt
浏览文件 @
c6b39a00
...
...
@@ -59,6 +59,7 @@ math_library(matrix_bit_code)
math_library
(
unpooling
)
math_library
(
vol2col
)
math_library
(
prelu
)
cc_test
(
math_function_test SRCS math_function_test.cc DEPS math_function
)
cc_test
(
selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selected_rows_functor
)
...
...
paddle/fluid/operators/math/prelu.cu
0 → 100644
浏览文件 @
c6b39a00
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/prelu.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
static
const
int
CUDA_NUM_THREADS
=
1024
;
static
const
int
CUDA_MAX_NUM_BLOCKS
=
65535
;
inline
static
int
GET_NUM_BLOCKS
(
const
int
N
)
{
return
(
N
+
CUDA_NUM_THREADS
-
1
)
/
CUDA_NUM_THREADS
;
}
template
<
typename
T
>
__global__
void
PReluChannelWiseKernel
(
const
T
*
input
,
const
T
*
alpha
,
T
*
output
,
int
channel
,
size_t
spatial_size
)
{
size_t
offset
=
blockIdx
.
x
*
spatial_size
;
const
T
*
in
=
input
+
offset
;
T
*
out
=
output
+
offset
;
T
scale
=
alpha
[
blockIdx
.
x
%
channel
];
for
(
size_t
i
=
threadIdx
.
x
;
i
<
spatial_size
;
i
+=
blockDim
.
x
)
{
T
x
=
in
[
i
];
out
[
i
]
=
(
x
>
0
)
?
x
:
scale
*
x
;
}
}
template
<
typename
T
>
__global__
void
PReluElementWiseKernel
(
const
T
*
input
,
const
T
*
alpha
,
T
*
output
,
size_t
spatial_size
)
{
size_t
offset
=
blockIdx
.
x
*
spatial_size
;
const
T
*
in
=
input
+
offset
;
const
T
*
scale
=
alpha
+
offset
;
T
*
out
=
output
+
offset
;
for
(
size_t
i
=
threadIdx
.
x
;
i
<
spatial_size
;
i
+=
blockDim
.
x
)
{
T
x
=
in
[
i
];
out
[
i
]
=
(
x
>
0
)
?
x
:
scale
[
i
]
*
x
;
}
}
template
<
typename
T
>
__global__
void
PReluScalarKernel
(
const
T
*
input
,
const
T
*
alpha
,
T
*
output
,
size_t
spatial_size
)
{
size_t
offset
=
blockIdx
.
x
*
spatial_size
;
const
T
*
in
=
input
+
offset
;
T
scale
=
*
alpha
;
T
*
out
=
output
+
offset
;
for
(
size_t
i
=
threadIdx
.
x
;
i
<
spatial_size
;
i
+=
blockDim
.
x
)
{
T
x
=
in
[
i
];
out
[
i
]
=
(
x
>
0
)
?
x
:
scale
*
x
;
}
}
template
<
typename
T
>
static
inline
void
PReluChannelWise
(
cudaStream_t
stream
,
const
T
*
input
,
const
T
*
alpha
,
T
*
output
,
std
::
vector
<
int
>
input_shape
)
{
size_t
unroll
=
input_shape
[
0
]
*
input_shape
[
1
];
size_t
spatial_size
=
input_shape
[
2
]
*
input_shape
[
3
];
CHECK_LT
(
unroll
,
CUDA_MAX_NUM_BLOCKS
);
PReluChannelWiseKernel
<<<
unroll
,
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
input
,
alpha
,
output
,
input_shape
[
1
],
spatial_size
);
}
template
<
typename
T
>
static
inline
void
PReluElementWise
(
cudaStream_t
stream
,
const
T
*
input
,
const
T
*
alpha
,
T
*
output
,
std
::
vector
<
int
>
input_shape
)
{
size_t
unroll
=
input_shape
[
0
]
*
input_shape
[
1
];
size_t
spatial_size
=
input_shape
[
2
]
*
input_shape
[
3
];
CHECK_LT
(
unroll
,
CUDA_MAX_NUM_BLOCKS
);
PReluElementWiseKernel
<<<
unroll
,
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
input
,
alpha
,
output
,
spatial_size
);
}
template
<
typename
T
>
static
inline
void
PReluScalar
(
cudaStream_t
stream
,
const
T
*
input
,
const
T
*
alpha
,
T
*
output
,
std
::
vector
<
int
>
input_shape
)
{
size_t
unroll
=
input_shape
[
0
]
*
input_shape
[
1
];
size_t
spatial_size
=
input_shape
[
2
]
*
input_shape
[
3
];
CHECK_LT
(
unroll
,
CUDA_MAX_NUM_BLOCKS
);
PReluScalarKernel
<<<
unroll
,
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
input
,
alpha
,
output
,
spatial_size
);
}
template
<
typename
T
>
void
PreluChannelWiseDirectCUDAFunctor
<
T
>::
operator
()(
cudaStream_t
stream
,
const
T
*
input
,
const
T
*
alpha
,
T
*
output
,
std
::
vector
<
int
>
input_shape
)
{
size_t
unroll
=
input_shape
[
0
]
*
input_shape
[
1
];
size_t
spatial_size
=
input_shape
[
2
]
*
input_shape
[
3
];
CHECK_LT
(
unroll
,
CUDA_MAX_NUM_BLOCKS
);
PReluChannelWiseKernel
<<<
unroll
,
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
input
,
alpha
,
output
,
input_shape
[
1
],
spatial_size
);
}
template
<
typename
T
>
void
PreluElementWiseDirectCUDAFunctor
<
T
>::
operator
()(
cudaStream_t
stream
,
const
T
*
input
,
const
T
*
alpha
,
T
*
output
,
std
::
vector
<
int
>
input_shape
)
{
size_t
unroll
=
input_shape
[
0
]
*
input_shape
[
1
];
size_t
spatial_size
=
input_shape
[
2
]
*
input_shape
[
3
];
CHECK_LT
(
unroll
,
CUDA_MAX_NUM_BLOCKS
);
PReluElementWiseKernel
<<<
unroll
,
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
input
,
alpha
,
output
,
spatial_size
);
}
template
<
typename
T
>
void
PreluScalarDirectCUDAFunctor
<
T
>::
operator
()(
cudaStream_t
stream
,
const
T
*
input
,
const
T
*
alpha
,
T
*
output
,
std
::
vector
<
int
>
input_shape
)
{
size_t
unroll
=
input_shape
[
0
]
*
input_shape
[
1
];
size_t
spatial_size
=
input_shape
[
2
]
*
input_shape
[
3
];
CHECK_LT
(
unroll
,
CUDA_MAX_NUM_BLOCKS
);
PReluScalarKernel
<<<
unroll
,
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
input
,
alpha
,
output
,
spatial_size
);
}
template
class
PreluChannelWiseDirectCUDAFunctor
<
float
>;
template
class
PreluChannelWiseDirectCUDAFunctor
<
double
>;
template
class
PreluElementWiseDirectCUDAFunctor
<
float
>;
template
class
PreluElementWiseDirectCUDAFunctor
<
double
>;
template
class
PreluScalarDirectCUDAFunctor
<
float
>;
template
class
PreluScalarDirectCUDAFunctor
<
double
>;
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/prelu.h
0 → 100644
浏览文件 @
c6b39a00
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
#ifdef PADDLE_WITH_CUDA
template
<
typename
T
>
class
PreluChannelWiseDirectCUDAFunctor
{
public:
void
operator
()(
cudaStream_t
stream
,
const
T
*
input
,
const
T
*
alpha
,
T
*
output
,
std
::
vector
<
int
>
input_shape
);
};
template
<
typename
T
>
class
PreluElementWiseDirectCUDAFunctor
{
public:
void
operator
()(
cudaStream_t
stream
,
const
T
*
input
,
const
T
*
alpha
,
T
*
output
,
std
::
vector
<
int
>
input_shape
);
};
template
<
typename
T
>
class
PreluScalarDirectCUDAFunctor
{
public:
void
operator
()(
cudaStream_t
stream
,
const
T
*
input
,
const
T
*
alpha
,
T
*
output
,
std
::
vector
<
int
>
input_shape
);
};
#endif
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/prelu_op.cc
浏览文件 @
c6b39a00
...
...
@@ -58,7 +58,7 @@ class PReluOp : public framework::OperatorWithKernel {
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
platform
::
CPUPlace
());
ctx
.
device_context
());
}
};
...
...
paddle/fluid/operators/prelu_op.cu
0 → 100644
浏览文件 @
c6b39a00
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/prelu.h"
#include "paddle/fluid/operators/prelu_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
DeviceContext
,
typename
T
>
class
CUDAPReluKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
x
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
*
alpha
=
context
.
Input
<
Tensor
>
(
"Alpha"
);
auto
*
out
=
context
.
Output
<
Tensor
>
(
"Out"
);
const
T
*
x_ptr
=
x
->
data
<
T
>
();
T
*
o_ptr
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
const
T
*
alpha_ptr
=
alpha
->
data
<
T
>
();
auto
&
mode
=
context
.
Attr
<
std
::
string
>
(
"mode"
);
int
numel
=
x
->
numel
();
auto
dim
=
x
->
dims
();
std
::
vector
<
int
>
input_shape
=
framework
::
vectorize2int
(
dim
);
if
(
mode
==
"channel"
)
{
math
::
PreluChannelWiseDirectCUDAFunctor
<
T
>
prelu_channel_wise
;
prelu_channel_wise
(
context
.
cuda_device_context
().
stream
(),
x_ptr
,
alpha_ptr
,
o_ptr
,
input_shape
);
}
else
if
(
mode
==
"element"
)
{
math
::
PreluElementWiseDirectCUDAFunctor
<
T
>
prelu_element_wise
;
prelu_element_wise
(
context
.
cuda_device_context
().
stream
(),
x_ptr
,
alpha_ptr
,
o_ptr
,
input_shape
);
}
else
{
math
::
PreluScalarDirectCUDAFunctor
<
T
>
prelu_scalar
;
prelu_scalar
(
context
.
cuda_device_context
().
stream
(),
x_ptr
,
alpha_ptr
,
o_ptr
,
input_shape
);
}
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
prelu
,
ops
::
CUDAPReluKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
CUDAPReluKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录