提交 626227eb 编写于 作者: D dzhwinter

"fix ci"

上级 b92b408e
...@@ -32,14 +32,16 @@ namespace operators { ...@@ -32,14 +32,16 @@ namespace operators {
} \ } \
} }
#define REGISTER_ACTIVATION_OP_GRAD_MAKER(OP_NAME) \ #define REGISTER_ACTIVATION_OP_GRAD_MAKER(OP_NAME, KERNEL_TYPE) \
class OP_NAME##GradMaker : public framework::SingleGradOpDescMaker { \ class OP_NAME##GradMaker : public framework::SingleGradOpDescMaker { \
public: \ public: \
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; \
\
protected: \ protected: \
std::unique_ptr<framework::OpDesc> Apply() const override { \ std::unique_ptr<framework::OpDesc> Apply() const override { \
auto *op = new framework::OpDesc(); \ auto *op = new framework::OpDesc(); \
op->SetType(#OP_NAME "_grad"); \ op->SetType(#KERNEL_TYPE "_grad"); \
op->SetInput("Out", Input("Out")); \ op->SetInput("Out", Output("Out")); \
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); \ op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); \
\ \
op->SetAttrMap(Attrs()); \ op->SetAttrMap(Attrs()); \
...@@ -452,56 +454,64 @@ REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc); ...@@ -452,56 +454,64 @@ REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc);
// variable // variable
// is used in gradient operator. // is used in gradient operator.
// The operator name written in lowercase intentionally. // The operator name written in lowercase intentionally.
REGISTER_ACTIVATION_OP_GRAD_MAKER(sigmoid); REGISTER_ACTIVATION_OP_GRAD_MAKER(Sigmoid, sigmoid);
REGISTER_ACTIVATION_OP_GRAD_MAKER(exp); REGISTER_ACTIVATION_OP_GRAD_MAKER(Exp, exp);
REGISTER_ACTIVATION_OP_GRAD_MAKER(relu); REGISTER_ACTIVATION_OP_GRAD_MAKER(Relu, relu);
REGISTER_ACTIVATION_OP_GRAD_MAKER(tanh); REGISTER_ACTIVATION_OP_GRAD_MAKER(Tanh, tanh);
REGISTER_ACTIVATION_OP_GRAD_MAKER(sqrt); REGISTER_ACTIVATION_OP_GRAD_MAKER(Sqrt, sqrt);
REGISTER_ACTIVATION_OP_GRAD_MAKER(ceil); REGISTER_ACTIVATION_OP_GRAD_MAKER(Ceil, ceil);
REGISTER_ACTIVATION_OP_GRAD_MAKER(floor); REGISTER_ACTIVATION_OP_GRAD_MAKER(Floor, floor);
REGISTER_ACTIVATION_OP_GRAD_MAKER(reciprocal); REGISTER_ACTIVATION_OP_GRAD_MAKER(Reciprocal, reciprocal);
REGISTER_ACTIVATION_OP_GRAD_MAKER(relu6); REGISTER_ACTIVATION_OP_GRAD_MAKER(Relu6, relu6);
REGISTER_ACTIVATION_OP_GRAD_MAKER(soft_relu); REGISTER_ACTIVATION_OP_GRAD_MAKER(SoftRelu, soft_relu);
REGISTER_ACTIVATION_OP_GRAD_MAKER(hard_sigmoid); REGISTER_ACTIVATION_OP_GRAD_MAKER(HardSigmoid, hard_sigmoid);
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
#define REGISTER_INPLACE_ACTIVATION_OP(act_type, op_name) \
REGISTER_OPERATOR(act_type, ops::ActivationOp, ops::op_name##OpMaker, \
ops::op_name##GradMaker); \
REGISTER_OPERATOR(act_type##grad, ops::ActivationOpGrad)
#define REGISTER_ACTIVATION_OP(act_type, op_name) \ #define REGISTER_ACTIVATION_OP(act_type, op_name) \
REGISTER_OP(act_type, ops::ActivationOp, ops::op_name##OpMaker, \ REGISTER_OP(act_type, ops::ActivationOp, ops::op_name##OpMaker, \
act_type##_grad, ops::ActivationOpGrad); act_type##_grad, ops::ActivationOpGrad);
#define FOR_EACH_OP_FUNCTOR(__macro) \ #define FOR_EACH_INPLACE_OP_FUNCTOR(__macro) \
__macro(sigmoid, Sigmoid); \ __macro(sigmoid, Sigmoid); \
__macro(logsigmoid, LogSigmoid); \ __macro(relu, Relu); \
__macro(exp, Exp); \ __macro(exp, Exp); \
__macro(relu, Relu); \ __macro(tanh, Tanh); \
__macro(tanh, Tanh); \ __macro(ceil, Ceil); \
__macro(softshrink, SoftShrink); \ __macro(floor, Floor); \
__macro(sqrt, Sqrt); \ __macro(sqrt, Sqrt); \
__macro(abs, Abs); \ __macro(soft_relu, SoftRelu); \
__macro(ceil, Ceil); \ __macro(relu6, Relu6); \
__macro(floor, Floor); \ __macro(reciprocal, Reciprocal); \
__macro(cos, Cos); \ __macro(hard_sigmoid, HardSigmoid);
__macro(sin, Sin); \
__macro(round, Round); \ #define FOR_EACH_OP_FUNCTOR(__macro) \
__macro(reciprocal, Reciprocal); \ __macro(logsigmoid, LogSigmoid); \
__macro(log, Log); \ __macro(softshrink, SoftShrink); \
__macro(square, Square); \ __macro(abs, Abs); \
__macro(brelu, BRelu); \ __macro(cos, Cos); \
__macro(soft_relu, SoftRelu); \ __macro(sin, Sin); \
__macro(pow, Pow); \ __macro(round, Round); \
__macro(stanh, STanh); \ __macro(log, Log); \
__macro(softplus, Softplus); \ __macro(square, Square); \
__macro(softsign, Softsign); \ __macro(brelu, BRelu); \
__macro(relu6, Relu6); \ __macro(pow, Pow); \
__macro(leaky_relu, LeakyRelu); \ __macro(stanh, STanh); \
__macro(tanh_shrink, TanhShrink); \ __macro(softplus, Softplus); \
__macro(elu, ELU); \ __macro(softsign, Softsign); \
__macro(hard_shrink, HardShrink); \ __macro(leaky_relu, LeakyRelu); \
__macro(hard_sigmoid, HardSigmoid); \ __macro(tanh_shrink, TanhShrink); \
__macro(swish, Swish); \ __macro(elu, ELU); \
__macro(hard_shrink, HardShrink); \
__macro(swish, Swish); \
__macro(thresholded_relu, ThresholdedRelu); __macro(thresholded_relu, ThresholdedRelu);
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \ #define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \
...@@ -518,4 +528,5 @@ namespace ops = paddle::operators; ...@@ -518,4 +528,5 @@ namespace ops = paddle::operators;
ops::grad_functor<double>>); ops::grad_functor<double>>);
FOR_EACH_OP_FUNCTOR(REGISTER_ACTIVATION_OP); FOR_EACH_OP_FUNCTOR(REGISTER_ACTIVATION_OP);
FOR_EACH_INPLACE_OP_FUNCTOR(REGISTER_INPLACE_ACTIVATION_OP);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL); FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL);
...@@ -9,7 +9,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -9,7 +9,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册