提交 01c39a9f 编写于 作者: Z zp7 提交者: Yanzhan Yang

add relu6 threshold param (#1724)

上级 b6dc8773
......@@ -24,7 +24,7 @@ namespace operators {
#ifdef RELU_OP
DECLARE_OPERATOR(Relu, ReluParam, ReluKernel);
DECLARE_OPERATOR(Relu6, ReluParam, Relu6Kernel);
DECLARE_OPERATOR(Relu6, Relu6Param, Relu6Kernel);
#endif
#ifdef SIGMOID_OP
......
......@@ -22,7 +22,7 @@ namespace operators {
#ifdef RELU_OP
DECLARE_KERNEL(Relu, ReluParam);
DECLARE_KERNEL(Relu6, ReluParam);
DECLARE_KERNEL(Relu6, Relu6Param);
#endif
#ifdef SIGMOID_OP
......
......@@ -38,15 +38,16 @@ void ReluKernel<CPU, float>::Compute(const ReluParam<CPU> &param) {
}
template <>
bool Relu6Kernel<CPU, float>::Init(ReluParam<CPU> *param) {
bool Relu6Kernel<CPU, float>::Init(Relu6Param<CPU> *param) {
return true;
}
template <>
void Relu6Kernel<CPU, float>::Compute(const ReluParam<CPU> &param) {
void Relu6Kernel<CPU, float>::Compute(const Relu6Param<CPU> &param) {
const LoDTensor *input = param.InputX();
LoDTensor *output = param.Out();
ActivationCompute<float, RELU6>()(input, output);
float threshold = param.getThreshold();
ActivationCompute<float, RELU6>()(input, output, threshold);
output->set_lod(input->lod());
}
#endif
......
......@@ -15,7 +15,8 @@ limitations under the License. */
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void relu6(__read_only image2d_t input,
__write_only image2d_t output){
__write_only image2d_t output,
__private const float threshold){
const int x = get_global_id(0);
const int y = get_global_id(1);
......@@ -26,6 +27,6 @@ __kernel void relu6(__read_only image2d_t input,
half4 in = read_imageh(input, sampler, (int2)(x, y));
in = max((half4)(0.0f, 0.0f, 0.0f, 0.0f), in);
in = min((half4)(6.0f, 6.0f, 6.0f, 6.0f), in);
in = min((half4)(threshold, threshold, threshold, threshold), in);
write_imageh(output, (int2)(x, y), in);
}
......@@ -19,21 +19,23 @@ namespace paddle_mobile {
namespace operators {
template <>
bool Relu6Kernel<GPU_CL, float>::Init(ReluParam<GPU_CL>* param) {
bool Relu6Kernel<GPU_CL, float>::Init(Relu6Param<GPU_CL>* param) {
this->cl_helper_.AddKernel("relu6", "relu6.cl");
return true;
}
template <>
void Relu6Kernel<GPU_CL, float>::Compute(const ReluParam<GPU_CL>& param) {
void Relu6Kernel<GPU_CL, float>::Compute(const Relu6Param<GPU_CL>& param) {
auto kernel = this->cl_helper_.KernelAt(0);
const auto* input = param.InputX();
auto* output = param.Out();
float threshold = param.getThreshold();
auto default_work_size = this->cl_helper_.DefaultWorkSize(*output);
auto inputImage = input->GetCLImage();
auto outputImage = output->GetCLImage();
clSetKernelArg(kernel, 0, sizeof(cl_mem), &inputImage);
clSetKernelArg(kernel, 1, sizeof(cl_mem), &outputImage);
clSetKernelArg(kernel, 2, sizeof(cl_mem), &threshold);
const size_t work_size[2] = {input->ImageWidth(), input->ImageHeight()};
clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 2, NULL,
......
......@@ -116,6 +116,14 @@ inline float32x4_t vActiveq_f32<LEAKY_RELU>(const float32x4_t &x,
const float32x4_t &alpha) {
return vmaxq_f32(x, vmulq_f32(x, alpha));
}
template <>
inline float32x4_t vActiveq_f32<RELU6>(const float32x4_t &x,
const float32x4_t &alpha) {
float32x4_t __zero = vdupq_n_f32(0.f);
float32x4_t __threshold = vdupq_n_f32(vgetq_lane_f32(alpha, 0));
return vminq_f32(vmaxq_f32(x, __zero), __threshold);
}
#endif
template <ActivationType Act = IDENTITY>
......@@ -164,6 +172,11 @@ inline float Active<LEAKY_RELU>(const float &x, const float &alpha) {
return std::max(x, alpha * x);
}
template <>
inline float Active<RELU6>(const float &x, const float &alpha) {
return std::min(std::max(x, 0.f), alpha);
}
} // namespace math
} // namespace operators
} // namespace paddle_mobile
......@@ -1675,6 +1675,20 @@ class ReluParam : public ReluParamBase<Dtype> {
using ReluParamBase<Dtype>::ReluParamBase;
};
template <typename Dtype>
class Relu6Param : public ReluParamBase<Dtype> {
public:
Relu6Param(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, Scope *scope)
: ReluParamBase<Dtype>(inputs, outputs, attrs, scope) {
threshold = OpParam::GetAttr<float>("threshold", attrs);
}
float getThreshold() const { return threshold; }
private:
float threshold;
};
#ifdef PADDLE_MOBILE_CL
template <>
class ReluParam<GPU_CL> : public ReluParamBase<GPU_CL> {
......
......@@ -44,6 +44,7 @@ int TestRelu6Op(const std::vector<int> input_shape) {
auto output_var = scope.get()->Var("output");
framework::AttributeMap attrs;
attrs["threshold"].Set<float>(6.f);
auto *op = new operators::Relu6Op<CPU, float>("relu6", inputs, outputs, attrs,
scope.get());
op->InferShape();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册