提交 1391e255 编写于 作者: C chenjiaoAngel

fix build error

上级 903bbc02
...@@ -12,12 +12,12 @@ ...@@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/backends/arm/math/sequence_pool.h"
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
#include <limits> #include <limits>
#include <vector> #include <vector>
#include "lite/backends/arm/math/funcs.h" #include "lite/backends/arm/math/funcs.h"
#include "lite/backends/arm/math/sequence_pool_grad.h"
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/core/tensor.h" #include "lite/core/tensor.h"
#include "lite/core/type_system.h" #include "lite/core/type_system.h"
...@@ -28,7 +28,7 @@ namespace arm { ...@@ -28,7 +28,7 @@ namespace arm {
namespace math { namespace math {
template <> template <>
void seq_pool_sum<float>(const float* din, void seq_pool_sum_grad<float>(const float* din,
const float* din_grad, const float* din_grad,
float* dout, float* dout,
const std::vector<uint64_t> lod, const std::vector<uint64_t> lod,
...@@ -55,7 +55,7 @@ void seq_pool_sum<float>(const float* din, ...@@ -55,7 +55,7 @@ void seq_pool_sum<float>(const float* din,
} }
template <> template <>
void seq_pool_average<float>(const float* din, void seq_pool_average_grad<float>(const float* din,
const float* din_grad, const float* din_grad,
float* dout, float* dout,
const std::vector<uint64_t> lod, const std::vector<uint64_t> lod,
...@@ -85,7 +85,7 @@ void seq_pool_average<float>(const float* din, ...@@ -85,7 +85,7 @@ void seq_pool_average<float>(const float* din,
} }
template <> template <>
void seq_pool_sqrt<float>(const float* din, void seq_pool_sqrt_grad<float>(const float* din,
const float* din_grad, const float* din_grad,
float* dout, float* dout,
const std::vector<uint64_t> lod, const std::vector<uint64_t> lod,
...@@ -115,7 +115,7 @@ void seq_pool_sqrt<float>(const float* din, ...@@ -115,7 +115,7 @@ void seq_pool_sqrt<float>(const float* din,
} }
template <> template <>
void seq_pool_first<float>(const float* din, void seq_pool_first_grad<float>(const float* din,
const float* din_grad, const float* din_grad,
float* dout, float* dout,
const std::vector<uint64_t> lod, const std::vector<uint64_t> lod,
......
...@@ -28,6 +28,13 @@ void seq_pool_sum_grad(const T* din, ...@@ -28,6 +28,13 @@ void seq_pool_sum_grad(const T* din,
const std::vector<uint64_t> lod, const std::vector<uint64_t> lod,
int64_t width); int64_t width);
template <typename T>
void seq_pool_average_grad(const T* din,
const T* din_grad,
T* dout,
const std::vector<uint64_t> lod,
int64_t width);
template <typename T> template <typename T>
void seq_pool_sqrt_grad(const T* din, void seq_pool_sqrt_grad(const T* din,
const T* din_grad, const T* din_grad,
......
...@@ -57,13 +57,14 @@ void SequencePoolGradCompute::Run() { ...@@ -57,13 +57,14 @@ void SequencePoolGradCompute::Run() {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL(sequence_pool, REGISTER_LITE_KERNEL(sequence_pool_grad,
kARM, kARM,
kFloat, kFloat,
kNCHW, kNCHW,
paddle::lite::kernels::arm::SequencePoolCompute, paddle::lite::kernels::arm::SequencePoolGradCompute,
def) def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Out@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("X@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("MaxIndex", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("MaxIndex", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
...@@ -1012,13 +1012,13 @@ struct SequencePoolConcatParam : ParamBase { ...@@ -1012,13 +1012,13 @@ struct SequencePoolConcatParam : ParamBase {
struct SequencePoolGradParam : ParamBase { struct SequencePoolGradParam : ParamBase {
const lite::Tensor* X{}; const lite::Tensor* X{};
lite::Tensor* Out_Grad{};
std::string pool_type{"AVERAGE"}; std::string pool_type{"AVERAGE"};
#ifdef LITE_WITH_X86 #ifdef LITE_WITH_X86
float pad_value{0.0}; float pad_value{0.0};
lite::Tensor* MaxIndex{}; lite::Tensor* MaxIndex{};
#endif #endif
// for backward // for backward
const lite::Tensor* Out_Grad{};
lite::Tensor* X_Grad{}; lite::Tensor* X_Grad{};
}; };
......
...@@ -19,7 +19,7 @@ namespace paddle { ...@@ -19,7 +19,7 @@ namespace paddle {
namespace lite { namespace lite {
namespace operators { namespace operators {
bool SequencePoolOp::CheckShape() const { bool SequencePoolGradOp::CheckShape() const {
CHECK_OR_FALSE(param_.X); CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.X_Grad); CHECK_OR_FALSE(param_.X_Grad);
CHECK_OR_FALSE(param_.Out_Grad); CHECK_OR_FALSE(param_.Out_Grad);
...@@ -30,7 +30,7 @@ bool SequencePoolOp::CheckShape() const { ...@@ -30,7 +30,7 @@ bool SequencePoolOp::CheckShape() const {
return true; return true;
} }
bool SequencePoolOp::InferShapeImpl() const { bool SequencePoolGradOp::InferShapeImpl() const {
const auto *input = param_.X; const auto *input = param_.X;
auto x_dims = input->dims(); auto x_dims = input->dims();
if (param_.X_Grad) { if (param_.X_Grad) {
...@@ -40,18 +40,18 @@ bool SequencePoolOp::InferShapeImpl() const { ...@@ -40,18 +40,18 @@ bool SequencePoolOp::InferShapeImpl() const {
return true; return true;
} }
bool SequencePoolOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { bool SequencePoolGradOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.X = const_cast<lite::Tensor *>( param_.X = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>()); &scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>());
CHECK(param_.X); CHECK(param_.X);
if (!op_desc.Input("Out@GRAD").empty()) { if (!opdesc.Input("Out@GRAD").empty()) {
auto *out_grad_var = scope->FindVar(op_desc.Input("Out@GRAD").front()); auto *out_grad_var = scope->FindVar(opdesc.Input("Out@GRAD").front());
CHECK(out_grad_var); CHECK(out_grad_var);
param_.Out_Grad = &out_grad_var->Get<Tensor>(); param_.Out_Grad = &out_grad_var->Get<Tensor>();
} }
if (!op_desc.Output("X@GRAD").empty()) { if (!opdesc.Output("X@GRAD").empty()) {
auto *x_grad_var = scope->FindVar(op_desc.Output("X@GRAD").front()); auto *x_grad_var = scope->FindVar(opdesc.Output("X@GRAD").front());
CHECK(x_grad_var); CHECK(x_grad_var);
param_.X_Grad = x_grad_var->GetMutable<Tensor>(); param_.X_Grad = x_grad_var->GetMutable<Tensor>();
} }
...@@ -63,4 +63,4 @@ bool SequencePoolOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { ...@@ -63,4 +63,4 @@ bool SequencePoolOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_OP(sequence_pool, paddle::lite::operators::SequencePoolOp); REGISTER_LITE_OP(sequence_pool_grad, paddle::lite::operators::SequencePoolGradOp);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册