未验证 提交 b9273631 编写于 作者: M mapingshuo 提交者: GitHub

add mean op (#3130)

* add mean op, test=develop

* split forward and backward, test=develop
上级 054fd1af
......@@ -61,6 +61,7 @@ lite_option(LITE_WITH_ARM "Enable ARM in lite mode" OFF)
lite_option(LITE_WITH_NPU "Enable NPU in lite mode" OFF)
lite_option(LITE_WITH_XPU "Enable XPU in lite mode" OFF)
lite_option(LITE_WITH_BM "Enable BM in lite mode" OFF)
lite_option(LITE_WITH_TRAIN "Enable training operators and kernels in lite" OFF)
lite_option(LITE_WITH_OPENMP "Enable OpenMP in lite framework" ON)
lite_option(LITE_WITH_OPENCL "Enable OpenCL support in lite" OFF)
lite_option(LITE_WITH_FPGA "Enable FPGA support in lite" OFF)
......
......@@ -122,6 +122,9 @@ if (LITE_WITH_ARM)
endif()
endif()
if (LITE_WITH_TRAIN)
add_definitions("-DLITE_WITH_TRAIN")
endif()
if (WITH_ARM_DOTPROD)
add_definitions("-DWITH_ARM_DOTPROD")
......
......@@ -333,16 +333,16 @@ lite::Tensor *Predictor::GetInputByName(const std::string &name) {
}
}
#ifdef LITE_WITH_TRAIN
void Predictor::FeedVars(const std::vector<framework::Tensor> &tensors) {
auto var = scope_->FindVar("feed");
auto &feed_list = *(var->GetMutable<std::vector<lite::Tensor>>());
feed_list.resize(tensors.size());
for (size_t i = 0; i < tensors.size(); ++i)
feed_list[i].ShareDataWith(tensors[i]);
}
#endif
// #ifdef LITE_WITH_TRAIN
// void Predictor::FeedVars(const std::vector<framework::Tensor> &tensors) {
// auto var = scope_->FindVar("feed");
// auto &feed_list = *(var->GetMutable<std::vector<lite::Tensor>>());
// feed_list.resize(tensors.size());
// for (size_t i = 0; i < tensors.size(); ++i)
// feed_list[i].ShareDataWith(tensors[i]);
// }
// #endif
} // namespace lite
} // namespace paddle
......@@ -101,14 +101,14 @@ class LITE_API Predictor {
bool record_info = false);
void SaveOpKernelInfo(const std::string& model_dir);
#ifdef LITE_WITH_TRAIN
void Run(const std::vector<framework::Tensor>& tensors) {
FeedVars(tensors);
program_->Run();
}
void FeedVars(const std::vector<framework::Tensor>& tensors);
#endif
// #ifdef LITE_WITH_TRAIN
// void Run(const std::vector<framework::Tensor>& tensors) {
// FeedVars(tensors);
// program_->Run();
// }
// void FeedVars(const std::vector<framework::Tensor>& tensors);
// #endif
private:
Optimizer optimizer_;
......
......@@ -105,6 +105,11 @@ add_kernel(lod_reset_compute_arm ARM extra SRCS lod_reset_compute.cc DEPS ${lite
add_kernel(is_empty_compute_arm ARM extra SRCS is_empty_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(lstm_arm ARM extra SRCS lstm_compute.cc DEPS ${lite_kernel_deps} math_arm)
# 4. training kernels
add_kernel(mean_compute_arm ARM extra SRCS mean_compute.cc DEPS ${lite_kernel_deps} math_arm)
if(LITE_WITH_TRAIN)
add_kernel(mean_grad_compute_arm ARM extra SRCS mean_grad_compute.cc DEPS ${lite_kernel_deps} math_arm)
endif()
lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm)
lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm)
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/mean_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void MeanCompute::Run() {
auto& param = this->Param<operators::MeanParam>();
const auto* input = param.X;
auto* output = param.Out;
auto x_dim = input->dims();
auto x_data = input->data<float>();
auto out_data = output->mutable_data<float>();
int x_size = x_dim.production();
float sum = 0;
for (int i = 0; i < x_size; i++) {
sum += x_data[i];
}
out_data[0] = sum / x_size;
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
mean, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::MeanCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
#include "lite/operators/mean_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class MeanCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::MeanParam;
void Run() override;
virtual ~MeanCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/mean_grad_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void MeanGradCompute::Run() {
auto& param = this->Param<operators::MeanGradParam>();
const auto* input = param.X;
const auto* out_grad = param.Out_grad;
auto* input_grad = param.X_grad;
auto out_grad_data = out_grad->data<float>();
auto input_data = input->data<float>();
auto input_grad_data = input_grad->mutable_data<float>();
int input_grad_size = input_grad->dims().production();
// TODO(mapingshuo): use parallel methods to accelerate this for loop
for (int i = 0; i < input_grad_size; i++) {
input_grad_data[i] = out_grad_data[0] / input_grad_size;
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(mean_grad,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::MeanGradCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Out@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("X@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
#include "lite/operators/mean_grad_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class MeanGradCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::MeanGradParam;
void Run() override;
virtual ~MeanGradCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -18,7 +18,6 @@ add_operator(activation_ops basic SRCS activation_ops.cc DEPS ${op_DEPS})
add_operator(elementwise_ops basic SRCS elementwise_ops.cc DEPS ${op_DEPS})
add_operator(box_coder_op_lite basic SRCS box_coder_op.cc DEPS ${op_DEPS})
add_operator(multiclass_nms_op_lite basic SRCS multiclass_nms_op.cc DEPS ${op_DEPS})
add_operator(mean_op basic SRCS mean_op.cc DEPS ${op_DEPS})
add_operator(fill_constant_op basic SRCS fill_constant_op.cc DEPS ${op_DEPS})
add_operator(fill_constant_batch_size_like_op basic SRCS fill_constant_batch_size_like_op.cc DEPS ${op_DEPS})
add_operator(shuffle_channel_op basic SRCS shuffle_channel_op.cc DEPS ${op_DEPS})
......@@ -139,6 +138,12 @@ add_operator(sequence_topk_avg_pooling_op basic SRCS sequence_topk_avg_pooling_o
add_operator(search_fc_op basic SRCS search_fc_op.cc DEPS ${op_DEPS})
add_operator(lstm_op extra SRCS lstm_op.cc DEPS ${op_DEPS})
# 4. training op
add_operator(mean_op extra SRCS mean_op.cc DEPS ${op_DEPS})
if (LITE_WITH_TRAIN)
add_operator(mean_grad_op extra SRCS mean_grad_op.cc DEPS ${op_DEPS})
endif()
if (NOT LITE_WITH_X86)
lite_cc_test(test_fc_op SRCS fc_op_test.cc
DEPS fc_op memory
......
......@@ -78,45 +78,45 @@ bool ActivationOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
return true;
}
#ifdef LITE_WITH_TRAIN
bool ActivationGradOp::CheckShape() const {
CHECK_OR_FALSE(param_.X_grad);
CHECK_OR_FALSE(param_.Out_grad);
return true;
}
bool ActivationGradOp::InferShape() const {
param_.X_grad->Resize(param_.Out_grad->dims());
return true;
}
bool ActivationGradOp::AttachImpl(const cpp::OpDesc& opdesc,
lite::Scope* scope) {
auto Out_grad_name = opdesc.Input(framework::GradVarName("Out")).front();
auto X_grad_name = opdesc.Output(framework::GradVarName("X")).front();
param_.Out_grad = GetVar<lite::Tensor>(scope, Out_grad_name);
param_.X_grad = GetMutableVar<Tensor>(scope, X_grad_name);
if (opdesc.HasInput("X")) {
auto X_name = opdesc.Input("X").front();
param_.X = GetVar<lite::Tensor>(scope, X_name);
} else {
param_.X = param_.X_grad;
}
if (opdesc.HasInput("Out")) {
auto Out_name = opdesc.Input("Out").front();
param_.Out = GetVar<lite::Tensor>(scope, Out_name);
} else {
param_.Out = param_.Out_grad;
}
return true;
}
#endif
// #ifdef LITE_WITH_TRAIN
// bool ActivationGradOp::CheckShape() const {
// CHECK_OR_FALSE(param_.X_grad);
// CHECK_OR_FALSE(param_.Out_grad);
// return true;
// }
// bool ActivationGradOp::InferShape() const {
// param_.X_grad->Resize(param_.Out_grad->dims());
// return true;
// }
// bool ActivationGradOp::AttachImpl(const cpp::OpDesc& opdesc,
// lite::Scope* scope) {
// auto Out_grad_name = opdesc.Input(framework::GradVarName("Out")).front();
// auto X_grad_name = opdesc.Output(framework::GradVarName("X")).front();
// param_.Out_grad = GetVar<lite::Tensor>(scope, Out_grad_name);
// param_.X_grad = GetMutableVar<Tensor>(scope, X_grad_name);
// if (opdesc.HasInput("X")) {
// auto X_name = opdesc.Input("X").front();
// param_.X = GetVar<lite::Tensor>(scope, X_name);
// } else {
// param_.X = param_.X_grad;
// }
// if (opdesc.HasInput("Out")) {
// auto Out_name = opdesc.Input("Out").front();
// param_.Out = GetVar<lite::Tensor>(scope, Out_name);
// } else {
// param_.Out = param_.Out_grad;
// }
// return true;
// }
// #endif
} // namespace operators
} // namespace lite
......@@ -139,6 +139,6 @@ REGISTER_LITE_OP(rsqrt, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(softsign, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(gelu, paddle::lite::operators::ActivationOp);
#ifdef LITE_WITH_TRAIN
REGISTER_LITE_OP(square_grad, paddle::lite::operators::ActivationGradOp);
#endif
// #ifdef LITE_WITH_TRAIN
// REGISTER_LITE_OP(square_grad, paddle::lite::operators::ActivationGradOp);
// #endif
......@@ -38,25 +38,26 @@ class ActivationOp : public OpLite {
mutable operators::ActivationParam param_;
};
#ifdef LITE_WITH_TRAIN
class ActivationGradOp : public OpLite {
public:
explicit ActivationGradOp(const std::string& type) : OpLite(type) {}
// #ifdef LITE_WITH_TRAIN
// class ActivationGradOp : public OpLite {
// public:
// explicit ActivationGradOp(const std::string& type) : OpLite(type) {}
bool CheckShape() const override;
// bool CheckShape() const override;
bool InferShape() const override;
// bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
// bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
// void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_);
// }
std::string DebugString() const override { return "activation_grad_op"; }
// std::string DebugString() const override { return "activation_grad_op"; }
private:
mutable operators::ActivationGradParam param_;
};
#endif
// private:
// mutable operators::ActivationGradParam param_;
// };
// #endif
} // namespace operators
} // namespace lite
......
......@@ -96,39 +96,39 @@ bool ElementwiseOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
return true;
}
#ifdef LITE_WITH_TRAIN
bool ElementwiseGradExplicitOp::CheckShape() const {
CHECK_OR_FALSE(param_.Y);
CHECK_OR_FALSE(param_.X_grad);
CHECK_OR_FALSE(param_.Out_grad);
return true;
}
// #ifdef LITE_WITH_TRAIN
// bool ElementwiseGradExplicitOp::CheckShape() const {
// CHECK_OR_FALSE(param_.Y);
// CHECK_OR_FALSE(param_.X_grad);
// CHECK_OR_FALSE(param_.Out_grad);
// return true;
//}
bool ElementwiseGradExplicitOp::InferShape() const {
param_.X_grad->Resize(param_.Out_grad->dims());
if (param_.Y_grad) param_.Y_grad->Resize(param_.Y->dims());
return true;
}
// bool ElementwiseGradExplicitOp::InferShape() const {
// param_.X_grad->Resize(param_.Out_grad->dims());
// if (param_.Y_grad) param_.Y_grad->Resize(param_.Y->dims());
// return true;
// }
bool ElementwiseGradExplicitOp::AttachImpl(const cpp::OpDesc& opdesc,
lite::Scope* scope) {
CHECK_EQ(opdesc.InputArgumentNames().size(), 2UL);
auto Y_name = opdesc.Input("Y").front();
auto Out_name = opdesc.Input(framework::GradVarName("Out")).front();
auto X_grad = opdesc.Output(framework::GradVarName("X")).front();
// bool ElementwiseGradExplicitOp::AttachImpl(const cpp::OpDesc& opdesc,
// lite::Scope* scope) {
// CHECK_EQ(opdesc.InputArgumentNames().size(), 2UL);
// auto Y_name = opdesc.Input("Y").front();
// auto Out_name = opdesc.Input(framework::GradVarName("Out")).front();
// auto X_grad = opdesc.Output(framework::GradVarName("X")).front();
if (opdesc.Output(framework::GradVarName("Y")).size() > 0) {
auto Y_grad = opdesc.Output(framework::GradVarName("Y")).front();
param_.Y_grad = GetMutableVar<Tensor>(scope, Y_grad);
}
param_.Y = GetVar<lite::Tensor>(scope, Y_name);
param_.Out_grad = GetVar<lite::Tensor>(scope, Out_name);
param_.X_grad = GetMutableVar<lite::Tensor>(scope, X_grad);
param_.axis = opdesc.GetAttr<int>("axis");
// if (opdesc.Output(framework::GradVarName("Y")).size() > 0) {
// auto Y_grad = opdesc.Output(framework::GradVarName("Y")).front();
// param_.Y_grad = GetMutableVar<Tensor>(scope, Y_grad);
// }
// param_.Y = GetVar<lite::Tensor>(scope, Y_name);
// param_.Out_grad = GetVar<lite::Tensor>(scope, Out_name);
// param_.X_grad = GetMutableVar<lite::Tensor>(scope, X_grad);
// param_.axis = opdesc.GetAttr<int>("axis");
return true;
}
#endif
// return true;
// }
// #endif
} // namespace operators
} // namespace lite
......@@ -141,7 +141,9 @@ REGISTER_LITE_OP(elementwise_mul, paddle::lite::operators::ElementwiseOp);
REGISTER_LITE_OP(elementwise_max, paddle::lite::operators::ElementwiseOp);
REGISTER_LITE_OP(elementwise_div, paddle::lite::operators::ElementwiseOp);
#ifdef LITE_WITH_TRAIN
REGISTER_LITE_OP(elementwise_sub_grad,
paddle::lite::operators::ElementwiseGradExplicitOp);
#endif
// #ifdef LITE_WITH_TRAIN
// REGISTER_LITE_OP(elementwise_sub_grad,
// paddle::lite::operators::ElementwiseGradExplicitOp);
// REGISTER_LITE_OP(elementwise_add_grad,
// paddle::lite::operators::ElementwiseGradExplicitOp);
// #endif
......@@ -39,27 +39,29 @@ class ElementwiseOp : public OpLite {
mutable operators::ElementwiseParam param_;
};
#ifdef LITE_WITH_TRAIN
class ElementwiseGradExplicitOp : public OpLite {
public:
explicit ElementwiseGradExplicitOp(const std::string& type) : OpLite(type) {}
// #ifdef LITE_WITH_TRAIN
// class ElementwiseGradExplicitOp : public OpLite {
// public:
// explicit ElementwiseGradExplicitOp(const std::string& type) : OpLite(type)
// {}
bool CheckShape() const override;
// bool CheckShape() const override;
bool InferShape() const override;
// bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
// bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
// void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_);
// }
std::string DebugString() const override {
return "elementwise_grad_explicit_op";
}
// std::string DebugString() const override {
// return "elementwise_grad_explicit_op";
// }
private:
mutable operators::ElementwiseGradParam param_;
};
#endif
// private:
// mutable operators::ElementwiseGradParam param_;
// };
// #endif
} // namespace operators
} // namespace lite
......
......@@ -100,8 +100,8 @@ REGISTER_LITE_OP(fusion_elementwise_max_activation,
REGISTER_LITE_OP(fusion_elementwise_div_activation,
paddle::lite::operators::FusionElementwiseActivationOp);
#ifdef LITE_WITH_TRAIN
REGISTER_LITE_OP(
fusion_elementwise_sub_activation_grad,
paddle::lite::operators::FusionElementwiseActivationGradExplicitOp);
#endif
// #ifdef LITE_WITH_TRAIN
// REGISTER_LITE_OP(
// fusion_elementwise_sub_activation_grad,
// paddle::lite::operators::FusionElementwiseActivationGradExplicitOp);
// #endif
......@@ -43,28 +43,29 @@ class FusionElementwiseActivationOp : public OpLite {
mutable operators::FusionElementwiseActivationParam param_;
};
#ifdef LITE_WITH_TRAIN
class FusionElementwiseActivationGradExplicitOp : public OpLite {
public:
explicit FusionElementwiseActivationGradExplicitOp(const std::string& type)
: OpLite(type) {}
// #ifdef LITE_WITH_TRAIN
// class FusionElementwiseActivationGradExplicitOp : public OpLite {
// public:
// explicit FusionElementwiseActivationGradExplicitOp(const std::string& type)
// : OpLite(type) {}
bool CheckShape() const override;
// bool CheckShape() const override;
bool InferShape() const override;
// bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
// bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
// void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_);
// }
std::string DebugString() const override {
return "fusion_elementwise_activation_grad_explicit_op";
}
// std::string DebugString() const override {
// return "fusion_elementwise_activation_grad_explicit_op";
// }
private:
mutable operators::FusionElementwiseActivationGradParam param_;
};
#endif
// private:
// mutable operators::FusionElementwiseActivationGradParam param_;
// };
// #endif
} // namespace operators
} // namespace lite
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/mean_grad_op.h"
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool MeanGradOp::CheckShape() const {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Out_grad);
CHECK_OR_FALSE(param_.X_grad);
return true;
}
bool MeanGradOp::InferShape() const {
param_.X_grad->Resize(param_.X->dims());
return true;
}
bool MeanGradOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
CHECK_EQ(opdesc.InputArgumentNames().size(), 2UL);
auto X_name = opdesc.Input("X").front();
auto Out_grad_name = opdesc.Input("Out@GRAD").front();
auto X_grad_name = opdesc.Output("X@GRAD").front();
param_.X = GetVar<lite::Tensor>(scope, X_name);
param_.Out_grad = GetVar<lite::Tensor>(scope, Out_grad_name);
param_.X_grad = GetMutableVar<Tensor>(scope, X_grad_name);
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(mean_grad, paddle::lite::operators::MeanGradOp);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
class MeanGradOp : public OpLite {
public:
explicit MeanGradOp(const std::string &type) : OpLite(type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "mean_grad"; }
private:
mutable operators::MeanGradParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/mean_op.h"
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
......@@ -19,82 +21,28 @@ namespace paddle {
namespace lite {
namespace operators {
class MeanOp : public OpLite {
public:
explicit MeanOp(const std::string& type) : OpLite(type) {}
bool MeanOp::CheckShape() const {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Out);
return true;
}
bool CheckShape() const override {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Out);
return true;
}
bool MeanOp::InferShape() const {
param_.Out->Resize(std::vector<int64_t>{1});
return true;
}
bool InferShape() const override {
param_.Out->Resize(std::vector<int64_t>{1});
return true;
}
bool MeanOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
auto X_name = opdesc.Input("X").front();
auto Out_name = opdesc.Output("Out").front();
bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override {
auto X_name = opdesc.Input("X").front();
auto Out_name = opdesc.Output("Out").front();
param_.X = GetVar<lite::Tensor>(scope, X_name);
param_.Out = GetMutableVar<Tensor>(scope, Out_name);
return true;
}
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "mean"; }
private:
mutable operators::MeanParam param_;
};
#ifdef LITE_WITH_TRAIN
class MeanGradOp : public OpLite {
public:
explicit MeanGradOp(const std::string& type) : OpLite(type) {}
bool CheckShape() const override {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Out_grad);
CHECK_OR_FALSE(param_.X_grad);
return true;
}
bool InferShape() const override {
param_.X_grad->Resize(param_.X->dims());
// param_.X_grad->set_lod(param_.X->lod());
return true;
}
bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override {
CHECK_EQ(opdesc.InputArgumentNames().size(), 2UL);
auto X_name = opdesc.Input("X").front();
auto Out_grad_name = opdesc.Input(framework::GradVarName("Out")).front();
auto X_grad_name = opdesc.Output(framework::GradVarName("X")).front();
param_.X = GetVar<lite::Tensor>(scope, X_name);
param_.Out_grad = GetVar<lite::Tensor>(scope, Out_grad_name);
param_.X_grad = GetMutableVar<Tensor>(scope, X_grad_name);
return true;
}
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "mean_grad"; }
private:
mutable operators::MeanGradParam param_;
};
#endif
param_.X = GetVar<lite::Tensor>(scope, X_name);
param_.Out = GetMutableVar<Tensor>(scope, Out_name);
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(mean, paddle::lite::operators::MeanOp);
#ifdef LITE_WITH_TRAIN
REGISTER_LITE_OP(mean_grad, paddle::lite::operators::MeanGradOp);
#endif
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
class MeanOp : public OpLite {
public:
explicit MeanOp(const std::string &type) : OpLite(type) {}
bool CheckShape() const override;
bool InferShape() const override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override;
std::string DebugString() const override { return "mean"; }
private:
mutable operators::MeanParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -60,6 +60,9 @@ if(LITE_BUILD_EXTRA)
lite_cc_test(test_kernel_lookup_table_compute SRCS lookup_table_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_lookup_table_dequant_compute SRCS lookup_table_dequant_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_gather_compute SRCS gather_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
# for training kernel
lite_cc_test(test_kernel_mean_compute SRCS mean_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
endif()
lite_cc_test(test_kernel_pad2d_compute SRCS pad2d_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_prior_box_compute SRCS prior_box_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
#include "lite/tests/utils/fill_data.h"
namespace paddle {
namespace lite {
class MeanComputeTester : public arena::TestCase {
protected:
DDim input_dims_{{2, 5}};
std::string input_ = "x";
std::string output_ = "out";
public:
MeanComputeTester(const Place& place,
const std::string& alias,
const DDim& input_dims)
: TestCase(place, alias), input_dims_(input_dims) {}
void RunBaseline(Scope* scope) override {
auto input = scope->FindTensor(input_);
auto output = scope->NewTensor(output_);
std::vector<int64_t> out_dims{1};
output->Resize(out_dims);
auto input_data = input->data<float>();
auto output_data = output->mutable_data<float>();
int x_size = input_dims_.production();
float sum = 0;
for (int i = 0; i < x_size; i++) {
sum += input_data[i];
}
output_data[0] = sum / x_size;
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("mean");
op_desc->SetInput("X", {input_});
op_desc->SetOutput("Out", {output_});
}
void PrepareData() override {
std::vector<float> input(input_dims_.production());
fill_data_rand(input.data(), -1.f, 1.f, input_dims_.production());
SetCommonTensor(input_, input_dims_, input.data());
}
};
void TestNormalCase(Place place, float abs_error = 2e-5) {
LOG(INFO) << "Test Mean";
for (std::vector<int64_t> dims : std::vector<std::vector<int64_t>>{
{5}, {4, 5}, {3, 4, 5}, {2, 3, 4, 5}}) {
std::unique_ptr<arena::TestCase> tester(
new MeanComputeTester(place, "def", DDim(dims)));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
#ifdef LITE_WITH_TRAIN
class MeanGradComputeTester : public arena::TestCase {
protected:
DDim input_dims_{{2, 5}};
DDim output_grad_dims_{{1}};
std::string input_ = "x";
std::string input_grad_ = "x_grad";
std::string output_grad_ = "out_grad";
public:
MeanGradComputeTester(const Place& place,
const std::string& alias,
const DDim& input_dims)
: TestCase(place, alias), input_dims_(input_dims) {}
void RunBaseline(Scope* scope) override {
auto input = scope->FindTensor(input_);
auto output_grad = scope->FindTensor(output_grad_);
auto input_grad = scope->NewTensor(input_grad_);
input_grad->Resize(input_dims_);
auto input_data = input->data<float>();
auto output_grad_data = output_grad->data<float>();
auto input_grad_data = input_grad->mutable_data<float>();
int x_size = input_dims_.production();
float d_x = output_grad_data[0] / x_size;
for (int i = 0; i < x_size; i++) {
input_grad_data[i] = d_x;
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("mean_grad");
op_desc->SetInput("X", {input_});
op_desc->SetInput("Out@GRAD", {output_grad_});
op_desc->SetOutput("X@GRAD", {input_grad_});
}
void PrepareData() override {
std::vector<float> input(input_dims_.production());
fill_data_rand(input.data(), -1.f, 1.f, input_dims_.production());
SetCommonTensor(input_, input_dims_, input.data());
std::vector<float> output_grad(1);
fill_data_rand(output_grad.data(), -1.f, 1.f, 1);
SetCommonTensor(output_grad_, output_grad_dims_, output_grad.data());
}
};
void TestGradNormalCase(Place place, float abs_error = 2e-5) {
LOG(INFO) << "Test Mean Grad";
for (std::vector<int64_t> dims : std::vector<std::vector<int64_t>>{
{5}, {4, 5}, {3, 4, 5}, {2, 3, 4, 5}}) {
std::unique_ptr<arena::TestCase> tester(
new MeanGradComputeTester(place, "def", DDim(dims)));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
#endif
TEST(Mean, precision) {
#ifdef LITE_WITH_ARM
float abs_error = 2e-5;
Place place(TARGET(kARM));
TestNormalCase(place, abs_error);
#ifdef LITE_WITH_TRAIN
TestGradNormalCase(place, abs_error);
#endif
#endif
}
} // namespace lite
} // namespace paddle
......@@ -594,6 +594,7 @@ function cmake_arm {
-DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=ON \
-DWITH_TESTING=ON \
-DLITE_BUILD_EXTRA=ON \
-DLITE_WITH_TRAIN=ON \
-DARM_TARGET_OS=$1 -DARM_TARGET_ARCH_ABI=$2 -DARM_TARGET_LANG=$3
}
......
......@@ -45,8 +45,6 @@ for path in paths:
op_parser = RegisterLiteOpParser(str_info)
ops = op_parser.parse()
for op in ops:
if "_grad" in op:
continue
if tailored == "ON":
if op not in minlines: continue
out = "USE_LITE_OP(%s);" % op
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册