提交 56ed27dd 编写于 作者: Z zhangyang0701 提交者: GitHub

Merge branch 'develop' into develop

......@@ -89,6 +89,10 @@ const char *G_OP_TYPE_FUSION_DECONV_RELU = "fusion_deconv_relu";
const char *G_OP_TYPE_FUSION_DECONV_ADD = "fusion_deconv_add";
const char *G_OP_TYPE_FUSION_DECONV_ADD_RELU = "fusion_deconv_add_relu";
const char *G_OP_TYPE_SEQUENCE_EXPAND = "sequence_expand";
const char *G_OP_TYPE_SEQUENCE_POOL = "sequence_pool";
const char *G_OP_TYPE_SEQUENCE_SOFTMAX = "sequence_softmax";
std::unordered_map<
std::string, std::pair<std::vector<std::string>, std::vector<std::string>>>
op_input_output_key = {
......@@ -162,5 +166,8 @@ std::unordered_map<
{G_OP_TYPE_TANH, {{"X"}, {"Out"}}},
{G_OP_TYPE_FUSION_DECONV_RELU, {{"Input"}, {"Out"}}},
{G_OP_TYPE_FUSION_DECONV_ADD, {{"Input"}, {"Out"}}},
{G_OP_TYPE_FUSION_DECONV_ADD_RELU, {{"Input"}, {"Out"}}}};
{G_OP_TYPE_FUSION_DECONV_ADD_RELU, {{"Input"}, {"Out"}}},
{G_OP_TYPE_SEQUENCE_EXPAND, {{"X", "Y"}, {"Out"}}},
{G_OP_TYPE_SEQUENCE_POOL, {{"X"}, {"Out"}}},
{G_OP_TYPE_SEQUENCE_SOFTMAX, {{"X"}, {"Out"}}}};
} // namespace paddle_mobile
......@@ -105,6 +105,8 @@ enum ActivationType {
enum PoolingType {
MAX = 0,
AVG = 1,
SUM = 2,
FIRST = 3,
};
extern const char *G_OP_TYPE_CONV;
......@@ -169,6 +171,10 @@ extern const char *G_OP_TYPE_FUSION_DECONV_RELU;
extern const char *G_OP_TYPE_FUSION_DECONV_ADD;
extern const char *G_OP_TYPE_FUSION_DECONV_ADD_RELU;
extern const char *G_OP_TYPE_SEQUENCE_EXPAND;
extern const char *G_OP_TYPE_SEQUENCE_POOL;
extern const char *G_OP_TYPE_SEQUENCE_SOFTMAX;
extern std::unordered_map<
std::string, std::pair<std::vector<std::string>, std::vector<std::string>>>
op_input_output_key;
......
......@@ -90,28 +90,28 @@ Executor<Device, T>::Executor(const Program<Device> &program, int batch_size,
}
}
template <typename Device>
template <typename T>
static void LoadMemInternal(void **data, LoDTensor *tensor,
bool quant_uint8 = false) {
char **data_buf = reinterpret_cast<char **>(data);
int64_t size = tensor->numel();
Device *tensor_data = tensor->mutable_data<Device>();
T *tensor_data = tensor->mutable_data<T>();
if (quant_uint8) {
// should be moved into operator init function
float min_value;
float max_value;
memory::Copy(&min_value, data_buf, sizeof(float));
memory::Copy(&max_value, data_buf + sizeof(float), sizeof(float));
data_buf += 2 * sizeof(float);
memory::Copy(&min_value, *data_buf, sizeof(float));
memory::Copy(&max_value, *data_buf + sizeof(float), sizeof(float));
*data_buf += 2 * sizeof(float);
const float factor = (max_value - min_value) / 255.0;
const uint8_t *uint8_data = reinterpret_cast<uint8_t *>(data_buf);
const uint8_t *uint8_data = reinterpret_cast<uint8_t *>(*data_buf);
for (int k = 0; k < size; ++k) {
tensor_data[k] = uint8_data[k] * factor + min_value;
}
data_buf += size * sizeof(uint8_t);
*data_buf += size * sizeof(uint8_t);
} else {
memory::Copy(tensor_data, *data_buf, size * sizeof(Device));
*data_buf += size * sizeof(Device);
memory::Copy(tensor_data, *data_buf, size * sizeof(T));
*data_buf += size * sizeof(T);
}
}
......
......@@ -264,3 +264,9 @@ LOAD_FUSION_MATCHER(fusion_dequant_add_bn_quant);
LOAD_OP1(fusion_dequant_add_bn_relu_quant, CPU);
LOAD_FUSION_MATCHER(fusion_dequant_add_bn_relu_quant);
#endif
#ifdef SEQUENCE_EXPAND_OP
LOAD_OP1(sequence_expand, CPU);
#endif
#ifdef SEQUENCE_POOL_OP
LOAD_OP1(sequence_pool, CPU);
#endif
......@@ -11,77 +11,36 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SIGMOID_OP
#pragma once
#include <cmath>
#pragma once
#include "framework/operator.h"
#include "operators/op_param.h"
#ifdef __ARM_NEON
#include <arm_neon.h>
#include "operators/math/math_func_neon.h"
#endif
namespace paddle_mobile {
namespace operators {
using framework::DDim;
void sigmoid(const Tensor *X, Tensor *Y) {
#ifdef __ARM_NEON
const float *input = X->data<float>();
float *output = Y->mutable_data<float>();
const DDim &dDim = X->dims();
int axis_index = 1;
if (dDim.size() < 4) {
axis_index = 0;
}
DDim outer_ddim =
paddle_mobile::framework::slice_ddim(dDim, 0, axis_index + 1);
DDim inner_ddim =
paddle_mobile::framework::slice_ddim(dDim, axis_index + 1, dDim.size());
int out_size = paddle_mobile::framework::product(outer_ddim);
int inner_size = paddle_mobile::framework::product(inner_ddim);
#define DECLARE_KERNEL(KernelClass, KernelParam) \
template <typename DeviceType, typename T> \
class KernelClass \
: public framework::OpKernelBase<DeviceType, KernelParam<DeviceType>> { \
public: \
bool Init(KernelParam<DeviceType> *param); \
void Compute(const KernelParam<DeviceType> &param); \
};
#ifdef RELU_OP
DECLARE_KERNEL(ReluKernel, ReluParam);
DECLARE_KERNEL(Relu6Kernel, ReluParam);
#endif
DLOG << "outsize=" << out_size;
DLOG << "innersize=" << inner_size;
#pragma omp parallel for
for (int i = 0; i < out_size; ++i) {
const float *input_outer_ptr = input + i * inner_size;
float *output_outer_ptr = output + i * inner_size;
int nn = inner_size >> 2;
int remain = inner_size - (nn << 2);
float32x4_t _one = vdupq_n_f32(1.f);
for (; nn > 0; nn--) {
float32x4_t data = vld1q_f32(input_outer_ptr);
data = vnegq_f32(data);
data = exp_ps(data);
data = vaddq_f32(data, _one);
float32x4_t out_data = vrecpeq_f32(data);
out_data = vmulq_f32(vrecpsq_f32(data, out_data), out_data);
vst1q_f32(output_outer_ptr, out_data);
#ifdef SIGMOID_OP
DECLARE_KERNEL(SigmoidKernel, SigmoidParam);
#endif
input_outer_ptr += 4;
output_outer_ptr += 4;
}
for (; remain > 0; remain--) {
*output_outer_ptr = 1.f / (1.f + exp(-*input_outer_ptr));
output_outer_ptr++;
input_outer_ptr++;
}
}
#else
#ifdef TANH_OP
DECLARE_KERNEL(TanhKernel, TanhParam);
#endif
}
template <typename P>
void SigmoidCompute(const SigmoidParam<CPU> &param) {
const Tensor *in_x = param.InputX();
Tensor *out = param.Out();
auto x_dims = in_x->dims();
out->Resize(x_dims);
sigmoid(in_x, out);
}
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -12,9 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef RELU_OP
#include "operators/kernel/relu_kernel.h"
#include "operators/kernel/activation_kernel.h"
#include "common/types.h"
#include "operators/math/activation.h"
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
......@@ -25,12 +23,12 @@ namespace paddle_mobile {
namespace operators {
template <typename Dtype, ActivationType Act>
struct ReluCompute {
struct ActivationCompute {
void operator()(const Tensor *input, Tensor *output) {}
};
template <ActivationType Act>
struct ReluCompute<float, Act> {
struct ActivationCompute<float, Act> {
void operator()(const Tensor *input, Tensor *output) {
const float *x = input->data<float>();
float *y = output->mutable_data<float>();
......@@ -65,6 +63,7 @@ struct ReluCompute<float, Act> {
}
};
#ifdef RELU_OP
template <>
bool ReluKernel<CPU, float>::Init(ReluParam<CPU> *param) {
return true;
......@@ -74,7 +73,7 @@ template <>
void ReluKernel<CPU, float>::Compute(const ReluParam<CPU> &param) {
const Tensor *input = param.InputX();
Tensor *output = param.Out();
ReluCompute<float, RELU>()(input, output);
ActivationCompute<float, RELU>()(input, output);
}
template <>
......@@ -86,10 +85,37 @@ template <>
void Relu6Kernel<CPU, float>::Compute(const ReluParam<CPU> &param) {
const Tensor *input = param.InputX();
Tensor *output = param.Out();
ReluCompute<float, RELU6>()(input, output);
ActivationCompute<float, RELU6>()(input, output);
}
#endif
} // namespace operators
} // namespace paddle_mobile
#ifdef SIGMOID_OP
template <>
bool SigmoidKernel<CPU, float>::Init(SigmoidParam<CPU> *param) {
return true;
}
template <>
void SigmoidKernel<CPU, float>::Compute(const SigmoidParam<CPU> &param) {
const Tensor *input = param.InputX();
Tensor *output = param.Out();
ActivationCompute<float, SIGMOID>()(input, output);
}
#endif
#ifdef TANH_OP
template <>
void TanhKernel<CPU, float>::Init(TanhParam<CPU> *param) {
return true;
}
template <>
void TanhKernel<CPU, float>::Compute(const TanhParam<CPU> &param) {
const Tensor *input = param.InputX();
Tensor *output = param.Out();
ActivationCompute<float, TANH>()(input, output);
}
#endif
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SEQUENCE_EXPAND_OP
#include <vector>
#include "operators/kernel/sequence_kernels.h"
namespace paddle_mobile {
namespace operators {
typedef int (*LoDElementFunctor)(const std::vector<size_t> &x_lod, int index);
int element_with_lod(const std::vector<size_t> &x_lod, int index) {
return x_lod[index];
}
int element_without_lod(const std::vector<size_t> &x_lod, int index) {
return index;
}
template <typename T>
inline void SequenceExpandImpl(const framework::LoDTensor &x,
const std::vector<size_t> &ref_lod,
framework::LoDTensor *output) {
const T *x_data = x.data<T>();
auto &x_lod = x.lod();
LoDElementFunctor lod_element = element_without_lod;
if (x_lod.size() == 1) lod_element = element_with_lod;
T *output_data = output->mutable_data<T>();
int x_item_length = x.numel() / x.dims()[0];
int out_offset = 0;
for (size_t i = 1; i < ref_lod.size(); ++i) {
int repeat_num = ref_lod[i] - ref_lod[i - 1];
int x_start = lod_element(x_lod[0], i - 1);
int x_end = lod_element(x_lod[0], i);
int x_seq_len = x_end - x_start;
if (repeat_num > 0) {
int out_start = out_offset;
if (output->lod().size() == 1) {
out_start = output->lod()[0][out_offset];
}
for (int j = 0; j < repeat_num; j++) {
for (int k = 0; k < x_seq_len; k++) {
memcpy(output_data + (out_start + j * x_seq_len + k) * x_item_length,
x_data + (x_start + k) * x_item_length,
x_item_length * sizeof(T));
}
}
}
out_offset += repeat_num;
}
}
template <typename T>
class SequenceExpandKernel<CPU, T>
: public framework::OpKernelBase<CPU, SequenceExpandParam<CPU>> {
public:
bool Init(SequenceExpandParam<CPU> *param) { return true; }
void Compute(const SequenceExpandParam<CPU> &param) {
const framework::LoDTensor *input_x = param.input_x_;
const framework::LoDTensor *input_y = param.input_y_;
framework::LoDTensor *output = param.output_;
output->mutable_data<T>();
const auto &x_lod = input_x->lod();
const auto &y_lod = input_y->lod();
int ref_level = param.ref_level_;
if (ref_level == -1) ref_level = y_lod.size() - 1;
if (y_lod[ref_level].size() <= 1) {
framework::TensorCopy(*input_x, output);
output->set_lod(input_x->lod());
return;
}
std::vector<size_t> out_lod;
if (x_lod.size() == 1) {
out_lod.push_back(0);
for (size_t i = 1; i < y_lod[ref_level].size(); ++i) {
int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1];
int x_start = x_lod[0][i - 1];
int x_end = x_lod[0][i];
int x_seq_len = x_end - x_start;
for (int j = 0; j < repeat_num; ++j) {
out_lod.push_back(out_lod.back() + x_seq_len);
}
}
}
output->set_lod({out_lod});
SequenceExpandImpl<T>(*input_x, y_lod[ref_level], output);
}
};
template class SequenceExpandKernel<CPU, float>;
// template class SequenceExpandKernel<CPU, int64_t>;
} // namespace operators
} // namespace paddle_mobile
#endif // SEQUENCE_EXPAND_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SEQUENCE_POOL_OP
#include <cmath>
#include <limits>
#include <string>
#include <vector>
#include "common/types.h"
#include "operators/kernel/sequence_kernels.h"
#include "operators/math/pooling.h"
#ifdef __ARM_NEON__
#include <arm_neon.h>
#endif // __ARM_NEON__
namespace paddle_mobile {
namespace operators {
template <PoolingType P = MAX, typename T = float>
void SequencePoolImpl(const framework::LoDTensor &input,
framework::LoDTensor *output) {
const float *input_ptr = input.data<float>();
float *output_ptr = output->mutable_data<float>();
const auto &lod = input.lod()[0];
int64_t width = input.numel() / input.dims()[0];
#pragma omp parallel for
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
const float *in_ptr = input_ptr + lod[i] * width;
float *out_ptr = output_ptr + i * width;
int64_t height = static_cast<int64_t>(lod[i + 1] - lod[i]);
if (width == 1) {
float max = -std::numeric_limits<float>::max();
int remain_h = height;
#ifdef __ARM_NEON__
int loop = remain_h >> 2;
remain_h = remain_h & 0x3;
float32x4_t __max4 = math::vPoolInitq_f32<MAX>();
for (int h = 0; h < loop; ++h) {
float32x4_t r0 = vld1q_f32(in_ptr);
__max4 = vmaxq_f32(__max4, r0);
in_ptr += 4;
}
float32x2_t __max2 =
vpmax_f32(vget_low_f32(__max4), vget_high_f32(__max4));
__max2 = vpmax_f32(__max2, __max2);
max = std::max(max, vget_lane_f32(__max2, 0));
#endif // __ARM_NEON__
for (int h = 0; h < remain_h; ++h) {
max = std::max(max, in_ptr[h]);
}
*out_ptr = max;
} else {
memcpy(out_ptr, in_ptr, width * sizeof(float));
in_ptr += width;
int remain_h = height - 1;
#ifdef __ARM_NEON__
int remain_w_start = width & 0xfffc;
#endif // __ARM_NEON__
for (int h = 0; h < remain_h; ++h) {
#ifdef __ARM_NEON__
for (int w = 0; w < width; w += 4) {
float32x4_t __in = vld1q_f32(in_ptr + w);
float32x4_t __out = vld1q_f32(out_ptr + w);
__out = vmaxq_f32(__out, __in);
vst1q_f32(out_ptr + w, __out);
}
#endif // __ARM_NEON__
for (int w = remain_w_start; w < width; ++w) {
out_ptr[w] = std::max(out_ptr[w], in_ptr[w]);
}
in_ptr += width;
}
}
}
}
template <>
void SequencePoolImpl<SUM, float>(const framework::LoDTensor &input,
framework::LoDTensor *output) {
const float *input_ptr = input.data<float>();
float *output_ptr = output->mutable_data<float>();
const auto &lod = input.lod()[0];
int64_t width = input.numel() / input.dims()[0];
#pragma omp parallel for
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
const float *in_ptr = input_ptr + lod[i] * width;
float *out_ptr = output_ptr + i * width;
int64_t height = static_cast<int64_t>(lod[i + 1] - lod[i]);
if (width == 1) {
float sum = 0.f;
int remain_h = height;
#ifdef __ARM_NEON__
int loop = remain_h >> 2;
remain_h = remain_h & 0x3;
float32x4_t __sum4 = vdupq_n_f32(0.f);
for (int h = 0; h < loop; ++h) {
float32x4_t r0 = vld1q_f32(in_ptr);
__sum4 = vaddq_f32(__sum4, r0);
in_ptr += 4;
}
float32x2_t __sum2 =
vpadd_f32(vget_low_f32(__sum4), vget_high_f32(__sum4));
sum += vget_lane_f32(__sum2, 0) + vget_lane_f32(__sum2, 1);
#endif // __ARM_NEON__
for (int h = 0; h < remain_h; ++h) {
sum += in_ptr[h];
}
*out_ptr = sum;
} else {
memcpy(out_ptr, in_ptr, width * sizeof(float));
in_ptr += width;
int remain_h = height - 1;
#ifdef __ARM_NEON__
int loop_w = width >> 2;
int remain_w_start = width & 0xfffc;
#endif // __ARM_NEON__
for (int h = 0; h < remain_h; ++h) {
#ifdef __ARM_NEON__
for (int w = 0; w < width - 3; w += 4) {
float32x4_t __in = vld1q_f32(in_ptr + w);
float32x4_t __out = vld1q_f32(out_ptr + w);
__out = vaddq_f32(__out, __in);
vst1q_f32(out_ptr + w, __out);
}
#endif // __ARM_NEON__
for (int w = remain_w_start; w < width; ++w) {
out_ptr[w] += in_ptr[w];
}
in_ptr += width;
}
}
}
}
template <>
void SequencePoolImpl<FIRST, float>(const framework::LoDTensor &input,
framework::LoDTensor *output) {
const float *input_ptr = input.data<float>();
float *output_ptr = output->mutable_data<float>();
const auto &lod = input.lod()[0];
int64_t width = input.numel() / input.dims()[0];
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
const float *in_ptr = input_ptr + lod[i] * width;
float *out_ptr = output_ptr + i * width;
memcpy(out_ptr, in_ptr, width * sizeof(float));
}
}
template <typename T>
class SequencePoolKernel<CPU, T>
: public framework::OpKernelBase<CPU, SequencePoolParam<CPU>> {
public:
bool Init(SequencePoolParam<CPU> *param) { return true; }
void Compute(const SequencePoolParam<CPU> &param) {
const framework::LoDTensor *input = param.input_;
framework::LoDTensor *output = param.output_;
output->mutable_data<T>();
const std::string pooling_type = param.pool_type_;
if (param.pool_type_ == "MAX") {
SequencePoolImpl<MAX, T>(*input, output);
} else if (param.pool_type_ == "FIRST") {
SequencePoolImpl<FIRST, T>(*input, output);
} else if (param.pool_type_ == "SUM") {
SequencePoolImpl<SUM, T>(*input, output);
} else {
PADDLE_MOBILE_THROW_EXCEPTION(
"pooling type `%s` has not been implemented.",
param.pool_type_.c_str());
}
}
};
template class SequencePoolKernel<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
#endif // SEQUENCE_POOL_OP
......@@ -12,32 +12,32 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SIGMOID_OP
#include "../sigmoid_kernel.h"
#include "../central-arm-func/sigmoid_arm_func.h"
#ifdef __ARM_NEON
#include "../../math/math_func_neon.h"
#endif
#include <cmath>
#ifdef SEQUENCE_SOFTMAX_OP
#include "framework/lod_tensor.h"
#include "operators/kernel/sequence_kernels.h"
#include "operators/math/softmax.h"
namespace paddle_mobile {
namespace operators {
using framework::DDim;
using framework::Tensor;
template <typename T>
class SequenceSoftmaxKernel<CPU, T>
: public framework::OpKernelBase<CPU, SoftmaxParam<CPU>> {
public:
bool Init(SoftmaxParam<CPU> *param) { return true; }
template <>
bool SigmoidKernel<CPU, float>::Init(SigmoidParam<CPU> *param) {
return true;
}
void Compute(const SoftmaxParam<CPU> &param) {
const framework::LoDTensor *input = param.InputX();
framework::LoDTensor *output = param.Out();
math::SequenceSoftmaxFuntor<CPU, T> sequence_softmax;
sequence_softmax(input, output);
}
};
template <>
void SigmoidKernel<CPU, float>::Compute(const SigmoidParam<CPU> &param) {
SigmoidCompute<float>(param);
}
template class SequenceSoftmaxKernel<CPU, float>;
template class SigmoidKernel<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
#endif
#endif // SEQUENCE_SOFTMAX_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
#define DECLARE_KERNEL(KernelClass, KernelParam) \
template <typename DeviceType, typename T> \
class KernelClass \
: public framework::OpKernelBase<DeviceType, KernelParam<DeviceType>> { \
public: \
bool Init(KernelParam<DeviceType> *param); \
void Compute(const KernelParam<DeviceType> &param); \
};
#ifdef SEQUENCE_EXPAND_OP
DECLARE_KERNEL(SequenceExpandKernel, SequenceExpandParam);
#endif // SEQUENCE_EXPAND_OP
#ifdef SEQUENCE_POOL_OP
DECLARE_KERNEL(SequencePoolKernel, SequencePoolParam);
#endif // SEQUENCE_POOL_OP
#ifdef SEQUENCE_SOFTMAX_OP
DECLARE_KERNEL(SequenceSoftmaxKernel, SoftmaxParam);
#endif // SEQUENCE_SOFTMAX_OP
} // namespace operators
} // namespace paddle_mobile
......@@ -60,6 +60,58 @@ float find_max(const float *input, const int num_classes) {
return max;
}
void SoftmaxBasic(const float *input, int num_classes, float *y) {
float *output = y;
// find max
float max = find_max(input, num_classes);
// exp(x - max) and sum(exp(x - max))
int remain = num_classes;
float sum = 0.f;
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
int loop = num_classes >> 3;
remain = num_classes & 0x7;
float32x4_t __max = vdupq_n_f32(max);
float32x4_t __sum = vdupq_n_f32(0.f);
for (int i = 0; i < loop; ++i, input += 8, output += 8) {
float32x4_t x0 = vld1q_f32(input);
float32x4_t x1 = vld1q_f32(input + 4);
x0 = vsubq_f32(x0, __max);
x1 = vsubq_f32(x1, __max);
x0 = exp_ps(x0);
x1 = exp_ps(x1);
__sum = vaddq_f32(x0, __sum);
__sum = vaddq_f32(x1, __sum);
vst1q_f32(output, x0);
vst1q_f32(output + 4, x1);
}
sum += vaddvq_f32(__sum);
#endif // __ARM_NEON__
for (int i = 0; i < remain; ++i) {
float out = expf(input[i] - max);
sum += out;
output[i] = out;
}
// exp(x - max) / sum
float inv_sum = 1.f / sum;
output = y;
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
float32x4_t __inv_sum = vdupq_n_f32(inv_sum);
for (int i = 0; i < loop; ++i, output += 8) {
float32x4_t x0 = vld1q_f32(output);
float32x4_t x1 = vld1q_f32(output + 4);
x0 = vmulq_f32(x0, __inv_sum);
x1 = vmulq_f32(x1, __inv_sum);
vst1q_f32(output, x0);
vst1q_f32(output + 4, x1);
}
#endif
for (int i = 0; i < remain; ++i) {
output[i] *= inv_sum;
}
}
template <>
void SoftmaxFuntor<CPU, float>::operator()(const framework::Tensor *X,
framework::Tensor *Y) {
......@@ -76,65 +128,25 @@ void SoftmaxFuntor<CPU, float>::operator()(const framework::Tensor *X,
size_t offset = (batch * channels + channel) * num_classes;
const float *input = x + offset;
float *output = y + offset;
// find max
float max = find_max(input, num_classes);
// exp(x - max)
int remain = num_classes;
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
int loop = num_classes >> 3;
remain = num_classes & 0x7;
float32x4_t __max = vdupq_n_f32(max);
for (int i = 0; i < loop; ++i, input += 8, output += 8) {
float32x4_t x0 = vld1q_f32(input);
float32x4_t x1 = vld1q_f32(input + 4);
x0 = vsubq_f32(x0, __max);
x1 = vsubq_f32(x1, __max);
x0 = exp_ps(x0);
x1 = exp_ps(x1);
vst1q_f32(output, x0);
vst1q_f32(output + 4, x1);
}
#endif // __ARM_NEON__
for (int i = 0; i < remain; ++i) {
output[i] = expf(input[i] - max);
}
SoftmaxBasic(input, num_classes, output);
}
}
}
// sum(exp(x - max))
float sum = 0.f;
output = y + offset;
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
float32x4_t __sum = vdupq_n_f32(0.f);
for (int i = 0; i < loop; ++i, output += 8) {
float32x4_t x0 = vld1q_f32(output);
float32x4_t x1 = vld1q_f32(output + 4);
__sum = vaddq_f32(x0, __sum);
__sum = vaddq_f32(x1, __sum);
}
sum += vaddvq_f32(__sum);
#endif // __ARM_NEON__
for (int i = 0; i < remain; ++i) {
sum += output[i];
}
template <>
void SequenceSoftmaxFuntor<CPU, float>::operator()(
const framework::LoDTensor *X, framework::LoDTensor *Y) {
const float *x = X->data<float>();
const auto &lod = X->lod().back();
float *y = Y->mutable_data<float>();
// exp(x - max) / sum
float inv_sum = 1.f / sum;
output = y + offset;
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
float32x4_t __inv_sum = vdupq_n_f32(inv_sum);
for (int i = 0; i < loop; ++i, output += 8) {
float32x4_t x0 = vld1q_f32(output);
float32x4_t x1 = vld1q_f32(output + 4);
x0 = vmulq_f32(x0, __inv_sum);
x1 = vmulq_f32(x1, __inv_sum);
vst1q_f32(output, x0);
vst1q_f32(output + 4, x1);
}
#endif
for (int i = 0; i < remain; ++i) {
output[i] *= inv_sum;
}
}
#pragma omp parallel for
for (int batch = 0; batch < lod.size() - 1; ++batch) {
int num_classes = lod[batch + 1] - lod[batch];
size_t offset = lod[batch];
const float *input = x + offset;
float *output = y + offset;
SoftmaxBasic(input, num_classes, output);
}
}
......
......@@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SOFTMAX_OP
#if defined(SOFTMAX_OP) || defined(SEQUENCE_SOFTMAX_OP)
#pragma once
#include "framework/lod_tensor.h"
#include "framework/tensor.h"
namespace paddle_mobile {
......@@ -28,7 +29,14 @@ class SoftmaxFuntor {
void operator()(const framework::Tensor *X, framework::Tensor *Y);
};
template <typename Device, typename T>
class SequenceSoftmaxFuntor {
public:
void operator()(const framework::LoDTensor *X, framework::LoDTensor *Y);
};
} // namespace math
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -978,12 +978,12 @@ class SoftmaxParam : public OpParam {
input_x_ = InputXFrom<GType>(inputs, scope);
out_ = OutFrom<GType>(outputs, scope);
}
const RType *InputX() const { return input_x_; }
RType *Out() const { return out_; }
const GType *InputX() const { return input_x_; }
GType *Out() const { return out_; }
private:
RType *input_x_;
RType *out_;
GType *input_x_;
GType *out_;
#ifdef PADDLE_MOBILE_FPGA
......@@ -2737,5 +2737,57 @@ class FusionDequantAddBNQuantParam : public FusionDequantAddBNParam<Dtype> {
};
#endif
#ifdef SEQUENCE_EXPAND_OP
template <typename Dtype>
class SequenceExpandParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
SequenceExpandParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs,
const Scope &scope) {
input_x_ = InputXFrom<GType>(inputs, scope);
input_y_ = InputYFrom<GType>(inputs, scope);
output_ = OutFrom<GType>(outputs, scope);
ref_level_ = -1;
if (OpParam::HasAttr("ref_level", attrs)) {
ref_level_ = OpParam::GetAttr<int>("ref_level", attrs);
}
}
public:
GType *input_x_;
GType *input_y_;
GType *output_;
int ref_level_;
};
#endif // SEQUENCE_EXPAND_OP
#ifdef SEQUENCE_POOL_OP
template <typename Dtype>
class SequencePoolParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
SequencePoolParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs,
const Scope &scope) {
input_ = InputXFrom<GType>(inputs, scope);
output_ = OutFrom<GType>(outputs, scope);
pool_type_ = "MAX";
if (OpParam::HasAttr("pooltype", attrs)) {
pool_type_ = OpParam::GetStringAttr("pooltype", attrs);
}
}
public:
GType *input_;
GType *output_;
std::string pool_type_;
};
#endif // SEQUENCE_EXPAND_OP
} // namespace operators
} // namespace paddle_mobile
......@@ -15,6 +15,7 @@ limitations under the License. */
#ifdef RELU_OP
#include "operators/relu_op.h"
namespace paddle_mobile {
namespace operators {
......@@ -47,4 +48,4 @@ REGISTER_OPERATOR_MALI_GPU(relu, ops::ReluOp);
REGISTER_OPERATOR_CL(relu, ops::ReluOp);
#endif
#endif
#endif // RELU_OP
......@@ -19,7 +19,7 @@ limitations under the License. */
#include <string>
#include "framework/operator.h"
#include "operators/kernel/relu_kernel.h"
#include "operators/kernel/activation_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SEQUENCE_EXPAND_OP
#include "operators/sequence_ops/sequence_expand_op.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
void SequenceExpandOp<DeviceType, T>::InferShape() const {
const auto *input_x = this->param_.input_x_;
const auto *input_y = this->param_.input_y_;
const auto &x_lod = input_x->lod();
const auto &y_lod = input_y->lod();
int ref_level = this->param_.ref_level_;
if (ref_level == -1) ref_level = y_lod.size() - 1;
auto out_dims = input_x->dims();
int64_t out_first_dim = 0;
if (y_lod[ref_level].size() > 1) {
for (size_t i = 1; i < y_lod[ref_level].size(); ++i) {
int x_seq_len = 1;
if (x_lod.size() == 1) {
x_seq_len = x_lod[0][i] - x_lod[0][i - 1];
}
out_first_dim +=
(y_lod[ref_level][i] - y_lod[ref_level][i - 1]) * x_seq_len;
}
out_dims[0] = out_first_dim;
}
this->param_.output_->Resize(out_dims);
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(sequence_expand, ops::SequenceExpandOp);
#endif
#endif // SEQUENCE_EXPAND_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SEQUENCE_EXPAND_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/sequence_kernels.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class SequenceExpandOp : public framework::OperatorWithKernel<
DeviceType, SequenceExpandParam<DeviceType>,
operators::SequenceExpandKernel<DeviceType, T>> {
public:
SequenceExpandOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, SequenceExpandParam<DeviceType>,
operators::SequenceExpandKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
// inference output shape
void InferShape() const override;
};
} // namespace operators
} // namespace paddle_mobile
#endif // SEQUENCE_EXPAND_OP
......@@ -12,27 +12,27 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef SEQUENCE_POOL_OP
#ifdef SIGMOID_OP
#include "framework/operator.h"
#include "operators/op_param.h"
#include "operators/sequence_ops/sequence_pool_op.h"
namespace paddle_mobile {
namespace operators {
using framework::OpKernelBase;
template <typename DeviceType, typename T>
class SigmoidKernel
: public OpKernelBase<DeviceType, SigmoidParam<DeviceType>> {
public:
void Compute(const SigmoidParam<DeviceType>& param);
bool Init(SigmoidParam<DeviceType>* param);
};
void SequencePoolOp<DeviceType, T>::InferShape() const {
const auto *input = this->param_.input_;
auto out_dims = input->dims();
out_dims[0] = input->lod()[0].size() - 1;
this->param_.output_->Resize(out_dims);
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(sequence_pool, ops::SequencePoolOp);
#endif
#endif // SEQUENCE_POOL_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SEQUENCE_POOL_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/sequence_kernels.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class SequencePoolOp : public framework::OperatorWithKernel<
DeviceType, SequencePoolParam<DeviceType>,
operators::SequencePoolKernel<DeviceType, T>> {
public:
SequencePoolOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, SequencePoolParam<DeviceType>,
operators::SequencePoolKernel<DeviceType, T>>(type, inputs, outputs,
attrs, scope) {}
// inference output shape
void InferShape() const override;
};
} // namespace operators
} // namespace paddle_mobile
#endif // SEQUENCE_POOL_OP
......@@ -12,33 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef RELU_OP
#ifdef SEQUENCE_SOFTMAX_OP
#pragma once
#include "framework/operator.h"
#include "operators/op_param.h"
#include "operators/sequence_ops/sequence_softmax_op.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class ReluKernel
: public framework::OpKernelBase<DeviceType, ReluParam<DeviceType>> {
public:
void Compute(const ReluParam<DeviceType>& param);
bool Init(ReluParam<DeviceType>* param);
};
void SequenceSoftmaxOp<DeviceType, T>::InferShape() const {
const auto *input_x = this->param_.InputX();
const auto &x_lod = input_x->lod();
template <typename DeviceType, typename T>
class Relu6Kernel
: public framework::OpKernelBase<DeviceType, ReluParam<DeviceType>> {
public:
void Compute(const ReluParam<DeviceType>& param);
bool Init(ReluParam<DeviceType>* param);
};
this->param_.Out()->Resize(input_x->dims());
this->param_.Out()->set_lod(input_x->lod());
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(sequence_softmax, ops::SequenceSoftmaxOp);
#endif
#endif // SEQUENCE_SOFTMAX_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SEQUENCE_SOFTMAX_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/sequence_kernels.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class SequenceSoftmaxOp : public framework::OperatorWithKernel<
DeviceType, SoftmaxParam<DeviceType>,
operators::SequenceSoftmaxKernel<DeviceType, T>> {
public:
SequenceSoftmaxOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, SoftmaxParam<DeviceType>,
operators::SequenceSoftmaxKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
// inference output shape
void InferShape() const override;
};
} // namespace operators
} // namespace paddle_mobile
#endif // SEQUENCE_SOFTMAX_OP
......@@ -18,7 +18,7 @@ limitations under the License. */
#include <string>
#include "framework/operator.h"
#include "operators/kernel/sigmoid_kernel.h"
#include "operators/kernel/activation_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
......
......@@ -28,6 +28,9 @@ void TanhOp<DeviceType, T>::InferShape() const {
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(tanh, ops::TanhOp);
#endif
#ifdef PADDLE_MOBILE_FPGA
REGISTER_OPERATOR_FPGA(tanh, ops::TanhOp);
#endif
......
......@@ -18,7 +18,7 @@ limitations under the License. */
#include <string>
#include "framework/operator.h"
#include "operators/kernel/tanh_kernel.h"
#include "operators/kernel/activation_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
......
......@@ -385,4 +385,13 @@ if (NOT FOUND_MATCH)
# gen test
ADD_EXECUTABLE(test-ocr net/test_ocr.cpp test_helper.h test_include.h)
target_link_libraries(test-ocr paddle-mobile)
ADD_EXECUTABLE(test-sequence-expand operators/test_sequence_expand_op.cpp test_helper.h test_include.h)
target_link_libraries(test-sequence-expand paddle-mobile)
ADD_EXECUTABLE(test-sequence-pool operators/test_sequence_pool_op.cpp test_helper.h test_include.h)
target_link_libraries(test-sequence-pool paddle-mobile)
ADD_EXECUTABLE(test-sequence-softmax operators/test_sequence_softmax_op.cpp test_helper.h test_include.h)
target_link_libraries(test-sequence-softmax paddle-mobile)
endif ()
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include "../test_include.h"
#include "operators/sequence_ops/sequence_expand_op.h"
namespace paddle_mobile {
int TestSequenceExpandOp(const framework::LoDTensor &input_x,
const framework::LoDTensor &input_y, int ref_level,
framework::LoDTensor *output) {
VariableNameMap inputs;
VariableNameMap outputs;
auto scope = std::make_shared<framework::Scope>();
inputs["X"] = std::vector<std::string>({"input_x"});
inputs["Y"] = std::vector<std::string>({"input_y"});
outputs["Out"] = std::vector<std::string>({"output"});
auto input_x_var = scope.get()->Var("input_x");
auto *x = input_x_var->template GetMutable<framework::LoDTensor>();
x->Resize(input_x.dims());
x->ShareDataWith(input_x);
x->set_lod(input_x.lod());
auto input_y_var = scope.get()->Var("input_y");
auto *y = input_y_var->template GetMutable<framework::LoDTensor>();
y->Resize(framework::make_ddim({0}));
y->mutable_data<float>();
y->set_lod(input_y.lod());
auto output_var = scope.get()->Var("output");
framework::AttributeMap attrs;
attrs["ref_level"].Set<int>(0);
auto *op = new operators::SequenceExpandOp<CPU, float>(
"sequence_expand", inputs, outputs, attrs, scope);
op->InferShape();
op->Init();
op->Run();
auto *out = output_var->template Get<framework::LoDTensor>();
output->Resize(out->dims());
output->ShareDataWith(*out);
output->set_lod(out->lod());
delete op;
return 0;
}
} // namespace paddle_mobile
// namespace framework = paddle_mobile::framework;
int main(int argc, char *argv[]) {
framework::LoDTensor input_x, input_y, output;
// case 1
{
std::vector<float> data{1, 2, 3, 4};
input_x.Resize(framework::make_ddim({4, 1}));
float *in_data = input_x.mutable_data<float>();
for (int i = 0; i < 4; ++i) in_data[i] = data[i];
input_x.set_lod({{0, 2, 4}});
input_y.set_lod({{0, 2, 4}, {0, 3, 6, 7, 8}});
TestSequenceExpandOp(input_x, input_y, 0, &output);
std::vector<float> expect_data{1, 2, 1, 2, 3, 4, 3, 4};
std::vector<int> expect_lod{0, 2, 4, 6, 8};
for (int i = 0; i < 5; ++i) {
if (output.lod()[0][i] != expect_lod[i]) {
std::cerr << "output_lod[" << i << "]: " << output.lod()[0][i]
<< " != expect_lod[" << i << "]: " << expect_lod[i]
<< std::endl;
return 1;
}
}
for (int i = 0; i < 8; ++i) {
if (output.data<float>()[i] != expect_data[i]) {
std::cerr << "output[" << i << "]: " << output.data<float>()[i]
<< " != expect[" << i << "]: " << expect_data[i] << std::endl;
return 1;
}
}
}
return 0;
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include "../test_include.h"
#include "operators/sequence_ops/sequence_pool_op.h"
namespace paddle_mobile {
int TestSequencePoolOp(const framework::LoDTensor &input_x,
const std::string pool_type,
framework::LoDTensor *output) {
VariableNameMap inputs;
VariableNameMap outputs;
auto scope = std::make_shared<framework::Scope>();
inputs["X"] = std::vector<std::string>({"input_x"});
outputs["Out"] = std::vector<std::string>({"output"});
auto input_x_var = scope.get()->Var("input_x");
auto *x = input_x_var->template GetMutable<framework::LoDTensor>();
x->Resize(input_x.dims());
x->ShareDataWith(input_x);
x->set_lod(input_x.lod());
auto output_var = scope.get()->Var("output");
framework::AttributeMap attrs;
attrs["pooltype"].SetString(pool_type);
auto *op = new operators::SequencePoolOp<CPU, float>("sequence_pool", inputs,
outputs, attrs, scope);
op->InferShape();
op->Init();
op->Run();
auto *out = output_var->template Get<framework::LoDTensor>();
output->Resize(out->dims());
output->ShareDataWith(*out);
delete op;
return 0;
}
} // namespace paddle_mobile
// namespace framework = paddle_mobile::framework;
int main(int argc, char *argv[]) {
framework::LoDTensor input_x, output;
// case 1
std::cerr << "running max case 1" << std::endl;
{
std::vector<float> data{1, 2, 3, 4};
input_x.Resize(framework::make_ddim({4, 1}));
float *in_data = input_x.mutable_data<float>();
for (int i = 0; i < 4; ++i) in_data[i] = data[i];
input_x.set_lod({{0, 2, 4}});
TestSequencePoolOp(input_x, "MAX", &output);
std::vector<float> expect_data{2, 4};
for (int i = 0; i < 2; ++i) {
if (output.data<float>()[i] != expect_data[i]) {
std::cerr << "output[" << i << "]: " << output.data<float>()[i]
<< " != expect[" << i << "]: " << expect_data[i] << std::endl;
return 1;
}
}
}
// case 2
std::cerr << "running max case 2" << std::endl;
{
std::vector<float> data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
input_x.Resize(framework::make_ddim({data.size(), 1}));
float *in_data = input_x.mutable_data<float>();
for (int i = 0; i < data.size(); ++i) in_data[i] = data[i];
input_x.set_lod({{0, 3, 10}});
TestSequencePoolOp(input_x, "MAX", &output);
std::vector<float> expect_data{3, 10};
for (int i = 0; i < 2; ++i) {
if (output.data<float>()[i] != expect_data[i]) {
std::cerr << "output[" << i << "]: " << output.data<float>()[i]
<< " != expect[" << i << "]: " << expect_data[i] << std::endl;
return 1;
}
}
}
std::cerr << "running max case 3" << std::endl;
// case 3
{
std::vector<float> data{1, 2, 3, 4, 5, 6, 7, 8};
input_x.Resize(framework::make_ddim({4, 2}));
float *in_data = input_x.mutable_data<float>();
for (int i = 0; i < data.size(); ++i) in_data[i] = data[i];
input_x.set_lod({{0, 2, 4}});
TestSequencePoolOp(input_x, "MAX", &output);
std::vector<float> expect_data{3, 4, 7, 8};
for (int i = 0; i < 4; ++i) {
if (output.data<float>()[i] != expect_data[i]) {
std::cerr << "output[" << i << "]: " << output.data<float>()[i]
<< " != expect[" << i << "]: " << expect_data[i] << std::endl;
return 1;
}
}
}
// case 4
std::cerr << "running max case 4" << std::endl;
{
std::vector<float> data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
11, 12, 13, 14, 15, 16, 17, 18, 19, 20};
input_x.Resize(framework::make_ddim({4, 5}));
float *in_data = input_x.mutable_data<float>();
for (int i = 0; i < data.size(); ++i) in_data[i] = data[i];
input_x.set_lod({{0, 2, 4}});
TestSequencePoolOp(input_x, "MAX", &output);
std::vector<float> expect_data{6, 7, 8, 9, 10, 16, 17, 18, 19, 20};
for (int i = 0; i < 10; ++i) {
if (output.data<float>()[i] != expect_data[i]) {
std::cerr << "output[" << i << "]: " << output.data<float>()[i]
<< " != expect[" << i << "]: " << expect_data[i] << std::endl;
return 1;
}
}
}
// case 1
std::cerr << "running sum case 1" << std::endl;
{
std::vector<float> data{1, 2, 3, 4};
input_x.Resize(framework::make_ddim({4, 1}));
float *in_data = input_x.mutable_data<float>();
for (int i = 0; i < 4; ++i) in_data[i] = data[i];
input_x.set_lod({{0, 2, 4}});
TestSequencePoolOp(input_x, "SUM", &output);
std::vector<float> expect_data{3, 7};
for (int i = 0; i < 2; ++i) {
if (output.data<float>()[i] != expect_data[i]) {
std::cerr << "output[" << i << "]: " << output.data<float>()[i]
<< " != expect[" << i << "]: " << expect_data[i] << std::endl;
return 1;
}
}
}
// case 2
std::cerr << "running sum case 2" << std::endl;
{
std::vector<float> data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
input_x.Resize(framework::make_ddim({data.size(), 1}));
float *in_data = input_x.mutable_data<float>();
for (int i = 0; i < data.size(); ++i) in_data[i] = data[i];
input_x.set_lod({{0, 3, 10}});
TestSequencePoolOp(input_x, "SUM", &output);
std::vector<float> expect_data{6, 49};
for (int i = 0; i < 2; ++i) {
if (output.data<float>()[i] != expect_data[i]) {
std::cerr << "output[" << i << "]: " << output.data<float>()[i]
<< " != expect[" << i << "]: " << expect_data[i] << std::endl;
return 1;
}
}
}
// case 3
std::cerr << "running sum case 3" << std::endl;
{
std::vector<float> data{1, 2, 3, 4, 5, 6, 7, 8};
input_x.Resize(framework::make_ddim({4, 2}));
float *in_data = input_x.mutable_data<float>();
for (int i = 0; i < data.size(); ++i) in_data[i] = data[i];
input_x.set_lod({{0, 2, 4}});
TestSequencePoolOp(input_x, "SUM", &output);
std::vector<float> expect_data{4, 6, 12, 14};
for (int i = 0; i < 4; ++i) {
if (output.data<float>()[i] != expect_data[i]) {
std::cerr << "output[" << i << "]: " << output.data<float>()[i]
<< " != expect[" << i << "]: " << expect_data[i] << std::endl;
return 1;
}
}
}
// case 4
std::cerr << "running sum case 4" << std::endl;
{
std::vector<float> data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
11, 12, 13, 14, 15, 16, 17, 18, 19, 20};
input_x.Resize(framework::make_ddim({4, 5}));
float *in_data = input_x.mutable_data<float>();
for (int i = 0; i < data.size(); ++i) in_data[i] = data[i];
input_x.set_lod({{0, 2, 4}});
TestSequencePoolOp(input_x, "SUM", &output);
std::vector<float> expect_data{7, 9, 11, 13, 15, 27, 29, 31, 33, 35};
for (int i = 0; i < 10; ++i) {
if (output.data<float>()[i] != expect_data[i]) {
std::cerr << "output[" << i << "]: " << output.data<float>()[i]
<< " != expect[" << i << "]: " << expect_data[i] << std::endl;
return 1;
}
}
}
// case 1
std::cerr << "running first case 1" << std::endl;
{
std::vector<float> data{1, 2, 3, 4};
input_x.Resize(framework::make_ddim({4, 1}));
float *in_data = input_x.mutable_data<float>();
for (int i = 0; i < 4; ++i) in_data[i] = data[i];
input_x.set_lod({{0, 2, 4}});
TestSequencePoolOp(input_x, "FIRST", &output);
std::vector<float> expect_data{1, 3};
for (int i = 0; i < 2; ++i) {
if (output.data<float>()[i] != expect_data[i]) {
std::cerr << "output[" << i << "]: " << output.data<float>()[i]
<< " != expect[" << i << "]: " << expect_data[i] << std::endl;
return 1;
}
}
}
// case 2
std::cerr << "running first case 2" << std::endl;
{
std::vector<float> data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
input_x.Resize(framework::make_ddim({data.size(), 1}));
float *in_data = input_x.mutable_data<float>();
for (int i = 0; i < data.size(); ++i) in_data[i] = data[i];
input_x.set_lod({{0, 3, 10}});
TestSequencePoolOp(input_x, "FIRST", &output);
std::vector<float> expect_data{1, 4};
for (int i = 0; i < 2; ++i) {
if (output.data<float>()[i] != expect_data[i]) {
std::cerr << "output[" << i << "]: " << output.data<float>()[i]
<< " != expect[" << i << "]: " << expect_data[i] << std::endl;
return 1;
}
}
}
// case 3
std::cerr << "running first case 3" << std::endl;
{
std::vector<float> data{1, 2, 3, 4, 5, 6, 7, 8};
input_x.Resize(framework::make_ddim({4, 2}));
float *in_data = input_x.mutable_data<float>();
for (int i = 0; i < data.size(); ++i) in_data[i] = data[i];
input_x.set_lod({{0, 2, 4}});
TestSequencePoolOp(input_x, "FIRST", &output);
std::vector<float> expect_data{1, 2, 5, 6};
for (int i = 0; i < 4; ++i) {
if (output.data<float>()[i] != expect_data[i]) {
std::cerr << "output[" << i << "]: " << output.data<float>()[i]
<< " != expect[" << i << "]: " << expect_data[i] << std::endl;
return 1;
}
}
}
// case 4
std::cerr << "running first case 4" << std::endl;
{
std::vector<float> data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
11, 12, 13, 14, 15, 16, 17, 18, 19, 20};
input_x.Resize(framework::make_ddim({4, 5}));
float *in_data = input_x.mutable_data<float>();
for (int i = 0; i < data.size(); ++i) in_data[i] = data[i];
input_x.set_lod({{0, 2, 4}});
TestSequencePoolOp(input_x, "FIRST", &output);
std::vector<float> expect_data{1, 2, 3, 4, 5, 11, 12, 13, 14, 15};
for (int i = 0; i < 10; ++i) {
if (output.data<float>()[i] != expect_data[i]) {
std::cerr << "output[" << i << "]: " << output.data<float>()[i]
<< " != expect[" << i << "]: " << expect_data[i] << std::endl;
return 1;
}
}
}
return 0;
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <math.h>
#include <limits>
#include "../test_include.h"
#include "operators/sequence_ops/sequence_softmax_op.h"
namespace paddle_mobile {
void SequenceSoftmax(const framework::LoDTensor *X, framework::LoDTensor *Y) {
const float *x = X->data<float>();
const auto &lod = X->lod().back();
float *y = Y->mutable_data<float>();
for (int batch = 0; batch < lod.size() - 1; ++batch) {
int num_classes = lod[batch + 1] - lod[batch];
size_t offset = lod[batch];
const float *input = x + offset;
float *output = y + offset;
float max = -std::numeric_limits<float>::max();
for (int j = 0; j < num_classes; ++j) {
max = (input[j] > max) ? input[j] : max;
}
float sum = 0.f;
for (int j = 0; j < num_classes; ++j) {
float tmp = std::expf(input[j] - max);
sum += tmp;
output[j] = tmp;
}
for (int j = 0; j < num_classes; ++j) {
output[j] /= sum;
}
}
Y->set_lod(X->lod());
}
int TestSequenceSoftmaxOp(const std::vector<int> &input_shape,
const std::vector<size_t> &input_lod) {
framework::DDim dims = framework::make_ddim(input_shape);
VariableNameMap inputs;
VariableNameMap outputs;
auto scope = std::make_shared<framework::Scope>();
inputs["X"] = std::vector<std::string>({"input"});
outputs["Out"] = std::vector<std::string>({"output"});
auto input_var = scope.get()->Var("input");
auto input = input_var->template GetMutable<framework::LoDTensor>();
SetupTensor<float>(input, dims, -100.0, 100.0);
input->set_lod({input_lod});
auto output_var = scope.get()->Var("output");
framework::AttributeMap attrs;
auto *op = new operators::SequenceSoftmaxOp<CPU, float>(
"sequence_softmax", inputs, outputs, attrs, scope);
op->InferShape();
op->Init();
op->Run();
auto output = output_var->template Get<framework::LoDTensor>();
framework::LoDTensor output_cmp;
float *output_cmp_data = output_cmp.mutable_data<float>(output->dims());
SequenceSoftmax(input, &output_cmp);
const float *output_data = output->data<float>();
for (int i = 0; i < output->numel(); ++i) {
float gap = output_data[i] - output_cmp_data[i];
if (std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) {
LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i]
<< ", output_cmp_data[" << i
<< "] = " << output_cmp_data[i];
delete op;
exit(1);
}
}
delete op;
return 0;
}
} // namespace paddle_mobile
int main(int argc, char *argv[]) {
TestSequenceSoftmaxOp({2, 1}, {0, 2});
TestSequenceSoftmaxOp({100, 1}, {0, 3, 100});
TestSequenceSoftmaxOp({100, 1}, {0, 50, 100});
return 0;
}
......@@ -12,10 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "../../src/operators/kernel/central-arm-func/sigmoid_arm_func.h"
#include "../../src/operators/kernel/sigmoid_kernel.h"
#include "../test_helper.h"
#include "framework/executor.h"
int main() {
paddle_mobile::framework::Tensor input;
......@@ -25,11 +22,5 @@ int main() {
auto out_ddim = paddle_mobile::framework::make_ddim({1, 4, 60, 60});
output.Resize(out_ddim);
paddle_mobile::operators::sigmoid(&input, &output);
auto *output_ptr = output.data<float>();
for (int j = 0; j < output.numel(); ++j) {
DLOG << " value of output: " << output_ptr[j];
}
DLOG << 5;
return 0;
}
......@@ -62,7 +62,6 @@ int TestSoftmaxOp(const std::vector<int> input_shape) {
SetupTensor<float>(input, dims, -100.0, 100.0);
auto output_var = scope.get()->Var("output");
auto output = output_var->template Get<framework::LoDTensor>();
framework::AttributeMap attrs;
auto *op = new operators::SoftmaxOp<CPU, float>("softmax", inputs, outputs,
......@@ -71,6 +70,8 @@ int TestSoftmaxOp(const std::vector<int> input_shape) {
op->Init();
op->Run();
auto output = output_var->template Get<framework::LoDTensor>();
framework::Tensor output_cmp;
float *output_cmp_data = output_cmp.mutable_data<float>(output->dims());
Softmax(input, &output_cmp);
......
......@@ -272,6 +272,9 @@ if(NOT FOUND_MATCH)
set(FUSION_DEQUANT_ADD_BN_RELU_OP ON)
set(FUSION_DEQUANT_ADD_BN_QUANT_OP ON)
set(FUSION_DEQUANT_ADD_BN_RELU_QUANT_OP ON)
set(SEQUENCE_EXPAND_OP ON)
set(SEQUENCE_POOL_OP ON)
set(SEQUENCE_SOFTMAX_OP ON)
endif()
# option(BATCHNORM_OP "" ON)
......@@ -496,6 +499,15 @@ endif()
if (FUSION_DEQUANT_ADD_BN_RELU_QUANT_OP)
# add_definitions(-DFUSION_DEQUANT_ADD_BN_RELU_QUANT_OP)
endif()
if (SEQUENCE_EXPAND_OP)
add_definitions(-DSEQUENCE_EXPAND_OP)
endif()
if (SEQUENCE_POOL_OP)
add_definitions(-DSEQUENCE_POOL_OP)
endif()
if (SEQUENCE_SOFTMAX_OP)
add_definitions(-DSEQUENCE_SOFTMAX_OP)
endif()
if (TANH_OP)
add_definitions(-DTANH_OP)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册