提交 b6e709e2 编写于 作者: H hjchen2

Fix bugs and add sequence softmax op

上级 4f6362c7
......@@ -91,6 +91,7 @@ const char *G_OP_TYPE_FUSION_DECONV_ADD_RELU = "fusion_deconv_add_relu";
const char *G_OP_TYPE_SEQUENCE_EXPAND = "sequence_expand";
const char *G_OP_TYPE_SEQUENCE_POOL = "sequence_pool";
const char *G_OP_TYPE_SEQUENCE_SOFTMAX = "sequence_softmax";
std::unordered_map<
std::string, std::pair<std::vector<std::string>, std::vector<std::string>>>
......@@ -167,5 +168,6 @@ std::unordered_map<
{G_OP_TYPE_FUSION_DECONV_ADD, {{"Input"}, {"Out"}}},
{G_OP_TYPE_FUSION_DECONV_ADD_RELU, {{"Input"}, {"Out"}}},
{G_OP_TYPE_SEQUENCE_EXPAND, {{"X", "Y"}, {"Out"}}},
{G_OP_TYPE_SEQUENCE_POOL, {{"X"}, {"Out"}}}};
{G_OP_TYPE_SEQUENCE_POOL, {{"X"}, {"Out"}}},
{G_OP_TYPE_SEQUENCE_SOFTMAX, {{"X"}, {"Out"}}}};
} // namespace paddle_mobile
......@@ -173,6 +173,7 @@ extern const char *G_OP_TYPE_FUSION_DECONV_ADD_RELU;
extern const char *G_OP_TYPE_SEQUENCE_EXPAND;
extern const char *G_OP_TYPE_SEQUENCE_POOL;
extern const char *G_OP_TYPE_SEQUENCE_SOFTMAX;
extern std::unordered_map<
std::string, std::pair<std::vector<std::string>, std::vector<std::string>>>
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#ifdef SEQUENCE_POOL_OP
#include <cmath>
#include <limits>
#include <string>
#include <vector>
#include "common/types.h"
......@@ -41,7 +42,7 @@ void SequencePoolImpl(const framework::LoDTensor &input,
float *out_ptr = output_ptr + i * width;
int64_t height = static_cast<int64_t>(lod[i + 1] - lod[i]);
if (width == 1) {
float val = 0.f;
float max = -std::numeric_limits<float>::max();
int remain_h = height;
#ifdef __ARM_NEON__
int loop = remain_h >> 2;
......@@ -53,19 +54,19 @@ void SequencePoolImpl(const framework::LoDTensor &input,
in_ptr += 4;
}
float32x2_t __max2 =
vpadd_f32(vget_low_f32(__max4), vget_high_f32(__max4));
__max2 = vpadd_f32(__max2, __max2);
val = std::max(val, vget_lane_f32(__max2, 0));
vpmax_f32(vget_low_f32(__max4), vget_high_f32(__max4));
__max2 = vpmax_f32(__max2, __max2);
max = std::max(max, vget_lane_f32(__max2, 0));
#endif // __ARM_NEON__
for (int h = 0; h < remain_h; ++h) {
val = std::max(val, in_ptr[h]);
max = std::max(max, in_ptr[h]);
}
*out_ptr = val;
*out_ptr = max;
} else {
memcpy(out_ptr, in_ptr, width * sizeof(float));
in_ptr += width;
int remain_h = height - 1;
#ifdef __ARM_NEON__
int loop_w = width >> 2;
int remain_w_start = width & 0xfffc;
#endif // __ARM_NEON__
for (int h = 0; h < remain_h; ++h) {
......@@ -121,6 +122,7 @@ void SequencePoolImpl<SUM, float>(const framework::LoDTensor &input,
*out_ptr = sum;
} else {
memcpy(out_ptr, in_ptr, width * sizeof(float));
in_ptr += width;
int remain_h = height - 1;
#ifdef __ARM_NEON__
int loop_w = width >> 2;
......@@ -128,7 +130,7 @@ void SequencePoolImpl<SUM, float>(const framework::LoDTensor &input,
#endif // __ARM_NEON__
for (int h = 0; h < remain_h; ++h) {
#ifdef __ARM_NEON__
for (int w = 0; w < width; w += 4) {
for (int w = 0; w < width - 3; w += 4) {
float32x4_t __in = vld1q_f32(in_ptr + w);
float32x4_t __out = vld1q_f32(out_ptr + w);
__out = vaddq_f32(__out, __in);
......@@ -169,6 +171,7 @@ class SequencePoolKernel<CPU, T>
const framework::LoDTensor *input = param.input_;
framework::LoDTensor *output = param.output_;
output->mutable_data<T>();
const std::string pooling_type = param.pool_type_;
if (param.pool_type_ == "MAX") {
SequencePoolImpl<MAX, T>(*input, output);
......@@ -176,6 +179,10 @@ class SequencePoolKernel<CPU, T>
SequencePoolImpl<FIRST, T>(*input, output);
} else if (param.pool_type_ == "SUM") {
SequencePoolImpl<SUM, T>(*input, output);
} else {
PADDLE_MOBILE_THROW_EXCEPTION(
"pooling type `%s` has not been implemented.",
param.pool_type_.c_str());
}
}
};
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SEQUENCE_SOFTMAX_OP
#include "framework/lod_tensor.h"
#include "operators/kernel/sequence_kernels.h"
#include "operators/math/softmax.h"
namespace paddle_mobile {
namespace operators {
template <typename T>
class SequenceSoftmaxKernel<CPU, T>
: public framework::OpKernelBase<CPU, SoftmaxParam<CPU>> {
public:
bool Init(SoftmaxParam<CPU> *param) { return true; }
void Compute(const SoftmaxParam<CPU> &param) {
const framework::LoDTensor *input = param.InputX();
framework::LoDTensor *output = param.Out();
math::SequenceSoftmaxFuntor<CPU, T> sequence_softmax;
sequence_softmax(input, output);
}
};
template class SequenceSoftmaxKernel<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
#endif // SEQUENCE_SOFTMAX_OP
......@@ -37,5 +37,9 @@ DECLARE_KERNEL(SequenceExpandKernel, SequenceExpandParam);
DECLARE_KERNEL(SequencePoolKernel, SequencePoolParam);
#endif // SEQUENCE_POOL_OP
#ifdef SEQUENCE_SOFTMAX_OP
DECLARE_KERNEL(SequenceSoftmaxKernel, SoftmaxParam);
#endif // SEQUENCE_SOFTMAX_OP
} // namespace operators
} // namespace paddle_mobile
......@@ -60,6 +60,68 @@ float find_max(const float *input, const int num_classes) {
return max;
}
void SoftmaxBasic(const float *input, int num_classes, float *y) {
float *output = y;
// find max
float max = find_max(input, num_classes);
// exp(x - max)
int remain = num_classes;
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
int loop = num_classes >> 3;
remain = num_classes & 0x7;
float32x4_t __max = vdupq_n_f32(max);
for (int i = 0; i < loop; ++i, input += 8, output += 8) {
float32x4_t x0 = vld1q_f32(input);
float32x4_t x1 = vld1q_f32(input + 4);
x0 = vsubq_f32(x0, __max);
x1 = vsubq_f32(x1, __max);
x0 = exp_ps(x0);
x1 = exp_ps(x1);
vst1q_f32(output, x0);
vst1q_f32(output + 4, x1);
}
#endif // __ARM_NEON__
for (int i = 0; i < remain; ++i) {
output[i] = expf(input[i] - max);
}
// sum(exp(x - max))
float sum = 0.f;
output = y;
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
float32x4_t __sum = vdupq_n_f32(0.f);
for (int i = 0; i < loop; ++i, output += 8) {
float32x4_t x0 = vld1q_f32(output);
float32x4_t x1 = vld1q_f32(output + 4);
__sum = vaddq_f32(x0, __sum);
__sum = vaddq_f32(x1, __sum);
}
sum += vaddvq_f32(__sum);
#endif // __ARM_NEON__
for (int i = 0; i < remain; ++i) {
sum += output[i];
}
// exp(x - max) / sum
float inv_sum = 1.f / sum;
output = y;
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
float32x4_t __inv_sum = vdupq_n_f32(inv_sum);
for (int i = 0; i < loop; ++i, output += 8) {
float32x4_t x0 = vld1q_f32(output);
float32x4_t x1 = vld1q_f32(output + 4);
x0 = vmulq_f32(x0, __inv_sum);
x1 = vmulq_f32(x1, __inv_sum);
vst1q_f32(output, x0);
vst1q_f32(output + 4, x1);
}
#endif
for (int i = 0; i < remain; ++i) {
output[i] *= inv_sum;
}
}
template <>
void SoftmaxFuntor<CPU, float>::operator()(const framework::Tensor *X,
framework::Tensor *Y) {
......@@ -76,65 +138,25 @@ void SoftmaxFuntor<CPU, float>::operator()(const framework::Tensor *X,
size_t offset = (batch * channels + channel) * num_classes;
const float *input = x + offset;
float *output = y + offset;
// find max
float max = find_max(input, num_classes);
// exp(x - max)
int remain = num_classes;
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
int loop = num_classes >> 3;
remain = num_classes & 0x7;
float32x4_t __max = vdupq_n_f32(max);
for (int i = 0; i < loop; ++i, input += 8, output += 8) {
float32x4_t x0 = vld1q_f32(input);
float32x4_t x1 = vld1q_f32(input + 4);
x0 = vsubq_f32(x0, __max);
x1 = vsubq_f32(x1, __max);
x0 = exp_ps(x0);
x1 = exp_ps(x1);
vst1q_f32(output, x0);
vst1q_f32(output + 4, x1);
}
#endif // __ARM_NEON__
for (int i = 0; i < remain; ++i) {
output[i] = expf(input[i] - max);
}
SoftmaxBasic(input, num_classes, output);
}
}
}
// sum(exp(x - max))
float sum = 0.f;
output = y + offset;
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
float32x4_t __sum = vdupq_n_f32(0.f);
for (int i = 0; i < loop; ++i, output += 8) {
float32x4_t x0 = vld1q_f32(output);
float32x4_t x1 = vld1q_f32(output + 4);
__sum = vaddq_f32(x0, __sum);
__sum = vaddq_f32(x1, __sum);
}
sum += vaddvq_f32(__sum);
#endif // __ARM_NEON__
for (int i = 0; i < remain; ++i) {
sum += output[i];
}
template <>
void SequenceSoftmaxFuntor<CPU, float>::operator()(
const framework::LoDTensor *X, framework::LoDTensor *Y) {
const float *x = X->data<float>();
const auto &lod = X->lod().back();
float *y = Y->mutable_data<float>();
// exp(x - max) / sum
float inv_sum = 1.f / sum;
output = y + offset;
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
float32x4_t __inv_sum = vdupq_n_f32(inv_sum);
for (int i = 0; i < loop; ++i, output += 8) {
float32x4_t x0 = vld1q_f32(output);
float32x4_t x1 = vld1q_f32(output + 4);
x0 = vmulq_f32(x0, __inv_sum);
x1 = vmulq_f32(x1, __inv_sum);
vst1q_f32(output, x0);
vst1q_f32(output + 4, x1);
}
#endif
for (int i = 0; i < remain; ++i) {
output[i] *= inv_sum;
}
}
#pragma omp parallel for
for (int batch = 0; batch < lod.size() - 1; ++batch) {
int num_classes = lod[batch + 1] - lod[batch];
size_t offset = lod[batch];
const float *input = x + offset;
float *output = y + offset;
SoftmaxBasic(input, num_classes, output);
}
}
......
......@@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SOFTMAX_OP
#if defined(SOFTMAX_OP) || defined(SEQUENCE_SOFTMAX_OP)
#pragma once
#include "framework/lod_tensor.h"
#include "framework/tensor.h"
namespace paddle_mobile {
......@@ -28,7 +29,14 @@ class SoftmaxFuntor {
void operator()(const framework::Tensor *X, framework::Tensor *Y);
};
template <typename Device, typename T>
class SequenceSoftmaxFuntor {
public:
void operator()(const framework::LoDTensor *X, framework::LoDTensor *Y);
};
} // namespace math
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -978,12 +978,12 @@ class SoftmaxParam : public OpParam {
input_x_ = InputXFrom<GType>(inputs, scope);
out_ = OutFrom<GType>(outputs, scope);
}
const RType *InputX() const { return input_x_; }
RType *Out() const { return out_; }
const GType *InputX() const { return input_x_; }
GType *Out() const { return out_; }
private:
RType *input_x_;
RType *out_;
GType *input_x_;
GType *out_;
#ifdef PADDLE_MOBILE_FPGA
......@@ -2778,7 +2778,7 @@ class SequencePoolParam : public OpParam {
output_ = OutFrom<GType>(outputs, scope);
pool_type_ = "MAX";
if (OpParam::HasAttr("pooltype", attrs)) {
pool_type_ = OpParam::GetAttr<std::string>("pooltype", attrs);
pool_type_ = OpParam::GetStringAttr("pooltype", attrs);
}
}
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SEQUENCE_SOFTMAX_OP
#include "operators/sequence_ops/sequence_softmax_op.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
void SequenceSoftmaxOp<DeviceType, T>::InferShape() const {
const auto *input_x = this->param_.InputX();
const auto &x_lod = input_x->lod();
this->param_.Out()->Resize(input_x->dims());
this->param_.Out()->set_lod(input_x->lod());
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(sequence_softmax, ops::SequenceSoftmaxOp);
#endif
#endif // SEQUENCE_SOFTMAX_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SEQUENCE_SOFTMAX_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/sequence_kernels.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class SequenceSoftmaxOp : public framework::OperatorWithKernel<
DeviceType, SoftmaxParam<DeviceType>,
operators::SequenceSoftmaxKernel<DeviceType, T>> {
public:
SequenceSoftmaxOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, SoftmaxParam<DeviceType>,
operators::SequenceSoftmaxKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
// inference output shape
void InferShape() const override;
};
} // namespace operators
} // namespace paddle_mobile
#endif // SEQUENCE_SOFTMAX_OP
......@@ -385,4 +385,13 @@ if (NOT FOUND_MATCH)
# gen test
ADD_EXECUTABLE(test-ocr net/test_ocr.cpp test_helper.h test_include.h)
target_link_libraries(test-ocr paddle-mobile)
ADD_EXECUTABLE(test-sequence-expand operators/test_sequence_expand_op.cpp test_helper.h test_include.h)
target_link_libraries(test-sequence-expand paddle-mobile)
ADD_EXECUTABLE(test-sequence-pool operators/test_sequence_pool_op.cpp test_helper.h test_include.h)
target_link_libraries(test-sequence-pool paddle-mobile)
ADD_EXECUTABLE(test-sequence-softmax operators/test_sequence_softmax_op.cpp test_helper.h test_include.h)
target_link_libraries(test-sequence-softmax paddle-mobile)
endif ()
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include "../test_include.h"
#include "operators/sequence_ops/sequence_expand_op.h"
namespace paddle_mobile {
int TestSequenceExpandOp(const framework::LoDTensor &input_x,
const framework::LoDTensor &input_y, int ref_level,
framework::LoDTensor *output) {
VariableNameMap inputs;
VariableNameMap outputs;
auto scope = std::make_shared<framework::Scope>();
inputs["X"] = std::vector<std::string>({"input_x"});
inputs["Y"] = std::vector<std::string>({"input_y"});
outputs["Out"] = std::vector<std::string>({"output"});
auto input_x_var = scope.get()->Var("input_x");
auto *x = input_x_var->template GetMutable<framework::LoDTensor>();
x->Resize(input_x.dims());
x->ShareDataWith(input_x);
x->set_lod(input_x.lod());
auto input_y_var = scope.get()->Var("input_y");
auto *y = input_y_var->template GetMutable<framework::LoDTensor>();
y->Resize(framework::make_ddim({0}));
y->mutable_data<float>();
y->set_lod(input_y.lod());
auto output_var = scope.get()->Var("output");
framework::AttributeMap attrs;
attrs["ref_level"].Set<int>(0);
auto *op = new operators::SequenceExpandOp<CPU, float>(
"sequence_expand", inputs, outputs, attrs, scope);
op->InferShape();
op->Init();
op->Run();
auto *out = output_var->template Get<framework::LoDTensor>();
output->Resize(out->dims());
output->ShareDataWith(*out);
output->set_lod(out->lod());
delete op;
return 0;
}
} // namespace paddle_mobile
// namespace framework = paddle_mobile::framework;
int main(int argc, char *argv[]) {
framework::LoDTensor input_x, input_y, output;
// case 1
{
std::vector<float> data{1, 2, 3, 4};
input_x.Resize(framework::make_ddim({4, 1}));
float *in_data = input_x.mutable_data<float>();
for (int i = 0; i < 4; ++i) in_data[i] = data[i];
input_x.set_lod({{0, 2, 4}});
input_y.set_lod({{0, 2, 4}, {0, 3, 6, 7, 8}});
TestSequenceExpandOp(input_x, input_y, 0, &output);
std::vector<float> expect_data{1, 2, 1, 2, 3, 4, 3, 4};
std::vector<int> expect_lod{0, 2, 4, 6, 8};
for (int i = 0; i < 5; ++i) {
if (output.lod()[0][i] != expect_lod[i]) {
std::cerr << "output_lod[" << i << "]: " << output.lod()[0][i]
<< " != expect_lod[" << i << "]: " << expect_lod[i]
<< std::endl;
return 1;
}
}
for (int i = 0; i < 8; ++i) {
if (output.data<float>()[i] != expect_data[i]) {
std::cerr << "output[" << i << "]: " << output.data<float>()[i]
<< " != expect[" << i << "]: " << expect_data[i] << std::endl;
return 1;
}
}
}
return 0;
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include "../test_include.h"
#include "operators/sequence_ops/sequence_pool_op.h"
namespace paddle_mobile {
int TestSequencePoolOp(const framework::LoDTensor &input_x,
const std::string pool_type,
framework::LoDTensor *output) {
VariableNameMap inputs;
VariableNameMap outputs;
auto scope = std::make_shared<framework::Scope>();
inputs["X"] = std::vector<std::string>({"input_x"});
outputs["Out"] = std::vector<std::string>({"output"});
auto input_x_var = scope.get()->Var("input_x");
auto *x = input_x_var->template GetMutable<framework::LoDTensor>();
x->Resize(input_x.dims());
x->ShareDataWith(input_x);
x->set_lod(input_x.lod());
auto output_var = scope.get()->Var("output");
framework::AttributeMap attrs;
attrs["pooltype"].SetString(pool_type);
auto *op = new operators::SequencePoolOp<CPU, float>("sequence_pool", inputs,
outputs, attrs, scope);
op->InferShape();
op->Init();
op->Run();
auto *out = output_var->template Get<framework::LoDTensor>();
output->Resize(out->dims());
output->ShareDataWith(*out);
delete op;
return 0;
}
} // namespace paddle_mobile
// namespace framework = paddle_mobile::framework;
int main(int argc, char *argv[]) {
framework::LoDTensor input_x, output;
// case 1
std::cerr << "running max case 1" << std::endl;
{
std::vector<float> data{1, 2, 3, 4};
input_x.Resize(framework::make_ddim({4, 1}));
float *in_data = input_x.mutable_data<float>();
for (int i = 0; i < 4; ++i) in_data[i] = data[i];
input_x.set_lod({{0, 2, 4}});
TestSequencePoolOp(input_x, "MAX", &output);
std::vector<float> expect_data{2, 4};
for (int i = 0; i < 2; ++i) {
if (output.data<float>()[i] != expect_data[i]) {
std::cerr << "output[" << i << "]: " << output.data<float>()[i]
<< " != expect[" << i << "]: " << expect_data[i] << std::endl;
return 1;
}
}
}
// case 2
std::cerr << "running max case 2" << std::endl;
{
std::vector<float> data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
input_x.Resize(framework::make_ddim({data.size(), 1}));
float *in_data = input_x.mutable_data<float>();
for (int i = 0; i < data.size(); ++i) in_data[i] = data[i];
input_x.set_lod({{0, 3, 10}});
TestSequencePoolOp(input_x, "MAX", &output);
std::vector<float> expect_data{3, 10};
for (int i = 0; i < 2; ++i) {
if (output.data<float>()[i] != expect_data[i]) {
std::cerr << "output[" << i << "]: " << output.data<float>()[i]
<< " != expect[" << i << "]: " << expect_data[i] << std::endl;
return 1;
}
}
}
std::cerr << "running max case 3" << std::endl;
// case 3
{
std::vector<float> data{1, 2, 3, 4, 5, 6, 7, 8};
input_x.Resize(framework::make_ddim({4, 2}));
float *in_data = input_x.mutable_data<float>();
for (int i = 0; i < data.size(); ++i) in_data[i] = data[i];
input_x.set_lod({{0, 2, 4}});
TestSequencePoolOp(input_x, "MAX", &output);
std::vector<float> expect_data{3, 4, 7, 8};
for (int i = 0; i < 4; ++i) {
if (output.data<float>()[i] != expect_data[i]) {
std::cerr << "output[" << i << "]: " << output.data<float>()[i]
<< " != expect[" << i << "]: " << expect_data[i] << std::endl;
return 1;
}
}
}
// case 4
std::cerr << "running max case 4" << std::endl;
{
std::vector<float> data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
11, 12, 13, 14, 15, 16, 17, 18, 19, 20};
input_x.Resize(framework::make_ddim({4, 5}));
float *in_data = input_x.mutable_data<float>();
for (int i = 0; i < data.size(); ++i) in_data[i] = data[i];
input_x.set_lod({{0, 2, 4}});
TestSequencePoolOp(input_x, "MAX", &output);
std::vector<float> expect_data{6, 7, 8, 9, 10, 16, 17, 18, 19, 20};
for (int i = 0; i < 10; ++i) {
if (output.data<float>()[i] != expect_data[i]) {
std::cerr << "output[" << i << "]: " << output.data<float>()[i]
<< " != expect[" << i << "]: " << expect_data[i] << std::endl;
return 1;
}
}
}
// case 1
std::cerr << "running sum case 1" << std::endl;
{
std::vector<float> data{1, 2, 3, 4};
input_x.Resize(framework::make_ddim({4, 1}));
float *in_data = input_x.mutable_data<float>();
for (int i = 0; i < 4; ++i) in_data[i] = data[i];
input_x.set_lod({{0, 2, 4}});
TestSequencePoolOp(input_x, "SUM", &output);
std::vector<float> expect_data{3, 7};
for (int i = 0; i < 2; ++i) {
if (output.data<float>()[i] != expect_data[i]) {
std::cerr << "output[" << i << "]: " << output.data<float>()[i]
<< " != expect[" << i << "]: " << expect_data[i] << std::endl;
return 1;
}
}
}
// case 2
std::cerr << "running sum case 2" << std::endl;
{
std::vector<float> data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
input_x.Resize(framework::make_ddim({data.size(), 1}));
float *in_data = input_x.mutable_data<float>();
for (int i = 0; i < data.size(); ++i) in_data[i] = data[i];
input_x.set_lod({{0, 3, 10}});
TestSequencePoolOp(input_x, "SUM", &output);
std::vector<float> expect_data{6, 49};
for (int i = 0; i < 2; ++i) {
if (output.data<float>()[i] != expect_data[i]) {
std::cerr << "output[" << i << "]: " << output.data<float>()[i]
<< " != expect[" << i << "]: " << expect_data[i] << std::endl;
return 1;
}
}
}
// case 3
std::cerr << "running sum case 3" << std::endl;
{
std::vector<float> data{1, 2, 3, 4, 5, 6, 7, 8};
input_x.Resize(framework::make_ddim({4, 2}));
float *in_data = input_x.mutable_data<float>();
for (int i = 0; i < data.size(); ++i) in_data[i] = data[i];
input_x.set_lod({{0, 2, 4}});
TestSequencePoolOp(input_x, "SUM", &output);
std::vector<float> expect_data{4, 6, 12, 14};
for (int i = 0; i < 4; ++i) {
if (output.data<float>()[i] != expect_data[i]) {
std::cerr << "output[" << i << "]: " << output.data<float>()[i]
<< " != expect[" << i << "]: " << expect_data[i] << std::endl;
return 1;
}
}
}
// case 4
std::cerr << "running sum case 4" << std::endl;
{
std::vector<float> data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
11, 12, 13, 14, 15, 16, 17, 18, 19, 20};
input_x.Resize(framework::make_ddim({4, 5}));
float *in_data = input_x.mutable_data<float>();
for (int i = 0; i < data.size(); ++i) in_data[i] = data[i];
input_x.set_lod({{0, 2, 4}});
TestSequencePoolOp(input_x, "SUM", &output);
std::vector<float> expect_data{7, 9, 11, 13, 15, 27, 29, 31, 33, 35};
for (int i = 0; i < 10; ++i) {
if (output.data<float>()[i] != expect_data[i]) {
std::cerr << "output[" << i << "]: " << output.data<float>()[i]
<< " != expect[" << i << "]: " << expect_data[i] << std::endl;
return 1;
}
}
}
// case 1
std::cerr << "running first case 1" << std::endl;
{
std::vector<float> data{1, 2, 3, 4};
input_x.Resize(framework::make_ddim({4, 1}));
float *in_data = input_x.mutable_data<float>();
for (int i = 0; i < 4; ++i) in_data[i] = data[i];
input_x.set_lod({{0, 2, 4}});
TestSequencePoolOp(input_x, "FIRST", &output);
std::vector<float> expect_data{1, 3};
for (int i = 0; i < 2; ++i) {
if (output.data<float>()[i] != expect_data[i]) {
std::cerr << "output[" << i << "]: " << output.data<float>()[i]
<< " != expect[" << i << "]: " << expect_data[i] << std::endl;
return 1;
}
}
}
// case 2
std::cerr << "running first case 2" << std::endl;
{
std::vector<float> data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
input_x.Resize(framework::make_ddim({data.size(), 1}));
float *in_data = input_x.mutable_data<float>();
for (int i = 0; i < data.size(); ++i) in_data[i] = data[i];
input_x.set_lod({{0, 3, 10}});
TestSequencePoolOp(input_x, "FIRST", &output);
std::vector<float> expect_data{1, 4};
for (int i = 0; i < 2; ++i) {
if (output.data<float>()[i] != expect_data[i]) {
std::cerr << "output[" << i << "]: " << output.data<float>()[i]
<< " != expect[" << i << "]: " << expect_data[i] << std::endl;
return 1;
}
}
}
// case 3
std::cerr << "running first case 3" << std::endl;
{
std::vector<float> data{1, 2, 3, 4, 5, 6, 7, 8};
input_x.Resize(framework::make_ddim({4, 2}));
float *in_data = input_x.mutable_data<float>();
for (int i = 0; i < data.size(); ++i) in_data[i] = data[i];
input_x.set_lod({{0, 2, 4}});
TestSequencePoolOp(input_x, "FIRST", &output);
std::vector<float> expect_data{1, 2, 5, 6};
for (int i = 0; i < 4; ++i) {
if (output.data<float>()[i] != expect_data[i]) {
std::cerr << "output[" << i << "]: " << output.data<float>()[i]
<< " != expect[" << i << "]: " << expect_data[i] << std::endl;
return 1;
}
}
}
// case 4
std::cerr << "running first case 4" << std::endl;
{
std::vector<float> data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
11, 12, 13, 14, 15, 16, 17, 18, 19, 20};
input_x.Resize(framework::make_ddim({4, 5}));
float *in_data = input_x.mutable_data<float>();
for (int i = 0; i < data.size(); ++i) in_data[i] = data[i];
input_x.set_lod({{0, 2, 4}});
TestSequencePoolOp(input_x, "FIRST", &output);
std::vector<float> expect_data{1, 2, 3, 4, 5, 11, 12, 13, 14, 15};
for (int i = 0; i < 10; ++i) {
if (output.data<float>()[i] != expect_data[i]) {
std::cerr << "output[" << i << "]: " << output.data<float>()[i]
<< " != expect[" << i << "]: " << expect_data[i] << std::endl;
return 1;
}
}
}
return 0;
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <math.h>
#include <limits>
#include "../test_include.h"
#include "operators/sequence_ops/sequence_softmax_op.h"
namespace paddle_mobile {
void SequenceSoftmax(const framework::LoDTensor *X, framework::LoDTensor *Y) {
const float *x = X->data<float>();
const auto &lod = X->lod().back();
float *y = Y->mutable_data<float>();
for (int batch = 0; batch < lod.size() - 1; ++batch) {
int num_classes = lod[batch + 1] - lod[batch];
size_t offset = lod[batch];
const float *input = x + offset;
float *output = y + offset;
float max = -std::numeric_limits<float>::max();
for (int j = 0; j < num_classes; ++j) {
max = (input[j] > max) ? input[j] : max;
}
float sum = 0.f;
for (int j = 0; j < num_classes; ++j) {
float tmp = std::expf(input[j] - max);
sum += tmp;
output[j] = tmp;
}
for (int j = 0; j < num_classes; ++j) {
output[j] /= sum;
}
}
Y->set_lod(X->lod());
}
int TestSequenceSoftmaxOp(const std::vector<int> &input_shape,
const std::vector<size_t> &input_lod) {
framework::DDim dims = framework::make_ddim(input_shape);
VariableNameMap inputs;
VariableNameMap outputs;
auto scope = std::make_shared<framework::Scope>();
inputs["X"] = std::vector<std::string>({"input"});
outputs["Out"] = std::vector<std::string>({"output"});
auto input_var = scope.get()->Var("input");
auto input = input_var->template GetMutable<framework::LoDTensor>();
SetupTensor<float>(input, dims, -100.0, 100.0);
input->set_lod({input_lod});
auto output_var = scope.get()->Var("output");
framework::AttributeMap attrs;
auto *op = new operators::SequenceSoftmaxOp<CPU, float>(
"sequence_softmax", inputs, outputs, attrs, scope);
op->InferShape();
op->Init();
op->Run();
auto output = output_var->template Get<framework::LoDTensor>();
framework::LoDTensor output_cmp;
float *output_cmp_data = output_cmp.mutable_data<float>(output->dims());
SequenceSoftmax(input, &output_cmp);
const float *output_data = output->data<float>();
for (int i = 0; i < output->numel(); ++i) {
float gap = output_data[i] - output_cmp_data[i];
if (std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) {
LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i]
<< ", output_cmp_data[" << i
<< "] = " << output_cmp_data[i];
delete op;
exit(1);
}
}
delete op;
return 0;
}
} // namespace paddle_mobile
int main(int argc, char *argv[]) {
TestSequenceSoftmaxOp({2, 1}, {0, 2});
TestSequenceSoftmaxOp({100, 1}, {0, 3, 100});
TestSequenceSoftmaxOp({100, 1}, {0, 50, 100});
return 0;
}
......@@ -62,7 +62,6 @@ int TestSoftmaxOp(const std::vector<int> input_shape) {
SetupTensor<float>(input, dims, -100.0, 100.0);
auto output_var = scope.get()->Var("output");
auto output = output_var->template Get<framework::LoDTensor>();
framework::AttributeMap attrs;
auto *op = new operators::SoftmaxOp<CPU, float>("softmax", inputs, outputs,
......@@ -71,6 +70,8 @@ int TestSoftmaxOp(const std::vector<int> input_shape) {
op->Init();
op->Run();
auto output = output_var->template Get<framework::LoDTensor>();
framework::Tensor output_cmp;
float *output_cmp_data = output_cmp.mutable_data<float>(output->dims());
Softmax(input, &output_cmp);
......
......@@ -274,6 +274,7 @@ if(NOT FOUND_MATCH)
set(FUSION_DEQUANT_ADD_BN_RELU_QUANT_OP ON)
set(SEQUENCE_EXPAND_OP ON)
set(SEQUENCE_POOL_OP ON)
set(SEQUENCE_SOFTMAX_OP ON)
endif()
# option(BATCHNORM_OP "" ON)
......@@ -504,6 +505,9 @@ endif()
if (SEQUENCE_POOL_OP)
add_definitions(-DSEQUENCE_POOL_OP)
endif()
if (SEQUENCE_SOFTMAX_OP)
add_definitions(-DSEQUENCE_SOFTMAX_OP)
endif()
if (TANH_OP)
add_definitions(-DTANH_OP)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册