Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
9f7b027d
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9f7b027d
编写于
4月 09, 2019
作者:
Z
Zeng Jinle
提交者:
GitHub
4月 09, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix activation grad op desc maker (#16715)
test=develop
上级
9bd44b94
变更
6
展开全部
显示空白变更内容
内联
并排
Showing
6 changed file
with
240 addition
and
172 deletion
+240
-172
paddle/fluid/framework/details/op_registry.h
paddle/fluid/framework/details/op_registry.h
+6
-0
paddle/fluid/op_use_default_grad_op_maker.spec
paddle/fluid/op_use_default_grad_op_maker.spec
+0
-23
paddle/fluid/operators/activation_cudnn_op.cu.cc
paddle/fluid/operators/activation_cudnn_op.cu.cc
+15
-1
paddle/fluid/operators/activation_op.cc
paddle/fluid/operators/activation_op.cc
+68
-93
paddle/fluid/operators/activation_op.cu
paddle/fluid/operators/activation_op.cu
+3
-2
paddle/fluid/operators/activation_op.h
paddle/fluid/operators/activation_op.h
+148
-53
未找到文件。
paddle/fluid/framework/details/op_registry.h
浏览文件 @
9f7b027d
...
@@ -233,6 +233,12 @@ struct OpInfoFiller<T, kNoNeedBufferVarsInference> {
...
@@ -233,6 +233,12 @@ struct OpInfoFiller<T, kNoNeedBufferVarsInference> {
}
}
};
};
// A fake OpInfoFiller of void
template
<
>
struct
OpInfoFiller
<
void
,
kUnknown
>
{
void
operator
()(
const
char
*
op_type
,
OpInfo
*
info
)
const
{}
};
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/op_use_default_grad_op_maker.spec
浏览文件 @
9f7b027d
abs
acos
asin
atan
attention_lstm
attention_lstm
brelu
conv_shift
conv_shift
cos
cos_sim
cos_sim
dequantize
dequantize
elu
fc
fc
flatten
flatten
fsp
fsp
...
@@ -21,13 +14,8 @@ fusion_seqconv_eltadd_relu
...
@@ -21,13 +14,8 @@ fusion_seqconv_eltadd_relu
fusion_seqexpand_concat_fc
fusion_seqexpand_concat_fc
fusion_seqpool_concat
fusion_seqpool_concat
fusion_squared_mat_sub
fusion_squared_mat_sub
gelu
gru
gru
hard_shrink
hierarchical_sigmoid
hierarchical_sigmoid
leaky_relu
log
logsigmoid
lrn
lrn
lstm_unit
lstm_unit
lstmp
lstmp
...
@@ -38,7 +26,6 @@ modified_huber_loss
...
@@ -38,7 +26,6 @@ modified_huber_loss
nce
nce
pool2d
pool2d
pool3d
pool3d
pow
prelu
prelu
quantize
quantize
rank_loss
rank_loss
...
@@ -50,20 +37,10 @@ reduce_sum
...
@@ -50,20 +37,10 @@ reduce_sum
requantize
requantize
reshape
reshape
rnn_memory_helper
rnn_memory_helper
round
sequence_softmax
sequence_softmax
sin
softplus
softshrink
softsign
spp
spp
square
squeeze
squeeze
stanh
swish
tanh_shrink
tensor_array_to_tensor
tensor_array_to_tensor
thresholded_relu
transpose
transpose
unpool
unpool
unsqueeze
unsqueeze
paddle/fluid/operators/activation_cudnn_op.cu.cc
浏览文件 @
9f7b027d
...
@@ -12,6 +12,9 @@
...
@@ -12,6 +12,9 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/platform/cudnn_desc.h"
#include "paddle/fluid/platform/cudnn_desc.h"
...
@@ -82,6 +85,8 @@ template <typename T>
...
@@ -82,6 +85,8 @@ template <typename T>
struct
CudnnReluGradFunctor
:
public
CudnnActivationGradFunctor
<
T
>
{
struct
CudnnReluGradFunctor
:
public
CudnnActivationGradFunctor
<
T
>
{
explicit
CudnnReluGradFunctor
(
const
CUDADeviceContext
&
ctx
)
explicit
CudnnReluGradFunctor
(
const
CUDADeviceContext
&
ctx
)
:
CudnnActivationGradFunctor
<
T
>
(
ctx
,
0.0
,
CUDNN_ACTIVATION_RELU
)
{}
:
CudnnActivationGradFunctor
<
T
>
(
ctx
,
0.0
,
CUDNN_ACTIVATION_RELU
)
{}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDepOut
;
}
};
};
template
<
typename
T
>
template
<
typename
T
>
...
@@ -94,6 +99,8 @@ struct CudnnRelu6GradFunctor : public CudnnActivationGradFunctor<T> {
...
@@ -94,6 +99,8 @@ struct CudnnRelu6GradFunctor : public CudnnActivationGradFunctor<T> {
explicit
CudnnRelu6GradFunctor
(
const
CUDADeviceContext
&
ctx
)
explicit
CudnnRelu6GradFunctor
(
const
CUDADeviceContext
&
ctx
)
:
CudnnActivationGradFunctor
<
T
>
(
ctx
,
6.0
,
CUDNN_ACTIVATION_CLIPPED_RELU
)
{
:
CudnnActivationGradFunctor
<
T
>
(
ctx
,
6.0
,
CUDNN_ACTIVATION_CLIPPED_RELU
)
{
}
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDepOut
;
}
};
};
template
<
typename
T
>
template
<
typename
T
>
...
@@ -105,6 +112,8 @@ template <typename T>
...
@@ -105,6 +112,8 @@ template <typename T>
struct
CudnnSigmoidGradFunctor
:
public
CudnnActivationGradFunctor
<
T
>
{
struct
CudnnSigmoidGradFunctor
:
public
CudnnActivationGradFunctor
<
T
>
{
explicit
CudnnSigmoidGradFunctor
(
const
CUDADeviceContext
&
ctx
)
explicit
CudnnSigmoidGradFunctor
(
const
CUDADeviceContext
&
ctx
)
:
CudnnActivationGradFunctor
<
T
>
(
ctx
,
0.0
,
CUDNN_ACTIVATION_SIGMOID
)
{}
:
CudnnActivationGradFunctor
<
T
>
(
ctx
,
0.0
,
CUDNN_ACTIVATION_SIGMOID
)
{}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDepOut
;
}
};
};
template
<
typename
T
>
template
<
typename
T
>
...
@@ -116,6 +125,8 @@ template <typename T>
...
@@ -116,6 +125,8 @@ template <typename T>
struct
CudnnTanhGradFunctor
:
public
CudnnActivationGradFunctor
<
T
>
{
struct
CudnnTanhGradFunctor
:
public
CudnnActivationGradFunctor
<
T
>
{
explicit
CudnnTanhGradFunctor
(
const
CUDADeviceContext
&
ctx
)
explicit
CudnnTanhGradFunctor
(
const
CUDADeviceContext
&
ctx
)
:
CudnnActivationGradFunctor
<
T
>
(
ctx
,
0.0
,
CUDNN_ACTIVATION_TANH
)
{}
:
CudnnActivationGradFunctor
<
T
>
(
ctx
,
0.0
,
CUDNN_ACTIVATION_TANH
)
{}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDepOut
;
}
};
};
template
<
typename
Functor
>
template
<
typename
Functor
>
...
@@ -140,10 +151,13 @@ class CudnnActivationGradKernel
...
@@ -140,10 +151,13 @@ class CudnnActivationGradKernel
public:
public:
using
T
=
typename
Functor
::
ELEMENT_TYPE
;
using
T
=
typename
Functor
::
ELEMENT_TYPE
;
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
static_assert
(
Functor
::
FwdDeps
()
==
kDepOut
,
"Forward deps must be Out."
);
const
framework
::
Tensor
*
X
,
*
Out
,
*
dOut
;
const
framework
::
Tensor
*
X
,
*
Out
,
*
dOut
;
X
=
Out
=
dOut
=
nullptr
;
X
=
Out
=
dOut
=
nullptr
;
framework
::
Tensor
*
dX
=
nullptr
;
framework
::
Tensor
*
dX
=
nullptr
;
ExtractActivationGradTensor
(
context
,
&
X
,
&
Out
,
&
dOut
,
&
dX
);
ExtractActivationGradTensor
<
Functor
::
FwdDeps
()
>
(
context
,
&
X
,
&
Out
,
&
dOut
,
&
dX
);
dX
->
mutable_data
<
T
>
(
context
.
GetPlace
());
dX
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
&
dev_ctx
=
context
.
template
device_context
<
CUDADeviceContext
>();
auto
&
dev_ctx
=
context
.
template
device_context
<
CUDADeviceContext
>();
Functor
functor
(
dev_ctx
);
Functor
functor
(
dev_ctx
);
...
...
paddle/fluid/operators/activation_op.cc
浏览文件 @
9f7b027d
...
@@ -15,7 +15,9 @@ limitations under the License. */
...
@@ -15,7 +15,9 @@ limitations under the License. */
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/activation_op.h"
#include <memory>
#include <memory>
#include <string>
#include <string>
#include <type_traits>
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h"
#include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h"
#include "paddle/fluid/platform/port.h"
#include "paddle/fluid/platform/port.h"
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
...
@@ -27,6 +29,25 @@ namespace operators {
...
@@ -27,6 +29,25 @@ namespace operators {
using
paddle
::
framework
::
Tensor
;
using
paddle
::
framework
::
Tensor
;
template
<
typename
GradFunctor
>
static
constexpr
bool
CanInplaceAct
()
{
return
GradFunctor
::
FwdDeps
()
==
kDepOut
||
GradFunctor
::
FwdDeps
()
==
kNoDeps
;
}
std
::
unique_ptr
<
std
::
unordered_set
<
std
::
string
>>
GetInplaceOpSet
()
{
std
::
unique_ptr
<
std
::
unordered_set
<
std
::
string
>>
ret
(
new
std
::
unordered_set
<
std
::
string
>
());
#define INSERT_INTO_INPLACE_OP_SET(op_type, __omitted, fwd_functor, \
bwd_functor) \
if (CanInplaceAct<bwd_functor<float>>()) { \
ret->insert(#op_type); \
}
FOR_EACH_ACTIVATION_OP
(
INSERT_INTO_INPLACE_OP_SET
);
#undef INSERT_INTO_INPLACE_OP_SET
return
ret
;
}
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \
class OP_NAME##OpMaker \
class OP_NAME##OpMaker \
: public ::paddle::framework::OpProtoAndCheckerMaker { \
: public ::paddle::framework::OpProtoAndCheckerMaker { \
...
@@ -50,27 +71,33 @@ using paddle::framework::Tensor;
...
@@ -50,27 +71,33 @@ using paddle::framework::Tensor;
} \
} \
}
}
#define REGISTER_ACTIVATION_OP_GRAD_MAKER(OP_NAME, KERNEL_TYPE) \
template
<
ActBwdOpFwdDeps
kDepValue
>
class OP_NAME##GradMaker \
class
ActivationGradOpDescMaker
:
public
framework
::
SingleGradOpDescMaker
{
: public ::paddle::framework::SingleGradOpDescMaker { \
public:
public: \
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker; \
\
protected:
protected: \
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
std::unique_ptr<::paddle::framework::OpDesc> Apply() const override { \
std
::
unique_ptr
<
framework
::
OpDesc
>
op
(
new
framework
::
OpDesc
());
auto* op = new ::paddle::framework::OpDesc(); \
op
->
SetType
(
ForwardOpType
()
+
"_grad"
);
op->SetType(#KERNEL_TYPE "_grad"); \
op
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
OutputGrad
(
"Out"
));
op->SetInput("Out", Output("Out")); \
op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
InputGrad
(
"X"
));
op->SetInput(::paddle::framework::GradVarName("Out"), \
op
->
SetAttrMap
(
Attrs
());
OutputGrad("Out")); \
\
if
(
static_cast
<
int
>
(
kDepValue
)
&
op->SetAttrMap(Attrs()); \
static_cast
<
int
>
(
ActBwdOpFwdDeps
::
kDepX
))
{
\
op
->
SetInput
(
"X"
,
Input
(
"X"
));
op->SetOutput(::paddle::framework::GradVarName("X"), InputGrad("X")); \
}
return std::unique_ptr<::paddle::framework::OpDesc>(op); \
} \
if
(
static_cast
<
int
>
(
kDepValue
)
&
static_cast
<
int
>
(
ActBwdOpFwdDeps
::
kDepOut
))
{
op
->
SetInput
(
"Out"
,
Output
(
"Out"
));
}
}
return
op
;
}
};
framework
::
OpKernelType
GetKernelType
(
const
framework
::
ExecutionContext
&
ctx
,
framework
::
OpKernelType
GetKernelType
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
OperatorWithKernel
&
oper
,
const
framework
::
OperatorWithKernel
&
oper
,
const
std
::
string
&
name
)
{
const
std
::
string
&
name
)
{
...
@@ -129,14 +156,15 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
...
@@ -129,14 +156,15 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
ShareDim
(
"Out"
,
framework
::
GradVarName
(
"X"
));
auto
out_grad_name
=
framework
::
GradVarName
(
"Out"
);
ctx
->
ShareLoD
(
"Out"
,
framework
::
GradVarName
(
"X"
));
ctx
->
ShareDim
(
out_grad_name
,
framework
::
GradVarName
(
"X"
));
ctx
->
ShareLoD
(
out_grad_name
,
framework
::
GradVarName
(
"X"
));
}
}
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
GetKernelType
(
ctx
,
*
this
,
"Out"
);
return
GetKernelType
(
ctx
,
*
this
,
framework
::
GradVarName
(
"Out"
)
);
}
}
};
};
...
@@ -558,79 +586,27 @@ REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc);
...
@@ -558,79 +586,27 @@ REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc);
REGISTER_ACTIVATION_OP_MAKER
(
Square
,
SquareDoc
);
REGISTER_ACTIVATION_OP_MAKER
(
Square
,
SquareDoc
);
REGISTER_ACTIVATION_OP_MAKER
(
Softplus
,
SoftplusDoc
);
REGISTER_ACTIVATION_OP_MAKER
(
Softplus
,
SoftplusDoc
);
REGISTER_ACTIVATION_OP_MAKER
(
Softsign
,
SoftsignDoc
);
REGISTER_ACTIVATION_OP_MAKER
(
Softsign
,
SoftsignDoc
);
REGISTER_ACTIVATION_OP_GRAD_MAKER
(
Sigmoid
,
sigmoid
);
REGISTER_ACTIVATION_OP_GRAD_MAKER
(
Relu
,
relu
);
REGISTER_ACTIVATION_OP_GRAD_MAKER
(
Gelu
,
gelu
);
REGISTER_ACTIVATION_OP_GRAD_MAKER
(
Exp
,
exp
);
REGISTER_ACTIVATION_OP_GRAD_MAKER
(
Tanh
,
tanh
);
REGISTER_ACTIVATION_OP_GRAD_MAKER
(
Ceil
,
ceil
);
REGISTER_ACTIVATION_OP_GRAD_MAKER
(
Floor
,
floor
);
REGISTER_ACTIVATION_OP_GRAD_MAKER
(
Sqrt
,
sqrt
);
REGISTER_ACTIVATION_OP_GRAD_MAKER
(
SoftRelu
,
soft_relu
);
REGISTER_ACTIVATION_OP_GRAD_MAKER
(
Relu6
,
relu6
);
REGISTER_ACTIVATION_OP_GRAD_MAKER
(
Reciprocal
,
reciprocal
);
REGISTER_ACTIVATION_OP_GRAD_MAKER
(
HardSigmoid
,
hard_sigmoid
);
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
#define FOR_EACH_INPLACE_OP_FUNCTOR(__macro) \
#define REGISTER_ACTIVATION_OP(KERNEL_TYPE, OP_NAME, functor, grad_functor) \
__macro(Sigmoid, sigmoid); \
REGISTER_OPERATOR( \
__macro(Relu, relu); \
KERNEL_TYPE, ops::ActivationOp, ops::OP_NAME##OpMaker, \
__macro(Exp, exp); \
ops::ActivationOpInferVarType, \
__macro(Tanh, tanh); \
ops::ActivationGradOpDescMaker<ops::grad_functor<float>::FwdDeps()>, \
__macro(Ceil, ceil); \
std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(), \
__macro(Floor, floor); \
::paddle::framework::SingleOpInplaceInToOut, \
__macro(Sqrt, sqrt); \
void>::type); \
__macro(SoftRelu, soft_relu); \
REGISTER_OPERATOR( \
__macro(Relu6, relu6); \
KERNEL_TYPE##_grad, ops::ActivationOpGrad, \
__macro(Reciprocal, reciprocal); \
std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(), \
__macro(HardSigmoid, hard_sigmoid);
::paddle::framework::SingleOpInplaceInToOut, \
void>::type)
#define FOR_EACH_OP_FUNCTOR(__macro) \
__macro(LogSigmoid, logsigmoid); \
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, op_name, functor, \
__macro(SoftShrink, softshrink); \
grad_functor) \
__macro(Abs, abs); \
__macro(Cos, cos); \
__macro(Acos, acos); \
__macro(Sin, sin); \
__macro(Asin, asin); \
__macro(Atan, atan); \
__macro(Round, round); \
__macro(Log, log); \
__macro(Square, square); \
__macro(Gelu, gelu); \
__macro(BRelu, brelu); \
__macro(Pow, pow); \
__macro(STanh, stanh); \
__macro(Softplus, softplus); \
__macro(Softsign, softsign); \
__macro(LeakyRelu, leaky_relu); \
__macro(TanhShrink, tanh_shrink); \
__macro(ELU, elu); \
__macro(HardShrink, hard_shrink); \
__macro(Swish, swish); \
__macro(ThresholdedRelu, thresholded_relu);
#define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
::paddle::operators::OP_NAME##OpMaker, \
::paddle::operators::ActivationOpInferVarType, \
::paddle::operators::OP_NAME##GradMaker, \
::paddle::framework::SingleOpInplaceInToOut); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad, \
::paddle::framework::SingleOpInplaceInToOut)
#define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
::paddle::operators::OP_NAME##OpMaker, \
::paddle::operators::ActivationOpInferVarType, \
::paddle::framework::DefaultGradOpDescMaker<true>); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_CPU_KERNEL( \
REGISTER_OP_CPU_KERNEL( \
act_type, ops::ActivationKernel<paddle::platform::CPUDeviceContext, \
act_type, ops::ActivationKernel<paddle::platform::CPUDeviceContext, \
ops::functor<float>>, \
ops::functor<float>>, \
...
@@ -643,6 +619,5 @@ namespace ops = paddle::operators;
...
@@ -643,6 +619,5 @@ namespace ops = paddle::operators;
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext, \
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext, \
ops::grad_functor<double>>);
ops::grad_functor<double>>);
FOR_EACH_OP_FUNCTOR
(
REGISTER_ACTIVATION_OP
);
FOR_EACH_ACTIVATION_OP
(
REGISTER_ACTIVATION_OP
);
FOR_EACH_INPLACE_OP_FUNCTOR
(
REGISTER_INPLACE_ACTIVATION_OP
);
FOR_EACH_ACTIVATION_OP
(
REGISTER_ACTIVATION_CPU_KERNEL
);
FOR_EACH_KERNEL_FUNCTOR
(
REGISTER_ACTIVATION_CPU_KERNEL
);
paddle/fluid/operators/activation_op.cu
浏览文件 @
9f7b027d
...
@@ -15,7 +15,8 @@ limitations under the License. */
...
@@ -15,7 +15,8 @@ limitations under the License. */
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
namespace
plat
=
paddle
::
platform
;
#define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, functor, grad_functor) \
#define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, op_name, functor, \
grad_functor) \
REGISTER_OP_CUDA_KERNEL( \
REGISTER_OP_CUDA_KERNEL( \
act_type, \
act_type, \
ops::ActivationKernel<plat::CUDADeviceContext, ops::functor<float>>, \
ops::ActivationKernel<plat::CUDADeviceContext, ops::functor<float>>, \
...
@@ -30,4 +31,4 @@ namespace plat = paddle::platform;
...
@@ -30,4 +31,4 @@ namespace plat = paddle::platform;
ops::ActivationGradKernel<plat::CUDADeviceContext, \
ops::ActivationGradKernel<plat::CUDADeviceContext, \
ops::grad_functor<plat::float16>>);
ops::grad_functor<plat::float16>>);
FOR_EACH_
KERNEL_FUNCTOR
(
REGISTER_ACTIVATION_CUDA_KERNEL
);
FOR_EACH_
ACTIVATION_OP
(
REGISTER_ACTIVATION_CUDA_KERNEL
);
paddle/fluid/operators/activation_op.h
浏览文件 @
9f7b027d
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录