未验证 提交 f15d621a 编写于 作者: H HappyAngel 提交者: GitHub

[arm] add sequence_pool_grad op (#3986)


* fix sequence pool grad compute error, test=develop

* fix sequence pool grad run, test=develop

* fix format test=develop
上级 875d4563
......@@ -118,6 +118,7 @@ if (NOT HAS_ARM_MATH_LIB_DIR)
beam_search.cc
reduce_max.cc
sequence_pool.cc
sequence_pool_grad.cc
sequence_expand.cc
slice.cc
reduce_mean.cc
......
......@@ -56,6 +56,7 @@
#include "lite/backends/arm/math/scale.h"
#include "lite/backends/arm/math/sequence_expand.h"
#include "lite/backends/arm/math/sequence_pool.h"
#include "lite/backends/arm/math/sequence_pool_grad.h"
#include "lite/backends/arm/math/sequence_softmax.h"
#include "lite/backends/arm/math/sgemm.h"
#include "lite/backends/arm/math/sgemv.h"
......
......@@ -36,70 +36,23 @@ void seq_pool_sum<float>(const float* din,
const float* din_ptr = din + lod[i] * width;
float* dout_ptr = dout + i * width;
int64_t height = static_cast<int64_t>(lod[i + 1] - lod[i]);
if (width == 1) {
float sum = 0.f;
for (int h = 0; h < height; ++h) {
sum += din_ptr[h];
}
*dout_ptr = sum;
} else {
memcpy(dout_ptr, din_ptr, width * sizeof(float));
din_ptr += width;
height = height - 1;
int cnt_w = width >> 2;
int remain_w = width & 3;
int cnt_h = height >> 2;
int remain_h = height & 3;
int stride = width << 2;
for (int w = 0; w < cnt_w; w++) {
const float* din_ptr0 = din_ptr + w * 4;
float32x4_t dout_val = vld1q_f32(dout_ptr);
const float* din_ptr1 = din_ptr0 + width;
const float* din_ptr2 = din_ptr1 + width;
const float* din_ptr3 = din_ptr2 + width;
for (int h = 0; h < cnt_h; h++) {
float32x4_t din0 = vld1q_f32(din_ptr0);
float32x4_t din1 = vld1q_f32(din_ptr1);
float32x4_t din2 = vld1q_f32(din_ptr2);
float32x4_t din3 = vld1q_f32(din_ptr3);
dout_val = vaddq_f32(din0, dout_val);
float32x4_t tmp = vaddq_f32(din1, din2);
din_ptr0 += stride;
din_ptr1 += stride;
dout_val = vaddq_f32(din3, dout_val);
din_ptr2 += stride;
din_ptr3 += stride;
dout_val = vaddq_f32(tmp, dout_val);
}
for (int h = 0; h < remain_h; h++) {
float32x4_t din0 = vld1q_f32(din_ptr0);
dout_val = vaddq_f32(din0, dout_val);
din_ptr0 += width;
}
vst1q_f32(dout_ptr, dout_val);
dout_ptr += 4;
}
const float* din_ptr00 = din_ptr + cnt_w * 4;
for (int w = 0; w < remain_w; w++) {
const float* din_ptr0 = din_ptr00 + w;
const float* din_ptr1 = din_ptr0 + width;
const float* din_ptr2 = din_ptr1 + width;
const float* din_ptr3 = din_ptr2 + width;
for (int h = 0; h < cnt_h; h++) {
*dout_ptr += din_ptr0[0];
float tmp = din_ptr1[0] + din_ptr2[0];
din_ptr0 += stride;
din_ptr1 += stride;
*dout_ptr += din_ptr3[0];
din_ptr2 += stride;
din_ptr3 += stride;
*dout_ptr += tmp;
if (height > 0) {
if (width == 1) {
float sum = 0.f;
for (int h = 0; h < height; ++h) {
sum += din_ptr[h];
}
for (int h = 0; h < remain_h; h++) {
*dout_ptr += din_ptr0[0];
din_ptr0 += width;
*dout_ptr = sum;
} else {
memcpy(dout_ptr, din_ptr, width * sizeof(float));
din_ptr += width;
height = height - 1;
for (int h = 0; h < height; h++) {
for (int w = 0; w < width; ++w) {
dout_ptr[w] += din_ptr[w];
}
din_ptr += width;
}
dout_ptr++;
}
}
}
......@@ -177,78 +130,35 @@ void seq_pool_sqrt<float>(const float* din,
template <>
void seq_pool_max<float>(const float* din,
float* dout,
int64_t* index,
const std::vector<uint64_t> lod,
int64_t width) {
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
const float* din_ptr = din + lod[i] * width;
float* dout_ptr = dout + i * width;
int64_t* index_ptr = index + i * width;
int64_t height = static_cast<int64_t>(lod[i + 1] - lod[i]);
if (height > 0) {
if (width == 1) {
float max = -std::numeric_limits<float>::max();
int64_t max_index = -1;
for (int h = 0; h < height; ++h) {
max = std::max(max, din_ptr[h]);
max_index = max >= din_ptr[h] ? h : max_index;
}
*dout_ptr = max;
*index_ptr = max_index;
} else {
memcpy(dout_ptr, din_ptr, width * sizeof(float));
memset(index_ptr, 0, width * sizeof(int64_t));
din_ptr += width;
height = height - 1;
int cnt_w = width >> 2;
int remain_w = width & 3;
int cnt_h = height >> 2;
int remain_h = height & 3;
int stride = width << 2;
for (int w = 0; w < cnt_w; w++) {
const float* din_ptr0 = din_ptr + w * 4;
float32x4_t dout_val = vld1q_f32(dout_ptr);
const float* din_ptr1 = din_ptr0 + width;
const float* din_ptr2 = din_ptr1 + width;
const float* din_ptr3 = din_ptr2 + width;
for (int h = 0; h < cnt_h; h++) {
float32x4_t din0 = vld1q_f32(din_ptr0);
float32x4_t din1 = vld1q_f32(din_ptr1);
float32x4_t din2 = vld1q_f32(din_ptr2);
float32x4_t din3 = vld1q_f32(din_ptr3);
dout_val = vmaxq_f32(din0, dout_val);
float32x4_t tmp = vmaxq_f32(din1, din2);
din_ptr0 += stride;
din_ptr1 += stride;
dout_val = vmaxq_f32(din3, dout_val);
din_ptr2 += stride;
din_ptr3 += stride;
dout_val = vmaxq_f32(tmp, dout_val);
}
for (int h = 0; h < remain_h; h++) {
float32x4_t din0 = vld1q_f32(din_ptr0);
dout_val = vmaxq_f32(din0, dout_val);
din_ptr0 += width;
}
vst1q_f32(dout_ptr, dout_val);
dout_ptr += 4;
}
const float* din_ptr00 = din_ptr + cnt_w * 4;
for (int w = 0; w < remain_w; w++) {
const float* din_ptr0 = din_ptr00 + w;
const float* din_ptr1 = din_ptr0 + width;
const float* din_ptr2 = din_ptr1 + width;
const float* din_ptr3 = din_ptr2 + width;
for (int h = 0; h < cnt_h; h++) {
*dout_ptr += din_ptr0[0];
*dout_ptr = std::max(*dout_ptr, din_ptr0[0]);
float tmp = std::max(din_ptr1[0], din_ptr2[0]);
din_ptr0 += stride;
din_ptr1 += stride;
*dout_ptr = std::max(*dout_ptr, din_ptr3[0]);
din_ptr2 += stride;
din_ptr3 += stride;
*dout_ptr = std::max(*dout_ptr, tmp);
}
for (int h = 0; h < remain_h; h++) {
*dout_ptr = std::max(*dout_ptr, din_ptr0[0]);
din_ptr0 += width;
int remain_h = height - 1;
for (int h = 0; h < remain_h; h++) {
for (int w = 0; w < width; w++) {
dout_ptr[w] = std::max(dout_ptr[w], din_ptr[w]);
index_ptr[w] = dout_ptr[w] > din_ptr[w] ? index_ptr[w] : h;
}
dout_ptr++;
din_ptr += width;
}
}
}
......@@ -258,26 +168,33 @@ void seq_pool_max<float>(const float* din,
template <>
void seq_pool_min<float>(const float* din,
float* dout,
int64_t* index,
const std::vector<uint64_t> lod,
int64_t width) {
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
const float* din_ptr = din + lod[i] * width;
float* dout_ptr = dout + i * width;
int64_t* index_ptr = index + i * width;
int64_t height = static_cast<int64_t>(lod[i + 1] - lod[i]);
if (height > 0) {
if (width == 1) {
float min = std::numeric_limits<float>::max();
int64_t min_index = -1;
for (int h = 0; h < height; ++h) {
min = std::min(min, din_ptr[h]);
min_index = min >= din_ptr[h] ? h : min_index;
}
*dout_ptr = min;
*index_ptr = min_index;
} else {
memcpy(dout_ptr, din_ptr, width * sizeof(float));
memset(index_ptr, 0, width * sizeof(int64_t));
din_ptr += width;
int remain_h = height - 1;
for (int h = 0; h < remain_h; h++) {
for (int w = 0; w < width; w++) {
dout_ptr[w] = std::min(dout_ptr[w], din_ptr[w]);
index_ptr[w] = dout_ptr[w] < din_ptr[w] ? index_ptr[w] : h;
}
din_ptr += width;
}
......
......@@ -42,12 +42,14 @@ void seq_pool_sqrt(const T* din,
template <typename T>
void seq_pool_max(const T* din,
T* dout,
int64_t* index,
const std::vector<uint64_t> lod,
int64_t width);
template <typename T>
void seq_pool_min(const T* din,
T* dout,
int64_t* index,
const std::vector<uint64_t> lod,
int64_t width);
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/arm/math/sequence_pool_grad.h"
#include <algorithm>
#include <cmath>
#include <limits>
#include <vector>
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
#include "lite/core/type_system.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <>
void seq_pool_sum_grad<float>(const float* din,
const float* dout_grad,
float* din_grad,
const std::vector<uint64_t> lod,
int64_t width) {
for (int i = 0; i < static_cast<int>(lod.size()) - 1; i++) {
int64_t height = static_cast<int64_t>(lod[i + 1] - lod[i]);
const float* dout_grad_ptr = dout_grad + i * width;
float* din_grad_ptr = din_grad + lod[i] * width;
if (height > 0) {
if (width == 1) {
for (int h = 0; h < height; ++h) {
din_grad_ptr[h] = dout_grad_ptr[h];
}
} else {
for (int h = 0; h < height; h++) {
for (int w = 0; w < width; w++) {
din_grad_ptr[w] = dout_grad_ptr[w];
}
din_grad_ptr += width;
}
}
}
}
}
template <>
void seq_pool_average_grad<float>(const float* din,
const float* dout_grad,
float* din_grad,
const std::vector<uint64_t> lod,
int64_t width) {
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
int64_t height = static_cast<int64_t>(lod[i + 1] - lod[i]);
const float* dout_grad_ptr = dout_grad + i * width;
float* din_grad_ptr = din_grad + lod[i] * width;
float alpha = 1.0 / height;
if (height > 0) {
if (width == 1) {
float sum = 0.f;
for (int h = 0; h < height; ++h) {
din_grad_ptr[h] = alpha * dout_grad_ptr[h];
}
} else {
for (int h = 0; h < height; h++) {
for (int w = 0; w < width; w++) {
din_grad_ptr[w] = alpha * dout_grad_ptr[w];
}
din_grad_ptr += width;
}
}
}
}
}
template <>
void seq_pool_sqrt_grad<float>(const float* din,
const float* dout_grad,
float* din_grad,
const std::vector<uint64_t> lod,
int64_t width) {
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
int64_t height = static_cast<int64_t>(lod[i + 1] - lod[i]);
const float* dout_grad_ptr = dout_grad + i * width;
float* din_grad_ptr = din_grad + lod[i] * width;
float alpha = 1.0 / sqrtf(height);
if (height > 0) {
if (width == 1) {
float sum = 0.f;
for (int h = 0; h < height; ++h) {
din_grad_ptr[h] = alpha * dout_grad_ptr[h];
}
} else {
for (int h = 0; h < height; h++) {
for (int w = 0; w < width; w++) {
din_grad_ptr[w] = alpha * dout_grad_ptr[w];
}
din_grad_ptr += width;
}
}
}
}
}
template <>
void seq_pool_max_grad<float>(const float* din,
const float* dout_grad,
const int64_t* index_grad,
float* din_grad,
const std::vector<uint64_t> lod,
int64_t width) {
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
int64_t height = lod[i + 1] - lod[i];
const float* dout_grad_ptr = dout_grad + i * width;
const int64_t* index_grad_ptr = index_grad + i * width;
float* din_grad_ptr = din_grad + lod[i] * width;
if (height > 0) {
for (int h = 0; h < height; h++) {
for (int w = 0; w < width; w++) {
if (h == index_grad_ptr[w]) {
din_grad_ptr[w] = dout_grad_ptr[w];
} else {
din_grad_ptr[w] = 0.f;
}
}
din_grad_ptr += width;
}
}
}
}
template <>
void seq_pool_first_grad<float>(const float* din,
const float* dout_grad,
float* din_grad,
const std::vector<uint64_t> lod,
int64_t width) {
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
int64_t height = lod[i + 1] - lod[i];
const float* dout_grad_ptr = dout_grad + i * width;
float* din_grad_ptr = din_grad + lod[i] * width;
if (height > 0) {
for (int w = 0; w < width; w++) {
din_grad_ptr[w] = dout_grad_ptr[w];
}
din_grad_ptr += width;
for (int h = 1; h < height; h++) {
for (int w = 0; w < width; w++) {
din_grad_ptr[w] = 0.f;
}
din_grad_ptr += width;
}
}
}
}
template <>
void seq_pool_last_grad<float>(const float* din,
const float* dout_grad,
float* din_grad,
const std::vector<uint64_t> lod,
int64_t width) {
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
int64_t height = lod[i + 1] - lod[i];
const float* dout_grad_ptr = dout_grad + i * width;
float* din_grad_ptr = din_grad + lod[i] * width;
if (height > 0) {
for (int h = 0; h < height - 1; h++) {
for (int w = 0; w < width; w++) {
din_grad_ptr[w] = 0.f;
}
din_grad_ptr += width;
}
// last
for (int w = 0; w < width; w++) {
din_grad_ptr[w] = dout_grad_ptr[w];
}
}
}
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "lite/core/tensor.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <typename T>
void seq_pool_sum_grad(const T* din,
const T* dout_grad,
T* din_grad,
const std::vector<uint64_t> lod,
int64_t width);
template <typename T>
void seq_pool_average_grad(const T* din,
const T* dout_grad,
T* din_grad,
const std::vector<uint64_t> lod,
int64_t width);
template <typename T>
void seq_pool_sqrt_grad(const T* din,
const T* dout_grad,
T* din_grad,
const std::vector<uint64_t> lod,
int64_t width);
template <typename T>
void seq_pool_max_grad(const T* din,
const T* dout_grad,
const int64_t* index_grad,
T* din_grad,
const std::vector<uint64_t> lod,
int64_t width);
template <typename T>
void seq_pool_first_grad(const T* din,
const T* dout_grad,
T* din_grad,
const std::vector<uint64_t> lod,
int64_t width);
template <typename T>
void seq_pool_last_grad(const T* din,
const T* dout_grad,
T* din_grad,
const std::vector<uint64_t> lod,
int64_t width);
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
......@@ -104,6 +104,7 @@ add_kernel(mean_grad_compute_arm ARM train SRCS mean_grad_compute.cc DEPS ${lite
add_kernel(elementwise_grad_compute_arm ARM train SRCS elementwise_grad_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(mul_grad_compute_arm ARM train SRCS mul_grad_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(sgd_compute_arm ARM train SRCS sgd_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(sequence_pool_grad_compute_arm ARM train SRCS sequence_pool_grad_compute.cc DEPS ${lite_kernel_deps} math_arm)
lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm)
lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm)
......
......@@ -32,6 +32,7 @@ void SequencePoolCompute::Run() {
auto& output = param.Out;
const auto* din = param.X->data<float>();
float* dout = output->mutable_data<float>();
int64_t* max_index = param.MaxIndex->mutable_data<int64_t>();
const auto pool_type = param.pool_type;
const auto lod = param.X->lod()[0];
......@@ -44,9 +45,9 @@ void SequencePoolCompute::Run() {
} else if (pool_type == "SQRT") {
lite::arm::math::seq_pool_sqrt(din, dout, lod, width);
} else if (pool_type == "MAX") {
lite::arm::math::seq_pool_max(din, dout, lod, width);
lite::arm::math::seq_pool_max(din, dout, max_index, lod, width);
} else if (pool_type == "MIN") {
lite::arm::math::seq_pool_min(din, dout, lod, width);
lite::arm::math::seq_pool_min(din, dout, max_index, lod, width);
} else if (pool_type == "FIRST") {
lite::arm::math::seq_pool_first(din, dout, lod, width);
} else if (pool_type == "LAST") {
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "lite/kernels/arm/sequence_pool_grad_compute.h"
#include <string>
#include <vector>
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
#include "lite/core/type_system.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void SequencePoolGradCompute::PrepareForRun() {}
void SequencePoolGradCompute::Run() {
auto& param = Param<operators::SequencePoolGradParam>();
auto& output_grad = param.Out_Grad;
auto& x_grad = param.X_Grad;
const auto* din_ptr = param.X->data<float>();
const auto* dout_grad_ptr = output_grad->data<float>();
const auto* index_grad_ptr = param.MaxIndex_Grad->data<int64_t>();
float* x_grad_ptr = x_grad->mutable_data<float>();
const auto pool_type = param.pool_type;
const auto lod = param.X->lod()[0];
int64_t width = param.X->numel() / param.X->dims()[0];
if (pool_type == "SUM") {
lite::arm::math::seq_pool_sum_grad(
din_ptr, dout_grad_ptr, x_grad_ptr, lod, width);
} else if (pool_type == "AVERAGE") {
lite::arm::math::seq_pool_average_grad(
din_ptr, dout_grad_ptr, x_grad_ptr, lod, width);
} else if (pool_type == "SQRT") {
lite::arm::math::seq_pool_sqrt_grad(
din_ptr, dout_grad_ptr, x_grad_ptr, lod, width);
} else if (pool_type == "MAX" || pool_type == "MIN") {
lite::arm::math::seq_pool_max_grad(
din_ptr, dout_grad_ptr, index_grad_ptr, x_grad_ptr, lod, width);
} else if (pool_type == "FIRST") {
lite::arm::math::seq_pool_first_grad(
din_ptr, dout_grad_ptr, x_grad_ptr, lod, width);
} else if (pool_type == "LAST") {
lite::arm::math::seq_pool_last_grad(
din_ptr, dout_grad_ptr, x_grad_ptr, lod, width);
} else {
LOG(ERROR) << " UNKNOWN sequence pool type";
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(sequence_pool_grad,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::SequencePoolGradCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Out@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("X@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("MaxIndex", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdint.h>
#include "lite/backends/arm/math/type_trans.h"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class SequencePoolGradCompute
: public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
void PrepareForRun() override;
void Run() override;
virtual ~SequencePoolGradCompute() = default;
private:
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -165,6 +165,7 @@ add_operator(activation_grad_ops train SRCS activation_grad_ops.cc DEPS ${op_DEP
add_operator(elementwise_grad_op train SRCS elementwise_grad_ops.cc DEPS ${op_DEPS})
add_operator(mul_grad_op train SRCS mul_grad_op.cc DEPS ${op_DEPS})
add_operator(sgd_op train SRCS sgd_op.cc DEPS ${op_DEPS})
add_operator(sequence_pool_grad train SRCS sequence_pool_grad_op.cc DEPS ${op_DEPS})
# Only for XPU
add_operator(__xpu__resnet50_op extra SRCS __xpu__resnet50_op.cc DEPS ${op_DEPS})
......
......@@ -997,10 +997,10 @@ struct BeamSearchParam : ParamBase {
struct SequencePoolParam : ParamBase {
const lite::Tensor* X{};
lite::Tensor* Out{};
lite::Tensor* MaxIndex{};
std::string pool_type{"AVERAGE"};
#ifdef LITE_WITH_X86
float pad_value{0.0};
lite::Tensor* MaxIndex{};
#endif
};
......@@ -1019,6 +1019,18 @@ struct SequencePoolConcatParam : ParamBase {
std::vector<std::string> pool_type{};
};
struct SequencePoolGradParam : ParamBase {
const lite::Tensor* X{};
std::string pool_type{"AVERAGE"};
#ifdef LITE_WITH_X86
float pad_value{0.0};
#endif
// for backward
const lite::Tensor* Out_Grad{};
const lite::Tensor* MaxIndex_Grad{};
lite::Tensor* X_Grad{};
};
struct SearchGroupPaddingParam : ParamBase {
lite::Tensor* x{};
lite::Tensor* out_emb_padding{};
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/sequence_pool_grad_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool SequencePoolGradOp::CheckShape() const {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.X_Grad);
CHECK_OR_FALSE(param_.Out_Grad);
auto lod = param_.X->lod();
CHECK_EQ_OR_FALSE(lod.size(), 1UL);
auto dims = param_.X->dims();
CHECK_GE_OR_FALSE(dims[0], (static_cast<int64_t>(lod[0].size()) - 1));
return true;
}
bool SequencePoolGradOp::InferShapeImpl() const {
const auto *input = param_.X;
auto x_dims = input->dims();
if (param_.X_Grad) {
param_.X_Grad->Resize(x_dims);
param_.X_Grad->set_lod(param_.X->lod());
}
return true;
}
bool SequencePoolGradOp::AttachImpl(const cpp::OpDesc &opdesc,
lite::Scope *scope) {
param_.X = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>());
CHECK(param_.X);
auto *out_grad_var = scope->FindVar(opdesc.Input("Out@GRAD").front());
CHECK(out_grad_var);
param_.Out_Grad = &out_grad_var->Get<Tensor>();
auto *x_grad_var = scope->FindVar(opdesc.Output("X@GRAD").front());
CHECK(x_grad_var);
param_.X_Grad = x_grad_var->GetMutable<Tensor>();
param_.pool_type = opdesc.GetAttr<std::string>("pooltype");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(sequence_pool_grad,
paddle::lite::operators::SequencePoolGradOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class SequencePoolGradOp : public OpLite {
public:
SequencePoolGradOp() {}
explicit SequencePoolGradOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "sequence_pool_grad"; }
private:
mutable SequencePoolGradParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -34,6 +34,7 @@ bool SequencePoolOp::InferShapeImpl() const {
auto out_dims = input->dims();
out_dims[0] = input->lod()[0].size() - 1;
param_.Out->Resize(out_dims);
param_.MaxIndex->Resize(out_dims);
return true;
}
......
......@@ -40,21 +40,22 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_BM AND NOT LIT
lite_cc_test(test_kernel_fill_constant_batch_size_like_compute SRCS fill_constant_batch_size_like_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
if(LITE_BUILD_EXTRA)
lite_cc_test(test_gru_unit SRCS gru_unit_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_sequence_pool_compute SRCS sequence_pool_compute_test.cc DEPS ${bm_kernels} arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_sequence_conv_compute SRCS sequence_conv_compute_test.cc DEPS ${bm_kernels} arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reduce_max_compute SRCS reduce_max_compute_test.cc DEPS arena_framework ${bm_kernels} ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_unsqueeze_compute SRCS unsqueeze_compute_test.cc DEPS arena_framework ${bm_kernels} ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_assign_compute SRCS assign_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_assign_value_compute SRCS assign_value_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_box_clip_compute SRCS box_clip_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reduce_mean_compute SRCS reduce_mean_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reduce_sum_compute SRCS reduce_sum_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reduce_prod_compute SRCS reduce_prod_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_stack_compute SRCS stack_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_range_compute SRCS range_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_affine_channel_compute SRCS affine_channel_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_anchor_generator_compute SRCS anchor_generator_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_gru_unit SRCS gru_unit_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_sequence_pool_compute SRCS sequence_pool_compute_test.cc DEPS ${bm_kernels} arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_sequence_conv_compute SRCS sequence_conv_compute_test.cc DEPS ${bm_kernels} arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reduce_max_compute SRCS reduce_max_compute_test.cc DEPS arena_framework ${bm_kernels} ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_unsqueeze_compute SRCS unsqueeze_compute_test.cc DEPS arena_framework ${bm_kernels} ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_assign_compute SRCS assign_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_assign_value_compute SRCS assign_value_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_box_clip_compute SRCS box_clip_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reduce_mean_compute SRCS reduce_mean_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reduce_sum_compute SRCS reduce_sum_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reduce_prod_compute SRCS reduce_prod_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_stack_compute SRCS stack_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_range_compute SRCS range_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_affine_channel_compute SRCS affine_channel_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_anchor_generator_compute SRCS anchor_generator_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_generate_proposals_compute SRCS generate_proposals_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_roi_align_compute SRCS roi_align_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_search_aligned_mat_mul_compute SRCS search_aligned_mat_mul_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......@@ -72,6 +73,7 @@ if(LITE_BUILD_EXTRA)
lite_cc_test(test_kernel_elementwise_grad_compute SRCS elementwise_grad_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_mul_grad_compute SRCS mul_grad_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_sgd_compute SRCS sgd_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_sequence_pool_grad_compute SRCS sequence_pool_grad_compute_test.cc DEPS ${bm_kernels} arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
endif()
endif()
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/sequence_pool_grad_compute.h"
#include <gtest/gtest.h>
#include <algorithm>
#include <cmath>
#include "lite/core/op_registry.h"
#include "lite/kernels/arm/sequence_pool_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
using param_t = operators::SequencePoolParam;
using grad_param_t = operators::SequencePoolGradParam;
using kernel_t = SequencePoolCompute;
using grad_kernel_t = SequencePoolGradCompute;
void sequence_pool_grad_common(grad_param_t* param,
float* out_grad,
int64_t* index_grad,
float* x_grad,
std::string pool_type) {
const auto lod = param->X->lod()[0];
int64_t width = param->X->numel() / param->X->dims()[0];
if (pool_type == "SUM") {
for (int i = 0; i < static_cast<int>(lod.size()) - 1; i++) {
int64_t height = static_cast<int64_t>(lod[i + 1] - lod[i]);
float* out_grad_ptr = out_grad + i * width;
float* x_grad_ptr = x_grad + lod[i] * width;
if (height > 0) {
if (width == 1) {
for (int h = 0; h < height; ++h) {
x_grad_ptr[h] = out_grad_ptr[h];
}
} else {
for (int h = 0; h < height; h++) {
for (int w = 0; w < width; w++) {
x_grad_ptr[w] = out_grad_ptr[w];
}
x_grad_ptr += width;
}
}
}
}
} else if (pool_type == "AVERAGE") {
for (int i = 0; i < static_cast<int>(lod.size()) - 1; i++) {
int64_t height = static_cast<int64_t>(lod[i + 1] - lod[i]);
const float* out_grad_ptr = out_grad + i * width;
float* x_grad_ptr = x_grad + lod[i] * width;
float alpha = 1.0 / height;
if (height > 0) {
if (width == 1) {
for (int h = 0; h < height; ++h) {
x_grad_ptr[h] = alpha * out_grad_ptr[h];
}
} else {
for (int h = 0; h < height; h++) {
for (int w = 0; w < width; w++) {
x_grad_ptr[w] = alpha * out_grad_ptr[w];
}
x_grad_ptr += width;
}
}
}
}
} else if (pool_type == "SQRT") {
for (int i = 0; i < static_cast<int>(lod.size()) - 1; i++) {
int64_t height = static_cast<int64_t>(lod[i + 1] - lod[i]);
const float* out_grad_ptr = out_grad + i * width;
float* x_grad_ptr = x_grad + lod[i] * width;
float alpha = 1.0 / sqrtf(height);
if (height > 0) {
if (width == 1) {
for (int h = 0; h < height; ++h) {
x_grad_ptr[h] = alpha * out_grad_ptr[h];
}
} else {
for (int h = 0; h < height; h++) {
for (int w = 0; w < width; w++) {
x_grad_ptr[w] = alpha * out_grad_ptr[w];
}
x_grad_ptr += width;
}
}
}
}
} else if (pool_type == "MAX" || pool_type == "MIN") {
for (int i = 0; i < static_cast<int>(lod.size()) - 1; i++) {
int64_t height = static_cast<int64_t>(lod[i + 1] - lod[i]);
const float* out_grad_ptr = out_grad + i * width;
const int64_t* index_grad_ptr = index_grad + i * width;
float* x_grad_ptr = x_grad + lod[i] * width;
float alpha = 1.0 / sqrtf(height);
if (height > 0) {
for (int w = 0; w < width; w++) {
for (int h = 0; h < height; h++) {
if (h == index_grad_ptr[w]) {
x_grad_ptr[h * width + w] = out_grad_ptr[w];
} else {
x_grad_ptr[h * width + w] = 0.f;
}
}
}
}
}
} else if (pool_type == "FIRST") {
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
int64_t height = static_cast<int64_t>(lod[i + 1] - lod[i]);
const float* out_grad_ptr = out_grad + i * width;
float* x_grad_ptr = x_grad + lod[i] * width;
if (height > 0) {
for (int w = 0; w < width; w++) {
for (int h = 0; h < height; h++) {
if (h == 0) {
x_grad_ptr[h * width + w] = out_grad_ptr[w];
} else {
x_grad_ptr[h * width + w] = 0.f;
}
}
}
}
}
} else if (pool_type == "LAST") {
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
int64_t height = static_cast<int64_t>(lod[i + 1] - lod[i]);
const float* out_grad_ptr = out_grad + i * width;
float* x_grad_ptr = x_grad + lod[i] * width;
if (height > 0) {
for (int w = 0; w < width; w++) {
for (int h = 0; h < height; h++) {
if (h == height - 1) {
x_grad_ptr[h * width + w] = out_grad_ptr[w];
} else {
x_grad_ptr[h * width + w] = 0.f;
}
}
}
}
}
} else {
LOG(FATAL) << " UNKNOWN sequence pool type";
}
}
void generate_lod(int seq_num,
int max_len,
std::vector<uint64_t>& seq_offset) { // NOLINT
seq_offset.clear();
int sum = 0;
seq_offset.push_back(sum);
for (int i = 0; i < seq_num; i++) {
sum += std::rand() % max_len + 1;
seq_offset.push_back(uint64_t(sum));
}
}
class SequencePoolGradTester {
public:
explicit SequencePoolGradTester(DDim dims,
std::vector<std::vector<uint64_t>> lod,
std::string pool_type)
: dims_(dims), lod_(lod), pool_type_(pool_type) {}
void prepare_kernel() {
std::unique_ptr<KernelContext> ctx1(new KernelContext);
ctx1->As<ARMContext>();
kernel_.SetContext(std::move(ctx1));
std::unique_ptr<KernelContext> ctx2(new KernelContext);
ctx2->As<ARMContext>();
delta_kernel_.SetContext(std::move(ctx2));
std::unique_ptr<KernelContext> ctx3(new KernelContext);
ctx3->As<ARMContext>();
grad_kernel_.SetContext(std::move(ctx3));
}
void run_forward(param_t* param,
kernel_t* kernel,
const std::vector<float>& in_vec,
int64_t* out_index_vec,
float* out_vec) {
Tensor x;
Tensor output;
Tensor index;
x.Resize(dims_);
output.Resize(out_dims_);
index.Resize(out_dims_);
auto* x_data = x.mutable_data<float>();
for (int i = 0; i < dims_.production(); i++) {
x_data[i] = in_vec[i];
}
x.set_lod(lod_);
param->X = &x;
param->pool_type = pool_type_;
param->Out = &output;
param->MaxIndex = &index;
kernel->SetParam(*param);
kernel->Launch();
auto* output_data = output.data<float>();
auto* output_index = index.data<int64_t>();
for (int i = 0; i < output.numel(); i++) {
out_vec[i] = output_data[i];
out_index_vec[i] = output_index[i];
}
}
void run_backward(grad_param_t* param,
grad_kernel_t* kernel,
const std::vector<float>& in_vec,
const std::vector<float>& out_grad_vec,
const std::vector<int64_t>& out_index_grad_vec,
float* in_grad_vec) {
Tensor x;
Tensor x_grad;
Tensor out_grad;
Tensor out_index_grad;
x.Resize(dims_);
x.set_lod(lod_);
// backword
x_grad.Resize(dims_);
out_grad.Resize(out_dims_);
out_index_grad.Resize(out_dims_);
auto* x_data = x.mutable_data<float>();
auto* out_grad_data = out_grad.mutable_data<float>();
auto* out_index_grad_data = out_index_grad.mutable_data<int64_t>();
for (int i = 0; i < dims_.production(); i++) {
x_data[i] = in_vec[i];
}
for (int i = 0; i < out_dims_.production(); i++) {
out_grad_data[i] = out_grad_vec[i];
out_index_grad_data[i] = out_index_grad_vec[i];
}
param->X = &x;
param->X_Grad = &x_grad;
param->Out_Grad = &out_grad;
param->MaxIndex_Grad = &out_index_grad;
param->pool_type = pool_type_;
kernel->SetParam(*param);
kernel->Launch();
auto* x_grad_data = x_grad.data<float>();
for (int i = 0; i < dims_.production(); i++) {
in_grad_vec[i] = x_grad_data[i];
}
LOG(INFO) << "end";
}
void check_grad(float delta, float max_grad_delta) {
std::vector<int64_t> out_shape;
out_dims_ = dims_;
out_dims_[0] = lod_[0].size() - 1;
std::vector<float> x(dims_.production());
std::vector<float> out(out_dims_.production());
std::vector<int64_t> index(out_dims_.production());
for (int i = 0; i < dims_.production(); i++) {
x[i] = static_cast<float>(i % 3 - 2.0) / 2.0 * 0.333 +
static_cast<float>(i % 19 - 10.0) / 10.0 * 0.333 +
static_cast<float>(i % 39 - 20.0) / 20.0 * 0.333 + 0.001213;
}
LOG(INFO) << "run_forward:";
this->run_forward(&param_, &kernel_, x, index.data(), out.data());
std::vector<float> out_grad(out_dims_.production());
std::vector<float> x_grad(dims_.production());
for (int i = 0; i < out_dims_.production(); i++) {
out_grad[i] = 1.0;
}
LOG(INFO) << "run_backward:";
this->run_backward(
&grad_param_, &grad_kernel_, x, out_grad, index, x_grad.data());
// get numeric gradient
std::vector<float> x_delta(dims_.production());
std::vector<float> out_delta(out_dims_.production());
Tensor tensor_x;
tensor_x.Resize(dims_);
tensor_x.set_lod(lod_);
grad_param_.X = &tensor_x;
LOG(INFO) << "sequence_pool_grad_common";
sequence_pool_grad_common(&grad_param_,
out_grad.data(),
index.data(),
x_delta.data(),
pool_type_);
for (int i = 0; i < dims_.production(); i++) {
EXPECT_NEAR(x_grad[i], x_delta[i], max_grad_delta);
}
}
private:
DDim dims_;
DDim out_dims_;
std::vector<std::vector<uint64_t>> lod_;
std::string pool_type_;
kernel_t kernel_;
kernel_t delta_kernel_;
grad_kernel_t grad_kernel_;
param_t param_;
param_t delta_param_;
grad_param_t grad_param_;
};
void TestSequencePoolGrad(DDim dims,
std::vector<std::vector<uint64_t>> lod,
std::string pool_type) {
LOG(INFO) << "Test SequencePool grad";
std::unique_ptr<SequencePoolGradTester> tester(
new SequencePoolGradTester(dims, lod, pool_type));
tester->prepare_kernel();
float delta = 0.001;
float max_grad_delta = 0.005;
tester->check_grad(delta, max_grad_delta);
}
TEST(sequence_pool_grad_host, compute) {
#ifdef LITE_WITH_ARM
int max_len = 2;
for (auto c : {2, 4}) {
for (auto h : {1, 3, 4}) {
for (auto w : {1, 3, 4}) {
for (auto pool_type :
{"SUM", "AVERAGE", "SQRT", "MAX", "MIN", "FIRST", "LAST"}) {
for (auto seq_num : {1, 3, 5}) {
std::vector<std::vector<uint64_t>> lod;
lod.resize(1);
generate_lod(seq_num, max_len, lod[0]);
int64_t n = int64_t(lod[0].back());
LOG(INFO) << "sequence_pool_grad parameter: "
<< ", n = " << n << ", c = " << c << ", h = " << h
<< ", w = " << w << ", seq_num = " << seq_num
<< ", pool_type = " << pool_type;
TestSequencePoolGrad(
DDim(std::vector<int64_t>({n, c, h, w})), lod, pool_type);
}
}
}
}
}
#endif
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(sequence_pool, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(sequence_pool_grad, kARM, kFloat, kNCHW, def);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册