提交 92eeabeb 编写于 作者: J juncaipeng 提交者: Xiaoyang LI

add assign_value and hard_sigmoid, add fluid_type (#1983)

* add assign_value op, arm kernel and test, add fluid_type, test=develop

* add hard_sigmoid, test=develop
上级 6d1da405
......@@ -110,6 +110,8 @@ USE_LITE_KERNEL(roi_align, kARM, kFloat, kNCHW, def)
USE_LITE_KERNEL(box_clip, kARM, kFloat, kNCHW, def)
USE_LITE_KERNEL(reduce_mean, kARM, kFloat, kNCHW, def)
USE_LITE_KERNEL(stack, kARM, kFloat, kNCHW, def)
USE_LITE_KERNEL(assign_value, kARM, kFloat, kNCHW, def)
USE_LITE_KERNEL(hard_sigmoid, kARM, kFloat, kNCHW, def)
USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, fp32_to_int8);
USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, int8_to_fp32);
......
......@@ -123,3 +123,5 @@ USE_LITE_OP(squeeze2) // for x2paddle
USE_LITE_OP(expand) // for x2paddle
USE_LITE_OP(roi_align)
USE_LITE_OP(box_clip)
USE_LITE_OP(assign_value)
USE_LITE_OP(hard_sigmoid)
......@@ -471,7 +471,7 @@ void act_prelu<float>(const float* din,
}
template <>
void act_sigmoid(const float* din, float* dout, int size, int threads) {
void act_sigmoid<float>(const float* din, float* dout, int size, int threads) {
int nums_per_thread = size / threads;
int remain = size - threads * nums_per_thread;
int neon_loop_cnt_dim4 = nums_per_thread >> 2;
......@@ -595,7 +595,7 @@ void act_swish<float>(
}
template <>
void act_log(const float* din, float* dout, int size, int threads) {
void act_log<float>(const float* din, float* dout, int size, int threads) {
int nums_per_thread = size / threads;
int remain = size - threads * nums_per_thread;
int neon_loop_cnt_dim4 = nums_per_thread >> 2;
......@@ -633,7 +633,7 @@ void act_log(const float* din, float* dout, int size, int threads) {
}
template <>
void act_exp(const float* din, float* dout, int size, int threads) {
void act_exp<float>(const float* din, float* dout, int size, int threads) {
int nums_per_thread = size / threads;
int remain = size - threads * nums_per_thread;
int neon_loop_cnt_dim4 = nums_per_thread >> 2;
......@@ -677,6 +677,21 @@ void act_floor<float>(const float* din, float* dout, int size, int threads) {
}
}
template <>
void act_hard_sigmoid<float>(const float* din,
float* dout,
const int64_t size,
const float slope,
const float offset,
int threads) {
for (int64_t i = 0; i < size; ++i) {
dout[0] = din[0] * slope + offset;
dout[0] = dout[0] < 1.0f ? dout[0] : 1.0f;
dout[0] = dout[0] > 0.0f ? dout[0] : 0.0f;
++din;
++dout;
}
}
} // namespace math
} // namespace arm
} // namespace lite
......
......@@ -58,6 +58,13 @@ void act_exp(const T* din, T* dout, int size, int threads);
template <typename T>
void act_floor(const T* din, T* dout, int size, int threads);
template <typename T>
void act_hard_sigmoid(const T* din,
T* dout,
const int64_t size,
const float slope,
const float offset,
int threads);
} // namespace math
} // namespace arm
} // namespace lite
......
......@@ -45,6 +45,37 @@ enum class Type {
__num__,
};
enum class FluidType {
// Pod Types
BOOL = 0,
INT16 = 1,
INT32 = 2,
INT64 = 3,
FP16 = 4,
FP32 = 5,
FP64 = 6,
// Tensor<size_t> is used in C++.
SIZE_T = 19,
UINT8 = 20,
INT8 = 21,
// Other types that may need additional descriptions
LOD_TENSOR = 7,
SELECTED_ROWS = 8,
FEED_MINIBATCH = 9,
FETCH_LIST = 10,
STEP_SCOPES = 11,
LOD_RANK_TABLE = 12,
LOD_TENSOR_ARRAY = 13,
PLACE_LIST = 14,
READER = 15,
// Any runtime decided variable type is raw
// raw variables should manage their own allocations
// in operators like nccl_op
RAW = 17,
TUPLE = 18,
};
template <typename T>
Type StdTypeToRepr() {
return Type::_unk;
......
......@@ -52,6 +52,7 @@ add_kernel(anchor_generator_compute_arm ARM basic SRCS anchor_generator_compute.
add_kernel(generate_proposals_compute_arm ARM basic SRCS generate_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(roi_align_compute_arm ARM basic SRCS roi_align_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(box_clip_compute_arm ARM basic SRCS box_clip_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(assign_value_compute_arm ARM basic SRCS assign_value_compute.cc DEPS ${lite_kernel_deps} math_arm)
# for OCR specific
add_kernel(gru_unit_compute_arm ARM extra SRCS gru_unit_compute.cc DEPS ${lite_kernel_deps} math_arm)
......
......@@ -147,6 +147,18 @@ void FloorCompute::Run() {
x_data, output_data, x_dims.production(), ctx.threads());
}
void HardSigmoidCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
auto x_dims = param.X->dims();
auto x_data = param.X->data<float>();
float slope = param.hard_sigmoid_slope;
float offset = param.hard_sigmoid_offset;
auto output_data = param.Out->mutable_data<float>();
lite::arm::math::act_hard_sigmoid<float>(
x_data, output_data, x_dims.production(), slope, offset, ctx.threads());
}
} // namespace arm
} // namespace kernels
} // namespace lite
......@@ -224,3 +236,12 @@ REGISTER_LITE_KERNEL(
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(hard_sigmoid,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::HardSigmoidCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
......@@ -121,6 +121,15 @@ class FloorCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
virtual ~FloorCompute() = default;
};
class HardSigmoidCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::ActivationParam;
void Run() override;
virtual ~HardSigmoidCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/assign_value_compute.h"
#include <string>
#include <vector>
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
#include "lite/core/type_system.h"
#include "lite/core/types.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
template <class T>
void TensorFromVector(const std::vector<T>& src, lite::Tensor* dst) {
auto* src_ptr = static_cast<const void*>(src.data());
auto* dst_ptr = static_cast<void*>(dst->mutable_data<T>());
auto size = src.size() * sizeof(T);
std::memcpy(dst_ptr, src_ptr, size);
}
void AssignValueCompute::Run() {
auto& param = Param<operators::AssignValueParam>();
int dtype = param.dtype;
std::vector<float> fp32_values = param.fp32_values;
std::vector<int> int32_values = param.int32_values;
auto* out = param.Out;
if (dtype == static_cast<int>(lite::core::FluidType::INT32)) {
TensorFromVector(int32_values, out);
} else if (dtype == static_cast<int>(lite::core::FluidType::FP32)) {
TensorFromVector(fp32_values, out);
} else {
LOG(FATAL) << "Unsupported dtype for assign_value_op:" << dtype;
}
return;
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(assign_value,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::AssignValueCompute,
def)
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
#include "lite/operators/assign_value_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class AssignValueCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::AssignValueParam;
void Run() override;
virtual ~AssignValueCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -70,6 +70,7 @@ add_operator(roi_align_op basic SRCS roi_align_op.cc DEPS ${op_DEPS})
add_operator(box_clip_op basic SRCS box_clip_op.cc DEPS ${op_DEPS})
add_operator(flatten_op basic SRCS flatten_op.cc DEPS ${op_DEPS})
add_operator(fake_quantize_range_abs_max_op basic SRCS fake_quantize_range_abs_max.cc DEPS ${op_DEPS})
add_operator(assign_value_op basic SRCS assign_value_op.cc DEPS ${op_DEPS})
# for OCR specific
add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS})
......
......@@ -51,6 +51,11 @@ bool ActivationOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
if (opdesc.Type() == "swish") {
param_.Swish_beta = opdesc.GetAttr<float>("beta");
}
if (opdesc.Type() == "hard_sigmoid") {
param_.hard_sigmoid_slope = opdesc.GetAttr<float>("slope");
param_.hard_sigmoid_offset = opdesc.GetAttr<float>("offset");
}
param_.Out = scope->FindVar(out_name)->GetMutable<lite::Tensor>();
return true;
}
......@@ -111,6 +116,7 @@ REGISTER_LITE_OP(relu6, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(log, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(exp, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(floor, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(hard_sigmoid, paddle::lite::operators::ActivationOp);
#ifdef LITE_WITH_TRAIN
REGISTER_LITE_OP(square_grad, paddle::lite::operators::ActivationGradOp);
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/assign_value_op.h"
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool AssignValueOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.Out);
auto shape = param_.shape;
auto int32_values = param_.int32_values;
auto fp32_values = param_.fp32_values;
size_t shape_num = 1;
for (int i = 0; i < shape.size(); i++) {
shape_num *= shape[i];
}
CHECK_OR_FALSE(shape_num == int32_values.size() ||
shape_num == fp32_values.size());
return true;
}
bool AssignValueOpLite::InferShape() const {
std::vector<int> shape = param_.shape;
std::vector<int64_t> out_shape;
for (size_t i = 0; i < shape.size(); i++) out_shape.push_back(shape[i]);
param_.Out->Resize(out_shape);
return true;
}
bool AssignValueOpLite::AttachImpl(const cpp::OpDesc &op_desc,
lite::Scope *scope) {
param_.shape = op_desc.GetAttr<std::vector<int>>("shape");
param_.dtype = op_desc.GetAttr<int>("dtype");
param_.fp32_values = op_desc.GetAttr<std::vector<float>>("fp32_values");
param_.int32_values = op_desc.GetAttr<std::vector<int>>("int32_values");
auto out = op_desc.Output("Out").front();
param_.Out = scope->FindVar(out)->GetMutable<lite::Tensor>();
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(assign_value, paddle::lite::operators::AssignValueOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/operators/op_params.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class AssignValueOpLite : public OpLite {
public:
AssignValueOpLite() {}
explicit AssignValueOpLite(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "assign value"; }
private:
mutable AssignValueParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -214,6 +214,8 @@ struct ActivationParam {
"channel"}; // prelu param, can be "all", "channel" or "element"
lite::Tensor* Prelu_alpha{}; // prelu param
float Swish_beta; // swish param
float hard_sigmoid_slope{0.2};
float hard_sigmoid_offset{0.5};
lite::Tensor* Out{};
bool has_active{false};
lite_api::ActivationType active_type;
......@@ -791,6 +793,7 @@ struct AssignParam {
lite::Tensor* Out{};
};
/// ----------------------- roi_align operators -----------------------
struct RoiAlignParam {
lite::Tensor* X{};
lite::Tensor* ROIs{};
......@@ -801,12 +804,22 @@ struct RoiAlignParam {
int sampling_ratio{-1};
};
/// ----------------------- box_clip operators -----------------------
struct BoxClipParam {
const lite::Tensor* Input{};
const lite::Tensor* ImInfo{};
lite::Tensor* Output{};
};
/// ----------------------- assign_value operators -----------------------
struct AssignValueParam {
std::vector<int> shape{};
int dtype{};
std::vector<float> fp32_values{};
std::vector<int> int32_values{};
lite::Tensor* Out{};
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -15,6 +15,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH
lite_cc_test(test_kernel_norm_compute SRCS norm_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_cast_compute SRCS cast_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_assign_compute SRCS assign_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_assign_value_compute SRCS assign_value_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_sequence_softmax_compute SRCS sequence_softmax_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_im2sequence_compute SRCS im2sequence_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_compare_compute SRCS compare_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
namespace paddle {
namespace lite {
class AssignValueComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string out_ = "out";
int dtype_{};
std::vector<int> shape_{};
std::vector<int> int32_values_{};
std::vector<float> fp32_values_{};
size_t num_ = 1;
public:
AssignValueComputeTester(const Place& place,
const std::string& alias,
int dtype,
int n,
int c,
int h,
int w)
: TestCase(place, alias) {
dtype_ = dtype;
shape_.push_back(n);
shape_.push_back(c);
shape_.push_back(h);
shape_.push_back(w);
num_ = n * c * h * w;
}
void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(out_);
CHECK(out);
std::vector<int64_t> out_shape(shape_.begin(), shape_.end());
out->Resize(out_shape);
if (dtype_ == 2) {
auto* out_data = out->mutable_data<int>();
for (int i = 0; i < out->numel(); i++) {
out_data[i] = int32_values_[i];
}
} else if (dtype_ == 5) {
auto* out_data = out->mutable_data<float>();
for (int i = 0; i < out->numel(); i++) {
out_data[i] = fp32_values_[i];
}
} else {
LOG(FATAL) << "unsuport dtype_:" << dtype_;
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("assign_value");
op_desc->SetAttr("shape", shape_);
op_desc->SetAttr("dtype", dtype_);
op_desc->SetAttr("fp32_values", fp32_values_);
op_desc->SetAttr("int32_values", int32_values_);
op_desc->SetOutput("Out", {out_});
}
void PrepareData() override {
// int32
if (dtype_ == 2) {
int32_values_.resize(num_);
for (int i = 0; i < num_; i++) {
int32_values_[i] = i;
}
} else if (dtype_ == 5) {
fp32_values_.resize(num_);
for (int i = 0; i < num_; i++) {
fp32_values_[i] = i / 1.23f;
}
} else {
LOG(FATAL) << "unsupport dtype_:" << dtype_;
}
}
};
TEST(AssignValue, precision) {
LOG(INFO) << "test argmax op";
#ifdef LITE_WITH_ARM
LOG(INFO) << "test argmax arm";
Place place(TARGET(kARM));
for (int dtype : {2, 5}) {
for (int n : {1}) {
for (int c : {2}) {
for (int h : {1}) {
for (int w : {2}) {
std::unique_ptr<arena::TestCase> tester(
new AssignValueComputeTester(place, "def", dtype, n, c, h, w));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
}
}
}
#endif
}
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册