提交 41f2b344 编写于 作者: J jiweibo

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle-Lite into...

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle-Lite into stream_manage. test=develop
......@@ -94,12 +94,10 @@ function(compile_flatbuffers_schema_to_cpp_opt TARGET SRC_FBS OPT)
message(STATUS "SRC_FBS_DIR: ${SRC_FBS_DIR}")
string(REGEX REPLACE "\\.fbs$" "_generated.h" GEN_HEADER ${SRC_FBS})
add_custom_command(
OUTPUT ${GEN_HEADER}
OUTPUT "${CMAKE_CURRENT_SOURCE_DIR}/${GEN_HEADER}"
COMMAND "${FLATBUFFERS_FLATC_EXECUTABLE}"
--cpp --gen-mutable --gen-object-api --reflect-names
--cpp-ptr-type flatbuffers::unique_ptr # Used to test with C++98 STLs
${OPT}
-I "${CMAKE_CURRENT_SOURCE_DIR}/tests/include_test"
-o "${CMAKE_CURRENT_SOURCE_DIR}/${SRC_FBS_DIR}"
"${CMAKE_CURRENT_SOURCE_DIR}/${SRC_FBS}"
DEPENDS flatbuffers
......
......@@ -37,14 +37,25 @@ rm ./lite/api/paddle_use_kernels.h
rm ./lite/api/paddle_use_ops.h
# 设置编译参数并开始编译
# android-armv7:cpu+gpu+cv+extra
./lite/tools/build_android.sh \
--arch=armv7 \
--toolchain=clang \
--with_cv=OFF \
--with_log=OFF \
--with_extra=OFF \
--with_extra=ON \
--with_cv=ON \
--with_opencl=ON
# android-armv8:cpu+gpu+cv+extra
./lite/tools/build_android.sh \
--arch=armv8 \
--toolchain=clang \
--with_log=OFF \
--with_extra=ON \
--with_cv=ON \
--with_opencl=ON
# 注:编译帮助请执行: ./lite/tools/build_android.sh help
```
......@@ -206,7 +217,7 @@ adb shell "export GLOG_v=4; \
## 3. 如何在Code中使用
即编译产物`demo/cxx/mobile_light`目录下的代码,在线版参考GitHub仓库[./lite/demo/cxx/mobile_light/mobilenetv1_light_api.cc](https://github.com/PaddlePaddle/Paddle-Lite/blob/develop/lite/demo/cxx/mobile_light/mobilenetv1_light_api.cc);
即编译产物`demo/cxx/mobile_light`目录下的代码,在线版参考GitHub仓库[./lite/demo/cxx/mobile_light/mobilenetv1_light_api.cc](https://github.com/PaddlePaddle/Paddle-Lite/blob/develop/lite/demo/cxx/mobile_light/mobilenetv1_light_api.cc),其中也包括判断当前设备是否支持OpenCL的方法;
注:这里给出的链接会跳转到线上最新develop分支的代码,很可能与您本地的代码存在差异,建议参考自己本地位于`lite/demo/cxx/`目录的代码,查看如何使用。
......
......@@ -32,9 +32,22 @@
#include "lite/backends/mlu/target_wrapper.h"
#endif
#ifdef LITE_WITH_OPENCL
#include "lite/backends/opencl/cl_runtime.h"
#endif
namespace paddle {
namespace lite_api {
bool IsOpenCLBackendValid() {
bool opencl_valid = false;
#ifdef LITE_WITH_OPENCL
opencl_valid = paddle::lite::CLRuntime::Global()->OpenCLAvaliableForDevice();
#endif
LOG(INFO) << "opencl_valid:" << opencl_valid;
return opencl_valid;
}
Tensor::Tensor(void *raw) : raw_tensor_(raw) {}
// TODO(Superjomn) refine this by using another `const void* const_raw`;
......
......@@ -38,6 +38,9 @@ using lod_t = std::vector<std::vector<uint64_t>>;
enum class LiteModelType { kProtobuf = 0, kNaiveBuffer, UNK };
// return true if current device supports OpenCL model
LITE_API bool IsOpenCLBackendValid();
struct LITE_API Tensor {
explicit Tensor(void* raw);
explicit Tensor(const void* raw);
......
......@@ -139,6 +139,86 @@ static bool conv_trans_weights_numc(const dtype* din,
}
return true;
}
template <typename Dtype>
void transpose(const Dtype* din, Dtype* dout, int m, int n) {
// nxm == mxn
// 4x4
int cnt_n = n >> 2;
int remain_n = n & 3;
int cnt_m = m >> 2;
int remain_m = m & 3;
int nn_num = n << 2; // n * 4
int mm_num = m << 2; // m * 4
for (int x = 0; x < cnt_n; x++) {
const Dtype* din_ptr0 = din + x * mm_num;
const Dtype* din_ptr1 = din_ptr0 + m;
const Dtype* din_ptr2 = din_ptr1 + m;
const Dtype* din_ptr3 = din_ptr2 + m;
Dtype* dout_ptr0 = dout + x * 4;
for (int y = 0; y < cnt_m; y++) {
float32x4_t din0 = vld1q_f32(din_ptr0); // a00 a01 a02 a03
float32x4_t din1 = vld1q_f32(din_ptr1);
float32x4_t din2 = vld1q_f32(din_ptr2);
float32x4_t din3 = vld1q_f32(din_ptr3);
Dtype* dout_ptr1 = dout_ptr0 + n;
Dtype* dout_ptr2 = dout_ptr1 + n;
Dtype* dout_ptr3 = dout_ptr2 + n;
// a00 b00 a02 b02 a01 b01 a03 b03
float32x4x2_t tmp0 = vtrnq_f32(din0, din1);
// c00 d00 c02 d02 c01 d01 c03 d03
float32x4x2_t tmp2 = vtrnq_f32(din2, din3);
din_ptr0 += 4;
din_ptr1 += 4;
// a00 b00 c00 d00 a02 b02 c02 d02
// a01 b01 c01 d01 a03 b03 c03 d03
float tmp_val1 = tmp0.val[0][2];
float tmp_val2 = tmp0.val[0][3];
tmp0.val[0][2] = tmp2.val[0][0];
tmp0.val[0][3] = tmp2.val[0][1];
float tmp_val3 = tmp0.val[1][2];
float tmp_val4 = tmp0.val[1][3];
tmp2.val[0][0] = tmp_val1;
tmp2.val[0][1] = tmp_val2;
tmp0.val[1][2] = tmp2.val[1][0];
tmp0.val[1][3] = tmp2.val[1][1];
tmp2.val[1][0] = tmp_val3;
tmp2.val[1][1] = tmp_val4;
din_ptr2 += 4;
din_ptr3 += 4;
vst1q_f32(dout_ptr0, tmp0.val[0]);
vst1q_f32(dout_ptr1, tmp0.val[1]);
dout_ptr0 += nn_num;
vst1q_f32(dout_ptr2, tmp2.val[0]);
vst1q_f32(dout_ptr3, tmp2.val[1]);
}
for (int y = 0; y < remain_m; y++) {
*dout_ptr0++ = *din_ptr0++;
*dout_ptr0++ = *din_ptr1++;
*dout_ptr0++ = *din_ptr2++;
*dout_ptr0++ = *din_ptr3++;
}
}
const Dtype* din_ptr0 = din + cnt_n * mm_num;
dout = dout + cnt_n * 4;
for (int x = 0; x < remain_n; x++) {
Dtype* dout_ptr0 = dout + x * 4;
for (int y = 0; y < cnt_m; y++) {
float32x4_t din0 = vld1q_f32(din_ptr0);
Dtype* dout_ptr1 = dout_ptr0 + n;
Dtype* dout_ptr2 = dout_ptr1 + n;
Dtype* dout_ptr3 = dout_ptr2 + n;
din_ptr0 += 4;
*dout_ptr0 = din0[0];
*dout_ptr1 = din0[1];
dout_ptr0 += nn_num;
*dout_ptr2 = din0[2];
*dout_ptr3 = din0[3];
}
for (int y = 0; y < remain_m; y++) {
*dout_ptr0++ = *din_ptr0++;
}
}
}
/*preprocessing inputs
* input din: [1, chin, he-hs, we - ws] --> outputs dout: [n, chin, 1, we - ws]
* n = he - hs
......
......@@ -2044,7 +2044,7 @@ void pooling3x3s1p0_avg(const float* din,
} else {
if (pad_bottom > 1) {
coef_h = 1.f / 3;
} else if (pad_bottom = 1) {
} else if (pad_bottom == 1) {
coef_h = 0.5f;
} else {
coef_h = 1.f;
......
......@@ -46,11 +46,60 @@ void seq_pool_sum<float>(const float* din,
memcpy(dout_ptr, din_ptr, width * sizeof(float));
din_ptr += width;
height = height - 1;
for (int h = 0; h < height; h++) {
for (int w = 0; w < width; ++w) {
dout_ptr[w] += din_ptr[w];
int cnt_w = width >> 2;
int remain_w = width & 3;
int cnt_h = height >> 2;
int remain_h = height & 3;
int stride = width << 2;
for (int w = 0; w < cnt_w; w++) {
const float* din_ptr0 = din_ptr + w * 4;
float32x4_t dout_val = vld1q_f32(dout_ptr);
const float* din_ptr1 = din_ptr0 + width;
const float* din_ptr2 = din_ptr1 + width;
const float* din_ptr3 = din_ptr2 + width;
for (int h = 0; h < cnt_h; h++) {
float32x4_t din0 = vld1q_f32(din_ptr0);
float32x4_t din1 = vld1q_f32(din_ptr1);
float32x4_t din2 = vld1q_f32(din_ptr2);
float32x4_t din3 = vld1q_f32(din_ptr3);
dout_val = vaddq_f32(din0, dout_val);
float32x4_t tmp = vaddq_f32(din1, din2);
din_ptr0 += stride;
din_ptr1 += stride;
dout_val = vaddq_f32(din3, dout_val);
din_ptr2 += stride;
din_ptr3 += stride;
dout_val = vaddq_f32(tmp, dout_val);
}
din_ptr += width;
for (int h = 0; h < remain_h; h++) {
float32x4_t din0 = vld1q_f32(din_ptr0);
dout_val = vaddq_f32(din0, dout_val);
din_ptr0 += width;
}
vst1q_f32(dout_ptr, dout_val);
dout_ptr += 4;
}
const float* din_ptr00 = din_ptr + cnt_w * 4;
for (int w = 0; w < remain_w; w++) {
const float* din_ptr0 = din_ptr00 + w;
const float* din_ptr1 = din_ptr0 + width;
const float* din_ptr2 = din_ptr1 + width;
const float* din_ptr3 = din_ptr2 + width;
for (int h = 0; h < cnt_h; h++) {
*dout_ptr += din_ptr0[0];
float tmp = din_ptr1[0] + din_ptr2[0];
din_ptr0 += stride;
din_ptr1 += stride;
*dout_ptr += din_ptr3[0];
din_ptr2 += stride;
din_ptr3 += stride;
*dout_ptr += tmp;
}
for (int h = 0; h < remain_h; h++) {
*dout_ptr += din_ptr0[0];
din_ptr0 += width;
}
dout_ptr++;
}
}
}
......@@ -144,12 +193,62 @@ void seq_pool_max<float>(const float* din,
} else {
memcpy(dout_ptr, din_ptr, width * sizeof(float));
din_ptr += width;
int remain_h = height - 1;
for (int h = 0; h < remain_h; h++) {
for (int w = 0; w < width; w++) {
dout_ptr[w] = std::max(dout_ptr[w], din_ptr[w]);
height = height - 1;
int cnt_w = width >> 2;
int remain_w = width & 3;
int cnt_h = height >> 2;
int remain_h = height & 3;
int stride = width << 2;
for (int w = 0; w < cnt_w; w++) {
const float* din_ptr0 = din_ptr + w * 4;
float32x4_t dout_val = vld1q_f32(dout_ptr);
const float* din_ptr1 = din_ptr0 + width;
const float* din_ptr2 = din_ptr1 + width;
const float* din_ptr3 = din_ptr2 + width;
for (int h = 0; h < cnt_h; h++) {
float32x4_t din0 = vld1q_f32(din_ptr0);
float32x4_t din1 = vld1q_f32(din_ptr1);
float32x4_t din2 = vld1q_f32(din_ptr2);
float32x4_t din3 = vld1q_f32(din_ptr3);
dout_val = vmaxq_f32(din0, dout_val);
float32x4_t tmp = vmaxq_f32(din1, din2);
din_ptr0 += stride;
din_ptr1 += stride;
dout_val = vmaxq_f32(din3, dout_val);
din_ptr2 += stride;
din_ptr3 += stride;
dout_val = vmaxq_f32(tmp, dout_val);
}
din_ptr += width;
for (int h = 0; h < remain_h; h++) {
float32x4_t din0 = vld1q_f32(din_ptr0);
dout_val = vmaxq_f32(din0, dout_val);
din_ptr0 += width;
}
vst1q_f32(dout_ptr, dout_val);
dout_ptr += 4;
}
const float* din_ptr00 = din_ptr + cnt_w * 4;
for (int w = 0; w < remain_w; w++) {
const float* din_ptr0 = din_ptr00 + w;
const float* din_ptr1 = din_ptr0 + width;
const float* din_ptr2 = din_ptr1 + width;
const float* din_ptr3 = din_ptr2 + width;
for (int h = 0; h < cnt_h; h++) {
*dout_ptr += din_ptr0[0];
*dout_ptr = std::max(*dout_ptr, din_ptr0[0]);
float tmp = std::max(din_ptr1[0], din_ptr2[0]);
din_ptr0 += stride;
din_ptr1 += stride;
*dout_ptr = std::max(*dout_ptr, din_ptr3[0]);
din_ptr2 += stride;
din_ptr3 += stride;
*dout_ptr = std::max(*dout_ptr, tmp);
}
for (int h = 0; h < remain_h; h++) {
*dout_ptr = std::max(*dout_ptr, din_ptr0[0]);
din_ptr0 += width;
}
dout_ptr++;
}
}
}
......
......@@ -11,10 +11,13 @@ nv_library(cuda_transpose SRCS transpose.cu DEPS ${cuda_static_deps})
nv_library(cudnn_conv SRCS cudnn_conv.cc DEPS cuda_activation cuda_scale cuda_type_trans ${cuda_static_deps})
nv_library(cuda_elementwise SRCS elementwise.cu DEPS ${cuda_static_deps})
nv_library(cudnn_pool SRCS cudnn_pool.cc DEPS ${cuda_static_deps})
nv_library(cuda_gru_forward SRCS gru_forward.cu DEPS cuda_activation ${cuda_static_deps})
nv_library(cuda_sequence2batch SRCS sequence2batch.cu DEPS ${cuda_static_deps})
nv_library(cuda_gemm SRCS gemm.cc DEPS ${cuda_static_deps})
nv_library(cuda_batched_gemm SRCS batched_gemm.cc DEPS ${cuda_static_deps})
nv_library(cuda_strided_gemm SRCS strided_gemm.cc DEPS ${cuda_static_deps})
nv_library(cuda_sequence_padding SRCS sequence_padding.cu DEPS ${cuda_static_deps})
nv_library(cuda_bias SRCS bias.cu DEPS ${cuda_static_deps})
set (
math_cuda
......@@ -25,10 +28,13 @@ set (
cuda_transpose
cuda_elementwise
cudnn_pool
cuda_gru_forward
cuda_sequence2batch
cuda_gemm
cuda_batched_gemm
cuda_strided_gemm
cuda_sequence_padding
cuda_bias
)
set(math_cuda "${math_cuda}" CACHE GLOBAL "math cuda")
......@@ -13,6 +13,7 @@
// limitations under the License.
#include <iostream>
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/backends/cuda/math/activation.h"
#include "lite/backends/cuda/math/utils.h"
......@@ -21,6 +22,20 @@ namespace lite {
namespace cuda {
namespace math {
ActivationType GetActiveType(const std::string& act) {
if (act == "sigmoid") {
return kSigmoid;
} else if (act == "relu") {
return kReLU;
} else if (act == "tanh") {
return kTanh;
} else if (act == "identify") {
return kIdentity;
} else {
LOG(FATAL) << "not supported activation: " << act;
}
}
template <typename T>
__global__ void relu_kernel(const int num,
const float alpha,
......@@ -470,6 +485,76 @@ template void relu(int, const half*, half*, float, cudaStream_t);
template void bias_relu(
int, const float*, const float* bias, float*, float, cudaStream_t);
// ------------- sigmoid -------------
template <typename T>
__global__ void sigmoid_kernel(const int num, const T* in, T* out) {
CUDA_KERNEL_LOOP(i, num) {
#if __CUDA_ARCH__ >= 350
out[i] = static_cast<T>(1.0f) /
(static_cast<T>(1.0f) + expf(-1 * __ldg(in + i)));
#else
out[i] = static_cast<T>(1.0f) / (static_cast<T>(1.0f) + expf(-in[i]));
#endif
}
}
template <>
__global__ void sigmoid_kernel(const int num, const half* in, half* out) {
CUDA_KERNEL_LOOP(i, num) {
half tmp = __float2half(1.0f);
#if __CUDA_ARCH__ >= 530
out[i] = __hdiv(
tmp, __hadd(tmp, hexp(__hmul(__float2half(-1.0f), __ldg(in + i)))));
#else
out[i] = __float2half(1.0f / (1.0f + expf(-1 * __half2float(in[i]))));
#endif
}
}
template <>
__global__ void sigmoid_kernel(const int num, const half2* in, half2* out) {
CUDA_KERNEL_LOOP(i, num) {
half2 tmp = __floats2half2_rn(1.0f, 1.0f);
#if __CUDA_ARCH__ >= 530
out[i] = __h2div(tmp,
__hadd2(tmp,
h2exp(__hmul2(__floats2half2_rn(-1.0f, -1.0f),
__ldg(in + i)))));
#else
out[i].x = __float2half(1.0f / (1.0f + expf(-1 * __half2float(in[i].x))));
out[i].y = __float2half(1.0f / (1.0f + expf(-1 * __half2float(in[i].y))));
#endif
}
}
template <typename T>
void sigmoid(const int num, const T* din, T* dout, cudaStream_t stream) {
sigmoid_kernel<T><<<CUDA_GET_BLOCKS(num), CUDA_NUM_THREADS, 0, stream>>>(
num, din, dout);
CUDA_POST_KERNEL_CHECK;
}
template <>
void sigmoid(const int num, const half* din, half* dout, cudaStream_t stream) {
if (num % 2 == 0) {
const half2* din2 = reinterpret_cast<const half2*>(din);
half2* dout2 = reinterpret_cast<half2*>(dout);
sigmoid_kernel<
half2><<<CUDA_GET_BLOCKS(num / 2), CUDA_NUM_THREADS, 0, stream>>>(
num / 2, din2, dout2);
} else {
sigmoid_kernel<half><<<CUDA_GET_BLOCKS(num), CUDA_NUM_THREADS, 0, stream>>>(
num, din, dout);
}
CUDA_POST_KERNEL_CHECK;
}
template void sigmoid(const int num,
const float* din,
float* dout,
cudaStream_t stream);
} // namespace math
} // namespace cuda
} // namespace lite
......
......@@ -17,11 +17,22 @@
#include <cuda_runtime.h>
#include <string>
#include "lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
enum ActivationType {
kSigmoid,
kReLU,
kTanh,
kIdentity,
};
ActivationType GetActiveType(const std::string& act);
// fp32 and half
template <typename T>
void relu(int num, const T* din, T* dout, float alpha, cudaStream_t stream);
......@@ -72,6 +83,9 @@ void bias_int8_nhwc(int num,
const void* scale,
cudaStream_t stream);
template <typename T>
void sigmoid(const int num, const T* din, T* dout, cudaStream_t stream);
} // namespace math
} // namespace cuda
} // namespace lite
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/cuda/math/bias.h"
#include <iostream>
#include "lite/backends/cuda/cuda_utils.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <typename T>
__global__ void RowwiseAddKernel(
const T* a, const T* b, T* c, int width, int num) {
CUDA_KERNEL_LOOP(i, num) {
int h = i / width;
int w = i - h * width;
c[i] = a[i] + b[w];
}
}
template <>
__global__ void RowwiseAddKernel(
const half* a, const half* b, half* c, int width, int num) {
CUDA_KERNEL_LOOP(i, num) {
int h = i / width;
int w = i - h * width;
c[i] = __hadd(a[i], b[w]);
}
}
template <typename T>
void RowwiseAdd<T>::operator()(const T* input,
const T* bias,
T* output,
const int width,
const int count,
const cudaStream_t& stream) {
RowwiseAddKernel<T><<<CUDA_GET_BLOCKS(count), CUDA_NUM_THREADS, 0, stream>>>(
input, bias, output, width, count);
CUDA_POST_KERNEL_CHECK;
}
template struct RowwiseAdd<float>;
template struct RowwiseAdd<half>;
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include "lite/backends/cuda/cuda_utils.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <typename T>
struct RowwiseAdd {
void operator()(const T* input,
const T* bias,
T* output,
const int width,
const int count,
const cudaStream_t& stream);
};
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include "lite/backends/cuda/math/gru_forward.h"
#include "lite/core/device_info.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
/*
* threads(frame_per_block, batch_per_block)
* grid(frame_blocks, batch_blocks)
*/
template <typename T>
__global__ void GruForwardResetOutput(
T* gate_value,
T* reset_output_value,
T* prev_output_value,
int frame_size,
int batch_size,
lite::cuda::math::ActivationType active_gate,
bool is_batch) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return;
int batch_idx = 0;
if (is_batch) {
batch_idx = blockIdx.y * blockDim.y + threadIdx.y;
if (batch_idx >= batch_size) return;
gate_value += batch_idx * 3 * frame_size;
reset_output_value += batch_idx * frame_size;
}
T prev_out = 0;
T reset_out_val;
T update_gate_value = gate_value[frame_idx + frame_size * 0];
T reset_gate_value = gate_value[frame_idx + frame_size * 1];
if (prev_output_value) {
if (is_batch) {
prev_output_value += batch_idx * frame_size;
}
prev_out = prev_output_value[frame_idx];
}
if (active_gate == lite::cuda::math::ActivationType::kSigmoid) {
update_gate_value = Sigmoid(update_gate_value);
reset_gate_value = Sigmoid(reset_gate_value);
} else if (active_gate == lite::cuda::math::ActivationType::kReLU) {
update_gate_value = ReLU(update_gate_value);
reset_gate_value = ReLU(reset_gate_value);
} else if (active_gate == lite::cuda::math::ActivationType::kTanh) {
update_gate_value = Tanh(update_gate_value);
reset_gate_value = Tanh(reset_gate_value);
}
reset_out_val = prev_out * reset_gate_value;
gate_value[frame_idx + frame_size * 0] = update_gate_value;
gate_value[frame_idx + frame_size * 1] = reset_gate_value;
reset_output_value[frame_idx] = reset_out_val;
}
template <>
__global__ void GruForwardResetOutput(
half* gate_value,
half* reset_output_value,
half* prev_output_value,
int frame_size,
int batch_size,
lite::cuda::math::ActivationType active_gate,
bool is_batch) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return;
int batch_idx = 0;
if (is_batch) {
batch_idx = blockIdx.y * blockDim.y + threadIdx.y;
if (batch_idx >= batch_size) return;
gate_value += batch_idx * 3 * frame_size;
reset_output_value += batch_idx * frame_size;
}
half prev_out = 0;
half reset_out_val;
half update_gate_value = gate_value[frame_idx + frame_size * 0];
half reset_gate_value = gate_value[frame_idx + frame_size * 1];
if (prev_output_value) {
if (is_batch) {
prev_output_value += batch_idx * frame_size;
}
prev_out = prev_output_value[frame_idx];
}
if (active_gate == ActivationType::kSigmoid) {
update_gate_value = Sigmoid(update_gate_value);
reset_gate_value = Sigmoid(reset_gate_value);
} else if (active_gate == ActivationType::kReLU) {
update_gate_value = ReLU(update_gate_value);
reset_gate_value = ReLU(reset_gate_value);
} else if (active_gate == ActivationType::kTanh) {
update_gate_value = Tanh(update_gate_value);
reset_gate_value = Tanh(reset_gate_value);
}
#if __CUDA_ARCH__ >= 530
reset_out_val = __hmul(prev_out, reset_gate_value);
#else
reset_out_val =
__float2half(__half2float(prev_out) * __half2float(reset_gate_value));
#endif
gate_value[frame_idx + frame_size * 0] = update_gate_value;
gate_value[frame_idx + frame_size * 1] = reset_gate_value;
reset_output_value[frame_idx] = reset_out_val;
}
/*
* threads(frame_per_block, batch_per_block)
* grid(frame_blocks, batch_blocks)
*/
template <typename T>
__global__ void GruForwardFinalOutput(
T* gate_value,
T* prev_output_value,
T* output_value,
int frame_size,
int batch_size,
lite::cuda::math::ActivationType active_node,
bool origin_mode,
bool is_batch) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return;
int batch_idx = 0;
if (is_batch) {
batch_idx = blockIdx.y * blockDim.y + threadIdx.y;
if (batch_idx >= batch_size) {
return;
}
gate_value += batch_idx * 3 * frame_size;
output_value += batch_idx * frame_size;
}
T output;
T prev_out = 0;
T update_gate_value = gate_value[frame_idx + frame_size * 0];
T state_frame_value = gate_value[frame_idx + frame_size * 2];
if (prev_output_value) {
if (is_batch) prev_output_value += batch_idx * frame_size;
prev_out = prev_output_value[frame_idx];
}
if (active_node == lite::cuda::math::ActivationType::kSigmoid) {
state_frame_value = Sigmoid(state_frame_value);
} else if (active_node == lite::cuda::math::ActivationType::kReLU) {
state_frame_value = ReLU(state_frame_value);
} else if (active_node == lite::cuda::math::ActivationType::kTanh) {
state_frame_value = Tanh(state_frame_value);
}
if (origin_mode) {
output = update_gate_value * prev_out + state_frame_value -
update_gate_value * state_frame_value;
} else {
output = prev_out - update_gate_value * prev_out +
update_gate_value * state_frame_value;
}
gate_value[frame_idx + frame_size * 2] = state_frame_value;
output_value[frame_idx] = output;
}
template <>
__global__ void GruForwardFinalOutput(
half* gate_value,
half* prev_output_value,
half* output_value,
int frame_size,
int batch_size,
lite::cuda::math::ActivationType active_node,
bool origin_mode,
bool is_batch) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return;
int batch_idx = 0;
if (is_batch) {
batch_idx = blockIdx.y * blockDim.y + threadIdx.y;
if (batch_idx >= batch_size) {
return;
}
gate_value += batch_idx * 3 * frame_size;
output_value += batch_idx * frame_size;
}
half output;
half prev_out = 0;
half update_gate_value = gate_value[frame_idx + frame_size * 0];
half state_frame_value = gate_value[frame_idx + frame_size * 2];
if (prev_output_value) {
if (is_batch) prev_output_value += batch_idx * frame_size;
prev_out = prev_output_value[frame_idx];
}
if (active_node == lite::cuda::math::ActivationType::kSigmoid) {
state_frame_value = Sigmoid(state_frame_value);
} else if (active_node == lite::cuda::math::ActivationType::kReLU) {
state_frame_value = ReLU(state_frame_value);
} else if (active_node == lite::cuda::math::ActivationType::kTanh) {
state_frame_value = Tanh(state_frame_value);
}
if (origin_mode) {
#if __CUDA_ARCH__ >= 530
output =
__hsub(__hadd(__hmul(update_gate_value, prev_out), state_frame_value),
__hmul(update_gate_value, state_frame_value));
#else
output = __float2half(
__half2float(update_gate_value) * __half2float(prev_out) +
__half2float(state_frame_value) -
__half2float(update_gate_value) * __half2float(state_frame_value));
#endif
} else {
#if __CUDA_ARCH__ >= 530
output = prev_out - update_gate_value * prev_out +
update_gate_value * state_frame_value;
output = __hadd(__hsub(prev_out, __hmul(update_gate_value, prev_out)),
__hmul(update_gate_value, state_frame_value));
#else
output = __float2half(
__half2float(prev_out) -
__half2float(update_gate_value) * __half2float(prev_out) +
__half2float(update_gate_value) * __half2float(state_frame_value));
#endif
}
gate_value[frame_idx + frame_size * 2] = state_frame_value;
output_value[frame_idx] = output;
}
template __global__ void GruForwardFinalOutput<float>(
float* gate_value,
float* prev_output_value,
float* output_value,
int frame_size,
int batch_size,
lite::cuda::math::ActivationType active_node,
bool origin_mode,
bool is_batch);
template __global__ void GruForwardResetOutput<float>(
float* gate_value,
float* reset_output_value,
float* prev_output_value,
int frame_size,
int batch_size,
lite::cuda::math::ActivationType active_gate,
bool is_batch);
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cudnn.h>
#include <string>
#include <vector>
#include "lite/api/paddle_place.h"
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/backends/cuda/math/activation.h"
#include "lite/core/context.h"
#include "lite/core/target_wrapper.h"
#include "lite/operators/op_params.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <typename Dtype>
inline __device__ Dtype Sigmoid(const Dtype a) {
return static_cast<Dtype>(1.0) / (static_cast<Dtype>(1.0) + expf(-a));
}
template <>
inline __device__ half Sigmoid(const half a) {
#if __CUDA_ARCH__ >= 530
const half tmp = __float2half(1.0f);
return __hdiv(tmp, __hadd(tmp, hexp(__hmul(__float2half(-1.f), a))));
#else
return __float2half(1.0f / (expf(__half2float(a) * -1) + 1.0f));
#endif
}
template <typename Dtype>
inline __device__ Dtype ReLU(const Dtype a) {
return a > static_cast<Dtype>(0.f) ? a : static_cast<Dtype>(0.f);
}
template <>
inline __device__ half ReLU(const half a) {
const half tmp = __float2half(0.f);
#if __CUDA_ARCH__ >= 530
return __hgt(a, tmp) ? a : tmp;
#else
return __float2half(__half2float(a) > 0.f ? __half2float(a) : 0.f);
#endif
}
template <typename Dtype>
inline __device__ Dtype Tanh(const Dtype a) {
Dtype tmp = static_cast<Dtype>(-2.0) * a;
return (static_cast<Dtype>(2.0) / (static_cast<Dtype>(1.0) + expf(tmp))) -
static_cast<Dtype>(1.0);
}
template <>
inline __device__ half Tanh(const half a) {
#if __CUDA_ARCH__ >= 530
half tmp = __float2half(1.0f);
half numerator = __hmul(__float2half(-2.0f), a);
return __hsub(__hdiv(__float2half(2.0f), __hadd(tmp, hexp(numerator))), tmp);
#else
float tmp = -2.0f * __half2float(a);
return __float2half(2.0f / (1.0f + expf(tmp)) - 1.0f);
#endif
}
template <typename T>
__global__ void GruForwardResetOutput(
T* gate_value,
T* reset_output_value,
T* prev_output_value,
int frame_size,
int batch_size,
lite::cuda::math::ActivationType active_gate,
bool is_batch);
template <typename T>
__global__ void GruForwardFinalOutput(
T* gate_value,
T* prev_output_value,
T* output_value,
int frame_size,
int batch_size,
lite::cuda::math::ActivationType active_node,
bool origin_mode,
bool is_batch);
/*
* threads(tile_size, 1)
* grids(frame_blocks, 1)
*/
template <class T, int TiledSize>
__global__ void FastCollectiveGruGate(T* gate_value,
T* prev_output_value,
T* gate_weight,
T* reset_output,
int frame_size,
ActivationType active_node) {
T xt_0 = 0.0f;
T a0 = 0.0f;
T c0 = 0.0f;
T b0[TiledSize];
int col = blockIdx.x * blockDim.x + threadIdx.x;
int tiled_mask = ((1 << TiledSize) - 1);
// tiled matrix multiply using register shift, faster than sm.
if (prev_output_value) {
for (int k = 0; k < (((frame_size - 1) / TiledSize) + 1); ++k) {
a0 = 0;
if ((threadIdx.x + k * TiledSize) < frame_size) {
a0 = prev_output_value[threadIdx.x + (k * TiledSize)];
}
for (int i = 0; i < TiledSize; ++i) {
if (col < frame_size * 2 && (i + k * TiledSize) < frame_size) {
b0[i] = gate_weight[(i + k * TiledSize) * frame_size * 2 + col];
}
}
for (int i = 0; i < TiledSize; ++i) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
c0 = c0 + __shfl_sync(tiled_mask, a0, i, TiledSize) * b0[i];
#else
c0 = c0 + __shfl(a0, i, TiledSize) * b0[i];
#endif
}
}
}
__syncthreads();
if (col < frame_size * 2) {
xt_0 = gate_value[col];
c0 += xt_0;
if (active_node == ActivationType::kSigmoid) {
c0 = Sigmoid(c0);
} else if (active_node == ActivationType::kReLU) {
c0 = ReLU(c0);
} else if (active_node == ActivationType::kTanh) {
c0 = Tanh(c0);
}
gate_value[col] = c0;
if (frame_size <= col && col < frame_size * 2) {
T htp_0 = 0.0;
if (prev_output_value) {
htp_0 = prev_output_value[col - frame_size];
}
reset_output[col - frame_size] = c0 * htp_0;
} else if (col < frame_size) {
gate_value[col] = c0;
}
}
}
template <class T, int TiledSize>
__global__ void FastCollectiveGruOut(T* gate_weight,
T* prev_out_value,
T* output_value,
T* gate_value,
T* reset_value,
int frame_size,
ActivationType active_node,
bool origin_mode) {
int col = blockIdx.x * blockDim.x + threadIdx.x;
T a0 = 0.0f;
T b0[TiledSize];
T c0 = 0.0f;
int tiled_mask = ((1 << TiledSize) - 1);
if (prev_out_value) {
for (int k = 0; k < ((frame_size - 1) / TiledSize + 1); ++k) {
a0 = 0;
if ((threadIdx.x + k * TiledSize) < frame_size) {
a0 = reset_value[threadIdx.x + k * TiledSize];
}
for (int i = 0; i < TiledSize; ++i) {
if (col < frame_size && (i + k * TiledSize) < frame_size) {
b0[i] = gate_weight[(i + k * TiledSize) * frame_size + col];
}
}
for (int i = 0; i < TiledSize; ++i) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
c0 = c0 + __shfl_sync(tiled_mask, a0, i, TiledSize) * b0[i];
#else
c0 = c0 + __shfl(a0, i, TiledSize) * b0[i];
#endif
}
}
}
__syncthreads();
if (col < frame_size) {
T xt_0 = gate_value[col + 2 * frame_size];
T gta_0 = gate_value[col];
T htp_0 = 0;
if (prev_out_value) {
htp_0 = prev_out_value[col];
}
c0 += xt_0;
if (active_node == ActivationType::kSigmoid) {
c0 = Sigmoid(c0);
} else if (active_node == ActivationType::kReLU) {
c0 = ReLU(c0);
} else if (active_node == ActivationType::kTanh) {
c0 = Tanh(c0);
}
gate_value[col + 2 * frame_size] = c0;
if (origin_mode) {
output_value[col] = htp_0 * gta_0 + (1 - gta_0) * c0;
} else {
output_value[col] = c0 * gta_0 + (1 - gta_0) * htp_0;
}
}
}
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/backends/cuda/math/sequence2batch.h"
#include "lite/backends/cuda/math/utils.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <typename T>
__global__ void CopyMatrixRowsKernel(const T* src,
T* dst,
const uint64_t* index,
int height,
int width,
bool is_src_index) {
int idx = threadIdx.x;
int idy = threadIdx.y;
int row_id = blockDim.y * gridDim.x + idy;
if (row_id < height) {
int src_idx = is_src_index ? index[row_id] : row_id;
int dst_idx = is_src_index ? row_id : index[row_id];
const T* src_data = src + src_idx * width;
T* dst_data = dst + dst_idx * width;
for (int i = idx; i < width; i += blockDim.x) {
dst_data[i] = src_data[i];
}
}
}
template <typename T>
void CopyMatrixRowsFunctor<T>::operator()(
const lite::Tensor& src,
lite::Tensor* dst,
const std::vector<uint64_t>& index_lod,
bool is_src_index,
const cudaStream_t& stream) {
auto src_dims = src.dims();
auto dst_dims = dst->dims();
CHECK_EQ(src_dims.size(), 2) << "The src must be matrix with rank 2.";
CHECK_EQ(dst_dims.size(), 2) << "The dst must be matrix with rank 2.";
CHECK_EQ(src_dims[1], dst_dims[1])
<< "The width of src and dst must be same.";
int height = dst_dims[0];
int width = dst_dims[1];
const auto* src_data = src.data<T>();
auto* dst_data = dst->template mutable_data<T>(TARGET(kCUDA));
index_tensor_.Resize({static_cast<int64_t>(index_lod.size())});
auto* index_tensor_data = index_tensor_.mutable_data<uint64_t>(TARGET(kCUDA));
TargetWrapperCuda::MemcpyAsync(index_tensor_data,
index_lod.data(),
sizeof(uint64_t) * index_lod.size(),
IoDirection::HtoD,
stream);
dim3 threads(128, 8);
dim3 grids((height + threads.y - 1) / threads.y);
CopyMatrixRowsKernel<T><<<grids, threads, 0, stream>>>(
src_data, dst_data, index_tensor_data, height, width, true);
CUDA_POST_KERNEL_CHECK;
}
template class CopyMatrixRowsFunctor<float>;
template class CopyMatrixRowsFunctor<half>;
template class LoDTensor2BatchFunctor<float>;
template class LoDTensor2BatchFunctor<half>;
template class Batch2LoDTensorFunctor<float>;
template class Batch2LoDTensorFunctor<half>;
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include <algorithm>
#include <string>
#include <vector>
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/core/context.h"
#include "lite/core/tensor.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <typename T>
class CopyMatrixRowsFunctor {
public:
// If is_src_index is true, copy the indexed rows of input src to the output
// dst. If is_src_index is false, copy the input src to the indexed of output
// dst. The indexes rows are based on the input index.
void operator()(const lite::Tensor& src,
lite::Tensor* dst,
const std::vector<uint64_t>& index_lod,
bool is_src_index,
const cudaStream_t& stream);
private:
lite::Tensor index_tensor_;
};
template <typename T>
class LoDTensor2BatchFunctor {
// Calculate the length of each sequence and
// sort sequence index by the length.
// example: sequences = {s0, s1, s2}
// s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2
// seq_info[3] = {(4, 5, 1), (0, 4, 0), (9, 3, 2)}
struct SeqInfo {
SeqInfo(size_t start, size_t length, size_t seq_idx)
: start_(start), length_(length), seq_idx_(seq_idx) {}
size_t start_;
size_t length_;
size_t seq_idx_;
};
public:
void operator()(const lite::Tensor& lod_tensor,
lite::Tensor* batch_tensor,
bool is_reverse,
const cudaStream_t& stream) const {
auto lods = lod_tensor.lod();
CHECK_EQ(lods.size(), 1UL) << "Only support one level sequence now.";
const auto& lod = lods[0];
std::vector<SeqInfo> seq_info;
for (int seq_id = 0; seq_id < static_cast<int>(lod.size()) - 1; ++seq_id) {
size_t length = lod[seq_id + 1] - lod[seq_id];
seq_info.emplace_back(lod[seq_id], length, seq_id);
}
std::sort(seq_info.begin(), seq_info.end(), [](SeqInfo a, SeqInfo b) {
return a.length_ > b.length_;
});
// Calculate the start position of each batch.
// example: sequences = {s0, s1, s2}
// s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2
// max_seqlen = 5,
// batchIndex = {b0, b1, b2, b3, b4}
// b0: 1 0 2, b1: 1 0 2, b2: 1 0 2, b3: 1 0, b4: 1
// batch_start_positions[6] = {0, 3, 6, 9, 11, 12}
// batch_start_positions[0] = 0
// batch_start_positions[1] = len(b0)
// batch_start_positions[2] = len(b0) + len(b1)
// ...
// seq2batch_idx[12] = {4, 0, 9,
// 5, 1, 10,
// 6, 2, 11,
// 7, 3,
// 8}
// seq_order = {1, 0, 2}, the sort order.
// where 1 is the second sequence,
// 0 is the first sequence,
// 2 is the third sequence.
LoD batch_lods;
batch_lods.emplace_back(std::vector<uint64_t>{0});
batch_lods.emplace_back(std::vector<uint64_t>{0});
batch_lods.emplace_back(std::vector<uint64_t>{0});
// batch_lods[0] is the start positions for batch LoDTensor
size_t max_seqlen = seq_info[0].length_;
batch_lods[0].resize(max_seqlen + 1);
// batch_lods[1] is the raw index in the input LoDTensor
batch_lods[1].resize(static_cast<size_t>(lod_tensor.dims()[0]));
// batch_lods[2] is the sort order for the input LoDTensor.
batch_lods[2].resize(seq_info.size());
auto* batch_starts = batch_lods[0].data();
auto* seq2batch_idx = batch_lods[1].data();
batch_starts[0] = 0;
for (size_t n = 0; n < max_seqlen; ++n) {
size_t batch_id = batch_starts[n];
for (size_t i = 0; i < seq_info.size(); ++i) {
size_t seq_len = seq_info[i].length_;
size_t start = seq_info[i].start_;
if (n < seq_len) {
seq2batch_idx[batch_id] =
is_reverse ? start + seq_len - 1 - n : start + n;
++batch_id;
} else {
break;
}
}
batch_starts[n + 1] = batch_id;
}
auto* seq_order = batch_lods[2].data();
for (size_t i = 0; i < seq_info.size(); ++i) {
seq_order[i] = seq_info[i].seq_idx_;
}
batch_tensor->set_lod(batch_lods);
lite::cuda::math::CopyMatrixRowsFunctor<T> to_batch;
to_batch(lod_tensor, batch_tensor, batch_lods[1], true, stream);
CUDA_POST_KERNEL_CHECK;
}
};
template <typename T>
class Batch2LoDTensorFunctor {
public:
void operator()(const lite::Tensor& batch_tensor,
lite::Tensor* lod_tensor,
const cudaStream_t& stream) {
auto in_lod = batch_tensor.lod();
CHECK_GT(in_lod.size(), 2UL) << "The LoD of LoDTensor should include at "
"least 2-level sequence infomation.";
CHECK_EQ(in_lod[1].size(), static_cast<size_t>(lod_tensor->dims()[0]))
<< "The LoD information should be consistent with the dims.";
lite::cuda::math::CopyMatrixRowsFunctor<T> to_seq;
to_seq(batch_tensor, lod_tensor, in_lod[1], false, stream);
CUDA_POST_KERNEL_CHECK;
}
};
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
......@@ -32,6 +32,16 @@ class TargetWrapper<TARGET(kCUDA)> {
static size_t num_devices();
static size_t maximum_stream() { return 0; }
static int GetComputeCapability() {
int dev_id = GetCurDevice();
int major, minor;
CUDA_CALL(cudaDeviceGetAttribute(
&major, cudaDevAttrComputeCapabilityMajor, dev_id));
CUDA_CALL(cudaDeviceGetAttribute(
&minor, cudaDevAttrComputeCapabilityMinor, dev_id));
return major * 10 + minor;
}
static int GetCurDevice() {
int dev_id;
CUDA_CALL(cudaGetDevice(&dev_id));
......
......@@ -38,17 +38,20 @@ CLRuntime::~CLRuntime() {
}
bool CLRuntime::Init() {
if (initialized_) {
if (is_cl_runtime_initialized_) {
return true;
}
bool is_platform_init = InitializePlatform();
bool is_device_init = InitializeDevice();
is_init_success_ = is_platform_init && is_device_init;
initialized_ = true;
context_ = CreateContext();
command_queue_ = CreateCommandQueue(context());
return initialized_;
LOG(INFO) << "is_platform_init:" << is_platform_init;
LOG(INFO) << "is_device_init:" << is_device_init;
if ((is_platform_init == true) && (is_device_init == true)) {
is_platform_device_init_success_ = true;
context_ = CreateContext();
command_queue_ = CreateCommandQueue(context());
is_cl_runtime_initialized_ = true;
}
return is_cl_runtime_initialized_;
}
cl::Platform& CLRuntime::platform() {
......@@ -64,7 +67,9 @@ cl::Context& CLRuntime::context() {
}
cl::Device& CLRuntime::device() {
CHECK(device_ != nullptr) << "device_ is not initialized!";
if (device_ == nullptr) {
LOG(ERROR) << "device_ is not initialized!";
}
return *device_;
}
......@@ -150,6 +155,14 @@ GpuType CLRuntime::ParseGpuTypeFromDeviceName(std::string device_name) {
}
bool CLRuntime::InitializeDevice() {
VLOG(3) << "device_info_.size():" << device_info_.size();
for (auto i : device_info_) {
VLOG(3) << ">>> " << i.first << " " << i.second;
}
if (device_info_.size() > 0 && device_info_.size() <= 2) {
return false;
}
device_info_["PLACEHOLDER"] = 1;
// ===================== BASIC =====================
// CL_DEVICE_TYPE_GPU
// CL_DEVICE_NAME
......@@ -160,7 +173,7 @@ bool CLRuntime::InitializeDevice() {
status_ = platform_->getDevices(CL_DEVICE_TYPE_GPU, &all_devices);
CL_CHECK_ERROR(status_);
if (all_devices.empty()) {
LOG(FATAL) << "No OpenCL GPU device found!";
LOG(ERROR) << "No available OpenCL GPU device found!";
return false;
}
device_ = std::make_shared<cl::Device>();
......@@ -313,9 +326,6 @@ bool CLRuntime::InitializeDevice() {
}
std::map<std::string, size_t>& CLRuntime::GetDeviceInfo() {
if (0 != device_info_.size()) {
return device_info_;
}
InitializeDevice();
return device_info_;
}
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include <vector>
#include "lite/backends/opencl/cl_include.h"
#include "lite/backends/opencl/cl_utility.h"
#include "lite/backends/opencl/cl_wrapper.h"
typedef enum {
UNKNOWN = 0,
......@@ -68,6 +69,28 @@ class CLRuntime {
public:
static CLRuntime* Global();
bool OpenCLAvaliableForDevice() {
bool opencl_lib_found = paddle::lite::CLWrapper::Global()->OpenclLibFound();
LOG(INFO) << "opencl_lib_found:" << opencl_lib_found;
if (opencl_lib_found == false) return false;
bool dlsym_success = paddle::lite::CLWrapper::Global()->DlsymSuccess();
LOG(INFO) << "dlsym_success:" << dlsym_success;
if (opencl_lib_found == false) return false;
InitializeDevice();
bool support_fp16 =
static_cast<bool>(device_info_["CL_DEVICE_EXTENSIONS_FP16"]);
LOG(INFO) << "support_fp16:" << support_fp16;
if (support_fp16 == false) return false;
is_device_avaliable_for_opencl_ =
dlsym_success && opencl_lib_found && support_fp16;
LOG(INFO) << "is_device_avaliable_for_opencl_:"
<< is_device_avaliable_for_opencl_;
return is_device_avaliable_for_opencl_;
}
bool Init();
cl::Platform& platform();
......@@ -85,7 +108,7 @@ class CLRuntime {
bool BuildProgram(cl::Program* program, const std::string& options = "");
bool IsInitSuccess() { return is_init_success_; }
bool IsInitSuccess() { return is_platform_device_init_success_; }
std::string cl_path() { return cl_path_; }
......@@ -167,9 +190,11 @@ class CLRuntime {
cl_int status_{CL_SUCCESS};
bool initialized_{false};
bool is_device_avaliable_for_opencl_{false};
bool is_cl_runtime_initialized_{false};
bool is_init_success_{false};
bool is_platform_device_init_success_{false};
};
} // namespace lite
......
......@@ -19,14 +19,16 @@ limitations under the License. */
namespace paddle {
namespace lite {
CLWrapper *CLWrapper::Global() {
static CLWrapper wrapper;
return &wrapper;
}
CLWrapper::CLWrapper() {
CHECK(InitHandle()) << "Fail to initialize the OpenCL library!";
InitFunctions();
opencl_lib_found_ = InitHandle();
CHECK(opencl_lib_found_) << "Fail to initialize the OpenCL library!";
dlsym_success_ = InitFunctions();
}
bool CLWrapper::InitHandle() {
......@@ -68,15 +70,17 @@ bool CLWrapper::InitHandle() {
}
}
void CLWrapper::InitFunctions() {
bool CLWrapper::InitFunctions() {
CHECK(handle_ != nullptr) << "The library handle can't be null!";
bool dlsym_success = true;
#define PADDLE_DLSYM(cl_func) \
do { \
cl_func##_ = (cl_func##Type)dlsym(handle_, #cl_func); \
if (cl_func##_ == nullptr) { \
LOG(FATAL) << "Cannot find the " << #cl_func \
LOG(ERROR) << "Cannot find the " << #cl_func \
<< " symbol in libOpenCL.so!"; \
dlsym_success = false; \
break; \
} \
VLOG(4) << "Loaded the " << #cl_func << " symbol successfully."; \
......@@ -137,6 +141,7 @@ void CLWrapper::InitFunctions() {
PADDLE_DLSYM(clEnqueueCopyImage);
#undef PADDLE_DLSYM
return dlsym_success;
}
} // namespace lite
......
......@@ -508,13 +508,20 @@ class CLWrapper final {
return clEnqueueCopyImage_;
}
bool OpenclLibFound() { return opencl_lib_found_; }
bool DlsymSuccess() { return dlsym_success_; }
private:
CLWrapper();
CLWrapper(const CLWrapper &) = delete;
CLWrapper &operator=(const CLWrapper &) = delete;
bool InitHandle();
void InitFunctions();
bool InitFunctions();
bool opencl_lib_found_{true};
bool dlsym_success_{true};
void *handle_{nullptr};
clGetPlatformIDsType clGetPlatformIDs_{nullptr};
clGetPlatformInfoType clGetPlatformInfo_{nullptr};
clBuildProgramType clBuildProgram_{nullptr};
......
......@@ -192,7 +192,8 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
} else if (is_weight_quantization) {
std::string scale_name = conv_weight_name + "_quant_scale";
if (conv_op_desc->HasAttr(scale_name)) {
auto scale = conv_op_desc->GetAttr<std::vector<float>>(scale_name);
std::vector<float> scale =
conv_op_desc->GetAttr<std::vector<float>>(scale_name);
CHECK_EQ(scale.size(), alpha_tensor.numel());
for (size_t i = 0; i < scale.size(); i++) {
scale[i] *= alpha_data[i];
......
......@@ -84,11 +84,12 @@ cpp::OpDesc TransposeSoftmaxTransposeFuser::GenOpDesc(
op_desc.SetInput("X", {matched.at("x1")->arg()->name});
op_desc.SetOutput("Out", {matched.at("out")->arg()->name});
op_desc.SetAttr("axis",
matched.at("transpose1")
->stmt()
->op_info()
->GetAttr<std::vector<int>>("axis")
.back());
*(matched.at("transpose1")
->stmt()
->op_info()
->GetAttr<std::vector<int>>("axis")
.end() -
1));
return op_desc;
}
......
......@@ -62,15 +62,17 @@ std::string Visualize(mir::SSAGraph* graph) {
<< string_trunc(op_info->GetAttr<std::string>(attr_name)) << "\"";
break;
case AttrType::FLOATS: {
auto vals = op_info->GetAttr<std::vector<float>>(attr_name);
std::vector<float> vals =
op_info->GetAttr<std::vector<float>>(attr_name);
os << ":floats: {" + Join(vals, ",") << "}";
} break;
case AttrType::INTS: {
auto vals = op_info->GetAttr<std::vector<int>>(attr_name);
std::vector<int> vals = op_info->GetAttr<std::vector<int>>(attr_name);
os << ":ints: {" + Join(vals, ",") + "}";
} break;
case AttrType::STRINGS: {
auto vals = op_info->GetAttr<std::vector<std::string>>(attr_name);
std::vector<std::string> vals =
op_info->GetAttr<std::vector<std::string>>(attr_name);
os << ":strings: {" + string_trunc(Join(vals, ",")) << "}";
} break;
default:
......
......@@ -204,7 +204,7 @@ void Program::Build(const cpp::ProgramDesc& prog) {
CHECK(ops_.empty()) << "Executor duplicate Build found";
// Create operators.
auto program = prog;
auto& program = prog;
CHECK(program.BlocksSize());
auto& main_block = *program.GetBlock<cpp::BlockDesc>(0);
for (size_t i = 0; i < main_block.OpsSize(); ++i) {
......@@ -272,7 +272,7 @@ void Program::PrepareWorkspace(const cpp::ProgramDesc& prog,
}
};
auto program = prog;
auto& program = prog;
CHECK(program.BlocksSize());
for (size_t b = 0; b < program.BlocksSize(); ++b) {
auto& main_block = *program.GetBlock<cpp::BlockDesc>(b);
......
......@@ -46,7 +46,8 @@ struct Program {
const std::shared_ptr<Scope>& root,
const std::vector<Place>& valid_places,
const std::vector<std::string>& var_names = {})
: scope_(root), valid_places_(valid_places), desc_(desc) {
: scope_(root), valid_places_(valid_places) {
desc_.CopyFrom(desc);
CHECK(scope_) << "scope should be init first";
VLOG(4) << "prepare work";
PrepareWorkspace(desc, var_names);
......
......@@ -78,6 +78,28 @@ void RunModel(std::string model_dir,
// 1. Set MobileConfig
MobileConfig config;
config.set_model_from_file(model_dir);
// NOTE: Use android gpu with opencl, you should ensure:
// first, [compile **cpu+opencl** paddlelite
// lib](https://github.com/PaddlePaddle/Paddle-Lite/blob/develop/docs/demo_guides/opencl.md);
// second, [convert and use opencl nb
// model](https://github.com/PaddlePaddle/Paddle-Lite/blob/develop/docs/user_guides/opt/opt_bin.md).
//
/* Uncomment code below to enable OpenCL
bool is_opencl_backend_valid = ::IsOpenCLBackendValid();
std::cout << "is_opencl_backend_valid:" << is_opencl_backend_valid <<
std::endl;
if (is_opencl_backend_valid) {
// give opencl nb model dir
config.set_model_from_file(model_dir);
} else {
std::cout << "Unsupport opencl nb model." << std::endl;
exit(1);
// you can give backup cpu nb model instead
// config.set_model_from_file(cpu_nb_model_dir);
}
*/
// NOTE: To load model transformed by model_optimize_tool before
// release/v2.3.0, plese use `set_model_dir` API as listed below.
// config.set_model_dir(model_dir);
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <cstddef>
#include <string>
#include <vector>
#include "lite/backends/arm/math/conv_block_utils.h"
#include "lite/backends/arm/math/conv_impl.h"
#include "lite/backends/arm/math/sgemm.h"
#include "lite/core/op_registry.h"
......
......@@ -7,6 +7,7 @@ message(STATUS "compile with lite CUDA kernels")
# basic kernels
add_kernel(mul_compute_cuda CUDA basic SRCS mul_compute.cc DEPS ${lite_kernel_deps} ${math_cuda})
add_kernel(fc_compute_cuda CUDA basic SRCS fc_compute.cu DEPS ${lite_kernel_deps} ${math_cuda})
add_kernel(gru_compute_cuda CUDA basic SRCS gru_compute.cu DEPS ${lite_kernel_deps} ${math_cuda})
add_kernel(matmul_compute_cuda CUDA basic SRCS matmul_compute.cc DEPS ${lite_kernel_deps} ${math_cuda})
add_kernel(search_group_padding_compute_cuda CUDA basic SRCS search_group_padding_compute.cu DEPS ${lite_kernel_deps})
add_kernel(io_copy_compute_cuda CUDA basic SRCS io_copy_compute.cc DEPS ${lite_kernel_deps})
......@@ -14,6 +15,7 @@ add_kernel(leaky_relu_compute_cuda CUDA basic SRCS leaky_relu_compute.cu DEPS ${
add_kernel(abs_compute_cuda CUDA basic SRCS abs_compute.cu DEPS ${lite_kernel_deps})
add_kernel(tanh_compute_cuda CUDA basic SRCS tanh_compute.cu DEPS ${lite_kernel_deps})
add_kernel(relu_compute_cuda CUDA basic SRCS relu_compute.cu DEPS ${lite_kernel_deps})
add_kernel(sigmoid_compute_cuda CUDA basic SRCS sigmoid_compute.cu DEPS ${lite_kernel_deps} ${math_cuda})
add_kernel(yolo_box_compute_cuda CUDA basic SRCS yolo_box_compute.cu DEPS ${lite_kernel_deps})
add_kernel(sequence_pool_compute_cuda CUDA extra SRCS sequence_pool_compute.cu DEPS ${lite_kernel_deps})
add_kernel(sequence_pool_concat_compute_cuda CUDA extra SRCS sequence_pool_concat_compute.cu DEPS ${lite_kernel_deps})
......@@ -60,6 +62,7 @@ nv_test(leaky_relu_compute_cuda_test SRCS leaky_relu_compute_test.cc DEPS leaky_
nv_test(abs_compute_cuda_test SRCS abs_compute_test.cc DEPS abs_compute_cuda)
nv_test(tanh_compute_cuda_test SRCS tanh_compute_test.cc DEPS tanh_compute_cuda)
nv_test(relu_compute_cuda_test SRCS relu_compute_test.cc DEPS relu_compute_cuda)
nv_test(sigmoid_compute_cuda_test SRCS sigmoid_compute_test.cc DEPS sigmoid_compute_cuda)
nv_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_compute_cuda)
nv_test(transpose_compute_cuda_test SRCS transpose_compute_test.cc DEPS transpose_compute_cuda)
nv_test(search_group_padding_compute_cuda_test SRCS search_group_padding_compute_test.cc DEPS search_group_padding_compute_cuda)
......@@ -69,6 +72,7 @@ nv_test(softmax_compute_cuda_test SRCS softmax_compute_test.cc DEPS softmax_comp
#nv_test(layout_cuda_test SRCS layout_compute_test.cc DEPS layout_compute_cuda)
nv_test(mul_compute_cuda_test SRCS mul_compute_test.cc DEPS mul_compute_cuda)
nv_test(fc_compute_cuda_test SRCS fc_compute_test.cc DEPS fc_compute_cuda)
nv_test(gru_compute_cuda_test SRCS gru_compute_test.cc DEPS gru_compute_cuda)
nv_test(matmul_compute_cuda_test SRCS matmul_compute_test.cc DEPS matmul_compute_cuda)
nv_test(dropout_compute_cuda_test SRCS dropout_compute_test.cc DEPS dropout_compute_cuda )
nv_test(bilinear_interp_compute_cuda_test SRCS bilinear_interp_compute_test.cc DEPS bilinear_interp_compute_cuda)
......
......@@ -69,7 +69,7 @@ void concat_compute_ref(const operators::ConcatParam& param) {
std::vector<int> input_cols(input.size());
for (int i = 0; i < num; ++i) {
int input_i_numel = input[i]->dims().size() == 0 ? 0 : 1;
for (int didx = 0; didx < input[i]->dims().size(); ++didx) {
for (size_t didx = 0; didx < input[i]->dims().size(); ++didx) {
input_i_numel *= input[i]->dims()[didx];
}
int t_cols = input_i_numel / rows;
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/backends/cuda/math/bias.h"
#include "lite/backends/cuda/math/gru_forward.h"
#include "lite/backends/cuda/math/sequence2batch.h"
#include "lite/backends/cuda/target_wrapper.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/gru_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T>
struct GRUMetaValue {
T* gate_weight;
T* state_weight;
T* gate_value;
T* reset_output_value;
T* output_value;
T* prev_out_value;
};
template <typename T>
struct GRUUnitFunctor {
static void compute(GRUMetaValue<T> value,
int frame_size,
int batch_size,
const lite::cuda::math::ActivationType& active_node,
const lite::cuda::math::ActivationType& active_gate,
bool origin_mode,
lite::cuda::math::Gemm<T, T>* blas,
CUDAContext* context) {
dim3 threads, grids;
if (batch_size == 1) {
if (lite::TargetWrapperCuda::GetComputeCapability() >= 70) {
if (frame_size < 16) {
constexpr int tiled_size = 8;
int frame_blocks = (frame_size * 2 + tiled_size - 1) / tiled_size;
threads = dim3(tiled_size, 1);
grids = dim3(frame_blocks, 1);
lite::cuda::math::FastCollectiveGruGate<
T,
tiled_size><<<grids, threads, 0, context->exec_stream()>>>(
value.gate_value,
value.prev_out_value,
value.gate_weight,
value.reset_output_value,
frame_size,
active_gate);
frame_blocks = (frame_size + tiled_size - 1) / tiled_size;
grids = dim3(frame_blocks, 1);
lite::cuda::math::FastCollectiveGruOut<
T,
tiled_size><<<grids, threads, 0, context->exec_stream()>>>(
value.state_weight,
value.prev_out_value,
value.output_value,
value.gate_value,
value.reset_output_value,
frame_size,
active_node,
origin_mode);
} else {
constexpr int tiled_size = 16;
int frame_blocks = (frame_size * 2 + tiled_size - 1) / tiled_size;
threads = dim3(tiled_size, 1);
grids = dim3(frame_blocks, 1);
lite::cuda::math::FastCollectiveGruGate<
T,
tiled_size><<<grids, threads, 0, context->exec_stream()>>>(
value.gate_value,
value.prev_out_value,
value.gate_weight,
value.reset_output_value,
frame_size,
active_gate);
frame_blocks = (frame_size + tiled_size - 1) / tiled_size;
grids = dim3(frame_blocks, 1);
lite::cuda::math::FastCollectiveGruOut<
T,
tiled_size><<<grids, threads, 0, context->exec_stream()>>>(
value.state_weight,
value.prev_out_value,
value.output_value,
value.gate_value,
value.reset_output_value,
frame_size,
active_node,
origin_mode);
}
return;
} else {
int frame_per_block = frame_size <= 1024 ? frame_size : 1024;
int frame_blocks = (frame_size + 1024 - 1) / 1024;
threads = dim3(frame_per_block, 1);
grids = dim3(frame_blocks, 1);
}
} else {
threads = dim3(32, 32);
grids = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32);
}
if (value.prev_out_value) {
CHECK(blas->init(false,
false,
batch_size,
frame_size * 2,
frame_size,
frame_size,
frame_size * 2,
frame_size * 3,
context));
blas->run(1.0f,
1.0f,
value.prev_out_value,
value.gate_weight,
value.gate_value,
context);
}
CUDA_POST_KERNEL_CHECK;
lite::cuda::math::GruForwardResetOutput<
T><<<grids, threads, 0, context->exec_stream()>>>(
value.gate_value,
value.reset_output_value,
value.prev_out_value,
frame_size,
batch_size,
active_gate,
batch_size == 1);
CUDA_POST_KERNEL_CHECK;
if (value.prev_out_value) {
CHECK(blas->init(false,
false,
batch_size,
frame_size,
frame_size,
frame_size,
frame_size,
frame_size * 3,
context));
blas->run(1.0f,
1.0f,
value.reset_output_value,
value.state_weight,
value.gate_value + frame_size * 2,
context);
}
CUDA_POST_KERNEL_CHECK;
lite::cuda::math::GruForwardFinalOutput<
T><<<grids, threads, 0, context->exec_stream()>>>(value.gate_value,
value.prev_out_value,
value.output_value,
frame_size,
batch_size,
active_node,
origin_mode,
batch_size == 1);
CUDA_POST_KERNEL_CHECK;
}
};
template struct GRUUnitFunctor<float>;
template <>
struct GRUUnitFunctor<half> {
static void compute(GRUMetaValue<half> value,
int frame_size,
int batch_size,
const lite::cuda::math::ActivationType& active_node,
const lite::cuda::math::ActivationType& active_gate,
bool origin_mode,
lite::cuda::math::Gemm<half, half>* blas,
CUDAContext* context) {
dim3 threads, grids;
if (batch_size == 1) {
int frame_per_block = frame_size <= 1024 ? frame_size : 1024;
int frame_blocks = (frame_size + 1024 - 1) / 1024;
threads = dim3(frame_per_block, 1);
grids = dim3(frame_blocks, 1);
} else {
threads = dim3(32, 32);
grids = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32);
}
if (value.prev_out_value) {
CHECK(blas->init(false,
false,
batch_size,
frame_size * 2,
frame_size,
frame_size,
frame_size * 2,
frame_size * 3,
context));
blas->run(1.0f,
1.0f,
value.prev_out_value,
value.gate_weight,
value.gate_value,
context);
}
CUDA_POST_KERNEL_CHECK;
lite::cuda::math::GruForwardResetOutput<
half><<<grids, threads, 0, context->exec_stream()>>>(
value.gate_value,
value.reset_output_value,
value.prev_out_value,
frame_size,
batch_size,
active_gate,
batch_size == 1);
CUDA_POST_KERNEL_CHECK;
if (value.prev_out_value) {
CHECK(blas->init(false,
false,
batch_size,
frame_size,
frame_size,
frame_size,
frame_size,
frame_size * 3,
context));
blas->run(1.0f,
1.0f,
value.reset_output_value,
value.state_weight,
value.gate_value + frame_size * 2,
context);
}
CUDA_POST_KERNEL_CHECK;
lite::cuda::math::GruForwardFinalOutput<
half><<<grids, threads, 0, context->exec_stream()>>>(
value.gate_value,
value.prev_out_value,
value.output_value,
frame_size,
batch_size,
active_node,
origin_mode,
batch_size == 1);
CUDA_POST_KERNEL_CHECK;
}
};
template <typename T, PrecisionType PType>
void GRUCompute<T, PType>::PrepareForRun() {
gemm_impl_.reset(new lite::cuda::math::Gemm<T, T>);
}
template <typename T, PrecisionType PType>
void GRUCompute<T, PType>::Run() {
auto& context = this->ctx_->template As<CUDAContext>();
auto stream = context.exec_stream();
auto& param = this->template Param<param_t>();
auto* input = param.input;
lite::Tensor* h0{nullptr};
if (param.h0) {
h0 = const_cast<lite::Tensor*>(param.h0);
}
lite::Tensor* bias{nullptr};
if (param.bias) {
bias = const_cast<lite::Tensor*>(param.bias);
}
const lite::Tensor* weight = param.weight;
T* weight_data = const_cast<T*>(weight->template data<T>());
lite::Tensor* batch_gate = param.batch_gate;
lite::Tensor* batch_reset_hidden_prev = param.batch_reset_hidden_prev;
lite::Tensor* batch_hidden = param.batch_hidden;
lite::Tensor* hidden = param.hidden;
T* batch_reset_hidden_prev_data =
batch_reset_hidden_prev->template mutable_data<T>(TARGET(kCUDA));
hidden->template mutable_data<T>(TARGET(kCUDA));
T* batch_gate_data = batch_gate->template mutable_data<T>(TARGET(kCUDA));
T* batch_hidden_data = batch_hidden->template mutable_data<T>(TARGET(kCUDA));
bool is_reverse = param.is_reverse;
auto active_node = lite::cuda::math::GetActiveType(param.activation);
auto active_gate = lite::cuda::math::GetActiveType(param.gate_activation);
bool origin_mode = param.origin_mode;
auto hidden_dims = hidden->dims();
int frame_size = hidden_dims[1];
lite::cuda::math::LoDTensor2BatchFunctor<T> batch_func;
batch_func(*input, batch_gate, is_reverse, stream);
if (bias) {
lite::cuda::math::RowwiseAdd<T> add_bias;
add_bias(batch_gate_data,
bias->template data<T>(),
batch_gate_data,
frame_size,
batch_gate->numel(),
stream);
}
GRUMetaValue<T> gru_value;
gru_value.gate_weight = weight_data;
gru_value.state_weight = weight_data + 2 * frame_size * frame_size;
if (h0) {
// Since the batch computing for GRU reorders the input sequences
// according to their length. The initialized cell state also needs
// to reorder.
ordered_h0_.Resize(h0->dims());
lite::cuda::math::CopyMatrixRowsFunctor<T> row_shuffle;
row_shuffle(*h0, &ordered_h0_, batch_gate->lod()[2], true, stream);
gru_value.prev_out_value = ordered_h0_.mutable_data<T>(TARGET(kCUDA));
} else {
gru_value.prev_out_value = nullptr;
}
auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1;
for (size_t n = 0; n < num_batch; ++n) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
int cur_batch_size = bend - bstart;
gru_value.output_value = batch_hidden_data + bstart * frame_size;
gru_value.gate_value = batch_gate_data + bstart * frame_size * 3;
gru_value.reset_output_value =
batch_reset_hidden_prev_data + bstart * frame_size;
GRUUnitFunctor<T>::compute(gru_value,
frame_size,
cur_batch_size,
active_node,
active_gate,
origin_mode,
gemm_impl_.get(),
&context);
gru_value.prev_out_value = gru_value.output_value;
}
lite::cuda::math::Batch2LoDTensorFunctor<T> to_seq;
batch_hidden->set_lod(batch_gate->lod());
to_seq(*batch_hidden, hidden, stream);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
using GRUFp32 =
paddle::lite::kernels::cuda::GRUCompute<float, PRECISION(kFloat)>;
using GRUFp16 = paddle::lite::kernels::cuda::GRUCompute<half, PRECISION(kFP16)>;
REGISTER_LITE_KERNEL(gru, kCUDA, kFloat, kNCHW, GRUFp32, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("H0", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Weight", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("BatchGate", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("BatchResetHiddenPrev", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("BatchHidden", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Hidden", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
REGISTER_LITE_KERNEL(gru, kCUDA, kFP16, kNCHW, GRUFp16, def)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindInput("H0", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindInput("Weight",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindOutput("BatchGate",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindOutput("BatchResetHiddenPrev",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindOutput("BatchHidden",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindOutput("Hidden",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.Finalize();
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "lite/backends/cuda/math/gemm.h"
#include "lite/core/kernel.h"
#include "lite/operators/op_params.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T, PrecisionType PType>
class GRUCompute : public KernelLite<TARGET(kCUDA), PType> {
public:
using param_t = operators::GRUParam;
void PrepareForRun() override;
void Run() override;
virtual ~GRUCompute() = default;
private:
std::unique_ptr<lite::cuda::math::Gemm<T, T>> gemm_impl_{nullptr};
lite::Tensor ordered_h0_;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/gru_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "lite/api/test_helper.h"
#include "lite/utils/float16.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class GRUTest : public ::testing::Test {
protected:
GRUTest()
: batch_(12),
frame_size_(128),
activation_("tanh"),
gate_activation_("sigmoid"),
is_reverse_(false),
origin_mode_(false),
x_shape_({batch_, frame_size_ * 3}),
w_shape_({frame_size_, frame_size_ * 3}),
out_shape_({batch_, frame_size_}),
lod_({{0, 4, 9, 12}}) {
x_ref_.Resize(lite::DDim(x_shape_));
x_gpu_.Resize(lite::DDim(x_shape_));
x_ref_.set_lod(lod_);
w_ref_.Resize(lite::DDim(w_shape_));
w_gpu_.Resize(lite::DDim(w_shape_));
auto x_ref_data = x_ref_.mutable_data<float>();
auto w_ref_data = w_ref_.mutable_data<float>();
for (int64_t i = 0; i < x_ref_.numel(); i++) {
x_ref_data[i] = static_cast<float>(i % 10 * 0.2);
}
for (int64_t i = 0; i < w_ref_.numel(); i++) {
w_ref_data[i] = static_cast<float>(i % 10 * 0.2);
}
out_ref_.Resize(lite::DDim(out_shape_));
out_cpu_.Resize(out_ref_.dims());
out_gpu_.Resize(out_ref_.dims());
batch_gate_gpu_.Resize(lite::DDim(x_shape_));
batch_hidden_gpu_.Resize(lite::DDim(out_shape_));
batch_reset_hidden_gpu_.Resize(lite::DDim(out_shape_));
RunBaseLine();
InitParamAndContext();
}
void InitParamAndContext() {
ctx_.reset(new KernelContext);
cudaStreamCreate(&stream_);
auto& context = ctx_->As<CUDAContext>();
context.SetExecStream(stream_);
param_.input = &x_gpu_;
param_.weight = &w_gpu_;
param_.gate_activation = gate_activation_;
param_.activation = activation_;
param_.is_reverse = is_reverse_;
param_.origin_mode = origin_mode_;
param_.hidden = &out_gpu_;
param_.batch_gate = &batch_gate_gpu_;
param_.batch_reset_hidden_prev = &batch_reset_hidden_gpu_;
param_.batch_hidden = &batch_hidden_gpu_;
}
void InitFloatInput() {
x_gpu_.Assign<float, lite::DDim, TARGET(kCUDA)>(x_ref_.data<float>(),
x_gpu_.dims());
x_gpu_.set_lod(x_ref_.lod());
w_gpu_.Assign<float, lite::DDim, TARGET(kCUDA)>(w_ref_.data<float>(),
w_gpu_.dims());
}
void InitHalfInput() {
x_half_.Resize(lite::DDim(x_shape_));
auto x_half_data = x_half_.mutable_data<half>();
for (int64_t i = 0; i < x_half_.numel(); i++) {
x_half_data[i] = half(lite::float16(x_ref_.data<float>()[i]));
}
x_gpu_.Assign<half, lite::DDim, TARGET(kCUDA)>(x_half_data, x_gpu_.dims());
x_gpu_.set_lod(x_ref_.lod());
w_half_.Resize(w_ref_.dims());
auto w_half_data = w_half_.mutable_data<half>();
for (int64_t i = 0; i < w_half_.numel(); i++) {
w_half_data[i] = half(lite::float16(w_ref_.data<float>()[i]));
}
w_gpu_.Assign<half, lite::DDim, TARGET(kCUDA)>(w_half_data, w_gpu_.dims());
}
void RunBaseLine() {}
int batch_, frame_size_;
std::string activation_, gate_activation_;
bool is_reverse_, origin_mode_;
std::vector<int64_t> x_shape_, w_shape_, out_shape_;
LoD lod_;
lite::Tensor x_ref_, w_ref_, out_ref_;
lite::Tensor x_gpu_, w_gpu_;
lite::Tensor x_half_, w_half_;
lite::Tensor batch_gate_gpu_;
lite::Tensor batch_hidden_gpu_;
lite::Tensor batch_reset_hidden_gpu_;
lite::Tensor out_cpu_, out_gpu_;
operators::GRUParam param_;
std::unique_ptr<KernelContext> ctx_;
cudaStream_t stream_;
};
TEST_F(GRUTest, TestFP32) {
InitFloatInput();
GRUCompute<float, PRECISION(kFloat)> kernel;
kernel.SetParam(param_);
kernel.SetContext(std::move(ctx_));
for (int i = 0; i < FLAGS_warmup; ++i) {
kernel.Launch();
cudaDeviceSynchronize();
}
auto start = GetCurrentUS();
kernel.PrepareForRun();
for (int i = 0; i < FLAGS_repeats; ++i) {
kernel.Run();
}
cudaDeviceSynchronize();
auto duration = (GetCurrentUS() - start) / 1000.0;
LOG(INFO) << "fp32, warmup: " << FLAGS_warmup
<< ", repeats: " << FLAGS_repeats << ", spend "
<< duration / FLAGS_repeats << " ms in average.";
}
TEST_F(GRUTest, TestFP16) {
InitHalfInput();
GRUCompute<half, PRECISION(kFP16)> kernel;
kernel.SetParam(param_);
kernel.SetContext(std::move(ctx_));
for (int i = 0; i < FLAGS_warmup; ++i) {
kernel.Launch();
cudaDeviceSynchronize();
}
auto start = GetCurrentUS();
kernel.PrepareForRun();
for (int i = 0; i < FLAGS_repeats; ++i) {
kernel.Run();
}
cudaDeviceSynchronize();
auto duration = (GetCurrentUS() - start) / 1000.0;
LOG(INFO) << "fp16, warmup: " << FLAGS_warmup
<< ", repeats: " << FLAGS_repeats << ", spend "
<< duration / FLAGS_repeats << " ms in average.";
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/backends/cuda/math/activation.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/sigmoid_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T, PrecisionType Ptype>
void SigmoidCompute<T, Ptype>::Run() {
auto& param = this->template Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
int num = static_cast<int>(param.X->numel());
auto input = param.X->template data<T>();
auto output = param.Out->template mutable_data<T>(TARGET(kCUDA));
lite::cuda::math::sigmoid<T>(num, input, output, stream);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
using SigmoidFp32 =
paddle::lite::kernels::cuda::SigmoidCompute<float, PRECISION(kFloat)>;
using SigmoidFp16 =
paddle::lite::kernels::cuda::SigmoidCompute<half, PRECISION(kFP16)>;
REGISTER_LITE_KERNEL(sigmoid, kCUDA, kFloat, kNCHW, SigmoidFp32, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
REGISTER_LITE_KERNEL(sigmoid, kCUDA, kFP16, kNCHW, SigmoidFp16, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T, PrecisionType Ptype>
class SigmoidCompute : public KernelLite<TARGET(kCUDA), Ptype> {
public:
using param_t = operators::ActivationParam;
void Run() override;
virtual ~SigmoidCompute() = default;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/sigmoid_compute.h"
#include <gtest/gtest.h>
#include <cmath>
#include <memory>
#include <utility>
#include <vector>
#include "lite/api/test_helper.h"
#include "lite/backends/cuda/target_wrapper.h"
#include "lite/utils/float16.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class SigmoidTest : public ::testing::Test {
protected:
SigmoidTest() : m_(8), n_(64), shape_({m_, n_}) {
x_ref_.Resize(lite::DDim(shape_));
x_gpu_.Resize(lite::DDim(shape_));
auto x_ref_data = x_ref_.mutable_data<float>();
for (int64_t i = 0; i < x_ref_.numel(); i++) {
x_ref_data[i] = static_cast<float>(i % 10 * 0.2);
}
out_ref_.Resize(lite::DDim(shape_));
out_cpu_.Resize(out_ref_.dims());
out_gpu_.Resize(out_ref_.dims());
RunBaseLine();
InitParamAndContext();
}
void InitParamAndContext() {
ctx_.reset(new KernelContext);
cudaStreamCreate(&stream_);
auto& context = ctx_->As<CUDAContext>();
context.SetExecStream(stream_);
param_.X = &x_gpu_;
param_.Out = &out_gpu_;
}
void InitFloatInput() {
x_gpu_.Assign<float, lite::DDim, TARGET(kCUDA)>(x_ref_.data<float>(),
x_gpu_.dims());
}
void InitHalfInput() {
x_half_.Resize(lite::DDim(shape_));
auto x_half_data = x_half_.mutable_data<half>();
for (int64_t i = 0; i < x_half_.numel(); i++) {
x_half_data[i] = half(lite::float16(x_ref_.data<float>()[i]));
}
x_gpu_.Assign<half, lite::DDim, TARGET(kCUDA)>(x_half_data, x_gpu_.dims());
}
void RunBaseLine() {
for (int64_t i = 0; i < x_ref_.numel(); ++i) {
out_ref_.mutable_data<float>()[i] =
1.f / (1.f + expf(-1 * x_ref_.data<float>()[i]));
}
}
int m_, n_;
std::vector<int64_t> shape_;
lite::Tensor x_ref_, out_ref_;
lite::Tensor x_gpu_;
lite::Tensor x_half_;
lite::Tensor out_cpu_, out_gpu_;
operators::ActivationParam param_;
std::unique_ptr<KernelContext> ctx_;
cudaStream_t stream_;
};
TEST_F(SigmoidTest, TestFP32) {
InitFloatInput();
SigmoidCompute<float, PRECISION(kFloat)> kernel;
kernel.SetParam(param_);
kernel.SetContext(std::move(ctx_));
for (int i = 0; i < FLAGS_warmup; ++i) {
kernel.Launch();
cudaDeviceSynchronize();
}
auto start = GetCurrentUS();
kernel.PrepareForRun();
for (int i = 0; i < FLAGS_repeats; ++i) {
kernel.Run();
}
cudaDeviceSynchronize();
auto duration = (GetCurrentUS() - start) / 1000.0;
LOG(INFO) << "fp32, warmup: " << FLAGS_warmup
<< ", repeats: " << FLAGS_repeats << ", spend "
<< duration / FLAGS_repeats << " ms in average.";
CopySync<TARGET(kCUDA)>(out_cpu_.mutable_data<float>(),
out_gpu_.data<float>(),
sizeof(float) * out_gpu_.numel(),
IoDirection::DtoH);
for (int i = 0; i < out_gpu_.numel(); ++i) {
float res = out_cpu_.data<float>()[i];
float ref = out_ref_.data<float>()[i];
EXPECT_NEAR(fabs(res - ref) / ref, 0.f, 1e-5);
}
}
TEST_F(SigmoidTest, TestFP16) {
InitHalfInput();
SigmoidCompute<half, PRECISION(kFP16)> kernel;
kernel.SetParam(param_);
kernel.SetContext(std::move(ctx_));
for (int i = 0; i < FLAGS_warmup; ++i) {
kernel.Launch();
cudaDeviceSynchronize();
}
auto start = GetCurrentUS();
kernel.PrepareForRun();
for (int i = 0; i < FLAGS_repeats; ++i) {
kernel.Run();
}
cudaDeviceSynchronize();
auto duration = (GetCurrentUS() - start) / 1000.0;
LOG(INFO) << "fp16, warmup: " << FLAGS_warmup
<< ", repeats: " << FLAGS_repeats << ", spend "
<< duration / FLAGS_repeats << " ms in average.";
const half* out_gpu_data = out_gpu_.data<half>();
half* out_cpu_data = out_cpu_.mutable_data<half>();
CopySync<TARGET(kCUDA)>(out_cpu_data,
out_gpu_data,
sizeof(half) * out_gpu_.numel(),
IoDirection::DtoH);
for (int i = 0; i < out_gpu_.numel(); ++i) {
float res = static_cast<float>(lite::float16(out_cpu_data[i]));
float ref = out_ref_.data<float>()[i];
EXPECT_NEAR(fabs(res - ref) / (ref + 1e-5), 0., 2e-2);
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -107,8 +107,7 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CNML_FILTER,
CNML_NCHW,
graph->FPType());
const auto weight_scale =
op_info->GetAttr<std::vector<float>>("weight_scale");
const auto weight_scale = op_info->GetInputScale(filter_var_name);
if (filter->precision() == PrecisionType::kUnk ||
filter->precision() == PrecisionType::kInt8) {
......@@ -162,7 +161,7 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
graph->BindConstData(bias_var_name, bias);
}
const auto input_scale = op_info->GetAttr<float>("input_scale");
const auto input_scale = op_info->GetInputScale(input_var_name)[0];
bool use_first_conv = false;
if (lite::TargetWrapperMlu::UseFirstConv() && input_dims[1] == 3) {
......
......@@ -224,8 +224,10 @@ void test_conv(int bs,
opdesc_mlu.SetAttr("groups", groups);
opdesc_mlu.SetAttr("fuse_relu", static_cast<bool>(fuse_relu));
opdesc_mlu.SetAttr("weight_scale", std::vector<float>(oc, filter_scale));
opdesc_mlu.SetAttr("input_scale", input_scale);
OpInfo op_info(opdesc_mlu);
op_info.SetInputScale(filter_int_var_name,
std::vector<float>(oc, filter_scale));
op_info.SetInputScale(input_var_name, {input_scale});
if (has_bias) {
if (is_channel_bias) {
......@@ -234,7 +236,7 @@ void test_conv(int bs,
bias->Resize({output_shape});
}
FillTensor<float>(bias);
opdesc_mlu.SetInput("Bias", {bias_var_name});
op_info.SetInput("Bias", {bias_var_name});
}
for (int i = 0; i < bs; i++) {
......@@ -248,7 +250,7 @@ void test_conv(int bs,
}
// create and convert op to MLU model, then run it on MLU
auto op = CreateOp<operators::ConvOpLite>(opdesc_mlu, &scope);
auto op = CreateOp<operators::ConvOpLite>(op_info, &scope);
LaunchOp(op, {input_var_name}, {output_var_name});
// compare results
auto* output_data = output->mutable_data<float>();
......
......@@ -68,7 +68,7 @@ int FCConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto w_tensor = graph->AddNode(
w_var_name, cnml_w_shape, CNML_FILTER, CNML_NCHW, graph->FPType());
auto input_scale = op_info->GetAttr<float>("input_scale");
auto input_scale = op_info->GetInputScale(x_var_name)[0];
auto output_tensor = graph->AddNode(output_var_name,
output->dims().Vectorize(),
......@@ -101,7 +101,7 @@ int FCConverter(void* ctx, OpLite* op, KernelBase* kernel) {
bias_tensor ? bias_tensor->mlu_tensor() : nullptr));
graph->SetComputingDataType(
fc_op, graph->GetNode(x_var_name)->mlu_tensor(), 1 / input_scale);
auto weight_scale = op_info->GetAttr<std::vector<float>>("weight_scale");
auto weight_scale = op_info->GetInputScale(w_var_name);
// LOG(INFO) << "W precision " << int(w->precision());
if (w->precision() == PrecisionType::kUnk ||
......
......@@ -131,14 +131,15 @@ void test_fc(const std::vector<int64_t>& input_shape,
fc_op_desc_mlu.SetOutput("Out", {out_var_name});
fc_op_desc_mlu.SetAttr("in_num_col_dims", static_cast<int>(in_num_col_dims));
fc_op_desc_mlu.SetAttr("weight_scale",
std::vector<float>(w_shape[1], w_scale));
fc_op_desc_mlu.SetAttr("input_scale", input_scale);
OpInfo op_info(fc_op_desc_mlu);
op_info.SetInputScale(w_int_var_name,
std::vector<float>(w_shape[1], w_scale));
op_info.SetInputScale(input_var_name, {input_scale});
if (has_bias) {
fc_op_desc_mlu.SetInput("Bias", {bias_var_name});
op_info.SetInput("Bias", {bias_var_name});
}
auto fc_op_mlu = CreateOp<operators::FcOpLite>(fc_op_desc_mlu, &scope);
auto fc_op_mlu = CreateOp<operators::FcOpLite>(op_info, &scope);
Tensor input_tmp, out_tmp;
input_tmp.Resize(input_shape);
......
......@@ -49,8 +49,7 @@ int LrnConverter(void* ctx, OpLite* op, KernelBase* kernel) {
<< "Unsuport WithinChannel";
}
auto local_size = op_info->GetAttr<int>("n");
CHECK(op_info->HasAttr("input_scale"));
auto input_scale = op_info->GetAttr<float>("input_scale");
auto input_scale = op_info->GetInputScale(x_var_name)[0];
VLOG(5) << "lrn input scale: " << input_scale;
cnmlLrnOpParam_t param;
......
......@@ -178,9 +178,10 @@ void test_lrn(float alpha,
opdesc.SetAttr("k", k);
opdesc.SetAttr("n", local_size);
opdesc.SetAttr("norm_region", norm_region);
opdesc.SetAttr<float>("input_scale", (*dmax - *dmin) / 255.f);
OpInfo op_info(opdesc);
op_info.SetInputScale(x_var_name, {(*dmax - *dmin) / 255.f});
auto op = CreateOp<operators::LrnOpLite>(opdesc, &scope);
auto op = CreateOp<operators::LrnOpLite>(op_info, &scope);
// baseline
lrn_compute_ref(op);
......@@ -213,7 +214,7 @@ void test_lrn(float alpha,
auto output_data = output_trans.mutable_data<float>();
auto* output_ref_data = out_ref->mutable_data<float>();
for (size_t i = 0; i < out->data_size(); i++) {
EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-4);
EXPECT_NEAR(output_data[i], output_ref_data[i], 5e-4);
}
}
......
......@@ -54,10 +54,11 @@ class SubgraphEngine : public subgraph::Engine {
VLOG(4) << "[MLU] PADDLE_LITE_MLU_SAVE_OFFLINE_MODEL is "
<< GetBoolFromEnv("PADDLE_LITE_MLU_SAVE_OFFLINE_MODEL");
VLOG(4) << "[MLU] PADDLE_LITE_MLU_DISABLE_BATCH_SIZE_CHANGEABLE is "
<< GetBoolFromEnv("PADDLE_LITE_MLU_DISABLE_BATCH_SIZE_CHANGEABLE");
<< GetBoolFromEnv("PADDLE_LITE_MLU_DISABLE_BATCH_SIZE_CHANGEABLE",
true);
VLOG(4) << "[MLU] LITE_DISABLE_MLU_CAST is "
<< GetBoolFromEnv("LITE_DISABLE_MLU_CAST");
if (GetBoolFromEnv("PADDLE_LITE_MLU_DISABLE_BATCH_SIZE_CHANGEABLE")) {
if (GetBoolFromEnv("PADDLE_LITE_MLU_DISABLE_BATCH_SIZE_CHANGEABLE", true)) {
disable_batch_size_changeable_ = true;
}
}
......
......@@ -54,10 +54,16 @@ class BlockDescWriteAPI {
virtual void SetForwardBlockIdx(int32_t idx) { NotImplemented(); }
template <typename T>
T* AddVar();
T* AddVar() {
NotImplemented();
return nullptr;
}
template <typename T>
T* AddOp();
T* AddOp() {
NotImplemented();
return nullptr;
}
virtual ~BlockDescWriteAPI() = default;
......
......@@ -73,7 +73,9 @@ class OpDescWriteAPI {
}
template <typename T>
void SetAttr(const std::string& name, const T& v);
void SetAttr(const std::string& name, const T& v) {
NotImplemented();
}
virtual ~OpDescWriteAPI() = default;
......
......@@ -40,7 +40,10 @@ class ProgramDescWriteAPI {
virtual void SetVersion(int64_t version) { NotImplemented(); }
template <typename T>
T* AddBlock();
T* AddBlock() {
NotImplemented();
return nullptr;
}
virtual ~ProgramDescWriteAPI() = default;
......
......@@ -57,6 +57,7 @@ class VectorView {
public:
typedef vector_view::VectorTraits<T, U> Traits;
explicit VectorView(typename Traits::vector_type const* cvec) {
CHECK(cvec);
cvec_ = cvec;
}
typename Traits::subscript_return_type operator[](size_t i) const {
......
......@@ -277,7 +277,7 @@ void OpAttrsCppToAny(const cpp::OpDesc &cpp_desc, OpDescType *any_desc) {
template <> \
void TransformProgramDescCppToAny<NT::T>(const cpp::T &cpp_desc, \
NT::T *any_desc) { \
auto desc = cpp_desc; \
auto &desc = cpp_desc; \
if (desc.HasVersion()) { \
any_desc->SetVersion(desc.Version()); \
} \
......
......@@ -8,9 +8,6 @@ endfunction()
lite_fbs_library(fbs_op_desc SRCS op_desc.cc FBS_DEPS framework_fbs_header)
lite_fbs_library(fbs_var_desc SRCS var_desc.cc FBS_DEPS framework_fbs_header)
lite_fbs_library(fbs_block_desc SRCS block_desc.cc FBS_DEPS framework_fbs_header)
lite_fbs_library(fbs_program_desc SRCS program_desc.cc FBS_DEPS framework_fbs_header)
lite_cc_test(test_vector_view SRCS vector_view_test.cc)
if (TARGET test_vector_view)
add_dependencies(test_vector_view framework_fbs_header)
endif()
lite_cc_library(fbs_program_desc SRCS program_desc.cc DEPS fbs_op_desc fbs_var_desc fbs_block_desc)
lite_cc_library(fbs_io SRCS io.cc DEPS fbs_program_desc)
lite_cc_test(test_vector_view SRCS vector_view_test.cc DEPS fbs_program_desc)
......@@ -19,15 +19,27 @@ namespace lite {
namespace fbs {
template <>
proto::VarDesc* BlockDesc::GetVar<proto::VarDesc>(int32_t idx) {
proto::VarDesc const* BlockDesc::GetVar<proto::VarDesc>(int32_t idx) const {
CHECK_LT(idx, VarsSize()) << "idx >= vars.size()";
return const_cast<proto::VarDesc*>(desc_->vars()->Get(idx));
return desc_->vars()->Get(idx);
}
template <>
proto::OpDesc* BlockDesc::GetOp<proto::OpDesc>(int32_t idx) {
proto::OpDesc const* BlockDesc::GetOp<proto::OpDesc>(int32_t idx) const {
CHECK_LT(idx, OpsSize()) << "idx >= ops.size()";
return const_cast<proto::OpDesc*>(desc_->ops()->Get(idx));
return desc_->ops()->Get(idx);
}
template <>
VarDesc const* BlockDesc::GetVar<VarDesc>(int32_t idx) const {
CHECK_LT(idx, VarsSize()) << "idx >= vars.size()";
return &vars_[idx];
}
template <>
OpDesc const* BlockDesc::GetOp<OpDesc>(int32_t idx) const {
CHECK_LT(idx, OpsSize()) << "idx >= ops.size()";
return &ops_[idx];
}
} // namespace fbs
......
......@@ -14,8 +14,11 @@
#pragma once
#include <vector>
#include "lite/model_parser/base/block_desc.h"
#include "lite/model_parser/flatbuffers/framework_generated.h"
#include "lite/model_parser/flatbuffers/op_desc.h"
#include "lite/model_parser/flatbuffers/var_desc.h"
#include "lite/utils/all.h"
namespace paddle {
......@@ -24,7 +27,17 @@ namespace fbs {
class BlockDesc : public BlockDescAPI {
public:
explicit BlockDesc(proto::BlockDesc* desc) : desc_(desc) { CHECK(desc_); }
explicit BlockDesc(proto::BlockDesc const* desc) : desc_(desc) {
CHECK(desc_);
vars_.reserve(VarsSize());
ops_.reserve(OpsSize());
for (size_t idx = 0; idx < VarsSize(); ++idx) {
vars_.push_back(VarDesc(desc_->vars()->Get(idx)));
}
for (size_t idx = 0; idx < OpsSize(); ++idx) {
ops_.push_back(OpDesc(desc_->ops()->Get(idx)));
}
}
int32_t Idx() const override { return desc_->idx(); }
......@@ -33,11 +46,12 @@ class BlockDesc : public BlockDescAPI {
size_t VarsSize() const override { return desc_->vars()->size(); }
template <typename T>
T* GetVar(int32_t idx);
T const* GetVar(int32_t idx) const;
template <typename T>
T const* GetVar(int32_t idx) const {
return GetVar<T>(idx);
T* GetVar(int32_t idx) {
NotImplemented();
return nullptr;
}
size_t OpsSize() const override {
......@@ -47,21 +61,32 @@ class BlockDesc : public BlockDescAPI {
}
template <typename T>
T* GetOp(int32_t idx);
T const* GetOp(int32_t idx) const;
template <typename T>
T const* GetOp(int32_t idx) const {
return GetOp<T>(idx);
T* GetOp(int32_t idx) {
NotImplemented();
return nullptr;
}
const std::vector<VarDesc>& GetVars() const { return vars_; }
int32_t ForwardBlockIdx() const override {
return desc_->forward_block_idx();
}
BlockDesc() = delete;
BlockDesc() { NotImplemented(); }
private:
proto::BlockDesc* desc_; // not_own
proto::BlockDesc const* desc_; // not_own
std::vector<VarDesc> vars_;
std::vector<OpDesc> ops_;
private:
void NotImplemented() const {
LOG(FATAL) << "The additional interfaces of BlockDesc is temporarily "
"unavailable in read-only mode.";
}
};
} // namespace fbs
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/model_parser/flatbuffers/io.h"
#include <memory>
#include <utility>
namespace paddle {
namespace lite {
namespace fbs {
void LoadModel(const std::string& path, ProgramDesc* prog) {
FILE* file = fopen(path.c_str(), "rb");
fseek(file, 0, SEEK_END);
int64_t size = ftell(file);
rewind(file);
char* data = new char[size];
size = fread(data, 1, size, file);
fclose(file);
std::unique_ptr<char[]> buf(data);
prog->Init(std::move(buf));
}
} // namespace fbs
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "lite/model_parser/flatbuffers/program_desc.h"
namespace paddle {
namespace lite {
namespace fbs {
void LoadModel(const std::string& path, ProgramDesc* prog);
} // namespace fbs
} // namespace lite
} // namespace paddle
......@@ -30,7 +30,7 @@ namespace fbs {
class OpDesc : public OpDescAPI {
public:
explicit OpDesc(proto::OpDesc* desc) : desc_(desc) { CHECK(desc_); }
explicit OpDesc(proto::OpDesc const* desc) : desc_(desc) { CHECK(desc_); }
std::string Type() const override { return desc_->type()->str(); }
......@@ -95,7 +95,7 @@ class OpDesc : public OpDescAPI {
OpDescAPI::AttrType GetAttrType(const std::string& name) const override {
const auto& attr = desc_->attrs()->LookupByKey(name.c_str());
CHECK(attr);
CHECK(attr) << "Can not find attr: " << name;
return static_cast<OpDescAPI::AttrType>(attr->type());
}
......@@ -124,10 +124,8 @@ class OpDesc : public OpDescAPI {
template <typename T>
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT GetAttr(size_t idx) const;
OpDesc() = delete;
private:
proto::OpDesc* desc_;
proto::OpDesc const* desc_;
// To reduce overhead, we expect to use namespace aliasing to make cpp::Desc
// and flatbuffers::Desc replace each other. However, there is no direct
......@@ -138,6 +136,7 @@ class OpDesc : public OpDescAPI {
// caused by different building options.
public:
OpDesc() { NotImplemented(); }
bool HasInput(const std::string& param) const {
return desc_->inputs()->LookupByKey(param.c_str()) != nullptr;
}
......
......@@ -19,9 +19,16 @@ namespace lite {
namespace fbs {
template <>
proto::BlockDesc* ProgramDesc::GetBlock<proto::BlockDesc>(int32_t idx) {
proto::BlockDesc const* ProgramDesc::GetBlock<proto::BlockDesc>(
int32_t idx) const {
CHECK_LT(idx, BlocksSize()) << "idx >= blocks.size()";
return const_cast<proto::BlockDesc*>(desc_->blocks()->Get(idx));
return desc_->blocks()->Get(idx);
}
template <>
BlockDesc const* ProgramDesc::GetBlock<BlockDesc>(int32_t idx) const {
CHECK_LT(idx, BlocksSize()) << "idx >= blocks.size()";
return &blocks_[idx];
}
} // namespace fbs
......
......@@ -15,7 +15,10 @@
#pragma once
#include <memory>
#include <utility>
#include <vector>
#include "lite/model_parser/base/program_desc.h"
#include "lite/model_parser/flatbuffers/block_desc.h"
#include "lite/model_parser/flatbuffers/framework_generated.h"
#include "lite/utils/all.h"
......@@ -26,18 +29,40 @@ namespace fbs {
class ProgramDesc : public ProgramDescAPI {
public:
ProgramDesc() = default;
explicit ProgramDesc(proto::ProgramDesc *desc) : desc_(desc) { CHECK(desc); }
explicit ProgramDesc(std::unique_ptr<const char[]> buf) {
Init(std::move(buf));
}
size_t BlocksSize() const override { return desc_->blocks()->size(); }
void Init(std::unique_ptr<const char[]> buf) {
CHECK(buf.get() != nullptr);
buf_ = std::move(buf);
desc_ = proto::GetProgramDesc(buf_.get());
blocks_.reserve(BlocksSize());
for (size_t idx = 0; idx < BlocksSize(); ++idx) {
blocks_.push_back(BlockDesc(desc_->blocks()->Get(idx)));
}
}
void CopyFrom(const ProgramDesc& other) {
size_t length = strlen(static_cast<const char*>(other.raw_buf()));
std::unique_ptr<char[]> buf(new char[length]);
memcpy(buf.get(), other.raw_buf(), length);
Init(std::move(buf));
}
template <typename T>
T *GetBlock(int32_t idx);
T const* GetBlock(int32_t idx) const;
template <typename T>
T const *GetBlock(int32_t idx) const {
return GetBlock<T>(idx);
T* GetBlock(int32_t idx) {
NotImplemented();
return nullptr;
}
const std::vector<BlockDesc>& GetBlocks() const { return blocks_; }
bool HasVersion() const override { return desc_->version() != nullptr; }
int64_t Version() const override {
......@@ -45,8 +70,22 @@ class ProgramDesc : public ProgramDescAPI {
return desc_->version()->version();
}
proto::ProgramDesc const* raw_desc() const { return desc_; }
const void* raw_buf() const { return buf_.get(); }
private:
proto::ProgramDesc *desc_; // not_own
proto::ProgramDesc const* desc_;
std::unique_ptr<const char[]> buf_;
std::vector<BlockDesc> blocks_;
private:
ProgramDesc& operator=(const ProgramDesc&) = delete;
ProgramDesc(const ProgramDesc&) = delete;
void NotImplemented() const {
LOG(FATAL) << "The additional interfaces of ProgramDesc is temporarily "
"unavailable in read-only mode.";
}
};
} // namespace fbs
......
......@@ -27,7 +27,7 @@ namespace fbs {
class VarDesc : public VarDescAPI {
public:
explicit VarDesc(proto::VarDesc* desc) : desc_(desc) {}
explicit VarDesc(proto::VarDesc const* desc) : desc_(desc) {}
std::string Name() const override { return desc_->name()->str(); }
......@@ -48,10 +48,14 @@ class VarDesc : public VarDescAPI {
return dims_vec;
}
VarDesc() = delete;
VarDescAPI::Type GetDataType() const {
CHECK(GetType() == VarDescAPI::Type::LOD_TENSOR);
return static_cast<VarDescAPI::Type>(
desc_->type()->lod_tensor()->tensor()->data_type());
}
private:
proto::VarDesc* desc_;
proto::VarDesc const* desc_;
// To reduce overhead, we expect to use namespace aliasing to make cpp::Desc
// and flatbuffers::Desc replace each other. However, there is no direct
......@@ -62,10 +66,7 @@ class VarDesc : public VarDescAPI {
// caused by different building options.
public:
VarDescAPI::Type GetDataType() const {
NotImplemented();
return data_type_;
}
VarDesc() { NotImplemented(); }
void SetDataType(Type data_type) { NotImplemented(); }
void SetShape(const std::vector<int64_t>& dims) { NotImplemented(); }
......@@ -74,7 +75,6 @@ class VarDesc : public VarDescAPI {
LOG(FATAL) << "The additional interfaces of VarDesc is temporarily "
"unavailable in read-only mode.";
}
Type data_type_;
std::vector<int64_t> shape_;
};
......
......@@ -104,20 +104,32 @@ class VectorView<std::string, Flatbuffers> {
explicit VectorView(typename Traits::vector_type const* cvec) {
cvec_ = cvec;
}
std::string operator[](size_t i) const { return cvec_->operator[](i)->str(); }
std::string operator[](size_t i) const {
CHECK(cvec_);
return cvec_->operator[](i)->str();
}
vector_view::FBSStrIterator begin() const {
CHECK(cvec_);
return vector_view::FBSStrIterator(cvec_->begin());
}
vector_view::FBSStrIterator end() const {
CHECK(cvec_);
return vector_view::FBSStrIterator(cvec_->end());
}
size_t size() const { return cvec_->size(); }
size_t size() const {
if (cvec_ == nullptr) {
return 0;
}
return cvec_->size();
}
operator std::vector<std::string>() const {
VLOG(5) << "Copying elements out of VectorView will damage performance.";
std::vector<std::string> tmp;
tmp.reserve(cvec_->size());
for (auto val : *cvec_) {
tmp.push_back(val->str());
tmp.reserve(size());
if (cvec_ != nullptr) {
for (auto val : *cvec_) {
tmp.push_back(val->str());
}
}
return tmp;
}
......
......@@ -24,6 +24,12 @@ VarDesc* BlockDesc::GetVar<VarDesc>(int32_t idx) {
return &vars_[idx];
}
template <>
VarDesc const* BlockDesc::GetVar<VarDesc>(int32_t idx) const {
CHECK_LT(idx, VarsSize()) << "idx >= vars.size()";
return &vars_[idx];
}
template <>
VarDesc* BlockDesc::AddVar<VarDesc>() {
vars_.emplace_back();
......@@ -36,6 +42,12 @@ OpDesc* BlockDesc::GetOp<OpDesc>(int32_t idx) {
return &ops_[idx];
}
template <>
OpDesc const* BlockDesc::GetOp<OpDesc>(int32_t idx) const {
CHECK_LT(idx, OpsSize()) << "idx >= ops.size()";
return &ops_[idx];
}
template <>
OpDesc* BlockDesc::AddOp<OpDesc>() {
ops_.emplace_back();
......
......@@ -46,12 +46,10 @@ class BlockDesc : public BlockDescAPI {
template <typename T>
T* GetVar(int32_t idx);
std::vector<VarDesc>& GetVars() { return vars_; }
template <typename T>
T const* GetVar(int32_t idx) const {
return GetVar<T>(idx);
}
T const* GetVar(int32_t idx) const;
std::vector<VarDesc>& GetVars() { return vars_; }
template <typename T>
T* AddVar();
......@@ -64,9 +62,7 @@ class BlockDesc : public BlockDescAPI {
T* GetOp(int32_t idx);
template <typename T>
T const* GetOp(int32_t idx) const {
return GetOp<T>(idx);
}
T const* GetOp(int32_t idx) const;
template <typename T>
T* AddOp();
......
......@@ -24,6 +24,12 @@ BlockDesc* ProgramDesc::GetBlock<BlockDesc>(int32_t idx) {
return &blocks_[idx];
}
template <>
BlockDesc const* ProgramDesc::GetBlock<BlockDesc>(int32_t idx) const {
CHECK_LT(idx, BlocksSize()) << "idx >= blocks.size()";
return &blocks_[idx];
}
template <>
BlockDesc* ProgramDesc::AddBlock<BlockDesc>() {
blocks_.emplace_back();
......
......@@ -30,6 +30,13 @@ class ProgramDesc : public ProgramDescAPI {
public:
ProgramDesc() = default;
void CopyFrom(const ProgramDesc& other) {
version_ = other.Version();
blocks_ = other.blocks();
}
const std::vector<BlockDesc>& blocks() const { return blocks_; }
size_t BlocksSize() const override { return blocks_.size(); }
void ClearBlocks() override { blocks_.clear(); }
......@@ -37,12 +44,10 @@ class ProgramDesc : public ProgramDescAPI {
template <typename T>
T* GetBlock(int32_t idx);
std::vector<BlockDesc>& GetBlocks() { return blocks_; }
template <typename T>
T const* GetBlock(int32_t idx) const {
return GetBlock<T>(idx);
}
T const* GetBlock(int32_t idx) const;
std::vector<BlockDesc>& GetBlocks() { return blocks_; }
template <typename T>
T* AddBlock();
......
......@@ -176,7 +176,7 @@ void LoadCombinedParamsPb(const std::string &path,
const cpp::ProgramDesc &cpp_prog,
bool params_from_memory) {
CHECK(scope);
auto prog = cpp_prog;
auto &prog = cpp_prog;
auto &main_block_desc = *prog.GetBlock<cpp::BlockDesc>(0);
// Get vars
......@@ -310,7 +310,7 @@ void SaveModelPb(const std::string &model_dir,
void SaveCombinedParamsPb(const std::string &path,
const lite::Scope &exec_scope,
const cpp::ProgramDesc &cpp_prog) {
auto prog = cpp_prog;
auto &prog = cpp_prog;
auto &main_block_desc = *prog.GetBlock<cpp::BlockDesc>(0);
// Get vars
......@@ -526,7 +526,7 @@ void SaveCombinedParamsNaive(const std::string &path,
naive_buffer::proto::CombinedParamsDesc pt_desc(&table);
naive_buffer::CombinedParamsDesc desc(&pt_desc);
auto prog = cpp_prog;
auto &prog = cpp_prog;
auto &main_block_desc = *prog.GetBlock<cpp::BlockDesc>(0);
// set unique_var_names to avoid saving shared params repeatedly
std::set<std::string> unique_var_names;
......@@ -681,7 +681,7 @@ void LoadCombinedParamsNaive(const std::string &path,
}
// Check all params loaded
auto prog = cpp_prog;
auto &prog = cpp_prog;
auto &main_block_desc = *prog.GetBlock<cpp::BlockDesc>(0);
for (size_t i = 0; i < main_block_desc.VarsSize(); ++i) {
auto &var = *main_block_desc.GetVar<cpp::VarDesc>(i);
......
......@@ -55,11 +55,6 @@ class BlockDesc : public BlockDescAPI {
template <typename T>
T* GetVar(int32_t idx);
template <typename T>
T const* GetVar(int32_t idx) const {
return GetVar<T>(idx);
}
template <typename T>
T* AddVar();
......@@ -70,11 +65,6 @@ class BlockDesc : public BlockDescAPI {
template <typename T>
T* GetOp(int32_t idx);
template <typename T>
T const* GetOp(int32_t idx) const {
return GetOp<T>(idx);
}
template <typename T>
T* AddOp();
......
......@@ -45,11 +45,6 @@ class ProgramDesc : public ProgramDescAPI {
template <typename T>
T *GetBlock(int32_t idx);
template <typename T>
T const *GetBlock(int32_t idx) const {
return GetBlock<T>(idx);
}
template <typename T>
T *AddBlock();
......
......@@ -83,7 +83,7 @@ class DeformableConvOpLite : public OpLite {
param_.conv_param.filter =
scope->FindVar(Filter)->GetMutable<lite::Tensor>();
param_.conv_param.strides = op_desc.GetAttr<std::vector<int>>("strides");
auto paddings = op_desc.GetAttr<std::vector<int>>("paddings");
std::vector<int> paddings = op_desc.GetAttr<std::vector<int>>("paddings");
auto dilations = op_desc.GetAttr<std::vector<int>>("dilations");
param_.conv_param.groups = op_desc.GetAttr<int>("groups");
param_.conv_param.dilations = std::make_shared<std::vector<int>>(dilations);
......
......@@ -54,7 +54,7 @@ class MaxPoolWithIndexOpLite : public OpLite {
param_.ksize = op_desc.GetAttr<std::vector<int>>("ksize");
param_.global_pooling = op_desc.GetAttr<bool>("global_pooling");
param_.strides = op_desc.GetAttr<std::vector<int>>("strides");
auto paddings = op_desc.GetAttr<std::vector<int>>("paddings");
std::vector<int> paddings = op_desc.GetAttr<std::vector<int>>("paddings");
if (op_desc.HasAttr("adaptive")) {
param_.adaptive = op_desc.GetAttr<bool>("adaptive");
}
......
......@@ -39,8 +39,8 @@ readonly THIRDPARTY_TAR=https://paddle-inference-dist.bj.bcebos.com/PaddleLite/t
readonly workspace=$PWD
# if operating in mac env, we should expand the maximum file num
os_nmae=`uname -s`
if [ ${os_nmae} == "Darwin" ]; then
os_name=`uname -s`
if [ ${os_name} == "Darwin" ]; then
ulimit -n 1024
fi
......
......@@ -21,8 +21,8 @@ USE_ADB_EMULATOR=ON
LITE_WITH_COVERAGE=OFF
# if operating in mac env, we should expand the maximum file num
os_nmae=`uname -s`
if [ ${os_nmae} == "Darwin" ]; then
os_name=`uname -s`
if [ ${os_name} == "Darwin" ]; then
ulimit -n 1024
fi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册