提交 3a6906c8 编写于 作者: C chenjiaoAngel

pull new code

Merge branch 'int8' of https://github.com/chenjiaoAngel/Paddle-Lite into int8
......@@ -39,7 +39,7 @@ else()
endif()
find_library(XPU_SDK_XPU_RT_FILE NAMES xpurt
PATHS ${XPU_SDK_ROOT}/XTDK/runtime/shlib
PATHS ${XPU_SDK_ROOT}/XTDK/runtime/shlib ${XPU_SDK_ROOT}/XTDK/shlib # libxpurt.so may have been moved to XTDK/runtime/shlib
NO_DEFAULT_PATH)
if(NOT XPU_SDK_XPU_RT_FILE)
......
......@@ -97,7 +97,7 @@ function(compile_flatbuffers_schema_to_cpp_opt TARGET SRC_FBS OPT)
OUTPUT ${GEN_HEADER}
COMMAND "${FLATBUFFERS_FLATC_EXECUTABLE}"
--cpp --gen-mutable --gen-object-api --reflect-names
--cpp-ptr-type flatbuffers::unique_ptr # Used to test with C++98 STLs
--force-empty --force-empty-vectors
${OPT}
-I "${CMAKE_CURRENT_SOURCE_DIR}/tests/include_test"
-o "${CMAKE_CURRENT_SOURCE_DIR}/${SRC_FBS_DIR}"
......
......@@ -86,19 +86,28 @@ config.set_model_from_file(/YOU_MODEL_PATH/mobilenet_v1_opt.nb)
predictor = create_paddle_predictor(config)
```
(3) 设置输入数据
(3) 从图片读入数据
```python
image = Image.open('./example.jpg')
resized_image = image.resize((224, 224), Image.BILINEAR)
image_data = np.array(resized_image).flatten().tolist()
```
(4) 设置输入数据
```python
input_tensor = predictor.get_input(0)
input_tensor.resize([1, 3, 224, 224])
input_tensor.set_float_data([1.] * 3 * 224 * 224)
input_tensor.set_float_data(image_data)
```
(4) 执行预测
(5) 执行预测
```python
predictor.run()
```
(5) 得到输出数据
(6) 得到输出数据
```python
output_tensor = predictor.get_output(0)
print(output_tensor.shape())
......
......@@ -24,6 +24,9 @@
#ifdef LITE_WITH_CUDA
#include "lite/backends/cuda/target_wrapper.h"
#endif
#ifdef LITE_WITH_XPU
#include "lite/backends/xpu/target_wrapper.h"
#endif
#ifdef LITE_WITH_MLU
#include "lite/backends/mlu/target_wrapper.h"
......@@ -272,7 +275,7 @@ CxxConfig::mlu_firstconv_param() const {
void CxxConfig::set_xpu_workspace_l3_size_per_thread(int l3_size) {
#ifdef LITE_WITH_XPU
lite::Context<TargetType::kXPU>::SetWorkspaceL3Size(l3_size);
lite::TargetWrapperXPU::workspace_l3_size_per_thread = l3_size;
#else
LOG(WARNING) << "The invoking of the function "
"'set_xpu_workspace_l3_size_per_thread' is ignored, please "
......@@ -282,7 +285,7 @@ void CxxConfig::set_xpu_workspace_l3_size_per_thread(int l3_size) {
void CxxConfig::set_xpu_dev_per_thread(int dev_no) {
#ifdef LITE_WITH_XPU
lite::Context<TargetType::kXPU>::SetDev(dev_no);
lite::TargetWrapperXPU::SetDev(dev_no);
#else
LOG(WARNING) << "The invoking of the function 'set_xpu_dev_per_thread' is "
"ignored, please rebuild it with LITE_WITH_XPU=ON.";
......@@ -291,7 +294,7 @@ void CxxConfig::set_xpu_dev_per_thread(int dev_no) {
void CxxConfig::set_xpu_multi_encoder_precision(const std::string &precision) {
#ifdef LITE_WITH_XPU
lite::Context<TargetType::kXPU>::_multi_encoder_precision = precision;
lite::TargetWrapperXPU::multi_encoder_precision = precision;
#else
LOG(WARNING) << "The invoking of the function "
"'set_xpu_multi_encoder_precision' is "
......
......@@ -55,6 +55,8 @@ USE_MIR_PASS(apu_subgraph_pass);
USE_MIR_PASS(quantized_op_attributes_inference_pass);
USE_MIR_PASS(lite_scale_activation_fuse_pass);
USE_MIR_PASS(__xpu__resnet_fuse_pass);
USE_MIR_PASS(__xpu__resnet_cbam_fuse_pass);
USE_MIR_PASS(__xpu__multi_encoder_fuse_pass);
USE_MIR_PASS(__xpu__embedding_with_eltwise_add_fuse_pass);
USE_MIR_PASS(__xpu__fc_fuse_pass);
USE_MIR_PASS(__xpu__mmdnn_fuse_pass);
......@@ -59,9 +59,9 @@ void TestModel(const std::vector<Place>& valid_places) {
}
auto* image_tensor = predictor.GetInput(1);
image_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 2})));
data = image_tensor->mutable_data<float>();
data[0] = FLAGS_im_height;
data[1] = FLAGS_im_width;
auto* data_1 = image_tensor->mutable_data<int>();
data_1[0] = FLAGS_im_height;
data_1[1] = FLAGS_im_width;
for (int i = 0; i < FLAGS_warmup; ++i) {
predictor.Run();
......
......@@ -763,24 +763,6 @@ void act_thresholded_relu<float>(
}
}
#ifdef LITE_WITH_TRAIN
template <>
void act_square_grad(const float* din,
const float* dout_grad,
float* din_grad,
int size,
int threads) {
const float* ptr_out_grad = dout_grad;
float* ptr_in_grad = din_grad;
for (int i = 0; i < size; ++i) {
ptr_in_grad[0] = ptr_out_grad[0] * 2.0 * din[0];
ptr_out_grad++;
ptr_in_grad++;
din++;
}
}
#endif
} // namespace math
} // namespace arm
} // namespace lite
......
......@@ -90,12 +90,6 @@ template <typename T>
void act_thresholded_relu(
const T* din, T* dout, int size, float threshold, int threads);
#ifdef LITE_WITH_TRAIN
template <typename T>
void act_square_grad(
const T* din, const T* dout_grad, T* din_grad, int size, int threads);
#endif
} // namespace math
} // namespace arm
} // namespace lite
......
......@@ -139,6 +139,91 @@ static bool conv_trans_weights_numc(const dtype* din,
}
return true;
}
// for example: m = 4, n = 4
// din = [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9 , 10 ,11], [12, 13, 14, 15]]
// dout = [[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]]
/*
m = 4 n = 4: 0 1 2 3 0 4 8 12
4 5 6 7 1 5 9 13
8 9 10 11 2 6 10 14
12 13 14 15 3 7 11 15
m = 8 n = 4: 0 1 2 3 4 5 6 7 0 4 8 12 16 20 24 28
8 9 10 11 12 13 14 15 1 5 9 13 17 21 25 29
16 17 18 19 20 21 22 23 2 6 10 14 18 22 26 30
24 25 26 27 28 29 30 31 3 7 11 15 19 23 27 31
m = 4 n = 8: 0 1 2 3 0 8 16 24
4 5 6 7 1 9 17 25
8 9 10 11 2 10 18 26
12 13 14 15 3 11 19 27
16 17 18 19 4 12 20 28
...
m = 8 n = 8: 0 1 2 3 4 5 6 7 0 8 16 24 32 40 48 56
8 9 10 11 12 13 14 15 1 9 17 25 33 41 49 57
16 17 18 19 20 21 22 23 2 10 18 26 34 42 50 58
24 25 26 27 28 29 30 31 3 11 19 27 35 43 51 59
32 33 34 35 36 37 38 39 4 12 20 28 36 44 52 60
...
int k = 0;
for (int i = 0; i < n; ++i) {
for (int j = 0; j < m; ++j) {
dout[k++] = din[j * n + i];
}
}
*/
template <typename Dtype>
void local_transpose(onst Dtype* din, Dtype* dout, int m, int n) {
// n % 4 == 0 m % 4 == 0
// n * m ==> n * m data trans
int offset_m = m << 2;
const Dtype* din_ptr = din;
Dtype* dout_ptr = dout;
for (int i = 0; i < n; i += 4) {
Dtype* out_ptr0 =dout_ptr;
Dtype* out_ptr1 = dout_ptr + m;
Dtype* out_ptr2 = out_ptr1 + m;
Dtype* out_ptr3 = out_ptr2 + m;
const Dtype* in_ptr0 = din_ptr;
const Dtype* in_ptr1 = din_ptr + m;
const Dtype* in_ptr2 = in_ptr1 + m;
const Dtype* in_ptr3 = in_ptr2 + m;
for (int j = 0; j < m; j += 4) {
float32x4_t vin0 = vld1q_f32(in_ptr0);
float32x4_t vin1 = vld1q_f32(in_ptr1);
float32x4_t vin2 = vld1q_f32(in_ptr2);
float32x4_t vin3 = vld1q_f32(in_ptr3);
// a00 b00 a02 b02 a01 b01 a03 b03
float32x4x2_t tmp0 = vtrnq_f32(vin0, vin1);
// c00 d00 c02 d02 c01 d01 c03 d03
float32x4x2_t tmp2 = vtrnq_f32(vin2, vin3);
in_ptr0 = in_ptr3 + m;
in_ptr1 = in_ptr3 + 2 * m;
float tmp_val1 = tmp0.val[0][2];
float tmp_val2 = tmp0.val[0][3];
tmp0.val[0][2] = tmp2.val[0][0];
tmp0.val[0][3] = tmp2.val[0][1];
float tmp_val3 = tmp0.val[1][2];
float tmp_val4 = tmp0.val[1][3];
tmp2.val[0][0] = tmp_val1;
tmp2.val[0][1] = tmp_val2;
tmp0.val[1][2] = tmp2.val[1][0];
tmp0.val[1][3] = tmp2.val[1][1];
tmp2.val[1][0] = tmp_val3;
tmp2.val[1][1] = tmp_val4;
in_ptr2 = in_ptr1 + m;
in_ptr3 = in_ptr1 + 2 * m;
vst1q_f32(out_ptr0, tmp0.val[0]);
vst1q_f32(out_ptr1, tmp0.val[1]);
out_ptr0 += 4;
out_ptr1 += 4;
vst1q_f32(out_ptr2, tmp2.val[0]);
vst1q_f32(out_ptr3, tmp2.val[1]);
out_ptr2 += 4;
out_ptr3 += 4;
}
dout_ptr += offset_m;
din_ptr += 4;
}
}
template <typename Dtype>
void transpose(const Dtype* din, Dtype* dout, int m, int n) {
// nxm == mxn
......
......@@ -11,10 +11,13 @@ nv_library(cuda_transpose SRCS transpose.cu DEPS ${cuda_static_deps})
nv_library(cudnn_conv SRCS cudnn_conv.cc DEPS cuda_activation cuda_scale cuda_type_trans ${cuda_static_deps})
nv_library(cuda_elementwise SRCS elementwise.cu DEPS ${cuda_static_deps})
nv_library(cudnn_pool SRCS cudnn_pool.cc DEPS ${cuda_static_deps})
nv_library(cuda_gru_forward SRCS gru_forward.cu DEPS cuda_activation ${cuda_static_deps})
nv_library(cuda_sequence2batch SRCS sequence2batch.cu DEPS ${cuda_static_deps})
nv_library(cuda_gemm SRCS gemm.cc DEPS ${cuda_static_deps})
nv_library(cuda_batched_gemm SRCS batched_gemm.cc DEPS ${cuda_static_deps})
nv_library(cuda_strided_gemm SRCS strided_gemm.cc DEPS ${cuda_static_deps})
nv_library(cuda_sequence_padding SRCS sequence_padding.cu DEPS ${cuda_static_deps})
nv_library(cuda_bias SRCS bias.cu DEPS ${cuda_static_deps})
set (
math_cuda
......@@ -25,10 +28,13 @@ set (
cuda_transpose
cuda_elementwise
cudnn_pool
cuda_gru_forward
cuda_sequence2batch
cuda_gemm
cuda_batched_gemm
cuda_strided_gemm
cuda_sequence_padding
cuda_bias
)
set(math_cuda "${math_cuda}" CACHE GLOBAL "math cuda")
......@@ -21,6 +21,20 @@ namespace lite {
namespace cuda {
namespace math {
ActivationType GetActiveType(const std::string& act) {
if (act == "sigmoid") {
return kSigmoid;
} else if (act == "relu") {
return kReLU;
} else if (act == "tanh") {
return kTanh;
} else if (act == "identify") {
return kIdentity;
} else {
LOG(FATAL) << "not supported activation: " << act;
}
}
template <typename T>
__global__ void relu_kernel(const int num,
const float alpha,
......
......@@ -17,11 +17,22 @@
#include <cuda_runtime.h>
#include <string>
#include "lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
enum ActivationType {
kSigmoid,
kReLU,
kTanh,
kIdentity,
};
ActivationType GetActiveType(const std::string& act);
// fp32 and half
template <typename T>
void relu(int num, const T* din, T* dout, float alpha, cudaStream_t stream);
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/cuda/math/bias.h"
#include <iostream>
#include "lite/backends/cuda/cuda_utils.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <typename T>
__global__ void RowwiseAddKernel(
const T* a, const T* b, T* c, int width, int num) {
CUDA_KERNEL_LOOP(i, num) {
int h = i / width;
int w = i - h * width;
c[i] = a[i] + b[w];
}
}
template <typename T>
void RowwiseAdd<T>::operator()(const T* input,
const T* bias,
T* output,
const int width,
const int count,
const cudaStream_t& stream) {
RowwiseAddKernel<T><<<CUDA_GET_BLOCKS(count), CUDA_NUM_THREADS, 0, stream>>>(
input, bias, output, width, count);
CUDA_POST_KERNEL_CHECK;
}
template struct RowwiseAdd<float>;
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include "lite/backends/cuda/cuda_utils.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <typename T>
struct RowwiseAdd {
void operator()(const T* input,
const T* bias,
T* output,
const int width,
const int count,
const cudaStream_t& stream);
};
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include "lite/backends/cuda/math/gru_forward.h"
#include "lite/core/device_info.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <typename T>
__global__ void GruForwardResetOutput(
T* gate_value,
T* reset_output_value,
T* prev_output_value,
int frame_size,
int batch_size,
lite::cuda::math::ActivationType active_gate,
bool is_batch) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return;
int batch_idx = 0;
if (is_batch) {
batch_idx = blockIdx.y * blockDim.y + threadIdx.y;
if (batch_idx >= batch_size) return;
gate_value += batch_idx * 3 * frame_size;
reset_output_value += batch_idx * frame_size;
}
T prev_out = 0;
T reset_out_val;
T update_gate_value = gate_value[frame_idx + frame_size * 0];
T reset_gate_value = gate_value[frame_idx + frame_size * 1];
if (prev_output_value) {
if (is_batch) {
prev_output_value += batch_idx * frame_size;
}
prev_out = prev_output_value[frame_idx];
}
if (active_gate == lite::cuda::math::ActivationType::kSigmoid) {
update_gate_value = Sigmoid(update_gate_value);
reset_gate_value = Sigmoid(reset_gate_value);
} else if (active_gate == lite::cuda::math::ActivationType::kReLU) {
update_gate_value = ReLU(update_gate_value);
reset_gate_value = ReLU(reset_gate_value);
} else if (active_gate == lite::cuda::math::ActivationType::kTanh) {
update_gate_value = Tanh(update_gate_value);
reset_gate_value = Tanh(reset_gate_value);
}
reset_out_val = prev_out * reset_gate_value;
gate_value[frame_idx + frame_size * 0] = update_gate_value;
gate_value[frame_idx + frame_size * 1] = reset_gate_value;
reset_output_value[frame_idx] = reset_out_val;
}
template <typename T>
__global__ void GruForwardFinalOutput(
T* gate_value,
T* prev_output_value,
T* output_value,
int frame_size,
int batch_size,
lite::cuda::math::ActivationType active_node,
bool origin_mode,
bool is_batch) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return;
int batch_idx = 0;
if (is_batch) {
batch_idx = blockIdx.y * blockDim.y + threadIdx.y;
if (batch_idx >= batch_size) {
return;
}
gate_value += batch_idx * 3 * frame_size;
output_value += batch_idx * frame_size;
}
T output;
T prev_out = 0;
T update_gate_value = gate_value[frame_idx + frame_size * 0];
T state_frame_value = gate_value[frame_idx + frame_size * 2];
if (prev_output_value) {
if (is_batch) prev_output_value += batch_idx * frame_size;
prev_out = prev_output_value[frame_idx];
}
if (active_node == lite::cuda::math::ActivationType::kSigmoid) {
state_frame_value = Sigmoid(state_frame_value);
} else if (active_node == lite::cuda::math::ActivationType::kReLU) {
state_frame_value = ReLU(state_frame_value);
} else if (active_node == lite::cuda::math::ActivationType::kTanh) {
state_frame_value = Tanh(state_frame_value);
}
if (origin_mode) {
output = update_gate_value * prev_out + state_frame_value -
update_gate_value * state_frame_value;
} else {
output = prev_out - update_gate_value * prev_out +
update_gate_value * state_frame_value;
}
gate_value[frame_idx + frame_size * 2] = state_frame_value;
output_value[frame_idx] = output;
}
template __global__ void GruForwardFinalOutput<float>(
float* gate_value,
float* prev_output_value,
float* output_value,
int frame_size,
int batch_size,
lite::cuda::math::ActivationType active_node,
bool origin_mode,
bool is_batch);
template __global__ void GruForwardResetOutput<float>(
float* gate_value,
float* reset_output_value,
float* prev_output_value,
int frame_size,
int batch_size,
lite::cuda::math::ActivationType active_gate,
bool is_batch);
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cudnn.h>
#include <string>
#include <vector>
#include "lite/api/paddle_place.h"
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/backends/cuda/math/activation.h"
#include "lite/core/context.h"
#include "lite/core/target_wrapper.h"
#include "lite/operators/op_params.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <typename Dtype>
inline __device__ Dtype Sigmoid(const Dtype a) {
return static_cast<Dtype>(1.0) / (static_cast<Dtype>(1.0) + expf(-a));
}
template <typename Dtype>
inline __device__ Dtype ReLU(const Dtype a) {
return a > static_cast<Dtype>(0.f) ? a : static_cast<Dtype>(0.f);
}
template <typename Dtype>
inline __device__ Dtype Tanh(const Dtype a) {
Dtype tmp = static_cast<Dtype>(-2.0) * a;
return (static_cast<Dtype>(2.0) / (static_cast<Dtype>(1.0) + expf(tmp))) -
static_cast<Dtype>(1.0);
}
template <typename T>
__global__ void GruForwardResetOutput(
T* gate_value,
T* reset_output_value,
T* prev_output_value,
int frame_size,
int batch_size,
lite::cuda::math::ActivationType active_gate,
bool is_batch);
template <typename T>
__global__ void GruForwardFinalOutput(
T* gate_value,
T* prev_output_value,
T* output_value,
int frame_size,
int batch_size,
lite::cuda::math::ActivationType active_node,
bool origin_mode,
bool is_batch);
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/backends/cuda/math/sequence2batch.h"
#include "lite/backends/cuda/math/utils.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <typename T>
__global__ void CopyMatrixRowsKernel(const T* src,
T* dst,
const uint64_t* index,
int height,
int width,
bool is_src_index) {
int idx = threadIdx.x;
int idy = threadIdx.y;
int row_id = blockDim.y * gridDim.x + idy;
if (row_id < height) {
int src_idx = is_src_index ? index[row_id] : row_id;
int dst_idx = is_src_index ? row_id : index[row_id];
const T* src_data = src + src_idx * width;
T* dst_data = dst + dst_idx * width;
for (int i = idx; i < width; i += blockDim.x) {
dst_data[i] = src_data[i];
}
}
}
template <typename T>
void CopyMatrixRowsFunctor<T>::operator()(
const lite::Tensor& src,
lite::Tensor* dst,
const std::vector<uint64_t>& index_lod,
bool is_src_index,
const cudaStream_t& stream) {
auto src_dims = src.dims();
auto dst_dims = dst->dims();
CHECK_EQ(src_dims.size(), 2) << "The src must be matrix with rank 2.";
CHECK_EQ(dst_dims.size(), 2) << "The dst must be matrix with rank 2.";
CHECK_EQ(src_dims[1], dst_dims[1])
<< "The width of src and dst must be same.";
int height = dst_dims[0];
int width = dst_dims[1];
const auto* src_data = src.data<T>();
auto* dst_data = dst->template mutable_data<T>(TARGET(kCUDA));
index_tensor_.Resize({static_cast<int64_t>(index_lod.size())});
auto* index_tensor_data = index_tensor_.mutable_data<uint64_t>(TARGET(kCUDA));
TargetWrapperCuda::MemcpyAsync(index_tensor_data,
index_lod.data(),
sizeof(uint64_t) * index_lod.size(),
IoDirection::HtoD,
stream);
dim3 threads(128, 8);
dim3 grids((height + threads.y - 1) / threads.y);
CopyMatrixRowsKernel<T><<<grids, threads, 0, stream>>>(
src_data, dst_data, index_tensor_data, height, width, true);
CUDA_POST_KERNEL_CHECK;
}
template class CopyMatrixRowsFunctor<float>;
template class LoDTensor2BatchFunctor<float>;
template class Batch2LoDTensorFunctor<float>;
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include <algorithm>
#include <string>
#include <vector>
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/core/context.h"
#include "lite/core/tensor.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <typename T>
class CopyMatrixRowsFunctor {
public:
void operator()(const lite::Tensor& src,
lite::Tensor* dst,
const std::vector<uint64_t>& index_lod,
bool is_src_index,
const cudaStream_t& stream);
private:
lite::Tensor index_tensor_;
};
template <typename T>
class LoDTensor2BatchFunctor {
struct SeqInfo {
SeqInfo(size_t start, size_t length, size_t seq_idx)
: start_(start), length_(length), seq_idx_(seq_idx) {}
size_t start_;
size_t length_;
size_t seq_idx_;
};
public:
void operator()(const lite::Tensor& lod_tensor,
lite::Tensor* batch_tensor,
bool is_reverse,
const cudaStream_t& stream) const {
auto lods = lod_tensor.lod();
CHECK_EQ(lods.size(), 1UL) << "Only support one level sequence now.";
const auto& lod = lods[0];
std::vector<SeqInfo> seq_info;
for (int seq_id = 0; seq_id < static_cast<int>(lod.size()) - 1; ++seq_id) {
size_t length = lod[seq_id + 1] - lod[seq_id];
seq_info.emplace_back(lod[seq_id], length, seq_id);
}
std::sort(seq_info.begin(), seq_info.end(), [](SeqInfo a, SeqInfo b) {
return a.length_ > b.length_;
});
LoD batch_lods;
batch_lods.emplace_back(std::vector<uint64_t>{0});
batch_lods.emplace_back(std::vector<uint64_t>{0});
batch_lods.emplace_back(std::vector<uint64_t>{0});
size_t max_seqlen = seq_info[0].length_;
batch_lods[0].resize(max_seqlen + 1);
batch_lods[1].resize(static_cast<size_t>(lod_tensor.dims()[0]));
batch_lods[2].resize(seq_info.size());
auto* batch_starts = batch_lods[0].data();
auto* seq2batch_idx = batch_lods[1].data();
batch_starts[0] = 0;
for (size_t n = 0; n < max_seqlen; ++n) {
size_t batch_id = batch_starts[n];
for (size_t i = 0; i < seq_info.size(); ++i) {
size_t seq_len = seq_info[i].length_;
size_t start = seq_info[i].start_;
if (n < seq_len) {
seq2batch_idx[batch_id] =
is_reverse ? start + seq_len - 1 - n : start + n;
++batch_id;
} else {
break;
}
}
batch_starts[n + 1] = batch_id;
}
auto* seq_order = batch_lods[2].data();
for (size_t i = 0; i < seq_info.size(); ++i) {
seq_order[i] = seq_info[i].seq_idx_;
}
batch_tensor->set_lod(batch_lods);
lite::cuda::math::CopyMatrixRowsFunctor<T> to_batch;
to_batch(lod_tensor, batch_tensor, batch_lods[1], true, stream);
CUDA_POST_KERNEL_CHECK;
}
};
template <typename T>
class Batch2LoDTensorFunctor {
public:
void operator()(const lite::Tensor& batch_tensor,
lite::Tensor* lod_tensor,
const cudaStream_t& stream) {
auto in_lod = batch_tensor.lod();
CHECK_GT(in_lod.size(), 2UL) << "The LoD of LoDTensor should include at "
"least 2-level sequence infomation.";
CHECK_EQ(in_lod[1].size(), static_cast<size_t>(lod_tensor->dims()[0]))
<< "The LoD information should be consistent with the dims.";
lite::cuda::math::CopyMatrixRowsFunctor<T> to_seq;
to_seq(batch_tensor, lod_tensor, in_lod[1], false, stream);
CUDA_POST_KERNEL_CHECK;
}
};
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
......@@ -15,6 +15,7 @@
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/core/target_wrapper.h"
namespace paddle {
......@@ -31,6 +32,16 @@ class TargetWrapper<TARGET(kCUDA)> {
static size_t num_devices();
static size_t maximum_stream() { return 0; }
static int GetComputeCapability() {
int dev_id = GetCurDevice();
int major, minor;
CUDA_CALL(cudaDeviceGetAttribute(
&major, cudaDevAttrComputeCapabilityMajor, dev_id));
CUDA_CALL(cudaDeviceGetAttribute(
&minor, cudaDevAttrComputeCapabilityMinor, dev_id));
return major * 10 + minor;
}
static size_t GetCurDevice() {
int dev_id;
cudaGetDevice(&dev_id);
......
......@@ -119,7 +119,7 @@ cl::NDRange CLContext::DefaultWorkSize(const CLImage &image) {
}
}
cl::NDRange CLContext::LocalWorkSizeTurn(cl::NDRange global_work_size,
cl::NDRange CLContext::LocalWorkSizeTune(cl::NDRange global_work_size,
size_t max_work_size,
int divisor) {
int preferred_lws = 0;
......@@ -157,7 +157,7 @@ cl::NDRange CLContext::LocalWorkSizeTurn(cl::NDRange global_work_size,
static_cast<size_t>(gws0)};
#endif
}
cl::NDRange CLContext::LocalWorkSizeTurnReverse(cl::NDRange global_work_size,
cl::NDRange CLContext::LocalWorkSizeTuneReverse(cl::NDRange global_work_size,
size_t max_work_size,
int divisor) {
int preferred_lws = 0;
......
......@@ -62,10 +62,10 @@ class CLContext {
cl::NDRange LocalWorkSize(cl::NDRange global_work_size, size_t max_work_size);
cl::NDRange LocalWorkSizeTurn(cl::NDRange global_work_size,
cl::NDRange LocalWorkSizeTune(cl::NDRange global_work_size,
size_t max_work_size,
int divitor = 2);
cl::NDRange LocalWorkSizeTurnReverse(cl::NDRange global_work_size,
cl::NDRange LocalWorkSizeTuneReverse(cl::NDRange global_work_size,
size_t max_work_size,
int divitor = 2);
bool IsArmMali();
......
......@@ -6,9 +6,7 @@ __kernel void conv2d_1x1_opt(
__private const int global_size_dim2,
__read_only image2d_t input_image,
__read_only image2d_t filter,
#if defined(BIASE_CH) || defined(BIASE_ELE)
__read_only image2d_t bias,
#endif
#ifdef BATCH_NORM
__read_only image2d_t new_scale,
__read_only image2d_t new_biase,
......@@ -284,9 +282,7 @@ __kernel void conv2d_1x1_simple(
__private const int global_size_dim2,
__read_only image2d_t input_image,
__read_only image2d_t filter,
#if defined(BIASE_CH) || defined(BIASE_ELE)
__read_only image2d_t bias,
#endif
#ifdef BATCH_NORM
__read_only image2d_t new_scale,
__read_only image2d_t new_biase,
......
......@@ -19,9 +19,7 @@ __kernel void conv2d_3x3(__private const int global_size_dim0,
__private const int global_size_dim2,
__read_only image2d_t input_image,
__read_only image2d_t filter,
#if defined(BIASE_CH) || defined(BIASE_ELE)
__read_only image2d_t bias,
#endif
__write_only image2d_t output_image,
__private const int stride,
__private const int offset,
......
......@@ -19,9 +19,7 @@ __kernel void conv2d_3x3_opt(__private const int item_ch,
__private const int item_h,
__read_only image2d_t input_image,
__read_only image2d_t filter_image,
#if defined(BIASE_CH) || defined(BIASE_ELE)
__read_only image2d_t bias,
#endif
__write_only image2d_t output_image,
__private const int stride,
__private const int pad,
......@@ -264,9 +262,7 @@ __kernel void conv2d_3x3_multi_batch(__private const int item_ch,
__private const int item_h,
__read_only image2d_t input_image,
__read_only image2d_t filter_image,
#if defined(BIASE_CH) || defined(BIASE_ELE)
__read_only image2d_t bias,
#endif
__write_only image2d_t output_image,
__private const int stride,
__private const int pad,
......
......@@ -5,9 +5,7 @@ __kernel void conv2d_5x5(__private const int global_size_dim0,
__private const int global_size_dim2,
__read_only image2d_t input_image,
__read_only image2d_t filter_image,
#if defined(BIASE_CH) || defined(BIASE_ELE)
__read_only image2d_t bias,
#endif
#ifdef BATCH_NORM
__read_only image2d_t new_scale,
__read_only image2d_t new_biase,
......
......@@ -20,9 +20,7 @@ __kernel void conv2d_5x5_opt(__private const int item_ch,
__private const int item_h,
__read_only image2d_t input_image,
__read_only image2d_t filter_image,
#if defined(BIASE_CH) || defined(BIASE_ELE)
__read_only image2d_t bias,
#endif
__write_only image2d_t output_image,
__private const int stride,
__private const int pad,
......@@ -268,9 +266,7 @@ __kernel void conv2d_5x5_multi_batch(__private const int item_ch,
__private const int item_h,
__read_only image2d_t input_image,
__read_only image2d_t filter_image,
#if defined(BIASE_CH) || defined(BIASE_ELE)
__read_only image2d_t bias,
#endif
__write_only image2d_t output_image,
__private const int stride,
__private const int pad,
......@@ -513,4 +509,4 @@ __kernel void conv2d_5x5_multi_batch(__private const int item_ch,
(int2)(out_w_base_id + out_w_id4, item_h_id),
output[4]);
}
}
\ No newline at end of file
}
......@@ -5,9 +5,7 @@ __kernel void conv2d_7x7(__private const int global_size_dim0,
__private const int global_size_dim2,
__read_only image2d_t input_image,
__read_only image2d_t filter_image,
#if defined(BIASE_CH) || defined(BIASE_ELE)
__read_only image2d_t bias,
#endif
#ifdef BATCH_NORM
__read_only image2d_t new_scale,
__read_only image2d_t new_biase,
......
......@@ -20,9 +20,7 @@ __kernel void conv2d_7x7_opt(__private const int item_ch,
__private const int item_h,
__read_only image2d_t input_image,
__read_only image2d_t filter_image,
#if defined(BIASE_CH) || defined(BIASE_ELE)
__read_only image2d_t bias,
#endif
__write_only image2d_t output_image,
__private const int stride,
__private const int pad,
......@@ -268,9 +266,7 @@ __kernel void conv2d_7x7_multi_batch(__private const int item_ch,
__private const int item_h,
__read_only image2d_t input_image,
__read_only image2d_t filter_image,
#if defined(BIASE_CH) || defined(BIASE_ELE)
__read_only image2d_t bias,
#endif
__write_only image2d_t output_image,
__private const int stride,
__private const int pad,
......@@ -513,4 +509,4 @@ __kernel void conv2d_7x7_multi_batch(__private const int item_ch,
(int2)(out_w_base_id + out_w_id4, item_h_id),
output[4]);
}
}
\ No newline at end of file
}
......@@ -19,9 +19,7 @@ __kernel void depth_conv2d(__private const int global_size_dim0,
__private const int global_size_dim2,
__read_only image2d_t input,
__read_only image2d_t filter,
#if defined(BIASE_CH) || defined(BIASE_ELE)
__read_only image2d_t bias,
#endif
#ifdef BATCH_NORM
__read_only image2d_t new_scale,
__read_only image2d_t new_biase,
......
......@@ -20,9 +20,7 @@ __kernel void depth_conv2d_3x3(
__private const int global_size_dim2,
__read_only image2d_t input,
__read_only image2d_t filter,
#if defined(BIASE_CH) || defined(BIASE_ELE)
__read_only image2d_t bias,
#endif
__write_only image2d_t output_image,
__private const int stride,
__private const int offset,
......@@ -249,9 +247,7 @@ __kernel void depth_conv2d_3x3s1(__private const int ou_ch_blk,
__private const int ou_nh,
__read_only image2d_t input,
__read_only image2d_t filter,
#if defined(BIASE_CH) || defined(BIASE_ELE)
__read_only image2d_t bias,
#endif
__write_only image2d_t output_image,
__private const int stride,
__private const int pad,
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstddef>
#include <cstdio>
#include <memory>
#include <string>
#include <type_traits>
#include "lite/backends/xpu/xpu_header_sitter.h"
namespace paddle {
namespace lite {
namespace xpu {
template <typename T>
void DumpCPUMem(const T* ptr,
size_t len,
const std::string& comment = "",
size_t stride = 1,
size_t item_per_line = 30) {
size_t after_stride_len = (len + stride - 1) / stride;
std::unique_ptr<T[]> after_stride(new T[after_stride_len]);
for (size_t i = 0; i < after_stride_len; ++i) {
after_stride[i] = ptr[i * stride];
}
double sum = 0;
for (size_t i = 0; i < len; ++i) {
sum += ptr[i];
}
printf(
"------------------------------ [%s] len=%zd stride=%zd sum=%f BEGIN "
"------------------------------\n",
comment.c_str(),
len,
stride,
sum);
size_t nline = (after_stride_len + item_per_line - 1) / item_per_line;
for (size_t i = 0; i < nline; ++i) {
size_t line_begin = i * item_per_line;
size_t line_end = line_begin + item_per_line;
printf("line[%04zd] -- ", i);
for (size_t ii = line_begin; (ii < line_end) && (ii < after_stride_len);
++ii) {
if (std::is_same<T, float>::value) {
printf("%.6f, ", static_cast<float>(after_stride[ii]));
} else if (std::is_same<T, int16_t>::value) {
printf("%d ", static_cast<int>(after_stride[ii]));
} else {
// CHECK(false) << "unknown type";
}
}
printf("\n");
}
printf(
"------------------------------ [%s] len=%zd stride=%zd sum=%f END "
"------------------------------\n",
comment.c_str(),
len,
stride,
sum);
}
template <typename T>
void DumpXPUMem(const T* ptr,
size_t len,
const std::string& comment = "",
size_t stride = 1,
size_t item_per_line = 30) {
size_t after_stride_len = (len + stride - 1) / stride;
std::unique_ptr<T[]> cpu_mem(new T[len]);
xpu_memcpy(
cpu_mem.get(), ptr, len * sizeof(T), XPUMemcpyKind::XPU_DEVICE_TO_HOST);
std::unique_ptr<T[]> after_stride(new T[after_stride_len]);
for (size_t i = 0; i < after_stride_len; ++i) {
after_stride[i] = cpu_mem[i * stride];
}
double sum = 0;
for (size_t i = 0; i < len; ++i) {
sum += cpu_mem[i];
}
printf(
"------------------------------ [%s] len=%zd stride=%zd sum=%f BEGIN "
"------------------------------\n",
comment.c_str(),
len,
stride,
sum);
size_t nline = (after_stride_len + item_per_line - 1) / item_per_line;
for (size_t i = 0; i < nline; ++i) {
size_t line_begin = i * item_per_line;
size_t line_end = line_begin + item_per_line;
printf("line[%04zd] -- ", i);
for (size_t ii = line_begin; (ii < line_end) && (ii < after_stride_len);
++ii) {
if (std::is_same<T, float>::value) {
printf("%.6f, ", static_cast<float>(after_stride[ii]));
} else if (std::is_same<T, int16_t>::value) {
printf("%d ", static_cast<int>(after_stride[ii]));
} else {
// CHECK(false) << "unknown type";
}
}
printf("\n");
}
printf(
"------------------------------ [%s] len=%zd stride=%zd sum=%f END "
"------------------------------\n",
comment.c_str(),
len,
stride,
sum);
}
} // namespace xpu
} // namespace lite
} // namespace paddle
......@@ -13,7 +13,6 @@
// limitations under the License.
#include "lite/backends/xpu/target_wrapper.h"
#include "lite/backends/xpu/xpu_header_sitter.h"
namespace paddle {
namespace lite {
......@@ -42,5 +41,21 @@ void TargetWrapperXPU::MemcpySync(void* dst,
}
}
XPUScratchPadGuard TargetWrapperXPU::MallocScratchPad(size_t size,
bool use_l3) {
void* ptr{nullptr};
if (use_l3) {
ptr = xdnn::alloc_workspace(GetRawContext(), size);
} else {
ptr = TargetWrapperXPU::Malloc(size);
}
CHECK(ptr != nullptr);
return XPUScratchPadGuard(new XPUScratchPad(ptr, use_l3));
}
std::string TargetWrapperXPU::multi_encoder_precision; // NOLINT
int TargetWrapperXPU::workspace_l3_size_per_thread{0};
thread_local xdnn::Context* TargetWrapperXPU::tls_raw_ctx_{nullptr};
} // namespace lite
} // namespace paddle
......@@ -14,6 +14,8 @@
#pragma once
#include <memory> // std::unique_ptr
#include "lite/backends/xpu/xpu_header_sitter.h" // xpu_free
#include "lite/core/target_wrapper.h"
namespace paddle {
......@@ -21,6 +23,24 @@ namespace lite {
using TargetWrapperXPU = TargetWrapper<TARGET(kXPU)>;
struct XPUScratchPad {
XPUScratchPad(void* addr, bool is_l3) : addr_(addr), is_l3_(is_l3) {}
void* addr_{nullptr};
bool is_l3_{false};
};
struct XPUScratchPadDeleter {
void operator()(XPUScratchPad* sp) const {
if (!sp->is_l3_) {
xpu_free(sp->addr_);
}
delete sp;
}
};
using XPUScratchPadGuard = std::unique_ptr<XPUScratchPad, XPUScratchPadDeleter>;
template <>
class TargetWrapper<TARGET(kXPU)> {
public:
......@@ -34,6 +54,41 @@ class TargetWrapper<TARGET(kXPU)> {
const void* src,
size_t size,
IoDirection dir);
static XPUScratchPadGuard MallocScratchPad(size_t size, bool use_l3 = true);
static xdnn::Context* GetRawContext() {
if (tls_raw_ctx_ == nullptr) {
tls_raw_ctx_ = xdnn::create_context();
CHECK(tls_raw_ctx_);
int r = xdnn::set_workspace_l3_size(tls_raw_ctx_,
workspace_l3_size_per_thread);
if (r != 0) {
LOG(WARNING) << "xdnn::set_workspace_l3_size() failed, r = " << r
<< ", workspace_l3_size_per_thread = "
<< workspace_l3_size_per_thread;
}
}
return tls_raw_ctx_;
}
// **DEPRECATED**, use xpu_set_device() at the very beginning of each worker
// thread
static void SetDev(int dev_no = 0) {
const char* dev_env = getenv("LITE_XPU_DEV");
if (dev_env) {
xpu_set_device(atoi(dev_env));
return;
}
xpu_set_device(dev_no);
}
static std::string multi_encoder_precision; // NOLINT
static int workspace_l3_size_per_thread;
private:
static thread_local xdnn::Context* tls_raw_ctx_;
};
} // namespace lite
......
......@@ -21,12 +21,6 @@ namespace lite {
std::string Context<TargetType::kNPU>::subgraph_model_cache_dir_{""}; // NOLINT
#endif
#ifdef LITE_WITH_XPU
std::string Context<TargetType::kXPU>::_multi_encoder_precision; // NOLINT
thread_local xdnn::Context* Context<TargetType::kXPU>::_tls_raw_ctx{nullptr};
int Context<TargetType::kXPU>::_workspace_l3_size_per_thread{0};
#endif
#ifdef LITE_WITH_MLU
int Context<TargetType::kMLU>::next_queue_id_{0};
std::map<int, int> Context<TargetType::kMLU>::queue_id_map_;
......
......@@ -144,45 +144,12 @@ class Context<TargetType::kXPU> {
void CopySharedTo(XPUContext* ctx) {}
// TODO(miaotianxiang): remove this
static xdnn::Context* GetRawContext() {
if (_tls_raw_ctx == nullptr) {
_tls_raw_ctx = xdnn::create_context();
CHECK(_tls_raw_ctx);
int r = xdnn::set_workspace_l3_size(_tls_raw_ctx,
_workspace_l3_size_per_thread);
if (r != 0) {
LOG(WARNING) << "xdnn::set_workspace_l3_size() failed, r = " << r
<< ", _workspace_l3_size_per_thread = "
<< _workspace_l3_size_per_thread;
}
}
return _tls_raw_ctx;
}
static void SetWorkspaceL3Size(int l3_size = 0xfffc00) {
_workspace_l3_size_per_thread = l3_size;
}
// **DEPRECATED**, use xpu_set_device() at the very beginning of each worker
// thread
static void SetDev(int dev_no = 0) {
const char* dev_env = getenv("LITE_XPU_DEV");
if (dev_env) {
xpu_set_device(atoi(dev_env));
return;
}
xpu_set_device(dev_no);
return TargetWrapperXPU::GetRawContext();
}
std::string name() const { return "XPUContext"; }
public:
static std::string _multi_encoder_precision; // NOLINT
private:
static thread_local xdnn::Context* _tls_raw_ctx;
static int _workspace_l3_size_per_thread;
};
#endif
......
......@@ -23,9 +23,11 @@ lite_cc_library(mir_passes
fusion/sequence_pool_concat_fuse_pass.cc
fusion/scale_activation_fuse_pass.cc
fusion/__xpu__resnet_fuse_pass.cc
fusion/__xpu__resnet_cbam_fuse_pass.cc
fusion/__xpu__multi_encoder_fuse_pass.cc
fusion/__xpu__embedding_with_eltwise_add_fuse_pass.cc
fusion/__xpu__fc_fuse_pass.cc
fusion/__xpu__mmdnn_fuse_pass.cc
elimination/identity_scale_eliminate_pass.cc
elimination/identity_dropout_eliminate_pass.cc
elimination/elementwise_mul_constant_eliminate_pass.cc
......
此差异已折叠。
......@@ -639,20 +639,21 @@ class XPUMultiEncoderFusePass : public ProgramPass {
std::set<int> fc_int31_ids;
#ifdef LITE_WITH_XPU
// TODO(miaotianxiang): core/mir/*_pass.cc are compiled anyway and need to
// access Context<kXPU>::_multi_encoder_precision, but this static member
// variable in class specialization defined in lite/core/context.cc
// is only compiled iff LITE_WITH_XPU==ON. To suppress linkage error, we use
// access TargetWrapperXPU::multi_encoder_precision, but this static member
// variable in class specialization defined in
// lite/backends/xpu/target_wrapper.cc is only compiled iff
// LITE_WITH_XPU==ON. To suppress linkage error, we use
// #ifdef here. Any better idea?
if (GetStringFromEnv("XPU_ENCODER_PRECISION", "int16") == "int31" ||
lite::Context<TargetType::kXPU>::_multi_encoder_precision == "int31") {
lite::TargetWrapperXPU::multi_encoder_precision == "int31") {
fc_int31_ids = {0, 1, 2, 3, 4, 5};
VLOG(3) << "Use int31 in XPUMultiEncoderOp, "
<< "lite::Context<>::_multi_encoder_precision="
<< lite::Context<TargetType::kXPU>::_multi_encoder_precision;
<< "lite::TargetWrapperXPU::multi_encoder_precision="
<< lite::TargetWrapperXPU::multi_encoder_precision;
} else {
VLOG(3) << "Use int16 in XPUMultiEncoderOp, "
<< "lite::Context<>::_multi_encoder_precision="
<< lite::Context<TargetType::kXPU>::_multi_encoder_precision;
<< "lite::TargetWrapperXPU::multi_encoder_precision="
<< lite::TargetWrapperXPU::multi_encoder_precision;
}
#endif
......
此差异已折叠。
......@@ -192,7 +192,8 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
} else if (is_weight_quantization) {
std::string scale_name = conv_weight_name + "_quant_scale";
if (conv_op_desc->HasAttr(scale_name)) {
auto scale = conv_op_desc->GetAttr<std::vector<float>>(scale_name);
std::vector<float> scale =
conv_op_desc->GetAttr<std::vector<float>>(scale_name);
CHECK_EQ(scale.size(), alpha_tensor.numel());
for (size_t i = 0; i < scale.size(); i++) {
scale[i] *= alpha_data[i];
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/conv_conv_fuse_pass.h"
#include <memory>
#include <vector>
#include "lite/core/mir/fusion/conv_conv_fuser.h"
#include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void ConvConvFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// initialze fuser params
std::vector<bool> conv_has_bias_cases{true, false};
std::vector<std::string> conv_type_cases{
"conv2d", "depthwise_conv2d"};
bool has_arm = false;
for (auto& place : graph->valid_places()) {
if (place.target == TARGET(kARM) && place.precision == PRECISION(kFloat)) {
has_arm = true;
break;
}
}
if (!has_arm) {
return;
}
// only support fp32 fusion
for (auto conv_has_bias0 : conv_has_bias_cases) {
for (auto conv_has_bias1 : conv_has_bias_cases) {
for (auto conv_type0 : conv_type_cases) {
for (auto conv_type1 : conv_type_cases) {
VLOG(4) << "conv_has_bias0:" << conv_has_bias0
<< " conv_type0:" << conv_type0;
VLOG(4) << "conv_has_bias1:" << conv_has_bias1
<< " conv_type1:" << conv_type1;
fusion::ConvConvFuser fuser(conv_type0, conv_type1, conv_has_bias0, conv_has_bias1);
fuser(graph.get());
}
}
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_conv_conv_fuse_pass, paddle::lite::mir::ConvConvFusePass)
.BindTargets({TARGET(kARM)});
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/conv_conv_fuse_pass.h"
#include <memory>
#include <vector>
#include "lite/core/mir/fusion/conv_conv_fuser.h"
#include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void ConvConvFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// initialze fuser params
std::vector<bool> conv_has_bias_cases{true, false};
std::vector<std::string> conv_type_cases{
"conv2d", "depthwise_conv2d"};
bool has_arm = false;
for (auto& place : graph->valid_places()) {
if (place.target == TARGET(kARM) && place.precision == PRECISION(kFloat)) {
has_arm = true;
break;
}
}
if (!has_arm) {
return;
}
// only support fp32 fusion
for (auto conv_has_bias0 : conv_has_bias_cases) {
for (auto conv_has_bias1 : conv_has_bias_cases) {
for (auto conv_type0 : conv_type_cases) {
for (auto conv_type1 : conv_type_cases) {
VLOG(4) << "conv_has_bias0:" << conv_has_bias0
<< " conv_type0:" << conv_type0;
VLOG(4) << "conv_has_bias1:" << conv_has_bias1
<< " conv_type1:" << conv_type1;
fusion::ConvConvFuser fuser(conv_type0, conv_type1, conv_has_bias0, conv_has_bias1);
fuser(graph.get());
}
}
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_conv_conv_fuse_pass, paddle::lite::mir::ConvConvFusePass)
.BindTargets({TARGET(kARM)});
......@@ -84,11 +84,12 @@ cpp::OpDesc TransposeSoftmaxTransposeFuser::GenOpDesc(
op_desc.SetInput("X", {matched.at("x1")->arg()->name});
op_desc.SetOutput("Out", {matched.at("out")->arg()->name});
op_desc.SetAttr("axis",
matched.at("transpose1")
->stmt()
->op_info()
->GetAttr<std::vector<int>>("axis")
.back());
*(matched.at("transpose1")
->stmt()
->op_info()
->GetAttr<std::vector<int>>("axis")
.end() -
1));
return op_desc;
}
......
......@@ -62,15 +62,17 @@ std::string Visualize(mir::SSAGraph* graph) {
<< string_trunc(op_info->GetAttr<std::string>(attr_name)) << "\"";
break;
case AttrType::FLOATS: {
auto vals = op_info->GetAttr<std::vector<float>>(attr_name);
std::vector<float> vals =
op_info->GetAttr<std::vector<float>>(attr_name);
os << ":floats: {" + Join(vals, ",") << "}";
} break;
case AttrType::INTS: {
auto vals = op_info->GetAttr<std::vector<int>>(attr_name);
std::vector<int> vals = op_info->GetAttr<std::vector<int>>(attr_name);
os << ":ints: {" + Join(vals, ",") + "}";
} break;
case AttrType::STRINGS: {
auto vals = op_info->GetAttr<std::vector<std::string>>(attr_name);
std::vector<std::string> vals =
op_info->GetAttr<std::vector<std::string>>(attr_name);
os << ":strings: {" + string_trunc(Join(vals, ",")) << "}";
} break;
default:
......
......@@ -94,6 +94,8 @@ class Optimizer {
#endif
"identity_dropout_eliminate_pass",
"__xpu__resnet_fuse_pass",
"__xpu__resnet_cbam_fuse_pass",
"__xpu__mmdnn_fuse_pass",
"__xpu__multi_encoder_fuse_pass",
"__xpu__embedding_with_eltwise_add_fuse_pass",
"__xpu__fc_fuse_pass",
......
......@@ -195,7 +195,7 @@ void Program::Build(const cpp::ProgramDesc& prog) {
CHECK(ops_.empty()) << "Executor duplicate Build found";
// Create operators.
auto program = prog;
auto& program = prog;
CHECK(program.BlocksSize());
auto& main_block = *program.GetBlock<cpp::BlockDesc>(0);
for (size_t i = 0; i < main_block.OpsSize(); ++i) {
......@@ -262,7 +262,7 @@ void Program::PrepareWorkspace(const cpp::ProgramDesc& prog,
}
};
auto program = prog;
auto& program = prog;
CHECK(program.BlocksSize());
for (size_t b = 0; b < program.BlocksSize(); ++b) {
auto& main_block = *program.GetBlock<cpp::BlockDesc>(b);
......
......@@ -46,7 +46,8 @@ struct Program {
const std::shared_ptr<Scope>& root,
const std::vector<Place>& valid_places,
const std::vector<std::string>& var_names = {})
: scope_(root), valid_places_(valid_places), desc_(desc) {
: scope_(root), valid_places_(valid_places) {
desc_.CopyFrom(desc);
CHECK(scope_) << "scope should be init first";
VLOG(4) << "prepare work";
PrepareWorkspace(desc, var_names);
......
......@@ -28,7 +28,7 @@ namespace lite {
namespace kernels {
namespace apu {
int SubgraphEngine::BuildDeviceProgram() {
bool SubgraphEngine::BuildDeviceProgram() {
unsigned int version;
Neuron_getVersion(&version);
VLOG(3) << "Neuron Adapter version: " << version;
......@@ -38,7 +38,7 @@ int SubgraphEngine::BuildDeviceProgram() {
int neuron_errCode = NeuronModel_create(&model_);
if (NEURON_NO_ERROR != neuron_errCode) {
LOG(WARNING) << "Fail to create model";
return subgraph::FAILED;
return false;
}
graph.set_model(model_);
graph.set_input_names(input_names_);
......@@ -46,6 +46,9 @@ int SubgraphEngine::BuildDeviceProgram() {
// Convert all of ops and their input vars and weights and added into the APU
// NIR graph
if (origin_program_.empty()) {
BuildOriginProgram();
}
const auto& bridges = subgraph::Registry::Instance();
for (auto& inst : origin_program_) {
auto op = const_cast<OpLite*>(inst.op());
......@@ -54,7 +57,7 @@ int SubgraphEngine::BuildDeviceProgram() {
op->InferShape();
std::string op_type = op->op_info()->Type();
if (!bridges.Exists(op_type, TARGET(kAPU))) {
return subgraph::FAILED;
return false;
}
auto kernel = inst.kernel();
......@@ -63,7 +66,7 @@ int SubgraphEngine::BuildDeviceProgram() {
const_cast<OpLite*>(op),
const_cast<KernelBase*>(kernel));
if (subgraph::CHECK_FAILED(status)) {
return subgraph::FAILED;
return false;
}
}
......@@ -84,7 +87,7 @@ int SubgraphEngine::BuildDeviceProgram() {
VLOG(3) << "input idx: " << graph.Get(input_names_[i])->index();
} else {
LOG(WARNING) << "Fail to find input: " << input_names_[i];
return subgraph::FAILED;
return false;
}
}
......@@ -105,7 +108,7 @@ int SubgraphEngine::BuildDeviceProgram() {
VLOG(3) << "output idx: " << graph.Get(output_names_[i])->index();
} else {
LOG(WARNING) << "Fail to find output: " << output_names_[i];
return subgraph::FAILED;
return false;
}
}
......@@ -116,7 +119,7 @@ int SubgraphEngine::BuildDeviceProgram() {
neuron_errCode = NeuronModel_finish(model_);
if (NEURON_NO_ERROR != neuron_errCode) {
LOG(WARNING) << "Fail to create NIR model:" << neuron_errCode;
return subgraph::FAILED;
return false;
}
VLOG(3) << "[APU] APU NIR model created!";
......@@ -129,15 +132,14 @@ int SubgraphEngine::BuildDeviceProgram() {
compilation_ = lite::apu::Device::Global().Build(model_);
if (compilation_ == nullptr) {
LOG(WARNING) << "[APU] Build APU DLA model failed!";
return subgraph::FAILED;
return false;
}
VLOG(3) << "[APU] APU DLA model created, Build cost "
<< GetCurrentUS() - start_time << " us";
return status;
return true;
}
int SubgraphEngine::LaunchDeviceProgram() {
bool SubgraphEngine::LaunchDeviceProgram() {
auto GetCurrentUS = []() -> double {
struct timeval time;
gettimeofday(&time, NULL);
......@@ -149,7 +151,7 @@ int SubgraphEngine::LaunchDeviceProgram() {
int neuron_errCode = NeuronExecution_create(compilation_, &run);
if (NEURON_NO_ERROR != neuron_errCode) {
LOG(WARNING) << "[APU] Build APU runtime failed!";
return subgraph::FAILED;
return false;
}
// Set input buffer
......@@ -177,7 +179,7 @@ int SubgraphEngine::LaunchDeviceProgram() {
neuron_errCode = NeuronExecution_compute(run);
if (NEURON_NO_ERROR != neuron_errCode) {
LOG(WARNING) << "Fail to run execution!" << neuron_errCode;
return subgraph::FAILED;
return false;
}
for (size_t i = 0; i < origin_otensors_.size(); i++) {
......@@ -190,7 +192,7 @@ int SubgraphEngine::LaunchDeviceProgram() {
}
NeuronExecution_free(run);
VLOG(3) << "[APU] Process cost " << GetCurrentUS() - start_time << " us";
return 0;
return true;
}
SubgraphEngine::~SubgraphEngine() {
......@@ -211,12 +213,11 @@ void SubgraphCompute::PrepareForRun() {
param.output_data_names,
param.scope));
CHECK(engine_);
engine_->Build();
}
void SubgraphCompute::Run() {
CHECK(engine_);
engine_->Launch();
engine_->Run();
}
} // namespace apu
......
......@@ -41,8 +41,8 @@ class SubgraphEngine : public subgraph::Engine {
~SubgraphEngine();
protected:
int BuildDeviceProgram() override;
int LaunchDeviceProgram() override;
bool BuildDeviceProgram() override;
bool LaunchDeviceProgram() override;
NeuronModel *model_;
NeuronCompilation *compilation_;
......
......@@ -103,7 +103,6 @@ add_kernel(deformable_conv_compute_arm ARM extra SRCS deformable_conv_compute.cc
add_kernel(mean_compute_arm ARM extra SRCS mean_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(mean_grad_compute_arm ARM train SRCS mean_grad_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(activation_grad_compute_arm ARM train SRCS activation_grad_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(elementwise_grad_compute_arm ARM train SRCS elementwise_grad_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(mul_grad_compute_arm ARM train SRCS mul_grad_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(sgd_compute_arm ARM train SRCS sgd_compute.cc DEPS ${lite_kernel_deps} math_arm)
......
......@@ -103,10 +103,19 @@ void SequenceConvCompute::Run() {
1,
1, // stride_h, stride_w, dilation_h, dilation_w
tmp_data);
local_naive_transpose(tmp_data,
sub_col_data,
kernel_size * hidden_dim,
input_row_end - input_row_begin);
int cols = kernel_size * hidden_dim;
int rows = input_row_end - input_row_begin;
if (cols % 4 == 0 && rows % 4 == 0) {
paddle::lite::arm::math::local_transpose(tmp_data,
sub_col_data,
cols,
rows);
} else {
local_naive_transpose(tmp_data,
sub_col_data,
cols,
rows);
}
}
}
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include <bmcompiler_if.h>
#include <math.h>
#include "lite/kernels/bm/bridges/graph.h"
#include "lite/kernels/bm/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
......@@ -64,10 +65,16 @@ int BatchNormConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto* bias_data = bias->mutable_data<float>();
auto* mean_data = mean->mutable_data<float>();
auto* variance_data = variance->mutable_data<float>();
float* new_bias = static_cast<float*>(malloc(bias->memory_size()));
float* new_scale = static_cast<float*>(malloc(scale->memory_size()));
CHECK(new_bias != nullptr);
CHECK(new_scale != nullptr);
for (int c = 0; c < channel_size; c++) {
float inv_scale = 1.f / (std::sqrt(variance_data[c] + epsilon));
bias_data[c] = bias_data[c] - inv_scale * scale_data[c] * mean_data[c];
scale_data[c] = inv_scale * scale_data[c];
new_bias[c] = bias_data[c] - inv_scale * scale_data[c] * mean_data[c];
new_scale[c] = inv_scale * scale_data[c];
}
const int input_num = 1;
......@@ -86,11 +93,13 @@ int BatchNormConverter(void* ctx, OpLite* op, KernelBase* kernel) {
output_dims.size(),
static_cast<const char*>(output_var_name.c_str()),
static_cast<const char*>(unique_op_name.c_str()),
static_cast<const float*>(scale->mutable_data<float>()),
static_cast<const float*>(bias->mutable_data<float>()),
static_cast<const float*>(new_scale),
static_cast<const float*>(new_bias),
1,
1,
1);
free(new_scale);
free(new_bias);
delete[] shape;
delete[] name;
delete[] dim;
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include <bmcompiler_if.h>
#include <math.h>
#include "lite/kernels/bm/bridges/graph.h"
#include "lite/kernels/bm/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
......
......@@ -76,6 +76,8 @@ int InterpolateConverter(void* ctx, OpLite* op, KernelBase* kernel) {
static_cast<const char*>(output_var_name.c_str()),
0,
0,
0,
0,
type);
}
graph->AddNode(output_var_name);
......
......@@ -28,12 +28,35 @@ namespace lite {
namespace kernels {
namespace bm {
int SubgraphEngine::BuildDeviceProgram() {
bool SubgraphEngine::PrepareWorkspaceForDeviceProgram() {
// Obtain the origin input tensors, and create the origin output
// tensors(Don't try to access them before launch the device program or the
// origin program)
PrepareWorkspaceForOriginProgram();
// Create the device input and output tensors, but don't initialize them
// with the dimensions
device_inputs_.resize(input_names_.size());
for (int i = 0; i < input_names_.size(); i++) {
device_inputs_[i].reset(new hiai::AiTensor);
CHECK(device_inputs_[i]);
}
device_outputs_.resize(output_names_.size());
for (int i = 0; i < output_names_.size(); i++) {
device_outputs_[i].reset(new hiai::AiTensor);
CHECK(device_outputs_[i]);
}
return true;
}
bool SubgraphEngine::BuildDeviceProgram() {
int status = 0;
subgraph::bm::Graph graph;
const auto& bridges = subgraph::Registry::Instance();
graph.CreateCompilerHandle();
auto& ctx = this->ctx_->template As<BMContext>();
if (origin_program_.empty()) {
BuildOriginProgram();
}
for (auto& inst : origin_program_) {
auto op = const_cast<OpLite*>(inst.op());
CHECK(op);
......@@ -42,7 +65,7 @@ int SubgraphEngine::BuildDeviceProgram() {
std::string op_type = op->op_info()->Type();
LOG(INFO) << op_type;
if (!bridges.Exists(op_type, TARGET(kBM))) {
return subgraph::FAILED;
return false;
}
auto kernel = inst.kernel();
status |=
......@@ -50,12 +73,13 @@ int SubgraphEngine::BuildDeviceProgram() {
const_cast<OpLite*>(op),
const_cast<KernelBase*>(kernel));
if (subgraph::CHECK_FAILED(status)) {
return subgraph::FAILED;
return false;
}
}
std::string net_name = "bmnetc_f32umodel";
std::string net_name = "bmnet_f32bmodel";
auto unique_net_name = lite::subgraph::bm::UniqueName(net_name);
__bmcompile_opt(
graph.GetCompilerHandle(), const_cast<char*>(net_name.c_str()), 1);
graph.GetCompilerHandle(), const_cast<char*>(unique_net_name.c_str()), 2);
void* bmodel_data = nullptr;
unsigned int data_size = 0;
bm_hd_ = static_cast<bm_handle_t>(ctx.GetHandle());
......@@ -63,7 +87,7 @@ int SubgraphEngine::BuildDeviceProgram() {
graph.UnlockCompilerMutex();
bmrt_hd_ = bmrt_create(bm_hd_);
if (false == bmrt_load_bmodel_data(bmrt_hd_, bmodel_data, data_size)) {
return subgraph::FAILED;
return false;
}
bmrt_get_network_names(bmrt_hd_, &net_names_);
net_info_ = bmrt_get_network_info(bmrt_hd_, net_names_[0]);
......@@ -116,10 +140,10 @@ int SubgraphEngine::BuildDeviceProgram() {
net_info_->output_dtypes[i],
stage.output_shapes[i]);
}
return status;
return true;
}
int SubgraphEngine::LaunchDeviceProgram() {
bool SubgraphEngine::LaunchDeviceProgram() {
for (size_t i = 0; i < device_inputs_.size(); i++) {
bm_memcpy_s2d(bm_hd_,
device_inputs_[i].device_mem,
......@@ -143,7 +167,7 @@ int SubgraphEngine::LaunchDeviceProgram() {
out_index++;
}
}
return 0;
return true;
}
void SubgraphCompute::PrepareForRun() {
......@@ -155,12 +179,11 @@ void SubgraphCompute::PrepareForRun() {
param.output_data_names,
param.scope));
CHECK(engine_);
engine_->Build();
}
void SubgraphCompute::Run() {
CHECK(engine_);
engine_->Launch();
engine_->Run();
}
} // namespace bm
......
......@@ -44,8 +44,9 @@ class SubgraphEngine : public subgraph::Engine {
ctx, block_idx, block_desc, input_names, output_names, scope) {}
protected:
int BuildDeviceProgram() override;
int LaunchDeviceProgram() override;
bool PrepareWorkspaceForDeviceProgram() override;
bool BuildDeviceProgram() override;
bool LaunchDeviceProgram() override;
private:
void *bmrt_hd_;
......
......@@ -7,6 +7,7 @@ message(STATUS "compile with lite CUDA kernels")
# basic kernels
add_kernel(mul_compute_cuda CUDA basic SRCS mul_compute.cc DEPS ${lite_kernel_deps} ${math_cuda})
add_kernel(fc_compute_cuda CUDA basic SRCS fc_compute.cu DEPS ${lite_kernel_deps} ${math_cuda})
add_kernel(gru_compute_cuda CUDA basic SRCS gru_compute.cu DEPS ${lite_kernel_deps} ${math_cuda})
add_kernel(matmul_compute_cuda CUDA basic SRCS matmul_compute.cc DEPS ${lite_kernel_deps} ${math_cuda})
add_kernel(search_group_padding_compute_cuda CUDA basic SRCS search_group_padding_compute.cu DEPS ${lite_kernel_deps})
add_kernel(io_copy_compute_cuda CUDA basic SRCS io_copy_compute.cc DEPS ${lite_kernel_deps})
......@@ -69,6 +70,7 @@ nv_test(softmax_compute_cuda_test SRCS softmax_compute_test.cc DEPS softmax_comp
#nv_test(layout_cuda_test SRCS layout_compute_test.cc DEPS layout_compute_cuda)
nv_test(mul_compute_cuda_test SRCS mul_compute_test.cc DEPS mul_compute_cuda)
nv_test(fc_compute_cuda_test SRCS fc_compute_test.cc DEPS fc_compute_cuda)
nv_test(gru_compute_cuda_test SRCS gru_compute_test.cc DEPS gru_compute_cuda)
nv_test(matmul_compute_cuda_test SRCS matmul_compute_test.cc DEPS matmul_compute_cuda)
nv_test(dropout_compute_cuda_test SRCS dropout_compute_test.cc DEPS dropout_compute_cuda )
nv_test(bilinear_interp_compute_cuda_test SRCS bilinear_interp_compute_test.cc DEPS bilinear_interp_compute_cuda)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/backends/cuda/math/bias.h"
#include "lite/backends/cuda/math/gru_forward.h"
#include "lite/backends/cuda/math/sequence2batch.h"
#include "lite/backends/cuda/target_wrapper.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/gru_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T>
struct GRUMetaValue {
T* gate_weight;
T* state_weight;
T* gate_value;
T* reset_output_value;
T* output_value;
T* prev_out_value;
};
template <typename T>
struct GRUUnitFunctor {
static void compute(GRUMetaValue<T> value,
int frame_size,
int batch_size,
const lite::cuda::math::ActivationType& active_node,
const lite::cuda::math::ActivationType& active_gate,
bool origin_mode,
lite::cuda::math::Gemm<T, T>* blas,
CUDAContext* context) {
dim3 threads, grids;
if (batch_size == 1) {
int frame_per_block = frame_size <= 1024 ? frame_size : 1024;
int frame_blocks = (frame_size + 1024 - 1) / 1024;
threads = dim3(frame_per_block, 1);
grids = dim3(frame_blocks, 1);
} else {
threads = dim3(32, 32);
grids = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32);
}
if (value.prev_out_value) {
CHECK(blas->init(false,
false,
batch_size,
frame_size * 2,
frame_size,
frame_size,
frame_size * 2,
frame_size * 3,
context));
blas->run(1.0f,
1.0f,
value.prev_out_value,
value.gate_weight,
value.gate_value,
context);
}
CUDA_POST_KERNEL_CHECK;
lite::cuda::math::GruForwardResetOutput<
T><<<grids, threads, 0, context->exec_stream()>>>(
value.gate_value,
value.reset_output_value,
value.prev_out_value,
frame_size,
batch_size,
active_gate,
batch_size == 1);
CUDA_POST_KERNEL_CHECK;
if (value.prev_out_value) {
CHECK(blas->init(false,
false,
batch_size,
frame_size,
frame_size,
frame_size,
frame_size,
frame_size * 3,
context));
blas->run(1.0f,
1.0f,
value.reset_output_value,
value.state_weight,
value.gate_value + frame_size * 2,
context);
}
CUDA_POST_KERNEL_CHECK;
lite::cuda::math::GruForwardFinalOutput<
T><<<grids, threads, 0, context->exec_stream()>>>(value.gate_value,
value.prev_out_value,
value.output_value,
frame_size,
batch_size,
active_node,
origin_mode,
batch_size == 1);
CUDA_POST_KERNEL_CHECK;
}
};
template struct GRUUnitFunctor<float>;
template <typename T, PrecisionType PType>
void GRUCompute<T, PType>::PrepareForRun() {
gemm_impl_.reset(new lite::cuda::math::Gemm<T, T>);
}
template <typename T, PrecisionType PType>
void GRUCompute<T, PType>::Run() {
auto& context = this->ctx_->template As<CUDAContext>();
auto stream = context.exec_stream();
auto& param = this->template Param<param_t>();
auto* input = param.input;
lite::Tensor* h0{nullptr};
if (param.h0) {
h0 = const_cast<lite::Tensor*>(param.h0);
}
lite::Tensor* bias{nullptr};
if (param.bias) {
bias = const_cast<lite::Tensor*>(param.bias);
}
auto* weight = param.weight;
auto* weight_data = const_cast<T*>(weight->template data<T>());
auto* batch_gate = param.batch_gate;
auto* batch_reset_hidden_prev = param.batch_reset_hidden_prev;
auto* batch_hidden = param.batch_hidden;
auto* hidden = param.hidden;
auto* batch_reset_hidden_prev_data =
batch_reset_hidden_prev->template mutable_data<T>(TARGET(kCUDA));
hidden->template mutable_data<T>(TARGET(kCUDA));
auto* batch_gate_data = batch_gate->template mutable_data<T>(TARGET(kCUDA));
auto* batch_hidden_data =
batch_hidden->template mutable_data<T>(TARGET(kCUDA));
bool is_reverse = param.is_reverse;
auto active_node = lite::cuda::math::GetActiveType(param.activation);
auto active_gate = lite::cuda::math::GetActiveType(param.gate_activation);
bool origin_mode = param.origin_mode;
auto hidden_dims = hidden->dims();
int frame_size = hidden_dims[1];
lite::cuda::math::LoDTensor2BatchFunctor<T> batch_func;
batch_func(*input, batch_gate, is_reverse, stream);
if (bias) {
lite::cuda::math::RowwiseAdd<T> add_bias;
add_bias(batch_gate_data,
bias->template data<T>(),
batch_gate_data,
frame_size,
batch_gate->numel(),
stream);
}
GRUMetaValue<T> gru_value;
gru_value.gate_weight = weight_data;
gru_value.state_weight = weight_data + 2 * frame_size * frame_size;
if (h0) {
// Since the batch computing for GRU reorders the input sequences
// according to their length. The initialized cell state also needs
// to reorder.
ordered_h0_.Resize(h0->dims());
lite::cuda::math::CopyMatrixRowsFunctor<T> row_shuffle;
row_shuffle(*h0, &ordered_h0_, batch_gate->lod()[2], true, stream);
gru_value.prev_out_value = ordered_h0_.mutable_data<T>(TARGET(kCUDA));
} else {
gru_value.prev_out_value = nullptr;
}
auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1;
for (size_t n = 0; n < num_batch; ++n) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
int cur_batch_size = bend - bstart;
gru_value.output_value = batch_hidden_data + bstart * frame_size;
gru_value.gate_value = batch_gate_data + bstart * frame_size * 3;
gru_value.reset_output_value =
batch_reset_hidden_prev_data + bstart * frame_size;
GRUUnitFunctor<T>::compute(gru_value,
frame_size,
cur_batch_size,
active_node,
active_gate,
origin_mode,
gemm_impl_.get(),
&context);
gru_value.prev_out_value = gru_value.output_value;
}
lite::cuda::math::Batch2LoDTensorFunctor<T> to_seq;
batch_hidden->set_lod(batch_gate->lod());
to_seq(*batch_hidden, hidden, stream);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
using GRUFp32 =
paddle::lite::kernels::cuda::GRUCompute<float, PRECISION(kFloat)>;
REGISTER_LITE_KERNEL(gru, kCUDA, kFloat, kNCHW, GRUFp32, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("H0", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Weight", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("BatchGate", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("BatchResetHiddenPrev", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("BatchHidden", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Hidden", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
......@@ -12,41 +12,35 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/activation_grad_compute.h"
#include "lite/backends/arm/math/funcs.h"
#pragma once
#include <memory>
#include "lite/backends/cuda/math/gemm.h"
#include "lite/core/kernel.h"
#include "lite/operators/op_params.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void SquareGradCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
auto out_grad_dims = param.Out_grad->dims();
auto out_grad_data = param.Out_grad->data<float>();
auto x_data = param.X->data<float>();
auto x_grad_data = param.X_grad->mutable_data<float>();
lite::arm::math::act_square_grad<float>(x_data,
out_grad_data,
x_grad_data,
out_grad_dims.production(),
ctx.threads());
}
} // namespace arm
namespace cuda {
template <typename T, PrecisionType PType>
class GRUCompute : public KernelLite<TARGET(kCUDA), PType> {
public:
using param_t = operators::GRUParam;
void PrepareForRun() override;
void Run() override;
virtual ~GRUCompute() = default;
private:
std::unique_ptr<lite::cuda::math::Gemm<T, T>> gemm_impl_{nullptr};
lite::Tensor ordered_h0_;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(square_grad,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::SquareGradCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Out@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("X@GRAD", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/gru_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "lite/api/test_helper.h"
#include "lite/utils/float16.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class GRUTest : public ::testing::Test {
protected:
GRUTest()
: batch_(12),
frame_size_(128),
activation_("tanh"),
gate_activation_("sigmoid"),
is_reverse_(false),
origin_mode_(false),
x_shape_({batch_, frame_size_ * 3}),
w_shape_({frame_size_, frame_size_ * 3}),
out_shape_({batch_, frame_size_}),
lod_({{0, 4, 9, 12}}) {
x_ref_.Resize(lite::DDim(x_shape_));
x_gpu_.Resize(lite::DDim(x_shape_));
x_ref_.set_lod(lod_);
w_ref_.Resize(lite::DDim(w_shape_));
w_gpu_.Resize(lite::DDim(w_shape_));
auto x_ref_data = x_ref_.mutable_data<float>();
auto w_ref_data = w_ref_.mutable_data<float>();
for (int64_t i = 0; i < x_ref_.numel(); i++) {
x_ref_data[i] = static_cast<float>(i % 10 * 0.2);
}
for (int64_t i = 0; i < w_ref_.numel(); i++) {
w_ref_data[i] = static_cast<float>(i % 10 * 0.2);
}
out_ref_.Resize(lite::DDim(out_shape_));
out_cpu_.Resize(out_ref_.dims());
out_gpu_.Resize(out_ref_.dims());
batch_gate_gpu_.Resize(lite::DDim(x_shape_));
batch_hidden_gpu_.Resize(lite::DDim(out_shape_));
batch_reset_hidden_gpu_.Resize(lite::DDim(out_shape_));
RunBaseLine();
InitParamAndContext();
}
void InitParamAndContext() {
ctx_.reset(new KernelContext);
cudaStreamCreate(&stream_);
auto& context = ctx_->As<CUDAContext>();
context.SetExecStream(stream_);
param_.input = &x_gpu_;
param_.weight = &w_gpu_;
param_.gate_activation = gate_activation_;
param_.activation = activation_;
param_.is_reverse = is_reverse_;
param_.origin_mode = origin_mode_;
param_.hidden = &out_gpu_;
param_.batch_gate = &batch_gate_gpu_;
param_.batch_reset_hidden_prev = &batch_reset_hidden_gpu_;
param_.batch_hidden = &batch_hidden_gpu_;
}
void InitFloatInput() {
x_gpu_.Assign<float, lite::DDim, TARGET(kCUDA)>(x_ref_.data<float>(),
x_gpu_.dims());
x_gpu_.set_lod(x_ref_.lod());
w_gpu_.Assign<float, lite::DDim, TARGET(kCUDA)>(w_ref_.data<float>(),
w_gpu_.dims());
}
void RunBaseLine() {}
int batch_, frame_size_;
std::string activation_, gate_activation_;
bool is_reverse_, origin_mode_;
std::vector<int64_t> x_shape_, w_shape_, out_shape_;
LoD lod_;
lite::Tensor x_ref_, w_ref_, out_ref_;
lite::Tensor x_gpu_, w_gpu_;
lite::Tensor x_half_, w_half_;
lite::Tensor batch_gate_gpu_;
lite::Tensor batch_hidden_gpu_;
lite::Tensor batch_reset_hidden_gpu_;
lite::Tensor out_cpu_, out_gpu_;
operators::GRUParam param_;
std::unique_ptr<KernelContext> ctx_;
cudaStream_t stream_;
};
TEST_F(GRUTest, TestFP32) {
InitFloatInput();
GRUCompute<float, PRECISION(kFloat)> kernel;
kernel.SetParam(param_);
kernel.SetContext(std::move(ctx_));
for (int i = 0; i < FLAGS_warmup; ++i) {
kernel.Launch();
cudaDeviceSynchronize();
}
auto start = GetCurrentUS();
kernel.PrepareForRun();
for (int i = 0; i < FLAGS_repeats; ++i) {
kernel.Run();
}
cudaDeviceSynchronize();
auto duration = (GetCurrentUS() - start) / 1000.0;
LOG(INFO) << "fp32, warmup: " << FLAGS_warmup
<< ", repeats: " << FLAGS_repeats << ", spend "
<< duration / FLAGS_repeats << " ms in average.";
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -18,6 +18,7 @@ add_kernel(read_from_array_compute_host Host extra SRCS read_from_array_compute.
add_kernel(assign_compute_host Host extra SRCS assign_compute.cc DEPS ${lite_kernel_deps})
add_kernel(retinanet_detection_output_compute_host Host extra SRCS retinanet_detection_output_compute.cc DEPS ${lite_kernel_deps})
add_kernel(where_index_compute_host Host extra SRCS where_index_compute.cc DEPS ${lite_kernel_deps})
add_kernel(activation_grad_compute_host Host train SRCS activation_grad_compute.cc DEPS ${lite_kernel_deps})
if(LITE_BUILD_EXTRA)
lite_cc_test(test_where_index_compute_host SRCS where_index_compute.cc DEPS where_index_compute_host)
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/host/activation_grad_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
void SquareGradCompute::Run() {
auto& param = this->Param<param_t>();
CHECK(param.X);
auto out_grad_dims = param.Out_grad->dims();
auto out_grad_data = param.Out_grad->data<float>();
auto x_data = param.X->data<float>();
auto x_grad_data = param.X_grad->mutable_data<float>();
for (int i = 0; i < out_grad_dims.production(); i++) {
x_grad_data[i] = out_grad_data[i] * 2.0 * x_data[i];
}
}
void ReluGradCompute::Run() {
auto& param = this->Param<param_t>();
CHECK(param.X);
auto out_grad_dims = param.Out_grad->dims();
auto out_grad_data = param.Out_grad->data<float>();
auto x_data = param.X->data<float>();
auto x_grad_data = param.X_grad->mutable_data<float>();
for (int i = 0; i < out_grad_dims.production(); i++) {
x_grad_data[i] = x_data[i] > 0 ? out_grad_data[i] : 0.0;
}
}
void TanhGradCompute::Run() {
auto& param = this->Param<param_t>();
CHECK(param.Out);
auto out_grad_dims = param.Out_grad->dims();
auto out_grad_data = param.Out_grad->data<float>();
auto out_data = param.Out->data<float>();
auto x_grad_data = param.X_grad->mutable_data<float>();
for (int i = 0; i < out_grad_dims.production(); i++) {
x_grad_data[i] = out_grad_data[i] *
(static_cast<float>(1.0) - out_data[i] * out_data[i]);
}
}
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(square_grad,
kHost,
kFloat,
kNCHW,
paddle::lite::kernels::host::SquareGradCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kHost))})
.BindInput("Out@GRAD", {LiteType::GetTensorTy(TARGET(kHost))})
.BindOutput("X@GRAD", {LiteType::GetTensorTy(TARGET(kHost))})
.Finalize();
REGISTER_LITE_KERNEL(relu_grad,
kHost,
kFloat,
kNCHW,
paddle::lite::kernels::host::SquareGradCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kHost))})
.BindInput("Out@GRAD", {LiteType::GetTensorTy(TARGET(kHost))})
.BindOutput("X@GRAD", {LiteType::GetTensorTy(TARGET(kHost))})
.Finalize();
REGISTER_LITE_KERNEL(tanh_grad,
kHost,
kFloat,
kNCHW,
paddle::lite::kernels::host::SquareGradCompute,
def)
.BindInput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
.BindInput("Out@GRAD", {LiteType::GetTensorTy(TARGET(kHost))})
.BindOutput("X@GRAD", {LiteType::GetTensorTy(TARGET(kHost))})
.Finalize();
此差异已折叠。
此差异已折叠。
此差异已折叠。
......@@ -49,8 +49,7 @@ int LrnConverter(void* ctx, OpLite* op, KernelBase* kernel) {
<< "Unsuport WithinChannel";
}
auto local_size = op_info->GetAttr<int>("n");
CHECK(op_info->HasAttr("input_scale"));
auto input_scale = op_info->GetAttr<float>("input_scale");
auto input_scale = op_info->GetInputScale(x_var_name)[0];
VLOG(5) << "lrn input scale: " << input_scale;
cnmlLrnOpParam_t param;
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册