提交 c6c1d514 编写于 作者: qnqinan's avatar qnqinan

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle-Lite into develop

......@@ -180,13 +180,16 @@ function(add_cuda_static_lib alias cuda_lib_paths file_name)
add_library(${alias} STATIC IMPORTED GLOBAL)
set_property(TARGET ${alias} PROPERTY IMPORTED_LOCATION ${ABS_PATH})
set(CUDA_STATIC_MODULES ${CUDA_STATIC_MODULES} ${alias} PARENT_SCOPE)
if (NOT ABS_PATH)
message(FATAL_ERROR "Can not find CUDA static library: ${file_name}")
endif()
endfunction()
add_cuda_static_lib(cudart_static CUDNN_CHECK_LIBRARY_DIRS libcudart_static.a)
add_cuda_static_lib(cublas_static CUDNN_CHECK_LIBRARY_DIRS libcublas_static.a)
add_cuda_static_lib(curand_static CUDNN_CHECK_LIBRARY_DIRS libcurand_static.a)
add_cuda_static_lib(culibos_static CUDNN_CHECK_LIBRARY_DIRS libculibos.a)
if((${CUDA_VERSION} GREATER 10.0) OR (${CUDA_VERSION} EQUAL 10.0))
if(NOT ${CUDA_VERSION} LESS 10.1)
add_cuda_static_lib(cublasLt_static CUDNN_CHECK_LIBRARY_DIRS libcublasLt_static.a)
endif()
......
......@@ -216,6 +216,8 @@ if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM)
COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/cxx/makefiles/mobile_full/Makefile.${ARM_TARGET_OS}.${ARM_TARGET_ARCH_ABI}" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx/mobile_full/Makefile"
COMMAND cp -r "${CMAKE_SOURCE_DIR}/lite/demo/cxx/mobile_light" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx"
COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/cxx/makefiles/mobile_light/Makefile.${ARM_TARGET_OS}.${ARM_TARGET_ARCH_ABI}" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx/mobile_light/Makefile"
COMMAND cp -r "${CMAKE_SOURCE_DIR}/lite/demo/cxx/mobile_detection" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx"
COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/cxx/makefiles/mobile_detection/Makefile.${ARM_TARGET_OS}.${ARM_TARGET_ARCH_ABI}" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx/mobile_detection/Makefile"
COMMAND cp "${CMAKE_SOURCE_DIR}/lite/api/paddle_*.h" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx/include"
)
add_dependencies(publish_inference_android_cxx_demos logging gflags)
......@@ -228,6 +230,9 @@ if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM)
COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/cxx/README.md" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx"
COMMAND cp -r "${CMAKE_SOURCE_DIR}/lite/demo/cxx/mobile_light" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx"
COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/cxx/makefiles/mobile_light/Makefile.${ARM_TARGET_OS}.${ARM_TARGET_ARCH_ABI}" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx/mobile_light/Makefile"
COMMAND cp -r "${CMAKE_SOURCE_DIR}/lite/demo/cxx/mobile_detection" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx"
COMMAND cp "${CMAKE_SOURCE_DIR}/lite/demo/cxx/makefiles/mobile_detection/Makefile.${ARM_TARGET_OS}.${ARM_TARGET_ARCH_ABI}" "${INFER_LITE_PUBLISH_ROOT}/demo/cxx/mobile_detection/Makefile"
)
add_dependencies(tiny_publish_cxx_lib publish_inference_android_cxx_demos)
endif()
......
......@@ -44,9 +44,10 @@ void OutputOptModel(const std::string& load_model_dir,
const std::vector<std::vector<int64_t>>& input_shapes) {
lite_api::CxxConfig config;
config.set_model_dir(load_model_dir);
std::vector<Place> vaild_places = {Place{TARGET(kARM), PRECISION(kFloat)},
Place{TARGET(kX86), PRECISION(kFloat)},
Place{TARGET(kOpenCL), PRECISION(kFloat)}};
std::vector<Place> vaild_places = {
Place{TARGET(kARM), PRECISION(kFloat)},
Place{TARGET(kX86), PRECISION(kFloat)},
};
if (FLAGS_is_quantized_model) {
vaild_places.insert(vaild_places.begin(),
Place{TARGET(kARM), PRECISION(kInt8)});
......
......@@ -21,14 +21,14 @@
#include "lite/api/paddle_use_passes.h"
#include "lite/api/test_helper.h"
#include "lite/core/device_info.h"
#include "lite/tests/utils/timer.h"
#include "lite/core/profile/timer.h"
#include "lite/utils/cp_logging.h"
#include "lite/utils/string.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/basic_profiler.h"
#endif // LITE_WITH_PROFILE
using paddle::lite::Timer;
using paddle::lite::profile::Timer;
DEFINE_string(input_shape,
"1,3,224,224",
......@@ -102,20 +102,20 @@ void Run(const std::vector<std::vector<int64_t>>& input_shapes,
Timer ti;
for (int j = 0; j < repeat; ++j) {
ti.start();
ti.Start();
predictor->Run();
ti.end();
LOG(INFO) << "iter: " << j << ", time: " << ti.latest_time() << " ms";
float t = ti.Stop();
LOG(INFO) << "iter: " << j << ", time: " << t << " ms";
}
LOG(INFO) << "================== Speed Report ===================";
LOG(INFO) << "Model: " << model_dir
<< ", power_mode: " << static_cast<int>(power_mode)
<< ", threads num " << thread_num << ", warmup: " << warmup_times
<< ", repeats: " << repeat << ", avg time: " << ti.get_average_ms()
<< ", repeats: " << repeat << ", avg time: " << ti.LapTimes().Avg()
<< " ms"
<< ", min time: " << ti.get_min_time() << " ms"
<< ", max time: " << ti.get_max_time() << " ms.";
<< ", min time: " << ti.LapTimes().Min() << " ms"
<< ", max time: " << ti.LapTimes().Max() << " ms.";
auto output = predictor->GetOutput(0);
auto out = output->data<float>();
......
......@@ -121,6 +121,7 @@ template void Tensor::CopyFromCpu<int, TargetType::kARM>(const int *);
template void Tensor::CopyFromCpu<float, TargetType::kARM>(const float *);
template void Tensor::CopyFromCpu<int8_t, TargetType::kARM>(const int8_t *);
template void Tensor::CopyFromCpu<int, TargetType::kCUDA>(const int *);
template void Tensor::CopyFromCpu<int64_t, TargetType::kCUDA>(const int64_t *);
template void Tensor::CopyFromCpu<float, TargetType::kCUDA>(const float *);
template void Tensor::CopyFromCpu<int8_t, TargetType::kCUDA>(const int8_t *);
......
......@@ -12,20 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
......
......@@ -79,6 +79,7 @@ if (NOT HAS_ARM_MATH_LIB_DIR)
conv5x5s1_depthwise_int8.cc
conv5x5s1_depthwise_fp32.cc
conv5x5s2_depthwise_fp32.cc
conv3x3_winograd_fp32_c4.cc
conv_winograd_3x3.cc
conv_impl.cc
softmax.cc
......
......@@ -32,8 +32,10 @@ void col2im<float>(const float* data_col,
const int width,
const int kernel_h,
const int kernel_w,
const int pad_h,
const int pad_w,
const int pad_h0,
const int pad_h1,
const int pad_w0,
const int pad_w1,
const int stride_h,
const int stride_w,
const int dilation_h,
......@@ -41,19 +43,22 @@ void col2im<float>(const float* data_col,
float* data_im) {
memset(data_im, 0, height * width * channels * sizeof(float));
const int output_h =
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
(height + pad_h0 + pad_h1 - (dilation_h * (kernel_h - 1) + 1)) /
stride_h +
1;
const int output_w =
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
(width + pad_w0 + pad_w1 - (dilation_w * (kernel_w - 1) + 1)) / stride_w +
1;
const int channel_size = height * width;
for (int channel = channels; channel--; data_im += channel_size) {
for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
int input_row = -pad_h + kernel_row * dilation_h;
int input_row = -pad_h0 + kernel_row * dilation_h;
for (int output_rows = output_h; output_rows; output_rows--) {
if (!is_a_ge_zero_and_a_lt_b(input_row, height)) {
data_col += output_w;
} else {
int input_col = -pad_w + kernel_col * dilation_w;
int input_col = -pad_w0 + kernel_col * dilation_w;
for (int output_col = output_w; output_col; output_col--) {
if (is_a_ge_zero_and_a_lt_b(input_col, width)) {
data_im[input_row * width + input_col] += *data_col;
......
......@@ -26,8 +26,10 @@ void col2im(const Dtype* data_col,
const int width,
const int kernel_h,
const int kernel_w,
const int pad_h,
const int pad_w,
const int pad_h0,
const int pad_h1,
const int pad_w0,
const int pad_w1,
const int stride_h,
const int stride_w,
const int dilation_h,
......
此差异已折叠。
......@@ -254,6 +254,7 @@ inline void prepack_input_nxwc4_dw(const float* din,
LOG(FATAL) << "prepack_dw_input, valid height must > zero";
}
float32x4_t vzero = vdupq_n_f32(0.f);
auto out_data = dout;
int size_w = we - ws;
int w0 = ws < 0 ? 0 : ws;
......@@ -269,6 +270,7 @@ inline void prepack_input_nxwc4_dw(const float* din,
bool flag_ext_l = left_remain > 0;
int left_sl = 4 - left_remain;
int left_valid_sl = left_sl > width ? width : left_sl;
uint32x4_t vmask_padl;
bool flag_mask_l = false;
if (flag_ext_l) {
......@@ -290,6 +292,7 @@ inline void prepack_input_nxwc4_dw(const float* din,
}
int size_c = width * height;
for (int h = hs; h < he; ++h) {
dout = out_data + (h - hs) * 4 * size_w;
auto ptr_c0 = din + cs * size_c + h * width;
auto ptr_c1 = ptr_c0 + size_c;
auto ptr_c2 = ptr_c1 + size_c;
......@@ -351,10 +354,10 @@ inline void prepack_input_nxwc4_dw(const float* din,
}
transpose_4x4(vc0, vc1, vc2, vc3, dout);
dout += 16;
ptr_c0 += left_sl;
ptr_c1 += left_sl;
ptr_c2 += left_sl;
ptr_c3 += left_sl;
ptr_c0 += left_valid_sl;
ptr_c1 += left_valid_sl;
ptr_c2 += left_valid_sl;
ptr_c3 += left_valid_sl;
}
/// valid
for (int i = 0; i < cnt_valid; ++i) {
......@@ -986,7 +989,9 @@ inline bool write_to_output_c4_fp32(const float* din,
int size_h = (he > height ? height : he) - hs; // size_h == hei_n
int cnt = (width - ws) / w4;
int valid_we = we > width ? width : we;
int cnt = (valid_we - ws) / w4;
int remain = valid_we - ws - cnt * w4;
for (int i = 0; i < size_h; i++) {
int size_w = i * width;
......@@ -1087,12 +1092,12 @@ inline bool write_to_output_c4_fp32(const float* din,
#endif
}
}
if (we > width) {
if (remain > 0) {
int offset = i * w_round * c4 + c4 * w4 * cnt;
din_hei_ptr = ptr_din + offset;
int j = we - w4;
int j = 0;
if (flag_relu) {
for (; j < width; ++j) {
for (; j < remain; ++j) {
*(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0.f);
*(doutc1_ptr++) = LITEMAX(din_hei_ptr[1], 0.f);
*(doutc2_ptr++) = LITEMAX(din_hei_ptr[2], 0.f);
......@@ -1100,7 +1105,7 @@ inline bool write_to_output_c4_fp32(const float* din,
din_hei_ptr += w4;
}
} else {
for (; j < width; ++j) {
for (; j < remain; ++j) {
*(doutc0_ptr++) = din_hei_ptr[0];
*(doutc1_ptr++) = din_hei_ptr[1];
*(doutc2_ptr++) = din_hei_ptr[2];
......
......@@ -314,7 +314,23 @@ void fill_bias_int8(int* tensor,
const int* bias,
int channel,
int channel_size);
// new winograd
void weight_trans_c4(
float* dest, const float* src, int ic, int oc, void* workspace);
void conv_compute_6x6_3x3(const float* input,
float* output,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const float* weight,
const float* bias,
const operators::ConvParam& param,
ARMContext* ctx);
} // namespace math
} // namespace arm
} // namespace lite
......
......@@ -37,6 +37,16 @@ void sgemm_prepack_c4(int M,
bool has_bias,
bool has_relu,
ARMContext* ctx);
void sgemm_prepack_c4_small(int M,
int N,
int K,
const float* A_packed,
const float* B,
float* C,
const float* bias,
bool has_bias,
bool has_relu,
ARMContext* ctx);
} // namespace math
} // namespace arm
} // namespace lite
......
......@@ -11,6 +11,7 @@ nv_library(cuda_transpose SRCS transpose.cu DEPS ${cuda_static_deps})
nv_library(cudnn_conv SRCS cudnn_conv.cc DEPS cuda_activation cuda_scale
cuda_type_trans ${cuda_static_deps})
nv_library(cuda_elementwise SRCS elementwise.cu DEPS ${cuda_static_deps})
nv_library(cudnn_pool SRCS cudnn_pool.cc DEPS ${cuda_static_deps})
nv_library(cuda_gemm SRCS gemm.cc DEPS ${cuda_static_deps})
nv_library(cuda_batched_gemm SRCS batched_gemm.cc DEPS ${cuda_static_deps})
......@@ -22,6 +23,7 @@ set (
cuda_type_trans
cuda_transpose
cuda_elementwise
cudnn_pool
cuda_gemm
cuda_batched_gemm
)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/cuda/math/cudnn_pool.h"
#include "lite/backends/cuda/math/activation.h"
#include "lite/backends/cuda/math/scale.h"
#include "lite/backends/cuda/math/type_trans.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
inline void UpdatePadding(std::vector<int>* paddings,
const bool global_pooling,
const bool adaptive,
const std::vector<int>& data_dims,
const std::vector<int>& strides,
const std::vector<int>& ksize) {
if (paddings->size() == data_dims.size()) {
for (size_t i = 0; i < data_dims.size(); ++i) {
int copy_pad = *(paddings->begin() + 2 * i);
paddings->insert(paddings->begin() + 2 * i + 1, copy_pad);
}
} else {
CHECK(data_dims.size() * 2 == paddings->size())
<< "Paddings size should be the same or twice as the pooling size.";
}
if (global_pooling || adaptive) {
for (auto it = paddings->begin(); it != paddings->end(); it++) {
*it = 0;
}
}
}
inline void UpdateKsize(std::vector<int>* ksize,
const std::vector<int>& data_dims) {
ksize->resize(static_cast<size_t>(data_dims.size()));
for (size_t i = 0; i < ksize->size(); ++i) {
*(ksize->begin() + i) = static_cast<int>(data_dims[i]);
}
}
template <>
bool CudnnPool2DNHWC<PRECISION(kFloat)>::create(
const operators::PoolParam& param, Context<TARGET(kCUDA)>* ctx) {
return true;
}
template <>
bool CudnnPool2DNHWC<PRECISION(kFloat)>::init(const operators::PoolParam& param,
Context<TARGET(kCUDA)>* ctx) {
this->stream_ = ctx->exec_stream();
CUDNN_CHECK(cudnnCreate(&this->handle_));
CUDNN_CHECK(cudnnSetStream(this->handle_, this->stream_));
cudnnCreateTensorDescriptor(&this->input_desc_);
cudnnCreateTensorDescriptor(&this->output_desc_);
cudnnCreatePoolingDescriptor(&this->pooling_desc_);
return create(param, ctx);
}
template <>
bool CudnnPool2DNHWC<PRECISION(kFloat)>::run(
const operators::PoolParam& param) {
auto x_dims = param.x->dims();
auto o_dims = param.output->dims();
int batch = x_dims[0];
const float* in_data = param.x->data<float>();
float* out_data = param.output->mutable_data<float>(TARGET(kCUDA));
int ih = x_dims[1];
int iw = x_dims[2]; // nchw
int ic = x_dims[3];
int oh = o_dims[1];
int ow = o_dims[2];
int oc = o_dims[3];
std::vector<int> ksize = param.ksize;
std::vector<int> strides = param.strides;
std::vector<int> paddings = *(param.paddings.get());
std::string pooling_type = param.pooling_type;
bool global_pooling = param.global_pooling;
bool exclusive = param.exclusive;
bool adaptive = param.adaptive;
std::vector<int> data_dims = {ih, iw};
UpdatePadding(&paddings, global_pooling, adaptive, data_dims, strides, ksize);
if (data_dims.size() * 2 == paddings.size()) {
for (size_t i = 0; i < data_dims.size(); ++i) {
paddings.erase(paddings.begin() + i + 1);
}
}
if (global_pooling) {
UpdateKsize(&ksize, data_dims);
}
CUDNN_CHECK(cudnnSetTensor4dDescriptor(this->input_desc_,
CUDNN_TENSOR_NHWC,
CUDNN_DATA_FLOAT,
batch,
ic,
ih,
iw));
CUDNN_CHECK(cudnnSetTensor4dDescriptor(this->output_desc_,
CUDNN_TENSOR_NHWC,
CUDNN_DATA_FLOAT,
batch,
oc,
oh,
ow));
cudnnPoolingMode_t mode;
if (pooling_type == "max") {
mode = CUDNN_POOLING_MAX;
} else {
mode = exclusive ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING
: CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
}
CUDNN_CHECK(cudnnSetPoolingNdDescriptor(this->pooling_desc_,
mode,
CUDNN_NOT_PROPAGATE_NAN,
ksize.size(),
ksize.data(),
paddings.data(),
strides.data()));
float alpha = 1.0f;
float beta = 0.0f;
CUDNN_CHECK(cudnnPoolingForward(this->handle_,
this->pooling_desc_,
&alpha,
this->input_desc_,
in_data,
&beta,
this->output_desc_,
out_data));
return true;
}
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cudnn.h>
#include <string>
#include <vector>
#include "lite/api/paddle_place.h"
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/core/context.h"
#include "lite/core/target_wrapper.h"
#include "lite/operators/op_params.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <PrecisionType Ptype_out>
class CudnnPool2DBase {
public:
CudnnPool2DBase()
: handle_(NULL),
input_desc_(NULL),
output_desc_(NULL),
pooling_desc_(NULL) {}
~CudnnPool2DBase() {
if (handle_ != NULL) {
CUDNN_CHECK(cudnnDestroy(handle_));
}
if (input_desc_) {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc_));
}
if (output_desc_) {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(output_desc_));
}
if (pooling_desc_) {
cudnnDestroyPoolingDescriptor(pooling_desc_);
}
}
protected:
cudaStream_t stream_;
cudnnHandle_t handle_;
cudnnTensorDescriptor_t input_desc_;
cudnnTensorDescriptor_t output_desc_;
cudnnPoolingDescriptor_t pooling_desc_;
};
template <PrecisionType Ptype_out>
class CudnnPool2DNHWC : public CudnnPool2DBase<Ptype_out> {
public:
CudnnPool2DNHWC() : CudnnPool2DBase<Ptype_out>() {}
virtual ~CudnnPool2DNHWC() = default;
virtual bool init(const operators::PoolParam& param,
Context<TARGET(kCUDA)>* ctx);
virtual bool create(const operators::PoolParam& param,
Context<TARGET(kCUDA)>* ctx);
virtual bool run(const operators::PoolParam& param);
};
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
......@@ -13,13 +13,55 @@
// limitations under the License.
#include "lite/backends/cuda/math/elementwise.h"
#include "lite/backends/cuda/math/utils.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <typename Dtype>
__global__ void elementwise_kernel(const size_t total,
const Dtype* x_data,
const Dtype* y_data,
Dtype* out_data,
int pre,
int n,
int post,
BinaryOperation type) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < total) {
int idx = tid / post % n;
#if __CUDA_ARCH__ >= 350
out_data[tid] = binary_calc(__ldg(x_data + tid), __ldg(y_data + idx), type);
#else
out_data[tid] = binary_calc(x_data[tid], y_data[idx], type);
#endif
}
}
template <typename Dtype>
__global__ void elementwise_relu_kernel(const size_t total,
const Dtype* x_data,
const Dtype* y_data,
Dtype* out_data,
int pre,
int n,
int post,
BinaryOperation type) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < total) {
int idx = tid / post % n;
Dtype temp;
#if __CUDA_ARCH__ >= 350
temp = binary_calc(__ldg(x_data + tid), __ldg(y_data + idx), type);
#else
temp = binary_calc(x_data[tid], y_data[idx], type);
#endif
out_data[tid] = temp > 0 ? temp : 0;
}
}
template <typename Dtype>
__global__ void elementwise_add_kernel(const size_t total,
const Dtype* x_data,
......@@ -76,6 +118,56 @@ __global__ void elementwise_add_nhwc4_int8_kernel(const size_t total,
}
}
template <typename Dtype>
void elementwise(const Dtype* x_data,
const Dtype* y_data,
Dtype* out_data,
int pre,
int n,
int post,
BinaryOperation type,
cudaStream_t stream) {
int num = pre * n * post;
int thread = 256;
int block = (num + thread - 1) / thread;
elementwise_kernel<<<block, thread, 0, stream>>>(
num, x_data, y_data, out_data, pre, n, post, type);
}
template <typename Dtype>
void elementwise_relu(const Dtype* x_data,
const Dtype* y_data,
Dtype* out_data,
int pre,
int n,
int post,
BinaryOperation type,
cudaStream_t stream) {
int num = pre * n * post;
int thread = 256;
int block = (num + thread - 1) / thread;
elementwise_relu_kernel<<<block, thread, 0, stream>>>(
num, x_data, y_data, out_data, pre, n, post, type);
}
template void elementwise(const float*,
const float*,
float*,
int,
int,
int,
BinaryOperation,
cudaStream_t);
template void elementwise_relu(const float*,
const float*,
float*,
int,
int,
int,
BinaryOperation,
cudaStream_t);
template <typename Dtype>
void elementwise_add(int num,
const Dtype* x_data,
......
......@@ -15,12 +15,33 @@
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include "lite/backends/cuda/math/utils.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <typename Dtype>
void elementwise(const Dtype* x_data,
const Dtype* y_data,
Dtype* out_data,
int pre,
int n,
int post,
BinaryOperation type,
cudaStream_t stream);
template <typename Dtype>
void elementwise_relu(const Dtype* x_data,
const Dtype* y_data,
Dtype* out_data,
int pre,
int n,
int post,
BinaryOperation type,
cudaStream_t stream);
template <typename Dtype>
void elementwise_add(int num,
const Dtype* x_data,
......
......@@ -25,6 +25,24 @@ namespace lite {
namespace cuda {
namespace math {
enum class BinaryOperation {
kADD = 0,
kMUL = 1,
kDIV = 2,
};
template <typename T>
__device__ T binary_calc(T x, T y, BinaryOperation type);
template <>
__device__ __forceinline__ float binary_calc(float x,
float y,
BinaryOperation type) {
if (type == BinaryOperation::kADD) return x + y;
if (type == BinaryOperation::kMUL) return x * y;
if (type == BinaryOperation::kDIV) return x / y;
}
template <typename T>
__device__ T from_float(float x);
......
......@@ -142,21 +142,25 @@ ge::TensorPtr CvtTensor(lite::Tensor* in_tensor,
int CvtActMode(std::string act_type) {
int act_mode = 1;
if (act_type == "sigmod") {
if (act_type == "sigmoid") {
act_mode = 0;
} else if (act_type == "relu") {
act_mode = 1;
} else if (act_type == "tanh") {
act_mode = 2;
} else if (act_type == "relu_clipped") {
act_mode = 3;
} else if (act_type == "elu") {
act_mode = 4;
} else if (act_type == "leaky_relu") {
act_mode = 5;
} else if (act_type == "abs") {
act_mode = 6;
} else if (act_type == "softsign") {
act_mode = 8;
} else if (act_type == "softplus") {
act_mode = 9;
} else if (act_type == "hardsigmoid") {
} else if (act_type == "hard_sigmoid") {
act_mode = 10;
} else {
// TODO(hong19860320) support more activation mode
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "lite/backends/x86/math/beam_search.h"
#include <algorithm>
#include <cmath>
#include <map>
#include "lite/fluid/lod.h"
......
......@@ -99,7 +99,7 @@ add_custom_target(all_kernel_faked_cc DEPENDS all_kernel_faked.cc)
#----------------------------------------------- NOT CHANGE -----------------------------------------------
lite_cc_library(kernel SRCS kernel.cc
DEPS context type_system target_wrapper any op_params tensor
PROFILE_DEPS basic_profiler
PROFILE_DEPS lite_profiler
)
lite_cc_library(op SRCS op_lite.cc DEPS scope op_registry target_wrapper kernel
cpp_op_desc tensor
......@@ -113,7 +113,7 @@ lite_cc_library(type_system SRCS type_system.cc DEPS tensor target_wrapper)
lite_cc_library(program SRCS program.cc
DEPS op kernel model_parser ${ops} ${cpp_wrapper}
PROFILE_DEPS basic_profiler)
PROFILE_DEPS lite_profiler)
if (NOT LITE_ON_TINY_PUBLISH)
lite_cc_library(optimizer SRCS optimizer.cc DEPS mir_pass_manager model_parser program)
......
......@@ -37,6 +37,9 @@ void TestCase::CreateInstruction() {
// prepare context
(*it)->SetContext(std::move(ctx_));
instruction_.reset(new Instruction(op, std::move(*it)));
#ifdef LITE_WITH_PROFILE
instruction_->set_profiler(new profile::Profiler());
#endif
}
void TestCase::PrepareInputsForInstruction() {
......
......@@ -31,7 +31,7 @@
#include "lite/utils/replace_stl/stream.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/basic_profiler.h"
#include "lite/core/profile/profiler.h"
#endif // LITE_WITH_PROFILE
namespace paddle {
......@@ -58,7 +58,10 @@ class KernelBase {
virtual void Run() = 0;
#ifdef LITE_WITH_PROFILE
void SetProfileID(uint32_t id) { profile_id_ = id; }
void SetProfiler(profile::Profiler* profiler, int id) {
profiler_ = profiler;
profile_id_ = id;
}
#endif
void Launch() {
......@@ -82,10 +85,12 @@ class KernelBase {
#endif
#ifdef LITE_WITH_PROFILE
if (profile_id_ >= 0) {
profile::ProfileBlock x(profile_id_, "kernel");
Run();
}
CHECK(profiler_) << "Profiler pointer of kernel can not be nullptr. "
"When LITE_WITH_PROFILE is defined, please set a "
"Profiler for Instruction.";
profiler_->StartTiming(profile_id_, ctx_.get());
Run();
profiler_->StopTiming(profile_id_, ctx_.get());
#else
Run();
#endif
......@@ -175,6 +180,7 @@ class KernelBase {
bool is_first_epoch_{true};
#ifdef LITE_WITH_PROFILE
profile::Profiler* profiler_{nullptr};
int profile_id_{-1};
#endif
};
......
......@@ -396,6 +396,8 @@ void DeleteQuantDequantOpFuser::InsertNewNode(SSAGraph* graph,
op_desc->SetAttr<float>("input_scale", scale_value);
op_desc->SetInput("X", {input_act_node->arg()->name});
IR_NODE_LINK_TO(input_act_node, quantized_node)
auto update_op_desc = *quantized_node->stmt()->mutable_op_info();
quantized_node->stmt()->ResetOp(update_op_desc, graph->valid_places());
// delete nodes and edges
std::unordered_set<const Node*> nodes2rm = {input_scale_node,
......@@ -440,6 +442,8 @@ void DeleteQuantDequantOpFuser::InsertNewNode(SSAGraph* graph,
op_desc->SetInput("Y", {input_act_right_node->arg()->name});
IR_NODE_LINK_TO(input_act_left_node, quantized_node)
IR_NODE_LINK_TO(input_act_right_node, quantized_node)
auto update_op_desc = *quantized_node->stmt()->mutable_op_info();
quantized_node->stmt()->ResetOp(update_op_desc, graph->valid_places());
// delete nodes and edges
std::unordered_set<const Node*> nodes2rm = {input_scale_left_node,
......
......@@ -49,7 +49,7 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
<< instruct.op_type();
VLOG(4) << "instruct.kernels().size():" << instruct.kernels().size();
for (auto&& kernel : instruct.kernels()) {
float score = KernelGrade(*kernel, graph->valid_places());
float score = KernelGrade(instruct, *kernel, graph->valid_places());
VLOG(4) << "kernel->summary():" << kernel->summary()
<< " score:" << score;
scored.emplace_back(score, std::move(kernel));
......@@ -99,7 +99,7 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
instruct.ResetOp(update_desc, graph->valid_places());
scored.clear();
for (auto&& kernel : instruct.kernels()) {
float score = KernelGrade(*kernel, graph->valid_places());
float score = KernelGrade(instruct, *kernel, graph->valid_places());
scored.emplace_back(score, std::move(kernel));
}
std::sort(scored.begin(), scored.end(), KernelScoreCmp);
......
......@@ -48,7 +48,8 @@ class StaticKernelPickPass : public mir::StmtPass {
private:
// Score the kernel.
size_t KernelGrade(const lite::KernelBase& kernel,
size_t KernelGrade(const lite::mir::Node::Stmt& instruct,
const lite::KernelBase& kernel,
const std::vector<Place>& places) {
CHECK_GT(places.size(), 0) << "valid_places is empty.";
float final_score{-1.};
......@@ -66,7 +67,7 @@ class StaticKernelPickPass : public mir::StmtPass {
// valid_places.size() as default.
// where i is the place's index in valid_places array.
// score: score is the weighted sum of target、percision and layout
for (int i = 0; i < place_size; ++i) {
for (size_t i = 0; i < place_size; ++i) {
const auto& place = places[i];
float weight = static_cast<float>(place_size - i) / place_size;
size_t score{};
......@@ -83,8 +84,12 @@ class StaticKernelPickPass : public mir::StmtPass {
(place.precision == kernel.precision() ||
kernel.precision() == PRECISION(kAny) ||
place.precision == PRECISION(kAny))) {
score += kMax / static_cast<int>(
core::KernelPickFactor::Factor::PrecisionFirst);
// score skipped, if kernel is int8, but op is not int8
if (!(kernel.precision() == PRECISION(kInt8) &&
!instruct.op_info()->HasAttr("enable_int8"))) {
score += kMax / static_cast<int>(
core::KernelPickFactor::Factor::PrecisionFirst);
}
}
VLOG(4) << "[score s2]:" << score;
if (kernel_pick_factors_.IsDataLayoutConsidered() &&
......
......@@ -160,8 +160,8 @@ TEST(NPUSubgraph, compare) {
TestModel(FLAGS_model_dir,
FLAGS_model_file,
FLAGS_params_file,
{lite_api::Place{TARGET(kARM), PRECISION(kFloat)},
lite_api::Place{TARGET(kNPU), PRECISION(kFloat)}},
{lite_api::Place{TARGET(kNPU), PRECISION(kFloat)},
lite_api::Place{TARGET(kARM), PRECISION(kFloat)}},
input_tensor_shape,
FLAGS_optimized_model_dir + "/NPU");
// verify results
......
......@@ -115,6 +115,8 @@ KernelRegistry::KernelRegistry()
INIT_FOR(kCUDA, kAny, kNCHW);
INIT_FOR(kCUDA, kAny, kAny);
INIT_FOR(kCUDA, kInt8, kNHWC);
INIT_FOR(kCUDA, kInt64, kNCHW);
INIT_FOR(kCUDA, kInt64, kNHWC);
INIT_FOR(kHost, kFloat, kNCHW);
INIT_FOR(kHost, kAny, kNCHW);
......
......@@ -73,7 +73,7 @@ class Optimizer {
"lite_transpose_softmax_transpose_fuse_pass", //
"lite_interpolate_fuse_pass", //
"identity_scale_eliminate_pass", //
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
#if (defined LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) || (defined LITE_WITH_CUDA)
"lite_elementwise_add_activation_fuse_pass", //
#endif
"static_kernel_pick_pass", // pick original kernel from graph
......
......@@ -5,4 +5,5 @@ endif()
lite_cc_library(basic_profiler SRCS basic_profiler.cc DEPS gflags)
lite_cc_test(test_basic_profiler SRCS basic_profiler_test.cc DEPS basic_profiler)
lite_cc_library(lite_profiler SRCS profiler.cc DEPS context)
lite_cc_test(test_lite_timer SRCS test_timer.cc DEPS lite_profiler)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/profile/profiler.h"
#include <map>
#include <string>
#include <utility>
namespace paddle {
namespace lite {
namespace profile {
int Profiler::NewTimer(const OpCharacter& ch) {
StatisUnit unit;
unit.character = ch;
if (ch.target == TargetType::kCUDA) {
#ifdef LITE_WITH_CUDA
unit.timer.reset(new DeviceTimer<TargetType::kCUDA>());
#else
LOG(ERROR) << "The timer type specified as cuda is uninitialized, so the "
"default x86 timer is used instead.";
#endif
} else {
unit.timer.reset(new DeviceTimer<TargetType::kHost>());
}
units_.push_back(std::move(unit));
return units_.size() - 1;
}
void Profiler::StartTiming(const int index, KernelContext* ctx) {
CHECK_LT(index, units_.size())
<< "The timer index in the profiler is out of range.";
units_[index].timer->Start(ctx);
}
float Profiler::StopTiming(const int index, KernelContext* ctx) {
CHECK_LT(index, units_.size())
<< "The timer index in the profiler is out of range.";
return units_[index].timer->Stop(ctx);
}
std::string Profiler::Summary(bool concise) {
STL::stringstream ss;
auto cout_title = [&ss](const std::string& title, const std::string& name) {
// clang-format off
ss << "===== " << title << ": " << name << " =====" << std::endl;
ss << std::setw(25) << std::left << "Operator Type" \
<< std::setw(40) << std::left << "Kernel Name" \
<< std::setw(10) << std::left << "Remark" \
<< std::setw(10) << std::left << "Avg (ms)" \
<< std::setw(10) << std::left << "Min (ms)" \
<< std::setw(10) << std::left << "Max (ms)" \
<< std::endl;
// clang-format on
};
if (concise) {
auto op_comp = [](const OpCharacter& c1, const OpCharacter& c2) {
return (c1.target < c2.target) || (c1.op_type < c2.op_type) ||
(c1.kernel_name < c2.kernel_name) || (c1.remark < c2.remark);
};
std::map<OpCharacter, TimeInfo, decltype(op_comp)> summary(op_comp);
for (auto& unit : units_) {
auto ch = summary.find(unit.character);
if (ch != summary.end()) {
ch->second.avg += unit.timer->LapTimes().Avg();
ch->second.min += unit.timer->LapTimes().Min();
ch->second.max += unit.timer->LapTimes().Max();
} else {
TimeInfo info({unit.timer->LapTimes().Avg(),
unit.timer->LapTimes().Min(),
unit.timer->LapTimes().Max()});
summary.insert({unit.character, info});
}
}
cout_title("Concise Profiler Summary", name_);
for (const auto& item : summary) {
// clang-format off
ss << std::setw(25) << std::left << item.first.op_type \
<< std::setw(40) << std::left << item.first.kernel_name \
<< std::setw(10) << std::left << item.first.remark \
<< std::setw(10) << std::left << item.second.avg \
<< std::setw(10) << std::left << item.second.min \
<< std::setw(10) << std::left << item.second.max \
<< std::endl;
// clang-format on
}
} else {
cout_title("Detailed Profiler Summary", name_);
for (auto& unit : units_) {
// clang-format off
ss << std::setw(25) << std::left << unit.character.op_type \
<< std::setw(40) << std::left << unit.character.kernel_name \
<< std::setw(10) << std::left << unit.character.remark \
<< std::setw(10) << std::left << unit.timer->LapTimes().Avg() \
<< std::setw(10) << std::left << unit.timer->LapTimes().Min() \
<< std::setw(10) << std::left << unit.timer->LapTimes().Max() \
<< std::endl;
// clang-format on
}
}
return ss.str();
}
} // namespace profile
} // namespace lite
} // namespace paddle
......@@ -11,42 +11,49 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <Eigen/Core>
#include <algorithm>
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
#include "lite/operators/relu_op.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "lite/core/profile/timer.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
namespace profile {
template <typename T>
class ReluCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
struct TimeInfo {
float avg;
float min;
float max;
};
struct OpCharacter {
TargetType target;
std::string op_type{std::string("N/A")};
std::string kernel_name{std::string("N/A")};
std::string remark{std::string("N/A")};
};
struct StatisUnit {
std::unique_ptr<Timer> timer;
OpCharacter character;
};
class Profiler final {
public:
using param_t = operators::ActivationParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto n = param.X->dims().production();
const float* input = param.X->data<float>();
float* output = param.Out->mutable_data<float>();
for (int i = 0; i < n; i++) {
output[i] = std::max(0.f, input[i]);
}
}
virtual ~ReluCompute() = default;
Profiler() = default;
explicit Profiler(const std::string& name) : name_(name) {}
int NewTimer(const OpCharacter& ch);
void StartTiming(const int index, KernelContext* ctx);
float StopTiming(const int index, KernelContext* ctx);
std::string Summary(bool concise = true);
private:
std::string name_{std::string("N/A")};
std::vector<StatisUnit> units_;
};
} // namespace x86
} // namespace kernels
} // namespace profile
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <chrono> // NOLINT
#include <thread> // NOLINT
#include "lite/core/context.h"
#include "lite/core/profile/profiler.h"
#include "lite/core/profile/timer.h"
#include "lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
namespace profile {
TEST(timer, real_latency) {
Timer timer;
timer.Start();
std::this_thread::sleep_for(std::chrono::milliseconds(10));
timer.Stop();
timer.Start();
std::this_thread::sleep_for(std::chrono::milliseconds(50));
timer.Stop();
LOG(INFO) << "LapTimes().Avg() = " << timer.LapTimes().Avg();
}
#ifdef LITE_WITH_CUDA
TEST(gpu_timer, real_latency) {
DeviceTimer<TargetType::kCUDA> timer;
KernelContext ctx;
cudaStream_t exec_stream;
cudaStreamCreate(&exec_stream);
(&ctx.As<CUDAContext>())->SetExecStream(exec_stream);
timer.Start(&ctx);
std::this_thread::sleep_for(std::chrono::milliseconds(10));
timer.Stop(&ctx);
(&timer)->Start(&ctx);
std::this_thread::sleep_for(std::chrono::milliseconds(50));
timer.Stop(&ctx);
LOG(INFO) << "LapTimes().Avg() = " << timer.LapTimes().Avg();
}
TEST(profiler, real_latency) {
KernelContext ctx;
cudaStream_t exec_stream;
cudaStreamCreate(&exec_stream);
(&ctx.As<CUDAContext>())->SetExecStream(exec_stream);
Profiler profiler("name");
profile::OpCharacter ch;
ch.target = TargetType::kCUDA;
ch.op_type = "operator/1";
ch.kernel_name = "kernel/1";
int idx = profiler.NewTimer(ch);
profiler.StartTiming(idx, &ctx);
std::this_thread::sleep_for(std::chrono::milliseconds(10));
profiler.StopTiming(idx, &ctx);
std::cout << profiler.Summary();
}
#endif
} // namespace profile
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <chrono> // NOLINT
#include <list>
#ifdef LITE_WITH_CUDA
#include "lite/backends/cuda/cuda_utils.h"
#endif
#include "lite/core/context.h"
namespace paddle {
namespace lite {
namespace profile {
template <typename T>
class TimeList {
public:
void Clear() { laps_t_.clear(); }
void Add(T t) { laps_t_.push_back(t); }
T Max() const { return *std::max_element(laps_t_.begin(), laps_t_.end()); }
T Min() const { return *std::min_element(laps_t_.begin(), laps_t_.end()); }
T Sum() const { return std::accumulate(laps_t_.begin(), laps_t_.end(), 0.0); }
size_t Size() const { return laps_t_.size(); }
T Avg() const {
if (!Size()) {
return 0;
}
return Sum() / Size();
}
const std::list<T>& Raw() const { return laps_t_; }
private:
std::list<T> laps_t_;
};
class Timer {
public:
Timer() = default;
virtual ~Timer() = default;
void Reset() { laps_t_.Clear(); }
void Start() { t_start_ = std::chrono::system_clock::now(); }
float Stop() {
t_stop_ = std::chrono::system_clock::now();
auto ts = std::chrono::duration_cast<std::chrono::microseconds>(t_stop_ -
t_start_);
float elapse_ms = 1000.f * static_cast<float>(ts.count()) *
std::chrono::microseconds::period::num /
std::chrono::microseconds::period::den;
this->laps_t_.Add(elapse_ms);
return elapse_ms;
}
virtual void Start(KernelContext* ctx) { return Start(); }
virtual float Stop(KernelContext* ctx) { return Stop(); }
float AvgLapTimeMs() const { return laps_t_.Avg(); }
const TimeList<float>& LapTimes() const { return laps_t_; }
protected:
std::chrono::time_point<std::chrono::system_clock> t_start_, t_stop_;
TimeList<float> laps_t_;
};
template <TargetType Target>
class DeviceTimer final : public Timer {};
#ifdef LITE_WITH_CUDA
template <>
class DeviceTimer<TargetType::kCUDA> final : public Timer {
public:
DeviceTimer() {
CUDA_CALL(cudaEventCreate(&e_start_));
CUDA_CALL(cudaEventCreate(&e_stop_));
}
~DeviceTimer() {
CUDA_CALL(cudaEventDestroy(e_start_));
CUDA_CALL(cudaEventDestroy(e_stop_));
}
void Start(KernelContext* ctx) {
cudaStream_t stream;
stream = ctx->As<CUDAContext>().exec_stream();
CUDA_CALL(cudaEventRecord(e_start_, stream));
}
float Stop(KernelContext* ctx) {
cudaStream_t stream;
stream = ctx->As<CUDAContext>().exec_stream();
CUDA_CALL(cudaEventRecord(e_stop_, stream));
CUDA_CALL(cudaEventSynchronize(e_stop_));
float elapse_ms = 1.f;
CUDA_CALL(cudaEventElapsedTime(&elapse_ms, e_start_, e_stop_));
this->laps_t_.Add(elapse_ms);
return elapse_ms;
}
private:
cudaEvent_t e_start_, e_stop_;
};
#endif
} // namespace profile
} // namespace lite
} // namespace paddle
......@@ -122,6 +122,9 @@ void RuntimeProgram::Run() {
#endif // LITE_WITH_PRECISION_PROFILE
#endif // LITE_WITH_PROFILE
}
#ifdef LITE_WITH_PROFILE
LOG(INFO) << "\n" << profiler_.Summary();
#endif // LITE_WITH_PROFILE
}
void Program::Build(const cpp::ProgramDesc& prog) {
......@@ -183,11 +186,6 @@ void Program::PrepareWorkspace(const cpp::ProgramDesc& prog) {
void Instruction::Run() {
CHECK(op_) << "op null";
CHECK(kernel_) << "kernel null";
#ifdef LITE_WITH_PROFILE
if (profile_id_ >= 0) {
profile::ProfileBlock x(profile_id_, "instruction");
}
#endif // LITE_WITH_PROFILE
if (first_epoch_) {
first_epoch_ = false;
CHECK(op_->CheckShape());
......
......@@ -22,9 +22,6 @@
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/model_parser/cpp/program_desc.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/basic_profiler.h"
#endif // LITE_WITH_PROFILE
namespace paddle {
namespace lite {
......@@ -87,22 +84,7 @@ struct Program {
struct Instruction {
Instruction(const std::shared_ptr<OpLite>& op,
std::unique_ptr<KernelBase>&& kernel)
: op_(op), kernel_(std::move(kernel)) {
#ifdef LITE_WITH_PROFILE
if (op_->Type() != "feed" && op_->Type() != "fetch") {
profile_id_ = profile::BasicProfiler<profile::BasicTimer>::Global()
.NewRcd(kernel_->SerializedKernelType())
.id();
kernel_->SetProfileID(profile_id_);
// Set profile custom info
auto& profiler =
*profile::BasicProfiler<profile::BasicTimer>::Global().mutable_record(
profile_id_);
profiler.SetCustomInfo("op_type", op_->Type());
profiler.SetCustomInfo("op_info", op_->SerializedOpInfo());
}
#endif // LITE_WITH_PROFILE
}
: op_(op), kernel_(std::move(kernel)) {}
// Run the instruction.
void Run();
......@@ -113,6 +95,20 @@ struct Instruction {
const KernelBase* kernel() const { return kernel_.get(); }
KernelBase* mutable_kernel() { return kernel_.get(); }
#ifdef LITE_WITH_PROFILE
void set_profiler(profile::Profiler* profiler) {
profiler_ = profiler;
if (op_->Type() != "feed" && op_->Type() != "fetch") {
profile::OpCharacter ch;
ch.target = kernel()->target();
ch.op_type = op_->Type();
ch.kernel_name = kernel()->name();
profile_id_ = profiler->NewTimer(ch);
kernel_->SetProfiler(profiler_, profile_id_);
}
}
#endif
private:
std::shared_ptr<OpLite> op_;
std::unique_ptr<KernelBase> kernel_;
......@@ -120,7 +116,7 @@ struct Instruction {
bool has_run_{false};
#ifdef LITE_WITH_PROFILE
// for profiler
profile::Profiler* profiler_;
int profile_id_{-1};
#endif // LITE_WITH_PROFILE
};
......@@ -135,6 +131,9 @@ class LITE_API RuntimeProgram {
if (instructions_.empty()) {
LOG(FATAL) << "no instructions";
}
#ifdef LITE_WITH_PROFILE
set_profiler();
#endif
}
void Run();
......@@ -159,6 +158,15 @@ class LITE_API RuntimeProgram {
RuntimeProgram(const RuntimeProgram&) = delete;
std::vector<Instruction> instructions_;
lite::Scope* exec_scope_{};
#ifdef LITE_WITH_PROFILE
profile::Profiler profiler_;
void set_profiler() {
for (auto i = instructions_.begin(); i != instructions_.end(); ++i) {
i->set_profiler(&profiler_);
}
}
#endif
};
} // namespace lite
......
CXX_DEFINES = -DARM_WITH_OMP -DHPPL_STUB_FUNC -DLITE_WITH_ARM -DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK \
-DLITE_WITH_LINUX -DPADDLE_DISABLE_PROFILER -DPADDLE_NO_PYTHON -DPADDLE_WITH_TESTING
LDFLAGS = -latomic -pthread -ldl -llog
LDFLAGS = -latomic -pthread -ldl -llog -lz
SYSROOT_COMPLILE = --sysroot=/opt/android-ndk-r17c/sysroot
THIRD_PARTY_LIBS = ../../../third_party/gflags/lib/libgflags.a
SYSTEM_INCLUDES = -I/opt/android-ndk-r17c/sources/cxx-stl/llvm-libc++/include \
-I/opt/android-ndk-r17c/sources/cxx-stl/llvm-libc++abi/include \
-I/opt/android-ndk-r17c/sources/android/support/include \
-I/opt/android-ndk-r17c/sysroot/usr/include \
THIRD_PARTY_INCLUDES = -I../../../third_party/gflags/include
ifeq ($(ARM_ABI), arm8)
CC = /opt/android-ndk-r17c/toolchains/aarch64-linux-android-4.9/prebuilt/linux-x86_64/bin/aarch64-linux-android-g++
CXX_FLAGS = -funwind-tables -no-canonical-prefixes -D__ANDROID_API__=23 -fexceptions -frtti -std=c++11 -fopenmp -O3 -DNDEBUG -fPIE
CXX_FLAGS = -funwind-tables -no-canonical-prefixes -D__ANDROID_API__=23 -fexceptions -frtti -std=c++11 -fopenmp -O3 -DNDEBUG -fPIE
CXXFLAGS_LINK = $(CXX_FLAGS) -pie -Wl,--gc-sections
SYSROOT_LINK = --sysroot=/opt/android-ndk-r17c/platforms/android-24/arch-arm64
SYSTEM_LIBS = /opt/android-ndk-r17c/sources/cxx-stl/llvm-libc++/libs/arm64-v8a/libc++_static.a \
/opt/android-ndk-r17c/sources/cxx-stl/llvm-libc++/libs/arm64-v8a/libc++abi.a
INCLUDES = $(SYSTEM_INCLUDES) -I/opt/android-ndk-r17c/sysroot/usr/include/aarch64-linux-android $(THIRD_PARTY_INCLUDES)
INCLUDES = $(SYSTEM_INCLUDES) -I/opt/android-ndk-r17c/sysroot/usr/include/aarch64-linux-android
else
CC = /opt/android-ndk-r17c/toolchains/arm-linux-androideabi-4.9/prebuilt/linux-x86_64/bin/arm-linux-androideabi-g++
CXX_FLAGS = -march=armv7-a -mthumb -mfpu=neon -mfloat-abi=softfp -funwind-tables -no-canonical-prefixes \
......@@ -31,5 +27,5 @@ else
/opt/android-ndk-r17c/sources/cxx-stl/llvm-libc++/libs/armeabi-v7a/libc++abi.a \
/opt/android-ndk-r17c/sources/cxx-stl/llvm-libc++/libs/armeabi-v7a/libandroid_support.a \
/opt/android-ndk-r17c/sources/cxx-stl/llvm-libc++/libs/armeabi-v7a/libunwind.a
INCLUDES = $(SYSTEM_INCLUDES) -I/opt/android-ndk-r17c/sysroot/usr/include/arm-linux-androideabi $(THIRD_PARTY_INCLUDES)
INCLUDES = $(SYSTEM_INCLUDES) -I/opt/android-ndk-r17c/sysroot/usr/include/arm-linux-androideabi
endif
# C++ Demo
1. 使用`lite/tools/Dockerfile.mobile`生成docker镜像
2. 运行并进入docker镜像环境,执行`wget http://paddle-inference-dist.bj.bcebos.com/lite_release/r0.1/inference_lite_lib.android.armv8.tar.gz `下载所需demo环境。(armv7 demo可使用命令`wget http://paddle-inference-dist.bj.bcebos.com/lite_release/r0.1/inference_lite_lib.android.armv7.tar.gz` 进行下载)。
2. 运行并进入docker镜像环境,执行`wget http://paddle-inference-dist.bj.bcebos.com/lite_release/v2.1.0/inference_lite_lib.android.armv8.tar.gz `下载所需demo环境。(armv7 demo可使用命令`wget http://paddle-inference-dist.bj.bcebos.com/lite_release/v2.1.0/inference_lite_lib.android.armv7.tar.gz` 进行下载)。
3. 解压下载文件`tar zxvf inference_lite_lib.android.armv8.tar.gz `
4. 执行以下命令准备模拟器环境
```shell
......@@ -27,8 +27,10 @@ tar zxvf mobilenet_v1.tar.gz
make
adb -s emulator-5554 push mobilenet_v1 /data/local/tmp/
adb -s emulator-5554 push mobilenetv1_full_api /data/local/tmp/
adb -s emulator-5554 push ../../../cxx/lib/libpaddle_full_api_shared.so /data/local/tmp/
adb -s emulator-5554 shell chmod +x /data/local/tmp/mobilenetv1_full_api
adb -s emulator-5554 shell "/data/local/tmp/mobilenetv1_full_api --model_dir=/data/local/tmp/mobilenet_v1 --optimized_model_dir=/data/local/tmp/mobilenet_v1.opt"
adb -s emulator-5554 shell "export LD_LIBRARY_PATH=/data/local/tmp/:$LD_LIBRARY_PATH &&
/data/local/tmp/mobilenetv1_full_api --model_dir=/data/local/tmp/mobilenet_v1 --optimized_model_dir=/data/local/tmp/mobilenet_v1.opt"
```
运行成功将在控制台输出预测结果的前10个类别的预测概率
......@@ -37,6 +39,24 @@ adb -s emulator-5554 shell "/data/local/tmp/mobilenetv1_full_api --model_dir=/da
cd ../mobile_light
make
adb -s emulator-5554 push mobilenetv1_light_api /data/local/tmp/
adb -s emulator-5554 push ../../../cxx/lib/libpaddle_light_api_shared.so /data/local/tmp/
adb -s emulator-5554 shell chmod +x /data/local/tmp/mobilenetv1_light_api
adb -s emulator-5554 shell "/data/local/tmp/mobilenetv1_light_api --model_dir=/data/local/tmp/mobilenet_v1.opt"
adb -s emulator-5554 shell "export LD_LIBRARY_PATH=/data/local/tmp/:$LD_LIBRARY_PATH &&
/data/local/tmp/mobilenetv1_light_api /data/local/tmp/mobilenet_v1.opt"
```
7. 编译并运行目标检测的demo
```shell
cd ../mobile_detection
wget https://paddle-inference-dist.bj.bcebos.com/mobilenetv1-ssd.tar.gz
tar zxvf mobilenetv1-ssd.tar.gz
make
adb -s emulator-5554 push mobile_detection /data/local/tmp/
adb -s emulator-5554 push test.jpg /data/local/tmp/
adb -s emulator-5554 push ../../../cxx/lib/libpaddle_light_api_shared.so /data/local/tmp/
adb -s emulator-5554 shell chmod +x /data/local/tmp/mobile_detection
adb -s emulator-5554 shell "export LD_LIBRARY_PATH=/data/local/tmp/:$LD_LIBRARY_PATH &&
/data/local/tmp/mobile_detection /data/local/tmp/mobilenetv1-ssd /data/local/tmp/test.jpg"
adb -s emulator-5554 pull /data/local/tmp/test_detection_result.jpg ./
```
运行成功将在mobile_detection目录下看到生成的目标检测结果图像: test_detection_result.jpg
ARM_ABI = arm7
export ARM_ABI
include ../Makefile.def
LITE_ROOT=../../../
THIRD_PARTY_DIR=${LITE_ROOT}/third_party
OPENCV_VERSION=opencv4.1.0
OPENCV_LIBS = ../../../third_party/${OPENCV_VERSION}/armeabi-v7a/libs/libopencv_imgcodecs.a \
../../../third_party/${OPENCV_VERSION}/armeabi-v7a/libs/libopencv_imgproc.a \
../../../third_party/${OPENCV_VERSION}/armeabi-v7a/libs/libopencv_core.a \
../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/libtegra_hal.a \
../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/liblibjpeg-turbo.a \
../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/liblibwebp.a \
../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/liblibpng.a \
../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/liblibjasper.a \
../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/liblibtiff.a \
../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/libIlmImf.a \
../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/libtbb.a \
../../../third_party/${OPENCV_VERSION}/armeabi-v7a/3rdparty/libs/libcpufeatures.a
OPENCV_INCLUDE = -I../../../third_party/${OPENCV_VERSION}/armeabi-v7a/include
CXX_INCLUDES = $(INCLUDES) ${OPENCV_INCLUDE} -I$(LITE_ROOT)/cxx/include
CXX_LIBS = ${OPENCV_LIBS} -L$(LITE_ROOT)/cxx/lib/ -lpaddle_light_api_shared $(SYSTEM_LIBS)
###############################################################
# How to use one of static libaray: #
# `libpaddle_api_full_bundled.a` #
# `libpaddle_api_light_bundled.a` #
###############################################################
# Note: default use lite's shared library. #
###############################################################
# 1. Comment above line using `libpaddle_light_api_shared.so`
# 2. Undo comment below line using `libpaddle_api_light_bundled.a`
#CXX_LIBS = $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS)
mobile_detection: fetch_opencv mobile_detection.o
$(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) mobile_detection.o -o mobile_detection $(CXX_LIBS) $(LDFLAGS)
mobile_detection.o: mobile_detection.cc
$(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o mobile_detection.o -c mobile_detection.cc
fetch_opencv:
@ test -d ${THIRD_PARTY_DIR} || mkdir ${THIRD_PARTY_DIR}
@ test -e ${THIRD_PARTY_DIR}/${OPENCV_VERSION}.tar.gz || \
(echo "fetch opencv libs" && \
wget -P ${THIRD_PARTY_DIR} https://paddle-inference-dist.bj.bcebos.com/${OPENCV_VERSION}.tar.gz)
@ test -d ${THIRD_PARTY_DIR}/${OPENCV_VERSION} || \
tar -zxvf ${THIRD_PARTY_DIR}/${OPENCV_VERSION}.tar.gz -C ${THIRD_PARTY_DIR}
.PHONY: clean
clean:
rm -f mobile_detection.o
rm -f mobile_detection
ARM_ABI = arm8
export ARM_ABI
include ../Makefile.def
LITE_ROOT=../../../
THIRD_PARTY_DIR=${LITE_ROOT}/third_party
OPENCV_VERSION=opencv4.1.0
OPENCV_LIBS = ../../../third_party/${OPENCV_VERSION}/arm64-v8a/libs/libopencv_imgcodecs.a \
../../../third_party/${OPENCV_VERSION}/arm64-v8a/libs/libopencv_imgproc.a \
../../../third_party/${OPENCV_VERSION}/arm64-v8a/libs/libopencv_core.a \
../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libtegra_hal.a \
../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibjpeg-turbo.a \
../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibwebp.a \
../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibpng.a \
../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibjasper.a \
../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibtiff.a \
../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libIlmImf.a \
../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libtbb.a \
../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libcpufeatures.a
OPENCV_INCLUDE = -I../../../third_party/${OPENCV_VERSION}/arm64-v8a/include
CXX_INCLUDES = $(INCLUDES) ${OPENCV_INCLUDE} -I$(LITE_ROOT)/cxx/include
CXX_LIBS = ${OPENCV_LIBS} -L$(LITE_ROOT)/cxx/lib/ -lpaddle_light_api_shared $(SYSTEM_LIBS)
###############################################################
# How to use one of static libaray: #
# `libpaddle_api_full_bundled.a` #
# `libpaddle_api_light_bundled.a` #
###############################################################
# Note: default use lite's shared library. #
###############################################################
# 1. Comment above line using `libpaddle_light_api_shared.so`
# 2. Undo comment below line using `libpaddle_api_light_bundled.a`
#CXX_LIBS = $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS)
mobile_detection: fetch_opencv mobile_detection.o
$(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) mobile_detection.o -o mobile_detection $(CXX_LIBS) $(LDFLAGS)
mobile_detection.o: mobile_detection.cc
$(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o mobile_detection.o -c mobile_detection.cc
fetch_opencv:
@ test -d ${THIRD_PARTY_DIR} || mkdir ${THIRD_PARTY_DIR}
@ test -e ${THIRD_PARTY_DIR}/${OPENCV_VERSION}.tar.gz || \
(echo "fetch opencv libs" && \
wget -P ${THIRD_PARTY_DIR} https://paddle-inference-dist.bj.bcebos.com/${OPENCV_VERSION}.tar.gz)
@ test -d ${THIRD_PARTY_DIR}/${OPENCV_VERSION} || \
tar -zxvf ${THIRD_PARTY_DIR}/${OPENCV_VERSION}.tar.gz -C ${THIRD_PARTY_DIR}
.PHONY: clean
clean:
rm -f mobile_detection.o
rm -f mobile_detection
......@@ -5,28 +5,25 @@ include ../Makefile.def
LITE_ROOT=../../../
CXX_INCLUDES = $(INCLUDES) -I$(LITE_ROOT)/cxx/include
THIRD_PARTY_INCLUDES = -I../../../third_party/gflags/include
CXX_LIBS = $(THIRD_PARTY_LIBS) $(LITE_ROOT)/cxx/lib/libpaddle_api_full_bundled.a $(SYSTEM_LIBS)
THIRD_PARTY_LIBS = ../../../third_party/gflags/lib/libgflags.a
CXX_INCLUDES = $(INCLUDES) ${THIRD_PARTY_INCLUDES} -I$(LITE_ROOT)/cxx/include
CXX_LIBS = $(THIRD_PARTY_LIBS) -L$(LITE_ROOT)/cxx/lib/ -lpaddle_full_api_shared $(SYSTEM_LIBS)
###############################################################
# How to use one of shared libraries: #
# `libpaddle_full_api_shared.so` #
# `libpaddle_light_api_shared.so` #
# How to use one of static libaray: #
# `libpaddle_api_full_bundled.a` #
# `libpaddle_api_light_bundled.a` #
###############################################################
# Note: default use lite's static library. #
# Note: default use lite's shared library. #
###############################################################
# 1. Comment above line using `libpaddle_api_full_bundled.a`;
# 2. Undo comment below line and execute
# `export LD_LIBRARY_PATH=<libpaddle_full_api_shared.so dir>`
# in command line before `make`;
# 3. After `adb push` model and mobilenetv1_full_api files to
# android devices, execute
# `export LD_LIBRARY_PATH=<libpaddle_full_api_shared.so android dir>`
# and `mobilenetv1_full_api` in android `adb shell`;
# 4. Get executed result of `mobilenetv1_full_api` in android.
# CXX_LIBS = $(THIRD_PARTY_LIBS) -L$(LITE_ROOT)/cxx/lib/ -lpaddle_full_api_shared $(SYSTEM_LIBS)
# 1. Comment above line using `libpaddle_full_api_shared.so`
# 2. Undo comment below line using `libpaddle_api_full_bundled.a`
#CXX_LIBS = $(THIRD_PARTY_LIBS) $(LITE_ROOT)/cxx/lib/libpaddle_api_full_bundled.a $(SYSTEM_LIBS)
mobilenetv1_full_api: mobilenetv1_full_api.o
$(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) mobilenetv1_full_api.o -o mobilenetv1_full_api $(CXX_LIBS) $(LDFLAGS)
......
......@@ -5,28 +5,25 @@ include ../Makefile.def
LITE_ROOT=../../../
CXX_INCLUDES = $(INCLUDES) -I$(LITE_ROOT)/cxx/include
THIRD_PARTY_INCLUDES = -I../../../third_party/gflags/include
CXX_LIBS = $(THIRD_PARTY_LIBS) $(LITE_ROOT)/cxx/lib/libpaddle_api_full_bundled.a $(SYSTEM_LIBS)
THIRD_PARTY_LIBS = ../../../third_party/gflags/lib/libgflags.a
CXX_INCLUDES = $(INCLUDES) ${THIRD_PARTY_INCLUDES} -I$(LITE_ROOT)/cxx/include
CXX_LIBS = $(THIRD_PARTY_LIBS) -L$(LITE_ROOT)/cxx/lib/ -lpaddle_full_api_shared $(SYSTEM_LIBS)
###############################################################
# How to use one of shared libraries: #
# `libpaddle_full_api_shared.so` #
# `libpaddle_light_api_shared.so` #
# How to use one of static libaray: #
# `libpaddle_api_full_bundled.a` #
# `libpaddle_api_light_bundled.a` #
###############################################################
# Note: default use lite's static library. #
# Note: default use lite's shared library. #
###############################################################
# 1. Comment above line using `libpaddle_api_full_bundled.a`;
# 2. Undo comment below line and execute
# `export LD_LIBRARY_PATH=<libpaddle_full_api_shared.so dir>`
# in command line before `make`;
# 3. After `adb push` model and mobilenetv1_full_api files to
# android devices, execute
# `export LD_LIBRARY_PATH=<libpaddle_full_api_shared.so android dir>`
# and `mobilenetv1_full_api` in android `adb shell`;
# 4. Get executed result of `mobilenetv1_full_api` in android.
# CXX_LIBS = $(THIRD_PARTY_LIBS) -L$(LITE_ROOT)/cxx/lib/ -lpaddle_full_api_shared $(SYSTEM_LIBS)
# 1. Comment above line using `libpaddle_full_api_shared.so`
# 2. Undo comment below line using `libpaddle_api_full_bundled.a`
#CXX_LIBS = $(THIRD_PARTY_LIBS) $(LITE_ROOT)/cxx/lib/libpaddle_api_full_bundled.a $(SYSTEM_LIBS)
mobilenetv1_full_api: mobilenetv1_full_api.o
$(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) mobilenetv1_full_api.o -o mobilenetv1_full_api $(CXX_LIBS) $(LDFLAGS)
......
......@@ -7,26 +7,19 @@ LITE_ROOT=../../../
CXX_INCLUDES = $(INCLUDES) -I$(LITE_ROOT)/cxx/include
CXX_LIBS = $(THIRD_PARTY_LIBS) $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS)
CXX_LIBS = -L$(LITE_ROOT)/cxx/lib/ -lpaddle_light_api_shared $(SYSTEM_LIBS)
###############################################################
# How to use one of shared libraries: #
# `libpaddle_light_api_shared.so` #
# `libpaddle_full_api_shared.so` #
# How to use one of static libaray: #
# `libpaddle_api_full_bundled.a` #
# `libpaddle_api_light_bundled.a` #
###############################################################
# Note: default use lite's static library. #
# Note: default use lite's shared library. #
###############################################################
# 1. Comment above line using `libpaddle_api_light_bundled.a`;
# 2. Undo comment below line and execute
# `export LD_LIBRARY_PATH=<libpaddle_light_api_shared.so dir>`
# in command line before `make`;
# 3. After `adb push` model and mobilenetv1_light_api files to
# android devices, execute
# `export LD_LIBRARY_PATH=<libpaddle_light_api_shared.so android dir>`
# and `mobilenetv1_light_api` in android `adb shell`;
# 4. Get executed result of `mobilenetv1_light_api` in android.
# CXX_LIBS = $(THIRD_PARTY_LIBS) -L$(LITE_ROOT)/cxx/lib/ -lpaddle_light_api_shared $(SYSTEM_LIBS)
# 1. Comment above line using `libpaddle_light_api_shared.so`
# 2. Undo comment below line using `libpaddle_api_light_bundled.a`
#CXX_LIBS = $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS)
mobilenetv1_light_api: mobilenetv1_light_api.o
$(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) mobilenetv1_light_api.o -o mobilenetv1_light_api $(CXX_LIBS) $(LDFLAGS)
......
......@@ -7,26 +7,19 @@ LITE_ROOT=../../../
CXX_INCLUDES = $(INCLUDES) -I$(LITE_ROOT)/cxx/include
CXX_LIBS = $(THIRD_PARTY_LIBS) $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS)
CXX_LIBS = -L$(LITE_ROOT)/cxx/lib/ -lpaddle_light_api_shared $(SYSTEM_LIBS)
###############################################################
# How to use one of shared libaray: #
# `libpaddle_light_api_shared.so` #
# `libpaddle_full_api_shared.so` #
# How to use one of static libaray: #
# `libpaddle_api_full_bundled.a` #
# `libpaddle_api_light_bundled.a` #
###############################################################
# Note: default use lite's tatic library. #
# Note: default use lite's shared library. #
###############################################################
# 1. Comment above line using `libpaddle_api_light_bundled.a`;
# 2. Undo comment below line and execute
# `export LD_LIBRARY_PATH=<libpaddle_light_api_shared dir>`
# in command line before `make`;
# 3. After `adb push` model and mobilenetv1_light_api files to
# android devices, execute
# `export LD_LIBRARY_PATH=<libpaddle_light_api_shared android dir>` and
# `mobilenetv1_light_api` in android `adb shell`;
# 4. Get executed result of `mobilenetv1_light_api` in android.
# CXX_LIBS = $(THIRD_PARTY_LIBS) -L$(LITE_ROOT)/cxx/lib/ -lpaddle_light_api_shared $(SYSTEM_LIBS)
# 1. Comment above line using `libpaddle_light_api_shared.so`
# 2. Undo comment below line using `libpaddle_api_light_bundled.a`
#CXX_LIBS = $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS)
mobilenetv1_light_api: mobilenetv1_light_api.o
$(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) mobilenetv1_light_api.o -o mobilenetv1_light_api $(CXX_LIBS) $(LDFLAGS)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <vector>
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include "paddle_api.h" // NOLINT
using namespace paddle::lite_api; // NOLINT
struct Object {
int batch_id;
cv::Rect rec;
int class_id;
float prob;
};
int64_t ShapeProduction(const shape_t& shape) {
int64_t res = 1;
for (auto i : shape) res *= i;
return res;
}
const char* class_names[] = {
"background", "aeroplane", "bicycle", "bird", "boat",
"bottle", "bus", "car", "cat", "chair",
"cow", "diningtable", "dog", "horse", "motorbike",
"person", "pottedplant", "sheep", "sofa", "train",
"tvmonitor"};
// fill tensor with mean and scale and trans layout: nhwc -> nchw, neon speed up
void neon_mean_scale(const float* din,
float* dout,
int size,
const std::vector<float> mean,
const std::vector<float> scale) {
if (mean.size() != 3 || scale.size() != 3) {
std::cerr << "[ERROR] mean or scale size must equal to 3\n";
exit(1);
}
float32x4_t vmean0 = vdupq_n_f32(mean[0]);
float32x4_t vmean1 = vdupq_n_f32(mean[1]);
float32x4_t vmean2 = vdupq_n_f32(mean[2]);
float32x4_t vscale0 = vdupq_n_f32(1.f / scale[0]);
float32x4_t vscale1 = vdupq_n_f32(1.f / scale[1]);
float32x4_t vscale2 = vdupq_n_f32(1.f / scale[2]);
float* dout_c0 = dout;
float* dout_c1 = dout + size;
float* dout_c2 = dout + size * 2;
int i = 0;
for (; i < size - 3; i += 4) {
float32x4x3_t vin3 = vld3q_f32(din);
float32x4_t vsub0 = vsubq_f32(vin3.val[0], vmean0);
float32x4_t vsub1 = vsubq_f32(vin3.val[1], vmean1);
float32x4_t vsub2 = vsubq_f32(vin3.val[2], vmean2);
float32x4_t vs0 = vmulq_f32(vsub0, vscale0);
float32x4_t vs1 = vmulq_f32(vsub1, vscale1);
float32x4_t vs2 = vmulq_f32(vsub2, vscale2);
vst1q_f32(dout_c0, vs0);
vst1q_f32(dout_c1, vs1);
vst1q_f32(dout_c2, vs2);
din += 12;
dout_c0 += 4;
dout_c1 += 4;
dout_c2 += 4;
}
for (; i < size; i++) {
*(dout_c0++) = (*(din++) - mean[0]) * scale[0];
*(dout_c0++) = (*(din++) - mean[1]) * scale[1];
*(dout_c0++) = (*(din++) - mean[2]) * scale[2];
}
}
void pre_process(const cv::Mat& img, int width, int height, float* data) {
cv::Mat rgb_img;
cv::cvtColor(img, rgb_img, cv::COLOR_BGR2RGB);
cv::resize(rgb_img, rgb_img, cv::Size(width, height), 0.f, 0.f);
cv::Mat imgf;
rgb_img.convertTo(imgf, CV_32FC3, 1 / 255.f);
std::vector<float> mean = {0.5f, 0.5f, 0.5f};
std::vector<float> scale = {0.5f, 0.5f, 0.5f};
const float* dimg = reinterpret_cast<const float*>(imgf.data);
neon_mean_scale(dimg, data, width * height, mean, scale);
}
std::vector<Object> detect_object(const float* data,
int count,
float thresh,
cv::Mat& image) { // NOLINT
if (data == nullptr) {
std::cerr << "[ERROR] data can not be nullptr\n";
exit(1);
}
std::vector<Object> rect_out;
for (int iw = 0; iw < count; iw++) {
int oriw = image.cols;
int orih = image.rows;
if (data[1] > thresh && static_cast<int>(data[0]) > 0) {
Object obj;
int x = static_cast<int>(data[2] * oriw);
int y = static_cast<int>(data[3] * orih);
int w = static_cast<int>(data[4] * oriw) - x;
int h = static_cast<int>(data[5] * orih) - y;
cv::Rect rec_clip =
cv::Rect(x, y, w, h) & cv::Rect(0, 0, image.cols, image.rows);
obj.batch_id = 0;
obj.class_id = static_cast<int>(data[0]);
obj.prob = data[1];
obj.rec = rec_clip;
if (w > 0 && h > 0 && obj.prob <= 1) {
rect_out.push_back(obj);
cv::rectangle(image, rec_clip, cv::Scalar(0, 0, 255), 2, cv::LINE_AA);
std::string str_prob = std::to_string(obj.prob);
std::string text = std::string(class_names[obj.class_id]) + ": " +
str_prob.substr(0, str_prob.find(".") + 4);
int font_face = cv::FONT_HERSHEY_COMPLEX_SMALL;
double font_scale = 1.f;
int thickness = 2;
cv::Size text_size =
cv::getTextSize(text, font_face, font_scale, thickness, nullptr);
float new_font_scale = w * 0.35 * font_scale / text_size.width;
text_size = cv::getTextSize(
text, font_face, new_font_scale, thickness, nullptr);
cv::Point origin;
origin.x = x + 10;
origin.y = y + text_size.height + 10;
cv::putText(image,
text,
origin,
font_face,
new_font_scale,
cv::Scalar(0, 255, 255),
thickness,
cv::LINE_AA);
std::cout << "detection, image size: " << image.cols << ", "
<< image.rows
<< ", detect object: " << class_names[obj.class_id]
<< ", score: " << obj.prob << ", location: x=" << x
<< ", y=" << y << ", width=" << w << ", height=" << h
<< std::endl;
}
}
data += 6;
}
return rect_out;
}
void RunModel(std::string model_dir, std::string img_path) {
// 1. Set MobileConfig
MobileConfig config;
config.set_model_dir(model_dir);
// 2. Create PaddlePredictor by MobileConfig
std::shared_ptr<PaddlePredictor> predictor =
CreatePaddlePredictor<MobileConfig>(config);
// 3. Prepare input data from image
std::unique_ptr<Tensor> input_tensor(std::move(predictor->GetInput(0)));
const int in_width = 300;
const int in_height = 300;
input_tensor->Resize({1, 3, in_height, in_width});
auto* data = input_tensor->mutable_data<float>();
cv::Mat img = imread(img_path, cv::IMREAD_COLOR);
pre_process(img, in_width, in_height, data);
// 4. Run predictor
predictor->Run();
// 5. Get output and post process
std::unique_ptr<const Tensor> output_tensor(
std::move(predictor->GetOutput(0)));
auto* outptr = output_tensor->data<float>();
auto shape_out = output_tensor->shape();
int64_t cnt = 1;
for (auto& i : shape_out) {
cnt *= i;
}
auto rec_out = detect_object(outptr, static_cast<int>(cnt / 6), 0.6f, img);
std::string result_name =
img_path.substr(0, img_path.find(".")) + "_detection_result.jpg";
cv::imwrite(result_name, img);
}
int main(int argc, char** argv) {
if (argc < 3) {
std::cerr << "[ERROR] usage: " << argv[0] << " model_dir image_path\n";
exit(1);
}
std::string model_dir = argv[1];
std::string img_path = argv[2];
RunModel(model_dir, img_path);
return 0;
}
......@@ -13,12 +13,10 @@
// limitations under the License.
#include <gflags/gflags.h>
#include <stdio.h>
#include <iostream>
#include <vector>
#include "paddle_api.h" // NOLINT
#include "paddle_use_kernels.h" // NOLINT
#include "paddle_use_ops.h" // NOLINT
#include "paddle_use_passes.h" // NOLINT
#include "paddle_api.h" // NOLINT
#include "paddle_use_passes.h" // NOLINT
using namespace paddle::lite_api; // NOLINT
......@@ -32,27 +30,6 @@ int64_t ShapeProduction(const shape_t& shape) {
return res;
}
void CheckInput(char*** argv) {
if (FLAGS_model_dir == "") {
printf(
"Usage: %s --model_dir=<your-model-directory> "
"--optimized_model_dir=<your-optmized-model-directory> "
"--prefer_int8_kernel=[true|false]\n",
*argv[0]);
exit(1);
}
if (FLAGS_optimized_model_dir == "") {
FLAGS_optimized_model_dir = FLAGS_model_dir;
printf(
"[WARN] no `optimized_model_dir` provided. set `optimized_model_dir` "
":= `model_dir`:%s\n",
FLAGS_optimized_model_dir.c_str());
}
printf("[WARN] model_dir:%s\n", FLAGS_model_dir.c_str());
printf("[WARN] optimized_model_dir:%s\n", FLAGS_optimized_model_dir.c_str());
printf("[WARN] prefer_int8_kernel:%s\n", FLAGS_prefer_int8_kernel);
}
// 0. Enable OpenCL, if needed
// Enable `DEMO_WITH_OPENCL` macro below, if user need use gpu(opencl)
// #define DEMO_WITH_OPENCL
......@@ -99,15 +76,22 @@ void RunModel() {
// 6. Get output
std::unique_ptr<const Tensor> output_tensor(
std::move(predictor->GetOutput(0)));
printf("Output dim: %d\n", output_tensor->shape()[1]);
std::cout << "Output shape " << output_tensor->shape()[1] << std::endl;
for (int i = 0; i < ShapeProduction(output_tensor->shape()); i += 100) {
printf("Output[%d]: %f\n", i, output_tensor->data<float>()[i]);
std::cout << "Output[" << i << "]: " << output_tensor->data<float>()[i]
<< std::endl;
}
}
int main(int argc, char** argv) {
google::ParseCommandLineFlags(&argc, &argv, true);
CheckInput(&argv);
if (FLAGS_model_dir == "" || FLAGS_optimized_model_dir == "") {
std::cerr << "[ERROR] usage: " << argv[0]
<< " --model_dir=<your-model-directory>"
<< " --optimized_model_dir=<your-optmized-model-directory> "
<< " --prefer_int8_kernel=[true|false]\n";
exit(1);
}
RunModel();
return 0;
}
......@@ -12,35 +12,22 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <stdio.h>
#include <iostream>
#include <vector>
#include "paddle_api.h" // NOLINT
#include "paddle_use_kernels.h" // NOLINT
#include "paddle_use_ops.h" // NOLINT
#include "paddle_api.h" // NOLINT
using namespace paddle::lite_api; // NOLINT
DEFINE_string(model_dir, "", "Model dir path.");
int64_t ShapeProduction(const shape_t& shape) {
int64_t res = 1;
for (auto i : shape) res *= i;
return res;
}
void CheckInput(char*** argv) {
if (FLAGS_model_dir == "") {
printf("Usage: %s --model_dir=<your-nb-model-directory>\n", *argv[0]);
exit(1);
}
printf("[WARN] model_dir:%s\n", FLAGS_model_dir.c_str());
}
void RunModel() {
void RunModel(std::string model_dir) {
// 1. Set MobileConfig
MobileConfig config;
config.set_model_dir(FLAGS_model_dir);
config.set_model_dir(model_dir);
// 2. Create PaddlePredictor by MobileConfig
std::shared_ptr<PaddlePredictor> predictor =
......@@ -60,15 +47,19 @@ void RunModel() {
// 5. Get output
std::unique_ptr<const Tensor> output_tensor(
std::move(predictor->GetOutput(0)));
printf("Output dim: %d\n", output_tensor->shape()[1]);
std::cout << "Output shape " << output_tensor->shape()[1] << std::endl;
for (int i = 0; i < ShapeProduction(output_tensor->shape()); i += 100) {
printf("Output[%d]: %f\n", i, output_tensor->data<float>()[i]);
std::cout << "Output[" << i << "]: " << output_tensor->data<float>()[i]
<< std::endl;
}
}
int main(int argc, char** argv) {
google::ParseCommandLineFlags(&argc, &argv, true);
CheckInput(&argv);
RunModel();
if (argc < 2) {
std::cerr << "[ERROR] usage: ./" << argv[0] << " naive_buffer_model_dir\n";
exit(1);
}
std::string model_dir = argv[1];
RunModel(model_dir);
return 0;
}
......@@ -105,7 +105,6 @@ lite_cc_test(test_split_compute_arm SRCS split_compute_test.cc DEPS split_comput
lite_cc_test(test_concat_compute_arm SRCS concat_compute_test.cc DEPS concat_compute_arm)
lite_cc_test(test_transpose_compute_arm SRCS transpose_compute_test.cc DEPS transpose_compute_arm COMPILE_LEVEL extra)
lite_cc_test(test_argmax_compute_arm SRCS argmax_compute_test.cc DEPS argmax_compute_arm)
lite_cc_test(test_conv_transpose_compute_arm SRCS conv_transpose_compute_test.cc DEPS conv_transpose_compute_arm)
lite_cc_test(test_dropout_compute_arm SRCS dropout_compute_test.cc DEPS dropout_compute_arm)
if(LITE_BUILD_EXTRA)
lite_cc_test(test_lrn_compute_arm SRCS lrn_compute_test.cc DEPS lrn_compute_arm)
......
......@@ -56,6 +56,12 @@ void CastCompute::Run() {
float* out_data = param.Out->mutable_data<float>();
std::transform(
x_data_begin, x_data_end, out_data, TransOp<unsigned char, float>);
} else if (param.in_dtype == 3 && param.out_dtype == 2) {
const int64_t* x_data_begin = param.X->data<int64_t>();
const int64_t* x_data_end = x_data_begin + param.X->numel();
int32_t* out_data = param.Out->mutable_data<int32_t>();
std::transform(
x_data_begin, x_data_end, out_data, TransOp<int64_t, int32_t>);
} else {
LOG(FATAL) << "other has not been implemented";
}
......
......@@ -40,6 +40,7 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
int kw = w_dims[3];
int pad = paddings[0];
int stride = param.strides[0];
int threads = ctx.threads();
bool pads_equal =
((paddings[0] == paddings[1]) && (paddings[2] == paddings[3]));
......@@ -67,7 +68,11 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
VLOG(3) << "invoking dw conv";
} else if (param.groups == 1 && kw == 3 && stride == 1 && kps_equal &&
no_dilation) {
if (ic >= 32 && oc >= 32 && hout > 16 && wout > 16) {
bool use_winograd =
(threads == 1 && oc >= 4 && ic >= 4 && hout >= 6 && wout >= 6 &&
pads_equal) ||
(oc >= 32 && ic >= 32 && hout >= 16 && wout >= 16 && pads_equal);
if (use_winograd) {
/// winograd conv impl
impl_ = new WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>;
VLOG(3) << "invoking winograd conv";
......@@ -214,7 +219,7 @@ REGISTER_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW, ConvFp32, def)
REGISTER_LITE_KERNEL(conv2d, kARM, kInt8, kNCHW, ConvInt8_Int8, int8_out)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.BindInput("Filter",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindOutput("Output",
......@@ -223,7 +228,7 @@ REGISTER_LITE_KERNEL(conv2d, kARM, kInt8, kNCHW, ConvInt8_Int8, int8_out)
REGISTER_LITE_KERNEL(conv2d, kARM, kInt8, kNCHW, ConvInt8_Fp32, fp32_out)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.BindInput("Filter",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindOutput("Output",
......@@ -233,7 +238,7 @@ REGISTER_LITE_KERNEL(conv2d, kARM, kInt8, kNCHW, ConvInt8_Fp32, fp32_out)
REGISTER_LITE_KERNEL(
depthwise_conv2d, kARM, kInt8, kNCHW, ConvInt8_Int8, int8_out)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.BindInput("Filter",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindOutput("Output",
......@@ -243,7 +248,7 @@ REGISTER_LITE_KERNEL(
REGISTER_LITE_KERNEL(
depthwise_conv2d, kARM, kInt8, kNCHW, ConvInt8_Fp32, fp32_out)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.BindInput("Filter",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindOutput("Output",
......
......@@ -96,7 +96,8 @@ void Conv2DTransposeCompute::Run() {
int group_size_weights = ((m_roundup * k + 15) / 16) * 16;
bool flag_1x1s1p1 = (kw == 1) && (kh == 1) && (param.strides[0] == 1) &&
(param.strides[1] == 1) && pads_all_qual &&
(dilations[0] == 1) && (dilations[1] == 1);
(paddings[0] == 0) && (dilations[0] == 1) &&
(dilations[1] == 1);
ctx.ExtendWorkspace(sizeof(float) * group * m * n);
auto din = param.x->data<float>();
......@@ -138,7 +139,9 @@ void Conv2DTransposeCompute::Run() {
kh,
kw,
paddings[0],
paddings[1],
paddings[2],
paddings[3],
param.strides[0],
param.strides[1],
dilations[0],
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/conv_transpose_compute.h"
#include <gtest/gtest.h>
#include <cmath>
#include <cstdlib>
#include <functional>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
template <typename type, typename type2>
static void basic_gemm(int m,
int n,
int k,
const type* a,
const type* b,
const type2* bias,
type2* c,
type2 alpha,
type2 beta,
bool trans_a = false,
bool trans_b = false,
bool flag_bias = false,
bool flag_relu = false) {
#pragma omp parallel for
for (int i = 0; i < m; ++i) {
type2 bias_data = (type2)0;
if (flag_bias) {
bias_data = bias[i];
}
for (int j = 0; j < n; ++j) {
type2 sum = static_cast<type2>(0);
for (int l = 0; l < k; ++l) {
type av;
type bv;
if (trans_a) {
av = a[l * m + i];
} else {
av = a[i * k + l];
}
if (trans_b) {
bv = b[j * k + l];
} else {
bv = b[l * n + j];
}
sum += av * bv;
}
type2 tmp = alpha * sum + beta * c[i * n + j] + bias_data;
if (flag_relu) {
c[i * n + j] = tmp > (type2)0 ? tmp : (type2)0;
} else {
c[i * n + j] = tmp;
}
}
}
}
//! for float, dtype1 and type2 is float
//! for int8, dytpe1 is char, dtype2 is int
template <typename Dtype1, typename Dtype2>
bool deconv_basic(const Dtype1* din,
Dtype2* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const Dtype1* weights,
const Dtype2* bias,
int group,
int kernel_w,
int kernel_h,
int stride_w,
int stride_h,
int dila_w,
int dila_h,
int pad_w,
int pad_h,
bool flag_bias,
bool flag_relu) {
int m = chout * kernel_w * kernel_h / group;
int n = hin * win;
int k = chin / group;
if (chin != chout || group != chin) {
CHECK_OR_FALSE(chin % group == 0);
CHECK_OR_FALSE(chout % group == 0);
}
lite::Tensor workspace_tensor;
std::vector<int64_t> wt_shape = {1, 1, 1, group * m * n};
workspace_tensor.Resize(wt_shape);
auto* workspace_ptr = workspace_tensor.mutable_data<Dtype2>();
int group_size_in = win * hin * chin / group;
int group_size_out = wout * hout * chout / group;
int group_size_coldata = m * n;
int group_size_weights = chin * chout * kernel_w * kernel_h / (group * group);
bool flag_1x1s1p1 = (kernel_w == 1) && (kernel_h == 1) && (stride_h == 1) &&
(stride_w == 1) && (pad_w == 1) && (pad_h == 1) &&
(dila_w == 1) && (dila_h == 1);
for (int i = 0; i < num; ++i) {
const Dtype1* din_batch = din + i * chin * hin * win;
Dtype2* dout_batch = dout + i * chout * hout * wout;
Dtype2* col_data = workspace_ptr;
if (flag_1x1s1p1) {
col_data = dout_batch;
}
memset(col_data, 0, sizeof(Dtype2) * group_size_coldata);
for (int g = 0; g < group; ++g) {
const Dtype1* din_group = din_batch + g * group_size_in;
const Dtype1* weights_group = weights + g * group_size_weights;
Dtype2* coldata_group = col_data + g * group_size_coldata;
basic_gemm<Dtype1, Dtype2>(m,
n,
k,
weights_group,
din_group,
nullptr,
coldata_group,
(Dtype2)1,
(Dtype2)0,
true,
false,
false,
(!flag_bias && flag_relu));
}
if (!flag_1x1s1p1) {
lite::arm::math::col2im(col_data,
chout,
hout,
wout,
kernel_h,
kernel_w,
pad_h,
pad_w,
stride_h,
stride_w,
dila_h,
dila_w,
dout_batch);
}
if (flag_bias) {
lite::arm::math::fill_bias_relu(
dout_batch, bias, chout, wout * hout, flag_bias, flag_relu);
}
}
return true;
}
template <typename Dtype1, typename Dtype2>
void conv2d_transpose_compute_ref(const operators::ConvParam& param) {
const Dtype1* din = param.x->data<Dtype1>();
Dtype2* dout = param.output->mutable_data<Dtype2>();
int num = param.x->dims()[0];
int chout = param.output->dims()[1];
int hout = param.output->dims()[2];
int wout = param.output->dims()[3];
int chin = param.x->dims()[1];
int hin = param.x->dims()[2];
int win = param.x->dims()[3];
const Dtype1* weights = param.filter->mutable_data<Dtype1>();
Dtype2* bias = nullptr;
if (param.bias != nullptr) {
bias = param.bias->mutable_data<Dtype2>();
}
int group = param.groups;
auto paddings = *param.paddings;
auto dilations = *param.dilations;
int kernel_h = param.filter->dims()[2];
int kernel_w = param.filter->dims()[3];
int stride_h = param.strides[0];
int stride_w = param.strides[1];
int dila_h = dilations[0];
int dila_w = dilations[1];
int pad_h = paddings[0];
int pad_w = paddings[2];
bool flag_bias = (param.bias != nullptr);
bool flag_relu = param.fuse_relu;
deconv_basic<float, float>(din,
dout,
num,
chout,
hout,
wout,
chin,
hin,
win,
weights,
bias,
group,
kernel_w,
kernel_h,
stride_w,
stride_h,
dila_w,
dila_h,
pad_w,
pad_h,
flag_bias,
flag_relu);
}
TEST(conv2d_transpose_arm, retrive_op) {
auto op = KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"conv2d_transpose");
ASSERT_FALSE(op.empty());
ASSERT_TRUE(op.front());
}
TEST(conv2d_transpose_arm, init) {
Conv2DTransposeCompute compute;
ASSERT_EQ(compute.precision(), PRECISION(kFloat));
ASSERT_EQ(compute.target(), TARGET(kARM));
}
TEST(conv2d_transpose_arm, compute) {
DeviceInfo::Init();
for (auto n : {1, 2}) {
for (auto ic : {1, 3 /*, 128*/}) {
for (auto oc : {1, 3 /*, 128*/}) {
for (auto ih : {2, 8 /*, 56 , 112, 224, 512*/}) {
for (auto iw : {2, 8 /*, 56, 112, 224, 512*/}) {
for (auto flag_bias : {false, true}) {
for (auto flag_relu : {false, true}) {
for (auto dilation : {1, 2}) {
for (auto stride : {1, 2}) {
for (auto padding : {0, 1, 2}) {
for (auto ks : {2, 3, 5}) {
for (auto group : {1, 2}) {
// obtain shape
if (ic % group != 0 || oc % group != 0) {
group = 1;
}
std::vector<int64_t> input_shape = {n, ic, ih, iw};
std::vector<int64_t> filter_shape = {
oc / group, ic, ks, ks};
int oh = (ih - 1) * stride - 2 * padding +
dilation * (ks - 1) + 1;
int ow = (iw - 1) * stride - 2 * padding +
dilation * (ks - 1) + 1;
if (oh < 1 || ow < 1) {
break;
}
std::vector<int64_t> output_shape = {n, oc, oh, ow};
std::vector<int64_t> bias_shape = {1, oc, 1, 1};
// define and resize tensor
Tensor input;
Tensor filter;
Tensor filter_copy;
Tensor bias;
Tensor output;
Tensor output_ref;
input.Resize(input_shape);
filter.Resize(filter_shape);
filter_copy.Resize(filter_shape);
output.Resize(output_shape);
output_ref.Resize(output_shape);
auto* input_data = input.mutable_data<float>();
auto* filter_data = filter.mutable_data<float>();
auto* filter_copy_data =
filter_copy.mutable_data<float>();
auto* output_data = output.mutable_data<float>();
// initialize tensor
for (int i = 0; i < input.dims().production(); i++) {
float sign = i % 3 == 0 ? -1.0f : 1.0f;
input_data[i] = sign * static_cast<float>(i % 128);
}
for (int i = 0; i < filter.dims().production(); i++) {
filter_data[i] =
i /
static_cast<float>(filter.dims().production());
filter_copy_data[i] =
i / static_cast<float>(
filter_copy.dims().production());
}
if (flag_bias) {
bias.Resize(bias_shape);
auto* bias_data = bias.mutable_data<float>();
for (int i = 0; i < bias.dims().production(); i++) {
bias_data[i] = static_cast<float>(i);
}
}
// prepare kernel params and run
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<ARMContext>();
Conv2DTransposeCompute conv2d_transpose;
conv2d_transpose.SetContext(std::move(ctx));
operators::ConvParam param;
param.x = &input;
param.filter = &filter;
param.output = &output;
param.bias = nullptr;
if (flag_bias) {
bias.Resize(bias_shape);
auto* bias_data = bias.mutable_data<float>();
for (int i = 0; i < bias.dims().production(); i++) {
bias_data[i] = static_cast<float>(i);
}
param.bias = &bias;
}
param.fuse_relu = flag_relu;
std::vector<int> paddings = {
padding, padding, padding, padding};
param.strides = std::vector<int>({stride, stride});
std::vector<int> dilations = {dilation, dilation};
param.paddings =
std::make_shared<std::vector<int>>(paddings);
param.dilations =
std::make_shared<std::vector<int>>(dilations);
param.groups = group;
conv2d_transpose.SetParam(param);
conv2d_transpose.Launch();
// invoking ref implementation and compare results
param.filter = &filter_copy;
param.output = &output_ref;
conv2d_transpose_compute_ref<float, float>(param);
auto* output_ref_data =
output_ref.mutable_data<float>();
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(
output_data[i], output_ref_data[i], 1e-3);
}
}
}
}
}
}
}
}
}
}
}
}
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(conv2d_transpose, kARM, kFloat, kNCHW, def);
......@@ -26,6 +26,7 @@ template <>
void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::ReInitWhenNeeded() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
int threads = ctx.threads();
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims();
......@@ -36,77 +37,97 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::ReInitWhenNeeded() {
}
int ic = x_dims[1];
int ow = o_dims[3];
int oh = o_dims[2];
int ih = x_dims[2];
int iw = x_dims[3];
int oc = o_dims[1];
int tile_w = (ow + 5) / 6;
int tile_h = (oh + 5) / 6;
int size_tile = tile_h * tile_w;
int size_trans_channel = 8 * 8 * size_tile;
int max_ch = ic > oc ? ic : oc;
const int n_wino = size_tile;
workspace_size_ = (size_trans_channel * max_ch * 2 + n_wino) * sizeof(float);
int oh = o_dims[2];
int ow = o_dims[3];
int tile_block = 8;
#ifdef __aarch64__
tile_block = 16;
#endif
int parallel_threads =
(((ow + 5) / 6) * ((oh + 5) / 6) + tile_block - 1) / tile_block;
if (threads <= 2 && parallel_threads >= threads) {
if (last_kernel_is_c4_ == 1) {
return;
}
last_kernel_is_c4_ = 1;
auto pad = *(param.paddings);
int pad_h = pad[0];
int pad_w = pad[2];
int oc_pad = (oc + 3) / 4 * 4;
int ic_pad = (ic + 3) / 4 * 4;
const int new_input_size =
(ic + 3) / 4 * 4 * (ih + pad_h * 2) * (iw + pad_w * 2);
const int temp_size =
(tile_block * ((ic + 3) / 4 + (oc + 3) / 4) * 256 + 512) * threads;
ctx.ExtendWorkspace((temp_size + new_input_size) * sizeof(float));
weights_.Resize({1, 1, 1, 64 * oc_pad * ic_pad});
ctx.ExtendWorkspace((temp_size + new_input_size) * sizeof(float));
void* trans_tmp_ptr = malloc(sizeof(float) * 8 * 8 * oc * ic);
auto weights_data_ = weights_.mutable_data<float>();
lite::arm::math::weight_trans_c4(
weights_data_, param.filter->data<float>(), ic, oc, trans_tmp_ptr);
free(trans_tmp_ptr);
} else {
if (last_kernel_is_c4_ == 0) {
return;
}
last_kernel_is_c4_ = 0;
int tile_w = (ow + 5) / 6;
int tile_h = (oh + 5) / 6;
int size_tile = tile_h * tile_w;
int size_trans_channel = 8 * 8 * size_tile;
int max_ch = ic > oc ? ic : oc;
const int n_wino = size_tile;
ctx.ExtendWorkspace((size_trans_channel * max_ch * 2 + n_wino) *
sizeof(float));
const int m_wino = oc;
int hblock = lite::arm::math::get_hblock(&ctx);
int m_round = hblock * ((m_wino + hblock - 1) / hblock);
weights_.Resize({1, 1, 1, 8 * 8 * m_round * ic});
ctx.ExtendWorkspace((size_trans_channel * max_ch * 2 + n_wino) *
sizeof(float));
auto weights_wino =
static_cast<float*>(malloc(sizeof(float) * 8 * 8 * oc * ic));
void* trans_tmp_ptr = malloc(sizeof(float) * 8 * 8 * oc * ic);
lite::arm::math::winograd_transform_weights(
weights_wino, param.filter->data<float>(), oc, ic, trans_tmp_ptr);
auto weights_trans = weights_.mutable_data<float>();
for (int i = 0; i < 64; ++i) {
float* packed_weights = weights_trans + i * m_round * ic;
const float* weights_wino_ptr = weights_wino + i * oc * ic;
lite::arm::math::prepackA(packed_weights,
weights_wino_ptr,
1.f,
ic,
0,
m_wino,
0,
ic,
false,
&ctx);
}
free(trans_tmp_ptr);
free(weights_wino);
}
last_shape_ = x_dims;
}
template <>
void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims();
auto o_dims = param.output->dims();
last_shape_ = x_dims;
int ic = x_dims[1];
int ow = o_dims[3];
int oh = o_dims[2];
int oc = o_dims[1];
int tile_w = (ow + 5) / 6;
int tile_h = (oh + 5) / 6;
int size_tile = tile_h * tile_w;
int size_trans_channel = 8 * 8 * size_tile;
int max_ch = ic > oc ? ic : oc;
const int m_wino = oc;
const int n_wino = size_tile;
int hblock = lite::arm::math::get_hblock(&ctx);
int m_round = hblock * ((m_wino + hblock - 1) / hblock);
weights_.Resize({1, 1, 1, 8 * 8 * m_round * ic});
workspace_size_ = (size_trans_channel * max_ch * 2 + n_wino) * sizeof(float);
auto weights_wino =
static_cast<float*>(malloc(sizeof(float) * 8 * 8 * oc * ic));
void* trans_tmp_ptr = malloc(sizeof(float) * 8 * 8 * oc * ic);
lite::arm::math::winograd_transform_weights(
weights_wino, param.filter->data<float>(), oc, ic, trans_tmp_ptr);
auto weights_trans = weights_.mutable_data<float>();
for (int i = 0; i < 64; ++i) {
float* packed_weights = weights_trans + i * m_round * ic;
const float* weights_wino_ptr = weights_wino + i * oc * ic;
lite::arm::math::prepackA(packed_weights,
weights_wino_ptr,
1.f,
ic,
0,
m_wino,
0,
ic,
false,
&ctx);
}
free(trans_tmp_ptr);
free(weights_wino);
ReInitWhenNeeded();
}
template <>
void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
// extend workspace
ctx.ExtendWorkspace(workspace_size_);
const auto* i_data = param.x->data<float>();
const auto* w_data = weights_.data<float>();
const auto* b_data = param.bias ? param.bias->data<float>() : nullptr;
......@@ -124,8 +145,42 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
int ow = o_dims[3];
int oc = o_dims[1];
lite::arm::math::conv_winograd3x3(
i_data, o_data, bs, oc, oh, ow, ic, ih, iw, w_data, b_data, param, &ctx);
int tile_block = 8;
#ifdef __aarch64__
tile_block = 16;
#endif
int threads = ctx.threads();
int parallel_threads =
(((ow + 5) / 6) * ((oh + 5) / 6) + tile_block - 1) / tile_block;
if (threads <= 2 && parallel_threads >= threads) {
lite::arm::math::conv_compute_6x6_3x3(i_data,
o_data,
bs,
oc,
oh,
ow,
ic,
ih,
iw,
w_data,
b_data,
param,
&ctx);
} else {
lite::arm::math::conv_winograd3x3(i_data,
o_data,
bs,
oc,
oh,
ow,
ic,
ih,
iw,
w_data,
b_data,
param,
&ctx);
}
}
} // namespace arm
......
......@@ -40,6 +40,7 @@ class WinogradConv : public KernelLite<TARGET(kARM), Ptype> {
Tensor weights_;
DDim last_shape_;
int workspace_size_{0};
int last_kernel_is_c4_{-1};
};
} // namespace arm
......
......@@ -25,6 +25,38 @@ class FillConstantCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::FillConstantParam;
inline DDimLite GetShape(const param_t& param) {
// 1. shape is a Tensor
if (param.shape_tensor != nullptr) {
auto* shape_tensor = param.shape_tensor;
auto* shape_data = shape_tensor->data<int>();
auto vec_shape =
std::vector<int64_t>(shape_data, shape_data + shape_tensor->numel());
return DDimLite(vec_shape);
}
// 2. shape is a list/tuple containing Tensor
auto shape_tensor_list = param.shape_tensor_list;
if (shape_tensor_list.size() > 0) {
std::vector<int64_t> vec_shape;
for (size_t i = 0; i < shape_tensor_list.size(); ++i) {
auto tensor = shape_tensor_list[i];
vec_shape.push_back(*tensor->data<int>());
}
return DDimLite(vec_shape);
}
// 3. shape is a list/tuple without containing Tensor
auto vec_shape = param.shape;
return DDimLite(vec_shape);
}
void PrepareForRun() override {
auto& param = *param_.get_mutable<param_t>();
auto outdims = GetShape(param);
param.Out->Resize(outdims);
}
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto& context = ctx_->As<ARMContext>();
......@@ -107,6 +139,11 @@ REGISTER_LITE_KERNEL(fill_constant,
kNCHW,
paddle::lite::kernels::arm::FillConstantCompute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("ShapeTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("ShapeTensorList",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(
......
......@@ -15,14 +15,15 @@ add_kernel(transpose_compute_cuda CUDA basic SRCS transpose_compute.cu DEPS ${li
add_kernel(nearest_interp_compute_cuda CUDA basic SRCS nearest_interp_compute.cu DEPS ${lite_kernel_deps})
add_kernel(conv2d_cuda CUDA basic SRCS conv_compute.cc DEPS ${lite_kernel_deps} ${math_cuda})
add_kernel(concat_compute_cuda CUDA basic SRCS concat_compute.cu DEPS ${lite_kernel_deps})
add_kernel(elementwise_add_compute_cuda CUDA basic SRCS elementwise_add_compute.cu DEPS ${lite_kernel_deps} cuda_elementwise)
add_kernel(elementwise_compute_cuda CUDA basic SRCS elementwise_compute.cu DEPS ${lite_kernel_deps} cuda_elementwise)
add_kernel(calib_compute_cuda CUDA basic SRCS calib_compute.cu DEPS ${lite_kernel_deps})
add_kernel(layout_compute_cuda CUDA basic SRCS layout_compute.cc DEPS ${lite_kernel_deps} cuda_transpose)
add_kernel(feed_compute_cuda CUDA basic SRCS feed_compute.cc DEPS ${lite_kernel_deps})
add_kernel(scale_compute_cuda CUDA basic SRCS scale_compute.cc DEPS ${lite_kernel_deps} cuda_scale)
add_kernel(dropout_compute_cuda CUDA basic SRCS dropout_compute.cc DEPS ${lite_kernel_deps} cuda_scale)
add_kernel(softmax_compute_cuda CUDA basic SRCS softmax_compute.cu DEPS ${lite_kernel_deps})
add_kernel(pool_compute_cuda CUDA basic SRCS pool_compute.cu DEPS ${lite_kernel_deps})
add_kernel(pool_compute_cuda CUDA basic SRCS pool_compute.cu DEPS
${lite_kernel_deps} cudnn_pool)
add_kernel(bilinear_interp_compute_cuda CUDA basic SRCS bilinear_interp_compute.cu DEPS ${lite_kernel_deps})
add_kernel(search_seq_depadding_compute_cuda CUDA extra SRCS search_seq_depadding_compute.cu DEPS ${lite_kernel_deps})
add_kernel(search_grnn_compute_cuda CUDA extra SRCS search_grnn_compute.cu DEPS ${lite_kernel_deps} cuda_gemm)
......@@ -47,12 +48,13 @@ nv_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_c
nv_test(transpose_compute_cuda_test SRCS transpose_compute_test.cc DEPS transpose_compute_cuda)
nv_test(search_group_padding_compute_cuda_test SRCS search_group_padding_compute_test.cc DEPS search_group_padding_compute_cuda)
nv_test(concat_compute_cuda_test SRCS concat_compute_test.cc DEPS concat_compute_cuda)
nv_test(elementwise_add_compute_cuda_test SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_cuda)
nv_test(elementwise_compute_cuda_test SRCS elementwise_compute_test.cc DEPS elementwise_compute_cuda)
nv_test(softmax_compute_cuda_test SRCS softmax_compute_test.cc DEPS softmax_compute_cuda)
#nv_test(layout_cuda_test SRCS layout_compute_test.cc DEPS layout_compute_cuda)
nv_test(mul_compute_cuda_test SRCS mul_compute_test.cc DEPS mul_compute_cuda)
nv_test(dropout_compute_cuda_test SRCS dropout_compute_test.cc DEPS dropout_compute_cuda )
nv_test(bilinear_interp_compute_cuda_test SRCS bilinear_interp_compute_test.cc DEPS bilinear_interp_compute_cuda)
nv_test(pool_compute_cuda_test SRCS pool_compute_test.cc DEPS pool_compute_cuda)
nv_test(sequence_reverse_compute_cuda_test SRCS sequence_reverse_compute_test.cc DEPS sequence_reverse_compute_cuda)
nv_test(sequence_concat_compute_cuda_test SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_cuda)
nv_test(attention_padding_mask_compute_cuda_test SRCS attention_padding_mask_compute_test.cc DEPS attention_padding_mask_compute_cuda)
......
......@@ -40,6 +40,7 @@ __global__ void ker_attention_padding_mask(T* out_data,
const int attn_seq_len,
const int src_seq_num,
const int src_seq_len,
const T* pad_begin_data,
const T mask,
const int count) {
CUDA_KERNEL_LOOP(tid, count) {
......@@ -49,7 +50,12 @@ __global__ void ker_attention_padding_mask(T* out_data,
int attn_word_id = tmp_tid % attn_seq_len;
int src_seq_id = attn_seq_id % src_seq_num;
int cur_len = src_offset[src_seq_id + 1] - src_offset[src_seq_id];
if (src_word_id >= cur_len) {
int k = static_cast<int>(pad_begin_data[src_seq_id]);
if (k < cur_len &&
tid >= src_seq_len * (attn_seq_len * attn_seq_id + attn_word_id) + k &&
tid < src_seq_len * (attn_seq_len * attn_seq_id + attn_word_id) +
cur_len) {
out_data[tid] = mask;
} else {
out_data[tid] = attn_data[tid];
......@@ -79,6 +85,35 @@ void AttentionPaddingMaskCompute::Run() {
auto attn_data = attn->data<float>();
auto out_data = out->mutable_data<float>(TARGET(kCUDA));
std::vector<float> src_cpu(src->numel(), 0);
TargetWrapperCuda::MemcpyAsync(src_cpu.data(),
src->data<float>(),
sizeof(float) * src->numel(),
IoDirection::DtoH,
stream);
cudaStreamSynchronize(stream);
std::vector<float> pad_begin(src_seq_num, 0);
auto src_len = static_cast<int64_t>(src->lod()[0][1]);
int _pad_id = param.pad_id;
for (int i = 0; i < src_seq_num; ++i) {
const auto* src_data = src_cpu.data() + src_len * i;
int index = src_len - 1;
for (; index >= 0 && _pad_id == static_cast<int>(src_data[index]);
--index) {
}
pad_begin[i] = static_cast<float>(index + 1);
}
param.pad_begin->Resize({static_cast<int64_t>(src_seq_num)});
auto pad_begin_cuda_data =
param.pad_begin->mutable_data<float>(TARGET(kCUDA));
TargetWrapperCuda::MemcpyAsync(pad_begin_cuda_data,
pad_begin.data(),
sizeof(float) * src_seq_num,
IoDirection::HtoD,
stream);
std::vector<int> src_offset_cpu(src_offset.size(), 0);
for (int i = 0; i < src_offset.size(); i++) {
src_offset_cpu[i] = src_offset[i];
......@@ -101,11 +136,12 @@ void AttentionPaddingMaskCompute::Run() {
attn_seq_len,
src_seq_num,
src_seq_len,
pad_begin_cuda_data,
param.mask,
count);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
if (error != cudaSuccess) LOG(ERROR) << cudaGetErrorString(error);
}
} // namespace cuda
......@@ -113,7 +149,7 @@ void AttentionPaddingMaskCompute::Run() {
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(attention_padding_mask,
REGISTER_LITE_KERNEL(search_attention_padding_mask,
kCUDA,
kFloat,
kNCHW,
......
......@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/calib_compute.h"
#include <gtest/gtest.h>
#include <algorithm>
#include <memory>
......@@ -58,12 +59,7 @@ void calib_ref(const operators::CalibParam& param, bool to_float = true) {
}
TEST(calib_cuda, int8_to_fp32) {
LOG(INFO) << "to get kernel ...";
auto kernels = KernelRegistry::Global().Create(
"calib", TARGET(kCUDA), PRECISION(kInt8), DATALAYOUT(kNCHW));
ASSERT_FALSE(kernels.empty());
auto calib = std::move(*std::next(kernels.begin(), 1));
LOG(INFO) << "get kernel: " << calib->doc();
CalibComputeInt8ToFp32 calib;
const int n = 64, c = 32, h = 18, w = 18;
Tensor x;
Tensor x_cpu;
......@@ -87,14 +83,14 @@ TEST(calib_cuda, int8_to_fp32) {
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
calib->SetContext(std::move(ctx));
calib.SetContext(std::move(ctx));
operators::CalibParam param;
param.scale = 0.013f;
param.input = &x;
param.output = &output;
calib->SetParam(param);
calib->Launch();
calib.SetParam(param);
calib.Launch();
cudaDeviceSynchronize();
// invoking ref implementation and compare results
param.input = &x_cpu;
......@@ -113,12 +109,7 @@ TEST(calib_cuda, int8_to_fp32) {
}
TEST(calib_cuda, fp32_to_int8) {
LOG(INFO) << "to get kernel ...";
auto kernels = KernelRegistry::Global().Create(
"calib", TARGET(kCUDA), PRECISION(kInt8), DATALAYOUT(kNCHW));
ASSERT_FALSE(kernels.empty());
auto calib = std::move(kernels.front());
LOG(INFO) << "get kernel: " << calib->doc();
CalibComputeFp32ToInt8 calib;
const int n = 64, c = 32, h = 18, w = 18;
Tensor x;
Tensor x_cpu;
......@@ -142,14 +133,14 @@ TEST(calib_cuda, fp32_to_int8) {
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
calib->SetContext(std::move(ctx));
calib.SetContext(std::move(ctx));
operators::CalibParam param;
param.scale = 0.013f;
param.input = &x;
param.output = &output;
calib->SetParam(param);
calib->Launch();
calib.SetParam(param);
calib.Launch();
cudaDeviceSynchronize();
// invoking ref implementation and compare results
param.input = &x_cpu;
......
......@@ -42,7 +42,9 @@ TEST(conv_compute, fp32) {
operators::ConvParam param;
param.activation_param = act_param;
std::vector<int> pads = {1, 1, 1, 1};
std::vector<int> dilations = {1, 1, 1, 1};
param.paddings = std::make_shared<std::vector<int>>(pads);
param.dilations = std::make_shared<std::vector<int>>(dilations);
param.groups = 1;
Tensor x, filter, bias, y, x_cpu, filter_cpu, bias_cpu, y_cpu;
......@@ -149,6 +151,10 @@ TEST(conv_compute, int8) {
bias.Assign<float, lite::DDim, TARGET(kCUDA)>(bias_cpu_data,
filter_cpu.dims());
std::vector<int> pads = {0, 0, 0, 0};
std::vector<int> dilations = {1, 1, 1, 1};
param.paddings = std::make_shared<std::vector<int>>(pads);
param.dilations = std::make_shared<std::vector<int>>(dilations);
param.x = &x;
param.filter = &filter;
param.output = &y;
......@@ -203,12 +209,10 @@ TEST(conv_compute, int8_int8_out) {
std::cout << "input" << std::endl;
for (int i = 0; i < x_cpu.numel(); i++) {
x_cpu_data[i] = static_cast<int8_t>(random(-36, 36));
std::cout << float(x_cpu_data[i]) << std::endl;
}
std::cout << "filter" << std::endl;
for (int i = 0; i < filter_cpu.numel(); i++) {
filter_cpu_data[i] = static_cast<int8_t>(random(-10, 10));
std::cout << float(filter_cpu_data[i]) << std::endl;
}
for (int i = 0; i < bias_cpu.numel(); i++) {
bias_cpu_data[i] = i + 1.0;
......@@ -221,6 +225,10 @@ TEST(conv_compute, int8_int8_out) {
bias.Assign<float, lite::DDim, TARGET(kCUDA)>(bias_cpu_data,
filter_cpu.dims());
std::vector<int> pads = {0, 0, 0, 0};
std::vector<int> dilations = {1, 1, 1, 1};
param.paddings = std::make_shared<std::vector<int>>(pads);
param.dilations = std::make_shared<std::vector<int>>(dilations);
param.x = &x;
param.filter = &filter;
param.output = &y;
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "lite/backends/cuda/math/elementwise.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/elementwise_add_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
void ElementwiseAddCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
const lite::Tensor* x = param.X;
const lite::Tensor* y = param.Y;
lite::Tensor* out = param.Out;
CHECK(x->dims().production() == y->dims().production());
auto* x_data = x->data<float>();
auto* y_data = y->data<float>();
auto out_data = out->mutable_data<float>(TARGET(kCUDA));
int pixel_num = x->numel();
lite::cuda::math::elementwise_add(
pixel_num, x_data, y_data, out_data, stream);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
void ElementwiseAddComputeNHWC::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
const lite::Tensor* x = param.X;
const lite::Tensor* y = param.Y;
lite::Tensor* out = param.Out;
CHECK(x->dims().production() == y->dims().production());
auto* x_data = x->data<float>();
auto* y_data = y->data<float>();
auto out_data = out->mutable_data<float>(TARGET(kCUDA));
int pixel_num = x->numel();
lite::cuda::math::elementwise_add(
pixel_num, x_data, y_data, out_data, stream);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
void ElementwiseAddComputeInt8::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
const lite::Tensor* x = param.X;
const lite::Tensor* y = param.Y;
lite::Tensor* out = param.Out;
CHECK(x->dims().production() == y->dims().production());
const int c = x->dims()[3];
auto* x_data = x->data<float>();
auto* y_data = y->data<float>();
auto out_data = out->mutable_data<int8_t>(TARGET(kCUDA));
int pixel_num = x->numel();
float output_scale = param.output_scale;
if (c % 4 == 0) {
lite::cuda::math::elementwise_add_nhwc4_int8(
pixel_num / 4,
static_cast<const void*>(x_data),
static_cast<const void*>(y_data),
1. / output_scale,
static_cast<void*>(out_data),
stream);
} else {
lite::cuda::math::elementwise_add_int8(
pixel_num, x_data, y_data, 1. / output_scale, out_data, stream);
}
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(elementwise_add,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::ElementwiseAddCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
REGISTER_LITE_KERNEL(elementwise_add,
kCUDA,
kFloat,
kNHWC,
paddle::lite::kernels::cuda::ElementwiseAddComputeNHWC,
nhwc_format)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindInput("Y",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.Finalize();
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <map>
#include <vector>
#include "lite/backends/cuda/math/elementwise.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/elementwise_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
inline DDim trim_trailing_singular_dims(const DDim& dims) {
// Remove trailing dimensions of size 1 for y
auto actual_dims_size = dims.size();
for (; actual_dims_size != 0; --actual_dims_size) {
if (dims[actual_dims_size - 1] != 1) break;
}
std::vector<int64_t> trim_dims;
trim_dims.resize(actual_dims_size);
for (int i = 0; i < actual_dims_size; ++i) {
trim_dims[i] = dims[i];
}
if (trim_dims.size() == 0) {
return DDim();
}
return DDim(trim_dims);
}
inline bool is_broadcast(const DDim& x_dims,
const DDim& y_dims,
int axis,
int* pre,
int* n,
int* post) {
if (axis < 0) {
axis = x_dims.size() - y_dims.size();
}
DDim y_dim_trim = trim_trailing_singular_dims(y_dims);
axis = (y_dim_trim.size() == 0) ? x_dims.size() : axis;
if (x_dims.size() == y_dim_trim.size()) {
return false;
}
*pre = 1;
*n = 1;
*post = 1;
for (int i = 0; i < axis; ++i) {
(*pre) *= x_dims[i];
}
for (int i = 0; i < y_dim_trim.size(); ++i) {
CHECK_EQ(x_dims[i + axis], y_dim_trim[i])
<< "Broadcast dimension mismatch.";
(*n) *= y_dim_trim[i];
}
for (int i = axis + y_dim_trim.size(); i < x_dims.size(); ++i) {
(*post) *= x_dims[i];
}
return true;
}
#define ELEMENTWISE_COMPUTE(OP, WITH_RELU) \
auto& param = this->Param<param_t>(); \
auto& ctx = this->ctx_->template As<CUDAContext>(); \
auto stream = ctx.exec_stream(); \
const lite::Tensor* x = param.X; \
const lite::Tensor* y = param.Y; \
lite::Tensor* out = param.Out; \
int axis = param.axis; \
auto* x_data = x->data<float>(); \
auto* y_data = y->data<float>(); \
auto out_data = out->mutable_data<float>(TARGET(kCUDA)); \
int pixel_num = x->numel(); \
int pre = 1; \
int n = pixel_num; \
int post = 1; \
if (WITH_RELU) { \
if (is_broadcast(x->dims(), y->dims(), axis, &pre, &n, &post)) { \
lite::cuda::math::elementwise_relu( \
x_data, y_data, out_data, pre, n, post, OP, stream); \
} else { \
lite::cuda::math::elementwise_relu( \
x_data, y_data, out_data, 1, pixel_num, 1, OP, stream); \
} \
} else { \
if (is_broadcast(x->dims(), y->dims(), axis, &pre, &n, &post)) { \
lite::cuda::math::elementwise( \
x_data, y_data, out_data, pre, n, post, OP, stream); \
} else { \
lite::cuda::math::elementwise( \
x_data, y_data, out_data, 1, pixel_num, 1, OP, stream); \
} \
}
#define ELEMENTWISE_COMPUTE_NHWC(OP, WITH_RELU) \
std::map<int, int> pos_map = {{0, 0}, {1, 3}, {2, 1}, {3, 2}}; \
auto& param = this->Param<param_t>(); \
auto& ctx = this->ctx_->template As<CUDAContext>(); \
auto stream = ctx.exec_stream(); \
const lite::Tensor* x = param.X; \
const lite::Tensor* y = param.Y; \
lite::Tensor* out = param.Out; \
int axis = param.axis; \
if (axis < 0) axis = x->dims().size() - y->dims().size(); \
CHECK(axis >= 0) << "invalid axis of elementwise op"; \
axis = pos_map[axis]; \
auto* x_data = x->data<float>(); \
auto* y_data = y->data<float>(); \
auto out_data = out->mutable_data<float>(TARGET(kCUDA)); \
int pixel_num = x->numel(); \
int pre = 1; \
int n = pixel_num; \
int post = 1; \
if (WITH_RELU) { \
if (is_broadcast(x->dims(), y->dims(), axis, &pre, &n, &post)) { \
lite::cuda::math::elementwise_relu( \
x_data, y_data, out_data, pre, n, post, OP, stream); \
} else { \
lite::cuda::math::elementwise_relu( \
x_data, y_data, out_data, 1, pixel_num, 1, OP, stream); \
} \
} else { \
if (is_broadcast(x->dims(), y->dims(), axis, &pre, &n, &post)) { \
lite::cuda::math::elementwise( \
x_data, y_data, out_data, pre, n, post, OP, stream); \
} else { \
lite::cuda::math::elementwise( \
x_data, y_data, out_data, 1, pixel_num, 1, OP, stream); \
} \
}
void ElementwiseAddCompute::Run() {
ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kADD, false)
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
void ElementwiseAddComputeNHWC::Run() {
ELEMENTWISE_COMPUTE_NHWC(lite::cuda::math::BinaryOperation::kADD, false)
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
void ElementwiseMulCompute::Run() {
ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kMUL, false)
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
void ElementwiseMulComputeNHWC::Run() {
ELEMENTWISE_COMPUTE_NHWC(lite::cuda::math::BinaryOperation::kMUL, false)
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
void ElementwiseAddReluCompute::Run() {
ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kADD, true)
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
void ElementwiseAddReluComputeNHWC::Run() {
ELEMENTWISE_COMPUTE_NHWC(lite::cuda::math::BinaryOperation::kADD, true)
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
void ElementwiseMulReluCompute::Run() {
ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kMUL, true)
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
void ElementwiseMulReluComputeNHWC::Run() {
ELEMENTWISE_COMPUTE_NHWC(lite::cuda::math::BinaryOperation::kMUL, true)
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(elementwise_add,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::ElementwiseAddCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
REGISTER_LITE_KERNEL(elementwise_add,
kCUDA,
kFloat,
kNHWC,
paddle::lite::kernels::cuda::ElementwiseAddComputeNHWC,
nhwc_format)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindInput("Y",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(elementwise_mul,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::ElementwiseMulCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
REGISTER_LITE_KERNEL(elementwise_mul,
kCUDA,
kFloat,
kNHWC,
paddle::lite::kernels::cuda::ElementwiseMulComputeNHWC,
nhwc_format)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindInput("Y",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(fusion_elementwise_add_activation,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::ElementwiseAddReluCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
REGISTER_LITE_KERNEL(fusion_elementwise_add_activation,
kCUDA,
kFloat,
kNHWC,
paddle::lite::kernels::cuda::ElementwiseAddReluComputeNHWC,
nhwc_format)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindInput("Y",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(fusion_elementwise_mul_activation,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::ElementwiseMulReluCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
REGISTER_LITE_KERNEL(fusion_elementwise_mul_activation,
kCUDA,
kFloat,
kNHWC,
paddle::lite::kernels::cuda::ElementwiseMulReluComputeNHWC,
nhwc_format)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindInput("Y",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.Finalize();
......@@ -38,13 +38,58 @@ class ElementwiseAddComputeNHWC
virtual ~ElementwiseAddComputeNHWC() = default;
};
class ElementwiseAddComputeInt8
class ElementwiseMulCompute
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::ElementwiseParam;
void Run() override;
virtual ~ElementwiseMulCompute() = default;
};
class ElementwiseMulComputeNHWC
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNHWC)> {
public:
using param_t = operators::ElementwiseParam;
void Run() override;
virtual ~ElementwiseAddComputeInt8() = default;
virtual ~ElementwiseMulComputeNHWC() = default;
};
class ElementwiseAddReluCompute
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::FusionElementwiseActivationParam;
void Run() override;
virtual ~ElementwiseAddReluCompute() = default;
};
class ElementwiseAddReluComputeNHWC
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNHWC)> {
public:
using param_t = operators::FusionElementwiseActivationParam;
void Run() override;
virtual ~ElementwiseAddReluComputeNHWC() = default;
};
class ElementwiseMulReluCompute
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::FusionElementwiseActivationParam;
void Run() override;
virtual ~ElementwiseMulReluCompute() = default;
};
class ElementwiseMulReluComputeNHWC
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNHWC)> {
public:
using param_t = operators::FusionElementwiseActivationParam;
void Run() override;
virtual ~ElementwiseMulReluComputeNHWC() = default;
};
} // namespace cuda
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/elementwise_add_compute.h"
#include "lite/kernels/cuda/elementwise_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
......@@ -31,6 +31,14 @@ static void ElementwiseAddRef(float* x, float* y, float* out, int num) {
}
}
static void ElementwiseBroadcastRef(
float* x, float* y, float* out, int pre, int n, int post) {
for (int i = 0; i < pre * n * post; ++i) {
int idx = (i / post) % n;
out[i] = x[i] + y[idx];
}
}
TEST(elementwise_add, normal) {
ElementwiseAddCompute elementwise_add_kernel;
std::unique_ptr<KernelContext> ctx(new KernelContext);
......@@ -99,38 +107,117 @@ TEST(elementwise_add, normal) {
}
}
TEST(elementwise_add, int8_out) {
ElementwiseAddComputeInt8 elementwise_add_kernel;
TEST(elementwise_add, bias) {
ElementwiseAddCompute elementwise_add_kernel;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
operators::ElementwiseParam param;
Tensor x, y, out;
Tensor x_cpu, y_cpu, out_cpu;
Tensor x_ref, y_ref, out_ref;
const int n = 1;
const int c = 3;
const int h = 2000;
const int w = 2000;
x.Resize({n, c, h, w});
y.Resize({c, 1, 1});
out.Resize({n, c, h, w});
x_cpu.Resize({n, c, h, w});
y_cpu.Resize({c, 1, 1});
out_cpu.Resize({n, c, h, w});
x_ref.Resize({n, c, h, w});
y_ref.Resize({c, 1, 1});
out_ref.Resize({n, c, h, w});
auto* out_data = out.mutable_data<float>(TARGET(kCUDA));
auto* x_cpu_data = x_cpu.mutable_data<float>();
auto* y_cpu_data = y_cpu.mutable_data<float>();
auto* out_cpu_data = out_cpu.mutable_data<float>();
auto* x_ref_data = x_ref.mutable_data<float>();
auto* y_ref_data = y_ref.mutable_data<float>();
auto* out_ref_data = out_ref.mutable_data<float>();
for (int i = 0; i < x_cpu.numel(); ++i) {
x_cpu_data[i] = i + 5.0;
x_ref_data[i] = i + 5.0;
}
for (int i = 0; i < y_cpu.numel(); ++i) {
y_cpu_data[i] = i - 5.0;
y_ref_data[i] = i - 5.0;
}
x.Assign<float, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims());
y.Assign<float, lite::DDim, TARGET(kCUDA)>(y_cpu_data, y_cpu.dims());
param.X = &x;
param.Y = &y;
param.Out = &out;
param.axis = -1;
elementwise_add_kernel.SetParam(param);
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
elementwise_add_kernel.SetContext(std::move(ctx));
elementwise_add_kernel.Launch();
cudaDeviceSynchronize();
CopySync<TARGET(kCUDA)>(
out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH);
ElementwiseBroadcastRef(x_ref_data, y_ref_data, out_ref_data, n, c, h * w);
for (int i = 0; i < out.numel(); i++) {
EXPECT_NEAR(out_cpu_data[i], out_ref_data[i], 1e-5);
}
}
TEST(elementwise_add_nhwc, bias) {
ElementwiseAddComputeNHWC elementwise_add_kernel;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
operators::ElementwiseParam param;
Tensor x, y, out;
Tensor x_cpu, y_cpu, out_cpu;
Tensor x_ref, y_ref, out_ref;
const int n = 1;
const int h = 36;
const int w = 36;
const int c = 125;
const int c = 3;
const int h = 2000;
const int w = 2000;
x.Resize({n, h, w, c});
y.Resize({n, h, w, c});
y.Resize({c, 1, 1});
out.Resize({n, h, w, c});
x_cpu.Resize({n, h, w, c});
y_cpu.Resize({n, h, w, c});
y_cpu.Resize({c, 1, 1});
out_cpu.Resize({n, h, w, c});
x_ref.Resize({n, h, w, c});
y_ref.Resize({c, 1, 1});
out_ref.Resize({n, h, w, c});
auto* out_data = out.mutable_data<int8_t>(TARGET(kCUDA));
auto* out_data = out.mutable_data<float>(TARGET(kCUDA));
auto* x_cpu_data = x_cpu.mutable_data<float>();
auto* y_cpu_data = y_cpu.mutable_data<float>();
auto* out_cpu_data = out_cpu.mutable_data<int8_t>();
auto* out_cpu_data = out_cpu.mutable_data<float>();
auto* x_ref_data = x_ref.mutable_data<float>();
auto* y_ref_data = y_ref.mutable_data<float>();
auto* out_ref_data = out_ref.mutable_data<float>();
for (int i = 0; i < x_cpu.numel(); ++i) {
x_cpu_data[i] = i + 5.0;
x_ref_data[i] = i + 5.0;
}
for (int i = 0; i < y_cpu.numel(); ++i) {
y_cpu_data[i] = i;
y_cpu_data[i] = i - 5.0;
y_ref_data[i] = i - 5.0;
}
x.Assign<float, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims());
......@@ -139,7 +226,7 @@ TEST(elementwise_add, int8_out) {
param.X = &x;
param.Y = &y;
param.Out = &out;
param.output_scale = 50 / 127.;
param.axis = -1;
elementwise_add_kernel.SetParam(param);
cudaStream_t stream;
......@@ -147,16 +234,15 @@ TEST(elementwise_add, int8_out) {
context.SetExecStream(stream);
elementwise_add_kernel.SetContext(std::move(ctx));
auto start = GetCurrentUS();
for (int i = 0; i < 1000000; i++) {
elementwise_add_kernel.Launch();
}
LOG(INFO) << "time: " << (GetCurrentUS() - start) / 1000000.;
elementwise_add_kernel.Launch();
cudaDeviceSynchronize();
CopySync<TARGET(kCUDA)>(
out_cpu_data, out_data, sizeof(int8_t) * out.numel(), IoDirection::DtoH);
out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH);
ElementwiseBroadcastRef(
x_ref_data, y_ref_data, out_ref_data, n * h * w, c, 1);
for (int i = 0; i < out.numel(); i++) {
// LOG(INFO) << float(out_cpu_data[i]);
EXPECT_NEAR(out_cpu_data[i], out_ref_data[i], 1e-5);
}
}
......
此差异已折叠。
......@@ -20,7 +20,8 @@ namespace lite {
namespace kernels {
namespace cuda {
class FeedCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
template <typename T, PrecisionType Ptype>
class FeedCompute : public KernelLite<TARGET(kCUDA), Ptype> {
public:
using param_t = operators::FeedParam;
using TargetW = TargetWrapper<TARGET(kCUDA)>;
......
此差异已折叠。
......@@ -16,6 +16,7 @@
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include "lite/backends/cuda/blas.h"
namespace paddle {
namespace lite {
......@@ -26,6 +27,7 @@ TEST(mul_compute, normal) {
MulCompute mul_kernel;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
context.InitOnce();
Tensor x, y, out, x_cpu, y_cpu, out_cpu;
int x_h = 2, x_w_y_h = 3, y_w = 4;
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册