提交 52a72e72 编写于 作者: Y yiicy

add compute api, test=develop

上级 30616633
......@@ -106,6 +106,7 @@ lite_option(LITE_BUILD_EXTRA "Enable extra algorithm support in Lite, both kerne
lite_option(LITE_BUILD_TAILOR "Enable tailoring library according to model" OFF)
# cv build options
lite_option(LITE_WITH_CV "Enable build cv image in lite" OFF)
lite_option(LITE_WITH_COMPUTE_API "Enable build conmpute api in lite" OFF)
lite_option(LITE_WITH_STATIC_CUDA "Statically link cuda libraries." ON)
lite_option(LITE_WITH_ARM_CLANG "when arm lang is clang, its ON." OFF)
......
......@@ -129,6 +129,9 @@ if (LITE_WITH_ARM)
if (LITE_WITH_CV)
add_definitions("-DLITE_WITH_CV")
endif()
if (LITE_WITH_COMPUTE_API)
add_definitions("-DLITE_WITH_COMPUTE_API")
endif()
endif()
if (LITE_WITH_TRAIN)
......
......@@ -280,6 +280,8 @@ if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM)
COMMAND cp "${CMAKE_BINARY_DIR}/lite/gen_code/paddle_code_generator" "${INFER_LITE_PUBLISH_ROOT}/bin"
COMMAND cp "${CMAKE_BINARY_DIR}/lite/api/test_model_bin" "${INFER_LITE_PUBLISH_ROOT}/bin"
COMMAND cp "${CMAKE_SOURCE_DIR}/lite/utils/cv/paddle_*.h" "${INFER_LITE_PUBLISH_ROOT}/cxx/include"
COMMAND cp "${CMAKE_SOURCE_DIR}/lite/api/compute_api.h" "${INFER_LITE_PUBLISH_ROOT}/cxx/include"
COMMAND cp "${CMAKE_SOURCE_DIR}/lite/api/compute_param.h" "${INFER_LITE_PUBLISH_ROOT}/cxx/include"
)
if(NOT IOS)
add_dependencies(publish_inference_cxx_lib paddle_code_generator)
......@@ -323,6 +325,9 @@ if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM)
COMMAND cp "${CMAKE_BINARY_DIR}/libpaddle_api_light_bundled.a" "${INFER_LITE_PUBLISH_ROOT}/cxx/lib"
COMMAND cp "${CMAKE_BINARY_DIR}/lite/api/libpaddle_light_api_shared.so" "${INFER_LITE_PUBLISH_ROOT}/cxx/lib"
COMMAND cp "${CMAKE_SOURCE_DIR}/lite/utils/cv/paddle_*.h" "${INFER_LITE_PUBLISH_ROOT}/cxx/include"
COMMAND cp "${CMAKE_SOURCE_DIR}/lite/api/compute_api.h" "${INFER_LITE_PUBLISH_ROOT}/cxx/include"
COMMAND cp "${CMAKE_SOURCE_DIR}/lite/api/compute_param.h" "${INFER_LITE_PUBLISH_ROOT}/cxx/include"
COMMAND cp "${CMAKE_SOURCE_DIR}/lite/api/compute_utils.h" "${INFER_LITE_PUBLISH_ROOT}/cxx/include"
)
add_dependencies(tiny_publish_cxx_lib paddle_light_api_shared)
add_dependencies(tiny_publish_cxx_lib bundle_light_api)
......@@ -380,6 +385,8 @@ if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM)
COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/cxx/makefiles/mask_detection/Makefile.${ARM_TARGET_OS}.${ARM_TARGET_ARCH_ABI}" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx/mask_detection/Makefile"
COMMAND cp -r "${CMAKE_SOURCE_DIR}/lite/demo/cxx/test_libs" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx"
COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/cxx/makefiles/test_libs/Makefile.${ARM_TARGET_OS}.${ARM_TARGET_ARCH_ABI}" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx/test_libs/Makefile"
COMMAND cp -r "${CMAKE_SOURCE_DIR}/lite/demo/cxx/test_compute_api" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx"
COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/cxx/makefiles/test_compute_api/Makefile.${ARM_TARGET_OS}.${ARM_TARGET_ARCH_ABI}" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx/test_compute_api/Makefile"
)
add_dependencies(publish_inference_android_cxx_demos logging gflags)
add_dependencies(publish_inference_cxx_lib publish_inference_android_cxx_demos)
......@@ -401,6 +408,8 @@ if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM)
COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/cxx/makefiles/test_cv/Makefile.${ARM_TARGET_OS}.${ARM_TARGET_ARCH_ABI}" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx/test_cv/Makefile"
COMMAND cp -r "${CMAKE_SOURCE_DIR}/lite/demo/cxx/mask_detection" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx"
COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/cxx/makefiles/mask_detection/Makefile.${ARM_TARGET_OS}.${ARM_TARGET_ARCH_ABI}" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx/mask_detection/Makefile"
COMMAND cp -r "${CMAKE_SOURCE_DIR}/lite/demo/cxx/test_compute_api" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx"
COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/cxx/makefiles/test_compute_api/Makefile.${ARM_TARGET_OS}.${ARM_TARGET_ARCH_ABI}" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx/test_compute_api/Makefile"
)
add_dependencies(tiny_publish_cxx_lib publish_inference_android_cxx_demos)
endif()
......
......@@ -64,7 +64,11 @@ if ((NOT LITE_ON_TINY_PUBLISH) AND (LITE_WITH_CUDA OR LITE_WITH_X86 OR LITE_WITH
else()
if ((ARM_TARGET_OS STREQUAL "android") OR (ARM_TARGET_OS STREQUAL "armlinux"))
add_library(paddle_light_api_shared SHARED "")
target_sources(paddle_light_api_shared PUBLIC ${__lite_cc_files} paddle_api.cc light_api.cc light_api_impl.cc)
if (LITE_WITH_COMPUTE_API)
target_sources(paddle_light_api_shared PUBLIC ${__lite_cc_files} paddle_api.cc light_api.cc light_api_impl.cc compute_param.cc compute_api.cc compute_utils.cc)
else()
target_sources(paddle_light_api_shared PUBLIC ${__lite_cc_files} paddle_api.cc light_api.cc light_api_impl.cc)
endif()
set(TARGET_COMIPILE_FLAGS "-fdata-sections")
if (NOT (ARM_TARGET_LANG STREQUAL "clang")) #gcc
set(TARGET_COMIPILE_FLAGS "${TARGET_COMIPILE_FLAGS} -flto")
......@@ -308,7 +312,11 @@ lite_cc_library(paddle_api SRCS paddle_api.cc DEPS op_params tensor device_info)
#-----------------------------------------------------------------------------------------------------
# The final inference library for both CxxConfig and MobileConfig.
if (LITE_ON_TINY_PUBLISH)
lite_cc_library(paddle_api_light SRCS light_api_impl.cc DEPS light_api paddle_api stream)
if (LITE_WITH_COMPUTE_API)
lite_cc_library(paddle_api_light SRCS light_api_impl.cc compute_param.cc compute_api.cc compute_utils.cc DEPS light_api paddle_api stream)
else()
lite_cc_library(paddle_api_light SRCS light_api_impl.cc DEPS light_api paddle_api stream)
endif()
else()
lite_cc_library(paddle_api_light SRCS light_api_impl.cc DEPS light_api paddle_api)
endif()
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "compute_api.h" // NOLINT
#include <algorithm>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/context.h"
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/core/program.h"
#include "log_lite.h" // NOLINT
namespace paddle {
namespace lite_api {
class InstructionWrapper {
public:
InstructionWrapper(
std::shared_ptr<lite::OpLite>& op, // NOLINT
std::vector<std::unique_ptr<lite::KernelBase>>& kernels) { // NOLINT
op_ = op;
for (auto& kernel : kernels) {
kernels_.emplace_back(std::move(kernel));
}
}
lite::OpLite* get_op() { return op_.get(); }
lite::KernelBase* get_kernel() {
if (kernel_idx > kernels_.size()) {
LOGF("Error! kernel index > kernel size\n");
}
return kernels_[kernel_idx].get();
}
void set_kernel_idx(int idx) { kernel_idx = idx; }
~InstructionWrapper() = default;
private:
std::shared_ptr<lite::OpLite> op_;
std::vector<std::unique_ptr<lite::KernelBase>> kernels_;
int kernel_idx{0};
};
void ComputeEngine<TARGET(kARM)>::env_init(PowerMode power_mode, int threads) {
lite::DeviceInfo::Init();
lite::DeviceInfo::Global().SetRunMode(power_mode, threads);
}
bool ComputeEngine<TARGET(kARM)>::CreateOperator(const char* op_type,
PrecisionType precision,
DataLayoutType layout) {
auto op = lite::LiteOpRegistry::Global().Create(op_type);
LCHECK(op, "no Op found for %s\n", op_type);
LOGI("Create %s Operator Success\n", op_type);
lite_api::Place place(TARGET(kARM), precision, layout);
auto kernels = op->CreateKernels({place});
LCHECK_GT(kernels.size(), 0, "no kernel found for: %s\n", op_type);
LOGI("Create %s kernel Success\n", op_type);
instruction_ = new InstructionWrapper(op, kernels);
return true;
}
// param must set input and output
void ComputeEngine<TARGET(kARM)>::SetParam(ParamBase* param) {
delete static_cast<lite::operators::ParamBase*>(param_);
// generate raw param
param_ = param->AttachRawParam();
auto* ins = static_cast<InstructionWrapper*>(instruction_);
// pick kernel
ins->set_kernel_idx(param->GetKernelIndex());
// get raw kernel and op
auto* kernel = ins->get_kernel();
LCHECK(kernel, "SetParam, pick kernel error\n");
auto* op = ins->get_op();
// set context
std::unique_ptr<lite::KernelContext> ctx(new lite::KernelContext);
kernel->SetContext(std::move(ctx));
op->SetParam(static_cast<lite::operators::ParamBase*>(param_));
op->CheckShape();
op->AttachKernel(kernel);
LOGI("SetParam Success\n");
}
void ComputeEngine<TARGET(kARM)>::Launch() {
auto* ins = static_cast<InstructionWrapper*>(instruction_);
auto* kernel = ins->get_kernel();
LCHECK(kernel, "Launch, pick kernel error\n");
auto* op = ins->get_op();
op->InferShapeImpl();
kernel->Launch();
LOGI("Run Success\n");
}
ComputeEngine<TARGET(kARM)>::~ComputeEngine() {
delete static_cast<InstructionWrapper*>(instruction_);
delete static_cast<lite::operators::ParamBase*>(param_);
instruction_ = nullptr;
param_ = nullptr;
}
} // namespace lite_api
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "compute_param.h" // NOLINT
#include "paddle_place.h" // NOLINT
namespace paddle {
namespace lite_api {
// now ComputeEngine only support Target = Arm
template <TargetType Type>
class LITE_API ComputeEngine {
public:
ComputeEngine() = default;
bool CreateOperator(const char* op_type,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW)) {}
void SetParam(ParamBase* param) {}
void Launch() {}
~ComputeEngine() = default;
private:
void* instruction_;
void* param_;
};
template <>
class LITE_API ComputeEngine<TARGET(kARM)> {
public:
ComputeEngine() = default;
static void env_init(PowerMode power_mode, int threads);
bool CreateOperator(const char* op_type,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW));
void SetParam(ParamBase* param);
void Launch();
~ComputeEngine();
private:
void* instruction_{nullptr};
void* param_{nullptr};
};
} // namespace lite_api
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "compute_param.h" // NOLINT
#include "lite/operators/op_params.h"
#include "log_lite.h" // NOLINT
namespace paddle {
namespace lite_api {
void *ActivationParam::AttachRawParam() {
//! necessary check
LCHECK(X, "ActivationParam must set input tensor: X\n");
LCHECK(Out, "ActivationParam must set output tensor: Out\n");
auto *raw_act_param = new lite::operators::ActivationParam();
// Tensor
raw_act_param->X = static_cast<const lite::Tensor *>(X->GetRawTensor());
raw_act_param->Out = static_cast<lite::Tensor *>(Out->GetRawTensor());
raw_act_param->Prelu_alpha =
Prelu_alpha ? static_cast<lite::Tensor *>(Prelu_alpha->GetRawTensor())
: nullptr;
raw_act_param->active_type = active_type;
raw_act_param->has_active = has_active;
raw_act_param->Leaky_relu_alpha = Leaky_relu_alpha;
raw_act_param->Relu_clipped_coef = Relu_clipped_coef;
raw_act_param->Prelu_mode = Prelu_mode;
raw_act_param->Swish_beta = Swish_beta;
raw_act_param->hard_sigmoid_slope = hard_sigmoid_slope;
raw_act_param->hard_sigmoid_offset = hard_sigmoid_offset;
raw_act_param->hard_swish_scale = hard_swish_scale;
raw_act_param->hard_swish_offset = hard_swish_offset;
raw_act_param->hard_swish_threshold = hard_swish_threshold;
return raw_act_param;
}
void *ConvParam::AttachRawParam() {
//! necessary check
LCHECK(x, "ConvParam must set input tensor: x\n");
LCHECK(filter, "ConvParam must set filter tensor: filter\n");
LCHECK(output, "ConvParam must set output tensor: output\n");
if (enable_int8 && out_ptype == PRECISION(kFloat)) {
LCHECK_NE(input_scale, 0.f, "int8 conv out float, must has input scale\n");
LCHECK(!weight_scale.empty(),
"int8 conv out float, must has weights scale\n");
} else if (enable_int8 && out_ptype == PRECISION(kInt8)) {
LCHECK_NE(input_scale, 0.f, "int8 conv out int8, must has input scale\n");
LCHECK_NE(output_scale, 0.f, "int8 conv out int8, must has output scale\n");
LCHECK(!weight_scale.empty(),
"int8 conv out int8, must has weights scale\n");
}
auto *raw_conv_param = new lite::operators::ConvParam();
// Tensor
raw_conv_param->x = static_cast<lite::Tensor *>(x->GetRawTensor());
raw_conv_param->filter = static_cast<lite::Tensor *>(filter->GetRawTensor());
raw_conv_param->output = static_cast<lite::Tensor *>(output->GetRawTensor());
raw_conv_param->bias =
bias ? static_cast<lite::Tensor *>(bias->GetRawTensor()) : nullptr;
raw_conv_param->residualData =
residualData ? static_cast<lite::Tensor *>(residualData->GetRawTensor())
: nullptr;
// activation param
raw_conv_param->activation_param.active_type = activation_param.active_type;
raw_conv_param->activation_param.has_active = activation_param.has_active;
raw_conv_param->activation_param.Relu_clipped_coef =
activation_param.Relu_clipped_coef;
raw_conv_param->activation_param.Leaky_relu_alpha =
activation_param.Leaky_relu_alpha;
raw_conv_param->activation_param.Swish_beta = activation_param.Swish_beta;
raw_conv_param->activation_param.hard_sigmoid_slope =
activation_param.hard_sigmoid_slope;
raw_conv_param->activation_param.hard_sigmoid_offset =
activation_param.hard_sigmoid_offset;
raw_conv_param->activation_param.hard_swish_scale =
activation_param.hard_swish_scale;
raw_conv_param->activation_param.hard_swish_offset =
activation_param.hard_swish_offset;
raw_conv_param->activation_param.hard_swish_threshold =
activation_param.hard_swish_threshold;
// for int8
raw_conv_param->enable_int8 = enable_int8;
raw_conv_param->input_scale = input_scale;
raw_conv_param->weight_scale = weight_scale;
raw_conv_param->output_scale = output_scale;
raw_conv_param->bit_length = bit_length;
raw_conv_param->strides = strides;
raw_conv_param->paddings = paddings;
raw_conv_param->groups = groups;
raw_conv_param->dilations = dilations;
raw_conv_param->fuse_residual_connection = fuse_residual_connection;
raw_conv_param->data_format = data_format;
raw_conv_param->output_size = output_size;
return raw_conv_param;
}
int ConvParam::GetKernelIndex() {
if (enable_int8) {
if (out_ptype == PRECISION(kFloat)) {
return 1;
} else if (out_ptype == PRECISION(kInt8)) {
return 0;
} else {
LOGF("conv only support float and int8 precision\n");
}
} else {
return 0;
}
}
} // namespace lite_api
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle_api.h" // NOLINT
#include "paddle_place.h" // NOLINT
namespace paddle {
namespace lite_api {
class LITE_API ParamBase {
public:
PrecisionType out_ptype{PRECISION(kFloat)};
virtual int GetKernelIndex() { return 0; }
virtual void* AttachRawParam() {}
virtual ~ParamBase() = default;
};
class LITE_API ActivationParam : public ParamBase {
public:
Tensor* X{};
Tensor* Out{};
ActivationType active_type{ActivationType::kIndentity};
bool has_active{false};
float Leaky_relu_alpha{0}; // leaky_relu param
float Relu_clipped_coef{6}; // relu_clipped param
const char* Prelu_mode{
"channel"}; // prelu param, can be "all", "channel" or "element"
Tensor* Prelu_alpha{}; // prelu param
float Swish_beta; // swish param
// hard_sigmoid param
float hard_sigmoid_slope{0.2f};
float hard_sigmoid_offset{0.5f};
// hard_swish param
float hard_swish_threshold{6.0};
float hard_swish_scale{6.0};
float hard_swish_offset{3.0};
ActivationParam() = default;
virtual ~ActivationParam() = default;
void* AttachRawParam() override;
};
class LITE_API ConvParam : public ParamBase {
public:
Tensor* x{};
Tensor* filter{};
Tensor* bias{nullptr};
Tensor* residualData{nullptr};
Tensor* output{};
std::vector<int> strides{1, 1};
std::shared_ptr<std::vector<int>> paddings;
int groups{1};
std::shared_ptr<std::vector<int>> dilations;
bool fuse_residual_connection{false};
const char* data_format{"Anylayout"};
// for activation
ActivationParam activation_param;
// only used in conv_transpose.
std::vector<int> output_size;
// for int8
bool enable_int8{false};
float input_scale{1.0f};
std::vector<float> weight_scale{};
float output_scale{1.0f};
int bit_length{8};
ConvParam() = default;
virtual ~ConvParam() = default;
void* AttachRawParam() override;
int GetKernelIndex() override;
};
} // namespace lite_api
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "compute_utils.h" // NOLINT
#include "lite/backends/arm/math/type_trans.h"
#include "lite/core/tensor.h"
#include "log_lite.h" // NOLINT
#include "paddle_place.h" // NOLINT
namespace paddle {
namespace lite_api {
// clang-format off
void ComputeUtils::TensorFloatToInt8(Tensor& tin, Tensor& tout, float scale) {
lite::Tensor* raw_tin = static_cast<lite::Tensor*>(tin.GetRawTensor());
lite::Tensor* raw_tout = static_cast<lite::Tensor*>(tout.GetRawTensor());
LCHECK(raw_tin, "tensor in must have raw tensor\n");
tout.Resize(tin.shape());
int outer_size = 1;
int axis_size = 1;
int inner_size = raw_tin->numel();
const float* din = raw_tin->data<float>();
int8_t* dout = raw_tout->mutable_data<int8_t>();
paddle::lite::arm::math::fp32_to_int8(
din, dout, &scale, axis_size, outer_size, inner_size);
}
void ComputeUtils::TensorFloatToInt8Inplace(Tensor& tin, float scale) {
lite::Tensor* raw_tin = static_cast<lite::Tensor*>(tin.GetRawTensor());
LCHECK(raw_tin, "tensor in must have raw tensor\n");
LCHECK_GT(raw_tin->numel(), 0, "tensor in shape must greater than zero\n");
LCHECK_EQ(raw_tin->precision(),
PRECISION(kFloat),
"tensor in precision must be float\n");
int outer_size = 1;
int axis_size = 1;
int inner_size = raw_tin->numel();
const float* din = raw_tin->data<float>();
int8_t* dout = raw_tin->mutable_data<int8_t>();
paddle::lite::arm::math::fp32_to_int8(
din, dout, &scale, axis_size, outer_size, inner_size);
}
void ComputeUtils::TensorInt8ToFloat(Tensor& tin, Tensor& tout, float scale) {
lite::Tensor* raw_tin = static_cast<lite::Tensor*>(tin.GetRawTensor());
lite::Tensor* raw_tout = static_cast<lite::Tensor*>(tout.GetRawTensor());
LCHECK(raw_tin, "tensor in must have raw tensor\n");
LCHECK_GT(raw_tin->numel(), 0, "tensor in shape must greater than zero\n");
LCHECK_EQ(raw_tin->precision(),
PRECISION(kInt8),
"tensor in precision must be int8");
tout.Resize(tin.shape());
int outer_size = 1;
int axis_size = 1;
int inner_size = raw_tin->numel();
const int8_t* din = raw_tin->data<int8_t>();
float* dout = raw_tout->mutable_data<float>();
paddle::lite::arm::math::int8_to_fp32(
din, dout, &scale, axis_size, outer_size, inner_size);
}
void ComputeUtils::TensorInt8ToFloatInplace(Tensor& tin, float scale) {
lite::Tensor* raw_tin = static_cast<lite::Tensor*>(tin.GetRawTensor());
lite::Tensor tmp_out;
LCHECK(raw_tin, "tensor in must have raw tensor\n");
LCHECK_GT(raw_tin->numel(), 0, "tensor in shape must greater than zero\n");
LCHECK_EQ(raw_tin->precision(),
PRECISION(kInt8),
"tensor in precision must be int8");
tmp_out.Resize(tin.shape());
int outer_size = 1;
int axis_size = 1;
int inner_size = raw_tin->numel();
const int8_t* din = raw_tin->data<int8_t>();
float* tmp_dout = tmp_out.mutable_data<float>();
paddle::lite::arm::math::int8_to_fp32(
din, tmp_dout, &scale, axis_size, outer_size, inner_size);
float* dout = raw_tin->mutable_data<float>();
memcpy(dout, tmp_dout, raw_tin->numel() * sizeof(float));
}
void ComputeUtils::ConvWeightsFloatToInt8(Tensor& weightin,
Tensor& weightout,
std::vector<float> scale) {
lite::Tensor* raw_win = static_cast<lite::Tensor*>(weightin.GetRawTensor());
lite::Tensor* raw_wout = static_cast<lite::Tensor*>(weightout.GetRawTensor());
LCHECK(raw_win, "weights in must have raw tensor\n");
LCHECK_GT(raw_win->numel(), 0, "weights in shape must greater than zero\n");
LCHECK_EQ(raw_win->precision(),
PRECISION(kFloat),
"weights in precision must be float");
weightout.Resize(weightin.shape());
int outer_size = 1;
int axis_size = raw_win->dims()[0]; // chout
int inner_size =
raw_win->numel() / axis_size; // chin / group * ksize_w * ksize_h
const float* din = raw_win->data<float>();
int8_t* dout = raw_wout->mutable_data<int8_t>();
paddle::lite::arm::math::fp32_to_int8(
din, dout, scale.data(), axis_size, outer_size, inner_size);
}
void ComputeUtils::ConvWeightsFloatToInt8Inplace(Tensor& weightin,
std::vector<float> scale) {
lite::Tensor* raw_win = static_cast<lite::Tensor*>(weightin.GetRawTensor());
LCHECK(raw_win, "weights in must have raw tensor\n");
LCHECK_GT(raw_win->numel(), 0, "weights in shape must greater than zero\n");
LCHECK_EQ(raw_win->precision(),
PRECISION(kFloat),
"weights in precision must be float");
int outer_size = 1;
int axis_size = raw_win->dims()[0]; // chout
int inner_size =
raw_win->numel() / axis_size; // chin / group * ksize_w * ksize_h
const float* din = raw_win->data<float>();
int8_t* dout = raw_win->mutable_data<int8_t>();
paddle::lite::arm::math::fp32_to_int8(
din, dout, scale.data(), axis_size, outer_size, inner_size);
}
void ComputeUtils::ConvWeightsInt8ToFloat(Tensor& weightin,
Tensor& weightout,
std::vector<float> scale) {
lite::Tensor* raw_win = static_cast<lite::Tensor*>(weightin.GetRawTensor());
lite::Tensor* raw_wout = static_cast<lite::Tensor*>(weightout.GetRawTensor());
LCHECK(raw_win, "weights in must have raw tensor\n");
LCHECK_GT(raw_win->numel(), 0, "weights in shape must greater than zero\n");
LCHECK_EQ(raw_win->precision(),
PRECISION(kInt8),
"weights in precision must be int8");
weightout.Resize(weightin.shape());
int outer_size = 1;
int axis_size = raw_win->dims()[0]; // chout
int inner_size =
raw_win->numel() / axis_size; // chin / group * ksize_w * ksize_h
const int8_t* din = raw_win->data<int8_t>();
float* dout = raw_wout->mutable_data<float>();
paddle::lite::arm::math::int8_to_fp32(
din, dout, scale.data(), axis_size, outer_size, inner_size);
}
void ComputeUtils::ConvWeightsInt8ToFloatInplace(Tensor& weightin,
std::vector<float> scale) {
lite::Tensor* raw_win = static_cast<lite::Tensor*>(weightin.GetRawTensor());
lite::Tensor tmp_out;
LCHECK(raw_win, "weights in must have raw tensor\n");
LCHECK_GT(raw_win->numel(), 0, "weights in shape must greater than zero\n");
LCHECK_EQ(raw_win->precision(),
PRECISION(kInt8),
"weights in precision must be int8");
tmp_out.Resize(weightin.shape());
int outer_size = 1;
int axis_size = raw_win->dims()[0]; // chout
int inner_size =
raw_win->numel() / axis_size; // chin / group * ksize_w * ksize_h
const int8_t* din = raw_win->data<int8_t>();
float* dout_tmp = tmp_out.mutable_data<float>();
paddle::lite::arm::math::int8_to_fp32(
din, dout_tmp, scale.data(), axis_size, outer_size, inner_size);
float* dout = raw_win->mutable_data<float>();
memcpy(dout, dout_tmp, raw_win->numel() * sizeof(float));
}
// clang-format on
} // namespace lite_api
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle_api.h" // NOLINT
namespace paddle {
namespace lite_api {
struct LITE_API ComputeUtils {
static void TensorFloatToInt8(Tensor& tin, // NOLINT
Tensor& tout, // NOLINT
float scale);
static void TensorFloatToInt8Inplace(Tensor& tin, float scale); // NOLINT
static void TensorInt8ToFloat(Tensor& tin, // NOLINT
Tensor& tout, // NOLINT
float scale);
static void TensorInt8ToFloatInplace(Tensor& tin, float scale); // NOLINT
static void ConvWeightsFloatToInt8(Tensor& weightin, // NOLINT
Tensor& weightout, // NOLINT
std::vector<float> scale);
static void ConvWeightsFloatToInt8Inplace(Tensor& weightin, // NOLINT
std::vector<float> scale);
static void ConvWeightsInt8ToFloat(Tensor& weightin, // NOLINT
Tensor& weightout, // NOLINT
std::vector<float> scale);
static void ConvWeightsInt8ToFloatInplace(Tensor& weightin, // NOLINT
std::vector<float> scale);
};
} // namespace lite_api
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#define LOGI(fmt, ...) printf(fmt, ##__VA_ARGS__)
#define LOGE(fmt, ...) printf(fmt, ##__VA_ARGS__)
#define LOGF(fmt, ...) \
printf(fmt, ##__VA_ARGS__); \
exit(1)
#define LCHECK(a, fmt, ...) \
do { \
if (!a) { \
LOGF(fmt, ##__VA_ARGS__); \
} \
} while (0)
#define LCHECK_EQ(a, b, fmt, ...) \
do { \
if (a != b) { \
LOGF(fmt, ##__VA_ARGS__); \
} \
} while (0)
#define LCHECK_NE(a, b, fmt, ...) \
do { \
if (a == b) { \
LOGF(fmt, ##__VA_ARGS__); \
} \
} while (0)
#define LCHECK_GE(a, b, fmt, ...) \
do { \
if (a < b) { \
LOGF(fmt, ##__VA_ARGS__); \
} \
} while (0)
#define LCHECK_GT(a, b, fmt, ...) \
do { \
if (a <= b) { \
LOGF(fmt, ##__VA_ARGS__); \
} \
} while (0)
#define LCHECK_LE(a, b, fmt, ...) \
do { \
if (a > b) { \
LOGF(fmt, ##__VA_ARGS__); \
} \
} while (0)
#define LCHECK_LT(a, b, fmt, ...) \
do { \
if (a >= b) { \
LOGF(fmt, ##__VA_ARGS__); \
} \
} while (0)
......@@ -35,6 +35,22 @@ const lite::Tensor *ctensor(void *x) {
return static_cast<const lite::Tensor *>(x);
}
#ifdef LITE_WITH_COMPUTE_API
lite::Tensor *mtensor(void *x) { return static_cast<lite::Tensor *>(x); }
Tensor::Tensor() : raw_tensor_(new lite::Tensor()) {}
void Tensor::ReleaseRawTensor() {
delete static_cast<lite::Tensor *>(raw_tensor_);
raw_tensor_ = nullptr;
}
void Tensor::set_precision(PrecisionType ptype) {
mtensor(raw_tensor_)->set_precision(ptype);
}
void *Tensor::GetRawTensor() { return raw_tensor_; }
#endif
void Tensor::Resize(const shape_t &shape) {
tensor(raw_tensor_)->Resize(shape);
}
......
......@@ -36,6 +36,12 @@ struct LITE_API Tensor {
explicit Tensor(void* raw);
explicit Tensor(const void* raw);
#ifdef LITE_WITH_COMPUTE_API
Tensor();
void ReleaseRawTensor();
void set_precision(PrecisionType ptype);
void* GetRawTensor();
#endif
void Resize(const shape_t& shape);
/// Readonly data.
......
......@@ -86,8 +86,8 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels(
auto pick_kernel = [&](const Place &place) {
auto ks = KernelRegistry::Global().Create(
op_type_, place.target, place.precision, place.layout);
VLOG(5) << "pick kernel for " << op_info()->Type() << " "
<< place.DebugString() << " get " << ks.size() << " kernels";
VLOG(5) << "pick kernel for " << op_type_ << " " << place.DebugString()
<< " get " << ks.size() << " kernels";
for (auto &&it : ks) {
AttachKernel(it.get());
kernels.emplace_back(std::move(it));
......
......@@ -68,6 +68,7 @@ class OpLite : public Registry {
// Inference the outputs' shape.
virtual bool InferShapeImpl() const { return true; }
virtual bool InferShape();
virtual bool SetParam(operators::ParamBase *param) { return false; }
// Run this operator.
virtual bool Run();
// Indicate whether the Op runs only once or not
......
ARM_ABI = arm7
export ARM_ABI
include ../Makefile.def
LITE_ROOT=../../../
CXX_INCLUDES = $(INCLUDES) -I$(LITE_ROOT)/cxx/include
CXX_LIBS = -L$(LITE_ROOT)/cxx/lib/ -lpaddle_light_api_shared $(SYSTEM_LIBS)
CXX_DEFINES += -DLITE_WITH_COMPUTE_API
###############################################################
# How to use one of static libaray: #
# `libpaddle_api_full_bundled.a` #
# `libpaddle_api_light_bundled.a` #
###############################################################
# Note: default use lite's shared library. #
###############################################################
# 1. Comment above line using `libpaddle_light_api_shared.so`
# 2. Undo comment below line using `libpaddle_api_light_bundled.a`
#CXX_LIBS = $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS)
#activation
test_activation: test_activation.o
$(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) test_activation.o -o test_activation $(CXX_LIBS) $(LDFLAGS)
test_activation.o: activation_test.cc
$(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o test_activation.o -c activation_test.cc
# conv
test_conv: test_conv.o
$(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) test_conv.o -o test_conv $(CXX_LIBS) $(LDFLAGS)
test_conv.o: conv_test.cc
$(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o test_conv.o -c conv_test.cc
# int8 conv
test_conv_int8: test_conv_int8.o
$(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) test_conv_int8.o -o test_conv_int8 $(CXX_LIBS) $(LDFLAGS)
test_conv_int8.o: conv_int8_test.cc
$(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o test_conv_int8.o -c conv_int8_test.cc
.PHONY: clean
clean:
rm -f test_activation.o
rm -f test_activation
rm -f test_conv.o
rm -f test_conv
rm -f test_conv_int8.o
rm -f test_conv_int8
ARM_ABI = arm8
export ARM_ABI
include ../Makefile.def
LITE_ROOT=../../../
CXX_INCLUDES = $(INCLUDES) -I$(LITE_ROOT)/cxx/include
CXX_LIBS = -L$(LITE_ROOT)/cxx/lib/ -lpaddle_light_api_shared $(SYSTEM_LIBS)
CXX_DEFINES += -DLITE_WITH_COMPUTE_API
###############################################################
# How to use one of static libaray: #
# `libpaddle_api_full_bundled.a` #
# `libpaddle_api_light_bundled.a` #
###############################################################
# Note: default use lite's shared library. #
###############################################################
# 1. Comment above line using `libpaddle_light_api_shared.so`
# 2. Undo comment below line using `libpaddle_api_light_bundled.a`
#CXX_LIBS = $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS)
#activation
test_activation: test_activation.o
$(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) test_activation.o -o test_activation $(CXX_LIBS) $(LDFLAGS)
test_activation.o: activation_test.cc
$(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o test_activation.o -c activation_test.cc
# conv
test_conv: test_conv.o
$(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) test_conv.o -o test_conv $(CXX_LIBS) $(LDFLAGS)
test_conv.o: conv_test.cc
$(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o test_conv.o -c conv_test.cc
# int8 conv
test_conv_int8: test_conv_int8.o
$(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) test_conv_int8.o -o test_conv_int8 $(CXX_LIBS) $(LDFLAGS)
test_conv_int8.o: conv_int8_test.cc
$(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o test_conv_int8.o -c conv_int8_test.cc
.PHONY: clean
clean:
rm -f test_activation.o
rm -f test_activation
rm -f test_conv.o
rm -f test_conv
rm -f test_conv_int8.o
rm -f test_conv_int8
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <memory>
#include <vector>
#include "compute_api.h" // NOLINT
#include "compute_param.h" // NOLINT
#include "compute_utils.h" // NOLINT
#include "paddle_api.h" // NOLINT
#include "utils.h" // NOLINT
using namespace paddle::lite_api; // NOLINT
void activation_naive_impl(const float* din,
float* dout,
int64_t len,
ActivationType act_type,
float leaky_relu_alpha) {
switch (act_type) {
case ActivationType::kRelu: {
for (int i = 0; i < len; i++) {
dout[i] = std::max(0.f, din[i]);
}
break;
}
case ActivationType::kRelu6: {
for (int i = 0; i < len; i++) {
dout[i] = std::max(0.f, din[i]);
dout[i] = std::min(6.f, dout[i]);
}
break;
}
case ActivationType::kLeakyRelu: {
for (int i = 0; i < len; i++) {
dout[i] = din[i] > 0.f ? din[i] : din[i] * leaky_relu_alpha;
}
break;
}
case ActivationType::kSigmoid: {
for (int i = 0; i < len; i++) {
dout[i] = 1.f / (1.f + std::exp(-din[i]));
}
break;
}
case ActivationType::kTanh: {
for (int i = 0; i < len; i++) {
dout[i] = (std::exp(din[i]) - std::exp(-din[i])) /
(std::exp(din[i]) + std::exp(-din[i]));
}
break;
}
default:
std::cerr << "the type of activation is unknow." << std::endl;
assert(0);
}
}
void activation_func(int n,
int c,
int h,
int w,
ActivationType act_type,
float leaky_relu_alpha,
int warmup,
int repeats,
bool check_result,
int threads,
PowerMode power_mode) {
Tensor input, output, output_ref;
input.Resize({n, c, h, w});
input.set_precision(PRECISION(kFloat));
output_ref.Resize({n, c, h, w});
output_ref.set_precision(PRECISION(kFloat));
fill_tensor_rand(input, -1.f, 1.f);
ComputeEngine<TARGET(kARM)>::env_init(power_mode, threads);
ComputeEngine<TARGET(kARM)> act;
ActivationParam act_param;
act_param.active_type = act_type;
act_param.X = &input;
act_param.Out = &output;
act_param.Leaky_relu_alpha = leaky_relu_alpha;
std::string act_str;
if (act_type == ActivationType::kRelu) {
act_str = "relu";
} else if (act_type == ActivationType::kRelu6) {
act_str = "relu6";
} else if (act_type == ActivationType::kLeakyRelu) {
act_str = "leaky_relu";
} else if (act_type == ActivationType::kSigmoid) {
act_str = "sigmoid";
} else if (act_type == ActivationType::kTanh) {
act_str = "tanh";
} else {
std::cerr << "act type: " << static_cast<int>(act_type)
<< "is not support now." << std::endl;
assert(0);
}
act.CreateOperator(act_str.c_str());
act.SetParam(&act_param);
act.Launch();
if (output.shape() != output_ref.shape()) {
std::cerr << "act op infer shape error." << std::endl;
assert(0);
}
Timer t;
for (int i = 0; i < warmup; ++i) {
act.Launch();
}
for (int i = 0; i < repeats; ++i) {
t.Start();
act.Launch();
t.Stop();
}
auto shape = input.shape();
std::cout << "act input shape: " << shape[0] << ", " << shape[1] << ", "
<< shape[2] << ", " << shape[3]
<< ", act_type: " << static_cast<int>(act_type)
<< ", warmup: " << warmup << ", repeats: " << repeats
<< ", power mode: " << static_cast<int>(power_mode)
<< ", threads: " << threads << ", avg time: " << t.LapTimes().Avg()
<< " ms" << std::endl;
if (check_result) {
const float* din = input.data<float>();
float* dout_ref = output_ref.mutable_data<float>();
int64_t len = dim_production(input);
activation_naive_impl(din, dout_ref, len, act_type, leaky_relu_alpha);
double max_ratio = 0;
double max_diff = 0;
tensor_cmp_host(output, output_ref, max_ratio, max_diff);
if (std::abs(max_ratio) > 1e-3f) {
if (max_diff > 5e-4f) {
std::cout << "basic result" << std::endl;
print_tensor(output_ref);
std::cout << "lite result" << std::endl;
print_tensor(output);
Tensor tdiff;
tdiff.set_precision(PRECISION(kFloat));
tensor_diff(output_ref, output, tdiff);
std::cout << "diff result" << std::endl;
print_tensor(tdiff);
tdiff.ReleaseRawTensor();
}
}
}
input.ReleaseRawTensor();
output.ReleaseRawTensor();
output_ref.ReleaseRawTensor();
}
static int basic_test = 1;
static int n = 1;
static int c = 3;
static int h = 224;
static int w = 224;
static int act_type = 1;
static float leaky_relu_alpha = 2.f;
static int warmup = 0;
static int repeats = 1;
static int check_result = 1;
static int power_mode = 3;
static int threads = 1;
int main(int argc, const char** argv) {
if (argc < 2) {
std::cout << "usage: ./" << argv[0]
<< "basic_test n c h w act_type leaky_relu_alpha"
" warmup repeats check_result power_mode threads"
<< std::endl;
return 0;
}
if (argc >= 2) {
basic_test = atoi(argv[1]) > 0;
}
if (argc >= 3) {
n = atoi(argv[2]);
}
if (argc >= 4) {
c = atoi(argv[3]);
}
if (argc >= 5) {
h = atoi(argv[4]);
}
if (argc >= 6) {
w = atoi(argv[5]);
}
if (argc >= 7) {
act_type = atoi(argv[6]);
}
if (argc >= 8) {
leaky_relu_alpha = atof(argv[7]);
}
if (argc >= 9) {
warmup = atoi(argv[8]);
}
if (argc >= 10) {
repeats = atoi(argv[9]);
}
if (argc >= 11) {
check_result = atoi(argv[10]);
}
if (argc >= 12) {
power_mode = atoi(argv[11]);
}
if (argc >= 13) {
threads = atoi(argv[12]);
}
// basic test
if (basic_test) {
std::cout << "RUN BASIC TEST BEGIN: " << std::endl;
for (auto& n : {1, 3, 4}) {
for (auto& c : {1, 3, 32}) {
for (auto& h : {5, 64, 112, 224}) {
for (auto& w : {5, 64, 112, 224}) {
for (auto& act_type : {1, 2, 4, 5, 6}) {
for (auto& threads : {1, 2, 4}) {
activation_func(n,
c,
h,
w,
static_cast<ActivationType>(act_type),
leaky_relu_alpha,
0,
1,
1,
threads,
static_cast<PowerMode>(3));
}
}
}
}
}
}
std::cout << "RUN BASIC TEST END: " << std::endl;
}
// costum test
std::cout << "RUN CUSTOM TEST BEGIN: " << std::endl;
activation_func(n,
c,
h,
w,
static_cast<ActivationType>(act_type),
leaky_relu_alpha,
warmup,
repeats,
check_result,
threads,
static_cast<PowerMode>(power_mode));
std::cout << "RUN CUSTOM TEST END: " << std::endl;
return 0;
}
此差异已折叠。
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <memory>
#include <vector>
#include "compute_api.h" // NOLINT
#include "compute_param.h" // NOLINT
#include "compute_utils.h" // NOLINT
#include "paddle_api.h" // NOLINT
#include "utils.h" // NOLINT
using namespace paddle::lite_api; // NOLINT
static int basic_test = 1;
static int batch = 1;
static int in_channel = 32;
static int in_height = 112;
static int in_width = 112;
static int out_channel = 32;
static int group = 1;
static int kernel_h = 3;
static int kernel_w = 3;
static int pad_h0 = 1;
static int pad_h1 = 1;
static int pad_w0 = 1;
static int pad_w1 = 1;
static int stride_h = 1;
static int stride_w = 1;
static int dila_h = 1;
static int dila_w = 1;
static int flag_act = 0;
static int flag_bias = 1;
static float leaky_relu_alpha = 2.f;
static int warmup = 0;
static int repeats = 1;
static int check_result = 1;
static int power_mode = 3;
static int threads = 1;
template <typename Dtype1, typename Dtype2>
static void conv_basic(const Dtype1* din,
Dtype2* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const Dtype1* weights,
const Dtype2* bias,
int group,
int kernel_w,
int kernel_h,
int stride_w,
int stride_h,
int dila_w,
int dila_h,
int pad_w,
int pad_h,
bool flag_bias,
int act_type,
float six = 6.f,
float scale = 1.f) {
Dtype2 beta = 0;
auto src_data = din;
auto dst_data_ref = dout;
auto weights_data = weights;
auto with_bias = flag_bias;
auto bias_data = bias;
int in_num = num;
int out_channels = chout;
int out_h = hout;
int out_w = wout;
int in_channel = chin;
int in_h = hin;
int in_w = win;
int out_c_group = out_channels / group;
int in_c_group = in_channel / group;
for (int n = 0; n < in_num; ++n) {
#pragma omp parallel for collapse(4)
for (int g = 0; g < group; ++g) {
for (int oc = 0; oc < out_c_group; ++oc) {
for (int oh = 0; oh < out_h; ++oh) {
for (int ow = 0; ow < out_w; ++ow) {
int out_idx = n * group * out_c_group * out_h * out_w +
g * out_c_group * out_h * out_w + oc * out_h * out_w +
oh * out_w + ow;
Dtype2 bias_d = with_bias ? (bias_data[g * out_c_group + oc]) : 0;
dst_data_ref[out_idx] = bias_d; // + dst_data_ref[out_idx] * beta;
for (int ic = 0; ic < in_c_group; ++ic) {
for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) {
int iw = ow * stride_w - pad_w + kw * (dila_w);
int ih = oh * stride_h - pad_h + kh * (dila_h);
if (iw < 0 || iw >= in_w) continue;
if (ih < 0 || ih >= in_h) continue;
int iidx = n * in_channel * in_h * in_w +
g * in_c_group * in_h * in_w + ic * in_h * in_w +
ih * in_w + iw;
int widx =
g * out_c_group * in_c_group * kernel_h * kernel_w +
oc * in_c_group * kernel_h * kernel_w +
ic * kernel_h * kernel_w + kh * kernel_w + kw;
dst_data_ref[out_idx] += src_data[iidx] * weights_data[widx];
}
}
}
if (act_type > 0) {
// 1-relu 2-relu6 4-leakyrelu
if (act_type == 1) {
dst_data_ref[out_idx] = dst_data_ref[out_idx] > (Dtype2)0
? dst_data_ref[out_idx]
: (Dtype2)0;
} else if (act_type == 2) {
dst_data_ref[out_idx] = dst_data_ref[out_idx] > (Dtype2)0
? dst_data_ref[out_idx]
: (Dtype2)0;
dst_data_ref[out_idx] = dst_data_ref[out_idx] < (Dtype2)six
? dst_data_ref[out_idx]
: (Dtype2)six;
} else if (act_type == 4) {
dst_data_ref[out_idx] =
dst_data_ref[out_idx] > (Dtype2)0
? dst_data_ref[out_idx]
: (Dtype2)(dst_data_ref[out_idx] * scale);
} else {
printf("this act type: %d does not support \n", act_type);
}
}
}
}
}
}
}
}
shape_t compute_out_dim(const shape_t& dim_in, const ConvParam& param) {
shape_t dim_out = dim_in;
auto paddings = *param.paddings;
auto dilations = *param.dilations;
auto filter_shape = param.filter->shape();
dim_out[1] = filter_shape[0];
auto kernel_h = filter_shape[2];
auto kernel_w = filter_shape[3];
auto h = dim_in[2];
auto w = dim_in[3];
int dila_h = dilations[0];
int dila_w = dilations[1];
int pad_top = paddings[0];
int pad_bottom = paddings[1];
int pad_left = paddings[2];
int pad_right = paddings[3];
int stride_h = param.strides[0];
int stride_w = param.strides[1];
auto kernel_exten = dila_h * (kernel_h - 1) + 1;
auto hout = (h + pad_top + pad_bottom - kernel_exten) / stride_h + 1;
kernel_exten = dila_w * (kernel_w - 1) + 1;
auto wout = (w + pad_left + pad_right - kernel_exten) / stride_w + 1;
dim_out[2] = hout;
dim_out[3] = wout;
return dim_out;
}
void test_conv_fp32(const std::vector<shape_t>& input_dims,
const shape_t& weight_dim,
int group,
const std::vector<int>& strides,
const std::vector<int>& pads,
const std::vector<int>& dilas,
bool flag_bias,
int flag_act,
const int thread_num,
const int power_mode,
const float leakey_relu_scale) {
ComputeEngine<TARGET(kARM)>::env_init(static_cast<PowerMode>(power_mode),
thread_num);
ConvParam param;
param.x = new Tensor;
param.x->set_precision(PRECISION(kFloat));
param.filter = new Tensor;
param.filter->Resize(weight_dim);
param.filter->set_precision(PRECISION(kFloat));
if (flag_bias) {
param.bias = new Tensor;
param.bias->Resize({weight_dim[0]});
param.bias->set_precision(PRECISION(kFloat));
}
param.strides = strides;
param.paddings = std::make_shared<std::vector<int>>(pads);
param.dilations = std::make_shared<std::vector<int>>(dilas);
param.groups = group;
const float six = 6.f;
if (flag_act > 0) {
ActivationParam act_param;
act_param.has_active = true;
act_param.active_type =
static_cast<ActivationType>(flag_act); // 1-relu, 2-relu6, 4-leakyrelu
if (flag_act == 1) {
// param.fuse_relu = true;
} else if (flag_act == 2) {
act_param.Relu_clipped_coef = six;
} else if (flag_act == 4) {
act_param.Leaky_relu_alpha = leakey_relu_scale;
}
param.activation_param = act_param;
}
param.output = new Tensor;
param.output->set_precision(PRECISION(kFloat));
fill_tensor_rand(*param.filter, -1.f, 1.f);
// fill_tensor_const(*param.filter, 1.f);
if (flag_bias) {
fill_tensor_rand(*param.bias, -1.f, 1.f);
// fill_tensor_const(*param.bias, 1.f);
}
auto wptr = param.filter->data<float>();
auto bias_ptr = flag_bias ? param.bias->data<float>() : nullptr;
ComputeEngine<TARGET(kARM)> conv;
conv.CreateOperator("conv2d");
for (auto& dim_in : input_dims) {
param.x->Resize(dim_in);
shape_t out_tmp_dims = compute_out_dim(dim_in, param);
if (out_tmp_dims[2] < 1 || out_tmp_dims[3] < 1) {
continue;
}
param.output->Resize(out_tmp_dims);
break;
}
conv.SetParam(&param);
for (auto& dim_in : input_dims) {
if (weight_dim[1] * group != dim_in[1]) {
"input channel must equal to weights channel\n";
exit(1);
}
shape_t dim_out = compute_out_dim(dim_in, param);
if (dim_out[2] < 1 || dim_out[3] < 1) {
continue;
}
param.x->Resize(dim_in);
param.output->Resize(dim_out);
fill_tensor_rand(*param.x, -1.f, 1.f);
// fill_tensor_const(*param.x, 1.f);
auto din = param.x->data<float>();
Tensor tout_basic;
if (check_result) {
tout_basic.set_precision(PRECISION(kFloat));
tout_basic.Resize(dim_out);
fill_tensor_const(tout_basic, 0.f);
auto dout_basic = tout_basic.mutable_data<float>();
conv_basic<float, float>(din,
dout_basic,
dim_in[0],
dim_out[1],
dim_out[2],
dim_out[3],
dim_in[1],
dim_in[2],
dim_in[3],
wptr,
bias_ptr,
group,
weight_dim[3],
weight_dim[2],
strides[1],
strides[0],
dilas[1],
dilas[0],
pads[2],
pads[0],
flag_bias,
flag_act,
six,
leakey_relu_scale);
}
/// warm up
for (int i = 0; i < warmup; ++i) {
conv.Launch();
}
/// compute
Timer t0;
for (int i = 0; i < repeats; ++i) {
t0.Start();
conv.Launch();
t0.Stop();
}
double gops = 2.0 * dim_production(*param.output) * dim_in[1] *
weight_dim[2] * weight_dim[3] / param.groups;
std::cout << "conv fp32: input shape: (" << dim_in[0] << ", " << dim_in[1]
<< ", " << dim_in[2] << ", " << dim_in[3] << "), output shape: ("
<< dim_out[0] << ", " << dim_out[1] << ", " << dim_out[2] << ", "
<< dim_out[3] << "),running time, avg: " << t0.LapTimes().Avg()
<< ", min time: " << t0.LapTimes().Min()
<< ", total GOPS: " << 1e-9 * gops
<< " GOPS, avg GOPs: " << 1e-6 * gops / t0.LapTimes().Avg()
<< " GOPs, max GOPs: " << 1e-6 * gops / t0.LapTimes().Min()
<< std::endl;
if (check_result) {
double max_ratio = 0;
double max_diff = 0;
tensor_cmp_host(tout_basic, *param.output, max_ratio, max_diff);
std::cout << "compare result, max diff: " << max_diff
<< ", max ratio: " << max_ratio << std::endl;
if (std::abs(max_ratio) > 1e-3f) {
if (max_diff > 5e-4f) {
std::cout << "basic result\n";
print_tensor(tout_basic);
std::cout << "lite result\n";
print_tensor(*param.output);
Tensor tdiff;
tdiff.Resize(tout_basic.shape());
tdiff.set_precision(PRECISION(kFloat));
tensor_diff(tout_basic, *param.output, tdiff);
print_tensor(tdiff);
std::cerr << "test fp32 conv: input: (" << dim_in[0] << ", "
<< dim_in[1] << ", " << dim_in[2] << ", " << dim_in[3]
<< "), output: (" << dim_out[0] << ", " << dim_out[1]
<< ", " << dim_out[2] << ", " << dim_out[3]
<< "), weight dim: (" << weight_dim[0] << ", "
<< weight_dim[1] << ", " << weight_dim[2] << ", "
<< weight_dim[3] << "), pad: " << pads[0] << ", " << pads[1]
<< ", " << pads[2] << ", " << pads[3]
<< ", stride: " << strides[0] << ", " << strides[1]
<< ", dila_: " << dilas[0] << ", " << dilas[1]
<< ", group: " << group
<< ", bias: " << (flag_bias ? "true" : "false")
<< ", act: " << flag_act << ", threads: " << thread_num
<< ", power_mode: " << power_mode << " failed!!\n";
exit(1);
}
}
}
std::cout << "test fp32 conv: input: (" << dim_in[0] << ", " << dim_in[1]
<< ", " << dim_in[2] << ", " << dim_in[3] << "), output: ("
<< dim_out[0] << ", " << dim_out[1] << ", " << dim_out[2] << ", "
<< dim_out[3] << "), weight dim: (" << weight_dim[0] << ", "
<< weight_dim[1] << ", " << weight_dim[2] << ", " << weight_dim[3]
<< "), pad: " << pads[0] << ", " << pads[1] << ", " << pads[2]
<< ", " << pads[3] << ", stride: " << strides[0] << ", "
<< strides[1] << ", dila_: " << dilas[0] << ", " << dilas[1]
<< ", group: " << group
<< ", bias: " << (flag_bias ? "true" : "false")
<< ", act: " << flag_act << ", threads: " << thread_num
<< ", power_mode: " << power_mode << " success!!\n";
}
param.x->ReleaseRawTensor();
param.filter->ReleaseRawTensor();
param.output->ReleaseRawTensor();
if (flag_bias) {
param.bias->ReleaseRawTensor();
}
delete param.x;
delete param.filter;
delete param.output;
delete param.bias;
}
int main(int argc, const char** argv) {
if (argc < 2) {
std::cout << "usage: ./" << argv[0]
<< "basic_test check_result batch in_channel in_height in_width "
"out_channel group kernel_h pad_h0 stride_h dila_h flag_act "
"flag_bias warmup repeats threads power_mode."
<< std::endl;
return 0;
}
if (argc >= 2) {
basic_test = atoi(argv[1]);
}
if (argc >= 3) {
check_result = atoi(argv[2]);
}
if (argc >= 4) {
batch = atoi(argv[3]);
}
if (argc >= 5) {
in_channel = atoi(argv[4]);
}
if (argc >= 6) {
in_height = atoi(argv[5]);
}
if (argc >= 7) {
in_width = atoi(argv[6]);
}
if (argc >= 8) {
out_channel = atoi(argv[7]);
}
if (argc >= 9) {
group = atof(argv[8]);
}
if (argc >= 10) {
if (argc >= 13) {
kernel_h = atoi(argv[9]);
kernel_w = kernel_h;
pad_h0 = atoi(argv[10]);
pad_h1 = pad_h0;
pad_w0 = pad_h0;
pad_w1 = pad_h0;
stride_h = atoi(argv[11]);
stride_w = stride_h;
dila_h = atoi(argv[12]);
dila_w = dila_h;
} else {
std::cout
<< "kernel_h padh0 stride_h dila_h must be set at the same time."
<< std::endl;
}
}
if (argc >= 14) {
flag_act = atoi(argv[13]);
}
if (argc >= 15) {
flag_bias = atoi(argv[14]);
}
if (argc >= 16) {
warmup = atoi(argv[15]);
}
if (argc >= 17) {
repeats = atoi(argv[16]);
}
if (argc >= 18) {
threads = atoi(argv[17]);
}
if (argc >= 19) {
power_mode = atoi(argv[18]);
}
if (argc >= 20) {
leaky_relu_alpha = atof(argv[19]);
}
// basic test
if (basic_test) {
std::cout << "RUN BASIC TEST BEGIN: " << std::endl;
for (auto& cin : {1, 3, 8}) {
for (auto& cout : {1, 5, 16}) {
for (auto& g : {1, 2}) {
for (auto& kw : {1, 2, 3}) {
for (auto& kh : {1, 2, 3}) {
for (auto& stride : {1, 2}) {
for (auto& pad_left : {0, 2}) {
for (auto& pad_right : {0, 2}) {
for (auto& pad_top : {0, 2}) {
for (auto& pad_bottom : {0, 2}) {
for (auto& dila : {1, 2}) {
for (auto& flag_bias : {false, true}) {
for (auto& flag_act : {0, 1, 2, 4}) {
for (auto& threads : {1, 2, 4}) {
if (cin % g != 0 || cout % g != 0) {
continue;
}
std::vector<shape_t> dims;
shape_t weights_dim({cout, cin / g, kh, kw});
for (auto& batch : {1, 2}) {
for (auto& h : {1, 3, 19, 32}) {
dims.push_back(shape_t({batch, cin, h, h}));
}
}
// skip 3x3 depthwise conv
if (g == cin && cin == cout && kw == 3 &&
kh == 3) {
break;
}
// skip 3x3s1 direct conv
if (g == 1 && (cin != 1 || cout != 1) &&
kw == 3 && kh == 3 && stride == 1) {
break;
}
const float leakey_relu_scale = 2.22;
test_conv_fp32(
dims,
weights_dim,
g,
{stride, stride},
{pad_top, pad_bottom, pad_left, pad_right},
{dila, dila},
flag_bias,
flag_act,
threads,
3,
leakey_relu_scale);
}
}
}
}
}
}
}
}
}
}
}
}
}
}
std::cout << "RUN BASIC TEST END: " << std::endl;
}
// costum test
std::cout << "RUN CUSTOM TEST BEGIN: " << std::endl;
std::vector<shape_t> dims;
dims.emplace_back(shape_t({batch, in_channel, in_height, in_width}));
shape_t weights_dim({out_channel, in_channel / group, kernel_h, kernel_w});
test_conv_fp32(dims,
weights_dim,
group,
{stride_h, stride_w},
{pad_h0, pad_h1, pad_w0, pad_w1},
{dila_h, dila_w},
flag_bias,
flag_act,
threads,
3,
leaky_relu_alpha);
std::cout << "RUN CUSTOM TEST END: " << std::endl;
return 0;
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <iostream>
#include <random>
#include <vector>
#include "compute_api.h" // NOLINT
#include "paddle_api.h" // NOLINT
using namespace paddle::lite_api; // NOLINT
template <typename Dtype>
void fill_tensor_host_const_impl(Dtype* dio, Dtype value, int64_t size) {
for (int64_t i = 0; i < size; ++i) {
dio[i] = value;
}
}
int64_t dim_production(const Tensor& t) {
shape_t s = t.shape();
int64_t n = 1;
for (int i = 0; i < s.size(); ++i) {
n *= s[i];
}
return n;
}
/**
* \brief Fill the host tensor buffer with rand value.
* \param tensor The reference of input tensor.
*/
void fill_tensor_const(Tensor& tensor, float value) { // NOLINT
int64_t size = dim_production(tensor);
PrecisionType type = tensor.precision();
switch (type) {
case PRECISION(kInt8):
fill_tensor_host_const_impl(
tensor.mutable_data<int8_t>(), static_cast<signed char>(value), size);
break;
case PRECISION(kInt32):
fill_tensor_host_const_impl(
tensor.mutable_data<int>(), static_cast<int>(value), size);
break;
case PRECISION(kFloat):
fill_tensor_host_const_impl(
tensor.mutable_data<float>(), static_cast<float>(value), size);
break;
default:
std::cerr << "data type is unsupported now." << std::endl;
assert(0);
}
}
template <typename Dtype>
void fill_tensor_host_rand_impl(Dtype* dio, int64_t size) {
for (int64_t i = 0; i < size; ++i) {
Dtype rand_x = static_cast<Dtype>(rand() % 256); // NOLINT
dio[i] = (rand_x - 128) / 128;
}
}
template <>
void fill_tensor_host_rand_impl<signed char>(signed char* dio, int64_t size) {
for (int64_t i = 0; i < size; ++i) {
dio[i] = rand() % 256 - 128; // NOLINT
}
}
template <>
void fill_tensor_host_rand_impl<unsigned char>(unsigned char* dio,
int64_t size) {
for (int64_t i = 0; i < size; ++i) {
dio[i] = rand() % 256; // NOLINT
}
}
/**
* \brief Fill the host tensor buffer with rand value.
* \param The reference of input tensor.
*/
void fill_tensor_rand(Tensor& tensor) { // NOLINT
int64_t size = dim_production(tensor);
PrecisionType type = tensor.precision();
switch (type) {
case PRECISION(kInt8):
fill_tensor_host_rand_impl(tensor.mutable_data<int8_t>(), size);
break;
case PRECISION(kInt32):
fill_tensor_host_rand_impl(tensor.mutable_data<int>(), size);
break;
case PRECISION(kFloat):
fill_tensor_host_rand_impl(tensor.mutable_data<float>(), size);
break;
default:
std::cerr << "data type: is unsupported now" << std::endl;
assert(0);
}
}
template <typename Dtype>
void fill_tensor_host_rand_impl2(Dtype* dio,
Dtype vstart,
Dtype vend,
int64_t size) {
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<float> dis(0, 1.f);
for (int64_t i = 0; i < size; ++i) {
Dtype random_num = static_cast<Dtype>(vstart + (vend - vstart) * dis(gen));
dio[i] = random_num;
}
}
/**
* \brief Fill the host tensor buffer with rand value from vstart to vend.
* \param tensor The reference of input tensor.
*/
void fill_tensor_rand(Tensor& tensor, float vstart, float vend) { // NOLINT
int64_t size = dim_production(tensor);
PrecisionType type = tensor.precision();
switch (type) {
case PRECISION(kInt8):
fill_tensor_host_rand_impl2(tensor.mutable_data<int8_t>(),
static_cast<signed char>(vstart),
static_cast<signed char>(vend),
size);
break;
case PRECISION(kInt32):
fill_tensor_host_rand_impl2(tensor.mutable_data<int>(),
static_cast<int>(vstart),
static_cast<int>(vend),
size);
break;
case PRECISION(kFloat):
fill_tensor_host_rand_impl2(
tensor.mutable_data<float>(), vstart, vend, size);
break;
default:
std::cerr << "data type: is unsupported now" << std::endl;
assert(0);
}
}
template <typename Dtype>
void print_tensor_host_impl(const Dtype* din, int64_t size, int64_t width);
template <>
void print_tensor_host_impl(const float* din, int64_t size, int64_t width) {
for (int i = 0; i < size; ++i) {
printf("%.6f ", din[i]);
if ((i + 1) % width == 0) {
printf("\n");
}
}
printf("\n");
}
template <>
void print_tensor_host_impl(const int* din, int64_t size, int64_t width) {
for (int i = 0; i < size; ++i) {
printf("%d ", din[i]);
if ((i + 1) % width == 0) {
printf("\n");
}
}
printf("\n");
}
template <>
void print_tensor_host_impl(const signed char* din,
int64_t size,
int64_t width) {
for (int i = 0; i < size; ++i) {
printf("%d ", din[i]);
if ((i + 1) % width == 0) {
printf("\n");
}
}
printf("\n");
}
/**
* \brief Print the data in host tensor.
* \param tensor The reference of input tensor.
*/
void print_tensor(const Tensor& tensor) {
printf("host tensor data size: %ld\n", dim_production(tensor));
int64_t size = dim_production(tensor);
int64_t width = tensor.shape()[tensor.shape().size() - 1];
PrecisionType type = tensor.precision();
switch (type) {
case PRECISION(kInt8):
print_tensor_host_impl(tensor.data<int8_t>(), size, width);
break;
case PRECISION(kInt32):
print_tensor_host_impl(tensor.data<int>(), size, width);
break;
case PRECISION(kFloat):
print_tensor_host_impl(tensor.data<float>(), size, width);
break;
default:
std::cerr << "data type: is unsupported now" << std::endl;
assert(0);
}
}
template <typename Dtype>
double tensor_mean_value_host_impl(const Dtype* din, int64_t size) {
double sum = 0.0;
for (int64_t i = 0; i < size; ++i) {
sum += din[i];
}
return sum / size;
}
double tensor_mean(const Tensor& tensor) {
int64_t size = dim_production(tensor);
PrecisionType type = tensor.precision();
switch (type) {
case PRECISION(kInt8):
return tensor_mean_value_host_impl(tensor.data<int8_t>(), size);
case PRECISION(kInt32):
return tensor_mean_value_host_impl(tensor.data<int>(), size);
case PRECISION(kFloat):
return tensor_mean_value_host_impl(tensor.data<float>(), size);
default:
std::cerr << "data type: is unsupported now" << std::endl;
assert(0);
}
return 0.0;
}
template <typename dtype>
void data_diff_kernel(const dtype* src1_truth,
const dtype* src2,
int size,
double& max_ratio, // NOLINT
double& max_diff) { // NOLINT
const double eps = 1e-6f;
max_diff = fabs(src1_truth[0] - src2[0]);
max_ratio = fabs(max_diff) / (std::abs(src1_truth[0]) + eps);
for (int i = 1; i < size; ++i) {
double diff = fabs(src1_truth[i] - src2[i]);
double ratio = fabs(diff) / (std::abs(src1_truth[i]) + eps);
if (max_ratio < ratio) {
max_diff = diff;
max_ratio = ratio;
}
}
}
void tensor_cmp_host(const Tensor& src1_basic,
const Tensor& src2,
double& max_ratio, // NOLINT
double& max_diff) { // NOLINT
max_ratio = 0.;
max_diff = 0.;
int64_t size = dim_production(src1_basic);
int64_t size2 = dim_production(src2);
if (size != size2) {
std::cerr << "ERROR: tensor_cmp_host: wrong shape" << std::endl;
assert(0);
}
auto ptype1 = src1_basic.precision();
auto ptype2 = src2.precision();
if (ptype1 != ptype2) {
std::cerr << "ERROR: tensor_cmp_host: wrong data type" << std::endl;
assert(0);
}
if (size == 0) return;
switch (src1_basic.precision()) {
case PRECISION(kFloat):
data_diff_kernel(src1_basic.data<float>(),
src2.data<float>(),
size,
max_ratio,
max_diff);
return;
case PRECISION(kInt32):
data_diff_kernel(
src1_basic.data<int>(), src2.data<int>(), size, max_ratio, max_diff);
return;
case PRECISION(kInt8):
data_diff_kernel(src1_basic.data<int8_t>(),
src2.data<int8_t>(),
size,
max_ratio,
max_diff);
return;
default:
std::cerr << "data type: is unsupported now" << std::endl;
assert(0);
}
}
template <typename dtype>
void tensor_diff_kernel(const dtype* src1,
const dtype* src2,
dtype* dst,
int64_t size) {
for (int i = 0; i < size; ++i) {
dst[i] = src1[i] - src2[i];
}
}
void tensor_diff(const Tensor& t1, const Tensor& t2, Tensor& tdiff) { // NOLINT
int64_t size1 = dim_production(t1);
int64_t size2 = dim_production(t2);
if (size1 != size2) {
std::cerr << "ERROR: tensor_diff: wrong shape" << std::endl;
assert(0);
}
auto ptype1 = t1.precision();
auto ptype2 = t2.precision();
if (ptype1 != ptype2) {
std::cerr << "ERROR: tensor_diff: wrong data type" << std::endl;
assert(0);
}
tdiff.Resize(t1.shape());
switch (t1.precision()) {
case PRECISION(kFloat):
tensor_diff_kernel(t1.data<float>(),
t2.data<float>(),
tdiff.mutable_data<float>(),
size1);
return;
case PRECISION(kInt32):
tensor_diff_kernel(
t1.data<int>(), t2.data<int>(), tdiff.mutable_data<int>(), size1);
case PRECISION(kInt8):
tensor_diff_kernel(t1.data<int8_t>(),
t2.data<int8_t>(),
tdiff.mutable_data<int8_t>(),
size1);
return;
default:
std::cerr << "data type: is unsupported now" << std::endl;
assert(0);
}
}
template <typename T>
class TimeList {
public:
void Clear() { laps_t_.clear(); }
void Add(T t) { laps_t_.push_back(t); }
T Last(size_t offset = 0) const {
if (!Size(offset)) {
return 0;
}
return laps_t_.back();
}
T Max(size_t offset = 0) const {
if (!Size(offset)) {
return 0;
}
return *std::max_element((laps_t_.begin() + offset), laps_t_.end());
}
T Min(size_t offset = 0) const {
if (!Size(offset)) {
return 0;
}
return *std::min_element((laps_t_.begin() + offset), laps_t_.end());
}
T Sum(size_t offset = 0) const {
if (!Size(offset)) {
return 0;
}
return std::accumulate((laps_t_.begin() + offset), laps_t_.end(), 0.0);
}
size_t Size(size_t offset = 0) const {
size_t size = (laps_t_.size() <= offset) ? 0 : (laps_t_.size() - offset);
return size;
}
T Avg(size_t offset = 0) const {
if (!Size(offset)) {
return 0;
}
return Sum(offset) / Size(offset);
}
const std::vector<T>& Raw() const { return laps_t_; }
private:
std::vector<T> laps_t_;
};
class Timer {
public:
Timer() = default;
virtual ~Timer() = default;
void Reset() { laps_t_.Clear(); }
void Start() { t_start_ = std::chrono::system_clock::now(); }
float Stop() {
t_stop_ = std::chrono::system_clock::now();
auto ts = std::chrono::duration_cast<std::chrono::microseconds>(t_stop_ -
t_start_);
float elapse_ms = 1000.f * static_cast<float>(ts.count()) *
std::chrono::microseconds::period::num /
std::chrono::microseconds::period::den;
this->laps_t_.Add(elapse_ms);
return elapse_ms;
}
float AvgLapTimeMs() const { return laps_t_.Avg(); }
const TimeList<float>& LapTimes() const { return laps_t_; }
protected:
TimeList<float> laps_t_;
private:
std::chrono::time_point<std::chrono::system_clock> t_start_, t_stop_;
};
......@@ -93,6 +93,11 @@ bool ActivationOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
return true;
}
bool ActivationOp::SetParam(ParamBase* param) {
param_ = *static_cast<ActivationParam*>(param);
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
......
......@@ -31,6 +31,8 @@ class ActivationOp : public OpLite {
bool InferShapeImpl() const override;
bool SetParam(ParamBase* param) override;
bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
......
......@@ -39,6 +39,11 @@ bool ConvOpLite::CheckShape() const {
return true;
}
bool ConvOpLite::SetParam(ParamBase* param) {
param_ = *static_cast<ConvParam*>(param);
return true;
}
inline int ConvOutputSize(int input_size,
int filter_size,
int dilation,
......
......@@ -61,6 +61,7 @@ class ConvOpLite : public OpLite {
output_dims.production() * input_dims[1] / param_.groups;
}
#endif
bool SetParam(ParamBase* param) override;
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) override {
......
......@@ -21,6 +21,7 @@ BUILD_DIR=$(pwd)
OPTMODEL_DIR=""
BUILD_TAILOR=OFF
BUILD_CV=OFF
BUILD_COMPUTE_API=OFF
WITH_LOG=ON
WITH_PROFILE=OFF
BUILD_NPU=OFF
......@@ -130,6 +131,7 @@ function make_tiny_publish_so {
-DANDROID_STL_TYPE=$android_stl \
-DLITE_BUILD_EXTRA=$BUILD_EXTRA \
-DLITE_WITH_CV=$BUILD_CV \
-DLITE_WITH_COMPUTE_API=$BUILD_COMPUTE_API \
-DLITE_BUILD_TAILOR=$BUILD_TAILOR \
-DLITE_OPTMODEL_DIR=$OPTMODEL_DIR \
-DLITE_WITH_NPU=$BUILD_NPU \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册