提交 4f6362c7 编写于 作者: H hjchen2

Add sequence_expand, sequence_pool and tanh op

上级 380013d9
......@@ -89,6 +89,9 @@ const char *G_OP_TYPE_FUSION_DECONV_RELU = "fusion_deconv_relu";
const char *G_OP_TYPE_FUSION_DECONV_ADD = "fusion_deconv_add";
const char *G_OP_TYPE_FUSION_DECONV_ADD_RELU = "fusion_deconv_add_relu";
const char *G_OP_TYPE_SEQUENCE_EXPAND = "sequence_expand";
const char *G_OP_TYPE_SEQUENCE_POOL = "sequence_pool";
std::unordered_map<
std::string, std::pair<std::vector<std::string>, std::vector<std::string>>>
op_input_output_key = {
......@@ -162,5 +165,7 @@ std::unordered_map<
{G_OP_TYPE_TANH, {{"X"}, {"Out"}}},
{G_OP_TYPE_FUSION_DECONV_RELU, {{"Input"}, {"Out"}}},
{G_OP_TYPE_FUSION_DECONV_ADD, {{"Input"}, {"Out"}}},
{G_OP_TYPE_FUSION_DECONV_ADD_RELU, {{"Input"}, {"Out"}}}};
{G_OP_TYPE_FUSION_DECONV_ADD_RELU, {{"Input"}, {"Out"}}},
{G_OP_TYPE_SEQUENCE_EXPAND, {{"X", "Y"}, {"Out"}}},
{G_OP_TYPE_SEQUENCE_POOL, {{"X"}, {"Out"}}}};
} // namespace paddle_mobile
......@@ -105,6 +105,8 @@ enum ActivationType {
enum PoolingType {
MAX = 0,
AVG = 1,
SUM = 2,
FIRST = 3,
};
extern const char *G_OP_TYPE_CONV;
......@@ -169,6 +171,9 @@ extern const char *G_OP_TYPE_FUSION_DECONV_RELU;
extern const char *G_OP_TYPE_FUSION_DECONV_ADD;
extern const char *G_OP_TYPE_FUSION_DECONV_ADD_RELU;
extern const char *G_OP_TYPE_SEQUENCE_EXPAND;
extern const char *G_OP_TYPE_SEQUENCE_POOL;
extern std::unordered_map<
std::string, std::pair<std::vector<std::string>, std::vector<std::string>>>
op_input_output_key;
......
......@@ -264,3 +264,9 @@ LOAD_FUSION_MATCHER(fusion_dequant_add_bn_quant);
LOAD_OP1(fusion_dequant_add_bn_relu_quant, CPU);
LOAD_FUSION_MATCHER(fusion_dequant_add_bn_relu_quant);
#endif
#ifdef SEQUENCE_EXPAND_OP
LOAD_OP1(sequence_expand, CPU);
#endif
#ifdef SEQUENCE_POOL_OP
LOAD_OP1(sequence_pool, CPU);
#endif
......@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef RELU_OP
#pragma once
#include "framework/operator.h"
......@@ -22,23 +20,27 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class ReluKernel
: public framework::OpKernelBase<DeviceType, ReluParam<DeviceType>> {
public:
void Compute(const ReluParam<DeviceType>& param);
bool Init(ReluParam<DeviceType>* param);
};
template <typename DeviceType, typename T>
class Relu6Kernel
: public framework::OpKernelBase<DeviceType, ReluParam<DeviceType>> {
public:
void Compute(const ReluParam<DeviceType>& param);
bool Init(ReluParam<DeviceType>* param);
};
#define DECLARE_KERNEL(KernelClass, KernelParam) \
template <typename DeviceType, typename T> \
class KernelClass \
: public framework::OpKernelBase<DeviceType, KernelParam<DeviceType>> { \
public: \
bool Init(KernelParam<DeviceType> *param); \
void Compute(const KernelParam<DeviceType> &param); \
};
} // namespace operators
} // namespace paddle_mobile
#ifdef RELU_OP
DECLARE_KERNEL(ReluKernel, ReluParam);
DECLARE_KERNEL(Relu6Kernel, ReluParam);
#endif
#ifdef SIGMOID_OP
DECLARE_KERNEL(SigmoidKernel, SigmoidParam);
#endif
#ifdef TANH_OP
DECLARE_KERNEL(TanhKernel, TanhParam);
#endif
} // namespace operators
} // namespace paddle_mobile
......@@ -12,9 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef RELU_OP
#include "operators/kernel/relu_kernel.h"
#include "operators/kernel/activation_kernel.h"
#include "common/types.h"
#include "operators/math/activation.h"
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
......@@ -25,12 +23,12 @@ namespace paddle_mobile {
namespace operators {
template <typename Dtype, ActivationType Act>
struct ReluCompute {
struct ActivationCompute {
void operator()(const Tensor *input, Tensor *output) {}
};
template <ActivationType Act>
struct ReluCompute<float, Act> {
struct ActivationCompute<float, Act> {
void operator()(const Tensor *input, Tensor *output) {
const float *x = input->data<float>();
float *y = output->mutable_data<float>();
......@@ -65,6 +63,7 @@ struct ReluCompute<float, Act> {
}
};
#ifdef RELU_OP
template <>
bool ReluKernel<CPU, float>::Init(ReluParam<CPU> *param) {
return true;
......@@ -74,7 +73,7 @@ template <>
void ReluKernel<CPU, float>::Compute(const ReluParam<CPU> &param) {
const Tensor *input = param.InputX();
Tensor *output = param.Out();
ReluCompute<float, RELU>()(input, output);
ActivationCompute<float, RELU>()(input, output);
}
template <>
......@@ -86,10 +85,37 @@ template <>
void Relu6Kernel<CPU, float>::Compute(const ReluParam<CPU> &param) {
const Tensor *input = param.InputX();
Tensor *output = param.Out();
ReluCompute<float, RELU6>()(input, output);
ActivationCompute<float, RELU6>()(input, output);
}
#endif
} // namespace operators
} // namespace paddle_mobile
#ifdef SIGMOID_OP
template <>
bool SigmoidKernel<CPU, float>::Init(SigmoidParam<CPU> *param) {
return true;
}
template <>
void SigmoidKernel<CPU, float>::Compute(const SigmoidParam<CPU> &param) {
const Tensor *input = param.InputX();
Tensor *output = param.Out();
ActivationCompute<float, SIGMOID>()(input, output);
}
#endif
#ifdef TANH_OP
template <>
void TanhKernel<CPU, float>::Init(TanhParam<CPU> *param) {
return true;
}
template <>
void TanhKernel<CPU, float>::Compute(const TanhParam<CPU> &param) {
const Tensor *input = param.InputX();
Tensor *output = param.Out();
ActivationCompute<float, TANH>()(input, output);
}
#endif
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SEQUENCE_EXPAND_OP
#include <vector>
#include "operators/kernel/sequence_kernels.h"
namespace paddle_mobile {
namespace operators {
typedef int (*LoDElementFunctor)(const std::vector<size_t> &x_lod, int index);
int element_with_lod(const std::vector<size_t> &x_lod, int index) {
return x_lod[index];
}
int element_without_lod(const std::vector<size_t> &x_lod, int index) {
return index;
}
template <typename T>
inline void SequenceExpandImpl(const framework::LoDTensor &x,
const std::vector<size_t> &ref_lod,
framework::LoDTensor *output) {
const T *x_data = x.data<T>();
auto &x_lod = x.lod();
LoDElementFunctor lod_element = element_without_lod;
if (x_lod.size() == 1) lod_element = element_with_lod;
T *output_data = output->mutable_data<T>();
int x_item_length = x.numel() / x.dims()[0];
int out_offset = 0;
for (size_t i = 1; i < ref_lod.size(); ++i) {
int repeat_num = ref_lod[i] - ref_lod[i - 1];
int x_start = lod_element(x_lod[0], i - 1);
int x_end = lod_element(x_lod[0], i);
int x_seq_len = x_end - x_start;
if (repeat_num > 0) {
int out_start = out_offset;
if (output->lod().size() == 1) {
out_start = output->lod()[0][out_offset];
}
for (int j = 0; j < repeat_num; j++) {
for (int k = 0; k < x_seq_len; k++) {
memcpy(output_data + (out_start + j * x_seq_len + k) * x_item_length,
x_data + (x_start + k) * x_item_length,
x_item_length * sizeof(T));
}
}
}
out_offset += repeat_num;
}
}
template <typename T>
class SequenceExpandKernel<CPU, T>
: public framework::OpKernelBase<CPU, SequenceExpandParam<CPU>> {
public:
bool Init(SequenceExpandParam<CPU> *param) { return true; }
void Compute(const SequenceExpandParam<CPU> &param) {
const framework::LoDTensor *input_x = param.input_x_;
const framework::LoDTensor *input_y = param.input_y_;
framework::LoDTensor *output = param.output_;
output->mutable_data<T>();
const auto &x_lod = input_x->lod();
const auto &y_lod = input_y->lod();
int ref_level = param.ref_level_;
if (ref_level == -1) ref_level = y_lod.size() - 1;
if (y_lod[ref_level].size() <= 1) {
framework::TensorCopy(*input_x, output);
output->set_lod(input_x->lod());
return;
}
std::vector<size_t> out_lod;
if (x_lod.size() == 1) {
out_lod.push_back(0);
for (size_t i = 1; i < y_lod[ref_level].size(); ++i) {
int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1];
int x_start = x_lod[0][i - 1];
int x_end = x_lod[0][i];
int x_seq_len = x_end - x_start;
for (int j = 0; j < repeat_num; ++j) {
out_lod.push_back(out_lod.back() + x_seq_len);
}
}
}
output->set_lod({out_lod});
SequenceExpandImpl<T>(*input_x, y_lod[ref_level], output);
}
};
template class SequenceExpandKernel<CPU, float>;
// template class SequenceExpandKernel<CPU, int64_t>;
} // namespace operators
} // namespace paddle_mobile
#endif // SEQUENCE_EXPAND_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SEQUENCE_POOL_OP
#include <cmath>
#include <string>
#include <vector>
#include "common/types.h"
#include "operators/kernel/sequence_kernels.h"
#include "operators/math/pooling.h"
#ifdef __ARM_NEON__
#include <arm_neon.h>
#endif // __ARM_NEON__
namespace paddle_mobile {
namespace operators {
template <PoolingType P = MAX, typename T = float>
void SequencePoolImpl(const framework::LoDTensor &input,
framework::LoDTensor *output) {
const float *input_ptr = input.data<float>();
float *output_ptr = output->mutable_data<float>();
const auto &lod = input.lod()[0];
int64_t width = input.numel() / input.dims()[0];
#pragma omp parallel for
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
const float *in_ptr = input_ptr + lod[i] * width;
float *out_ptr = output_ptr + i * width;
int64_t height = static_cast<int64_t>(lod[i + 1] - lod[i]);
if (width == 1) {
float val = 0.f;
int remain_h = height;
#ifdef __ARM_NEON__
int loop = remain_h >> 2;
remain_h = remain_h & 0x3;
float32x4_t __max4 = math::vPoolInitq_f32<MAX>();
for (int h = 0; h < loop; ++h) {
float32x4_t r0 = vld1q_f32(in_ptr);
__max4 = vmaxq_f32(__max4, r0);
in_ptr += 4;
}
float32x2_t __max2 =
vpadd_f32(vget_low_f32(__max4), vget_high_f32(__max4));
__max2 = vpadd_f32(__max2, __max2);
val = std::max(val, vget_lane_f32(__max2, 0));
#endif // __ARM_NEON__
for (int h = 0; h < remain_h; ++h) {
val = std::max(val, in_ptr[h]);
}
*out_ptr = val;
} else {
memcpy(out_ptr, in_ptr, width * sizeof(float));
int remain_h = height - 1;
#ifdef __ARM_NEON__
int loop_w = width >> 2;
int remain_w_start = width & 0xfffc;
#endif // __ARM_NEON__
for (int h = 0; h < remain_h; ++h) {
#ifdef __ARM_NEON__
for (int w = 0; w < width; w += 4) {
float32x4_t __in = vld1q_f32(in_ptr + w);
float32x4_t __out = vld1q_f32(out_ptr + w);
__out = vmaxq_f32(__out, __in);
vst1q_f32(out_ptr + w, __out);
}
#endif // __ARM_NEON__
for (int w = remain_w_start; w < width; ++w) {
out_ptr[w] = std::max(out_ptr[w], in_ptr[w]);
}
in_ptr += width;
}
}
}
}
template <>
void SequencePoolImpl<SUM, float>(const framework::LoDTensor &input,
framework::LoDTensor *output) {
const float *input_ptr = input.data<float>();
float *output_ptr = output->mutable_data<float>();
const auto &lod = input.lod()[0];
int64_t width = input.numel() / input.dims()[0];
#pragma omp parallel for
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
const float *in_ptr = input_ptr + lod[i] * width;
float *out_ptr = output_ptr + i * width;
int64_t height = static_cast<int64_t>(lod[i + 1] - lod[i]);
if (width == 1) {
float sum = 0.f;
int remain_h = height;
#ifdef __ARM_NEON__
int loop = remain_h >> 2;
remain_h = remain_h & 0x3;
float32x4_t __sum4 = vdupq_n_f32(0.f);
for (int h = 0; h < loop; ++h) {
float32x4_t r0 = vld1q_f32(in_ptr);
__sum4 = vaddq_f32(__sum4, r0);
in_ptr += 4;
}
float32x2_t __sum2 =
vpadd_f32(vget_low_f32(__sum4), vget_high_f32(__sum4));
sum += vget_lane_f32(__sum2, 0) + vget_lane_f32(__sum2, 1);
#endif // __ARM_NEON__
for (int h = 0; h < remain_h; ++h) {
sum += in_ptr[h];
}
*out_ptr = sum;
} else {
memcpy(out_ptr, in_ptr, width * sizeof(float));
int remain_h = height - 1;
#ifdef __ARM_NEON__
int loop_w = width >> 2;
int remain_w_start = width & 0xfffc;
#endif // __ARM_NEON__
for (int h = 0; h < remain_h; ++h) {
#ifdef __ARM_NEON__
for (int w = 0; w < width; w += 4) {
float32x4_t __in = vld1q_f32(in_ptr + w);
float32x4_t __out = vld1q_f32(out_ptr + w);
__out = vaddq_f32(__out, __in);
vst1q_f32(out_ptr + w, __out);
}
#endif // __ARM_NEON__
for (int w = remain_w_start; w < width; ++w) {
out_ptr[w] += in_ptr[w];
}
in_ptr += width;
}
}
}
}
template <>
void SequencePoolImpl<FIRST, float>(const framework::LoDTensor &input,
framework::LoDTensor *output) {
const float *input_ptr = input.data<float>();
float *output_ptr = output->mutable_data<float>();
const auto &lod = input.lod()[0];
int64_t width = input.numel() / input.dims()[0];
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
const float *in_ptr = input_ptr + lod[i] * width;
float *out_ptr = output_ptr + i * width;
memcpy(out_ptr, in_ptr, width * sizeof(float));
}
}
template <typename T>
class SequencePoolKernel<CPU, T>
: public framework::OpKernelBase<CPU, SequencePoolParam<CPU>> {
public:
bool Init(SequencePoolParam<CPU> *param) { return true; }
void Compute(const SequencePoolParam<CPU> &param) {
const framework::LoDTensor *input = param.input_;
framework::LoDTensor *output = param.output_;
output->mutable_data<T>();
if (param.pool_type_ == "MAX") {
SequencePoolImpl<MAX, T>(*input, output);
} else if (param.pool_type_ == "FIRST") {
SequencePoolImpl<FIRST, T>(*input, output);
} else if (param.pool_type_ == "SUM") {
SequencePoolImpl<SUM, T>(*input, output);
}
}
};
template class SequencePoolKernel<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
#endif // SEQUENCE_POOL_OP
......@@ -186,7 +186,6 @@ inline void DepthwiseConv3x3(const ConvParam<CPU> &param) {
}
}
}
#endif // __aarch64__
template <typename Itype, typename Otype>
inline void DepthwiseConv5x5(const ConvParam<CPU> &param) {
......@@ -209,6 +208,7 @@ inline void DepthwiseConv5x5(const ConvParam<CPU> &param) {
GemmConv<Itype, Otype>(param);
}
}
#endif // __aarch64__
} // namespace operators
} // namespace paddle_mobile
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SIGMOID_OP
#pragma once
#include <cmath>
#include "operators/op_param.h"
#ifdef __ARM_NEON
#include <arm_neon.h>
#include "operators/math/math_func_neon.h"
#endif
namespace paddle_mobile {
namespace operators {
using framework::DDim;
void sigmoid(const Tensor *X, Tensor *Y) {
#ifdef __ARM_NEON
const float *input = X->data<float>();
float *output = Y->mutable_data<float>();
const DDim &dDim = X->dims();
int axis_index = 1;
if (dDim.size() < 4) {
axis_index = 0;
}
DDim outer_ddim =
paddle_mobile::framework::slice_ddim(dDim, 0, axis_index + 1);
DDim inner_ddim =
paddle_mobile::framework::slice_ddim(dDim, axis_index + 1, dDim.size());
int out_size = paddle_mobile::framework::product(outer_ddim);
int inner_size = paddle_mobile::framework::product(inner_ddim);
DLOG << "outsize=" << out_size;
DLOG << "innersize=" << inner_size;
#pragma omp parallel for
for (int i = 0; i < out_size; ++i) {
const float *input_outer_ptr = input + i * inner_size;
float *output_outer_ptr = output + i * inner_size;
int nn = inner_size >> 2;
int remain = inner_size - (nn << 2);
float32x4_t _one = vdupq_n_f32(1.f);
for (; nn > 0; nn--) {
float32x4_t data = vld1q_f32(input_outer_ptr);
data = vnegq_f32(data);
data = exp_ps(data);
data = vaddq_f32(data, _one);
float32x4_t out_data = vrecpeq_f32(data);
out_data = vmulq_f32(vrecpsq_f32(data, out_data), out_data);
vst1q_f32(output_outer_ptr, out_data);
input_outer_ptr += 4;
output_outer_ptr += 4;
}
for (; remain > 0; remain--) {
*output_outer_ptr = 1.f / (1.f + exp(-*input_outer_ptr));
output_outer_ptr++;
input_outer_ptr++;
}
}
#else
#endif
}
template <typename P>
void SigmoidCompute(const SigmoidParam<CPU> &param) {
const Tensor *in_x = param.InputX();
Tensor *out = param.Out();
auto x_dims = in_x->dims();
out->Resize(x_dims);
sigmoid(in_x, out);
}
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -12,32 +12,30 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SIGMOID_OP
#include "../sigmoid_kernel.h"
#include "../central-arm-func/sigmoid_arm_func.h"
#ifdef __ARM_NEON
#include "../../math/math_func_neon.h"
#endif
#include <cmath>
#pragma once
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using framework::DDim;
using framework::Tensor;
#define DECLARE_KERNEL(KernelClass, KernelParam) \
template <typename DeviceType, typename T> \
class KernelClass \
: public framework::OpKernelBase<DeviceType, KernelParam<DeviceType>> { \
public: \
bool Init(KernelParam<DeviceType> *param); \
void Compute(const KernelParam<DeviceType> &param); \
};
template <>
bool SigmoidKernel<CPU, float>::Init(SigmoidParam<CPU> *param) {
return true;
}
#ifdef SEQUENCE_EXPAND_OP
DECLARE_KERNEL(SequenceExpandKernel, SequenceExpandParam);
#endif // SEQUENCE_EXPAND_OP
template <>
void SigmoidKernel<CPU, float>::Compute(const SigmoidParam<CPU> &param) {
SigmoidCompute<float>(param);
}
#ifdef SEQUENCE_POOL_OP
DECLARE_KERNEL(SequencePoolKernel, SequencePoolParam);
#endif // SEQUENCE_POOL_OP
template class SigmoidKernel<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -56,6 +56,9 @@ inline int32x4_t vRoundq_f32(const float32x4_t &x) {
template <>
inline int32x4_t vRoundq_f32<ROUND_NEAREST_AWAY_ZERO>(const float32x4_t &x) {
#if __aarch64__
return vcvtaq_s32_f32(x);
#else
float32x4_t plus = vdupq_n_f32(0.5);
float32x4_t minus = vdupq_n_f32(-0.5);
float32x4_t zero = vdupq_n_f32(0);
......@@ -64,10 +67,14 @@ inline int32x4_t vRoundq_f32<ROUND_NEAREST_AWAY_ZERO>(const float32x4_t &x) {
temp = vaddq_f32(x, temp);
int32x4_t ret = vcvtq_s32_f32(temp);
return ret;
#endif
}
template <>
inline int32x4_t vRoundq_f32<ROUND_NEAREST_TO_EVEN>(const float32x4_t &x) {
#if __aarch64__
return vcvtnq_s32_f32(x);
#else
float32x4_t point5 = vdupq_n_f32(0.5);
int32x4_t one = vdupq_n_s32(1);
int32x4_t zero = vdupq_n_s32(0);
......@@ -90,6 +97,7 @@ inline int32x4_t vRoundq_f32<ROUND_NEAREST_TO_EVEN>(const float32x4_t &x) {
smask = vsubq_s32(smask, one);
rnd = vaddq_s32(rnd, smask);
return rnd;
#endif
}
#endif // __ARM_NEON__
......
......@@ -2737,5 +2737,57 @@ class FusionDequantAddBNQuantParam : public FusionDequantAddBNParam<Dtype> {
};
#endif
#ifdef SEQUENCE_EXPAND_OP
template <typename Dtype>
class SequenceExpandParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
SequenceExpandParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs,
const Scope &scope) {
input_x_ = InputXFrom<GType>(inputs, scope);
input_y_ = InputYFrom<GType>(inputs, scope);
output_ = OutFrom<GType>(outputs, scope);
ref_level_ = -1;
if (OpParam::HasAttr("ref_level", attrs)) {
ref_level_ = OpParam::GetAttr<int>("ref_level", attrs);
}
}
public:
GType *input_x_;
GType *input_y_;
GType *output_;
int ref_level_;
};
#endif // SEQUENCE_EXPAND_OP
#ifdef SEQUENCE_POOL_OP
template <typename Dtype>
class SequencePoolParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
SequencePoolParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs,
const Scope &scope) {
input_ = InputXFrom<GType>(inputs, scope);
output_ = OutFrom<GType>(outputs, scope);
pool_type_ = "MAX";
if (OpParam::HasAttr("pooltype", attrs)) {
pool_type_ = OpParam::GetAttr<std::string>("pooltype", attrs);
}
}
public:
GType *input_;
GType *output_;
std::string pool_type_;
};
#endif // SEQUENCE_EXPAND_OP
} // namespace operators
} // namespace paddle_mobile
......@@ -15,6 +15,7 @@ limitations under the License. */
#ifdef RELU_OP
#include "operators/relu_op.h"
namespace paddle_mobile {
namespace operators {
......@@ -47,4 +48,4 @@ REGISTER_OPERATOR_MALI_GPU(relu, ops::ReluOp);
REGISTER_OPERATOR_CL(relu, ops::ReluOp);
#endif
#endif
#endif // RELU_OP
......@@ -19,7 +19,7 @@ limitations under the License. */
#include <string>
#include "framework/operator.h"
#include "operators/kernel/relu_kernel.h"
#include "operators/kernel/activation_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SEQUENCE_EXPAND_OP
#include "operators/sequence_ops/sequence_expand_op.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
void SequenceExpandOp<DeviceType, T>::InferShape() const {
const auto *input_x = this->param_.input_x_;
const auto *input_y = this->param_.input_y_;
const auto &x_lod = input_x->lod();
const auto &y_lod = input_y->lod();
int ref_level = this->param_.ref_level_;
if (ref_level == -1) ref_level = y_lod.size() - 1;
auto out_dims = input_x->dims();
int64_t out_first_dim = 0;
if (y_lod[ref_level].size() > 1) {
for (size_t i = 1; i < y_lod[ref_level].size(); ++i) {
int x_seq_len = 1;
if (x_lod.size() == 1) {
x_seq_len = x_lod[0][i] - x_lod[0][i - 1];
}
out_first_dim +=
(y_lod[ref_level][i] - y_lod[ref_level][i - 1]) * x_seq_len;
}
out_dims[0] = out_first_dim;
}
this->param_.output_->Resize(out_dims);
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(sequence_expand, ops::SequenceExpandOp);
#endif
#endif // SEQUENCE_EXPAND_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SEQUENCE_EXPAND_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/sequence_kernels.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class SequenceExpandOp : public framework::OperatorWithKernel<
DeviceType, SequenceExpandParam<DeviceType>,
operators::SequenceExpandKernel<DeviceType, T>> {
public:
SequenceExpandOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, SequenceExpandParam<DeviceType>,
operators::SequenceExpandKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
// inference output shape
void InferShape() const override;
};
} // namespace operators
} // namespace paddle_mobile
#endif // SEQUENCE_EXPAND_OP
......@@ -12,27 +12,27 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef SEQUENCE_POOL_OP
#ifdef SIGMOID_OP
#include "framework/operator.h"
#include "operators/op_param.h"
#include "operators/sequence_ops/sequence_pool_op.h"
namespace paddle_mobile {
namespace operators {
using framework::OpKernelBase;
template <typename DeviceType, typename T>
class SigmoidKernel
: public OpKernelBase<DeviceType, SigmoidParam<DeviceType>> {
public:
void Compute(const SigmoidParam<DeviceType>& param);
bool Init(SigmoidParam<DeviceType>* param);
};
void SequencePoolOp<DeviceType, T>::InferShape() const {
const auto *input = this->param_.input_;
auto out_dims = input->dims();
out_dims[0] = input->lod()[0].size() - 1;
this->param_.output_->Resize(out_dims);
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(sequence_pool, ops::SequencePoolOp);
#endif
#endif // SEQUENCE_POOL_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SEQUENCE_POOL_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/sequence_kernels.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class SequencePoolOp : public framework::OperatorWithKernel<
DeviceType, SequencePoolParam<DeviceType>,
operators::SequencePoolKernel<DeviceType, T>> {
public:
SequencePoolOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, SequencePoolParam<DeviceType>,
operators::SequencePoolKernel<DeviceType, T>>(type, inputs, outputs,
attrs, scope) {}
// inference output shape
void InferShape() const override;
};
} // namespace operators
} // namespace paddle_mobile
#endif // SEQUENCE_POOL_OP
......@@ -18,7 +18,7 @@ limitations under the License. */
#include <string>
#include "framework/operator.h"
#include "operators/kernel/sigmoid_kernel.h"
#include "operators/kernel/activation_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
......
......@@ -28,6 +28,9 @@ void TanhOp<DeviceType, T>::InferShape() const {
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(tanh, ops::TanhOp);
#endif
#ifdef PADDLE_MOBILE_FPGA
REGISTER_OPERATOR_FPGA(tanh, ops::TanhOp);
#endif
......
......@@ -18,7 +18,7 @@ limitations under the License. */
#include <string>
#include "framework/operator.h"
#include "operators/kernel/tanh_kernel.h"
#include "operators/kernel/activation_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
......
......@@ -12,10 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "../../src/operators/kernel/central-arm-func/sigmoid_arm_func.h"
#include "../../src/operators/kernel/sigmoid_kernel.h"
#include "../test_helper.h"
#include "framework/executor.h"
int main() {
paddle_mobile::framework::Tensor input;
......@@ -25,11 +22,5 @@ int main() {
auto out_ddim = paddle_mobile::framework::make_ddim({1, 4, 60, 60});
output.Resize(out_ddim);
paddle_mobile::operators::sigmoid(&input, &output);
auto *output_ptr = output.data<float>();
for (int j = 0; j < output.numel(); ++j) {
DLOG << " value of output: " << output_ptr[j];
}
DLOG << 5;
return 0;
}
......@@ -272,6 +272,8 @@ if(NOT FOUND_MATCH)
set(FUSION_DEQUANT_ADD_BN_RELU_OP ON)
set(FUSION_DEQUANT_ADD_BN_QUANT_OP ON)
set(FUSION_DEQUANT_ADD_BN_RELU_QUANT_OP ON)
set(SEQUENCE_EXPAND_OP ON)
set(SEQUENCE_POOL_OP ON)
endif()
# option(BATCHNORM_OP "" ON)
......@@ -496,6 +498,12 @@ endif()
if (FUSION_DEQUANT_ADD_BN_RELU_QUANT_OP)
# add_definitions(-DFUSION_DEQUANT_ADD_BN_RELU_QUANT_OP)
endif()
if (SEQUENCE_EXPAND_OP)
add_definitions(-DSEQUENCE_EXPAND_OP)
endif()
if (SEQUENCE_POOL_OP)
add_definitions(-DSEQUENCE_POOL_OP)
endif()
if (TANH_OP)
add_definitions(-DTANH_OP)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册