提交 d736fc0e 编写于 作者: Q qijun

add activation macro

上级 3c49e7b1
......@@ -139,9 +139,9 @@ class OperatorBase {
// Macro for define a clone method.
// If you are writing an kernel operator, `Clone` will be defined when you
// register it. i.e. `Clone` method is not needed to define by yourself.
#define DEFINE_OP_CLONE_METHOD(cls) \
std::unique_ptr<OperatorBase> Clone() const final { \
return std::unique_ptr<OperatorBase>(new cls(*this)); \
#define DEFINE_OP_CLONE_METHOD(cls) \
std::unique_ptr<::paddle::framework::OperatorBase> Clone() const final { \
return std::unique_ptr<::paddle::framework::OperatorBase>(new cls(*this)); \
}
// Macro for define a default constructor for Operator.
......
......@@ -12,19 +12,33 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/sigmoid_op.h"
#include "paddle/operators/activation_op.h"
#define FILL_ACTIVATION_OP \
public: \
using framework::OperatorWithKernel::OperatorWithKernel; \
\
protected: \
void InferShape(const framework::InferShapeContext &ctx) const override { \
ctx.Output<framework::Tensor>("Y")->Resize( \
ctx.Input<framework::Tensor>("X")->dims()); \
}
#define FILL_ACTIVATION_GRAD_OP \
public: \
using framework::OperatorWithKernel::OperatorWithKernel; \
\
protected: \
void InferShape(const framework::InferShapeContext &ctx) const override { \
ctx.Output<framework::Tensor>(framework::GradVarName("X")) \
->Resize(ctx.Input<framework::Tensor>("Y")->dims()); \
}
namespace paddle {
namespace operators {
class SigmoidOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<Tensor>("Y")->Resize(ctx.Input<Tensor>("X")->dims());
}
FILL_ACTIVATION_OP
};
class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -32,23 +46,52 @@ class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
SigmoidOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "sigmoid input");
AddOutput("Y", "sigmoid output");
AddComment("Sigmoid function");
AddInput("X", "Input of Sigmoid operator");
AddOutput("Y", "Output of Sigmoid operator");
AddComment("Sigmoid activation operator");
}
};
class SigmoidOpGrad : public framework::OperatorWithKernel {
FILL_ACTIVATION_GRAD_OP
};
class ExpOp : public framework::OperatorWithKernel {
FILL_ACTIVATION_OP
};
class ExpOpMaker : public framework::OpProtoAndCheckerMaker {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
ExpOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Exp operator");
AddOutput("Y", "Output of Exp operator");
AddComment("Exp activation operator");
}
};
class ExpOpGrad : public framework::OperatorWithKernel {
FILL_ACTIVATION_GRAD_OP
};
class ReluOp : public framework::OperatorWithKernel {
FILL_ACTIVATION_OP
};
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<Tensor>(framework::GradVarName("X"))
->Resize(ctx.Input<Tensor>("Y")->dims());
class ReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ReluOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Relu operator");
AddOutput("Y", "Output of Relu operator");
AddComment("Relu activation operator");
}
};
class ReluOpGrad : public framework::OperatorWithKernel {
FILL_ACTIVATION_GRAD_OP
};
} // namespace operators
} // namespace paddle
......@@ -59,3 +102,14 @@ REGISTER_OP_CPU_KERNEL(sigmoid,
ops::SigmoidKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
sigmoid_grad, ops::SigmoidGradKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP(exp, ops::ExpOp, ops::ExpOpMaker, exp_grad, ops::ExpOpGrad);
REGISTER_OP_CPU_KERNEL(exp, ops::ExpKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(exp_grad,
ops::ExpGradKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP(relu, ops::ReluOp, ops::ReluOpMaker, relu_grad, ops::ReluOpGrad);
REGISTER_OP_CPU_KERNEL(relu,
ops::ReluKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(relu_grad,
ops::ReluGradKernel<paddle::platform::CPUPlace, float>);
......@@ -13,7 +13,7 @@
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/sigmoid_op.h"
#include "paddle/operators/activation_op.h"
namespace ops = paddle::operators;
......@@ -21,3 +21,12 @@ REGISTER_OP_GPU_KERNEL(sigmoid,
ops::SigmoidKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
sigmoid_grad, ops::SigmoidGradKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(exp, ops::ExpKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(exp_grad,
ops::ExpGradKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(relu,
ops::ReluKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(relu_grad,
ops::ReluGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/activation_functor.h"
#define ACTIVATION_KERNEL_NAME(ACTIVATION_NAME) ACTIVATION_NAME##Kernel
#define DEFINE_ACTIVATION_KERNEL(ACTIVATION_NAME) \
template <typename Place, typename T> \
class ACTIVATION_KERNEL_NAME(ACTIVATION_NAME) : public framework::OpKernel { \
public: \
void Compute(const framework::ExecutionContext& context) const override { \
auto* X = context.Input<framework::Tensor>("X"); \
auto* Y = context.Output<framework::Tensor>("Y"); \
Y->mutable_data<T>(context.GetPlace()); \
math::ACTIVATION_NAME<Place, T> functor; \
auto* device_context = context.device_context(); \
functor(*device_context, *X, Y); \
} \
};
#define DEFINE_ACTIVATION_GRAD_KERNEL(ACTIVATION_GRAD_NAME) \
template <typename Place, typename T> \
class ACTIVATION_KERNEL_NAME(ACTIVATION_GRAD_NAME) \
: public framework::OpKernel { \
public: \
void Compute(const framework::ExecutionContext& context) const override { \
auto* X = context.Input<framework::Tensor>("X"); \
auto* Y = context.Input<framework::Tensor>("Y"); \
auto* dY = \
context.Input<framework::Tensor>(framework::GradVarName("Y")); \
auto* dX = \
context.Output<framework::Tensor>(framework::GradVarName("X")); \
dX->mutable_data<T>(context.GetPlace()); \
math::ACTIVATION_GRAD_NAME<Place, T> functor; \
auto* device_context = context.device_context(); \
functor(*device_context, *X, *Y, *dY, dX); \
} \
};
namespace paddle {
namespace operators {
DEFINE_ACTIVATION_KERNEL(Sigmoid);
DEFINE_ACTIVATION_GRAD_KERNEL(SigmoidGrad);
DEFINE_ACTIVATION_KERNEL(Exp);
DEFINE_ACTIVATION_GRAD_KERNEL(ExpGrad);
DEFINE_ACTIVATION_KERNEL(Relu);
DEFINE_ACTIVATION_GRAD_KERNEL(ReluGrad);
} // namespace operators
} // namespace paddle
#include "paddle/framework/eigen.h"
#include "paddle/framework/tensor.h"
namespace paddle {
namespace operators {
namespace math {
template <typename Place, typename T>
struct sigmoid {
void operator()(const platform::DeviceContext& deice_context,
const framework::Tensor& input, framework::Tensor* output) {
auto x = framework::EigenVector<T>::Flatten(*output);
auto y = framework::EigenVector<T>::Flatten(input);
auto* place = device_context.get_eigen_device<Place>();
y.device(*place) = 1. / (1. + (-x).exp());
}
};
}
}
}
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/tensor.h"
namespace paddle {
namespace operators {
namespace math {
template <typename Place, typename T>
struct Sigmoid {
void operator()(const platform::DeviceContext& device_context,
const framework::Tensor& X, framework::Tensor* Y) {
auto x = framework::EigenVector<T>::Flatten(X);
auto y = framework::EigenVector<T>::Flatten(*Y);
auto* place = device_context.template get_eigen_device<Place>();
y.device(*place) = 1. / (1. + (-x).exp());
}
};
template <typename Place, typename T>
struct SigmoidGrad {
void operator()(const platform::DeviceContext& device_context,
const framework::Tensor& X, const framework::Tensor& Y,
const framework::Tensor& dY, framework::Tensor* dX) {
auto dx = framework::EigenVector<T>::Flatten(*dX);
auto y = framework::EigenVector<T>::Flatten(Y);
auto dy = framework::EigenVector<T>::Flatten(dY);
auto* place = device_context.template get_eigen_device<Place>();
dx.device(*place) = dy * y * (1. - y);
}
};
template <typename Place, typename T>
struct Exp {
void operator()(const platform::DeviceContext& device_context,
const framework::Tensor& input, framework::Tensor* output) {
auto x = framework::EigenVector<T>::Flatten(input);
auto y = framework::EigenVector<T>::Flatten(*output);
auto* place = device_context.template get_eigen_device<Place>();
y.device(*place) = x.exp();
}
};
template <typename Place, typename T>
struct ExpGrad {
void operator()(const platform::DeviceContext& device_context,
const framework::Tensor& X, const framework::Tensor& Y,
const framework::Tensor& dY, framework::Tensor* dX) {
auto dx = framework::EigenVector<T>::Flatten(*dX);
auto dy = framework::EigenVector<T>::Flatten(dY);
auto* place = device_context.template get_eigen_device<Place>();
dx.device(*place) = dy.exp();
}
};
template <typename Place, typename T>
struct Relu {
void operator()(const platform::DeviceContext& device_context,
const framework::Tensor& input, framework::Tensor* output) {
auto x = framework::EigenVector<T>::Flatten(input);
auto y = framework::EigenVector<T>::Flatten(*output);
auto* place = device_context.template get_eigen_device<Place>();
y.device(*place) = x.cwiseMax(static_cast<T>(0));
}
};
template <typename Place, typename T>
struct ReluGrad {
void operator()(const platform::DeviceContext& device_context,
const framework::Tensor& X, const framework::Tensor& Y,
const framework::Tensor& dY, framework::Tensor* dX) {
auto dx = framework::EigenVector<T>::Flatten(*dX);
auto dy = framework::EigenVector<T>::Flatten(dY);
auto x = framework::EigenVector<T>::Flatten(X);
auto* place = device_context.template get_eigen_device<Place>();
dx.device(*place) = dy * (x > static_cast<T>(0)).template cast<T>();
}
};
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T>
class SigmoidKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto input = context.Input<Tensor>("X");
auto output = context.Output<Tensor>("Y");
output->mutable_data<T>(context.GetPlace());
// The clipping is used in Paddle's raw implenmention
auto X = EigenVector<T>::Flatten(*input);
auto Y = EigenVector<T>::Flatten(*output);
auto place = context.GetEigenDevice<Place>();
Y.device(place) = 1. / (1. + (-X).exp());
}
};
template <typename Place, typename T>
class SigmoidGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto Y_t = context.Input<Tensor>("Y");
auto dY_t = context.Input<Tensor>(framework::GradVarName("Y"));
auto dX_t = context.Output<Tensor>(framework::GradVarName("X"));
dX_t->mutable_data<T>(context.GetPlace());
auto dX = EigenVector<T>::Flatten(*dX_t);
auto Y = EigenVector<T>::Flatten(*Y_t);
auto dY = EigenVector<T>::Flatten(*dY_t);
dX.device(context.GetEigenDevice<Place>()) = dY * Y * (1. - Y);
}
};
} // namespace operators
} // namespace paddle
......@@ -36,7 +36,6 @@ USE_OP(onehot_cross_entropy);
USE_OP(sgd);
USE_OP(mul);
USE_OP(mean);
USE_OP(sigmoid);
USE_OP(softmax);
USE_OP(rowwise_add);
USE_OP(fill_zeros_like);
......@@ -55,6 +54,9 @@ USE_OP(top_k);
USE_OP(squared_l2_distance);
USE_OP(sum);
USE_OP(reshape);
USE_OP(sigmoid);
USE_OP(exp);
USE_OP(relu);
namespace paddle {
namespace framework {
......
import unittest
import numpy as np
from op_test import OpTest
class TestExp(OpTest):
def setUp(self):
self.op_type = "exp"
self.inputs = {
'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32")
}
self.outputs = {'Y': np.exp(self.inputs['X'])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Y", max_relative_error=0.007)
if __name__ == '__main__':
unittest.main()
import unittest
import numpy as np
from op_test import OpTest
class TestExp(OpTest):
def setUp(self):
self.op_type = "exp"
self.inputs = {
'X': np.random.uniform(-1, 1, [11, 17]).astype("float32")
}
self.outputs = {'Y': np.maximum(self.inputs['X'], 0)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Y", max_relative_error=0.007)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册