未验证 提交 8eeaa0ac 编写于 作者: H HappyAngel 提交者: GitHub

Merge pull request #138 from PaddlePaddle/develop

pull
......@@ -94,12 +94,10 @@ function(compile_flatbuffers_schema_to_cpp_opt TARGET SRC_FBS OPT)
message(STATUS "SRC_FBS_DIR: ${SRC_FBS_DIR}")
string(REGEX REPLACE "\\.fbs$" "_generated.h" GEN_HEADER ${SRC_FBS})
add_custom_command(
OUTPUT ${GEN_HEADER}
OUTPUT "${CMAKE_CURRENT_SOURCE_DIR}/${GEN_HEADER}"
COMMAND "${FLATBUFFERS_FLATC_EXECUTABLE}"
--cpp --gen-mutable --gen-object-api --reflect-names
--force-empty --force-empty-vectors
${OPT}
-I "${CMAKE_CURRENT_SOURCE_DIR}/tests/include_test"
-o "${CMAKE_CURRENT_SOURCE_DIR}/${SRC_FBS_DIR}"
"${CMAKE_CURRENT_SOURCE_DIR}/${SRC_FBS}"
DEPENDS flatbuffers
......
......@@ -37,14 +37,25 @@ rm ./lite/api/paddle_use_kernels.h
rm ./lite/api/paddle_use_ops.h
# 设置编译参数并开始编译
# android-armv7:cpu+gpu+cv+extra
./lite/tools/build_android.sh \
--arch=armv7 \
--toolchain=clang \
--with_cv=OFF \
--with_log=OFF \
--with_extra=OFF \
--with_extra=ON \
--with_cv=ON \
--with_opencl=ON
# android-armv8:cpu+gpu+cv+extra
./lite/tools/build_android.sh \
--arch=armv8 \
--toolchain=clang \
--with_log=OFF \
--with_extra=ON \
--with_cv=ON \
--with_opencl=ON
# 注:编译帮助请执行: ./lite/tools/build_android.sh help
```
......@@ -206,7 +217,7 @@ adb shell "export GLOG_v=4; \
## 3. 如何在Code中使用
即编译产物`demo/cxx/mobile_light`目录下的代码,在线版参考GitHub仓库[./lite/demo/cxx/mobile_light/mobilenetv1_light_api.cc](https://github.com/PaddlePaddle/Paddle-Lite/blob/develop/lite/demo/cxx/mobile_light/mobilenetv1_light_api.cc);
即编译产物`demo/cxx/mobile_light`目录下的代码,在线版参考GitHub仓库[./lite/demo/cxx/mobile_light/mobilenetv1_light_api.cc](https://github.com/PaddlePaddle/Paddle-Lite/blob/develop/lite/demo/cxx/mobile_light/mobilenetv1_light_api.cc),其中也包括判断当前设备是否支持OpenCL的方法;
注:这里给出的链接会跳转到线上最新develop分支的代码,很可能与您本地的代码存在差异,建议参考自己本地位于`lite/demo/cxx/`目录的代码,查看如何使用。
......
......@@ -32,9 +32,22 @@
#include "lite/backends/mlu/target_wrapper.h"
#endif
#ifdef LITE_WITH_OPENCL
#include "lite/backends/opencl/cl_runtime.h"
#endif
namespace paddle {
namespace lite_api {
bool IsOpenCLBackendValid() {
bool opencl_valid = false;
#ifdef LITE_WITH_OPENCL
opencl_valid = paddle::lite::CLRuntime::Global()->OpenCLAvaliableForDevice();
#endif
LOG(INFO) << "opencl_valid:" << opencl_valid;
return opencl_valid;
}
Tensor::Tensor(void *raw) : raw_tensor_(raw) {}
// TODO(Superjomn) refine this by using another `const void* const_raw`;
......
......@@ -33,6 +33,9 @@ using lod_t = std::vector<std::vector<uint64_t>>;
enum class LiteModelType { kProtobuf = 0, kNaiveBuffer, UNK };
// return true if current device supports OpenCL model
LITE_API bool IsOpenCLBackendValid();
struct LITE_API Tensor {
explicit Tensor(void* raw);
explicit Tensor(const void* raw);
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include <iostream>
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/backends/cuda/math/activation.h"
#include "lite/backends/cuda/math/utils.h"
......@@ -484,6 +485,76 @@ template void relu(int, const half*, half*, float, cudaStream_t);
template void bias_relu(
int, const float*, const float* bias, float*, float, cudaStream_t);
// ------------- sigmoid -------------
template <typename T>
__global__ void sigmoid_kernel(const int num, const T* in, T* out) {
CUDA_KERNEL_LOOP(i, num) {
#if __CUDA_ARCH__ >= 350
out[i] = static_cast<T>(1.0f) /
(static_cast<T>(1.0f) + expf(-1 * __ldg(in + i)));
#else
out[i] = static_cast<T>(1.0f) / (static_cast<T>(1.0f) + expf(-in[i]));
#endif
}
}
template <>
__global__ void sigmoid_kernel(const int num, const half* in, half* out) {
CUDA_KERNEL_LOOP(i, num) {
half tmp = __float2half(1.0f);
#if __CUDA_ARCH__ >= 530
out[i] = __hdiv(
tmp, __hadd(tmp, hexp(__hmul(__float2half(-1.0f), __ldg(in + i)))));
#else
out[i] = __float2half(1.0f / (1.0f + expf(-1 * __half2float(in[i]))));
#endif
}
}
template <>
__global__ void sigmoid_kernel(const int num, const half2* in, half2* out) {
CUDA_KERNEL_LOOP(i, num) {
half2 tmp = __floats2half2_rn(1.0f, 1.0f);
#if __CUDA_ARCH__ >= 530
out[i] = __h2div(tmp,
__hadd2(tmp,
h2exp(__hmul2(__floats2half2_rn(-1.0f, -1.0f),
__ldg(in + i)))));
#else
out[i].x = __float2half(1.0f / (1.0f + expf(-1 * __half2float(in[i].x))));
out[i].y = __float2half(1.0f / (1.0f + expf(-1 * __half2float(in[i].y))));
#endif
}
}
template <typename T>
void sigmoid(const int num, const T* din, T* dout, cudaStream_t stream) {
sigmoid_kernel<T><<<CUDA_GET_BLOCKS(num), CUDA_NUM_THREADS, 0, stream>>>(
num, din, dout);
CUDA_POST_KERNEL_CHECK;
}
template <>
void sigmoid(const int num, const half* din, half* dout, cudaStream_t stream) {
if (num % 2 == 0) {
const half2* din2 = reinterpret_cast<const half2*>(din);
half2* dout2 = reinterpret_cast<half2*>(dout);
sigmoid_kernel<
half2><<<CUDA_GET_BLOCKS(num / 2), CUDA_NUM_THREADS, 0, stream>>>(
num / 2, din2, dout2);
} else {
sigmoid_kernel<half><<<CUDA_GET_BLOCKS(num), CUDA_NUM_THREADS, 0, stream>>>(
num, din, dout);
}
CUDA_POST_KERNEL_CHECK;
}
template void sigmoid(const int num,
const float* din,
float* dout,
cudaStream_t stream);
} // namespace math
} // namespace cuda
} // namespace lite
......
......@@ -83,6 +83,9 @@ void bias_int8_nhwc(int num,
const void* scale,
cudaStream_t stream);
template <typename T>
void sigmoid(const int num, const T* din, T* dout, cudaStream_t stream);
} // namespace math
} // namespace cuda
} // namespace lite
......
......@@ -31,6 +31,17 @@ __global__ void RowwiseAddKernel(
c[i] = a[i] + b[w];
}
}
template <>
__global__ void RowwiseAddKernel(
const half* a, const half* b, half* c, int width, int num) {
CUDA_KERNEL_LOOP(i, num) {
int h = i / width;
int w = i - h * width;
c[i] = __hadd(a[i], b[w]);
}
}
template <typename T>
void RowwiseAdd<T>::operator()(const T* input,
const T* bias,
......@@ -44,6 +55,7 @@ void RowwiseAdd<T>::operator()(const T* input,
}
template struct RowwiseAdd<float>;
template struct RowwiseAdd<half>;
} // namespace math
} // namespace cuda
......
......@@ -22,6 +22,10 @@ namespace lite {
namespace cuda {
namespace math {
/*
* threads(frame_per_block, batch_per_block)
* grid(frame_blocks, batch_blocks)
*/
template <typename T>
__global__ void GruForwardResetOutput(
T* gate_value,
......@@ -33,6 +37,7 @@ __global__ void GruForwardResetOutput(
bool is_batch) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return;
int batch_idx = 0;
if (is_batch) {
batch_idx = blockIdx.y * blockDim.y + threadIdx.y;
......@@ -44,12 +49,14 @@ __global__ void GruForwardResetOutput(
T reset_out_val;
T update_gate_value = gate_value[frame_idx + frame_size * 0];
T reset_gate_value = gate_value[frame_idx + frame_size * 1];
if (prev_output_value) {
if (is_batch) {
prev_output_value += batch_idx * frame_size;
}
prev_out = prev_output_value[frame_idx];
}
if (active_gate == lite::cuda::math::ActivationType::kSigmoid) {
update_gate_value = Sigmoid(update_gate_value);
reset_gate_value = Sigmoid(reset_gate_value);
......@@ -60,12 +67,71 @@ __global__ void GruForwardResetOutput(
update_gate_value = Tanh(update_gate_value);
reset_gate_value = Tanh(reset_gate_value);
}
reset_out_val = prev_out * reset_gate_value;
gate_value[frame_idx + frame_size * 0] = update_gate_value;
gate_value[frame_idx + frame_size * 1] = reset_gate_value;
reset_output_value[frame_idx] = reset_out_val;
}
template <>
__global__ void GruForwardResetOutput(
half* gate_value,
half* reset_output_value,
half* prev_output_value,
int frame_size,
int batch_size,
lite::cuda::math::ActivationType active_gate,
bool is_batch) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return;
int batch_idx = 0;
if (is_batch) {
batch_idx = blockIdx.y * blockDim.y + threadIdx.y;
if (batch_idx >= batch_size) return;
gate_value += batch_idx * 3 * frame_size;
reset_output_value += batch_idx * frame_size;
}
half prev_out = 0;
half reset_out_val;
half update_gate_value = gate_value[frame_idx + frame_size * 0];
half reset_gate_value = gate_value[frame_idx + frame_size * 1];
if (prev_output_value) {
if (is_batch) {
prev_output_value += batch_idx * frame_size;
}
prev_out = prev_output_value[frame_idx];
}
if (active_gate == ActivationType::kSigmoid) {
update_gate_value = Sigmoid(update_gate_value);
reset_gate_value = Sigmoid(reset_gate_value);
} else if (active_gate == ActivationType::kReLU) {
update_gate_value = ReLU(update_gate_value);
reset_gate_value = ReLU(reset_gate_value);
} else if (active_gate == ActivationType::kTanh) {
update_gate_value = Tanh(update_gate_value);
reset_gate_value = Tanh(reset_gate_value);
}
#if __CUDA_ARCH__ >= 530
reset_out_val = __hmul(prev_out, reset_gate_value);
#else
reset_out_val =
__float2half(__half2float(prev_out) * __half2float(reset_gate_value));
#endif
gate_value[frame_idx + frame_size * 0] = update_gate_value;
gate_value[frame_idx + frame_size * 1] = reset_gate_value;
reset_output_value[frame_idx] = reset_out_val;
}
/*
* threads(frame_per_block, batch_per_block)
* grid(frame_blocks, batch_blocks)
*/
template <typename T>
__global__ void GruForwardFinalOutput(
T* gate_value,
......@@ -87,14 +153,17 @@ __global__ void GruForwardFinalOutput(
gate_value += batch_idx * 3 * frame_size;
output_value += batch_idx * frame_size;
}
T output;
T prev_out = 0;
T update_gate_value = gate_value[frame_idx + frame_size * 0];
T state_frame_value = gate_value[frame_idx + frame_size * 2];
if (prev_output_value) {
if (is_batch) prev_output_value += batch_idx * frame_size;
prev_out = prev_output_value[frame_idx];
}
if (active_node == lite::cuda::math::ActivationType::kSigmoid) {
state_frame_value = Sigmoid(state_frame_value);
} else if (active_node == lite::cuda::math::ActivationType::kReLU) {
......@@ -102,6 +171,7 @@ __global__ void GruForwardFinalOutput(
} else if (active_node == lite::cuda::math::ActivationType::kTanh) {
state_frame_value = Tanh(state_frame_value);
}
if (origin_mode) {
output = update_gate_value * prev_out + state_frame_value -
update_gate_value * state_frame_value;
......@@ -109,6 +179,76 @@ __global__ void GruForwardFinalOutput(
output = prev_out - update_gate_value * prev_out +
update_gate_value * state_frame_value;
}
gate_value[frame_idx + frame_size * 2] = state_frame_value;
output_value[frame_idx] = output;
}
template <>
__global__ void GruForwardFinalOutput(
half* gate_value,
half* prev_output_value,
half* output_value,
int frame_size,
int batch_size,
lite::cuda::math::ActivationType active_node,
bool origin_mode,
bool is_batch) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return;
int batch_idx = 0;
if (is_batch) {
batch_idx = blockIdx.y * blockDim.y + threadIdx.y;
if (batch_idx >= batch_size) {
return;
}
gate_value += batch_idx * 3 * frame_size;
output_value += batch_idx * frame_size;
}
half output;
half prev_out = 0;
half update_gate_value = gate_value[frame_idx + frame_size * 0];
half state_frame_value = gate_value[frame_idx + frame_size * 2];
if (prev_output_value) {
if (is_batch) prev_output_value += batch_idx * frame_size;
prev_out = prev_output_value[frame_idx];
}
if (active_node == lite::cuda::math::ActivationType::kSigmoid) {
state_frame_value = Sigmoid(state_frame_value);
} else if (active_node == lite::cuda::math::ActivationType::kReLU) {
state_frame_value = ReLU(state_frame_value);
} else if (active_node == lite::cuda::math::ActivationType::kTanh) {
state_frame_value = Tanh(state_frame_value);
}
if (origin_mode) {
#if __CUDA_ARCH__ >= 530
output =
__hsub(__hadd(__hmul(update_gate_value, prev_out), state_frame_value),
__hmul(update_gate_value, state_frame_value));
#else
output = __float2half(
__half2float(update_gate_value) * __half2float(prev_out) +
__half2float(state_frame_value) -
__half2float(update_gate_value) * __half2float(state_frame_value));
#endif
} else {
#if __CUDA_ARCH__ >= 530
output = prev_out - update_gate_value * prev_out +
update_gate_value * state_frame_value;
output = __hadd(__hsub(prev_out, __hmul(update_gate_value, prev_out)),
__hmul(update_gate_value, state_frame_value));
#else
output = __float2half(
__half2float(prev_out) -
__half2float(update_gate_value) * __half2float(prev_out) +
__half2float(update_gate_value) * __half2float(state_frame_value));
#endif
}
gate_value[frame_idx + frame_size * 2] = state_frame_value;
output_value[frame_idx] = output;
}
......@@ -122,6 +262,7 @@ template __global__ void GruForwardFinalOutput<float>(
lite::cuda::math::ActivationType active_node,
bool origin_mode,
bool is_batch);
template __global__ void GruForwardResetOutput<float>(
float* gate_value,
float* reset_output_value,
......
......@@ -34,10 +34,32 @@ template <typename Dtype>
inline __device__ Dtype Sigmoid(const Dtype a) {
return static_cast<Dtype>(1.0) / (static_cast<Dtype>(1.0) + expf(-a));
}
template <>
inline __device__ half Sigmoid(const half a) {
#if __CUDA_ARCH__ >= 530
const half tmp = __float2half(1.0f);
return __hdiv(tmp, __hadd(tmp, hexp(__hmul(__float2half(-1.f), a))));
#else
return __float2half(1.0f / (expf(__half2float(a) * -1) + 1.0f));
#endif
}
template <typename Dtype>
inline __device__ Dtype ReLU(const Dtype a) {
return a > static_cast<Dtype>(0.f) ? a : static_cast<Dtype>(0.f);
}
template <>
inline __device__ half ReLU(const half a) {
const half tmp = __float2half(0.f);
#if __CUDA_ARCH__ >= 530
return __hgt(a, tmp) ? a : tmp;
#else
return __float2half(__half2float(a) > 0.f ? __half2float(a) : 0.f);
#endif
}
template <typename Dtype>
inline __device__ Dtype Tanh(const Dtype a) {
Dtype tmp = static_cast<Dtype>(-2.0) * a;
......@@ -45,6 +67,18 @@ inline __device__ Dtype Tanh(const Dtype a) {
static_cast<Dtype>(1.0);
}
template <>
inline __device__ half Tanh(const half a) {
#if __CUDA_ARCH__ >= 530
half tmp = __float2half(1.0f);
half numerator = __hmul(__float2half(-2.0f), a);
return __hsub(__hdiv(__float2half(2.0f), __hadd(tmp, hexp(numerator))), tmp);
#else
float tmp = -2.0f * __half2float(a);
return __float2half(2.0f / (1.0f + expf(tmp)) - 1.0f);
#endif
}
template <typename T>
__global__ void GruForwardResetOutput(
T* gate_value,
......@@ -54,6 +88,7 @@ __global__ void GruForwardResetOutput(
int batch_size,
lite::cuda::math::ActivationType active_gate,
bool is_batch);
template <typename T>
__global__ void GruForwardFinalOutput(
T* gate_value,
......@@ -65,6 +100,134 @@ __global__ void GruForwardFinalOutput(
bool origin_mode,
bool is_batch);
/*
* threads(tile_size, 1)
* grids(frame_blocks, 1)
*/
template <class T, int TiledSize>
__global__ void FastCollectiveGruGate(T* gate_value,
T* prev_output_value,
T* gate_weight,
T* reset_output,
int frame_size,
ActivationType active_node) {
T xt_0 = 0.0f;
T a0 = 0.0f;
T c0 = 0.0f;
T b0[TiledSize];
int col = blockIdx.x * blockDim.x + threadIdx.x;
int tiled_mask = ((1 << TiledSize) - 1);
// tiled matrix multiply using register shift, faster than sm.
if (prev_output_value) {
for (int k = 0; k < (((frame_size - 1) / TiledSize) + 1); ++k) {
a0 = 0;
if ((threadIdx.x + k * TiledSize) < frame_size) {
a0 = prev_output_value[threadIdx.x + (k * TiledSize)];
}
for (int i = 0; i < TiledSize; ++i) {
if (col < frame_size * 2 && (i + k * TiledSize) < frame_size) {
b0[i] = gate_weight[(i + k * TiledSize) * frame_size * 2 + col];
}
}
for (int i = 0; i < TiledSize; ++i) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
c0 = c0 + __shfl_sync(tiled_mask, a0, i, TiledSize) * b0[i];
#else
c0 = c0 + __shfl(a0, i, TiledSize) * b0[i];
#endif
}
}
}
__syncthreads();
if (col < frame_size * 2) {
xt_0 = gate_value[col];
c0 += xt_0;
if (active_node == ActivationType::kSigmoid) {
c0 = Sigmoid(c0);
} else if (active_node == ActivationType::kReLU) {
c0 = ReLU(c0);
} else if (active_node == ActivationType::kTanh) {
c0 = Tanh(c0);
}
gate_value[col] = c0;
if (frame_size <= col && col < frame_size * 2) {
T htp_0 = 0.0;
if (prev_output_value) {
htp_0 = prev_output_value[col - frame_size];
}
reset_output[col - frame_size] = c0 * htp_0;
} else if (col < frame_size) {
gate_value[col] = c0;
}
}
}
template <class T, int TiledSize>
__global__ void FastCollectiveGruOut(T* gate_weight,
T* prev_out_value,
T* output_value,
T* gate_value,
T* reset_value,
int frame_size,
ActivationType active_node,
bool origin_mode) {
int col = blockIdx.x * blockDim.x + threadIdx.x;
T a0 = 0.0f;
T b0[TiledSize];
T c0 = 0.0f;
int tiled_mask = ((1 << TiledSize) - 1);
if (prev_out_value) {
for (int k = 0; k < ((frame_size - 1) / TiledSize + 1); ++k) {
a0 = 0;
if ((threadIdx.x + k * TiledSize) < frame_size) {
a0 = reset_value[threadIdx.x + k * TiledSize];
}
for (int i = 0; i < TiledSize; ++i) {
if (col < frame_size && (i + k * TiledSize) < frame_size) {
b0[i] = gate_weight[(i + k * TiledSize) * frame_size + col];
}
}
for (int i = 0; i < TiledSize; ++i) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
c0 = c0 + __shfl_sync(tiled_mask, a0, i, TiledSize) * b0[i];
#else
c0 = c0 + __shfl(a0, i, TiledSize) * b0[i];
#endif
}
}
}
__syncthreads();
if (col < frame_size) {
T xt_0 = gate_value[col + 2 * frame_size];
T gta_0 = gate_value[col];
T htp_0 = 0;
if (prev_out_value) {
htp_0 = prev_out_value[col];
}
c0 += xt_0;
if (active_node == ActivationType::kSigmoid) {
c0 = Sigmoid(c0);
} else if (active_node == ActivationType::kReLU) {
c0 = ReLU(c0);
} else if (active_node == ActivationType::kTanh) {
c0 = Tanh(c0);
}
gate_value[col + 2 * frame_size] = c0;
if (origin_mode) {
output_value[col] = htp_0 * gta_0 + (1 - gta_0) * c0;
} else {
output_value[col] = c0 * gta_0 + (1 - gta_0) * htp_0;
}
}
}
} // namespace math
} // namespace cuda
} // namespace lite
......
......@@ -77,8 +77,13 @@ void CopyMatrixRowsFunctor<T>::operator()(
}
template class CopyMatrixRowsFunctor<float>;
template class CopyMatrixRowsFunctor<half>;
template class LoDTensor2BatchFunctor<float>;
template class LoDTensor2BatchFunctor<half>;
template class Batch2LoDTensorFunctor<float>;
template class Batch2LoDTensorFunctor<half>;
} // namespace math
} // namespace cuda
......
......@@ -32,6 +32,9 @@ namespace math {
template <typename T>
class CopyMatrixRowsFunctor {
public:
// If is_src_index is true, copy the indexed rows of input src to the output
// dst. If is_src_index is false, copy the input src to the indexed of output
// dst. The indexes rows are based on the input index.
void operator()(const lite::Tensor& src,
lite::Tensor* dst,
const std::vector<uint64_t>& index_lod,
......@@ -44,6 +47,11 @@ class CopyMatrixRowsFunctor {
template <typename T>
class LoDTensor2BatchFunctor {
// Calculate the length of each sequence and
// sort sequence index by the length.
// example: sequences = {s0, s1, s2}
// s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2
// seq_info[3] = {(4, 5, 1), (0, 4, 0), (9, 3, 2)}
struct SeqInfo {
SeqInfo(size_t start, size_t length, size_t seq_idx)
: start_(start), length_(length), seq_idx_(seq_idx) {}
......@@ -60,21 +68,49 @@ class LoDTensor2BatchFunctor {
auto lods = lod_tensor.lod();
CHECK_EQ(lods.size(), 1UL) << "Only support one level sequence now.";
const auto& lod = lods[0];
std::vector<SeqInfo> seq_info;
for (int seq_id = 0; seq_id < static_cast<int>(lod.size()) - 1; ++seq_id) {
size_t length = lod[seq_id + 1] - lod[seq_id];
seq_info.emplace_back(lod[seq_id], length, seq_id);
}
std::sort(seq_info.begin(), seq_info.end(), [](SeqInfo a, SeqInfo b) {
return a.length_ > b.length_;
});
// Calculate the start position of each batch.
// example: sequences = {s0, s1, s2}
// s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2
// max_seqlen = 5,
// batchIndex = {b0, b1, b2, b3, b4}
// b0: 1 0 2, b1: 1 0 2, b2: 1 0 2, b3: 1 0, b4: 1
// batch_start_positions[6] = {0, 3, 6, 9, 11, 12}
// batch_start_positions[0] = 0
// batch_start_positions[1] = len(b0)
// batch_start_positions[2] = len(b0) + len(b1)
// ...
// seq2batch_idx[12] = {4, 0, 9,
// 5, 1, 10,
// 6, 2, 11,
// 7, 3,
// 8}
// seq_order = {1, 0, 2}, the sort order.
// where 1 is the second sequence,
// 0 is the first sequence,
// 2 is the third sequence.
LoD batch_lods;
batch_lods.emplace_back(std::vector<uint64_t>{0});
batch_lods.emplace_back(std::vector<uint64_t>{0});
batch_lods.emplace_back(std::vector<uint64_t>{0});
// batch_lods[0] is the start positions for batch LoDTensor
size_t max_seqlen = seq_info[0].length_;
batch_lods[0].resize(max_seqlen + 1);
// batch_lods[1] is the raw index in the input LoDTensor
batch_lods[1].resize(static_cast<size_t>(lod_tensor.dims()[0]));
// batch_lods[2] is the sort order for the input LoDTensor.
batch_lods[2].resize(seq_info.size());
auto* batch_starts = batch_lods[0].data();
......@@ -101,6 +137,7 @@ class LoDTensor2BatchFunctor {
}
batch_tensor->set_lod(batch_lods);
lite::cuda::math::CopyMatrixRowsFunctor<T> to_batch;
to_batch(lod_tensor, batch_tensor, batch_lods[1], true, stream);
CUDA_POST_KERNEL_CHECK;
......
......@@ -38,17 +38,20 @@ CLRuntime::~CLRuntime() {
}
bool CLRuntime::Init() {
if (initialized_) {
if (is_cl_runtime_initialized_) {
return true;
}
bool is_platform_init = InitializePlatform();
bool is_device_init = InitializeDevice();
is_init_success_ = is_platform_init && is_device_init;
initialized_ = true;
context_ = CreateContext();
command_queue_ = CreateCommandQueue(context());
return initialized_;
LOG(INFO) << "is_platform_init:" << is_platform_init;
LOG(INFO) << "is_device_init:" << is_device_init;
if ((is_platform_init == true) && (is_device_init == true)) {
is_platform_device_init_success_ = true;
context_ = CreateContext();
command_queue_ = CreateCommandQueue(context());
is_cl_runtime_initialized_ = true;
}
return is_cl_runtime_initialized_;
}
cl::Platform& CLRuntime::platform() {
......@@ -64,7 +67,9 @@ cl::Context& CLRuntime::context() {
}
cl::Device& CLRuntime::device() {
CHECK(device_ != nullptr) << "device_ is not initialized!";
if (device_ == nullptr) {
LOG(ERROR) << "device_ is not initialized!";
}
return *device_;
}
......@@ -150,6 +155,14 @@ GpuType CLRuntime::ParseGpuTypeFromDeviceName(std::string device_name) {
}
bool CLRuntime::InitializeDevice() {
VLOG(3) << "device_info_.size():" << device_info_.size();
for (auto i : device_info_) {
VLOG(3) << ">>> " << i.first << " " << i.second;
}
if (device_info_.size() > 0 && device_info_.size() <= 2) {
return false;
}
device_info_["PLACEHOLDER"] = 1;
// ===================== BASIC =====================
// CL_DEVICE_TYPE_GPU
// CL_DEVICE_NAME
......@@ -160,7 +173,7 @@ bool CLRuntime::InitializeDevice() {
status_ = platform_->getDevices(CL_DEVICE_TYPE_GPU, &all_devices);
CL_CHECK_ERROR(status_);
if (all_devices.empty()) {
LOG(FATAL) << "No OpenCL GPU device found!";
LOG(ERROR) << "No available OpenCL GPU device found!";
return false;
}
device_ = std::make_shared<cl::Device>();
......@@ -313,9 +326,6 @@ bool CLRuntime::InitializeDevice() {
}
std::map<std::string, size_t>& CLRuntime::GetDeviceInfo() {
if (0 != device_info_.size()) {
return device_info_;
}
InitializeDevice();
return device_info_;
}
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include <vector>
#include "lite/backends/opencl/cl_include.h"
#include "lite/backends/opencl/cl_utility.h"
#include "lite/backends/opencl/cl_wrapper.h"
typedef enum {
UNKNOWN = 0,
......@@ -68,6 +69,28 @@ class CLRuntime {
public:
static CLRuntime* Global();
bool OpenCLAvaliableForDevice() {
bool opencl_lib_found = paddle::lite::CLWrapper::Global()->OpenclLibFound();
LOG(INFO) << "opencl_lib_found:" << opencl_lib_found;
if (opencl_lib_found == false) return false;
bool dlsym_success = paddle::lite::CLWrapper::Global()->DlsymSuccess();
LOG(INFO) << "dlsym_success:" << dlsym_success;
if (opencl_lib_found == false) return false;
InitializeDevice();
bool support_fp16 =
static_cast<bool>(device_info_["CL_DEVICE_EXTENSIONS_FP16"]);
LOG(INFO) << "support_fp16:" << support_fp16;
if (support_fp16 == false) return false;
is_device_avaliable_for_opencl_ =
dlsym_success && opencl_lib_found && support_fp16;
LOG(INFO) << "is_device_avaliable_for_opencl_:"
<< is_device_avaliable_for_opencl_;
return is_device_avaliable_for_opencl_;
}
bool Init();
cl::Platform& platform();
......@@ -85,7 +108,7 @@ class CLRuntime {
bool BuildProgram(cl::Program* program, const std::string& options = "");
bool IsInitSuccess() { return is_init_success_; }
bool IsInitSuccess() { return is_platform_device_init_success_; }
std::string cl_path() { return cl_path_; }
......@@ -167,9 +190,11 @@ class CLRuntime {
cl_int status_{CL_SUCCESS};
bool initialized_{false};
bool is_device_avaliable_for_opencl_{false};
bool is_cl_runtime_initialized_{false};
bool is_init_success_{false};
bool is_platform_device_init_success_{false};
};
} // namespace lite
......
......@@ -19,14 +19,16 @@ limitations under the License. */
namespace paddle {
namespace lite {
CLWrapper *CLWrapper::Global() {
static CLWrapper wrapper;
return &wrapper;
}
CLWrapper::CLWrapper() {
CHECK(InitHandle()) << "Fail to initialize the OpenCL library!";
InitFunctions();
opencl_lib_found_ = InitHandle();
CHECK(opencl_lib_found_) << "Fail to initialize the OpenCL library!";
dlsym_success_ = InitFunctions();
}
bool CLWrapper::InitHandle() {
......@@ -68,15 +70,17 @@ bool CLWrapper::InitHandle() {
}
}
void CLWrapper::InitFunctions() {
bool CLWrapper::InitFunctions() {
CHECK(handle_ != nullptr) << "The library handle can't be null!";
bool dlsym_success = true;
#define PADDLE_DLSYM(cl_func) \
do { \
cl_func##_ = (cl_func##Type)dlsym(handle_, #cl_func); \
if (cl_func##_ == nullptr) { \
LOG(FATAL) << "Cannot find the " << #cl_func \
LOG(ERROR) << "Cannot find the " << #cl_func \
<< " symbol in libOpenCL.so!"; \
dlsym_success = false; \
break; \
} \
VLOG(4) << "Loaded the " << #cl_func << " symbol successfully."; \
......@@ -137,6 +141,7 @@ void CLWrapper::InitFunctions() {
PADDLE_DLSYM(clEnqueueCopyImage);
#undef PADDLE_DLSYM
return dlsym_success;
}
} // namespace lite
......
......@@ -508,13 +508,20 @@ class CLWrapper final {
return clEnqueueCopyImage_;
}
bool OpenclLibFound() { return opencl_lib_found_; }
bool DlsymSuccess() { return dlsym_success_; }
private:
CLWrapper();
CLWrapper(const CLWrapper &) = delete;
CLWrapper &operator=(const CLWrapper &) = delete;
bool InitHandle();
void InitFunctions();
bool InitFunctions();
bool opencl_lib_found_{true};
bool dlsym_success_{true};
void *handle_{nullptr};
clGetPlatformIDsType clGetPlatformIDs_{nullptr};
clGetPlatformInfoType clGetPlatformInfo_{nullptr};
clBuildProgramType clBuildProgram_{nullptr};
......
......@@ -175,7 +175,11 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
for (int i = 0; i < weight_scale_size; i++) {
weight_scale.push_back(whole_weight_scale);
}
op_desc.SetAttr("enable_int8", true);
// Arm CPU does not support conv2d_transpose
if (quantized_op_type_ != "conv2d_transpose") {
op_desc.SetAttr("enable_int8", true);
}
op_desc.SetInputScale(weight_name, weight_scale);
// change the weight from the float type to int8 type.
......@@ -280,6 +284,7 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph,
op_desc.SetInput("X", {quantized_op_input->arg()->name});
op_desc.SetOutput("Out", {dequant_op_out->arg()->name});
}
// Arm CPU does not support conv2d_transpose
if (quantized_op_type_ != "conv2d_transpose") {
op_desc.SetAttr("enable_int8", true);
}
......
......@@ -78,6 +78,28 @@ void RunModel(std::string model_dir,
// 1. Set MobileConfig
MobileConfig config;
config.set_model_from_file(model_dir);
// NOTE: Use android gpu with opencl, you should ensure:
// first, [compile **cpu+opencl** paddlelite
// lib](https://github.com/PaddlePaddle/Paddle-Lite/blob/develop/docs/demo_guides/opencl.md);
// second, [convert and use opencl nb
// model](https://github.com/PaddlePaddle/Paddle-Lite/blob/develop/docs/user_guides/opt/opt_bin.md).
//
/* Uncomment code below to enable OpenCL
bool is_opencl_backend_valid = ::IsOpenCLBackendValid();
std::cout << "is_opencl_backend_valid:" << is_opencl_backend_valid <<
std::endl;
if (is_opencl_backend_valid) {
// give opencl nb model dir
config.set_model_from_file(model_dir);
} else {
std::cout << "Unsupport opencl nb model." << std::endl;
exit(1);
// you can give backup cpu nb model instead
// config.set_model_from_file(cpu_nb_model_dir);
}
*/
// NOTE: To load model transformed by model_optimize_tool before
// release/v2.3.0, plese use `set_model_dir` API as listed below.
// config.set_model_dir(model_dir);
......
......@@ -15,6 +15,7 @@ add_kernel(leaky_relu_compute_cuda CUDA basic SRCS leaky_relu_compute.cu DEPS ${
add_kernel(abs_compute_cuda CUDA basic SRCS abs_compute.cu DEPS ${lite_kernel_deps})
add_kernel(tanh_compute_cuda CUDA basic SRCS tanh_compute.cu DEPS ${lite_kernel_deps})
add_kernel(relu_compute_cuda CUDA basic SRCS relu_compute.cu DEPS ${lite_kernel_deps})
add_kernel(sigmoid_compute_cuda CUDA basic SRCS sigmoid_compute.cu DEPS ${lite_kernel_deps} ${math_cuda})
add_kernel(yolo_box_compute_cuda CUDA basic SRCS yolo_box_compute.cu DEPS ${lite_kernel_deps})
add_kernel(sequence_pool_compute_cuda CUDA extra SRCS sequence_pool_compute.cu DEPS ${lite_kernel_deps})
add_kernel(sequence_pool_concat_compute_cuda CUDA extra SRCS sequence_pool_concat_compute.cu DEPS ${lite_kernel_deps})
......@@ -61,6 +62,7 @@ nv_test(leaky_relu_compute_cuda_test SRCS leaky_relu_compute_test.cc DEPS leaky_
nv_test(abs_compute_cuda_test SRCS abs_compute_test.cc DEPS abs_compute_cuda)
nv_test(tanh_compute_cuda_test SRCS tanh_compute_test.cc DEPS tanh_compute_cuda)
nv_test(relu_compute_cuda_test SRCS relu_compute_test.cc DEPS relu_compute_cuda)
nv_test(sigmoid_compute_cuda_test SRCS sigmoid_compute_test.cc DEPS sigmoid_compute_cuda)
nv_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_compute_cuda)
nv_test(transpose_compute_cuda_test SRCS transpose_compute_test.cc DEPS transpose_compute_cuda)
nv_test(search_group_padding_compute_cuda_test SRCS search_group_padding_compute_test.cc DEPS search_group_padding_compute_cuda)
......
......@@ -69,7 +69,7 @@ void concat_compute_ref(const operators::ConcatParam& param) {
std::vector<int> input_cols(input.size());
for (int i = 0; i < num; ++i) {
int input_i_numel = input[i]->dims().size() == 0 ? 0 : 1;
for (int didx = 0; didx < input[i]->dims().size(); ++didx) {
for (size_t didx = 0; didx < input[i]->dims().size(); ++didx) {
input_i_numel *= input[i]->dims()[didx];
}
int t_cols = input_i_numel / rows;
......
......@@ -48,10 +48,69 @@ struct GRUUnitFunctor {
CUDAContext* context) {
dim3 threads, grids;
if (batch_size == 1) {
int frame_per_block = frame_size <= 1024 ? frame_size : 1024;
int frame_blocks = (frame_size + 1024 - 1) / 1024;
threads = dim3(frame_per_block, 1);
grids = dim3(frame_blocks, 1);
if (lite::TargetWrapperCuda::GetComputeCapability() >= 70) {
if (frame_size < 16) {
constexpr int tiled_size = 8;
int frame_blocks = (frame_size * 2 + tiled_size - 1) / tiled_size;
threads = dim3(tiled_size, 1);
grids = dim3(frame_blocks, 1);
lite::cuda::math::FastCollectiveGruGate<
T,
tiled_size><<<grids, threads, 0, context->exec_stream()>>>(
value.gate_value,
value.prev_out_value,
value.gate_weight,
value.reset_output_value,
frame_size,
active_gate);
frame_blocks = (frame_size + tiled_size - 1) / tiled_size;
grids = dim3(frame_blocks, 1);
lite::cuda::math::FastCollectiveGruOut<
T,
tiled_size><<<grids, threads, 0, context->exec_stream()>>>(
value.state_weight,
value.prev_out_value,
value.output_value,
value.gate_value,
value.reset_output_value,
frame_size,
active_node,
origin_mode);
} else {
constexpr int tiled_size = 16;
int frame_blocks = (frame_size * 2 + tiled_size - 1) / tiled_size;
threads = dim3(tiled_size, 1);
grids = dim3(frame_blocks, 1);
lite::cuda::math::FastCollectiveGruGate<
T,
tiled_size><<<grids, threads, 0, context->exec_stream()>>>(
value.gate_value,
value.prev_out_value,
value.gate_weight,
value.reset_output_value,
frame_size,
active_gate);
frame_blocks = (frame_size + tiled_size - 1) / tiled_size;
grids = dim3(frame_blocks, 1);
lite::cuda::math::FastCollectiveGruOut<
T,
tiled_size><<<grids, threads, 0, context->exec_stream()>>>(
value.state_weight,
value.prev_out_value,
value.output_value,
value.gate_value,
value.reset_output_value,
frame_size,
active_node,
origin_mode);
}
return;
} else {
int frame_per_block = frame_size <= 1024 ? frame_size : 1024;
int frame_blocks = (frame_size + 1024 - 1) / 1024;
threads = dim3(frame_per_block, 1);
grids = dim3(frame_blocks, 1);
}
} else {
threads = dim3(32, 32);
grids = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32);
......@@ -121,6 +180,90 @@ struct GRUUnitFunctor {
template struct GRUUnitFunctor<float>;
template <>
struct GRUUnitFunctor<half> {
static void compute(GRUMetaValue<half> value,
int frame_size,
int batch_size,
const lite::cuda::math::ActivationType& active_node,
const lite::cuda::math::ActivationType& active_gate,
bool origin_mode,
lite::cuda::math::Gemm<half, half>* blas,
CUDAContext* context) {
dim3 threads, grids;
if (batch_size == 1) {
int frame_per_block = frame_size <= 1024 ? frame_size : 1024;
int frame_blocks = (frame_size + 1024 - 1) / 1024;
threads = dim3(frame_per_block, 1);
grids = dim3(frame_blocks, 1);
} else {
threads = dim3(32, 32);
grids = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32);
}
if (value.prev_out_value) {
CHECK(blas->init(false,
false,
batch_size,
frame_size * 2,
frame_size,
frame_size,
frame_size * 2,
frame_size * 3,
context));
blas->run(1.0f,
1.0f,
value.prev_out_value,
value.gate_weight,
value.gate_value,
context);
}
CUDA_POST_KERNEL_CHECK;
lite::cuda::math::GruForwardResetOutput<
half><<<grids, threads, 0, context->exec_stream()>>>(
value.gate_value,
value.reset_output_value,
value.prev_out_value,
frame_size,
batch_size,
active_gate,
batch_size == 1);
CUDA_POST_KERNEL_CHECK;
if (value.prev_out_value) {
CHECK(blas->init(false,
false,
batch_size,
frame_size,
frame_size,
frame_size,
frame_size,
frame_size * 3,
context));
blas->run(1.0f,
1.0f,
value.reset_output_value,
value.state_weight,
value.gate_value + frame_size * 2,
context);
}
CUDA_POST_KERNEL_CHECK;
lite::cuda::math::GruForwardFinalOutput<
half><<<grids, threads, 0, context->exec_stream()>>>(
value.gate_value,
value.prev_out_value,
value.output_value,
frame_size,
batch_size,
active_node,
origin_mode,
batch_size == 1);
CUDA_POST_KERNEL_CHECK;
}
};
template <typename T, PrecisionType PType>
void GRUCompute<T, PType>::PrepareForRun() {
gemm_impl_.reset(new lite::cuda::math::Gemm<T, T>);
......@@ -141,18 +284,17 @@ void GRUCompute<T, PType>::Run() {
if (param.bias) {
bias = const_cast<lite::Tensor*>(param.bias);
}
auto* weight = param.weight;
auto* weight_data = const_cast<T*>(weight->template data<T>());
auto* batch_gate = param.batch_gate;
auto* batch_reset_hidden_prev = param.batch_reset_hidden_prev;
auto* batch_hidden = param.batch_hidden;
auto* hidden = param.hidden;
auto* batch_reset_hidden_prev_data =
const lite::Tensor* weight = param.weight;
T* weight_data = const_cast<T*>(weight->template data<T>());
lite::Tensor* batch_gate = param.batch_gate;
lite::Tensor* batch_reset_hidden_prev = param.batch_reset_hidden_prev;
lite::Tensor* batch_hidden = param.batch_hidden;
lite::Tensor* hidden = param.hidden;
T* batch_reset_hidden_prev_data =
batch_reset_hidden_prev->template mutable_data<T>(TARGET(kCUDA));
hidden->template mutable_data<T>(TARGET(kCUDA));
auto* batch_gate_data = batch_gate->template mutable_data<T>(TARGET(kCUDA));
auto* batch_hidden_data =
batch_hidden->template mutable_data<T>(TARGET(kCUDA));
T* batch_gate_data = batch_gate->template mutable_data<T>(TARGET(kCUDA));
T* batch_hidden_data = batch_hidden->template mutable_data<T>(TARGET(kCUDA));
bool is_reverse = param.is_reverse;
auto active_node = lite::cuda::math::GetActiveType(param.activation);
auto active_gate = lite::cuda::math::GetActiveType(param.gate_activation);
......@@ -224,6 +366,8 @@ void GRUCompute<T, PType>::Run() {
using GRUFp32 =
paddle::lite::kernels::cuda::GRUCompute<float, PRECISION(kFloat)>;
using GRUFp16 = paddle::lite::kernels::cuda::GRUCompute<half, PRECISION(kFP16)>;
REGISTER_LITE_KERNEL(gru, kCUDA, kFloat, kNCHW, GRUFp32, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("H0", {LiteType::GetTensorTy(TARGET(kCUDA))})
......@@ -234,3 +378,20 @@ REGISTER_LITE_KERNEL(gru, kCUDA, kFloat, kNCHW, GRUFp32, def)
.BindOutput("BatchHidden", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Hidden", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
REGISTER_LITE_KERNEL(gru, kCUDA, kFP16, kNCHW, GRUFp16, def)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindInput("H0", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindInput("Weight",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindOutput("BatchGate",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindOutput("BatchResetHiddenPrev",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindOutput("BatchHidden",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindOutput("Hidden",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.Finalize();
......@@ -45,10 +45,13 @@ class GRUTest : public ::testing::Test {
x_ref_.Resize(lite::DDim(x_shape_));
x_gpu_.Resize(lite::DDim(x_shape_));
x_ref_.set_lod(lod_);
w_ref_.Resize(lite::DDim(w_shape_));
w_gpu_.Resize(lite::DDim(w_shape_));
auto x_ref_data = x_ref_.mutable_data<float>();
auto w_ref_data = w_ref_.mutable_data<float>();
for (int64_t i = 0; i < x_ref_.numel(); i++) {
x_ref_data[i] = static_cast<float>(i % 10 * 0.2);
}
......@@ -63,6 +66,7 @@ class GRUTest : public ::testing::Test {
batch_hidden_gpu_.Resize(lite::DDim(out_shape_));
batch_reset_hidden_gpu_.Resize(lite::DDim(out_shape_));
RunBaseLine();
InitParamAndContext();
}
......@@ -91,6 +95,22 @@ class GRUTest : public ::testing::Test {
w_gpu_.dims());
}
void InitHalfInput() {
x_half_.Resize(lite::DDim(x_shape_));
auto x_half_data = x_half_.mutable_data<half>();
for (int64_t i = 0; i < x_half_.numel(); i++) {
x_half_data[i] = half(lite::float16(x_ref_.data<float>()[i]));
}
x_gpu_.Assign<half, lite::DDim, TARGET(kCUDA)>(x_half_data, x_gpu_.dims());
x_gpu_.set_lod(x_ref_.lod());
w_half_.Resize(w_ref_.dims());
auto w_half_data = w_half_.mutable_data<half>();
for (int64_t i = 0; i < w_half_.numel(); i++) {
w_half_data[i] = half(lite::float16(w_ref_.data<float>()[i]));
}
w_gpu_.Assign<half, lite::DDim, TARGET(kCUDA)>(w_half_data, w_gpu_.dims());
}
void RunBaseLine() {}
int batch_, frame_size_;
......@@ -134,6 +154,29 @@ TEST_F(GRUTest, TestFP32) {
<< duration / FLAGS_repeats << " ms in average.";
}
TEST_F(GRUTest, TestFP16) {
InitHalfInput();
GRUCompute<half, PRECISION(kFP16)> kernel;
kernel.SetParam(param_);
kernel.SetContext(std::move(ctx_));
for (int i = 0; i < FLAGS_warmup; ++i) {
kernel.Launch();
cudaDeviceSynchronize();
}
auto start = GetCurrentUS();
kernel.PrepareForRun();
for (int i = 0; i < FLAGS_repeats; ++i) {
kernel.Run();
}
cudaDeviceSynchronize();
auto duration = (GetCurrentUS() - start) / 1000.0;
LOG(INFO) << "fp16, warmup: " << FLAGS_warmup
<< ", repeats: " << FLAGS_repeats << ", spend "
<< duration / FLAGS_repeats << " ms in average.";
}
} // namespace cuda
} // namespace kernels
} // namespace lite
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/backends/cuda/math/activation.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/sigmoid_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T, PrecisionType Ptype>
void SigmoidCompute<T, Ptype>::Run() {
auto& param = this->template Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
int num = static_cast<int>(param.X->numel());
auto input = param.X->template data<T>();
auto output = param.Out->template mutable_data<T>(TARGET(kCUDA));
lite::cuda::math::sigmoid<T>(num, input, output, stream);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
using SigmoidFp32 =
paddle::lite::kernels::cuda::SigmoidCompute<float, PRECISION(kFloat)>;
using SigmoidFp16 =
paddle::lite::kernels::cuda::SigmoidCompute<half, PRECISION(kFP16)>;
REGISTER_LITE_KERNEL(sigmoid, kCUDA, kFloat, kNCHW, SigmoidFp32, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
REGISTER_LITE_KERNEL(sigmoid, kCUDA, kFP16, kNCHW, SigmoidFp16, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T, PrecisionType Ptype>
class SigmoidCompute : public KernelLite<TARGET(kCUDA), Ptype> {
public:
using param_t = operators::ActivationParam;
void Run() override;
virtual ~SigmoidCompute() = default;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/sigmoid_compute.h"
#include <gtest/gtest.h>
#include <cmath>
#include <memory>
#include <utility>
#include <vector>
#include "lite/api/test_helper.h"
#include "lite/backends/cuda/target_wrapper.h"
#include "lite/utils/float16.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class SigmoidTest : public ::testing::Test {
protected:
SigmoidTest() : m_(8), n_(64), shape_({m_, n_}) {
x_ref_.Resize(lite::DDim(shape_));
x_gpu_.Resize(lite::DDim(shape_));
auto x_ref_data = x_ref_.mutable_data<float>();
for (int64_t i = 0; i < x_ref_.numel(); i++) {
x_ref_data[i] = static_cast<float>(i % 10 * 0.2);
}
out_ref_.Resize(lite::DDim(shape_));
out_cpu_.Resize(out_ref_.dims());
out_gpu_.Resize(out_ref_.dims());
RunBaseLine();
InitParamAndContext();
}
void InitParamAndContext() {
ctx_.reset(new KernelContext);
cudaStreamCreate(&stream_);
auto& context = ctx_->As<CUDAContext>();
context.SetExecStream(stream_);
param_.X = &x_gpu_;
param_.Out = &out_gpu_;
}
void InitFloatInput() {
x_gpu_.Assign<float, lite::DDim, TARGET(kCUDA)>(x_ref_.data<float>(),
x_gpu_.dims());
}
void InitHalfInput() {
x_half_.Resize(lite::DDim(shape_));
auto x_half_data = x_half_.mutable_data<half>();
for (int64_t i = 0; i < x_half_.numel(); i++) {
x_half_data[i] = half(lite::float16(x_ref_.data<float>()[i]));
}
x_gpu_.Assign<half, lite::DDim, TARGET(kCUDA)>(x_half_data, x_gpu_.dims());
}
void RunBaseLine() {
for (int64_t i = 0; i < x_ref_.numel(); ++i) {
out_ref_.mutable_data<float>()[i] =
1.f / (1.f + expf(-1 * x_ref_.data<float>()[i]));
}
}
int m_, n_;
std::vector<int64_t> shape_;
lite::Tensor x_ref_, out_ref_;
lite::Tensor x_gpu_;
lite::Tensor x_half_;
lite::Tensor out_cpu_, out_gpu_;
operators::ActivationParam param_;
std::unique_ptr<KernelContext> ctx_;
cudaStream_t stream_;
};
TEST_F(SigmoidTest, TestFP32) {
InitFloatInput();
SigmoidCompute<float, PRECISION(kFloat)> kernel;
kernel.SetParam(param_);
kernel.SetContext(std::move(ctx_));
for (int i = 0; i < FLAGS_warmup; ++i) {
kernel.Launch();
cudaDeviceSynchronize();
}
auto start = GetCurrentUS();
kernel.PrepareForRun();
for (int i = 0; i < FLAGS_repeats; ++i) {
kernel.Run();
}
cudaDeviceSynchronize();
auto duration = (GetCurrentUS() - start) / 1000.0;
LOG(INFO) << "fp32, warmup: " << FLAGS_warmup
<< ", repeats: " << FLAGS_repeats << ", spend "
<< duration / FLAGS_repeats << " ms in average.";
CopySync<TARGET(kCUDA)>(out_cpu_.mutable_data<float>(),
out_gpu_.data<float>(),
sizeof(float) * out_gpu_.numel(),
IoDirection::DtoH);
for (int i = 0; i < out_gpu_.numel(); ++i) {
float res = out_cpu_.data<float>()[i];
float ref = out_ref_.data<float>()[i];
EXPECT_NEAR(fabs(res - ref) / ref, 0.f, 1e-5);
}
}
TEST_F(SigmoidTest, TestFP16) {
InitHalfInput();
SigmoidCompute<half, PRECISION(kFP16)> kernel;
kernel.SetParam(param_);
kernel.SetContext(std::move(ctx_));
for (int i = 0; i < FLAGS_warmup; ++i) {
kernel.Launch();
cudaDeviceSynchronize();
}
auto start = GetCurrentUS();
kernel.PrepareForRun();
for (int i = 0; i < FLAGS_repeats; ++i) {
kernel.Run();
}
cudaDeviceSynchronize();
auto duration = (GetCurrentUS() - start) / 1000.0;
LOG(INFO) << "fp16, warmup: " << FLAGS_warmup
<< ", repeats: " << FLAGS_repeats << ", spend "
<< duration / FLAGS_repeats << " ms in average.";
const half* out_gpu_data = out_gpu_.data<half>();
half* out_cpu_data = out_cpu_.mutable_data<half>();
CopySync<TARGET(kCUDA)>(out_cpu_data,
out_gpu_data,
sizeof(half) * out_gpu_.numel(),
IoDirection::DtoH);
for (int i = 0; i < out_gpu_.numel(); ++i) {
float res = static_cast<float>(lite::float16(out_cpu_data[i]));
float ref = out_ref_.data<float>()[i];
EXPECT_NEAR(fabs(res - ref) / (ref + 1e-5), 0., 2e-2);
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -39,8 +39,8 @@ readonly THIRDPARTY_TAR=https://paddle-inference-dist.bj.bcebos.com/PaddleLite/t
readonly workspace=$PWD
# if operating in mac env, we should expand the maximum file num
os_nmae=`uname -s`
if [ ${os_nmae} == "Darwin" ]; then
os_name=`uname -s`
if [ ${os_name} == "Darwin" ]; then
ulimit -n 1024
fi
......
......@@ -21,8 +21,8 @@ USE_ADB_EMULATOR=ON
LITE_WITH_COVERAGE=OFF
# if operating in mac env, we should expand the maximum file num
os_nmae=`uname -s`
if [ ${os_nmae} == "Darwin" ]; then
os_name=`uname -s`
if [ ${os_name} == "Darwin" ]; then
ulimit -n 1024
fi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册