diff --git a/cmake/device/xpu.cmake b/cmake/device/xpu.cmake index 145a2394986e5cf03c75dfc367e5997c3ad75731..16fc7dcf4191a6b2a145d4d6e70e915fe5321a6b 100644 --- a/cmake/device/xpu.cmake +++ b/cmake/device/xpu.cmake @@ -39,7 +39,7 @@ else() endif() find_library(XPU_SDK_XPU_RT_FILE NAMES xpurt - PATHS ${XPU_SDK_ROOT}/XTDK/runtime/shlib + PATHS ${XPU_SDK_ROOT}/XTDK/runtime/shlib ${XPU_SDK_ROOT}/XTDK/shlib # libxpurt.so may have been moved to XTDK/runtime/shlib NO_DEFAULT_PATH) if(NOT XPU_SDK_XPU_RT_FILE) diff --git a/cmake/external/flatbuffers.cmake b/cmake/external/flatbuffers.cmake index 12c6b162f686f0c08f1c90610767b3508130d0da..7c6374b40b92a8807c5bb9529d907c576f6ad05c 100644 --- a/cmake/external/flatbuffers.cmake +++ b/cmake/external/flatbuffers.cmake @@ -97,7 +97,7 @@ function(compile_flatbuffers_schema_to_cpp_opt TARGET SRC_FBS OPT) OUTPUT ${GEN_HEADER} COMMAND "${FLATBUFFERS_FLATC_EXECUTABLE}" --cpp --gen-mutable --gen-object-api --reflect-names - --cpp-ptr-type flatbuffers::unique_ptr # Used to test with C++98 STLs + --force-empty --force-empty-vectors ${OPT} -I "${CMAKE_CURRENT_SOURCE_DIR}/tests/include_test" -o "${CMAKE_CURRENT_SOURCE_DIR}/${SRC_FBS_DIR}" diff --git a/docs/demo_guides/python_demo.md b/docs/demo_guides/python_demo.md index d6a7b15bd9be638ef586e6b589e35eecbf1613c2..59f81783c0b2e791f9623e84cf57c269cbb7d6f2 100644 --- a/docs/demo_guides/python_demo.md +++ b/docs/demo_guides/python_demo.md @@ -86,19 +86,28 @@ config.set_model_from_file(/YOU_MODEL_PATH/mobilenet_v1_opt.nb) predictor = create_paddle_predictor(config) ``` -(3) 设置输入数据 +(3) 从图片读入数据 + +```python +image = Image.open('./example.jpg') +resized_image = image.resize((224, 224), Image.BILINEAR) +image_data = np.array(resized_image).flatten().tolist() +``` + +(4) 设置输入数据 + ```python input_tensor = predictor.get_input(0) input_tensor.resize([1, 3, 224, 224]) -input_tensor.set_float_data([1.] * 3 * 224 * 224) +input_tensor.set_float_data(image_data) ``` -(4) 执行预测 +(5) 执行预测 ```python predictor.run() ``` -(5) 得到输出数据 +(6) 得到输出数据 ```python output_tensor = predictor.get_output(0) print(output_tensor.shape()) diff --git a/lite/api/paddle_api.cc b/lite/api/paddle_api.cc index 1452e17c93a4fa0cb04acf618c7ea6139060ed8f..2bcfa9be1f8a601ace71291c7d820bc77d1acde6 100644 --- a/lite/api/paddle_api.cc +++ b/lite/api/paddle_api.cc @@ -24,6 +24,9 @@ #ifdef LITE_WITH_CUDA #include "lite/backends/cuda/target_wrapper.h" #endif +#ifdef LITE_WITH_XPU +#include "lite/backends/xpu/target_wrapper.h" +#endif #ifdef LITE_WITH_MLU #include "lite/backends/mlu/target_wrapper.h" @@ -272,7 +275,7 @@ CxxConfig::mlu_firstconv_param() const { void CxxConfig::set_xpu_workspace_l3_size_per_thread(int l3_size) { #ifdef LITE_WITH_XPU - lite::Context::SetWorkspaceL3Size(l3_size); + lite::TargetWrapperXPU::workspace_l3_size_per_thread = l3_size; #else LOG(WARNING) << "The invoking of the function " "'set_xpu_workspace_l3_size_per_thread' is ignored, please " @@ -282,7 +285,7 @@ void CxxConfig::set_xpu_workspace_l3_size_per_thread(int l3_size) { void CxxConfig::set_xpu_dev_per_thread(int dev_no) { #ifdef LITE_WITH_XPU - lite::Context::SetDev(dev_no); + lite::TargetWrapperXPU::SetDev(dev_no); #else LOG(WARNING) << "The invoking of the function 'set_xpu_dev_per_thread' is " "ignored, please rebuild it with LITE_WITH_XPU=ON."; @@ -291,7 +294,7 @@ void CxxConfig::set_xpu_dev_per_thread(int dev_no) { void CxxConfig::set_xpu_multi_encoder_precision(const std::string &precision) { #ifdef LITE_WITH_XPU - lite::Context::_multi_encoder_precision = precision; + lite::TargetWrapperXPU::multi_encoder_precision = precision; #else LOG(WARNING) << "The invoking of the function " "'set_xpu_multi_encoder_precision' is " diff --git a/lite/api/paddle_use_passes.h b/lite/api/paddle_use_passes.h index 485bd10770d6e5a29963f336dfdf6d47302ccbc0..2ec4965d3d526c82c41b51954f9564488c5126e1 100644 --- a/lite/api/paddle_use_passes.h +++ b/lite/api/paddle_use_passes.h @@ -55,6 +55,8 @@ USE_MIR_PASS(apu_subgraph_pass); USE_MIR_PASS(quantized_op_attributes_inference_pass); USE_MIR_PASS(lite_scale_activation_fuse_pass); USE_MIR_PASS(__xpu__resnet_fuse_pass); +USE_MIR_PASS(__xpu__resnet_cbam_fuse_pass); USE_MIR_PASS(__xpu__multi_encoder_fuse_pass); USE_MIR_PASS(__xpu__embedding_with_eltwise_add_fuse_pass); USE_MIR_PASS(__xpu__fc_fuse_pass); +USE_MIR_PASS(__xpu__mmdnn_fuse_pass); diff --git a/lite/api/test_yolov3_lite_bm.cc b/lite/api/test_yolov3_lite_bm.cc index d70ecf3c03955286244aa13cfe65f19569a55930..ded851d93313c3e155dd7f8860eee7446e56e715 100644 --- a/lite/api/test_yolov3_lite_bm.cc +++ b/lite/api/test_yolov3_lite_bm.cc @@ -59,9 +59,9 @@ void TestModel(const std::vector& valid_places) { } auto* image_tensor = predictor.GetInput(1); image_tensor->Resize(DDim(std::vector({1, 2}))); - data = image_tensor->mutable_data(); - data[0] = FLAGS_im_height; - data[1] = FLAGS_im_width; + auto* data_1 = image_tensor->mutable_data(); + data_1[0] = FLAGS_im_height; + data_1[1] = FLAGS_im_width; for (int i = 0; i < FLAGS_warmup; ++i) { predictor.Run(); diff --git a/lite/backends/arm/math/activation.cc b/lite/backends/arm/math/activation.cc index 8e94e212fcb5ff83e8dbfa9d70652cbdaca50656..01f25cbd36d327f7a3c252fdc675262d39748318 100644 --- a/lite/backends/arm/math/activation.cc +++ b/lite/backends/arm/math/activation.cc @@ -763,24 +763,6 @@ void act_thresholded_relu( } } -#ifdef LITE_WITH_TRAIN -template <> -void act_square_grad(const float* din, - const float* dout_grad, - float* din_grad, - int size, - int threads) { - const float* ptr_out_grad = dout_grad; - float* ptr_in_grad = din_grad; - for (int i = 0; i < size; ++i) { - ptr_in_grad[0] = ptr_out_grad[0] * 2.0 * din[0]; - ptr_out_grad++; - ptr_in_grad++; - din++; - } -} -#endif - } // namespace math } // namespace arm } // namespace lite diff --git a/lite/backends/arm/math/activation.h b/lite/backends/arm/math/activation.h index 0a849e9ec711a8c554388d9b69a25b79a7b392ec..b0147040cd11a888ec045948f0914a13aa932a2f 100644 --- a/lite/backends/arm/math/activation.h +++ b/lite/backends/arm/math/activation.h @@ -90,12 +90,6 @@ template void act_thresholded_relu( const T* din, T* dout, int size, float threshold, int threads); -#ifdef LITE_WITH_TRAIN -template -void act_square_grad( - const T* din, const T* dout_grad, T* din_grad, int size, int threads); -#endif - } // namespace math } // namespace arm } // namespace lite diff --git a/lite/backends/arm/math/conv_block_utils.h b/lite/backends/arm/math/conv_block_utils.h index 42a98bc9442b2a619cf5882783bb63f5c4ea7db4..9fe7fe4930fe069f9d29027ff316efe775b9e7c3 100644 --- a/lite/backends/arm/math/conv_block_utils.h +++ b/lite/backends/arm/math/conv_block_utils.h @@ -139,6 +139,91 @@ static bool conv_trans_weights_numc(const dtype* din, } return true; } +// for example: m = 4, n = 4 +// din = [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9 , 10 ,11], [12, 13, 14, 15]] +// dout = [[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]] +/* + m = 4 n = 4: 0 1 2 3 0 4 8 12 + 4 5 6 7 1 5 9 13 + 8 9 10 11 2 6 10 14 + 12 13 14 15 3 7 11 15 + m = 8 n = 4: 0 1 2 3 4 5 6 7 0 4 8 12 16 20 24 28 + 8 9 10 11 12 13 14 15 1 5 9 13 17 21 25 29 + 16 17 18 19 20 21 22 23 2 6 10 14 18 22 26 30 + 24 25 26 27 28 29 30 31 3 7 11 15 19 23 27 31 + m = 4 n = 8: 0 1 2 3 0 8 16 24 + 4 5 6 7 1 9 17 25 + 8 9 10 11 2 10 18 26 + 12 13 14 15 3 11 19 27 + 16 17 18 19 4 12 20 28 + ... + m = 8 n = 8: 0 1 2 3 4 5 6 7 0 8 16 24 32 40 48 56 + 8 9 10 11 12 13 14 15 1 9 17 25 33 41 49 57 + 16 17 18 19 20 21 22 23 2 10 18 26 34 42 50 58 + 24 25 26 27 28 29 30 31 3 11 19 27 35 43 51 59 + 32 33 34 35 36 37 38 39 4 12 20 28 36 44 52 60 + ... + int k = 0; + for (int i = 0; i < n; ++i) { + for (int j = 0; j < m; ++j) { + dout[k++] = din[j * n + i]; + } + } +*/ +template +void local_transpose(onst Dtype* din, Dtype* dout, int m, int n) { + // n % 4 == 0 m % 4 == 0 + // n * m ==> n * m data trans + int offset_m = m << 2; + const Dtype* din_ptr = din; + Dtype* dout_ptr = dout; + for (int i = 0; i < n; i += 4) { + Dtype* out_ptr0 =dout_ptr; + Dtype* out_ptr1 = dout_ptr + m; + Dtype* out_ptr2 = out_ptr1 + m; + Dtype* out_ptr3 = out_ptr2 + m; + const Dtype* in_ptr0 = din_ptr; + const Dtype* in_ptr1 = din_ptr + m; + const Dtype* in_ptr2 = in_ptr1 + m; + const Dtype* in_ptr3 = in_ptr2 + m; + for (int j = 0; j < m; j += 4) { + float32x4_t vin0 = vld1q_f32(in_ptr0); + float32x4_t vin1 = vld1q_f32(in_ptr1); + float32x4_t vin2 = vld1q_f32(in_ptr2); + float32x4_t vin3 = vld1q_f32(in_ptr3); + // a00 b00 a02 b02 a01 b01 a03 b03 + float32x4x2_t tmp0 = vtrnq_f32(vin0, vin1); + // c00 d00 c02 d02 c01 d01 c03 d03 + float32x4x2_t tmp2 = vtrnq_f32(vin2, vin3); + in_ptr0 = in_ptr3 + m; + in_ptr1 = in_ptr3 + 2 * m; + float tmp_val1 = tmp0.val[0][2]; + float tmp_val2 = tmp0.val[0][3]; + tmp0.val[0][2] = tmp2.val[0][0]; + tmp0.val[0][3] = tmp2.val[0][1]; + float tmp_val3 = tmp0.val[1][2]; + float tmp_val4 = tmp0.val[1][3]; + tmp2.val[0][0] = tmp_val1; + tmp2.val[0][1] = tmp_val2; + tmp0.val[1][2] = tmp2.val[1][0]; + tmp0.val[1][3] = tmp2.val[1][1]; + tmp2.val[1][0] = tmp_val3; + tmp2.val[1][1] = tmp_val4; + in_ptr2 = in_ptr1 + m; + in_ptr3 = in_ptr1 + 2 * m; + vst1q_f32(out_ptr0, tmp0.val[0]); + vst1q_f32(out_ptr1, tmp0.val[1]); + out_ptr0 += 4; + out_ptr1 += 4; + vst1q_f32(out_ptr2, tmp2.val[0]); + vst1q_f32(out_ptr3, tmp2.val[1]); + out_ptr2 += 4; + out_ptr3 += 4; + } + dout_ptr += offset_m; + din_ptr += 4; + } +} template void transpose(const Dtype* din, Dtype* dout, int m, int n) { // nxm == mxn diff --git a/lite/backends/cuda/math/CMakeLists.txt b/lite/backends/cuda/math/CMakeLists.txt index 7f96308a5dcaf5742bd5dcef7c2e5f146cdb7c59..c23d3d0ed0351b59d4a373efb2474e9a73763659 100644 --- a/lite/backends/cuda/math/CMakeLists.txt +++ b/lite/backends/cuda/math/CMakeLists.txt @@ -11,10 +11,13 @@ nv_library(cuda_transpose SRCS transpose.cu DEPS ${cuda_static_deps}) nv_library(cudnn_conv SRCS cudnn_conv.cc DEPS cuda_activation cuda_scale cuda_type_trans ${cuda_static_deps}) nv_library(cuda_elementwise SRCS elementwise.cu DEPS ${cuda_static_deps}) nv_library(cudnn_pool SRCS cudnn_pool.cc DEPS ${cuda_static_deps}) +nv_library(cuda_gru_forward SRCS gru_forward.cu DEPS cuda_activation ${cuda_static_deps}) +nv_library(cuda_sequence2batch SRCS sequence2batch.cu DEPS ${cuda_static_deps}) nv_library(cuda_gemm SRCS gemm.cc DEPS ${cuda_static_deps}) nv_library(cuda_batched_gemm SRCS batched_gemm.cc DEPS ${cuda_static_deps}) nv_library(cuda_strided_gemm SRCS strided_gemm.cc DEPS ${cuda_static_deps}) nv_library(cuda_sequence_padding SRCS sequence_padding.cu DEPS ${cuda_static_deps}) +nv_library(cuda_bias SRCS bias.cu DEPS ${cuda_static_deps}) set ( math_cuda @@ -25,10 +28,13 @@ set ( cuda_transpose cuda_elementwise cudnn_pool + cuda_gru_forward + cuda_sequence2batch cuda_gemm cuda_batched_gemm cuda_strided_gemm cuda_sequence_padding + cuda_bias ) set(math_cuda "${math_cuda}" CACHE GLOBAL "math cuda") diff --git a/lite/backends/cuda/math/activation.cu b/lite/backends/cuda/math/activation.cu index a45e3eb378eefdbabce0b837891514dc659e0429..7524fbc4fbe34806358b06187639055703387e7b 100644 --- a/lite/backends/cuda/math/activation.cu +++ b/lite/backends/cuda/math/activation.cu @@ -21,6 +21,20 @@ namespace lite { namespace cuda { namespace math { +ActivationType GetActiveType(const std::string& act) { + if (act == "sigmoid") { + return kSigmoid; + } else if (act == "relu") { + return kReLU; + } else if (act == "tanh") { + return kTanh; + } else if (act == "identify") { + return kIdentity; + } else { + LOG(FATAL) << "not supported activation: " << act; + } +} + template __global__ void relu_kernel(const int num, const float alpha, diff --git a/lite/backends/cuda/math/activation.h b/lite/backends/cuda/math/activation.h index 887a222ee83878aa19fd6a94a76572e48ab4d954..0150a32865ca081aaa51d540e2a6ee5757c37d91 100644 --- a/lite/backends/cuda/math/activation.h +++ b/lite/backends/cuda/math/activation.h @@ -17,11 +17,22 @@ #include #include +#include "lite/utils/cp_logging.h" + namespace paddle { namespace lite { namespace cuda { namespace math { +enum ActivationType { + kSigmoid, + kReLU, + kTanh, + kIdentity, +}; + +ActivationType GetActiveType(const std::string& act); + // fp32 and half template void relu(int num, const T* din, T* dout, float alpha, cudaStream_t stream); diff --git a/lite/backends/cuda/math/bias.cu b/lite/backends/cuda/math/bias.cu new file mode 100644 index 0000000000000000000000000000000000000000..392abb70967954860a72d1c32f266f6c159fa587 --- /dev/null +++ b/lite/backends/cuda/math/bias.cu @@ -0,0 +1,51 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "lite/backends/cuda/math/bias.h" + +#include + +#include "lite/backends/cuda/cuda_utils.h" + +namespace paddle { +namespace lite { +namespace cuda { +namespace math { + +template +__global__ void RowwiseAddKernel( + const T* a, const T* b, T* c, int width, int num) { + CUDA_KERNEL_LOOP(i, num) { + int h = i / width; + int w = i - h * width; + c[i] = a[i] + b[w]; + } +} +template +void RowwiseAdd::operator()(const T* input, + const T* bias, + T* output, + const int width, + const int count, + const cudaStream_t& stream) { + RowwiseAddKernel<<>>( + input, bias, output, width, count); + CUDA_POST_KERNEL_CHECK; +} + +template struct RowwiseAdd; + +} // namespace math +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/lite/backends/cuda/math/bias.h b/lite/backends/cuda/math/bias.h new file mode 100644 index 0000000000000000000000000000000000000000..98f805a013ff80b267301be4d47a9694c5ce642f --- /dev/null +++ b/lite/backends/cuda/math/bias.h @@ -0,0 +1,39 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +#include "lite/backends/cuda/cuda_utils.h" + +namespace paddle { +namespace lite { +namespace cuda { +namespace math { + +template +struct RowwiseAdd { + void operator()(const T* input, + const T* bias, + T* output, + const int width, + const int count, + const cudaStream_t& stream); +}; + +} // namespace math +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/lite/backends/cuda/math/gru_forward.cu b/lite/backends/cuda/math/gru_forward.cu new file mode 100644 index 0000000000000000000000000000000000000000..b5654c83ac2a80e2d6a6fddbf71eadd27a38ab69 --- /dev/null +++ b/lite/backends/cuda/math/gru_forward.cu @@ -0,0 +1,137 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "lite/backends/cuda/math/gru_forward.h" +#include "lite/core/device_info.h" + +namespace paddle { +namespace lite { +namespace cuda { +namespace math { + +template +__global__ void GruForwardResetOutput( + T* gate_value, + T* reset_output_value, + T* prev_output_value, + int frame_size, + int batch_size, + lite::cuda::math::ActivationType active_gate, + bool is_batch) { + const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (frame_idx >= frame_size) return; + int batch_idx = 0; + if (is_batch) { + batch_idx = blockIdx.y * blockDim.y + threadIdx.y; + if (batch_idx >= batch_size) return; + gate_value += batch_idx * 3 * frame_size; + reset_output_value += batch_idx * frame_size; + } + T prev_out = 0; + T reset_out_val; + T update_gate_value = gate_value[frame_idx + frame_size * 0]; + T reset_gate_value = gate_value[frame_idx + frame_size * 1]; + if (prev_output_value) { + if (is_batch) { + prev_output_value += batch_idx * frame_size; + } + prev_out = prev_output_value[frame_idx]; + } + if (active_gate == lite::cuda::math::ActivationType::kSigmoid) { + update_gate_value = Sigmoid(update_gate_value); + reset_gate_value = Sigmoid(reset_gate_value); + } else if (active_gate == lite::cuda::math::ActivationType::kReLU) { + update_gate_value = ReLU(update_gate_value); + reset_gate_value = ReLU(reset_gate_value); + } else if (active_gate == lite::cuda::math::ActivationType::kTanh) { + update_gate_value = Tanh(update_gate_value); + reset_gate_value = Tanh(reset_gate_value); + } + reset_out_val = prev_out * reset_gate_value; + gate_value[frame_idx + frame_size * 0] = update_gate_value; + gate_value[frame_idx + frame_size * 1] = reset_gate_value; + reset_output_value[frame_idx] = reset_out_val; +} + +template +__global__ void GruForwardFinalOutput( + T* gate_value, + T* prev_output_value, + T* output_value, + int frame_size, + int batch_size, + lite::cuda::math::ActivationType active_node, + bool origin_mode, + bool is_batch) { + const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (frame_idx >= frame_size) return; + int batch_idx = 0; + if (is_batch) { + batch_idx = blockIdx.y * blockDim.y + threadIdx.y; + if (batch_idx >= batch_size) { + return; + } + gate_value += batch_idx * 3 * frame_size; + output_value += batch_idx * frame_size; + } + T output; + T prev_out = 0; + T update_gate_value = gate_value[frame_idx + frame_size * 0]; + T state_frame_value = gate_value[frame_idx + frame_size * 2]; + if (prev_output_value) { + if (is_batch) prev_output_value += batch_idx * frame_size; + prev_out = prev_output_value[frame_idx]; + } + if (active_node == lite::cuda::math::ActivationType::kSigmoid) { + state_frame_value = Sigmoid(state_frame_value); + } else if (active_node == lite::cuda::math::ActivationType::kReLU) { + state_frame_value = ReLU(state_frame_value); + } else if (active_node == lite::cuda::math::ActivationType::kTanh) { + state_frame_value = Tanh(state_frame_value); + } + if (origin_mode) { + output = update_gate_value * prev_out + state_frame_value - + update_gate_value * state_frame_value; + } else { + output = prev_out - update_gate_value * prev_out + + update_gate_value * state_frame_value; + } + gate_value[frame_idx + frame_size * 2] = state_frame_value; + output_value[frame_idx] = output; +} + +template __global__ void GruForwardFinalOutput( + float* gate_value, + float* prev_output_value, + float* output_value, + int frame_size, + int batch_size, + lite::cuda::math::ActivationType active_node, + bool origin_mode, + bool is_batch); +template __global__ void GruForwardResetOutput( + float* gate_value, + float* reset_output_value, + float* prev_output_value, + int frame_size, + int batch_size, + lite::cuda::math::ActivationType active_gate, + bool is_batch); + +} // namespace math +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/lite/backends/cuda/math/gru_forward.h b/lite/backends/cuda/math/gru_forward.h new file mode 100644 index 0000000000000000000000000000000000000000..ed2a4895a865a398e2d3dd47fdc271e887f9221f --- /dev/null +++ b/lite/backends/cuda/math/gru_forward.h @@ -0,0 +1,71 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include +#include + +#include "lite/api/paddle_place.h" +#include "lite/backends/cuda/cuda_utils.h" +#include "lite/backends/cuda/math/activation.h" +#include "lite/core/context.h" +#include "lite/core/target_wrapper.h" +#include "lite/operators/op_params.h" + +namespace paddle { +namespace lite { +namespace cuda { +namespace math { + +template +inline __device__ Dtype Sigmoid(const Dtype a) { + return static_cast(1.0) / (static_cast(1.0) + expf(-a)); +} +template +inline __device__ Dtype ReLU(const Dtype a) { + return a > static_cast(0.f) ? a : static_cast(0.f); +} +template +inline __device__ Dtype Tanh(const Dtype a) { + Dtype tmp = static_cast(-2.0) * a; + return (static_cast(2.0) / (static_cast(1.0) + expf(tmp))) - + static_cast(1.0); +} + +template +__global__ void GruForwardResetOutput( + T* gate_value, + T* reset_output_value, + T* prev_output_value, + int frame_size, + int batch_size, + lite::cuda::math::ActivationType active_gate, + bool is_batch); +template +__global__ void GruForwardFinalOutput( + T* gate_value, + T* prev_output_value, + T* output_value, + int frame_size, + int batch_size, + lite::cuda::math::ActivationType active_node, + bool origin_mode, + bool is_batch); + +} // namespace math +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/lite/backends/cuda/math/sequence2batch.cu b/lite/backends/cuda/math/sequence2batch.cu new file mode 100644 index 0000000000000000000000000000000000000000..b6c9c77d085be369913fa52cc5a2e0df9c78af92 --- /dev/null +++ b/lite/backends/cuda/math/sequence2batch.cu @@ -0,0 +1,86 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "lite/backends/cuda/cuda_utils.h" +#include "lite/backends/cuda/math/sequence2batch.h" +#include "lite/backends/cuda/math/utils.h" + +namespace paddle { +namespace lite { +namespace cuda { +namespace math { + +template +__global__ void CopyMatrixRowsKernel(const T* src, + T* dst, + const uint64_t* index, + int height, + int width, + bool is_src_index) { + int idx = threadIdx.x; + int idy = threadIdx.y; + int row_id = blockDim.y * gridDim.x + idy; + if (row_id < height) { + int src_idx = is_src_index ? index[row_id] : row_id; + int dst_idx = is_src_index ? row_id : index[row_id]; + const T* src_data = src + src_idx * width; + T* dst_data = dst + dst_idx * width; + for (int i = idx; i < width; i += blockDim.x) { + dst_data[i] = src_data[i]; + } + } +} + +template +void CopyMatrixRowsFunctor::operator()( + const lite::Tensor& src, + lite::Tensor* dst, + const std::vector& index_lod, + bool is_src_index, + const cudaStream_t& stream) { + auto src_dims = src.dims(); + auto dst_dims = dst->dims(); + CHECK_EQ(src_dims.size(), 2) << "The src must be matrix with rank 2."; + CHECK_EQ(dst_dims.size(), 2) << "The dst must be matrix with rank 2."; + CHECK_EQ(src_dims[1], dst_dims[1]) + << "The width of src and dst must be same."; + int height = dst_dims[0]; + int width = dst_dims[1]; + const auto* src_data = src.data(); + auto* dst_data = dst->template mutable_data(TARGET(kCUDA)); + + index_tensor_.Resize({static_cast(index_lod.size())}); + auto* index_tensor_data = index_tensor_.mutable_data(TARGET(kCUDA)); + TargetWrapperCuda::MemcpyAsync(index_tensor_data, + index_lod.data(), + sizeof(uint64_t) * index_lod.size(), + IoDirection::HtoD, + stream); + dim3 threads(128, 8); + dim3 grids((height + threads.y - 1) / threads.y); + CopyMatrixRowsKernel<<>>( + src_data, dst_data, index_tensor_data, height, width, true); + CUDA_POST_KERNEL_CHECK; +} + +template class CopyMatrixRowsFunctor; +template class LoDTensor2BatchFunctor; +template class Batch2LoDTensorFunctor; + +} // namespace math +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/lite/backends/cuda/math/sequence2batch.h b/lite/backends/cuda/math/sequence2batch.h new file mode 100644 index 0000000000000000000000000000000000000000..2fb333c83a3152e33cc8b3bd0929c9f10be9f2a1 --- /dev/null +++ b/lite/backends/cuda/math/sequence2batch.h @@ -0,0 +1,130 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +#include +#include +#include + +#include "lite/backends/cuda/cuda_utils.h" +#include "lite/core/context.h" +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { +namespace cuda { +namespace math { + +template +class CopyMatrixRowsFunctor { + public: + void operator()(const lite::Tensor& src, + lite::Tensor* dst, + const std::vector& index_lod, + bool is_src_index, + const cudaStream_t& stream); + + private: + lite::Tensor index_tensor_; +}; + +template +class LoDTensor2BatchFunctor { + struct SeqInfo { + SeqInfo(size_t start, size_t length, size_t seq_idx) + : start_(start), length_(length), seq_idx_(seq_idx) {} + size_t start_; + size_t length_; + size_t seq_idx_; + }; + + public: + void operator()(const lite::Tensor& lod_tensor, + lite::Tensor* batch_tensor, + bool is_reverse, + const cudaStream_t& stream) const { + auto lods = lod_tensor.lod(); + CHECK_EQ(lods.size(), 1UL) << "Only support one level sequence now."; + const auto& lod = lods[0]; + std::vector seq_info; + for (int seq_id = 0; seq_id < static_cast(lod.size()) - 1; ++seq_id) { + size_t length = lod[seq_id + 1] - lod[seq_id]; + seq_info.emplace_back(lod[seq_id], length, seq_id); + } + std::sort(seq_info.begin(), seq_info.end(), [](SeqInfo a, SeqInfo b) { + return a.length_ > b.length_; + }); + LoD batch_lods; + batch_lods.emplace_back(std::vector{0}); + batch_lods.emplace_back(std::vector{0}); + batch_lods.emplace_back(std::vector{0}); + size_t max_seqlen = seq_info[0].length_; + batch_lods[0].resize(max_seqlen + 1); + batch_lods[1].resize(static_cast(lod_tensor.dims()[0])); + batch_lods[2].resize(seq_info.size()); + + auto* batch_starts = batch_lods[0].data(); + auto* seq2batch_idx = batch_lods[1].data(); + batch_starts[0] = 0; + for (size_t n = 0; n < max_seqlen; ++n) { + size_t batch_id = batch_starts[n]; + for (size_t i = 0; i < seq_info.size(); ++i) { + size_t seq_len = seq_info[i].length_; + size_t start = seq_info[i].start_; + if (n < seq_len) { + seq2batch_idx[batch_id] = + is_reverse ? start + seq_len - 1 - n : start + n; + ++batch_id; + } else { + break; + } + } + batch_starts[n + 1] = batch_id; + } + auto* seq_order = batch_lods[2].data(); + for (size_t i = 0; i < seq_info.size(); ++i) { + seq_order[i] = seq_info[i].seq_idx_; + } + + batch_tensor->set_lod(batch_lods); + lite::cuda::math::CopyMatrixRowsFunctor to_batch; + to_batch(lod_tensor, batch_tensor, batch_lods[1], true, stream); + CUDA_POST_KERNEL_CHECK; + } +}; + +template +class Batch2LoDTensorFunctor { + public: + void operator()(const lite::Tensor& batch_tensor, + lite::Tensor* lod_tensor, + const cudaStream_t& stream) { + auto in_lod = batch_tensor.lod(); + CHECK_GT(in_lod.size(), 2UL) << "The LoD of LoDTensor should include at " + "least 2-level sequence infomation."; + CHECK_EQ(in_lod[1].size(), static_cast(lod_tensor->dims()[0])) + << "The LoD information should be consistent with the dims."; + lite::cuda::math::CopyMatrixRowsFunctor to_seq; + to_seq(batch_tensor, lod_tensor, in_lod[1], false, stream); + CUDA_POST_KERNEL_CHECK; + } +}; + +} // namespace math +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/lite/backends/cuda/target_wrapper.h b/lite/backends/cuda/target_wrapper.h index 3eeee84c1c46a65782e38b998bcd8142e08cbec1..caa9b3077fe96bf73e50b33688b90b71e0cd5c23 100644 --- a/lite/backends/cuda/target_wrapper.h +++ b/lite/backends/cuda/target_wrapper.h @@ -15,6 +15,7 @@ #pragma once #include #include +#include "lite/backends/cuda/cuda_utils.h" #include "lite/core/target_wrapper.h" namespace paddle { @@ -31,6 +32,16 @@ class TargetWrapper { static size_t num_devices(); static size_t maximum_stream() { return 0; } + static int GetComputeCapability() { + int dev_id = GetCurDevice(); + int major, minor; + CUDA_CALL(cudaDeviceGetAttribute( + &major, cudaDevAttrComputeCapabilityMajor, dev_id)); + CUDA_CALL(cudaDeviceGetAttribute( + &minor, cudaDevAttrComputeCapabilityMinor, dev_id)); + return major * 10 + minor; + } + static size_t GetCurDevice() { int dev_id; cudaGetDevice(&dev_id); diff --git a/lite/backends/opencl/cl_context.cc b/lite/backends/opencl/cl_context.cc index 67d679fdd596b109b714bf7ba3cd45b2632b9420..002073517bc61af60da213db9af6e56da5f5b501 100644 --- a/lite/backends/opencl/cl_context.cc +++ b/lite/backends/opencl/cl_context.cc @@ -119,7 +119,7 @@ cl::NDRange CLContext::DefaultWorkSize(const CLImage &image) { } } -cl::NDRange CLContext::LocalWorkSizeTurn(cl::NDRange global_work_size, +cl::NDRange CLContext::LocalWorkSizeTune(cl::NDRange global_work_size, size_t max_work_size, int divisor) { int preferred_lws = 0; @@ -157,7 +157,7 @@ cl::NDRange CLContext::LocalWorkSizeTurn(cl::NDRange global_work_size, static_cast(gws0)}; #endif } -cl::NDRange CLContext::LocalWorkSizeTurnReverse(cl::NDRange global_work_size, +cl::NDRange CLContext::LocalWorkSizeTuneReverse(cl::NDRange global_work_size, size_t max_work_size, int divisor) { int preferred_lws = 0; diff --git a/lite/backends/opencl/cl_context.h b/lite/backends/opencl/cl_context.h index 82d15bee5ec460a1fb06430571f007fcef23f66f..c204a8510402b8741c761938c3b2c37ac07fe961 100644 --- a/lite/backends/opencl/cl_context.h +++ b/lite/backends/opencl/cl_context.h @@ -62,10 +62,10 @@ class CLContext { cl::NDRange LocalWorkSize(cl::NDRange global_work_size, size_t max_work_size); - cl::NDRange LocalWorkSizeTurn(cl::NDRange global_work_size, + cl::NDRange LocalWorkSizeTune(cl::NDRange global_work_size, size_t max_work_size, int divitor = 2); - cl::NDRange LocalWorkSizeTurnReverse(cl::NDRange global_work_size, + cl::NDRange LocalWorkSizeTuneReverse(cl::NDRange global_work_size, size_t max_work_size, int divitor = 2); bool IsArmMali(); diff --git a/lite/backends/opencl/cl_kernel/image/conv2d_1x1_opt_kernel.cl b/lite/backends/opencl/cl_kernel/image/conv2d_1x1_opt_kernel.cl index 1c808da68ddc923e12234bc4b6ac99b35bfffb0b..9209f0e0f8d04fad5e788f3742c7922af8e13f49 100644 --- a/lite/backends/opencl/cl_kernel/image/conv2d_1x1_opt_kernel.cl +++ b/lite/backends/opencl/cl_kernel/image/conv2d_1x1_opt_kernel.cl @@ -6,9 +6,7 @@ __kernel void conv2d_1x1_opt( __private const int global_size_dim2, __read_only image2d_t input_image, __read_only image2d_t filter, -#if defined(BIASE_CH) || defined(BIASE_ELE) __read_only image2d_t bias, -#endif #ifdef BATCH_NORM __read_only image2d_t new_scale, __read_only image2d_t new_biase, @@ -284,9 +282,7 @@ __kernel void conv2d_1x1_simple( __private const int global_size_dim2, __read_only image2d_t input_image, __read_only image2d_t filter, -#if defined(BIASE_CH) || defined(BIASE_ELE) __read_only image2d_t bias, -#endif #ifdef BATCH_NORM __read_only image2d_t new_scale, __read_only image2d_t new_biase, diff --git a/lite/backends/opencl/cl_kernel/image/conv2d_3x3_kernel.cl b/lite/backends/opencl/cl_kernel/image/conv2d_3x3_kernel.cl index 771765ea6063a08784ae824a757b28450d808f6d..6a3aa6455daf8d20430a434ff6f47dac382f1f74 100644 --- a/lite/backends/opencl/cl_kernel/image/conv2d_3x3_kernel.cl +++ b/lite/backends/opencl/cl_kernel/image/conv2d_3x3_kernel.cl @@ -19,9 +19,7 @@ __kernel void conv2d_3x3(__private const int global_size_dim0, __private const int global_size_dim2, __read_only image2d_t input_image, __read_only image2d_t filter, -#if defined(BIASE_CH) || defined(BIASE_ELE) __read_only image2d_t bias, -#endif __write_only image2d_t output_image, __private const int stride, __private const int offset, diff --git a/lite/backends/opencl/cl_kernel/image/conv2d_3x3_opt_kernel.cl b/lite/backends/opencl/cl_kernel/image/conv2d_3x3_opt_kernel.cl index 79f3922e89549fc15b7a849efb0e2b6595357102..739f852a7c6b60e4c38cb2523dfb745af65bc8df 100644 --- a/lite/backends/opencl/cl_kernel/image/conv2d_3x3_opt_kernel.cl +++ b/lite/backends/opencl/cl_kernel/image/conv2d_3x3_opt_kernel.cl @@ -19,9 +19,7 @@ __kernel void conv2d_3x3_opt(__private const int item_ch, __private const int item_h, __read_only image2d_t input_image, __read_only image2d_t filter_image, -#if defined(BIASE_CH) || defined(BIASE_ELE) __read_only image2d_t bias, -#endif __write_only image2d_t output_image, __private const int stride, __private const int pad, @@ -264,9 +262,7 @@ __kernel void conv2d_3x3_multi_batch(__private const int item_ch, __private const int item_h, __read_only image2d_t input_image, __read_only image2d_t filter_image, -#if defined(BIASE_CH) || defined(BIASE_ELE) __read_only image2d_t bias, -#endif __write_only image2d_t output_image, __private const int stride, __private const int pad, diff --git a/lite/backends/opencl/cl_kernel/image/conv2d_5x5_kernel.cl b/lite/backends/opencl/cl_kernel/image/conv2d_5x5_kernel.cl index d856af6a1d4026b1595bc287901e53f64267dc81..f08d53fa4968d041337adfe3252529bca3b5c55e 100644 --- a/lite/backends/opencl/cl_kernel/image/conv2d_5x5_kernel.cl +++ b/lite/backends/opencl/cl_kernel/image/conv2d_5x5_kernel.cl @@ -5,9 +5,7 @@ __kernel void conv2d_5x5(__private const int global_size_dim0, __private const int global_size_dim2, __read_only image2d_t input_image, __read_only image2d_t filter_image, -#if defined(BIASE_CH) || defined(BIASE_ELE) __read_only image2d_t bias, -#endif #ifdef BATCH_NORM __read_only image2d_t new_scale, __read_only image2d_t new_biase, diff --git a/lite/backends/opencl/cl_kernel/image/conv2d_5x5_opt_kernel.cl b/lite/backends/opencl/cl_kernel/image/conv2d_5x5_opt_kernel.cl index 4ed2e072022dc4b457a86d634bf4bc21ab62bc45..4cce039f27b750950a1475ac266e0f5117c6d259 100644 --- a/lite/backends/opencl/cl_kernel/image/conv2d_5x5_opt_kernel.cl +++ b/lite/backends/opencl/cl_kernel/image/conv2d_5x5_opt_kernel.cl @@ -20,9 +20,7 @@ __kernel void conv2d_5x5_opt(__private const int item_ch, __private const int item_h, __read_only image2d_t input_image, __read_only image2d_t filter_image, -#if defined(BIASE_CH) || defined(BIASE_ELE) __read_only image2d_t bias, -#endif __write_only image2d_t output_image, __private const int stride, __private const int pad, @@ -268,9 +266,7 @@ __kernel void conv2d_5x5_multi_batch(__private const int item_ch, __private const int item_h, __read_only image2d_t input_image, __read_only image2d_t filter_image, -#if defined(BIASE_CH) || defined(BIASE_ELE) __read_only image2d_t bias, -#endif __write_only image2d_t output_image, __private const int stride, __private const int pad, @@ -513,4 +509,4 @@ __kernel void conv2d_5x5_multi_batch(__private const int item_ch, (int2)(out_w_base_id + out_w_id4, item_h_id), output[4]); } -} \ No newline at end of file +} diff --git a/lite/backends/opencl/cl_kernel/image/conv2d_7x7_kernel.cl b/lite/backends/opencl/cl_kernel/image/conv2d_7x7_kernel.cl index 4998dc99279fffad8750ef3b6495597e9fc4ad65..2a2f210601e760651ee850686391af3c040fbe7f 100644 --- a/lite/backends/opencl/cl_kernel/image/conv2d_7x7_kernel.cl +++ b/lite/backends/opencl/cl_kernel/image/conv2d_7x7_kernel.cl @@ -5,9 +5,7 @@ __kernel void conv2d_7x7(__private const int global_size_dim0, __private const int global_size_dim2, __read_only image2d_t input_image, __read_only image2d_t filter_image, -#if defined(BIASE_CH) || defined(BIASE_ELE) __read_only image2d_t bias, -#endif #ifdef BATCH_NORM __read_only image2d_t new_scale, __read_only image2d_t new_biase, diff --git a/lite/backends/opencl/cl_kernel/image/conv2d_7x7_opt_kernel.cl b/lite/backends/opencl/cl_kernel/image/conv2d_7x7_opt_kernel.cl index d82f4b4c96b586b6ecf948827402afd0766dcea4..4eadcd9f8032996abae04660b6878ab5beaff9a7 100644 --- a/lite/backends/opencl/cl_kernel/image/conv2d_7x7_opt_kernel.cl +++ b/lite/backends/opencl/cl_kernel/image/conv2d_7x7_opt_kernel.cl @@ -20,9 +20,7 @@ __kernel void conv2d_7x7_opt(__private const int item_ch, __private const int item_h, __read_only image2d_t input_image, __read_only image2d_t filter_image, -#if defined(BIASE_CH) || defined(BIASE_ELE) __read_only image2d_t bias, -#endif __write_only image2d_t output_image, __private const int stride, __private const int pad, @@ -268,9 +266,7 @@ __kernel void conv2d_7x7_multi_batch(__private const int item_ch, __private const int item_h, __read_only image2d_t input_image, __read_only image2d_t filter_image, -#if defined(BIASE_CH) || defined(BIASE_ELE) __read_only image2d_t bias, -#endif __write_only image2d_t output_image, __private const int stride, __private const int pad, @@ -513,4 +509,4 @@ __kernel void conv2d_7x7_multi_batch(__private const int item_ch, (int2)(out_w_base_id + out_w_id4, item_h_id), output[4]); } -} \ No newline at end of file +} diff --git a/lite/backends/opencl/cl_kernel/image/depthwise_conv2d_basic_kernel.cl b/lite/backends/opencl/cl_kernel/image/depthwise_conv2d_basic_kernel.cl index 27313aea23ed16ecc7a6763dfbbbe63bca18941a..465b9f8f925a130b4d1b059ab15e93bc29128ec7 100755 --- a/lite/backends/opencl/cl_kernel/image/depthwise_conv2d_basic_kernel.cl +++ b/lite/backends/opencl/cl_kernel/image/depthwise_conv2d_basic_kernel.cl @@ -19,9 +19,7 @@ __kernel void depth_conv2d(__private const int global_size_dim0, __private const int global_size_dim2, __read_only image2d_t input, __read_only image2d_t filter, -#if defined(BIASE_CH) || defined(BIASE_ELE) __read_only image2d_t bias, -#endif #ifdef BATCH_NORM __read_only image2d_t new_scale, __read_only image2d_t new_biase, diff --git a/lite/backends/opencl/cl_kernel/image/depthwise_conv2d_kernel.cl b/lite/backends/opencl/cl_kernel/image/depthwise_conv2d_kernel.cl index 5626fe6be7d451d4ffe22a2008affa7d82298bc3..6fbdc21f934f21dd26c3eb66885f7087e3d340c0 100755 --- a/lite/backends/opencl/cl_kernel/image/depthwise_conv2d_kernel.cl +++ b/lite/backends/opencl/cl_kernel/image/depthwise_conv2d_kernel.cl @@ -20,9 +20,7 @@ __kernel void depth_conv2d_3x3( __private const int global_size_dim2, __read_only image2d_t input, __read_only image2d_t filter, -#if defined(BIASE_CH) || defined(BIASE_ELE) __read_only image2d_t bias, -#endif __write_only image2d_t output_image, __private const int stride, __private const int offset, @@ -249,9 +247,7 @@ __kernel void depth_conv2d_3x3s1(__private const int ou_ch_blk, __private const int ou_nh, __read_only image2d_t input, __read_only image2d_t filter, -#if defined(BIASE_CH) || defined(BIASE_ELE) __read_only image2d_t bias, -#endif __write_only image2d_t output_image, __private const int stride, __private const int pad, diff --git a/lite/backends/xpu/debug.h b/lite/backends/xpu/debug.h new file mode 100644 index 0000000000000000000000000000000000000000..75d18b6f4bf461a871c26c7665d8b48bc2f3db38 --- /dev/null +++ b/lite/backends/xpu/debug.h @@ -0,0 +1,131 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include "lite/backends/xpu/xpu_header_sitter.h" + +namespace paddle { +namespace lite { +namespace xpu { + +template +void DumpCPUMem(const T* ptr, + size_t len, + const std::string& comment = "", + size_t stride = 1, + size_t item_per_line = 30) { + size_t after_stride_len = (len + stride - 1) / stride; + std::unique_ptr after_stride(new T[after_stride_len]); + for (size_t i = 0; i < after_stride_len; ++i) { + after_stride[i] = ptr[i * stride]; + } + double sum = 0; + for (size_t i = 0; i < len; ++i) { + sum += ptr[i]; + } + + printf( + "------------------------------ [%s] len=%zd stride=%zd sum=%f BEGIN " + "------------------------------\n", + comment.c_str(), + len, + stride, + sum); + size_t nline = (after_stride_len + item_per_line - 1) / item_per_line; + for (size_t i = 0; i < nline; ++i) { + size_t line_begin = i * item_per_line; + size_t line_end = line_begin + item_per_line; + printf("line[%04zd] -- ", i); + for (size_t ii = line_begin; (ii < line_end) && (ii < after_stride_len); + ++ii) { + if (std::is_same::value) { + printf("%.6f, ", static_cast(after_stride[ii])); + } else if (std::is_same::value) { + printf("%d ", static_cast(after_stride[ii])); + } else { + // CHECK(false) << "unknown type"; + } + } + printf("\n"); + } + printf( + "------------------------------ [%s] len=%zd stride=%zd sum=%f END " + "------------------------------\n", + comment.c_str(), + len, + stride, + sum); +} + +template +void DumpXPUMem(const T* ptr, + size_t len, + const std::string& comment = "", + size_t stride = 1, + size_t item_per_line = 30) { + size_t after_stride_len = (len + stride - 1) / stride; + std::unique_ptr cpu_mem(new T[len]); + xpu_memcpy( + cpu_mem.get(), ptr, len * sizeof(T), XPUMemcpyKind::XPU_DEVICE_TO_HOST); + std::unique_ptr after_stride(new T[after_stride_len]); + for (size_t i = 0; i < after_stride_len; ++i) { + after_stride[i] = cpu_mem[i * stride]; + } + double sum = 0; + for (size_t i = 0; i < len; ++i) { + sum += cpu_mem[i]; + } + + printf( + "------------------------------ [%s] len=%zd stride=%zd sum=%f BEGIN " + "------------------------------\n", + comment.c_str(), + len, + stride, + sum); + size_t nline = (after_stride_len + item_per_line - 1) / item_per_line; + for (size_t i = 0; i < nline; ++i) { + size_t line_begin = i * item_per_line; + size_t line_end = line_begin + item_per_line; + printf("line[%04zd] -- ", i); + for (size_t ii = line_begin; (ii < line_end) && (ii < after_stride_len); + ++ii) { + if (std::is_same::value) { + printf("%.6f, ", static_cast(after_stride[ii])); + } else if (std::is_same::value) { + printf("%d ", static_cast(after_stride[ii])); + } else { + // CHECK(false) << "unknown type"; + } + } + printf("\n"); + } + printf( + "------------------------------ [%s] len=%zd stride=%zd sum=%f END " + "------------------------------\n", + comment.c_str(), + len, + stride, + sum); +} + +} // namespace xpu +} // namespace lite +} // namespace paddle diff --git a/lite/backends/xpu/target_wrapper.cc b/lite/backends/xpu/target_wrapper.cc index 5dcbc1e275cca8c32003cbef74dfb1f6d4caee93..85a0023590858ab72e9e4f258d62dce809888918 100644 --- a/lite/backends/xpu/target_wrapper.cc +++ b/lite/backends/xpu/target_wrapper.cc @@ -13,7 +13,6 @@ // limitations under the License. #include "lite/backends/xpu/target_wrapper.h" -#include "lite/backends/xpu/xpu_header_sitter.h" namespace paddle { namespace lite { @@ -42,5 +41,21 @@ void TargetWrapperXPU::MemcpySync(void* dst, } } +XPUScratchPadGuard TargetWrapperXPU::MallocScratchPad(size_t size, + bool use_l3) { + void* ptr{nullptr}; + if (use_l3) { + ptr = xdnn::alloc_workspace(GetRawContext(), size); + } else { + ptr = TargetWrapperXPU::Malloc(size); + } + CHECK(ptr != nullptr); + return XPUScratchPadGuard(new XPUScratchPad(ptr, use_l3)); +} + +std::string TargetWrapperXPU::multi_encoder_precision; // NOLINT +int TargetWrapperXPU::workspace_l3_size_per_thread{0}; +thread_local xdnn::Context* TargetWrapperXPU::tls_raw_ctx_{nullptr}; + } // namespace lite } // namespace paddle diff --git a/lite/backends/xpu/target_wrapper.h b/lite/backends/xpu/target_wrapper.h index c42d4139246085d8b9a367b45b60699209d0b668..b84b5d75e74a14e81091b003aa3ae5514e53a42c 100644 --- a/lite/backends/xpu/target_wrapper.h +++ b/lite/backends/xpu/target_wrapper.h @@ -14,6 +14,8 @@ #pragma once +#include // std::unique_ptr +#include "lite/backends/xpu/xpu_header_sitter.h" // xpu_free #include "lite/core/target_wrapper.h" namespace paddle { @@ -21,6 +23,24 @@ namespace lite { using TargetWrapperXPU = TargetWrapper; +struct XPUScratchPad { + XPUScratchPad(void* addr, bool is_l3) : addr_(addr), is_l3_(is_l3) {} + + void* addr_{nullptr}; + bool is_l3_{false}; +}; + +struct XPUScratchPadDeleter { + void operator()(XPUScratchPad* sp) const { + if (!sp->is_l3_) { + xpu_free(sp->addr_); + } + delete sp; + } +}; + +using XPUScratchPadGuard = std::unique_ptr; + template <> class TargetWrapper { public: @@ -34,6 +54,41 @@ class TargetWrapper { const void* src, size_t size, IoDirection dir); + + static XPUScratchPadGuard MallocScratchPad(size_t size, bool use_l3 = true); + + static xdnn::Context* GetRawContext() { + if (tls_raw_ctx_ == nullptr) { + tls_raw_ctx_ = xdnn::create_context(); + CHECK(tls_raw_ctx_); + int r = xdnn::set_workspace_l3_size(tls_raw_ctx_, + workspace_l3_size_per_thread); + if (r != 0) { + LOG(WARNING) << "xdnn::set_workspace_l3_size() failed, r = " << r + << ", workspace_l3_size_per_thread = " + << workspace_l3_size_per_thread; + } + } + return tls_raw_ctx_; + } + + // **DEPRECATED**, use xpu_set_device() at the very beginning of each worker + // thread + static void SetDev(int dev_no = 0) { + const char* dev_env = getenv("LITE_XPU_DEV"); + if (dev_env) { + xpu_set_device(atoi(dev_env)); + return; + } + + xpu_set_device(dev_no); + } + + static std::string multi_encoder_precision; // NOLINT + static int workspace_l3_size_per_thread; + + private: + static thread_local xdnn::Context* tls_raw_ctx_; }; } // namespace lite diff --git a/lite/core/context.cc b/lite/core/context.cc index d59531b8232d864d3edf94d6c0302d1453c6af3d..f14d1dfddea806ab3839f6f897b9d4d3fe396ca8 100644 --- a/lite/core/context.cc +++ b/lite/core/context.cc @@ -21,12 +21,6 @@ namespace lite { std::string Context::subgraph_model_cache_dir_{""}; // NOLINT #endif -#ifdef LITE_WITH_XPU -std::string Context::_multi_encoder_precision; // NOLINT -thread_local xdnn::Context* Context::_tls_raw_ctx{nullptr}; -int Context::_workspace_l3_size_per_thread{0}; -#endif - #ifdef LITE_WITH_MLU int Context::next_queue_id_{0}; std::map Context::queue_id_map_; diff --git a/lite/core/context.h b/lite/core/context.h index 0b5c66374b84ed4cc8ccc39c95a3e73f3cdc187e..c3993d9589eeac442eaa827152fd1293852396db 100644 --- a/lite/core/context.h +++ b/lite/core/context.h @@ -144,45 +144,12 @@ class Context { void CopySharedTo(XPUContext* ctx) {} + // TODO(miaotianxiang): remove this static xdnn::Context* GetRawContext() { - if (_tls_raw_ctx == nullptr) { - _tls_raw_ctx = xdnn::create_context(); - CHECK(_tls_raw_ctx); - int r = xdnn::set_workspace_l3_size(_tls_raw_ctx, - _workspace_l3_size_per_thread); - if (r != 0) { - LOG(WARNING) << "xdnn::set_workspace_l3_size() failed, r = " << r - << ", _workspace_l3_size_per_thread = " - << _workspace_l3_size_per_thread; - } - } - return _tls_raw_ctx; - } - - static void SetWorkspaceL3Size(int l3_size = 0xfffc00) { - _workspace_l3_size_per_thread = l3_size; - } - - // **DEPRECATED**, use xpu_set_device() at the very beginning of each worker - // thread - static void SetDev(int dev_no = 0) { - const char* dev_env = getenv("LITE_XPU_DEV"); - if (dev_env) { - xpu_set_device(atoi(dev_env)); - return; - } - - xpu_set_device(dev_no); + return TargetWrapperXPU::GetRawContext(); } std::string name() const { return "XPUContext"; } - - public: - static std::string _multi_encoder_precision; // NOLINT - - private: - static thread_local xdnn::Context* _tls_raw_ctx; - static int _workspace_l3_size_per_thread; }; #endif diff --git a/lite/core/mir/CMakeLists.txt b/lite/core/mir/CMakeLists.txt index 2540bb56d4082570c984e8eea009b5575825fec9..be09ed4b1a63154b8561f4d39cff7d987a9fcba7 100644 --- a/lite/core/mir/CMakeLists.txt +++ b/lite/core/mir/CMakeLists.txt @@ -23,9 +23,11 @@ lite_cc_library(mir_passes fusion/sequence_pool_concat_fuse_pass.cc fusion/scale_activation_fuse_pass.cc fusion/__xpu__resnet_fuse_pass.cc + fusion/__xpu__resnet_cbam_fuse_pass.cc fusion/__xpu__multi_encoder_fuse_pass.cc fusion/__xpu__embedding_with_eltwise_add_fuse_pass.cc fusion/__xpu__fc_fuse_pass.cc + fusion/__xpu__mmdnn_fuse_pass.cc elimination/identity_scale_eliminate_pass.cc elimination/identity_dropout_eliminate_pass.cc elimination/elementwise_mul_constant_eliminate_pass.cc diff --git a/lite/core/mir/fusion/__xpu__mmdnn_fuse_pass.cc b/lite/core/mir/fusion/__xpu__mmdnn_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..61aeb2ab1f51ddcd6b153971253f8239472a1031 --- /dev/null +++ b/lite/core/mir/fusion/__xpu__mmdnn_fuse_pass.cc @@ -0,0 +1,1183 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/backends/xpu/math.h" +#include "lite/core/mir/pass_registry.h" +#include "lite/core/mir/xpu_pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +class XPUMmdnnFloat2Fix { + public: + void operator()(SSAGraph* graph) { + for (auto* node : graph->StmtTopologicalOrder()) { + CHECK(node->IsStmt()); + auto* op_info = node->stmt()->op_info(); + std::string op_type = op_info->Type(); + + static const std::vector target_ops{"var_conv_2d", + "search_fc"}; + if (std::find(target_ops.begin(), target_ops.end(), op_type) != + target_ops.end()) { + std::string weight_name = op_info->Input("W").front(); + auto* scope = node->stmt()->op()->scope(); + auto* weight_t = scope->FindMutableTensor(weight_name); + auto weight_dims = weight_t->dims(); + auto weight_len = weight_t->numel(); + float* weight_on_host = weight_t->mutable_data(); + float max_f = + paddle::lite::xpu::math::FindMaxAbs(weight_on_host, weight_len); + std::unique_ptr weight_int16(new int16_t[weight_len]); + paddle::lite::xpu::math::ConvertFP32ToInt16( + weight_on_host, weight_int16.get(), max_f, weight_len); + memcpy( + weight_on_host, weight_int16.get(), weight_len * sizeof(int16_t)); + + auto update_op_info = *op_info; + update_op_info.SetAttr("__xpu__float_to_fix", true); + update_op_info.SetAttr("__xpu__w_max", max_f); + node->stmt()->ResetOp(update_op_info, graph->valid_places()); + VLOG(3) << "Float2Fix, op_type=" << op_type + << ", weight_name=" << weight_name; + } else if (op_type == "match_matrix_tensor") { + std::string weight_name = op_info->Input("W").front(); + auto* scope = node->stmt()->op()->scope(); + auto* weight_t = scope->FindMutableTensor(weight_name); + auto weight_dims = weight_t->dims(); + auto weight_len = weight_t->numel(); + float* weight_on_host = weight_t->mutable_data(); + float max_f = + paddle::lite::xpu::math::FindMaxAbs(weight_on_host, weight_len); + std::unique_ptr weight_int16(new int16_t[weight_len]); + std::unique_ptr weight_trans_int16(new int16_t[weight_len]); + paddle::lite::xpu::math::ConvertFP32ToInt16( + weight_on_host, weight_int16.get(), max_f, weight_len); + paddle::lite::xpu::math::Transpose(weight_int16.get(), + weight_trans_int16.get(), + weight_dims[0], + weight_dims[1] * weight_dims[2]); + memcpy(weight_on_host, + weight_trans_int16.get(), + weight_len * sizeof(int16_t)); + + auto update_op_info = *op_info; + update_op_info.SetAttr("__xpu__float_to_fix", true); + update_op_info.SetAttr("__xpu__w_max", max_f); + node->stmt()->ResetOp(update_op_info, graph->valid_places()); + VLOG(3) << "Float2Fix && Transposed, op_type=" << op_type + << ", weight_name=" << weight_name; + } else if (op_type == "search_grnn") { + auto* scope = node->stmt()->op()->scope(); + + std::string wi_name = op_info->Input("Wi").front(); + auto* wi_t = scope->FindMutableTensor(wi_name); + auto wi_dims = wi_t->dims(); + auto wi_len = wi_t->numel(); + auto wi_stride_len = wi_len / 3; + float* wi_on_host = wi_t->mutable_data(); + std::unique_ptr wi_int16(new int16_t[wi_len]); + std::vector wi_max(3); + for (int i = 0; i < 3; ++i) { + float max_f = paddle::lite::xpu::math::FindMaxAbs( + wi_on_host + i * wi_stride_len, wi_stride_len); + paddle::lite::xpu::math::ConvertFP32ToInt16( + wi_on_host + i * wi_stride_len, + wi_int16.get() + i * wi_stride_len, + max_f, + wi_stride_len); + wi_max[i] = max_f; + } + memcpy(wi_on_host, wi_int16.get(), wi_len * sizeof(int16_t)); + + std::string wh_name = op_info->Input("Wh").front(); + auto* wh_t = scope->FindMutableTensor(wh_name); + auto wh_dims = wh_t->dims(); + auto wh_len = wh_t->numel(); + auto wh_stride_len = wh_len / 3; + float* wh_on_host = wh_t->mutable_data(); + std::unique_ptr wh_int16(new int16_t[wh_len]); + std::vector wh_max(3); + for (int i = 0; i < 3; ++i) { + float max_f = paddle::lite::xpu::math::FindMaxAbs( + wh_on_host + i * wh_stride_len, wh_stride_len); + paddle::lite::xpu::math::ConvertFP32ToInt16( + wh_on_host + i * wh_stride_len, + wh_int16.get() + i * wh_stride_len, + max_f, + wh_stride_len); + wh_max[i] = max_f; + } + memcpy(wh_on_host, wh_int16.get(), wh_len * sizeof(int16_t)); + + auto update_op_info = *op_info; + update_op_info.SetAttr("__xpu__float_to_fix", true); + update_op_info.SetAttr>("__xpu__wi_max", wi_max); + update_op_info.SetAttr>("__xpu__wh_max", wh_max); + node->stmt()->ResetOp(update_op_info, graph->valid_places()); + VLOG(3) << "Float2Fix, op_type=" << op_type << ", wi_name=" << wi_name + << ", wh_name=" << wh_name; + } + } + } +}; + +class XPUMmdnnSearchAttentionFuser : public FuseBase { + public: + void BuildPattern() override { + auto* input = VarNode("input")->AsInput(); + + auto* search_group_padding = + OpNode("search_group_padding", "search_group_padding"); + auto* out_emb_padding = + VarNode("out_emb_padding") + ->assert_is_op_output("search_group_padding", "Out_emb_padding") + ->AsIntermediate(); + auto* out_new = VarNode("out_new") + ->assert_is_op_output("search_group_padding", "Out_new") + ->AsIntermediate(); + auto* out_padding = + VarNode("out_padding") + ->assert_is_op_output("search_group_padding", "Out_padding") + ->AsIntermediate(); + + auto* search_seq_fc_w = VarNode("search_seq_fc_w") + ->assert_is_op_input("search_seq_fc", "W") + ->AsInput(); + auto* search_seq_fc_b = VarNode("search_seq_fc_b") + ->assert_is_op_input("search_seq_fc", "b") + ->AsInput(); + auto* search_seq_fc = + OpNode("search_seq_fc", "search_seq_fc")->AsIntermediate(); + auto* search_seq_fc_out = VarNode("search_seq_fc_out") + ->assert_is_op_output("search_seq_fc", "Out") + ->AsIntermediate(); + + auto* search_aligned_mat_mul = + OpNode("search_aligned_mat_mul", "search_aligned_mat_mul") + ->AsIntermediate(); + auto* search_aligned_mat_mul_out = + VarNode("search_aligned_mat_mul_out") + ->assert_is_op_output("search_aligned_mat_mul", "Out") + ->AsIntermediate(); + auto* search_aligned_mat_mul_a = + VarNode("search_aligned_mat_mul_a") + ->assert_is_op_output("search_aligned_mat_mul", "_a_addr") + ->AsIntermediate(); + auto* search_aligned_mat_mul_b = + VarNode("search_aligned_mat_mul_b") + ->assert_is_op_output("search_aligned_mat_mul", "_b_addr") + ->AsIntermediate(); + auto* search_aligned_mat_mul_c = + VarNode("search_aligned_mat_mul_c") + ->assert_is_op_output("search_aligned_mat_mul", "_c_addr") + ->AsIntermediate(); + + auto* search_attention_padding_mask = + OpNode("search_attention_padding_mask", "search_attention_padding_mask") + ->AsIntermediate(); + auto* search_attention_padding_mask_out = + VarNode("search_attention_padding_mask_out") + ->assert_is_op_output("search_attention_padding_mask", "Out") + ->AsIntermediate(); + auto* search_attention_padding_mask_pad_begin = + VarNode("search_attention_padding_mask_pad_begin") + ->assert_is_op_output("search_attention_padding_mask", "pad_begin") + ->AsIntermediate(); + + auto* search_seq_softmax = + OpNode("search_seq_softmax", "search_seq_softmax")->AsIntermediate(); + auto* search_seq_softmax_out = + VarNode("search_seq_softmax_out") + ->assert_is_op_output("search_seq_softmax", "Out") + ->AsIntermediate(); + auto* search_seq_softmax_out_log = + VarNode("search_seq_softmax_out_log") + ->assert_is_op_output("search_seq_softmax", "Out_log") + ->AsIntermediate(); + + auto* search_aligned_mat_mul_2 = + OpNode("search_aligned_mat_mul_2", "search_aligned_mat_mul") + ->AsIntermediate(); + auto* search_aligned_mat_mul_2_out = + VarNode("search_aligned_mat_mul_2_out") + ->assert_is_op_output("search_aligned_mat_mul", "Out") + ->AsIntermediate(); + auto* search_aligned_mat_mul_2_a = + VarNode("search_aligned_mat_mul_2_a") + ->assert_is_op_output("search_aligned_mat_mul", "_a_addr") + ->AsIntermediate(); + auto* search_aligned_mat_mul_2_b = + VarNode("search_aligned_mat_mul_2_b") + ->assert_is_op_output("search_aligned_mat_mul", "_b_addr") + ->AsIntermediate(); + auto* search_aligned_mat_mul_2_c = + VarNode("search_aligned_mat_mul_2_c") + ->assert_is_op_output("search_aligned_mat_mul", "_c_addr") + ->AsIntermediate(); + + auto* search_seq_depadding = + OpNode("search_seq_depadding")->AsIntermediate(); + auto* search_seq_depadding_out = + VarNode("search_seq_depadding_out")->AsOutput(); + + *input >> *search_group_padding >> *out_emb_padding; + *search_group_padding >> *out_new; + *search_group_padding >> *out_padding; + + *search_seq_fc_w >> *search_seq_fc; + *search_seq_fc_b >> *search_seq_fc; + *out_emb_padding >> *search_seq_fc; + *search_seq_fc >> *search_seq_fc_out; + + *search_seq_fc_out >> *search_aligned_mat_mul; + *out_emb_padding >> *search_aligned_mat_mul; + *search_aligned_mat_mul >> *search_aligned_mat_mul_out; + *search_aligned_mat_mul >> *search_aligned_mat_mul_a; + *search_aligned_mat_mul >> *search_aligned_mat_mul_b; + *search_aligned_mat_mul >> *search_aligned_mat_mul_c; + + *search_aligned_mat_mul_out >> *search_attention_padding_mask; + *out_padding >> *search_attention_padding_mask; + *search_attention_padding_mask >> *search_attention_padding_mask_out; + *search_attention_padding_mask >> *search_attention_padding_mask_pad_begin; + + *search_attention_padding_mask_out >> *search_seq_softmax; + *search_seq_softmax >> *search_seq_softmax_out; + *search_seq_softmax >> *search_seq_softmax_out_log; + + *search_seq_softmax_out >> *search_aligned_mat_mul_2; + *out_emb_padding >> *search_aligned_mat_mul_2; + *search_aligned_mat_mul_2 >> *search_aligned_mat_mul_2_out; + *search_aligned_mat_mul_2 >> *search_aligned_mat_mul_2_a; + *search_aligned_mat_mul_2 >> *search_aligned_mat_mul_2_b; + *search_aligned_mat_mul_2 >> *search_aligned_mat_mul_2_c; + + *search_aligned_mat_mul_2_out >> *search_seq_depadding; + *out_new >> *search_seq_depadding; + *search_seq_depadding >> *search_seq_depadding_out; + } + + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override { + cpp::OpDesc op_desc; + op_desc.SetType("__xpu__mmdnn_search_attention"); + op_desc.SetInput("X", {matched.at("input")->arg()->name}); + op_desc.SetInput("W", {matched.at("search_seq_fc_w")->arg()->name}); + op_desc.SetInput("b", {matched.at("search_seq_fc_b")->arg()->name}); + op_desc.SetOutput("Out", + {matched.at("search_seq_depadding_out")->arg()->name}); + + auto* padding_op_info = + matched.at("search_group_padding")->stmt()->op_info(); + op_desc.SetAttr("pad_id", padding_op_info->GetAttr("pad_id")); + auto* matmul_0_op_info = + matched.at("search_aligned_mat_mul")->stmt()->op_info(); + op_desc.SetAttr("alpha0", matmul_0_op_info->GetAttr("alpha")); + auto* matmul_1_op_info = + matched.at("search_aligned_mat_mul_2")->stmt()->op_info(); + op_desc.SetAttr("alpha1", matmul_1_op_info->GetAttr("alpha")); + auto* mask_op_info = + matched.at("search_attention_padding_mask")->stmt()->op_info(); + op_desc.SetAttr("mask", mask_op_info->GetAttr("mask")); + + auto* new_stmt = matched.at("search_group_padding")->stmt(); + auto* scope = new_stmt->op()->scope(); + auto w_name = matched.at("search_seq_fc_w")->arg()->name; + auto* w_t = scope->FindMutableTensor(w_name); + auto w_dims = w_t->dims(); + int w_len = w_t->numel(); + float* w_on_host = w_t->mutable_data(); + + float max_f = paddle::lite::xpu::math::FindMaxAbs(w_on_host, w_len); + std::unique_ptr w_int16(new int16_t[w_len]); + paddle::lite::xpu::math::ConvertFP32ToInt16( + w_on_host, w_int16.get(), max_f, w_len); + memcpy(w_on_host, w_int16.get(), w_len * sizeof(int16_t)); + op_desc.SetAttr("W_max", max_f); + + auto new_op = LiteOpRegistry::Global().Create(op_desc.Type()); + new_op->Attach(op_desc, scope); + new_op->SetValidPlaces(new_stmt->op()->valid_places()); + auto kernels = new_op->CreateKernels(new_op->valid_places()); + new_stmt->SetOp(new_op); + new_stmt->SetKernels(std::move(kernels)); + + DirectedLink(matched.at("search_seq_fc_w"), + matched.at("search_group_padding")); + DirectedLink(matched.at("search_seq_fc_b"), + matched.at("search_group_padding")); + IR_OP_VAR_LINK(matched.at("search_group_padding"), + matched.at("search_seq_depadding_out")); + } +}; + +class XPUMmdnnMatchConvTopkFuser : public FuseBase { + public: + void BuildPattern() override { + auto* input_x = VarNode("input_x") + ->assert_is_op_input("match_matrix_tensor", "X") + ->AsInput(); + auto* input_y = VarNode("input_y") + ->assert_is_op_input("match_matrix_tensor", "Y") + ->AsInput(); + auto* input_w = VarNode("input_w") + ->assert_is_op_input("match_matrix_tensor", "W") + ->AsInput(); + + auto* match_matrix_tensor = + OpNode("match_matrix_tensor", "match_matrix_tensor"); + auto* match_out = VarNode("match_out") + ->assert_is_op_output("match_matrix_tensor", "Out") + ->AsIntermediate(); + auto* match_tmp = VarNode("match_tmp") + ->assert_is_op_output("match_matrix_tensor", "Tmp") + ->AsIntermediate(); + auto* relu0 = OpNode("relu0", "relu")->AsIntermediate(); + auto* relu0_out = VarNode("relu0_out") + ->assert_is_op_output("relu", "Out") + ->AsIntermediate(); + auto* conv_w = + VarNode("conv_w")->assert_is_op_input("var_conv_2d", "W")->AsInput(); + auto* conv = OpNode("conv", "var_conv_2d")->AsIntermediate(); + auto* conv_out = VarNode("conv_out") + ->assert_is_op_output("var_conv_2d", "Out") + ->AsIntermediate(); + auto* conv_col = VarNode("conv_col") + ->assert_is_op_output("var_conv_2d", "Col") + ->AsIntermediate(); + auto* relu1 = OpNode("relu1", "relu")->AsIntermediate(); + auto* relu1_out = VarNode("relu1_out") + ->assert_is_op_output("relu", "Out") + ->AsIntermediate(); + auto* seq_concat = + OpNode("seq_concat", "sequence_concat")->AsIntermediate(); + auto* seq_concat_out = + VarNode("seq_concat_out") + ->assert_is_op_output("sequence_concat", "Out") + ->assert_is_op_input("sequence_topk_avg_pooling", "X") + ->AsIntermediate(); + auto* topk_col = + VarNode("topk_col") + ->assert_is_op_input("sequence_topk_avg_pooling", "COLUMN") + ->AsInput(); + auto* topk_row = + VarNode("topk_row") + ->assert_is_op_input("sequence_topk_avg_pooling", "ROW") + ->AsInput(); + auto* topk = OpNode("topk", "sequence_topk_avg_pooling")->AsIntermediate(); + auto* topk_out = + VarNode("topk_out") + ->assert_is_op_output("sequence_topk_avg_pooling", "Out") + ->AsOutput(); + auto* topk_pos = + VarNode("topk_pos") + ->assert_is_op_output("sequence_topk_avg_pooling", "pos") + ->AsIntermediate(); + + *input_x >> *match_matrix_tensor; + *input_y >> *match_matrix_tensor; + *input_w >> *match_matrix_tensor; + *match_matrix_tensor >> *match_out >> *relu0 >> *relu0_out; + *match_matrix_tensor >> *match_tmp; + + *relu0_out >> *conv >> *conv_out >> *relu1 >> *relu1_out; + *conv_w >> *conv; + *conv >> *conv_col; + + *relu0_out >> *seq_concat; + *relu1_out >> *seq_concat; + *seq_concat >> *seq_concat_out >> *topk >> *topk_out; + *topk_col >> *topk; + *topk_row >> *topk; + *topk >> *topk_pos; + } + + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override { + cpp::OpDesc op_desc; + op_desc.SetType("__xpu__mmdnn_match_conv_topk"); + op_desc.SetInput("input_x", {matched.at("input_x")->arg()->name}); + op_desc.SetInput("input_y", {matched.at("input_y")->arg()->name}); + op_desc.SetInput("input_w", {matched.at("input_w")->arg()->name}); + op_desc.SetInput("conv_w", {matched.at("conv_w")->arg()->name}); + op_desc.SetOutput("topk_out", {matched.at("topk_out")->arg()->name}); + + auto* match_op_info = matched.at("match_matrix_tensor")->stmt()->op_info(); + op_desc.SetAttr("input_w_max", + match_op_info->GetAttr("w_max")); + op_desc.SetAttr("dim_t", match_op_info->GetAttr("dim_t")); + auto* conv_op_info = matched.at("conv")->stmt()->op_info(); + op_desc.SetAttr("conv_w_max", conv_op_info->GetAttr("w_max")); + auto* topk_op_info = matched.at("topk")->stmt()->op_info(); + op_desc.SetAttr>( + "topks", topk_op_info->GetAttr>("topks")); + op_desc.SetAttr("channel_num", + topk_op_info->GetAttr("channel_num")); + + auto* new_stmt = matched.at("match_matrix_tensor")->stmt(); + auto new_op = LiteOpRegistry::Global().Create(op_desc.Type()); + new_op->Attach(op_desc, new_stmt->op()->scope()); + new_op->SetValidPlaces(new_stmt->op()->valid_places()); + auto kernels = new_op->CreateKernels(new_op->valid_places()); + new_stmt->SetOp(new_op); + new_stmt->SetKernels(std::move(kernels)); + + // XXX(miaotianxiang): redundant links around |topk| are automatically + // removed as |topk| is + // marked intermediate. + // RemoveDirectedLink(matched.at("topk_col"), matched.at("topk")); + // RemoveDirectedLink(matched.at("topk_row"), matched.at("topk")); + std::vector arg_names{"conv_w"}; + for (auto name : arg_names) { + DirectedLink(matched.at(name), matched.at("match_matrix_tensor")); + } + std::vector out_names{"topk_out"}; + for (auto name : out_names) { + IR_OP_VAR_LINK(matched.at("match_matrix_tensor"), matched.at(name)); + } + } +}; + +class XPUMmdnnBidSeqRevEmbEltwiseFuser : public FuseBase { + public: + void BuildPattern() override { + auto* input0 = VarNode("input0")->AsInput(); + auto* input1 = VarNode("input1")->AsInput(); + auto* emb_tbl = VarNode("emb_tbl")->AsInput(); + + // fwd emb + auto* emb0 = OpNode("emb0", "lookup_table"); + auto* emb0_out = + VarNode("emb0_out")->assert_is_op_output("lookup_table", "Out"); + auto* emb1 = OpNode("emb1", "lookup_table"); + auto* emb1_out = + VarNode("emb1_out")->assert_is_op_output("lookup_table", "Out"); + + auto* eltwise01 = OpNode("eltwise01", "search_seq_arithmetic"); + auto* eltwise01_out = + VarNode("eltwise01_out") + ->assert_is_op_output("search_seq_arithmetic", "Out") + ->AsOutput(); + + // rev emb + auto* seq_rev2 = OpNode("seq_rev2", "sequence_reverse")->AsIntermediate(); + auto* seq_rev2_out = VarNode("seq_rev2_out") + ->assert_is_op_output("sequence_reverse", "Y") + ->AsIntermediate(); + auto* seq_rev3 = OpNode("seq_rev3", "sequence_reverse")->AsIntermediate(); + auto* seq_rev3_out = VarNode("seq_rev3_out") + ->assert_is_op_output("sequence_reverse", "Y") + ->AsIntermediate(); + auto* emb2 = OpNode("emb2", "lookup_table")->AsIntermediate(); + auto* emb2_out = VarNode("emb2_out") + ->assert_is_op_output("lookup_table", "Out") + ->AsIntermediate(); + auto* emb3 = OpNode("emb3", "lookup_table")->AsIntermediate(); + auto* emb3_out = VarNode("emb3_out") + ->assert_is_op_output("lookup_table", "Out") + ->AsIntermediate(); + + auto* eltwise23 = + OpNode("eltwise23", "search_seq_arithmetic")->AsIntermediate(); + auto* eltwise23_out = + VarNode("eltwise23_out") + ->assert_is_op_output("search_seq_arithmetic", "Out") + ->AsOutput(); + + *input0 >> *emb0 >> *emb0_out >> *eltwise01 >> *eltwise01_out; + *emb_tbl >> *emb0; + *input1 >> *emb1 >> *emb1_out >> *eltwise01; + *emb_tbl >> *emb1; + + *input0 >> *seq_rev2 >> *seq_rev2_out >> *emb2 >> *emb2_out >> *eltwise23 >> + *eltwise23_out; + *emb_tbl >> *emb2; + *input1 >> *seq_rev3 >> *seq_rev3_out >> *emb3 >> *emb3_out >> *eltwise23; + *emb_tbl >> *emb3; + } + + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override { + cpp::OpDesc op_desc; + op_desc.SetType("sequence_reverse"); + op_desc.SetInput("X", {matched.at("eltwise01_out")->arg()->name}); + op_desc.SetOutput("Y", {matched.at("eltwise23_out")->arg()->name}); + + auto emb0_op = matched.at("emb0")->stmt()->op(); + auto new_seq_rev_op = LiteOpRegistry::Global().Create("sequence_reverse"); + new_seq_rev_op->Attach(op_desc, emb0_op->scope()); + auto* new_seq_rev_node = + graph->GraphCreateInstructNode(new_seq_rev_op, emb0_op->valid_places()); + + DirectedLink(matched.at("eltwise01_out"), new_seq_rev_node); + DirectedLink(new_seq_rev_node, matched.at("eltwise23_out")); + } +}; + +class XPUMmdnnBidEmbAttFuser : public FuseBase { + public: + void BuildPattern() override { + auto* input0 = VarNode("input0")->AsInput(); + auto* input1 = VarNode("input1")->AsInput(); + auto* emb_tbl = VarNode("emb_tbl")->AsInput(); + + auto* emb0 = OpNode("emb0", "lookup_table"); + auto* emb0_out = VarNode("emb0_out") + ->assert_is_op_output("lookup_table", "Out") + ->AsIntermediate(); + auto* emb1 = OpNode("emb1", "lookup_table")->AsIntermediate(); + auto* emb1_out = VarNode("emb1_out") + ->assert_is_op_output("lookup_table", "Out") + ->AsIntermediate(); + auto* eltwise01 = + OpNode("eltwise01", "search_seq_arithmetic")->AsIntermediate(); + auto* eltwise01_out = + VarNode("eltwise01_out") + ->assert_is_op_output("search_seq_arithmetic", "Out") + ->AsOutput(); + + auto* att_2in1_w = + VarNode("att_2in1_w") + ->assert_is_op_input("__xpu__mmdnn_search_attention", "W") + ->AsInput(); + auto* att_2in1_b = + VarNode("att_2in1_b") + ->assert_is_op_input("__xpu__mmdnn_search_attention", "b") + ->AsInput(); + auto* att_2in1 = + OpNode("att_2in1", "__xpu__mmdnn_search_attention")->AsIntermediate(); + auto* att_2in1_out = + VarNode("att_2in1_out") + ->assert_is_op_output("__xpu__mmdnn_search_attention", "Out") + ->AsIntermediate(); + auto* seq_pool_2in1 = + OpNode("seq_pool_2in1", "sequence_pool")->AsIntermediate(); + auto* seq_pool_2in1_out = VarNode("seq_pool_2in1_out") + ->assert_is_op_output("sequence_pool", "Out") + ->AsOutput(); + auto* seq_pool_2in1_max_idx = + VarNode("seq_pool_2in1_max_idx") + ->assert_is_op_output("sequence_pool", "MaxIndex") + ->AsIntermediate(); + + *input0 >> *emb0 >> *emb0_out >> *eltwise01 >> *eltwise01_out; + *emb_tbl >> *emb0; + *input1 >> *emb1 >> *emb1_out >> *eltwise01; + *emb_tbl >> *emb1; + + *eltwise01_out >> *att_2in1 >> *att_2in1_out >> *seq_pool_2in1 >> + *seq_pool_2in1_out; + *seq_pool_2in1 >> *seq_pool_2in1_max_idx; + *att_2in1_w >> *att_2in1; + *att_2in1_b >> *att_2in1; + } + + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override { + cpp::OpDesc op_desc; + op_desc.SetType("__xpu__mmdnn_bid_emb_att"); + op_desc.SetInput("id0", {matched.at("input0")->arg()->name}); + op_desc.SetInput("id1", {matched.at("input1")->arg()->name}); + op_desc.SetInput("emb_tbl", {matched.at("emb_tbl")->arg()->name}); + op_desc.SetInput("att_fc_w", {matched.at("att_2in1_w")->arg()->name}); + op_desc.SetInput("att_fc_b", {matched.at("att_2in1_b")->arg()->name}); + op_desc.SetOutput("att_pool_out", + {matched.at("seq_pool_2in1_out")->arg()->name}); + op_desc.SetOutput("emb_fw_out", {matched.at("eltwise01_out")->arg()->name}); + + auto* att_fc_op_info = matched.at("att_2in1")->stmt()->op_info(); + op_desc.SetAttr("att_fc_w_max", + att_fc_op_info->GetAttr("W_max")); + + auto* new_stmt = matched.at("emb0")->stmt(); + auto new_op = LiteOpRegistry::Global().Create(op_desc.Type()); + new_op->Attach(op_desc, new_stmt->op()->scope()); + new_op->SetValidPlaces(new_stmt->op()->valid_places()); + auto kernels = new_op->CreateKernels(new_op->valid_places()); + new_stmt->SetOp(new_op); + new_stmt->SetKernels(std::move(kernels)); + + std::vector arg_names{ + "input1", "att_2in1_w", "att_2in1_b", + }; + for (auto name : arg_names) { + DirectedLink(matched.at(name), matched.at("emb0")); + } + std::vector out_names{ + "seq_pool_2in1_out", "eltwise01_out", + }; + for (auto name : out_names) { + IR_OP_VAR_LINK(matched.at("emb0"), matched.at(name)); + } + } +}; + +class XPUMmdnnBidEmbGrnnAttFuser : public FuseBase { + public: + void BuildPattern() override { + auto* input0 = VarNode("input0")->AsInput(); + auto* input1 = VarNode("input1")->AsInput(); + auto* emb_tbl = VarNode("emb_tbl")->AsInput(); + + auto* emb0 = OpNode("emb0", "lookup_table"); + auto* emb0_out = VarNode("emb0_out") + ->assert_is_op_output("lookup_table", "Out") + ->AsIntermediate(); + auto* emb1 = OpNode("emb1", "lookup_table")->AsIntermediate(); + auto* emb1_out = VarNode("emb1_out") + ->assert_is_op_output("lookup_table", "Out") + ->AsIntermediate(); + auto* eltwise01 = + OpNode("eltwise01", "search_seq_arithmetic")->AsIntermediate(); + auto* eltwise01_out = + VarNode("eltwise01_out") + ->assert_is_op_output("search_seq_arithmetic", "Out") + ->AsOutput(); + + auto* seq_rev_right0 = + OpNode("seq_rev_right0", "sequence_reverse")->AsIntermediate(); + auto* seq_rev_right0_out = + VarNode("seq_rev_right0_out") + ->assert_is_op_output("sequence_reverse", "Y") + ->AsIntermediate(); + auto* grnn_right_wh = VarNode("grnn_right_wh") + ->assert_is_op_input("search_grnn", "Wh") + ->AsInput(); + auto* grnn_right_wi = VarNode("grnn_right_wi") + ->assert_is_op_input("search_grnn", "Wi") + ->AsInput(); + auto* grnn_right = OpNode("grnn_right", "search_grnn")->AsIntermediate(); + auto* grnn_right_out = VarNode("grnn_right_out") + ->assert_is_op_output("search_grnn", "Out") + ->AsIntermediate(); + auto* grnn_right_idx_sorted_by_width = + VarNode("grnn_right_idx_sorted_by_width") + ->assert_is_op_output("search_grnn", "idx_sorted_by_width") + ->AsIntermediate(); + auto* grnn_right_layout_input = + VarNode("grnn_right_layout_input") + ->assert_is_op_output("search_grnn", "layout_input") + ->AsIntermediate(); + auto* grnn_right_tmp_buffer = + VarNode("grnn_right_tmp_buffer") + ->assert_is_op_output("search_grnn", "tmp_buffer") + ->AsIntermediate(); + auto* seq_rev_right1 = + OpNode("seq_rev_right1", "sequence_reverse")->AsIntermediate(); + auto* seq_rev_right1_out = + VarNode("seq_rev_right1_out") + ->assert_is_op_output("sequence_reverse", "Y") + ->AsIntermediate(); + auto* seq_pool_right = + OpNode("seq_pool_right", "sequence_pool")->AsIntermediate(); + auto* seq_pool_right_out = VarNode("seq_pool_right_out") + ->assert_is_op_output("sequence_pool", "Out") + ->AsOutput(); + auto* seq_pool_right_max_idx = + VarNode("seq_pool_right_max_idx") + ->assert_is_op_output("sequence_pool", "MaxIndex") + ->AsIntermediate(); + + auto* grnn_left_wh = VarNode("grnn_left_wh") + ->assert_is_op_input("search_grnn", "Wh") + ->AsInput(); + auto* grnn_left_wi = VarNode("grnn_left_wi") + ->assert_is_op_input("search_grnn", "Wi") + ->AsInput(); + auto* grnn_left = OpNode("grnn_left", "search_grnn")->AsIntermediate(); + auto* grnn_left_out = VarNode("grnn_left_out") + ->assert_is_op_output("search_grnn", "Out") + ->AsIntermediate(); + auto* grnn_left_idx_sorted_by_width = + VarNode("grnn_left_idx_sorted_by_width") + ->assert_is_op_output("search_grnn", "idx_sorted_by_width") + ->AsIntermediate(); + auto* grnn_left_layout_input = + VarNode("grnn_left_layout_input") + ->assert_is_op_output("search_grnn", "layout_input") + ->AsIntermediate(); + auto* grnn_left_tmp_buffer = + VarNode("grnn_left_tmp_buffer") + ->assert_is_op_output("search_grnn", "tmp_buffer") + ->AsIntermediate(); + auto* seq_pool_left = + OpNode("seq_pool_left", "sequence_pool")->AsIntermediate(); + auto* seq_pool_left_out = VarNode("seq_pool_left_out") + ->assert_is_op_output("sequence_pool", "Out") + ->AsOutput(); + auto* seq_pool_left_max_idx = + VarNode("seq_pool_left_max_idx") + ->assert_is_op_output("sequence_pool", "MaxIndex") + ->AsIntermediate(); + + auto* concat_2in1 = OpNode("concat_2in1", "concat")->AsIntermediate(); + auto* concat_2in1_out = VarNode("concat_2in1_out") + ->assert_is_op_output("concat", "Out") + ->AsIntermediate(); + auto* att_2in1_w = + VarNode("att_2in1_w") + ->assert_is_op_input("__xpu__mmdnn_search_attention", "W") + ->AsInput(); + auto* att_2in1_b = + VarNode("att_2in1_b") + ->assert_is_op_input("__xpu__mmdnn_search_attention", "b") + ->AsInput(); + auto* att_2in1 = + OpNode("att_2in1", "__xpu__mmdnn_search_attention")->AsIntermediate(); + auto* att_2in1_out = + VarNode("att_2in1_out") + ->assert_is_op_output("__xpu__mmdnn_search_attention", "Out") + ->AsIntermediate(); + auto* seq_pool_2in1 = + OpNode("seq_pool_2in1", "sequence_pool")->AsIntermediate(); + auto* seq_pool_2in1_out = VarNode("seq_pool_2in1_out") + ->assert_is_op_output("sequence_pool", "Out") + ->AsOutput(); + auto* seq_pool_2in1_max_idx = + VarNode("seq_pool_2in1_max_idx") + ->assert_is_op_output("sequence_pool", "MaxIndex") + ->AsIntermediate(); + + auto* concat_3in1 = OpNode("concat_3in1", "concat")->AsIntermediate(); + auto* concat_3in1_out = VarNode("concat_3in1_out") + ->assert_is_op_output("concat", "Out") + ->AsOutput(); + + *input0 >> *emb0 >> *emb0_out >> *eltwise01 >> *eltwise01_out; + *emb_tbl >> *emb0; + *input1 >> *emb1 >> *emb1_out >> *eltwise01; + *emb_tbl >> *emb1; + + *eltwise01_out >> *seq_rev_right0 >> *seq_rev_right0_out >> *grnn_right >> + *grnn_right_out >> *seq_rev_right1 >> *seq_rev_right1_out; + *grnn_right_out >> *seq_pool_right >> *seq_pool_right_out; + *seq_pool_right >> *seq_pool_right_max_idx; + *grnn_right_wh >> *grnn_right; + *grnn_right_wi >> *grnn_right; + *grnn_right >> *grnn_right_idx_sorted_by_width; + *grnn_right >> *grnn_right_layout_input; + *grnn_right >> *grnn_right_tmp_buffer; + + *eltwise01_out >> *grnn_left >> *grnn_left_out >> *seq_pool_left >> + *seq_pool_left_out; + *seq_pool_left >> *seq_pool_left_max_idx; + *grnn_left_wh >> *grnn_left; + *grnn_left_wi >> *grnn_left; + *grnn_left >> *grnn_left_idx_sorted_by_width; + *grnn_left >> *grnn_left_layout_input; + *grnn_left >> *grnn_left_tmp_buffer; + + *seq_rev_right1_out >> *concat_2in1; + *grnn_left_out >> *concat_2in1; + *concat_2in1 >> *concat_2in1_out >> *att_2in1 >> *att_2in1_out >> + *seq_pool_2in1 >> *seq_pool_2in1_out; + *seq_pool_2in1 >> *seq_pool_2in1_max_idx; + *att_2in1_w >> *att_2in1; + *att_2in1_b >> *att_2in1; + + *eltwise01_out >> *concat_3in1; + *seq_rev_right1_out >> *concat_3in1; + *grnn_left_out >> *concat_3in1; + *concat_3in1 >> *concat_3in1_out; + } + + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override { + cpp::OpDesc op_desc; + op_desc.SetType("__xpu__mmdnn_bid_emb_grnn_att"); + op_desc.SetInput("id0", {matched.at("input0")->arg()->name}); + op_desc.SetInput("id1", {matched.at("input1")->arg()->name}); + op_desc.SetInput("emb_tbl", {matched.at("emb_tbl")->arg()->name}); + op_desc.SetInput("grnn_fw_wh", {matched.at("grnn_left_wh")->arg()->name}); + op_desc.SetInput("grnn_fw_wi", {matched.at("grnn_left_wi")->arg()->name}); + op_desc.SetInput("grnn_rv_wh", {matched.at("grnn_right_wh")->arg()->name}); + op_desc.SetInput("grnn_rv_wi", {matched.at("grnn_right_wi")->arg()->name}); + op_desc.SetInput("att_fc_w", {matched.at("att_2in1_w")->arg()->name}); + op_desc.SetInput("att_fc_b", {matched.at("att_2in1_b")->arg()->name}); + op_desc.SetOutput("grnn_fw_pool_out", + {matched.at("seq_pool_left_out")->arg()->name}); + op_desc.SetOutput("grnn_rv_pool_out", + {matched.at("seq_pool_right_out")->arg()->name}); + op_desc.SetOutput("att_pool_out", + {matched.at("seq_pool_2in1_out")->arg()->name}); + op_desc.SetOutput("concat_3in1_out", + {matched.at("concat_3in1_out")->arg()->name}); + op_desc.SetOutput("emb_fw_out", {matched.at("eltwise01_out")->arg()->name}); + + auto* grnn_fw_op_info = matched.at("grnn_left")->stmt()->op_info(); + op_desc.SetAttr>( + "grnn_fw_wh_maxs", + grnn_fw_op_info->GetAttr>("wh_max")); + op_desc.SetAttr>( + "grnn_fw_wi_maxs", + grnn_fw_op_info->GetAttr>("wi_max")); + auto* grnn_rv_op_info = matched.at("grnn_right")->stmt()->op_info(); + op_desc.SetAttr>( + "grnn_rv_wh_maxs", + grnn_rv_op_info->GetAttr>("wh_max")); + op_desc.SetAttr>( + "grnn_rv_wi_maxs", + grnn_rv_op_info->GetAttr>("wi_max")); + auto* att_fc_op_info = matched.at("att_2in1")->stmt()->op_info(); + op_desc.SetAttr("att_fc_w_max", + att_fc_op_info->GetAttr("W_max")); + + auto* new_stmt = matched.at("emb0")->stmt(); + auto new_op = LiteOpRegistry::Global().Create(op_desc.Type()); + new_op->Attach(op_desc, new_stmt->op()->scope()); + new_op->SetValidPlaces(new_stmt->op()->valid_places()); + auto kernels = new_op->CreateKernels(new_op->valid_places()); + new_stmt->SetOp(new_op); + new_stmt->SetKernels(std::move(kernels)); + + std::vector arg_names{ + "input1", + "grnn_left_wh", + "grnn_left_wi", + "grnn_right_wh", + "grnn_right_wi", + "att_2in1_w", + "att_2in1_b", + }; + for (auto name : arg_names) { + DirectedLink(matched.at(name), matched.at("emb0")); + } + std::vector out_names{ + "seq_pool_left_out", + "seq_pool_right_out", + "seq_pool_2in1_out", + "concat_3in1_out", + "eltwise01_out", + }; + for (auto name : out_names) { + IR_OP_VAR_LINK(matched.at("emb0"), matched.at(name)); + } + } +}; + +class XPUMmdnnMergeAllFuser : public FuseBase { + public: + void BuildPattern() override { + auto* concat_7in1_input0 = VarNode("concat_7in1_input0") + ->assert_is_op_nth_input("concat", "X", 0) + ->AsInput(); + auto* concat_7in1_input1 = VarNode("concat_7in1_input1") + ->assert_is_op_nth_input("concat", "X", 1) + ->AsInput(); + auto* concat_7in1_input2 = VarNode("concat_7in1_input2") + ->assert_is_op_nth_input("concat", "X", 2) + ->AsInput(); + auto* concat_7in1_input3 = VarNode("concat_7in1_input3") + ->assert_is_op_nth_input("concat", "X", 3) + ->AsInput(); + auto* concat_7in1_input4 = VarNode("concat_7in1_input4") + ->assert_is_op_nth_input("concat", "X", 4) + ->AsInput(); + auto* concat_7in1_input5 = VarNode("concat_7in1_input5") + ->assert_is_op_nth_input("concat", "X", 5) + ->AsInput(); + auto* concat_7in1_input6 = VarNode("concat_7in1_input6") + ->assert_is_op_nth_input("concat", "X", 6) + ->AsInput(); + auto* concat_7in1 = OpNode("concat_7in1", "concat"); + auto* concat_7in1_out = VarNode("concat_7in1_out") + ->assert_is_op_output("concat", "Out") + ->AsIntermediate(); + auto* search_fc0_w = VarNode("search_fc0_w") + ->assert_is_op_input("search_fc", "W") + ->AsInput(); + auto* search_fc0_b = VarNode("search_fc0_b") + ->assert_is_op_input("search_fc", "b") + ->AsInput(); + auto* search_fc0 = OpNode("search_fc0", "search_fc")->AsIntermediate(); + auto* search_fc0_out = VarNode("search_fc0_out") + ->assert_is_op_output("search_fc", "Out") + ->AsIntermediate(); + auto* relu0 = OpNode("relu0", "relu")->AsIntermediate(); + auto* relu0_out = VarNode("relu0_out") + ->assert_is_op_output("relu", "Out") + ->AsIntermediate(); + + auto* concat_2in1_input0 = VarNode("concat_2in1_input0") + ->assert_is_op_nth_input("concat", "X", 0) + ->AsInput(); + auto* concat_2in1_input1 = VarNode("concat_2in1_input1") + ->assert_is_op_nth_input("concat", "X", 1) + ->AsInput(); + auto* concat_2in1 = OpNode("concat_2in1", "concat")->AsIntermediate(); + auto* concat_2in1_out = VarNode("concat_2in1_out") + ->assert_is_op_output("concat", "Out") + ->AsIntermediate(); + auto* seq_rev = OpNode("seq_rev", "sequence_reverse")->AsIntermediate(); + auto* seq_rev_out = VarNode("seq_rev_out") + ->assert_is_op_output("sequence_reverse", "Y") + ->AsIntermediate(); + + auto* grnn_rv_wh = VarNode("grnn_rv_wh") + ->assert_is_op_input("search_grnn", "Wh") + ->AsInput(); + auto* grnn_rv_wi = VarNode("grnn_rv_wi") + ->assert_is_op_input("search_grnn", "Wi") + ->AsInput(); + auto* grnn_rv = OpNode("grnn_rv", "search_grnn")->AsIntermediate(); + auto* grnn_rv_out = VarNode("grnn_rv_out") + ->assert_is_op_output("search_grnn", "Out") + ->AsIntermediate(); + auto* grnn_rv_idx_sorted_by_width = + VarNode("grnn_rv_idx_sorted_by_width") + ->assert_is_op_output("search_grnn", "idx_sorted_by_width") + ->AsIntermediate(); + auto* grnn_rv_layout_input = + VarNode("grnn_rv_layout_input") + ->assert_is_op_output("search_grnn", "layout_input") + ->AsIntermediate(); + auto* grnn_rv_tmp_buffer = + VarNode("grnn_rv_tmp_buffer") + ->assert_is_op_output("search_grnn", "tmp_buffer") + ->AsIntermediate(); + auto* seq_pool_rv = + OpNode("seq_pool_rv", "sequence_pool")->AsIntermediate(); + auto* seq_pool_rv_out = VarNode("seq_pool_rv_out") + ->assert_is_op_output("sequence_pool", "Out") + ->AsIntermediate(); + auto* seq_pool_rv_max_idx = + VarNode("seq_pool_rv_max_idx") + ->assert_is_op_output("sequence_pool", "MaxIndex") + ->AsIntermediate(); + + auto* grnn_fw_wh = VarNode("grnn_fw_wh") + ->assert_is_op_input("search_grnn", "Wh") + ->AsInput(); + auto* grnn_fw_wi = VarNode("grnn_fw_wi") + ->assert_is_op_input("search_grnn", "Wi") + ->AsInput(); + auto* grnn_fw = OpNode("grnn_fw", "search_grnn")->AsIntermediate(); + auto* grnn_fw_out = VarNode("grnn_fw_out") + ->assert_is_op_output("search_grnn", "Out") + ->AsIntermediate(); + auto* grnn_fw_idx_sorted_by_width = + VarNode("grnn_fw_idx_sorted_by_width") + ->assert_is_op_output("search_grnn", "idx_sorted_by_width") + ->AsIntermediate(); + auto* grnn_fw_layout_input = + VarNode("grnn_fw_layout_input") + ->assert_is_op_output("search_grnn", "layout_input") + ->AsIntermediate(); + auto* grnn_fw_tmp_buffer = + VarNode("grnn_fw_tmp_buffer") + ->assert_is_op_output("search_grnn", "tmp_buffer") + ->AsIntermediate(); + auto* seq_pool_fw = + OpNode("seq_pool_fw", "sequence_pool")->AsIntermediate(); + auto* seq_pool_fw_out = VarNode("seq_pool_fw_out") + ->assert_is_op_output("sequence_pool", "Out") + ->AsIntermediate(); + auto* seq_pool_fw_max_idx = + VarNode("seq_pool_fw_max_idx") + ->assert_is_op_output("sequence_pool", "MaxIndex") + ->AsIntermediate(); + + auto* rv_fw_concat = OpNode("rv_fw_concat", "concat")->AsIntermediate(); + auto* rv_fw_concat_out = VarNode("rv_fw_concat_out") + ->assert_is_op_output("concat", "Out") + ->AsIntermediate(); + + auto* last_concat = OpNode("last_concat", "concat")->AsIntermediate(); + auto* last_concat_out = VarNode("last_concat_out") + ->assert_is_op_output("concat", "Out") + ->AsIntermediate(); + auto* search_fc1_w = VarNode("search_fc1_w") + ->assert_is_op_input("search_fc", "W") + ->AsInput(); + auto* search_fc1_b = VarNode("search_fc1_b") + ->assert_is_op_input("search_fc", "b") + ->AsInput(); + auto* search_fc1 = OpNode("search_fc1", "search_fc")->AsIntermediate(); + auto* search_fc1_out = VarNode("search_fc1_out") + ->assert_is_op_output("search_fc", "Out") + ->AsIntermediate(); + auto* relu1 = OpNode("relu1", "relu")->AsIntermediate(); + auto* relu1_out = VarNode("relu1_out") + ->assert_is_op_output("relu", "Out") + ->AsIntermediate(); + auto* search_fc2_w = VarNode("search_fc2_w") + ->assert_is_op_input("search_fc", "W") + ->AsInput(); + auto* search_fc2_b = VarNode("search_fc2_b") + ->assert_is_op_input("search_fc", "b") + ->AsInput(); + auto* search_fc2 = OpNode("search_fc2", "search_fc")->AsIntermediate(); + auto* search_fc2_out = VarNode("search_fc2_out") + ->assert_is_op_output("search_fc", "Out") + ->AsOutput(); + + *concat_7in1_input0 >> *concat_7in1; + *concat_7in1_input1 >> *concat_7in1; + *concat_7in1_input2 >> *concat_7in1; + *concat_7in1_input3 >> *concat_7in1; + *concat_7in1_input4 >> *concat_7in1; + *concat_7in1_input5 >> *concat_7in1; + *concat_7in1_input6 >> *concat_7in1; + *concat_7in1 >> *concat_7in1_out >> *search_fc0 >> *search_fc0_out >> + *relu0 >> *relu0_out; + *search_fc0_w >> *search_fc0; + *search_fc0_b >> *search_fc0; + + *concat_2in1_input0 >> *concat_2in1; + *concat_2in1_input1 >> *concat_2in1; + *concat_2in1 >> *concat_2in1_out >> *seq_rev >> *seq_rev_out; + + *seq_rev_out >> *grnn_rv >> *grnn_rv_out >> *seq_pool_rv >> + *seq_pool_rv_out; + *seq_pool_rv >> *seq_pool_rv_max_idx; + *grnn_rv_wh >> *grnn_rv; + *grnn_rv_wi >> *grnn_rv; + *grnn_rv >> *grnn_rv_idx_sorted_by_width; + *grnn_rv >> *grnn_rv_layout_input; + *grnn_rv >> *grnn_rv_tmp_buffer; + + *concat_2in1_out >> *grnn_fw >> *grnn_fw_out >> *seq_pool_fw >> + *seq_pool_fw_out; + *seq_pool_fw >> *seq_pool_fw_max_idx; + *grnn_fw_wh >> *grnn_fw; + *grnn_fw_wi >> *grnn_fw; + *grnn_fw >> *grnn_fw_idx_sorted_by_width; + *grnn_fw >> *grnn_fw_layout_input; + *grnn_fw >> *grnn_fw_tmp_buffer; + + *seq_pool_rv_out >> *rv_fw_concat; + *seq_pool_fw_out >> *rv_fw_concat; + *rv_fw_concat >> *rv_fw_concat_out; + + *rv_fw_concat_out >> *last_concat; + *relu0_out >> *last_concat; + *last_concat >> *last_concat_out >> *search_fc1 >> *search_fc1_out >> + *relu1 >> *relu1_out >> *search_fc2 >> *search_fc2_out; + *search_fc1_w >> *search_fc1; + *search_fc1_b >> *search_fc1; + *search_fc2_w >> *search_fc2; + *search_fc2_b >> *search_fc2; + } + + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override { + cpp::OpDesc op_desc; + op_desc.SetType("__xpu__mmdnn_merge_all"); + auto* concat_7in1_op_info = matched.at("concat_7in1")->stmt()->op_info(); + op_desc.SetInput("concat_7in1_x", concat_7in1_op_info->Input("X")); + auto* concat_2in1_op_info = matched.at("concat_2in1")->stmt()->op_info(); + op_desc.SetInput("concat_2in1_x", concat_2in1_op_info->Input("X")); + op_desc.SetInput("grnn_fw_wh", {matched.at("grnn_fw_wh")->arg()->name}); + op_desc.SetInput("grnn_fw_wi", {matched.at("grnn_fw_wi")->arg()->name}); + op_desc.SetInput("grnn_rv_wh", {matched.at("grnn_rv_wh")->arg()->name}); + op_desc.SetInput("grnn_rv_wi", {matched.at("grnn_rv_wi")->arg()->name}); + op_desc.SetInput("fc0_w", {matched.at("search_fc0_w")->arg()->name}); + op_desc.SetInput("fc0_b", {matched.at("search_fc0_b")->arg()->name}); + op_desc.SetInput("fc1_w", {matched.at("search_fc1_w")->arg()->name}); + op_desc.SetInput("fc1_b", {matched.at("search_fc1_b")->arg()->name}); + op_desc.SetInput("fc2_w", {matched.at("search_fc2_w")->arg()->name}); + op_desc.SetInput("fc2_b", {matched.at("search_fc2_b")->arg()->name}); + + op_desc.SetOutput("out", {matched.at("search_fc2_out")->arg()->name}); + + auto* grnn_fw_op_info = matched.at("grnn_fw")->stmt()->op_info(); + op_desc.SetAttr>( + "grnn_fw_wh_maxs", + grnn_fw_op_info->GetAttr>("wh_max")); + op_desc.SetAttr>( + "grnn_fw_wi_maxs", + grnn_fw_op_info->GetAttr>("wi_max")); + auto* grnn_rv_op_info = matched.at("grnn_rv")->stmt()->op_info(); + op_desc.SetAttr>( + "grnn_rv_wh_maxs", + grnn_rv_op_info->GetAttr>("wh_max")); + op_desc.SetAttr>( + "grnn_rv_wi_maxs", + grnn_rv_op_info->GetAttr>("wi_max")); + auto* fc0_op_info = matched.at("search_fc0")->stmt()->op_info(); + op_desc.SetAttr("fc0_w_max", fc0_op_info->GetAttr("w_max")); + auto* fc1_op_info = matched.at("search_fc1")->stmt()->op_info(); + op_desc.SetAttr("fc1_w_max", fc1_op_info->GetAttr("w_max")); + auto* fc2_op_info = matched.at("search_fc2")->stmt()->op_info(); + op_desc.SetAttr("fc2_w_max", fc2_op_info->GetAttr("w_max")); + + auto* new_stmt = matched.at("concat_7in1")->stmt(); + auto new_op = LiteOpRegistry::Global().Create(op_desc.Type()); + new_op->Attach(op_desc, new_stmt->op()->scope()); + new_op->SetValidPlaces(new_stmt->op()->valid_places()); + auto kernels = new_op->CreateKernels(new_op->valid_places()); + new_stmt->SetOp(new_op); + new_stmt->SetKernels(std::move(kernels)); + + std::vector arg_names{ + "concat_2in1_input0", + "concat_2in1_input1", + "grnn_fw_wh", + "grnn_fw_wi", + "grnn_rv_wh", + "grnn_rv_wi", + "search_fc0_w", + "search_fc0_b", + "search_fc1_w", + "search_fc1_b", + "search_fc2_w", + "search_fc2_b", + }; + for (auto name : arg_names) { + DirectedLink(matched.at(name), matched.at("concat_7in1")); + } + std::vector out_names{ + "search_fc2_out", + }; + for (auto name : out_names) { + IR_OP_VAR_LINK(matched.at("concat_7in1"), matched.at(name)); + } + } +}; + +} // namespace fusion + +class XPUMmdnnFusePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override { + if (GetBoolFromEnv("XPU_ENABLE_XTCL")) return; + + fusion::XPUMmdnnFloat2Fix float_2_fix; + float_2_fix(graph.get()); + fusion::XPUMmdnnSearchAttentionFuser search_att_fuser; + search_att_fuser(graph.get()); + fusion::XPUMmdnnMatchConvTopkFuser match_conv_topk_fuser; + match_conv_topk_fuser(graph.get()); + + fusion::XPUMmdnnBidSeqRevEmbEltwiseFuser bi_seq_rev_emb_eltwise_fuser; + bi_seq_rev_emb_eltwise_fuser(graph.get()); + fusion::XPUMmdnnBidEmbGrnnAttFuser bid_emb_grnn_att_fuser; + bid_emb_grnn_att_fuser(graph.get()); + fusion::XPUMmdnnBidEmbAttFuser bid_emb_att_fuser; + bid_emb_att_fuser(graph.get()); + fusion::XPUMmdnnMergeAllFuser merge_all_fuser; + merge_all_fuser(graph.get()); + } +}; + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(__xpu__mmdnn_fuse_pass, paddle::lite::mir::XPUMmdnnFusePass) + .BindTargets({TARGET(kXPU)}) + .BindKernel("__xpu__mmdnn_search_attention") + .BindKernel("__xpu__mmdnn_bid_emb_grnn_att") + .BindKernel("__xpu__mmdnn_bid_emb_att") + .BindKernel("__xpu__mmdnn_match_conv_topk") + .BindKernel("__xpu__mmdnn_merge_all"); diff --git a/lite/core/mir/fusion/__xpu__multi_encoder_fuse_pass.cc b/lite/core/mir/fusion/__xpu__multi_encoder_fuse_pass.cc index 525042e44b2997013943f392f592d812bd68fa0b..04988612192b79824b1294428fa9b1c38d784979 100644 --- a/lite/core/mir/fusion/__xpu__multi_encoder_fuse_pass.cc +++ b/lite/core/mir/fusion/__xpu__multi_encoder_fuse_pass.cc @@ -639,20 +639,21 @@ class XPUMultiEncoderFusePass : public ProgramPass { std::set fc_int31_ids; #ifdef LITE_WITH_XPU // TODO(miaotianxiang): core/mir/*_pass.cc are compiled anyway and need to - // access Context::_multi_encoder_precision, but this static member - // variable in class specialization defined in lite/core/context.cc - // is only compiled iff LITE_WITH_XPU==ON. To suppress linkage error, we use + // access TargetWrapperXPU::multi_encoder_precision, but this static member + // variable in class specialization defined in + // lite/backends/xpu/target_wrapper.cc is only compiled iff + // LITE_WITH_XPU==ON. To suppress linkage error, we use // #ifdef here. Any better idea? if (GetStringFromEnv("XPU_ENCODER_PRECISION", "int16") == "int31" || - lite::Context::_multi_encoder_precision == "int31") { + lite::TargetWrapperXPU::multi_encoder_precision == "int31") { fc_int31_ids = {0, 1, 2, 3, 4, 5}; VLOG(3) << "Use int31 in XPUMultiEncoderOp, " - << "lite::Context<>::_multi_encoder_precision=" - << lite::Context::_multi_encoder_precision; + << "lite::TargetWrapperXPU::multi_encoder_precision=" + << lite::TargetWrapperXPU::multi_encoder_precision; } else { VLOG(3) << "Use int16 in XPUMultiEncoderOp, " - << "lite::Context<>::_multi_encoder_precision=" - << lite::Context::_multi_encoder_precision; + << "lite::TargetWrapperXPU::multi_encoder_precision=" + << lite::TargetWrapperXPU::multi_encoder_precision; } #endif diff --git a/lite/core/mir/fusion/__xpu__resnet_cbam_fuse_pass.cc b/lite/core/mir/fusion/__xpu__resnet_cbam_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..b25eb084f286fccfa4afe8832f9dc1ff8384d552 --- /dev/null +++ b/lite/core/mir/fusion/__xpu__resnet_cbam_fuse_pass.cc @@ -0,0 +1,1389 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/backends/xpu/math.h" +#include "lite/core/mir/pass_registry.h" +#include "lite/core/mir/xpu_pattern_matcher_high_api.h" +#include "lite/operators/subgraph_op.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +class XPUResNetCbamBlock0Fuser : public FuseBase { + public: + XPUResNetCbamBlock0Fuser() {} + + void BuildPattern() override { + auto* input = + VarNode("input")->assert_is_op_input("conv2d", "Input")->AsInput(); + + auto* left_conv1_weight = VarNode("left_conv1_weight") + ->assert_is_op_input("conv2d", "Filter") + ->AsInput(); + auto* left_conv1 = OpNode("left_conv1", "conv2d"); + auto* left_conv1_out = VarNode("left_conv1_out") + ->assert_is_op_output("conv2d", "Output") + ->assert_is_op_input("batch_norm", "X") + ->AsIntermediate(); + auto* left_bn1_scale = VarNode("left_bn1_scale") + ->assert_is_op_input("batch_norm", "Scale") + ->AsIntermediate(); + auto* left_bn1_bias = VarNode("left_bn1_bias") + ->assert_is_op_input("batch_norm", "Bias") + ->AsInput(); + auto* left_bn1_mean = VarNode("left_bn1_mean") + ->assert_is_op_input("batch_norm", "Mean") + ->AsIntermediate(); + auto* left_bn1_var = VarNode("left_bn1_variance") + ->assert_is_op_input("batch_norm", "Variance") + ->AsIntermediate(); + auto* left_bn1 = OpNode("left_bn1", "batch_norm")->AsIntermediate(); + auto* left_bn1_out = VarNode("left_bn1_out") + ->assert_is_op_output("batch_norm", "Y") + ->assert_is_op_input("relu", "X") + ->AsIntermediate(); + auto* left_bn1_mean_out = VarNode("left_bn1_mean_out") + ->assert_is_op_output("batch_norm", "MeanOut") + ->AsIntermediate(); + auto* left_bn1_var_out = + VarNode("left_bn1_var_out") + ->assert_is_op_output("batch_norm", "VarianceOut") + ->AsIntermediate(); + auto* left_bn1_saved_mean = + VarNode("left_bn1_saved_mean") + ->assert_is_op_output("batch_norm", "SavedMean") + ->AsIntermediate(); + auto* left_bn1_saved_var = + VarNode("left_bn1_saved_var") + ->assert_is_op_output("batch_norm", "SavedVariance") + ->AsIntermediate(); + auto* left_relu1 = OpNode("left_relu1", "relu")->AsIntermediate(); + auto* left_relu1_out = VarNode("left_relu1_out") + ->assert_is_op_output("relu", "Out") + ->assert_is_op_input("conv2d", "Input") + ->AsIntermediate(); + + auto* left_conv2_weight = VarNode("left_conv2_weight") + ->assert_is_op_input("conv2d", "Filter") + ->AsInput(); + auto* left_conv2 = OpNode("left_conv2", "conv2d")->AsIntermediate(); + auto* left_conv2_out = VarNode("left_conv2_out") + ->assert_is_op_output("conv2d", "Output") + ->assert_is_op_input("batch_norm", "X") + ->AsIntermediate(); + auto* left_bn2_scale = VarNode("left_bn2_scale") + ->assert_is_op_input("batch_norm", "Scale") + ->AsIntermediate(); + auto* left_bn2_bias = VarNode("left_bn2_bias") + ->assert_is_op_input("batch_norm", "Bias") + ->AsInput(); + auto* left_bn2_mean = VarNode("left_bn2_mean") + ->assert_is_op_input("batch_norm", "Mean") + ->AsIntermediate(); + auto* left_bn2_var = VarNode("left_bn2_variance") + ->assert_is_op_input("batch_norm", "Variance") + ->AsIntermediate(); + auto* left_bn2 = OpNode("left_bn2", "batch_norm")->AsIntermediate(); + auto* left_bn2_out = VarNode("left_bn2_out") + ->assert_is_op_output("batch_norm", "Y") + ->assert_is_op_input("relu", "X") + ->AsIntermediate(); + auto* left_bn2_mean_out = VarNode("left_bn2_mean_out") + ->assert_is_op_output("batch_norm", "MeanOut") + ->AsIntermediate(); + auto* left_bn2_var_out = + VarNode("left_bn2_var_out") + ->assert_is_op_output("batch_norm", "VarianceOut") + ->AsIntermediate(); + auto* left_bn2_saved_mean = + VarNode("left_bn2_saved_mean") + ->assert_is_op_output("batch_norm", "SavedMean") + ->AsIntermediate(); + auto* left_bn2_saved_var = + VarNode("left_bn2_saved_var") + ->assert_is_op_output("batch_norm", "SavedVariance") + ->AsIntermediate(); + auto* left_relu2 = OpNode("left_relu2", "relu")->AsIntermediate(); + auto* left_relu2_out = VarNode("left_relu2_out") + ->assert_is_op_output("relu", "Out") + ->assert_is_op_input("conv2d", "Input") + ->AsIntermediate(); + + auto* left_conv3_weight = VarNode("left_conv3_weight") + ->assert_is_op_input("conv2d", "Filter") + ->AsInput(); + auto* left_conv3 = OpNode("left_conv3", "conv2d")->AsIntermediate(); + auto* left_conv3_out = VarNode("left_conv3_out") + ->assert_is_op_output("conv2d", "Output") + ->assert_is_op_input("batch_norm", "X") + ->AsIntermediate(); + auto* left_bn3_scale = VarNode("left_bn3_scale") + ->assert_is_op_input("batch_norm", "Scale") + ->AsIntermediate(); + auto* left_bn3_bias = VarNode("left_bn3_bias") + ->assert_is_op_input("batch_norm", "Bias") + ->AsInput(); + auto* left_bn3_mean = VarNode("left_bn3_mean") + ->assert_is_op_input("batch_norm", "Mean") + ->AsIntermediate(); + auto* left_bn3_var = VarNode("left_bn3_variance") + ->assert_is_op_input("batch_norm", "Variance") + ->AsIntermediate(); + auto* left_bn3 = OpNode("left_bn3", "batch_norm")->AsIntermediate(); + auto* left_bn3_out = VarNode("left_bn3_out") + ->assert_is_op_output("batch_norm", "Y") + ->AsIntermediate(); + auto* left_bn3_mean_out = VarNode("left_bn3_mean_out") + ->assert_is_op_output("batch_norm", "MeanOut") + ->AsIntermediate(); + auto* left_bn3_var_out = + VarNode("left_bn3_var_out") + ->assert_is_op_output("batch_norm", "VarianceOut") + ->AsIntermediate(); + auto* left_bn3_saved_mean = + VarNode("left_bn3_saved_mean") + ->assert_is_op_output("batch_norm", "SavedMean") + ->AsIntermediate(); + auto* left_bn3_saved_var = + VarNode("left_bn3_saved_var") + ->assert_is_op_output("batch_norm", "SavedVariance") + ->AsIntermediate(); + + // cbam specific + auto* reduce_mean = OpNode("reduce_mean", "reduce_mean")->AsIntermediate(); + auto* reduce_mean_out = VarNode("reduce_mean_out") + ->assert_is_op_output("reduce_mean", "Out") + ->assert_is_op_input("concat") + ->AsIntermediate(); + auto* reduce_max = OpNode("reduce_max", "reduce_max")->AsIntermediate(); + auto* reduce_max_out = VarNode("reduce_max_out") + ->assert_is_op_output("reduce_max", "Out") + ->assert_is_op_input("concat") + ->AsIntermediate(); + auto* concat = OpNode("concat", "concat")->AsIntermediate(); + auto* concat_out = VarNode("concat_out") + ->assert_is_op_output("concat", "Out") + ->assert_is_op_input("conv2d", "Input") + ->AsIntermediate(); + auto* left_conv4_weight = VarNode("left_conv4_weight") + ->assert_is_op_input("conv2d", "Filter") + ->AsInput(); + auto* left_conv4 = OpNode("left_conv4", "conv2d")->AsIntermediate(); + auto* left_conv4_out = VarNode("left_conv4_out") + ->assert_is_op_output("conv2d", "Output") + ->assert_is_op_input("sigmoid", "X") + ->AsIntermediate(); + auto* sigmoid = OpNode("sigmoid", "sigmoid")->AsIntermediate(); + auto* sigmoid_out = VarNode("sigmoid_out") + ->assert_is_op_output("sigmoid", "Out") + ->assert_is_op_input("elementwise_mul") + ->AsIntermediate(); + auto* reshape = OpNode("reshape", "reshape2")->AsIntermediate(); + auto* reshape_out = VarNode("reshape_out") + ->assert_is_op_output("reshape2", "Out") + ->assert_is_op_input("elementwise_mul") + ->AsIntermediate(); + auto* reshape_xshape = VarNode("reshape_xshape") + ->assert_is_op_output("reshape2", "XShape") + ->AsIntermediate(); + auto* eltwise_mul = + OpNode("eltwise_mul", "elementwise_mul")->AsIntermediate(); + auto* eltwise_mul_out = VarNode("eltwise_mul_out") + ->assert_is_op_output("elementwise_mul", "Out") + ->assert_is_op_input("elementwise_add") + ->AsIntermediate(); + + auto* right_conv1_weight = VarNode("right_conv1_weight") + ->assert_is_op_input("conv2d", "Filter") + ->AsInput(); + auto* right_conv1 = OpNode("right_conv1", "conv2d")->AsIntermediate(); + auto* right_conv1_out = VarNode("right_conv1_out") + ->assert_is_op_output("conv2d", "Output") + ->assert_is_op_input("batch_norm", "X") + ->AsIntermediate(); + auto* right_bn1_scale = VarNode("right_bn1_scale") + ->assert_is_op_input("batch_norm", "Scale") + ->AsIntermediate(); + auto* right_bn1_bias = VarNode("right_bn1_bias") + ->assert_is_op_input("batch_norm", "Bias") + ->AsInput(); + auto* right_bn1_mean = VarNode("right_bn1_mean") + ->assert_is_op_input("batch_norm", "Mean") + ->AsIntermediate(); + auto* right_bn1_var = VarNode("right_bn1_variance") + ->assert_is_op_input("batch_norm", "Variance") + ->AsIntermediate(); + auto* right_bn1 = OpNode("right_bn1", "batch_norm")->AsIntermediate(); + auto* right_bn1_out = VarNode("right_bn1_out") + ->assert_is_op_output("batch_norm", "Y") + ->assert_is_op_input("elementwise_add") + ->AsIntermediate(); + auto* right_bn1_mean_out = + VarNode("right_bn1_mean_out") + ->assert_is_op_output("batch_norm", "MeanOut") + ->AsIntermediate(); + auto* right_bn1_var_out = + VarNode("right_bn1_var_out") + ->assert_is_op_output("batch_norm", "VarianceOut") + ->AsIntermediate(); + auto* right_bn1_saved_mean = + VarNode("right_bn1_saved_mean") + ->assert_is_op_output("batch_norm", "SavedMean") + ->AsIntermediate(); + auto* right_bn1_saved_var = + VarNode("right_bn1_saved_var") + ->assert_is_op_output("batch_norm", "SavedVariance") + ->AsIntermediate(); + + auto* add = OpNode("add", "elementwise_add")->AsIntermediate(); + auto* add_out = VarNode("add_out") + ->assert_is_op_output("elementwise_add", "Out") + ->assert_is_op_input("relu", "X") + ->AsIntermediate(); + auto* relu = OpNode("relu", "relu")->AsIntermediate(); + auto* relu_out = + VarNode("relu_out")->assert_is_op_output("relu", "Out")->AsOutput(); + + *input >> *left_conv1 >> *left_conv1_out >> *left_bn1 >> *left_bn1_out >> + *left_relu1 >> *left_relu1_out >> *left_conv2 >> *left_conv2_out >> + *left_bn2 >> *left_bn2_out >> *left_relu2 >> *left_relu2_out >> + *left_conv3 >> *left_conv3_out >> *left_bn3 >> + *left_bn3_out /* >> *add*/; + + *left_bn3_out >> *reduce_mean >> *reduce_mean_out >> *concat; + *left_bn3_out >> *reduce_max >> *reduce_max_out >> *concat; + *concat >> *concat_out >> *left_conv4 >> *left_conv4_out >> *sigmoid >> + *sigmoid_out >> *eltwise_mul; + *left_conv4_weight >> *left_conv4; + *left_bn3_out >> *reshape >> *reshape_out >> *eltwise_mul; + *reshape >> *reshape_xshape; + *eltwise_mul >> *eltwise_mul_out >> *add; + + *left_conv1_weight >> *left_conv1; + *left_bn1_scale >> *left_bn1; + *left_bn1_bias >> *left_bn1; + *left_bn1_mean >> *left_bn1; + *left_bn1_var >> *left_bn1; + *left_bn1 >> *left_bn1_mean_out; + *left_bn1 >> *left_bn1_var_out; + *left_bn1 >> *left_bn1_saved_mean; + *left_bn1 >> *left_bn1_saved_var; + + *left_conv2_weight >> *left_conv2; + *left_bn2_scale >> *left_bn2; + *left_bn2_bias >> *left_bn2; + *left_bn2_mean >> *left_bn2; + *left_bn2_var >> *left_bn2; + *left_bn2 >> *left_bn2_mean_out; + *left_bn2 >> *left_bn2_var_out; + *left_bn2 >> *left_bn2_saved_mean; + *left_bn2 >> *left_bn2_saved_var; + + *left_conv3_weight >> *left_conv3; + *left_bn3_scale >> *left_bn3; + *left_bn3_bias >> *left_bn3; + *left_bn3_mean >> *left_bn3; + *left_bn3_var >> *left_bn3; + *left_bn3 >> *left_bn3_mean_out; + *left_bn3 >> *left_bn3_var_out; + *left_bn3 >> *left_bn3_saved_mean; + *left_bn3 >> *left_bn3_saved_var; + + *input >> *right_conv1 >> *right_conv1_out >> *right_bn1 >> + *right_bn1_out >> *add; + + *right_conv1_weight >> *right_conv1; + *right_bn1_scale >> *right_bn1; + *right_bn1_bias >> *right_bn1; + *right_bn1_mean >> *right_bn1; + *right_bn1_var >> *right_bn1; + *right_bn1 >> *right_bn1_mean_out; + *right_bn1 >> *right_bn1_var_out; + *right_bn1 >> *right_bn1_saved_mean; + *right_bn1 >> *right_bn1_saved_var; + + *add >> *add_out >> *relu >> *relu_out; + } + + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override { + cpp::OpDesc op_desc; + op_desc.SetType("resnet_cbam_block0"); + op_desc.SetInput("Inputs", {matched.at("input")->arg()->name}); + op_desc.SetInput("Filter", + { + matched.at("left_conv1_weight")->arg()->name, + matched.at("left_conv2_weight")->arg()->name, + matched.at("left_conv3_weight")->arg()->name, + matched.at("left_conv4_weight")->arg()->name, + matched.at("right_conv1_weight")->arg()->name, + }); + op_desc.SetInput("Scale", + { + matched.at("left_bn1_scale")->arg()->name, + matched.at("left_bn2_scale")->arg()->name, + matched.at("left_bn3_scale")->arg()->name, + "placeholder_sa_conv", + matched.at("right_bn1_scale")->arg()->name, + }); + op_desc.SetInput("Bias", + { + matched.at("left_bn1_bias")->arg()->name, + matched.at("left_bn2_bias")->arg()->name, + matched.at("left_bn3_bias")->arg()->name, + "placeholder_sa_conv", + matched.at("right_bn1_bias")->arg()->name, + }); + op_desc.SetInput("Mean", + { + matched.at("left_bn1_mean")->arg()->name, + matched.at("left_bn2_mean")->arg()->name, + matched.at("left_bn3_mean")->arg()->name, + "placeholder_sa_conv", + matched.at("right_bn1_mean")->arg()->name, + }); + op_desc.SetInput("Var", + { + matched.at("left_bn1_variance")->arg()->name, + matched.at("left_bn2_variance")->arg()->name, + matched.at("left_bn3_variance")->arg()->name, + "placeholder_sa_conv", + matched.at("right_bn1_variance")->arg()->name, + }); + op_desc.SetOutput("Outputs", {matched.at("relu_out")->arg()->name}); + // XXX: keep these to fool SubgraphOp::AttachImpl() + op_desc.SetAttr("sub_block", 0); + op_desc.SetAttr>("input_data_names", {}); + op_desc.SetAttr>("output_data_names", {}); + + auto block0_stmt = matched.at("left_conv1")->stmt(); + // block0_stmt->ResetOp(op_desc, graph->valid_places()); + auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph"); + // XXX: memleak? + auto sub_block_desc = new cpp::BlockDesc(); + static_cast(fake_subgraph_op.get()) + ->SetSubBlock(sub_block_desc); + fake_subgraph_op->Attach(op_desc, block0_stmt->op()->scope()); + fake_subgraph_op->SetValidPlaces(block0_stmt->op()->valid_places()); + block0_stmt->SetOp(fake_subgraph_op); + + std::vector froms = { + "left_conv2_weight", + "left_conv3_weight", + "left_conv4_weight", + "right_conv1_weight", + "left_bn1_bias", + "left_bn2_bias", + "left_bn3_bias", + "right_bn1_bias", + }; + for (auto& from : froms) { + IR_NODE_LINK_TO(matched.at(from), matched.at("left_conv1")); + } + IR_OP_VAR_LINK(matched.at("left_conv1"), matched.at("relu_out")); + } +}; + +class XPUResNetCbamBlock1Fuser : public FuseBase { + public: + XPUResNetCbamBlock1Fuser() {} + + void BuildPattern() override { + auto* input = VarNode("input") + ->assert_is_op_input("conv2d", "Input") + ->assert_is_op_input("elementwise_add") + ->AsInput(); + + auto* right_conv1_weight = VarNode("right_conv1_weight") + ->assert_is_op_input("conv2d", "Filter") + ->AsInput(); + auto* right_conv1 = OpNode("right_conv1", "conv2d"); + auto* right_conv1_out = VarNode("right_conv1_out") + ->assert_is_op_output("conv2d", "Output") + ->assert_is_op_input("batch_norm", "X") + ->AsIntermediate(); + auto* right_bn1_scale = VarNode("right_bn1_scale") + ->assert_is_op_input("batch_norm", "Scale") + ->AsIntermediate(); + auto* right_bn1_bias = VarNode("right_bn1_bias") + ->assert_is_op_input("batch_norm", "Bias") + ->AsInput(); + auto* right_bn1_mean = VarNode("right_bn1_mean") + ->assert_is_op_input("batch_norm", "Mean") + ->AsIntermediate(); + auto* right_bn1_var = VarNode("right_bn1_variance") + ->assert_is_op_input("batch_norm", "Variance") + ->AsIntermediate(); + auto* right_bn1 = OpNode("right_bn1", "batch_norm")->AsIntermediate(); + auto* right_bn1_out = VarNode("right_bn1_out") + ->assert_is_op_output("batch_norm", "Y") + ->assert_is_op_input("relu", "X") + ->AsIntermediate(); + auto* right_bn1_mean_out = + VarNode("right_bn1_mean_out") + ->assert_is_op_output("batch_norm", "MeanOut") + ->AsIntermediate(); + auto* right_bn1_var_out = + VarNode("right_bn1_var_out") + ->assert_is_op_output("batch_norm", "VarianceOut") + ->AsIntermediate(); + auto* right_bn1_saved_mean = + VarNode("right_bn1_saved_mean") + ->assert_is_op_output("batch_norm", "SavedMean") + ->AsIntermediate(); + auto* right_bn1_saved_var = + VarNode("right_bn1_saved_var") + ->assert_is_op_output("batch_norm", "SavedVariance") + ->AsIntermediate(); + auto* right_relu1 = OpNode("right_relu1", "relu")->AsIntermediate(); + auto* right_relu1_out = VarNode("right_relu1_out") + ->assert_is_op_output("relu", "Out") + ->assert_is_op_input("conv2d", "Input") + ->AsIntermediate(); + + auto* right_conv2_weight = VarNode("right_conv2_weight") + ->assert_is_op_input("conv2d", "Filter") + ->AsInput(); + auto* right_conv2 = OpNode("right_conv2", "conv2d")->AsIntermediate(); + auto* right_conv2_out = VarNode("right_conv2_out") + ->assert_is_op_output("conv2d", "Output") + ->assert_is_op_input("batch_norm", "X") + ->AsIntermediate(); + auto* right_bn2_scale = VarNode("right_bn2_scale") + ->assert_is_op_input("batch_norm", "Scale") + ->AsIntermediate(); + auto* right_bn2_bias = VarNode("right_bn2_bias") + ->assert_is_op_input("batch_norm", "Bias") + ->AsInput(); + auto* right_bn2_mean = VarNode("right_bn2_mean") + ->assert_is_op_input("batch_norm", "Mean") + ->AsIntermediate(); + auto* right_bn2_var = VarNode("right_bn2_variance") + ->assert_is_op_input("batch_norm", "Variance") + ->AsIntermediate(); + auto* right_bn2 = OpNode("right_bn2", "batch_norm")->AsIntermediate(); + auto* right_bn2_out = VarNode("right_bn2_out") + ->assert_is_op_output("batch_norm", "Y") + ->assert_is_op_input("relu", "X") + ->AsIntermediate(); + auto* right_bn2_mean_out = + VarNode("right_bn2_mean_out") + ->assert_is_op_output("batch_norm", "MeanOut") + ->AsIntermediate(); + auto* right_bn2_var_out = + VarNode("right_bn2_var_out") + ->assert_is_op_output("batch_norm", "VarianceOut") + ->AsIntermediate(); + auto* right_bn2_saved_mean = + VarNode("right_bn2_saved_mean") + ->assert_is_op_output("batch_norm", "SavedMean") + ->AsIntermediate(); + auto* right_bn2_saved_var = + VarNode("right_bn2_saved_var") + ->assert_is_op_output("batch_norm", "SavedVariance") + ->AsIntermediate(); + auto* right_relu2 = OpNode("right_relu2", "relu")->AsIntermediate(); + auto* right_relu2_out = VarNode("right_relu2_out") + ->assert_is_op_output("relu", "Out") + ->assert_is_op_input("conv2d", "Input") + ->AsIntermediate(); + + auto* right_conv3_weight = VarNode("right_conv3_weight") + ->assert_is_op_input("conv2d", "Filter") + ->AsInput(); + auto* right_conv3 = OpNode("right_conv3", "conv2d")->AsIntermediate(); + auto* right_conv3_out = VarNode("right_conv3_out") + ->assert_is_op_output("conv2d", "Output") + ->assert_is_op_input("batch_norm", "X") + ->AsIntermediate(); + auto* right_bn3_scale = VarNode("right_bn3_scale") + ->assert_is_op_input("batch_norm", "Scale") + ->AsIntermediate(); + auto* right_bn3_bias = VarNode("right_bn3_bias") + ->assert_is_op_input("batch_norm", "Bias") + ->AsInput(); + auto* right_bn3_mean = VarNode("right_bn3_mean") + ->assert_is_op_input("batch_norm", "Mean") + ->AsIntermediate(); + auto* right_bn3_var = VarNode("right_bn3_variance") + ->assert_is_op_input("batch_norm", "Variance") + ->AsIntermediate(); + auto* right_bn3 = OpNode("right_bn3", "batch_norm")->AsIntermediate(); + auto* right_bn3_out = VarNode("right_bn3_out") + ->assert_is_op_output("batch_norm", "Y") + ->AsIntermediate(); + auto* right_bn3_mean_out = + VarNode("right_bn3_mean_out") + ->assert_is_op_output("batch_norm", "MeanOut") + ->AsIntermediate(); + auto* right_bn3_var_out = + VarNode("right_bn3_var_out") + ->assert_is_op_output("batch_norm", "VarianceOut") + ->AsIntermediate(); + auto* right_bn3_saved_mean = + VarNode("right_bn3_saved_mean") + ->assert_is_op_output("batch_norm", "SavedMean") + ->AsIntermediate(); + auto* right_bn3_saved_var = + VarNode("right_bn3_saved_var") + ->assert_is_op_output("batch_norm", "SavedVariance") + ->AsIntermediate(); + + // cbam specific + auto* reduce_mean = OpNode("reduce_mean", "reduce_mean")->AsIntermediate(); + auto* reduce_mean_out = VarNode("reduce_mean_out") + ->assert_is_op_output("reduce_mean", "Out") + ->assert_is_op_input("concat") + ->AsIntermediate(); + auto* reduce_max = OpNode("reduce_max", "reduce_max")->AsIntermediate(); + auto* reduce_max_out = VarNode("reduce_max_out") + ->assert_is_op_output("reduce_max", "Out") + ->assert_is_op_input("concat") + ->AsIntermediate(); + auto* concat = OpNode("concat", "concat")->AsIntermediate(); + auto* concat_out = VarNode("concat_out") + ->assert_is_op_output("concat", "Out") + ->assert_is_op_input("conv2d", "Input") + ->AsIntermediate(); + auto* right_conv4_weight = VarNode("right_conv4_weight") + ->assert_is_op_input("conv2d", "Filter") + ->AsInput(); + auto* right_conv4 = OpNode("right_conv4", "conv2d")->AsIntermediate(); + auto* right_conv4_out = VarNode("right_conv4_out") + ->assert_is_op_output("conv2d", "Output") + ->assert_is_op_input("sigmoid", "X") + ->AsIntermediate(); + auto* sigmoid = OpNode("sigmoid", "sigmoid")->AsIntermediate(); + auto* sigmoid_out = VarNode("sigmoid_out") + ->assert_is_op_output("sigmoid", "Out") + ->assert_is_op_input("elementwise_mul") + ->AsIntermediate(); + auto* reshape = OpNode("reshape", "reshape2")->AsIntermediate(); + auto* reshape_out = VarNode("reshape_out") + ->assert_is_op_output("reshape2", "Out") + ->assert_is_op_input("elementwise_mul") + ->AsIntermediate(); + auto* reshape_xshape = VarNode("reshape_xshape") + ->assert_is_op_output("reshape2", "XShape") + ->AsIntermediate(); + auto* eltwise_mul = + OpNode("eltwise_mul", "elementwise_mul")->AsIntermediate(); + auto* eltwise_mul_out = VarNode("eltwise_mul_out") + ->assert_is_op_output("elementwise_mul", "Out") + ->assert_is_op_input("elementwise_add") + ->AsIntermediate(); + + auto* add = OpNode("add", "elementwise_add")->AsIntermediate(); + auto* add_out = VarNode("add_out") + ->assert_is_op_output("elementwise_add", "Out") + ->assert_is_op_input("relu", "X") + ->AsIntermediate(); + auto* relu = OpNode("relu", "relu")->AsIntermediate(); + auto* relu_out = + VarNode("relu_out")->assert_is_op_output("relu", "Out")->AsOutput(); + + *input >> *right_conv1 >> *right_conv1_out >> *right_bn1 >> + *right_bn1_out >> *right_relu1 >> *right_relu1_out >> *right_conv2 >> + *right_conv2_out >> *right_bn2 >> *right_bn2_out >> *right_relu2 >> + *right_relu2_out >> *right_conv3 >> *right_conv3_out >> *right_bn3 >> + *right_bn3_out /* >> *add*/; + + *right_bn3_out >> *reduce_mean >> *reduce_mean_out >> *concat; + *right_bn3_out >> *reduce_max >> *reduce_max_out >> *concat; + *concat >> *concat_out >> *right_conv4 >> *right_conv4_out >> *sigmoid >> + *sigmoid_out >> *eltwise_mul; + *right_conv4_weight >> *right_conv4; + *right_bn3_out >> *reshape >> *reshape_out >> *eltwise_mul; + *reshape >> *reshape_xshape; + *eltwise_mul >> *eltwise_mul_out >> *add; + + *right_conv1_weight >> *right_conv1; + *right_bn1_scale >> *right_bn1; + *right_bn1_bias >> *right_bn1; + *right_bn1_mean >> *right_bn1; + *right_bn1_var >> *right_bn1; + *right_bn1 >> *right_bn1_mean_out; + *right_bn1 >> *right_bn1_var_out; + *right_bn1 >> *right_bn1_saved_mean; + *right_bn1 >> *right_bn1_saved_var; + + *right_conv2_weight >> *right_conv2; + *right_bn2_scale >> *right_bn2; + *right_bn2_bias >> *right_bn2; + *right_bn2_mean >> *right_bn2; + *right_bn2_var >> *right_bn2; + *right_bn2 >> *right_bn2_mean_out; + *right_bn2 >> *right_bn2_var_out; + *right_bn2 >> *right_bn2_saved_mean; + *right_bn2 >> *right_bn2_saved_var; + + *right_conv3_weight >> *right_conv3; + *right_bn3_scale >> *right_bn3; + *right_bn3_bias >> *right_bn3; + *right_bn3_mean >> *right_bn3; + *right_bn3_var >> *right_bn3; + *right_bn3 >> *right_bn3_mean_out; + *right_bn3 >> *right_bn3_var_out; + *right_bn3 >> *right_bn3_saved_mean; + *right_bn3 >> *right_bn3_saved_var; + + *input >> *add; + + *add >> *add_out >> *relu >> *relu_out; + } + + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override { + cpp::OpDesc op_desc; + op_desc.SetType("resnet_cbam_block1"); + op_desc.SetInput("Inputs", {matched.at("input")->arg()->name}); + op_desc.SetInput("Filter", + { + matched.at("right_conv1_weight")->arg()->name, + matched.at("right_conv2_weight")->arg()->name, + matched.at("right_conv3_weight")->arg()->name, + matched.at("right_conv4_weight")->arg()->name, + }); + op_desc.SetInput("Scale", + { + matched.at("right_bn1_scale")->arg()->name, + matched.at("right_bn2_scale")->arg()->name, + matched.at("right_bn3_scale")->arg()->name, + "placeholder_sa_conv", + }); + op_desc.SetInput("Bias", + { + matched.at("right_bn1_bias")->arg()->name, + matched.at("right_bn2_bias")->arg()->name, + matched.at("right_bn3_bias")->arg()->name, + "placeholder_sa_conv", + }); + op_desc.SetInput("Mean", + { + matched.at("right_bn1_mean")->arg()->name, + matched.at("right_bn2_mean")->arg()->name, + matched.at("right_bn3_mean")->arg()->name, + "placeholder_sa_conv", + }); + op_desc.SetInput("Var", + { + matched.at("right_bn1_variance")->arg()->name, + matched.at("right_bn2_variance")->arg()->name, + matched.at("right_bn3_variance")->arg()->name, + "placeholder_sa_conv", + }); + op_desc.SetOutput("Outputs", {matched.at("relu_out")->arg()->name}); + // XXX: keep these to fool SubgraphOp::AttachImpl() + op_desc.SetAttr("sub_block", 0); + op_desc.SetAttr>("input_data_names", {}); + op_desc.SetAttr>("output_data_names", {}); + + auto block1_stmt = matched.at("right_conv1")->stmt(); + auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph"); + // XXX: memleak? + auto sub_block_desc = new cpp::BlockDesc(); + static_cast(fake_subgraph_op.get()) + ->SetSubBlock(sub_block_desc); + fake_subgraph_op->Attach(op_desc, block1_stmt->op()->scope()); + fake_subgraph_op->SetValidPlaces(block1_stmt->op()->valid_places()); + block1_stmt->SetOp(fake_subgraph_op); + + std::vector froms = { + "right_conv2_weight", + "right_conv3_weight", + "right_conv4_weight", + "right_bn1_bias", + "right_bn2_bias", + "right_bn3_bias", + }; + for (auto& from : froms) { + IR_NODE_LINK_TO(matched.at(from), matched.at("right_conv1")); + } + IR_OP_VAR_LINK(matched.at("right_conv1"), matched.at("relu_out")); + } +}; + +class XPUResNetCbamBlock2Fuser : public FuseBase { + public: + XPUResNetCbamBlock2Fuser() {} + + void BuildPattern() override { + auto* input = VarNode("input")->assert_is_op_input("clip", "X")->AsInput(); + + auto* clip = OpNode("clip", "clip"); + auto* clip_out = VarNode("clip_out") + ->assert_is_op_output("clip", "Out") + ->assert_is_op_input("elementwise_pow") + ->AsIntermediate(); + auto* eltwise_y = VarNode("eltwise_y") + ->assert_is_op_input("elementwise_pow") + ->assert_is_op_input("elementwise_div") + ->AsIntermediate(); + auto* eltwise_pow = + OpNode("eltwise_pow", "elementwise_pow")->AsIntermediate(); + auto* eltwise_pow_out = VarNode("eltwise_pow_out") + ->assert_is_op_output("elementwise_pow", "Out") + ->assert_is_op_input("pad2d", "X") + ->AsIntermediate(); + auto* pad2d = OpNode("pad2d", "pad2d")->AsIntermediate(); + auto* pad2d_out = VarNode("pad2d_out") + ->assert_is_op_output("pad2d", "Out") + ->assert_is_op_input("pool2d", "X") + ->AsIntermediate(); + auto* pool2d = OpNode("pool2d", "pool2d")->AsIntermediate(); + auto* pool2d_out = VarNode("pool2d_out") + ->assert_is_op_output("pool2d", "Out") + ->assert_is_op_input("elementwise_pow") + ->AsIntermediate(); + + auto* fill_const = OpNode("fill_const", "fill_constant")->AsIntermediate(); + auto* fill_const_out = VarNode("fill_const_out") + ->assert_is_op_output("fill_constant", "Out") + ->assert_is_op_input("elementwise_div") + ->AsIntermediate(); + auto* eltwise_div = + OpNode("eltwise_div", "elementwise_div")->AsIntermediate(); + auto* eltwise_div_out = VarNode("eltwise_div_out") + ->assert_is_op_output("elementwise_div", "Out") + ->assert_is_op_input("elementwise_pow") + ->AsIntermediate(); + + auto* eltwise_pow2 = + OpNode("eltwise_pow2", "elementwise_pow")->AsIntermediate(); + auto* eltwise_pow2_out = VarNode("eltwise_pow2_out") + ->assert_is_op_output("elementwise_pow", "Out") + ->AsIntermediate(); + + auto* shape = OpNode("shape", "shape")->AsIntermediate(); + auto* shape_out = VarNode("shape_out") + ->assert_is_op_output("shape", "Out") + ->assert_is_op_input("gather") + ->AsIntermediate(); + auto* fill_const2 = + OpNode("fill_const2", "fill_constant")->AsIntermediate(); + auto* fill_const2_out = VarNode("fill_const2_out") + ->assert_is_op_output("fill_constant", "Out") + ->assert_is_op_input("gather") + ->AsIntermediate(); + auto* gather = OpNode("gather", "gather")->AsIntermediate(); + auto* gather_out = VarNode("gather_out") + ->assert_is_op_output("gather", "Out") + ->assert_is_op_input("assign", "X") + ->AsIntermediate(); + auto* assign = OpNode("assign", "assign")->AsIntermediate(); + auto* assign_out = VarNode("assign_out") + ->assert_is_op_output("assign", "Out") + ->assert_is_op_input("concat") + ->AsIntermediate(); + + auto* fill_const3 = + OpNode("fill_const3", "fill_constant")->AsIntermediate(); + auto* fill_const3_out = VarNode("fill_const3_out") + ->assert_is_op_output("fill_constant", "Out") + ->assert_is_op_input("assign") + ->AsIntermediate(); + auto* assign2 = OpNode("assign2", "assign")->AsIntermediate(); + auto* assign2_out = VarNode("assign2_out") + ->assert_is_op_output("assign", "Out") + ->assert_is_op_input("concat") + ->AsIntermediate(); + + auto* concat = OpNode("concat", "concat")->AsIntermediate(); + auto* concat_out = VarNode("concat_out") + ->assert_is_op_output("concat", "Out") + ->assert_is_op_input("cast", "X") + ->AsIntermediate(); + auto* cast = OpNode("cast", "cast")->AsIntermediate(); + auto* cast_out = VarNode("cast_out") + ->assert_is_op_output("cast", "Out") + ->assert_is_op_input("reshape2", "Shape") + ->AsIntermediate(); + + auto* reshape2 = OpNode("reshape2", "reshape2")->AsIntermediate(); + auto* reshape2_out = VarNode("reshape2_out") + ->assert_is_op_output("reshape2", "Out") + ->assert_is_op_input("matmul", "X") + ->AsIntermediate(); + auto* reshape2_xshape = VarNode("reshape2_xshape") + ->assert_is_op_output("reshape2", "XShape") + ->AsIntermediate(); + auto* matmul_y = + VarNode("matmul_y")->assert_is_op_input("matmul", "Y")->AsInput(); + auto* matmul = OpNode("matmul", "matmul")->AsIntermediate(); + auto* matmul_out = VarNode("matmul_out") + ->assert_is_op_output("matmul", "Out") + ->assert_is_op_input("elementwise_add") + ->AsIntermediate(); + auto* eltwise_add_y = VarNode("eltwise_add_y") + ->assert_is_op_input("elementwise_add") + ->AsInput(); + auto* eltwise_add = + OpNode("eltwise_add", "elementwise_add")->AsIntermediate(); + auto* eltwise_add_out = VarNode("eltwise_add_out") + ->assert_is_op_output("elementwise_add", "Out") + ->AsIntermediate(); + + auto* norm = OpNode("norm", "norm")->AsIntermediate(); + auto* norm_out = VarNode("norm_out") + ->assert_is_op_output("norm", "Out") + ->assert_is_op_input("elementwise_add") + ->AsIntermediate(); + auto* norm_norm = VarNode("norm_norm") + ->assert_is_op_output("norm", "Norm") + ->AsIntermediate(); + auto* fill_const4 = + OpNode("fill_const4", "fill_constant")->AsIntermediate(); + auto* fill_const4_out = VarNode("fill_const4_out") + ->assert_is_op_output("fill_constant", "Out") + ->assert_is_op_input("elementwise_add") + ->AsIntermediate(); + auto* eltwise_add2 = + OpNode("eltwise_add2", "elementwise_add")->AsIntermediate(); + auto* eltwise_add2_out = VarNode("eltwise_add2_out") + ->assert_is_op_output("elementwise_add", "Out") + ->assert_is_op_input("elementwise_mul") + ->AsIntermediate(); + auto* fill_const5 = + OpNode("fill_const5", "fill_constant")->AsIntermediate(); + auto* fill_const5_out = VarNode("fill_const5_out") + ->assert_is_op_output("fill_constant", "Out") + ->assert_is_op_input("elementwise_mul") + ->AsIntermediate(); + auto* eltwise_mul = + OpNode("eltwise_mul", "elementwise_mul")->AsIntermediate(); + auto* eltwise_mul_out = VarNode("eltwise_mul_out") + ->assert_is_op_output("elementwise_mul", "Out") + ->assert_is_op_input("elementwise_div") + ->AsIntermediate(); + + auto* eltwise_div2 = + OpNode("eltwise_div2", "elementwise_div")->AsIntermediate(); + auto* eltwise_div2_out = VarNode("eltwise_div2_out") + ->assert_is_op_output("elementwise_div", "Out") + ->AsOutput(); + + *input >> *clip >> *clip_out >> *eltwise_pow >> *eltwise_pow_out >> + *pad2d >> *pad2d_out >> *pool2d >> *pool2d_out >> *eltwise_pow2; + *eltwise_y >> *eltwise_pow; + + *fill_const >> *fill_const_out >> *eltwise_div >> *eltwise_div_out >> + *eltwise_pow2; + *eltwise_y >> *eltwise_div; + + *eltwise_pow2 >> *eltwise_pow2_out >> *shape >> *shape_out >> *gather >> + *gather_out >> *assign >> *assign_out >> *concat >> *concat_out >> + *cast >> *cast_out >> *reshape2; + *fill_const2 >> *fill_const2_out >> *gather; + *fill_const3 >> *fill_const3_out >> *assign2 >> *assign2_out >> *concat; + *eltwise_pow2_out >> *reshape2; + + *reshape2 >> *reshape2_out >> *matmul >> *matmul_out >> *eltwise_add >> + *eltwise_add_out; + *reshape2 >> *reshape2_xshape; + *matmul_y >> *matmul; + *eltwise_add_y >> *eltwise_add; + + *eltwise_add_out >> *norm >> *norm_out >> *eltwise_add2 >> + *eltwise_add2_out >> *eltwise_mul >> *eltwise_mul_out >> + *eltwise_div2 >> *eltwise_div2_out; + *norm >> *norm_norm; + *fill_const4 >> *fill_const4_out >> *eltwise_add2; + *fill_const5 >> *fill_const5_out >> *eltwise_mul; + *eltwise_add_out >> *eltwise_div2; + } + + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override { + cpp::OpDesc op_desc; + op_desc.SetType("resnet_cbam_block2"); + op_desc.SetInput("Inputs", {matched.at("input")->arg()->name}); + op_desc.SetInput("Filter", {matched.at("matmul_y")->arg()->name}); + op_desc.SetInput("Scale", {"placeholder_last_fc"}); + op_desc.SetInput("Bias", {matched.at("eltwise_add_y")->arg()->name}); + op_desc.SetInput("Mean", {"placeholder_last_fc"}); + op_desc.SetInput("Var", {"placeholder_last_fc"}); + op_desc.SetOutput("Outputs", {matched.at("eltwise_div2_out")->arg()->name}); + // XXX: keep these to fool SubgraphOp::AttachImpl() + op_desc.SetAttr("sub_block", 0); + op_desc.SetAttr>("input_data_names", {}); + op_desc.SetAttr>("output_data_names", {}); + + // extra traits to distill + auto block2_stmt = matched.at("clip")->stmt(); + auto* scope = block2_stmt->op()->scope(); + auto pow_tensor_name = matched.at("eltwise_y")->arg()->name; + auto* pow_tensor = scope->FindTensor(pow_tensor_name); + float pool_p = pow_tensor->data()[0]; + op_desc.SetAttr("pool_p", pool_p); + auto* matmul_op_info = matched.at("matmul")->stmt()->op_info(); + CHECK(matmul_op_info->GetAttr("transpose_Y") == true) + << "Y of last fc must have been transposed"; + + auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph"); + // XXX: memleak? + auto sub_block_desc = new cpp::BlockDesc(); + static_cast(fake_subgraph_op.get()) + ->SetSubBlock(sub_block_desc); + fake_subgraph_op->Attach(op_desc, scope); + fake_subgraph_op->SetValidPlaces(block2_stmt->op()->valid_places()); + block2_stmt->SetOp(fake_subgraph_op); + + std::vector froms = { + "matmul_y", "eltwise_add_y", + }; + for (auto& from : froms) { + IR_NODE_LINK_TO(matched.at(from), matched.at("clip")); + } + IR_OP_VAR_LINK(matched.at("clip"), matched.at("eltwise_div2_out")); + } +}; + +class XPUResNetCbamFuser : public xpu::XPUFuseBase { + public: + XPUResNetCbamFuser() {} + + void BuildPattern() override { + auto* input = + VarNode("input")->assert_is_op_input("conv2d", "Input")->AsInput(); + + auto* top_conv_weight = VarNode("top_conv_weight") + ->assert_is_op_input("conv2d", "Filter") + ->AsInput(); + auto* top_conv = OpNode("top_conv", "conv2d"); + auto* top_conv_out = VarNode("top_conv_out") + ->assert_is_op_output("conv2d", "Output") + ->assert_is_op_input("batch_norm", "X") + ->AsIntermediate(); + auto* top_bn_scale = VarNode("top_bn_scale") + ->assert_is_op_input("batch_norm", "Scale") + ->AsIntermediate(); + auto* top_bn_bias = VarNode("top_bn_bias") + ->assert_is_op_input("batch_norm", "Bias") + ->AsInput(); + auto* top_bn_mean = VarNode("top_bn_mean") + ->assert_is_op_input("batch_norm", "Mean") + ->AsIntermediate(); + auto* top_bn_var = VarNode("top_bn_variance") + ->assert_is_op_input("batch_norm", "Variance") + ->AsIntermediate(); + auto* top_bn = OpNode("top_bn", "batch_norm")->AsIntermediate(); + auto* top_bn_out = VarNode("top_bn_out") + ->assert_is_op_output("batch_norm", "Y") + ->assert_is_op_input("relu", "X") + ->AsIntermediate(); + auto* top_bn_mean_out = VarNode("top_bn_mean_out") + ->assert_is_op_output("batch_norm", "MeanOut") + ->AsIntermediate(); + auto* top_bn_var_out = + VarNode("top_bn_var_out") + ->assert_is_op_output("batch_norm", "VarianceOut") + ->AsIntermediate(); + auto* top_bn_saved_mean = + VarNode("top_bn_saved_mean") + ->assert_is_op_output("batch_norm", "SavedMean") + ->AsIntermediate(); + auto* top_bn_saved_var = + VarNode("top_bn_saved_var") + ->assert_is_op_output("batch_norm", "SavedVariance") + ->AsIntermediate(); + auto* top_relu = OpNode("top_relu", "relu")->AsIntermediate(); + auto* top_relu_out = VarNode("top_relu_out") + ->assert_is_op_output("relu", "Out") + ->assert_is_op_input("pool2d", "X") + ->AsIntermediate(); + auto* top_pool = OpNode("top_pool", "pool2d")->AsIntermediate(); + auto* top_pool_out = + VarNode("top_pool_out") + ->assert_is_op_output("pool2d", "Out") + ->assert_is_op_input("resnet_cbam_block0", "Inputs") + ->AsIntermediate(); + + // args are left out + auto* resnet_block0_1 = + OpNode("resnet_block0_1", "resnet_cbam_block0")->AsIntermediate(); + auto* resnet_block0_1_out = + VarNode("resnet_block0_1_out") + ->assert_is_op_output("resnet_cbam_block0", "Outputs") + ->AsIntermediate(); + auto* resnet_block1_1_1 = + OpNode("resnet_block1_1_1", "resnet_cbam_block1")->AsIntermediate(); + auto* resnet_block1_1_1_out = + VarNode("resnet_block1_1_1_out") + ->assert_is_op_output("resnet_cbam_block1", "Outputs") + ->AsIntermediate(); + auto* resnet_block1_1_2 = + OpNode("resnet_block1_1_2", "resnet_cbam_block1")->AsIntermediate(); + auto* resnet_block1_1_2_out = + VarNode("resnet_block1_1_2_out") + ->assert_is_op_output("resnet_cbam_block1", "Outputs") + ->AsIntermediate(); + + auto* resnet_block0_2 = + OpNode("resnet_block0_2", "resnet_cbam_block0")->AsIntermediate(); + auto* resnet_block0_2_out = + VarNode("resnet_block0_2_out") + ->assert_is_op_output("resnet_cbam_block0", "Outputs") + ->AsIntermediate(); + auto* resnet_block1_2_1 = + OpNode("resnet_block1_2_1", "resnet_cbam_block1")->AsIntermediate(); + auto* resnet_block1_2_1_out = + VarNode("resnet_block1_2_1_out") + ->assert_is_op_output("resnet_cbam_block1", "Outputs") + ->AsIntermediate(); + auto* resnet_block1_2_2 = + OpNode("resnet_block1_2_2", "resnet_cbam_block1")->AsIntermediate(); + auto* resnet_block1_2_2_out = + VarNode("resnet_block1_2_2_out") + ->assert_is_op_output("resnet_cbam_block1", "Outputs") + ->AsIntermediate(); + auto* resnet_block1_2_3 = + OpNode("resnet_block1_2_3", "resnet_cbam_block1")->AsIntermediate(); + auto* resnet_block1_2_3_out = + VarNode("resnet_block1_2_3_out") + ->assert_is_op_output("resnet_cbam_block1", "Outputs") + ->AsIntermediate(); + + auto* resnet_block0_3 = + OpNode("resnet_block0_3", "resnet_cbam_block0")->AsIntermediate(); + auto* resnet_block0_3_out = + VarNode("resnet_block0_3_out") + ->assert_is_op_output("resnet_cbam_block0", "Outputs") + ->AsIntermediate(); + auto* resnet_block1_3_1 = + OpNode("resnet_block1_3_1", "resnet_cbam_block1")->AsIntermediate(); + auto* resnet_block1_3_1_out = + VarNode("resnet_block1_3_1_out") + ->assert_is_op_output("resnet_cbam_block1", "Outputs") + ->AsIntermediate(); + auto* resnet_block1_3_2 = + OpNode("resnet_block1_3_2", "resnet_cbam_block1")->AsIntermediate(); + auto* resnet_block1_3_2_out = + VarNode("resnet_block1_3_2_out") + ->assert_is_op_output("resnet_cbam_block1", "Outputs") + ->AsIntermediate(); + auto* resnet_block1_3_3 = + OpNode("resnet_block1_3_3", "resnet_cbam_block1")->AsIntermediate(); + auto* resnet_block1_3_3_out = + VarNode("resnet_block1_3_3_out") + ->assert_is_op_output("resnet_cbam_block1", "Outputs") + ->AsIntermediate(); + auto* resnet_block1_3_4 = + OpNode("resnet_block1_3_4", "resnet_cbam_block1")->AsIntermediate(); + auto* resnet_block1_3_4_out = + VarNode("resnet_block1_3_4_out") + ->assert_is_op_output("resnet_cbam_block1", "Outputs") + ->AsIntermediate(); + auto* resnet_block1_3_5 = + OpNode("resnet_block1_3_5", "resnet_cbam_block1")->AsIntermediate(); + auto* resnet_block1_3_5_out = + VarNode("resnet_block1_3_5_out") + ->assert_is_op_output("resnet_cbam_block1", "Outputs") + ->AsIntermediate(); + + auto* resnet_block0_4 = + OpNode("resnet_block0_4", "resnet_cbam_block0")->AsIntermediate(); + auto* resnet_block0_4_out = + VarNode("resnet_block0_4_out") + ->assert_is_op_output("resnet_cbam_block0", "Outputs") + ->AsIntermediate(); + auto* resnet_block1_4_1 = + OpNode("resnet_block1_4_1", "resnet_cbam_block1")->AsIntermediate(); + auto* resnet_block1_4_1_out = + VarNode("resnet_block1_4_1_out") + ->assert_is_op_output("resnet_cbam_block1", "Outputs") + ->AsIntermediate(); + auto* resnet_block1_4_2 = + OpNode("resnet_block1_4_2", "resnet_cbam_block1")->AsIntermediate(); + auto* resnet_block1_4_2_out = + VarNode("resnet_block1_4_2_out") + ->assert_is_op_output("resnet_cbam_block1", "Outputs") + ->AsIntermediate(); + + auto* resnet_block2 = + OpNode("resnet_block2", "resnet_cbam_block2")->AsIntermediate(); + auto* resnet_block2_out = + VarNode("resnet_block2_out") + ->assert_is_op_output("resnet_cbam_block2", "Outputs") + ->AsOutput(); + + *input >> *top_conv >> *top_conv_out >> *top_bn >> *top_bn_out >> + *top_relu >> *top_relu_out >> *top_pool >> *top_pool_out >> + *resnet_block0_1 >> *resnet_block0_1_out >> *resnet_block1_1_1 >> + *resnet_block1_1_1_out >> *resnet_block1_1_2 >> + *resnet_block1_1_2_out >> *resnet_block0_2 >> *resnet_block0_2_out >> + *resnet_block1_2_1 >> *resnet_block1_2_1_out >> *resnet_block1_2_2 >> + *resnet_block1_2_2_out >> *resnet_block1_2_3 >> + *resnet_block1_2_3_out >> *resnet_block0_3 >> *resnet_block0_3_out >> + *resnet_block1_3_1 >> *resnet_block1_3_1_out >> *resnet_block1_3_2 >> + *resnet_block1_3_2_out >> *resnet_block1_3_3 >> + *resnet_block1_3_3_out >> *resnet_block1_3_4 >> + *resnet_block1_3_4_out >> *resnet_block1_3_5 >> + *resnet_block1_3_5_out >> *resnet_block0_4 >> *resnet_block0_4_out >> + *resnet_block1_4_1 >> *resnet_block1_4_1_out >> *resnet_block1_4_2 >> + *resnet_block1_4_2_out >> *resnet_block2 >> *resnet_block2_out; + + *top_conv_weight >> *top_conv; + *top_bn_scale >> *top_bn; + *top_bn_bias >> *top_bn; + *top_bn_mean >> *top_bn; + *top_bn_var >> *top_bn; + *top_bn >> *top_bn_mean_out; + *top_bn >> *top_bn_var_out; + *top_bn >> *top_bn_saved_mean; + *top_bn >> *top_bn_saved_var; + } + + void handle_placeholder_sa_conv(SSAGraph* graph, + const key2nodes_t& matched, + paddle::lite::Scope* scope, + const std::string& filter_name, + std::vector* max_filter_name) { + auto* filter_t = scope->FindMutableTensor(filter_name); + int filter_len = filter_t->numel(); + float* filter_on_host = filter_t->mutable_data(); + + float max_f = + paddle::lite::xpu::math::FindMaxAbs(filter_on_host, filter_len); + std::unique_ptr filter_int16(new int16_t[filter_len]); + paddle::lite::xpu::math::ConvertFP32ToInt16( + filter_on_host, filter_int16.get(), max_f, filter_len); + memcpy(filter_on_host, filter_int16.get(), filter_len * sizeof(int16_t)); + + // create new arg in graph and scope + std::string max_name = filter_name + "_max"; + max_filter_name->push_back(max_name); + auto* max_filter_node = graph->NewArgumentNode(max_name); + max_filter_node->arg()->is_weight = true; + max_filter_node->arg()->type = LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)); + DirectedLink(max_filter_node, matched.at("top_conv")); + auto* max_filter_t = scope->NewTensor(max_name); + max_filter_t->Resize({4}); + float* max_ptr = max_filter_t->mutable_data(); + max_ptr[0] = max_f; + max_ptr[1] = max_f; + max_ptr[2] = max_f; + max_ptr[3] = max_f; + } + + void handle_placeholder_last_fc(SSAGraph* graph, + const key2nodes_t& matched, + paddle::lite::Scope* scope, + const std::string& filter_name, + std::vector* max_filter_name) { + auto* filter_t = scope->FindMutableTensor(filter_name); + auto filter_dims = filter_t->dims(); + int filter_len = filter_t->numel(); + float* filter_on_host = filter_t->mutable_data(); + + // XXX(miaotianxiang): Y has already been transposed in model... + float max_f = + paddle::lite::xpu::math::FindMaxAbs(filter_on_host, filter_len); + std::unique_ptr filter_int16(new int16_t[filter_len]); + paddle::lite::xpu::math::ConvertFP32ToInt16( + filter_on_host, filter_int16.get(), max_f, filter_len); + memcpy(filter_on_host, filter_int16.get(), filter_len * sizeof(int16_t)); + + // create new arg in graph and scope + std::string max_name = filter_name + "_max"; + max_filter_name->push_back(max_name); + auto* max_filter_node = graph->NewArgumentNode(max_name); + max_filter_node->arg()->is_weight = true; + max_filter_node->arg()->type = LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)); + DirectedLink(max_filter_node, matched.at("top_conv")); + auto* max_filter_t = scope->NewTensor(max_name); + max_filter_t->Resize({4}); + float* max_ptr = max_filter_t->mutable_data(); + max_ptr[0] = max_f; + max_ptr[1] = max_f; + max_ptr[2] = max_f; + max_ptr[3] = max_f; + } + + void InsertNewNode(SSAGraph* graph, + const key2nodes_t& matched, + const std::vector& extra_input_vars) override { + cpp::OpDesc op_desc; + op_desc.SetType("__xpu__resnet_cbam"); + op_desc.SetInput("Input", {matched.at("input")->arg()->name}); + std::vector filter_name = { + matched.at("top_conv_weight")->arg()->name}; + std::vector scale_name = { + matched.at("top_bn_scale")->arg()->name}; + std::vector bias_name = { + matched.at("top_bn_bias")->arg()->name}; + std::vector mean_name = { + matched.at("top_bn_mean")->arg()->name}; + std::vector var_name = { + matched.at("top_bn_variance")->arg()->name}; + std::vector max_filter_name; + std::vector resnet_block_vec = { + "resnet_block0_1", + "resnet_block1_1_1", + "resnet_block1_1_2", + "resnet_block0_2", + "resnet_block1_2_1", + "resnet_block1_2_2", + "resnet_block1_2_3", + "resnet_block0_3", + "resnet_block1_3_1", + "resnet_block1_3_2", + "resnet_block1_3_3", + "resnet_block1_3_4", + "resnet_block1_3_5", + "resnet_block0_4", + "resnet_block1_4_1", + "resnet_block1_4_2", + "resnet_block2", + }; + for (auto& block : resnet_block_vec) { + auto* block_op_info = matched.at(block)->stmt()->op_info(); + auto block_filter_name = block_op_info->Input("Filter"); + std::copy(block_filter_name.begin(), + block_filter_name.end(), + std::back_inserter(filter_name)); + auto block_scale_name = block_op_info->Input("Scale"); + std::copy(block_scale_name.begin(), + block_scale_name.end(), + std::back_inserter(scale_name)); + auto block_bias_name = block_op_info->Input("Bias"); + std::copy(block_bias_name.begin(), + block_bias_name.end(), + std::back_inserter(bias_name)); + auto block_mean_name = block_op_info->Input("Mean"); + std::copy(block_mean_name.begin(), + block_mean_name.end(), + std::back_inserter(mean_name)); + auto block_var_name = block_op_info->Input("Var"); + std::copy(block_var_name.begin(), + block_var_name.end(), + std::back_inserter(var_name)); + } + + auto* resnet_cbam_stmt = matched.at("top_conv")->stmt(); + auto* scope = resnet_cbam_stmt->op()->scope(); + for (size_t i = 0; i < filter_name.size(); ++i) { + if (scale_name[i] == "placeholder_sa_conv") { + handle_placeholder_sa_conv( + graph, matched, scope, filter_name[i], &max_filter_name); + continue; + } else if (scale_name[i] == "placeholder_last_fc") { + handle_placeholder_last_fc( + graph, matched, scope, filter_name[i], &max_filter_name); + continue; + } + + auto* filter_t = scope->FindMutableTensor(filter_name[i]); + auto* scale_t = scope->FindMutableTensor(scale_name[i]); + auto* bias_t = scope->FindMutableTensor(bias_name[i]); + auto* mean_t = scope->FindMutableTensor(mean_name[i]); + auto* var_t = scope->FindMutableTensor(var_name[i]); + + int mean_len = mean_t->numel(); + int filter_len = filter_t->numel(); + int filter_stride = filter_len / mean_len; + + float* filter_on_host = filter_t->mutable_data(); + float* scale_on_host = scale_t->mutable_data(); + float* bias_on_host = bias_t->mutable_data(); + float* mean_on_host = mean_t->mutable_data(); + float* var_on_host = var_t->mutable_data(); + + // Perform preprocess + for (int i = 0; i < mean_len; ++i) { + scale_on_host[i] = scale_on_host[i] / sqrtf(var_on_host[i] + 0.00001f); + } + for (int i = 0; i < mean_len; ++i) { + for (int j = 0; j < filter_stride; ++j) { + filter_on_host[i * filter_stride + j] *= scale_on_host[i]; + } + } + for (int i = 0; i < mean_len; ++i) { + bias_on_host[i] += -mean_on_host[i] * scale_on_host[i]; + } + + float max_f = + paddle::lite::xpu::math::FindMaxAbs(filter_on_host, filter_len); + std::unique_ptr filter_int16(new int16_t[filter_len]); + paddle::lite::xpu::math::ConvertFP32ToInt16( + filter_on_host, filter_int16.get(), max_f, filter_len); + memcpy(filter_on_host, filter_int16.get(), filter_len * sizeof(int16_t)); + + // create new arg in graph and scope + std::string max_name = filter_name[i] + "_max"; + max_filter_name.push_back(max_name); + auto* max_filter_node = graph->NewArgumentNode(max_name); + max_filter_node->arg()->is_weight = true; + max_filter_node->arg()->type = LiteType::GetTensorTy( + TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)); + DirectedLink(max_filter_node, matched.at("top_conv")); + auto* max_filter_t = scope->NewTensor(max_name); + max_filter_t->Resize({4}); + float* max_ptr = max_filter_t->mutable_data(); + max_ptr[0] = max_f; + max_ptr[1] = max_f; + max_ptr[2] = max_f; + max_ptr[3] = max_f; + } + op_desc.SetInput("Filter", filter_name); + op_desc.SetInput("Bias", bias_name); + op_desc.SetInput("MaxFilter", max_filter_name); + op_desc.SetOutput("Output", {matched.at("resnet_block2_out")->arg()->name}); + op_desc.SetAttr("xpu", 1); + auto* block2_op_info = matched.at("resnet_block2")->stmt()->op_info(); + op_desc.SetAttr("pool_p", block2_op_info->GetAttr("pool_p")); + + auto resnet_cbam_op = LiteOpRegistry::Global().Create(op_desc.Type()); + resnet_cbam_op->Attach(op_desc, scope); + resnet_cbam_op->SetValidPlaces(resnet_cbam_stmt->op()->valid_places()); + auto kernels = + resnet_cbam_op->CreateKernels(resnet_cbam_op->valid_places()); + resnet_cbam_stmt->SetOp(resnet_cbam_op); + resnet_cbam_stmt->SetKernels(std::move(kernels)); + + IR_NODE_LINK_TO(matched.at("top_bn_bias"), matched.at("top_conv")); + for (auto* node : extra_input_vars) { + IR_NODE_LINK_TO(node, matched.at("top_conv")); + } + IR_OP_VAR_LINK(matched.at("top_conv"), matched.at("resnet_block2_out")); + } +}; + +} // namespace fusion + +class XPUResNetCbamFusePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override { + if (GetBoolFromEnv("XPU_ENABLE_XTCL")) return; + fusion::XPUResNetCbamBlock0Fuser block0_fuser; + block0_fuser(graph.get()); + fusion::XPUResNetCbamBlock1Fuser block1_fuser; + block1_fuser(graph.get()); + fusion::XPUResNetCbamBlock2Fuser block2_fuser; + block2_fuser(graph.get()); + fusion::XPUResNetCbamFuser resnet_fuser; + resnet_fuser(graph.get()); + } +}; + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(__xpu__resnet_cbam_fuse_pass, + paddle::lite::mir::XPUResNetCbamFusePass) + .BindTargets({TARGET(kXPU)}) + .BindKernel("__xpu__resnet_cbam"); diff --git a/lite/core/mir/fusion/conv_bn_fuser.cc b/lite/core/mir/fusion/conv_bn_fuser.cc index a8a5a5deb2a57982587d9db9f94cadb367af8595..a05f8fe8da5ee72581a9254b4d39354a0c5180e6 100644 --- a/lite/core/mir/fusion/conv_bn_fuser.cc +++ b/lite/core/mir/fusion/conv_bn_fuser.cc @@ -192,7 +192,8 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { } else if (is_weight_quantization) { std::string scale_name = conv_weight_name + "_quant_scale"; if (conv_op_desc->HasAttr(scale_name)) { - auto scale = conv_op_desc->GetAttr>(scale_name); + std::vector scale = + conv_op_desc->GetAttr>(scale_name); CHECK_EQ(scale.size(), alpha_tensor.numel()); for (size_t i = 0; i < scale.size(); i++) { scale[i] *= alpha_data[i]; diff --git a/lite/core/mir/fusion/conv_conv_fuse_pass.cc b/lite/core/mir/fusion/conv_conv_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..a86903d5d8056340605683ccefe607b0e4909a1c --- /dev/null +++ b/lite/core/mir/fusion/conv_conv_fuse_pass.cc @@ -0,0 +1,61 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/fusion/conv_conv_fuse_pass.h" +#include +#include +#include "lite/core/mir/fusion/conv_conv_fuser.h" +#include "lite/core/mir/graph_visualize_pass.h" +#include "lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void ConvConvFusePass::Apply(const std::unique_ptr& graph) { + // initialze fuser params + std::vector conv_has_bias_cases{true, false}; + std::vector conv_type_cases{ + "conv2d", "depthwise_conv2d"}; + bool has_arm = false; + for (auto& place : graph->valid_places()) { + if (place.target == TARGET(kARM) && place.precision == PRECISION(kFloat)) { + has_arm = true; + break; + } + } + if (!has_arm) { + return; + } + // only support fp32 fusion + for (auto conv_has_bias0 : conv_has_bias_cases) { + for (auto conv_has_bias1 : conv_has_bias_cases) { + for (auto conv_type0 : conv_type_cases) { + for (auto conv_type1 : conv_type_cases) { + VLOG(4) << "conv_has_bias0:" << conv_has_bias0 + << " conv_type0:" << conv_type0; + VLOG(4) << "conv_has_bias1:" << conv_has_bias1 + << " conv_type1:" << conv_type1; + fusion::ConvConvFuser fuser(conv_type0, conv_type1, conv_has_bias0, conv_has_bias1); + fuser(graph.get()); + } + } +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(lite_conv_conv_fuse_pass, paddle::lite::mir::ConvConvFusePass) + .BindTargets({TARGET(kARM)}); diff --git a/lite/core/mir/fusion/conv_conv_fuse_pass1.cc b/lite/core/mir/fusion/conv_conv_fuse_pass1.cc new file mode 100644 index 0000000000000000000000000000000000000000..a86903d5d8056340605683ccefe607b0e4909a1c --- /dev/null +++ b/lite/core/mir/fusion/conv_conv_fuse_pass1.cc @@ -0,0 +1,61 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/fusion/conv_conv_fuse_pass.h" +#include +#include +#include "lite/core/mir/fusion/conv_conv_fuser.h" +#include "lite/core/mir/graph_visualize_pass.h" +#include "lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void ConvConvFusePass::Apply(const std::unique_ptr& graph) { + // initialze fuser params + std::vector conv_has_bias_cases{true, false}; + std::vector conv_type_cases{ + "conv2d", "depthwise_conv2d"}; + bool has_arm = false; + for (auto& place : graph->valid_places()) { + if (place.target == TARGET(kARM) && place.precision == PRECISION(kFloat)) { + has_arm = true; + break; + } + } + if (!has_arm) { + return; + } + // only support fp32 fusion + for (auto conv_has_bias0 : conv_has_bias_cases) { + for (auto conv_has_bias1 : conv_has_bias_cases) { + for (auto conv_type0 : conv_type_cases) { + for (auto conv_type1 : conv_type_cases) { + VLOG(4) << "conv_has_bias0:" << conv_has_bias0 + << " conv_type0:" << conv_type0; + VLOG(4) << "conv_has_bias1:" << conv_has_bias1 + << " conv_type1:" << conv_type1; + fusion::ConvConvFuser fuser(conv_type0, conv_type1, conv_has_bias0, conv_has_bias1); + fuser(graph.get()); + } + } +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(lite_conv_conv_fuse_pass, paddle::lite::mir::ConvConvFusePass) + .BindTargets({TARGET(kARM)}); diff --git a/lite/core/mir/fusion/transpose_softmax_transpose_fuser.cc b/lite/core/mir/fusion/transpose_softmax_transpose_fuser.cc index d578b725ec42c926e5f0581fd8eeef855e586bdc..68417783e932f3c882eaae38e620b8b651b937dd 100644 --- a/lite/core/mir/fusion/transpose_softmax_transpose_fuser.cc +++ b/lite/core/mir/fusion/transpose_softmax_transpose_fuser.cc @@ -84,11 +84,12 @@ cpp::OpDesc TransposeSoftmaxTransposeFuser::GenOpDesc( op_desc.SetInput("X", {matched.at("x1")->arg()->name}); op_desc.SetOutput("Out", {matched.at("out")->arg()->name}); op_desc.SetAttr("axis", - matched.at("transpose1") - ->stmt() - ->op_info() - ->GetAttr>("axis") - .back()); + *(matched.at("transpose1") + ->stmt() + ->op_info() + ->GetAttr>("axis") + .end() - + 1)); return op_desc; } diff --git a/lite/core/mir/graph_visualize_pass.cc b/lite/core/mir/graph_visualize_pass.cc index 55b7a004567ec5a5298e084839d6dcf5a8591882..98b1597b49b9a7e151c86d11843e45163890191a 100644 --- a/lite/core/mir/graph_visualize_pass.cc +++ b/lite/core/mir/graph_visualize_pass.cc @@ -62,15 +62,17 @@ std::string Visualize(mir::SSAGraph* graph) { << string_trunc(op_info->GetAttr(attr_name)) << "\""; break; case AttrType::FLOATS: { - auto vals = op_info->GetAttr>(attr_name); + std::vector vals = + op_info->GetAttr>(attr_name); os << ":floats: {" + Join(vals, ",") << "}"; } break; case AttrType::INTS: { - auto vals = op_info->GetAttr>(attr_name); + std::vector vals = op_info->GetAttr>(attr_name); os << ":ints: {" + Join(vals, ",") + "}"; } break; case AttrType::STRINGS: { - auto vals = op_info->GetAttr>(attr_name); + std::vector vals = + op_info->GetAttr>(attr_name); os << ":strings: {" + string_trunc(Join(vals, ",")) << "}"; } break; default: diff --git a/lite/core/optimizer.h b/lite/core/optimizer.h index a498df3d4424e34ab0ab61b972eaf88b88f9bfaa..70905c96f08d74fc5e27c85c7ccf3d395420a5e9 100644 --- a/lite/core/optimizer.h +++ b/lite/core/optimizer.h @@ -94,6 +94,8 @@ class Optimizer { #endif "identity_dropout_eliminate_pass", "__xpu__resnet_fuse_pass", + "__xpu__resnet_cbam_fuse_pass", + "__xpu__mmdnn_fuse_pass", "__xpu__multi_encoder_fuse_pass", "__xpu__embedding_with_eltwise_add_fuse_pass", "__xpu__fc_fuse_pass", diff --git a/lite/core/program.cc b/lite/core/program.cc index acda3a642d44361c71860469d0901fe512aedb6f..2d0d4c8b66138e40d6986fcaa39e35e82322ece5 100644 --- a/lite/core/program.cc +++ b/lite/core/program.cc @@ -195,7 +195,7 @@ void Program::Build(const cpp::ProgramDesc& prog) { CHECK(ops_.empty()) << "Executor duplicate Build found"; // Create operators. - auto program = prog; + auto& program = prog; CHECK(program.BlocksSize()); auto& main_block = *program.GetBlock(0); for (size_t i = 0; i < main_block.OpsSize(); ++i) { @@ -262,7 +262,7 @@ void Program::PrepareWorkspace(const cpp::ProgramDesc& prog, } }; - auto program = prog; + auto& program = prog; CHECK(program.BlocksSize()); for (size_t b = 0; b < program.BlocksSize(); ++b) { auto& main_block = *program.GetBlock(b); diff --git a/lite/core/program.h b/lite/core/program.h index fb4265623af3d423821d0b60b5ea93bdc2b7115b..544795af2e6642baedcb6b3d1333f43b428f819d 100644 --- a/lite/core/program.h +++ b/lite/core/program.h @@ -46,7 +46,8 @@ struct Program { const std::shared_ptr& root, const std::vector& valid_places, const std::vector& var_names = {}) - : scope_(root), valid_places_(valid_places), desc_(desc) { + : scope_(root), valid_places_(valid_places) { + desc_.CopyFrom(desc); CHECK(scope_) << "scope should be init first"; VLOG(4) << "prepare work"; PrepareWorkspace(desc, var_names); diff --git a/lite/kernels/apu/subgraph_compute.cc b/lite/kernels/apu/subgraph_compute.cc index d5599e959d97d505b4d368d4000274b529dc9536..21373811dd91d009d834a16d2c437bc722cd676a 100644 --- a/lite/kernels/apu/subgraph_compute.cc +++ b/lite/kernels/apu/subgraph_compute.cc @@ -28,7 +28,7 @@ namespace lite { namespace kernels { namespace apu { -int SubgraphEngine::BuildDeviceProgram() { +bool SubgraphEngine::BuildDeviceProgram() { unsigned int version; Neuron_getVersion(&version); VLOG(3) << "Neuron Adapter version: " << version; @@ -38,7 +38,7 @@ int SubgraphEngine::BuildDeviceProgram() { int neuron_errCode = NeuronModel_create(&model_); if (NEURON_NO_ERROR != neuron_errCode) { LOG(WARNING) << "Fail to create model"; - return subgraph::FAILED; + return false; } graph.set_model(model_); graph.set_input_names(input_names_); @@ -46,6 +46,9 @@ int SubgraphEngine::BuildDeviceProgram() { // Convert all of ops and their input vars and weights and added into the APU // NIR graph + if (origin_program_.empty()) { + BuildOriginProgram(); + } const auto& bridges = subgraph::Registry::Instance(); for (auto& inst : origin_program_) { auto op = const_cast(inst.op()); @@ -54,7 +57,7 @@ int SubgraphEngine::BuildDeviceProgram() { op->InferShape(); std::string op_type = op->op_info()->Type(); if (!bridges.Exists(op_type, TARGET(kAPU))) { - return subgraph::FAILED; + return false; } auto kernel = inst.kernel(); @@ -63,7 +66,7 @@ int SubgraphEngine::BuildDeviceProgram() { const_cast(op), const_cast(kernel)); if (subgraph::CHECK_FAILED(status)) { - return subgraph::FAILED; + return false; } } @@ -84,7 +87,7 @@ int SubgraphEngine::BuildDeviceProgram() { VLOG(3) << "input idx: " << graph.Get(input_names_[i])->index(); } else { LOG(WARNING) << "Fail to find input: " << input_names_[i]; - return subgraph::FAILED; + return false; } } @@ -105,7 +108,7 @@ int SubgraphEngine::BuildDeviceProgram() { VLOG(3) << "output idx: " << graph.Get(output_names_[i])->index(); } else { LOG(WARNING) << "Fail to find output: " << output_names_[i]; - return subgraph::FAILED; + return false; } } @@ -116,7 +119,7 @@ int SubgraphEngine::BuildDeviceProgram() { neuron_errCode = NeuronModel_finish(model_); if (NEURON_NO_ERROR != neuron_errCode) { LOG(WARNING) << "Fail to create NIR model:" << neuron_errCode; - return subgraph::FAILED; + return false; } VLOG(3) << "[APU] APU NIR model created!"; @@ -129,15 +132,14 @@ int SubgraphEngine::BuildDeviceProgram() { compilation_ = lite::apu::Device::Global().Build(model_); if (compilation_ == nullptr) { LOG(WARNING) << "[APU] Build APU DLA model failed!"; - return subgraph::FAILED; + return false; } VLOG(3) << "[APU] APU DLA model created, Build cost " << GetCurrentUS() - start_time << " us"; - - return status; + return true; } -int SubgraphEngine::LaunchDeviceProgram() { +bool SubgraphEngine::LaunchDeviceProgram() { auto GetCurrentUS = []() -> double { struct timeval time; gettimeofday(&time, NULL); @@ -149,7 +151,7 @@ int SubgraphEngine::LaunchDeviceProgram() { int neuron_errCode = NeuronExecution_create(compilation_, &run); if (NEURON_NO_ERROR != neuron_errCode) { LOG(WARNING) << "[APU] Build APU runtime failed!"; - return subgraph::FAILED; + return false; } // Set input buffer @@ -177,7 +179,7 @@ int SubgraphEngine::LaunchDeviceProgram() { neuron_errCode = NeuronExecution_compute(run); if (NEURON_NO_ERROR != neuron_errCode) { LOG(WARNING) << "Fail to run execution!" << neuron_errCode; - return subgraph::FAILED; + return false; } for (size_t i = 0; i < origin_otensors_.size(); i++) { @@ -190,7 +192,7 @@ int SubgraphEngine::LaunchDeviceProgram() { } NeuronExecution_free(run); VLOG(3) << "[APU] Process cost " << GetCurrentUS() - start_time << " us"; - return 0; + return true; } SubgraphEngine::~SubgraphEngine() { @@ -211,12 +213,11 @@ void SubgraphCompute::PrepareForRun() { param.output_data_names, param.scope)); CHECK(engine_); - engine_->Build(); } void SubgraphCompute::Run() { CHECK(engine_); - engine_->Launch(); + engine_->Run(); } } // namespace apu diff --git a/lite/kernels/apu/subgraph_compute.h b/lite/kernels/apu/subgraph_compute.h index ecd8a38343cd1f62bb5a3bf8e948384b90cfe826..beb582b8cc16e456491c28ace5e2d1695143216a 100644 --- a/lite/kernels/apu/subgraph_compute.h +++ b/lite/kernels/apu/subgraph_compute.h @@ -41,8 +41,8 @@ class SubgraphEngine : public subgraph::Engine { ~SubgraphEngine(); protected: - int BuildDeviceProgram() override; - int LaunchDeviceProgram() override; + bool BuildDeviceProgram() override; + bool LaunchDeviceProgram() override; NeuronModel *model_; NeuronCompilation *compilation_; diff --git a/lite/kernels/arm/CMakeLists.txt b/lite/kernels/arm/CMakeLists.txt index 720e5be37a430d33427520ca1edafdd542784bb6..6d1d24adcb4cf74b3c6bb991a33316e974dc0110 100644 --- a/lite/kernels/arm/CMakeLists.txt +++ b/lite/kernels/arm/CMakeLists.txt @@ -103,7 +103,6 @@ add_kernel(deformable_conv_compute_arm ARM extra SRCS deformable_conv_compute.cc add_kernel(mean_compute_arm ARM extra SRCS mean_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(mean_grad_compute_arm ARM train SRCS mean_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) -add_kernel(activation_grad_compute_arm ARM train SRCS activation_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(elementwise_grad_compute_arm ARM train SRCS elementwise_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(mul_grad_compute_arm ARM train SRCS mul_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(sgd_compute_arm ARM train SRCS sgd_compute.cc DEPS ${lite_kernel_deps} math_arm) diff --git a/lite/kernels/arm/activation_grad_compute.cc b/lite/kernels/arm/activation_grad_compute.cc deleted file mode 100644 index 137668fa5e0d1bd07e838b3040a31e084a7475c8..0000000000000000000000000000000000000000 --- a/lite/kernels/arm/activation_grad_compute.cc +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "lite/kernels/arm/activation_grad_compute.h" -#include "lite/backends/arm/math/funcs.h" - -namespace paddle { -namespace lite { -namespace kernels { -namespace arm { - -void SquareGradCompute::Run() { - auto& param = this->Param(); - auto& ctx = this->ctx_->template As(); - auto out_grad_dims = param.Out_grad->dims(); - auto out_grad_data = param.Out_grad->data(); - - auto x_data = param.X->data(); - auto x_grad_data = param.X_grad->mutable_data(); - lite::arm::math::act_square_grad(x_data, - out_grad_data, - x_grad_data, - out_grad_dims.production(), - ctx.threads()); -} - -} // namespace arm -} // namespace kernels -} // namespace lite -} // namespace paddle - -REGISTER_LITE_KERNEL(square_grad, - kARM, - kFloat, - kNCHW, - paddle::lite::kernels::arm::SquareGradCompute, - def) - .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindInput("Out@GRAD", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindOutput("X@GRAD", {LiteType::GetTensorTy(TARGET(kARM))}) - .Finalize(); diff --git a/lite/kernels/arm/sequence_conv_compute.cc b/lite/kernels/arm/sequence_conv_compute.cc index d4685b2d3f0afb6980b46cf5b6fa8ad64c8df324..f1e7f678cd69abd84b45cf5031556eea4b837ca3 100644 --- a/lite/kernels/arm/sequence_conv_compute.cc +++ b/lite/kernels/arm/sequence_conv_compute.cc @@ -103,10 +103,19 @@ void SequenceConvCompute::Run() { 1, 1, // stride_h, stride_w, dilation_h, dilation_w tmp_data); - local_naive_transpose(tmp_data, - sub_col_data, - kernel_size * hidden_dim, - input_row_end - input_row_begin); + int cols = kernel_size * hidden_dim; + int rows = input_row_end - input_row_begin; + if (cols % 4 == 0 && rows % 4 == 0) { + paddle::lite::arm::math::local_transpose(tmp_data, + sub_col_data, + cols, + rows); + } else { + local_naive_transpose(tmp_data, + sub_col_data, + cols, + rows); + } } } diff --git a/lite/kernels/bm/bridges/batch_norm_op.cc b/lite/kernels/bm/bridges/batch_norm_op.cc index fbf70178fdd971edce34b3253b02febfa3e3b85c..f5ecc0825a17f26b1cf65605ea2e8c0c93338f39 100644 --- a/lite/kernels/bm/bridges/batch_norm_op.cc +++ b/lite/kernels/bm/bridges/batch_norm_op.cc @@ -13,6 +13,7 @@ // limitations under the License. #include +#include #include "lite/kernels/bm/bridges/graph.h" #include "lite/kernels/bm/bridges/utility.h" #include "lite/kernels/npu/bridges/registry.h" @@ -64,10 +65,16 @@ int BatchNormConverter(void* ctx, OpLite* op, KernelBase* kernel) { auto* bias_data = bias->mutable_data(); auto* mean_data = mean->mutable_data(); auto* variance_data = variance->mutable_data(); + + float* new_bias = static_cast(malloc(bias->memory_size())); + float* new_scale = static_cast(malloc(scale->memory_size())); + CHECK(new_bias != nullptr); + CHECK(new_scale != nullptr); + for (int c = 0; c < channel_size; c++) { float inv_scale = 1.f / (std::sqrt(variance_data[c] + epsilon)); - bias_data[c] = bias_data[c] - inv_scale * scale_data[c] * mean_data[c]; - scale_data[c] = inv_scale * scale_data[c]; + new_bias[c] = bias_data[c] - inv_scale * scale_data[c] * mean_data[c]; + new_scale[c] = inv_scale * scale_data[c]; } const int input_num = 1; @@ -86,11 +93,13 @@ int BatchNormConverter(void* ctx, OpLite* op, KernelBase* kernel) { output_dims.size(), static_cast(output_var_name.c_str()), static_cast(unique_op_name.c_str()), - static_cast(scale->mutable_data()), - static_cast(bias->mutable_data()), + static_cast(new_scale), + static_cast(new_bias), 1, 1, 1); + free(new_scale); + free(new_bias); delete[] shape; delete[] name; delete[] dim; diff --git a/lite/kernels/bm/bridges/density_prior_box_op.cc b/lite/kernels/bm/bridges/density_prior_box_op.cc index 137c5142d5ae544226dbe5d6cd7c872fc272b71a..895901d94e2b2077f530e196ef8f30d4f57df793 100644 --- a/lite/kernels/bm/bridges/density_prior_box_op.cc +++ b/lite/kernels/bm/bridges/density_prior_box_op.cc @@ -13,6 +13,7 @@ // limitations under the License. #include +#include #include "lite/kernels/bm/bridges/graph.h" #include "lite/kernels/bm/bridges/utility.h" #include "lite/kernels/npu/bridges/registry.h" diff --git a/lite/kernels/bm/bridges/interpolate_op.cc b/lite/kernels/bm/bridges/interpolate_op.cc index 8c2d39b16ac0206d83199fdeac6c30a0a352856e..a77ec4e8f788e581d9d226369210a449ec50840c 100644 --- a/lite/kernels/bm/bridges/interpolate_op.cc +++ b/lite/kernels/bm/bridges/interpolate_op.cc @@ -76,6 +76,8 @@ int InterpolateConverter(void* ctx, OpLite* op, KernelBase* kernel) { static_cast(output_var_name.c_str()), 0, 0, + 0, + 0, type); } graph->AddNode(output_var_name); diff --git a/lite/kernels/bm/subgraph_compute.cc b/lite/kernels/bm/subgraph_compute.cc index d7640e1ac7326d9764380469dc97a7806b044437..664198cf9fb45664fdc088df382b9b94a1924e9b 100644 --- a/lite/kernels/bm/subgraph_compute.cc +++ b/lite/kernels/bm/subgraph_compute.cc @@ -28,12 +28,35 @@ namespace lite { namespace kernels { namespace bm { -int SubgraphEngine::BuildDeviceProgram() { +bool SubgraphEngine::PrepareWorkspaceForDeviceProgram() { + // Obtain the origin input tensors, and create the origin output + // tensors(Don't try to access them before launch the device program or the + // origin program) + PrepareWorkspaceForOriginProgram(); + // Create the device input and output tensors, but don't initialize them + // with the dimensions + device_inputs_.resize(input_names_.size()); + for (int i = 0; i < input_names_.size(); i++) { + device_inputs_[i].reset(new hiai::AiTensor); + CHECK(device_inputs_[i]); + } + device_outputs_.resize(output_names_.size()); + for (int i = 0; i < output_names_.size(); i++) { + device_outputs_[i].reset(new hiai::AiTensor); + CHECK(device_outputs_[i]); + } + return true; +} + +bool SubgraphEngine::BuildDeviceProgram() { int status = 0; subgraph::bm::Graph graph; const auto& bridges = subgraph::Registry::Instance(); graph.CreateCompilerHandle(); auto& ctx = this->ctx_->template As(); + if (origin_program_.empty()) { + BuildOriginProgram(); + } for (auto& inst : origin_program_) { auto op = const_cast(inst.op()); CHECK(op); @@ -42,7 +65,7 @@ int SubgraphEngine::BuildDeviceProgram() { std::string op_type = op->op_info()->Type(); LOG(INFO) << op_type; if (!bridges.Exists(op_type, TARGET(kBM))) { - return subgraph::FAILED; + return false; } auto kernel = inst.kernel(); status |= @@ -50,12 +73,13 @@ int SubgraphEngine::BuildDeviceProgram() { const_cast(op), const_cast(kernel)); if (subgraph::CHECK_FAILED(status)) { - return subgraph::FAILED; + return false; } } - std::string net_name = "bmnetc_f32umodel"; + std::string net_name = "bmnet_f32bmodel"; + auto unique_net_name = lite::subgraph::bm::UniqueName(net_name); __bmcompile_opt( - graph.GetCompilerHandle(), const_cast(net_name.c_str()), 1); + graph.GetCompilerHandle(), const_cast(unique_net_name.c_str()), 2); void* bmodel_data = nullptr; unsigned int data_size = 0; bm_hd_ = static_cast(ctx.GetHandle()); @@ -63,7 +87,7 @@ int SubgraphEngine::BuildDeviceProgram() { graph.UnlockCompilerMutex(); bmrt_hd_ = bmrt_create(bm_hd_); if (false == bmrt_load_bmodel_data(bmrt_hd_, bmodel_data, data_size)) { - return subgraph::FAILED; + return false; } bmrt_get_network_names(bmrt_hd_, &net_names_); net_info_ = bmrt_get_network_info(bmrt_hd_, net_names_[0]); @@ -116,10 +140,10 @@ int SubgraphEngine::BuildDeviceProgram() { net_info_->output_dtypes[i], stage.output_shapes[i]); } - return status; + return true; } -int SubgraphEngine::LaunchDeviceProgram() { +bool SubgraphEngine::LaunchDeviceProgram() { for (size_t i = 0; i < device_inputs_.size(); i++) { bm_memcpy_s2d(bm_hd_, device_inputs_[i].device_mem, @@ -143,7 +167,7 @@ int SubgraphEngine::LaunchDeviceProgram() { out_index++; } } - return 0; + return true; } void SubgraphCompute::PrepareForRun() { @@ -155,12 +179,11 @@ void SubgraphCompute::PrepareForRun() { param.output_data_names, param.scope)); CHECK(engine_); - engine_->Build(); } void SubgraphCompute::Run() { CHECK(engine_); - engine_->Launch(); + engine_->Run(); } } // namespace bm diff --git a/lite/kernels/bm/subgraph_compute.h b/lite/kernels/bm/subgraph_compute.h index 60f7661c7990d90020dbfc7ec3a6e0d178dceb70..7a5b2552ff95681da09346ba11f40f1a6acb7f01 100644 --- a/lite/kernels/bm/subgraph_compute.h +++ b/lite/kernels/bm/subgraph_compute.h @@ -44,8 +44,9 @@ class SubgraphEngine : public subgraph::Engine { ctx, block_idx, block_desc, input_names, output_names, scope) {} protected: - int BuildDeviceProgram() override; - int LaunchDeviceProgram() override; + bool PrepareWorkspaceForDeviceProgram() override; + bool BuildDeviceProgram() override; + bool LaunchDeviceProgram() override; private: void *bmrt_hd_; diff --git a/lite/kernels/cuda/CMakeLists.txt b/lite/kernels/cuda/CMakeLists.txt index 22bb4345fe744df9a06997d366310e2cc24a7a12..76e2d1545ea74b9e6a2e72ed6a7088e52cf53d3f 100644 --- a/lite/kernels/cuda/CMakeLists.txt +++ b/lite/kernels/cuda/CMakeLists.txt @@ -7,6 +7,7 @@ message(STATUS "compile with lite CUDA kernels") # basic kernels add_kernel(mul_compute_cuda CUDA basic SRCS mul_compute.cc DEPS ${lite_kernel_deps} ${math_cuda}) add_kernel(fc_compute_cuda CUDA basic SRCS fc_compute.cu DEPS ${lite_kernel_deps} ${math_cuda}) +add_kernel(gru_compute_cuda CUDA basic SRCS gru_compute.cu DEPS ${lite_kernel_deps} ${math_cuda}) add_kernel(matmul_compute_cuda CUDA basic SRCS matmul_compute.cc DEPS ${lite_kernel_deps} ${math_cuda}) add_kernel(search_group_padding_compute_cuda CUDA basic SRCS search_group_padding_compute.cu DEPS ${lite_kernel_deps}) add_kernel(io_copy_compute_cuda CUDA basic SRCS io_copy_compute.cc DEPS ${lite_kernel_deps}) @@ -69,6 +70,7 @@ nv_test(softmax_compute_cuda_test SRCS softmax_compute_test.cc DEPS softmax_comp #nv_test(layout_cuda_test SRCS layout_compute_test.cc DEPS layout_compute_cuda) nv_test(mul_compute_cuda_test SRCS mul_compute_test.cc DEPS mul_compute_cuda) nv_test(fc_compute_cuda_test SRCS fc_compute_test.cc DEPS fc_compute_cuda) +nv_test(gru_compute_cuda_test SRCS gru_compute_test.cc DEPS gru_compute_cuda) nv_test(matmul_compute_cuda_test SRCS matmul_compute_test.cc DEPS matmul_compute_cuda) nv_test(dropout_compute_cuda_test SRCS dropout_compute_test.cc DEPS dropout_compute_cuda ) nv_test(bilinear_interp_compute_cuda_test SRCS bilinear_interp_compute_test.cc DEPS bilinear_interp_compute_cuda) diff --git a/lite/kernels/cuda/gru_compute.cu b/lite/kernels/cuda/gru_compute.cu new file mode 100644 index 0000000000000000000000000000000000000000..c9fa6c9d5c73baa7ae4ced77172a9362a9e3218b --- /dev/null +++ b/lite/kernels/cuda/gru_compute.cu @@ -0,0 +1,236 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#include "lite/backends/cuda/cuda_utils.h" +#include "lite/backends/cuda/math/bias.h" +#include "lite/backends/cuda/math/gru_forward.h" +#include "lite/backends/cuda/math/sequence2batch.h" +#include "lite/backends/cuda/target_wrapper.h" +#include "lite/core/op_registry.h" +#include "lite/kernels/cuda/gru_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +template +struct GRUMetaValue { + T* gate_weight; + T* state_weight; + T* gate_value; + T* reset_output_value; + T* output_value; + T* prev_out_value; +}; + +template +struct GRUUnitFunctor { + static void compute(GRUMetaValue value, + int frame_size, + int batch_size, + const lite::cuda::math::ActivationType& active_node, + const lite::cuda::math::ActivationType& active_gate, + bool origin_mode, + lite::cuda::math::Gemm* blas, + CUDAContext* context) { + dim3 threads, grids; + if (batch_size == 1) { + int frame_per_block = frame_size <= 1024 ? frame_size : 1024; + int frame_blocks = (frame_size + 1024 - 1) / 1024; + threads = dim3(frame_per_block, 1); + grids = dim3(frame_blocks, 1); + } else { + threads = dim3(32, 32); + grids = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32); + } + + if (value.prev_out_value) { + CHECK(blas->init(false, + false, + batch_size, + frame_size * 2, + frame_size, + frame_size, + frame_size * 2, + frame_size * 3, + context)); + blas->run(1.0f, + 1.0f, + value.prev_out_value, + value.gate_weight, + value.gate_value, + context); + } + CUDA_POST_KERNEL_CHECK; + + lite::cuda::math::GruForwardResetOutput< + T><<exec_stream()>>>( + value.gate_value, + value.reset_output_value, + value.prev_out_value, + frame_size, + batch_size, + active_gate, + batch_size == 1); + CUDA_POST_KERNEL_CHECK; + + if (value.prev_out_value) { + CHECK(blas->init(false, + false, + batch_size, + frame_size, + frame_size, + frame_size, + frame_size, + frame_size * 3, + context)); + blas->run(1.0f, + 1.0f, + value.reset_output_value, + value.state_weight, + value.gate_value + frame_size * 2, + context); + } + CUDA_POST_KERNEL_CHECK; + + lite::cuda::math::GruForwardFinalOutput< + T><<exec_stream()>>>(value.gate_value, + value.prev_out_value, + value.output_value, + frame_size, + batch_size, + active_node, + origin_mode, + batch_size == 1); + CUDA_POST_KERNEL_CHECK; + } +}; + +template struct GRUUnitFunctor; + +template +void GRUCompute::PrepareForRun() { + gemm_impl_.reset(new lite::cuda::math::Gemm); +} + +template +void GRUCompute::Run() { + auto& context = this->ctx_->template As(); + auto stream = context.exec_stream(); + auto& param = this->template Param(); + + auto* input = param.input; + lite::Tensor* h0{nullptr}; + if (param.h0) { + h0 = const_cast(param.h0); + } + lite::Tensor* bias{nullptr}; + if (param.bias) { + bias = const_cast(param.bias); + } + auto* weight = param.weight; + auto* weight_data = const_cast(weight->template data()); + auto* batch_gate = param.batch_gate; + auto* batch_reset_hidden_prev = param.batch_reset_hidden_prev; + auto* batch_hidden = param.batch_hidden; + auto* hidden = param.hidden; + auto* batch_reset_hidden_prev_data = + batch_reset_hidden_prev->template mutable_data(TARGET(kCUDA)); + hidden->template mutable_data(TARGET(kCUDA)); + auto* batch_gate_data = batch_gate->template mutable_data(TARGET(kCUDA)); + auto* batch_hidden_data = + batch_hidden->template mutable_data(TARGET(kCUDA)); + bool is_reverse = param.is_reverse; + auto active_node = lite::cuda::math::GetActiveType(param.activation); + auto active_gate = lite::cuda::math::GetActiveType(param.gate_activation); + bool origin_mode = param.origin_mode; + + auto hidden_dims = hidden->dims(); + int frame_size = hidden_dims[1]; + + lite::cuda::math::LoDTensor2BatchFunctor batch_func; + batch_func(*input, batch_gate, is_reverse, stream); + + if (bias) { + lite::cuda::math::RowwiseAdd add_bias; + add_bias(batch_gate_data, + bias->template data(), + batch_gate_data, + frame_size, + batch_gate->numel(), + stream); + } + GRUMetaValue gru_value; + gru_value.gate_weight = weight_data; + gru_value.state_weight = weight_data + 2 * frame_size * frame_size; + + if (h0) { + // Since the batch computing for GRU reorders the input sequences + // according to their length. The initialized cell state also needs + // to reorder. + ordered_h0_.Resize(h0->dims()); + lite::cuda::math::CopyMatrixRowsFunctor row_shuffle; + row_shuffle(*h0, &ordered_h0_, batch_gate->lod()[2], true, stream); + gru_value.prev_out_value = ordered_h0_.mutable_data(TARGET(kCUDA)); + } else { + gru_value.prev_out_value = nullptr; + } + auto batch_starts = batch_gate->lod()[0]; + size_t num_batch = batch_starts.size() - 1; + for (size_t n = 0; n < num_batch; ++n) { + int bstart = static_cast(batch_starts[n]); + int bend = static_cast(batch_starts[n + 1]); + int cur_batch_size = bend - bstart; + + gru_value.output_value = batch_hidden_data + bstart * frame_size; + gru_value.gate_value = batch_gate_data + bstart * frame_size * 3; + gru_value.reset_output_value = + batch_reset_hidden_prev_data + bstart * frame_size; + + GRUUnitFunctor::compute(gru_value, + frame_size, + cur_batch_size, + active_node, + active_gate, + origin_mode, + gemm_impl_.get(), + &context); + gru_value.prev_out_value = gru_value.output_value; + } + + lite::cuda::math::Batch2LoDTensorFunctor to_seq; + batch_hidden->set_lod(batch_gate->lod()); + to_seq(*batch_hidden, hidden, stream); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +using GRUFp32 = + paddle::lite::kernels::cuda::GRUCompute; + +REGISTER_LITE_KERNEL(gru, kCUDA, kFloat, kNCHW, GRUFp32, def) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindInput("H0", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindInput("Weight", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("BatchGate", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("BatchResetHiddenPrev", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("BatchHidden", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Hidden", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .Finalize(); diff --git a/lite/kernels/cuda/gru_compute.h b/lite/kernels/cuda/gru_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..070deca2c54b919d1afeb856633d94fe5919eabd --- /dev/null +++ b/lite/kernels/cuda/gru_compute.h @@ -0,0 +1,46 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include "lite/backends/cuda/math/gemm.h" +#include "lite/core/kernel.h" +#include "lite/operators/op_params.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +template +class GRUCompute : public KernelLite { + public: + using param_t = operators::GRUParam; + + void PrepareForRun() override; + + void Run() override; + + virtual ~GRUCompute() = default; + + private: + std::unique_ptr> gemm_impl_{nullptr}; + lite::Tensor ordered_h0_; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/gru_compute_test.cc b/lite/kernels/cuda/gru_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ff728fd317753a5aa9108808e5f2250fe97310c3 --- /dev/null +++ b/lite/kernels/cuda/gru_compute_test.cc @@ -0,0 +1,140 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/gru_compute.h" + +#include + +#include +#include +#include +#include + +#include "lite/api/test_helper.h" +#include "lite/utils/float16.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class GRUTest : public ::testing::Test { + protected: + GRUTest() + : batch_(12), + frame_size_(128), + activation_("tanh"), + gate_activation_("sigmoid"), + is_reverse_(false), + origin_mode_(false), + x_shape_({batch_, frame_size_ * 3}), + w_shape_({frame_size_, frame_size_ * 3}), + out_shape_({batch_, frame_size_}), + lod_({{0, 4, 9, 12}}) { + x_ref_.Resize(lite::DDim(x_shape_)); + x_gpu_.Resize(lite::DDim(x_shape_)); + x_ref_.set_lod(lod_); + w_ref_.Resize(lite::DDim(w_shape_)); + w_gpu_.Resize(lite::DDim(w_shape_)); + auto x_ref_data = x_ref_.mutable_data(); + auto w_ref_data = w_ref_.mutable_data(); + for (int64_t i = 0; i < x_ref_.numel(); i++) { + x_ref_data[i] = static_cast(i % 10 * 0.2); + } + for (int64_t i = 0; i < w_ref_.numel(); i++) { + w_ref_data[i] = static_cast(i % 10 * 0.2); + } + + out_ref_.Resize(lite::DDim(out_shape_)); + out_cpu_.Resize(out_ref_.dims()); + out_gpu_.Resize(out_ref_.dims()); + batch_gate_gpu_.Resize(lite::DDim(x_shape_)); + batch_hidden_gpu_.Resize(lite::DDim(out_shape_)); + batch_reset_hidden_gpu_.Resize(lite::DDim(out_shape_)); + RunBaseLine(); + InitParamAndContext(); + } + + void InitParamAndContext() { + ctx_.reset(new KernelContext); + cudaStreamCreate(&stream_); + auto& context = ctx_->As(); + context.SetExecStream(stream_); + param_.input = &x_gpu_; + param_.weight = &w_gpu_; + param_.gate_activation = gate_activation_; + param_.activation = activation_; + param_.is_reverse = is_reverse_; + param_.origin_mode = origin_mode_; + param_.hidden = &out_gpu_; + param_.batch_gate = &batch_gate_gpu_; + param_.batch_reset_hidden_prev = &batch_reset_hidden_gpu_; + param_.batch_hidden = &batch_hidden_gpu_; + } + + void InitFloatInput() { + x_gpu_.Assign(x_ref_.data(), + x_gpu_.dims()); + x_gpu_.set_lod(x_ref_.lod()); + w_gpu_.Assign(w_ref_.data(), + w_gpu_.dims()); + } + + void RunBaseLine() {} + + int batch_, frame_size_; + std::string activation_, gate_activation_; + bool is_reverse_, origin_mode_; + std::vector x_shape_, w_shape_, out_shape_; + LoD lod_; + lite::Tensor x_ref_, w_ref_, out_ref_; + lite::Tensor x_gpu_, w_gpu_; + lite::Tensor x_half_, w_half_; + lite::Tensor batch_gate_gpu_; + lite::Tensor batch_hidden_gpu_; + lite::Tensor batch_reset_hidden_gpu_; + lite::Tensor out_cpu_, out_gpu_; + + operators::GRUParam param_; + std::unique_ptr ctx_; + cudaStream_t stream_; +}; + +TEST_F(GRUTest, TestFP32) { + InitFloatInput(); + GRUCompute kernel; + kernel.SetParam(param_); + kernel.SetContext(std::move(ctx_)); + + for (int i = 0; i < FLAGS_warmup; ++i) { + kernel.Launch(); + cudaDeviceSynchronize(); + } + + auto start = GetCurrentUS(); + kernel.PrepareForRun(); + for (int i = 0; i < FLAGS_repeats; ++i) { + kernel.Run(); + } + cudaDeviceSynchronize(); + auto duration = (GetCurrentUS() - start) / 1000.0; + LOG(INFO) << "fp32, warmup: " << FLAGS_warmup + << ", repeats: " << FLAGS_repeats << ", spend " + << duration / FLAGS_repeats << " ms in average."; +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/host/CMakeLists.txt b/lite/kernels/host/CMakeLists.txt index 401470808259d9603218913638a494202b9f1339..cd91d2dc90f9f48668e1d5ab9fbe5d065cb0e191 100644 --- a/lite/kernels/host/CMakeLists.txt +++ b/lite/kernels/host/CMakeLists.txt @@ -18,6 +18,7 @@ add_kernel(read_from_array_compute_host Host extra SRCS read_from_array_compute. add_kernel(assign_compute_host Host extra SRCS assign_compute.cc DEPS ${lite_kernel_deps}) add_kernel(retinanet_detection_output_compute_host Host extra SRCS retinanet_detection_output_compute.cc DEPS ${lite_kernel_deps}) add_kernel(where_index_compute_host Host extra SRCS where_index_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(activation_grad_compute_host Host train SRCS activation_grad_compute.cc DEPS ${lite_kernel_deps}) if(LITE_BUILD_EXTRA) lite_cc_test(test_where_index_compute_host SRCS where_index_compute.cc DEPS where_index_compute_host) diff --git a/lite/kernels/host/activation_grad_compute.cc b/lite/kernels/host/activation_grad_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..4b837cfda4572fa106a1ba1d015ffd5163b08340 --- /dev/null +++ b/lite/kernels/host/activation_grad_compute.cc @@ -0,0 +1,98 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/host/activation_grad_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace host { + +void SquareGradCompute::Run() { + auto& param = this->Param(); + CHECK(param.X); + auto out_grad_dims = param.Out_grad->dims(); + auto out_grad_data = param.Out_grad->data(); + + auto x_data = param.X->data(); + auto x_grad_data = param.X_grad->mutable_data(); + for (int i = 0; i < out_grad_dims.production(); i++) { + x_grad_data[i] = out_grad_data[i] * 2.0 * x_data[i]; + } +} + +void ReluGradCompute::Run() { + auto& param = this->Param(); + CHECK(param.X); + auto out_grad_dims = param.Out_grad->dims(); + auto out_grad_data = param.Out_grad->data(); + + auto x_data = param.X->data(); + auto x_grad_data = param.X_grad->mutable_data(); + for (int i = 0; i < out_grad_dims.production(); i++) { + x_grad_data[i] = x_data[i] > 0 ? out_grad_data[i] : 0.0; + } +} + +void TanhGradCompute::Run() { + auto& param = this->Param(); + CHECK(param.Out); + auto out_grad_dims = param.Out_grad->dims(); + auto out_grad_data = param.Out_grad->data(); + + auto out_data = param.Out->data(); + auto x_grad_data = param.X_grad->mutable_data(); + for (int i = 0; i < out_grad_dims.production(); i++) { + x_grad_data[i] = out_grad_data[i] * + (static_cast(1.0) - out_data[i] * out_data[i]); + } +} + +} // namespace host +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(square_grad, + kHost, + kFloat, + kNCHW, + paddle::lite::kernels::host::SquareGradCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindInput("Out@GRAD", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindOutput("X@GRAD", {LiteType::GetTensorTy(TARGET(kHost))}) + .Finalize(); + +REGISTER_LITE_KERNEL(relu_grad, + kHost, + kFloat, + kNCHW, + paddle::lite::kernels::host::SquareGradCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindInput("Out@GRAD", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindOutput("X@GRAD", {LiteType::GetTensorTy(TARGET(kHost))}) + .Finalize(); + +REGISTER_LITE_KERNEL(tanh_grad, + kHost, + kFloat, + kNCHW, + paddle::lite::kernels::host::SquareGradCompute, + def) + .BindInput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindInput("Out@GRAD", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindOutput("X@GRAD", {LiteType::GetTensorTy(TARGET(kHost))}) + .Finalize(); diff --git a/lite/kernels/host/activation_grad_compute.h b/lite/kernels/host/activation_grad_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..d942b901c448ee87410a2030ea0f9f10aca0e493 --- /dev/null +++ b/lite/kernels/host/activation_grad_compute.h @@ -0,0 +1,55 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace host { + +class SquareGradCompute : public KernelLite { + public: + using param_t = operators::ActivationGradParam; + + void Run() override; + + virtual ~SquareGradCompute() = default; +}; + +class ReluGradCompute : public KernelLite { + public: + using param_t = operators::ActivationGradParam; + + void Run() override; + + virtual ~ReluGradCompute() = default; +}; + +class TanhGradCompute : public KernelLite { + public: + using param_t = operators::ActivationGradParam; + + void Run() override; + + virtual ~TanhGradCompute() = default; +}; + +} // namespace host +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/mlu/bridges/conv_op.cc b/lite/kernels/mlu/bridges/conv_op.cc index 84c5bd5638585a5b5e1e22308c9ddf3c06acd9e9..6d10605e2c4060cbd8b30d358ac15f2e78f13ca5 100644 --- a/lite/kernels/mlu/bridges/conv_op.cc +++ b/lite/kernels/mlu/bridges/conv_op.cc @@ -107,8 +107,7 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { CNML_FILTER, CNML_NCHW, graph->FPType()); - const auto weight_scale = - op_info->GetAttr>("weight_scale"); + const auto weight_scale = op_info->GetInputScale(filter_var_name); if (filter->precision() == PrecisionType::kUnk || filter->precision() == PrecisionType::kInt8) { @@ -162,7 +161,7 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { graph->BindConstData(bias_var_name, bias); } - const auto input_scale = op_info->GetAttr("input_scale"); + const auto input_scale = op_info->GetInputScale(input_var_name)[0]; bool use_first_conv = false; if (lite::TargetWrapperMlu::UseFirstConv() && input_dims[1] == 3) { diff --git a/lite/kernels/mlu/bridges/conv_op_test.cc b/lite/kernels/mlu/bridges/conv_op_test.cc index ddaf5b321ffd2af1fbd91af6cf15b5c7789cbba3..e23f7c68ab0048b8cc04ffdae33ea94fcabbcf65 100644 --- a/lite/kernels/mlu/bridges/conv_op_test.cc +++ b/lite/kernels/mlu/bridges/conv_op_test.cc @@ -224,8 +224,10 @@ void test_conv(int bs, opdesc_mlu.SetAttr("groups", groups); opdesc_mlu.SetAttr("fuse_relu", static_cast(fuse_relu)); - opdesc_mlu.SetAttr("weight_scale", std::vector(oc, filter_scale)); - opdesc_mlu.SetAttr("input_scale", input_scale); + OpInfo op_info(opdesc_mlu); + op_info.SetInputScale(filter_int_var_name, + std::vector(oc, filter_scale)); + op_info.SetInputScale(input_var_name, {input_scale}); if (has_bias) { if (is_channel_bias) { @@ -234,7 +236,7 @@ void test_conv(int bs, bias->Resize({output_shape}); } FillTensor(bias); - opdesc_mlu.SetInput("Bias", {bias_var_name}); + op_info.SetInput("Bias", {bias_var_name}); } for (int i = 0; i < bs; i++) { @@ -248,7 +250,7 @@ void test_conv(int bs, } // create and convert op to MLU model, then run it on MLU - auto op = CreateOp(opdesc_mlu, &scope); + auto op = CreateOp(op_info, &scope); LaunchOp(op, {input_var_name}, {output_var_name}); // compare results auto* output_data = output->mutable_data(); diff --git a/lite/kernels/mlu/bridges/fc_op.cc b/lite/kernels/mlu/bridges/fc_op.cc index ed9ef7edd002ad0476efb84b34239ce07641538a..e820fc7abca89a573cfbd7efd7ecca1640905e6a 100644 --- a/lite/kernels/mlu/bridges/fc_op.cc +++ b/lite/kernels/mlu/bridges/fc_op.cc @@ -68,7 +68,7 @@ int FCConverter(void* ctx, OpLite* op, KernelBase* kernel) { auto w_tensor = graph->AddNode( w_var_name, cnml_w_shape, CNML_FILTER, CNML_NCHW, graph->FPType()); - auto input_scale = op_info->GetAttr("input_scale"); + auto input_scale = op_info->GetInputScale(x_var_name)[0]; auto output_tensor = graph->AddNode(output_var_name, output->dims().Vectorize(), @@ -101,7 +101,7 @@ int FCConverter(void* ctx, OpLite* op, KernelBase* kernel) { bias_tensor ? bias_tensor->mlu_tensor() : nullptr)); graph->SetComputingDataType( fc_op, graph->GetNode(x_var_name)->mlu_tensor(), 1 / input_scale); - auto weight_scale = op_info->GetAttr>("weight_scale"); + auto weight_scale = op_info->GetInputScale(w_var_name); // LOG(INFO) << "W precision " << int(w->precision()); if (w->precision() == PrecisionType::kUnk || diff --git a/lite/kernels/mlu/bridges/fc_op_test.cc b/lite/kernels/mlu/bridges/fc_op_test.cc index af856a55a2ddc563d210af3b4ef0e669b32f5a57..b7c576581b7bab4b5dd3f2538350a65f94d62c62 100644 --- a/lite/kernels/mlu/bridges/fc_op_test.cc +++ b/lite/kernels/mlu/bridges/fc_op_test.cc @@ -131,14 +131,15 @@ void test_fc(const std::vector& input_shape, fc_op_desc_mlu.SetOutput("Out", {out_var_name}); fc_op_desc_mlu.SetAttr("in_num_col_dims", static_cast(in_num_col_dims)); - fc_op_desc_mlu.SetAttr("weight_scale", - std::vector(w_shape[1], w_scale)); - fc_op_desc_mlu.SetAttr("input_scale", input_scale); + OpInfo op_info(fc_op_desc_mlu); + op_info.SetInputScale(w_int_var_name, + std::vector(w_shape[1], w_scale)); + op_info.SetInputScale(input_var_name, {input_scale}); if (has_bias) { - fc_op_desc_mlu.SetInput("Bias", {bias_var_name}); + op_info.SetInput("Bias", {bias_var_name}); } - auto fc_op_mlu = CreateOp(fc_op_desc_mlu, &scope); + auto fc_op_mlu = CreateOp(op_info, &scope); Tensor input_tmp, out_tmp; input_tmp.Resize(input_shape); diff --git a/lite/kernels/mlu/bridges/lrn_op.cc b/lite/kernels/mlu/bridges/lrn_op.cc index 657f0dd6781590e1a9ca90bf25e4efcf789863dd..ff428ab10cef170983de788b9af517558e1fd7f5 100644 --- a/lite/kernels/mlu/bridges/lrn_op.cc +++ b/lite/kernels/mlu/bridges/lrn_op.cc @@ -49,8 +49,7 @@ int LrnConverter(void* ctx, OpLite* op, KernelBase* kernel) { << "Unsuport WithinChannel"; } auto local_size = op_info->GetAttr("n"); - CHECK(op_info->HasAttr("input_scale")); - auto input_scale = op_info->GetAttr("input_scale"); + auto input_scale = op_info->GetInputScale(x_var_name)[0]; VLOG(5) << "lrn input scale: " << input_scale; cnmlLrnOpParam_t param; diff --git a/lite/kernels/mlu/bridges/lrn_op_test.cc b/lite/kernels/mlu/bridges/lrn_op_test.cc index 21f7e816baeac264bf1b43b7520d464afa38c395..266446d6d3353bffa4398385703cd4cb64b4f53b 100644 --- a/lite/kernels/mlu/bridges/lrn_op_test.cc +++ b/lite/kernels/mlu/bridges/lrn_op_test.cc @@ -178,9 +178,10 @@ void test_lrn(float alpha, opdesc.SetAttr("k", k); opdesc.SetAttr("n", local_size); opdesc.SetAttr("norm_region", norm_region); - opdesc.SetAttr("input_scale", (*dmax - *dmin) / 255.f); + OpInfo op_info(opdesc); + op_info.SetInputScale(x_var_name, {(*dmax - *dmin) / 255.f}); - auto op = CreateOp(opdesc, &scope); + auto op = CreateOp(op_info, &scope); // baseline lrn_compute_ref(op); @@ -213,7 +214,7 @@ void test_lrn(float alpha, auto output_data = output_trans.mutable_data(); auto* output_ref_data = out_ref->mutable_data(); for (size_t i = 0; i < out->data_size(); i++) { - EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-4); + EXPECT_NEAR(output_data[i], output_ref_data[i], 5e-4); } } diff --git a/lite/kernels/mlu/subgraph_compute.h b/lite/kernels/mlu/subgraph_compute.h index ec9b69ebcb4f32f587de070aa8ab63ae1cedda13..044827dbf98c561b0d424a1c93b0da650ef58796 100644 --- a/lite/kernels/mlu/subgraph_compute.h +++ b/lite/kernels/mlu/subgraph_compute.h @@ -54,40 +54,15 @@ class SubgraphEngine : public subgraph::Engine { VLOG(4) << "[MLU] PADDLE_LITE_MLU_SAVE_OFFLINE_MODEL is " << GetBoolFromEnv("PADDLE_LITE_MLU_SAVE_OFFLINE_MODEL"); VLOG(4) << "[MLU] PADDLE_LITE_MLU_DISABLE_BATCH_SIZE_CHANGEABLE is " - << GetBoolFromEnv("PADDLE_LITE_MLU_DISABLE_BATCH_SIZE_CHANGEABLE"); + << GetBoolFromEnv("PADDLE_LITE_MLU_DISABLE_BATCH_SIZE_CHANGEABLE", + true); VLOG(4) << "[MLU] LITE_DISABLE_MLU_CAST is " << GetBoolFromEnv("LITE_DISABLE_MLU_CAST"); - if (GetBoolFromEnv("PADDLE_LITE_MLU_DISABLE_BATCH_SIZE_CHANGEABLE")) { + if (GetBoolFromEnv("PADDLE_LITE_MLU_DISABLE_BATCH_SIZE_CHANGEABLE", true)) { disable_batch_size_changeable_ = true; } } - int Build() { - // In order to attach all of the ops of the block desc, we need to build - // the original program firstly. - BuildOriginProgram(); - // Run InferShape() of all of ops, and convert Paddle ops to MLU IR graph - build_device_program_status_ = BuildDeviceProgram(); - return build_device_program_status_; - } - - int Launch() { - // Rebuild device program when the shapes of input tensors have been - // changed. - if (subgraph::CHECK_SUCCESS(build_device_program_status_) && - subgraph::CHECK_REBUILD_WHEN_SHAPE_CHANGED( - build_device_program_status_) && - InputShapeChanged()) { - Build(); - } - if (subgraph::CHECK_FAILED(build_device_program_status_)) { - LaunchOriginProgram(); - } else { - LaunchDeviceProgram(); - } - return 0; - } - bool InputShapeChanged() { std::vector> new_shape; // used in batch changable situation @@ -127,7 +102,10 @@ class SubgraphEngine : public subgraph::Engine { } protected: - int BuildDeviceProgram() override { + bool BuildDeviceProgram() override { + if (origin_program_.empty()) { + BuildOriginProgram(); + } if (!error_compile_batch_size_changeable_ && !disable_batch_size_changeable_) { int status = BuildDeviceProgramImpl(); @@ -142,7 +120,7 @@ class SubgraphEngine : public subgraph::Engine { return BuildDeviceProgramImpl(); } - int BuildDeviceProgramImpl() { + bool BuildDeviceProgramImpl() { int status = 0; auto graph = std::make_shared(); graph->SetFPType(fp_type_); @@ -197,13 +175,16 @@ class SubgraphEngine : public subgraph::Engine { status |= subgraph::FAILED; VLOG(4) << "[MLU] found unsupported batch_size changeable op type: " << op_type; - return status; + if (subgraph::CHECK_FAILED(status)) { + return false; + } + return true; } op->CheckShape(); const_cast(op)->InferShape(); if (!bridges.Exists(op_type, TARGET(kMLU))) { LOG(INFO) << "MLU bridges doesn't support op_type: " << op_type; - return subgraph::FAILED; + return false; } auto kernel = inst.kernel(); status |= bridges.Select(op_type, TARGET(kMLU))( @@ -211,7 +192,7 @@ class SubgraphEngine : public subgraph::Engine { const_cast(op), const_cast(kernel)); if (subgraph::CHECK_FAILED(status)) { - return subgraph::FAILED; + return false; } } // Obtain the output nodes of the MLU IR graph and build the graph to MLU @@ -242,7 +223,7 @@ class SubgraphEngine : public subgraph::Engine { if (GetBoolFromEnv("PADDLE_LITE_MLU_SAVE_OFFLINE_MODEL")) { graph->GenOfflineModel(GetOfflineModName()); } - return status; + return true; } std::string TrimStrings(const std::string& origin_str) { @@ -329,7 +310,7 @@ class SubgraphEngine : public subgraph::Engine { } } - int LaunchDeviceProgram() override { + bool LaunchDeviceProgram() override { // prepare input and output memory auto& mlu_context = this->ctx_->template As(); auto exec_queue = mlu_context.exec_queue(); @@ -453,7 +434,7 @@ class SubgraphEngine : public subgraph::Engine { // =========== DUMP END ================ } - return 0; + return true; } paddle::lite_api::PrecisionType fp_type_; @@ -501,12 +482,11 @@ class SubgraphCompute param.scope, this->precision())); CHECK(engine_); - engine_->Build(); } void Run() override { CHECK(engine_); - engine_->Launch(); + engine_->Run(); } virtual ~SubgraphCompute() = default; diff --git a/lite/kernels/opencl/conv_image_compute.cc b/lite/kernels/opencl/conv_image_compute.cc index 5b9e3b220a4134d466d6a4b5cbdfa2f42d12bdf6..083f72134eba8afc7db696f68d64098b9c59a0f9 100644 --- a/lite/kernels/opencl/conv_image_compute.cc +++ b/lite/kernels/opencl/conv_image_compute.cc @@ -30,92 +30,81 @@ namespace kernels { namespace opencl { void ConvImageCompute::PrepareForRun() { - const auto& param = this->Param(); - auto x_dims = param.x->dims(); - auto filter_dims = param.filter->dims(); - auto output_dims = param.output->dims(); + ReInitWhenNeeded(); + + auto filter_dims = conv_param_->filter->dims(); + filter_tensor_n_ = filter_dims[0]; + filter_tensor_c_ = filter_dims[1]; + filter_tensor_h_ = filter_dims[2]; + filter_tensor_w_ = filter_dims[3]; - float* filter_cpu = param.filter->mutable_data(); auto& context = ctx_->As(); CHECK(context.cl_context() != nullptr); const bool is_mali = context.cl_context()->IsArmMali(); - filter_gpu_image_ = std::unique_ptr(new Tensor); - tensor_hold_filter_image_ = std::unique_ptr(new Tensor); - tensor_hold_bias_image_ = std::unique_ptr(new Tensor); - int bs = x_dims[0]; - int c_in = x_dims[1]; - int h_out = output_dims[2]; - int w_out = output_dims[3]; - int kernel_h = filter_dims[2]; // oihw - int kernel_w = filter_dims[3]; - auto paddings = *param.paddings; - auto dilations = *param.dilations; - int stride_h = param.strides[0]; - int stride_w = param.strides[1]; - int pad_h = paddings[0]; - int pad_w = paddings[2]; - int groups = param.groups; - bool relu_fused = param.fuse_relu; - bool no_dilation = (dilations[0] == 1) && (dilations[1] == 1); - bool zero_pad = (pad_h == 0) && (pad_w == 0); - - bool pad_equal = - ((paddings[0] == paddings[1]) && (paddings[1] == paddings[2]) && - (paddings[2] == paddings[3])); - bool stride_equal = stride_h == stride_w; - bool dilation_equal = dilations[0] == dilations[1]; + + auto paddings = *conv_param_->paddings; + pad_up_ = paddings[0]; + pad_down_ = paddings[1]; + pad_left_ = paddings[2]; + pad_right_ = paddings[3]; + + auto dilations = *conv_param_->dilations; + dilation_h_ = dilations[0]; + dilation_w_ = dilations[1]; + + stride_h_ = conv_param_->strides[0]; + stride_w_ = conv_param_->strides[1]; + + groups_ = conv_param_->groups; + relu_fused_ = conv_param_->fuse_relu; + has_bias_ = (conv_param_->bias) != nullptr; + offset_ = filter_tensor_h_ / 2 - pad_up_; + + bool pad_equal = ((pad_left_ == pad_up_) && (pad_up_ == pad_left_) && + (pad_left_ == pad_right_)); + bool stride_equal = stride_h_ == stride_w_; + bool dilation_equal = dilation_h_ == dilation_w_; VLOG(3) << "Is arm mali / " << (is_mali ? "Yes" : "No"); - VLOG(3) << "Is relu fused? / " << (relu_fused ? "Yes" : "No"); - VLOG(3) << "groups:" << groups << " stride_h:" << stride_h - << " stride_w:" << stride_w << " pad_h:" << pad_h - << " pad_w:" << pad_w << " kernel_h:" << kernel_h - << " kernel_h:" << kernel_h; - VLOG(3) << "x_dims:" << x_dims[0] << " " << x_dims[1] << " " << x_dims[2] - << " " << x_dims[3]; - VLOG(3) << "dialtion:" << dilations[0] << " " << dilations[1]; - VLOG(3) << "output_dims:" << output_dims[0] << " " << output_dims[1] << " " - << output_dims[2] << " " << output_dims[3]; - VLOG(3) << "filter_dims:" << filter_dims[0] << " " << filter_dims[1] << " " - << filter_dims[2] << " " << filter_dims[3]; + VLOG(3) << "Is relu fused? / " << (relu_fused_ ? "Yes" : "No"); + VLOG(3) << "groups:" << groups_ << " stride_h_:" << stride_h_ + << " stride_w_:" << stride_w_ << " pad_left_:" << pad_left_ + << " pad_up_:" << pad_up_ << " filter_tensor_h_:" << filter_tensor_h_ + << " filter_tensor_h_:" << filter_tensor_h_; + VLOG(3) << "input_tensor_nchw:" << input_tensor_n_ << " " << input_tensor_c_ + << " " << input_tensor_h_ << " " << input_tensor_w_; + VLOG(3) << "dialtion:" << dilation_h_ << " " << dilation_w_; + VLOG(3) << "output_dims:" << output_tensor_n_ << " " << output_tensor_c_ + << " " << output_tensor_h_ << " " << output_tensor_w_; + VLOG(3) << "filter_dims:" << filter_tensor_n_ << " " << filter_tensor_c_ + << " " << filter_tensor_h_ << " " << filter_tensor_w_; VLOG(3) << "pad_equal:" << pad_equal; VLOG(3) << "stride_equal:" << stride_equal; VLOG(3) << "dilation_equal:" << dilation_equal; - VLOG(3) << "padding :" << paddings[0] << " " << paddings[1] << " " - << paddings[2] << " " << paddings[3]; + VLOG(3) << "padding :" << pad_up_ << " " << pad_down_ << " " << pad_left_ + << " " << pad_right_; CHECK(pad_equal && stride_equal && dilation_equal); + CHECK_GE(conv_param_->dilations->size(), 2); + CHECK(dilation_h_ == dilation_w_); + CHECK_GE(conv_param_->paddings->size(), 2); + CHECK(pad_left_ == pad_up_); + CHECK_GE(conv_param_->strides.size(), 2); + CHECK(stride_h_ == stride_w_); if (!is_mali) { - use_turn_ = false; + use_tune_ = false; } - // general gws.. - auto out_image_shape = InitImageDimInfoWith(output_dims); - - const std::vector& default_work_size = - DefaultWorkSize(output_dims, - DDim(std::vector{ - static_cast(out_image_shape["width"]), - static_cast(out_image_shape["height"])})); - default_c_blk_ = default_work_size[0]; - default_w_blk_ = default_work_size[1]; - default_nh_blk_ = default_work_size[2]; - c_blk_ = default_c_blk_; - w_blk_ = default_w_blk_; - nh_blk_ = default_nh_blk_; - global_work_size_ = cl::NDRange{static_cast(c_blk_), - static_cast(w_blk_), - static_cast(nh_blk_)}; - - if (kernel_h == 1 && kernel_w == 1) { - // conv2d_1x1 - // if (param.x->dims()[1] % 4 == 0) { - // kernel_func_names_.push_back("conv2d_1x1_simple"); - // } else { - // kernel_func_names_.push_back("conv2d_1x1_opt"); - // } + /********************************************* + * Upload filter, bias to opencl device + *********************************************/ + float* filter_cpu = conv_param_->filter->mutable_data(); + filter_gpu_image_ = std::unique_ptr(new Tensor); + tensor_hold_filter_image_ = std::unique_ptr(new Tensor); + tensor_hold_bias_image_ = std::unique_ptr(new Tensor); - if (param.x->dims()[1] % 4 == 0) { + if (filter_tensor_h_ == 1 && filter_tensor_h_ == 1) { + if (input_tensor_c_ % 4 == 0) { kernel_func_names_.push_back("conv2d_1x1_simple"); } else { kernel_func_names_.push_back("conv2d_1x1_opt"); @@ -124,89 +113,49 @@ void ConvImageCompute::PrepareForRun() { CLImageConverterNWBlock converter; const DDim& filter_image_dims = converter.InitImageDimInfoWith(filter_dims); - // std::vector filter_image_v(filter_image_dims[0] * - // filter_image_dims[1] * 4); // 4 : - // RGBA - tensor_hold_filter_image_->Resize( - {1, filter_image_dims[0], filter_image_dims[1], 4}); - + filter_image_h_ = filter_image_dims[1]; + filter_image_w_ = filter_image_dims[0]; + tensor_hold_filter_image_->Resize({1, filter_image_w_, filter_image_h_, 4}); half_t* filter_image_data = tensor_hold_filter_image_->mutable_data(); converter.NCHWToImage(filter_cpu, filter_image_data, filter_dims); filter_gpu_image_->mutable_data( - filter_image_dims[0], filter_image_dims[1], filter_image_data); + filter_image_w_, filter_image_h_, filter_image_data); impl_ = &ConvImageCompute::Conv2d1x1opt; - { - // calc 1x1 gws - w_blk_ = maptofactor(default_w_blk_, 4); - c_blk_ = default_c_blk_; - nh_blk_ = default_nh_blk_; - global_work_size_ = cl::NDRange{static_cast(c_blk_), - static_cast(w_blk_), - static_cast(nh_blk_)}; - } #define DEPTH_CONV_USE_SPL #ifdef DEPTH_CONV_USE_SPL - } else if (filter_dims[1] == 1 && x_dims[1] == output_dims[1] && - kernel_h == 3 && kernel_w == 3 && groups > 1) { + } else if (filter_tensor_c_ == 1 && input_tensor_c_ == output_tensor_c_ && + filter_tensor_h_ == 3 && filter_tensor_w_ == 3 && groups_ > 1) { // depth_conv2d_3x3s1, depth_conv2d_3x3 - if (stride_h == 1 && dilations[0] == 1) { + if (stride_h_ == 1 && dilation_h_ == 1) { kernel_func_names_.push_back("depth_conv2d_3x3s1"); impl_ = &ConvImageCompute::DepthwiseConv2d3x3s1; - { - // depthwise spl gws s1 - int c_block = (output_dims[1] + 3) / 4; - int w = output_dims[3]; - int nh = output_dims[0] * output_dims[2]; - int w_blk_size = 2; - int w_blk = (w + w_blk_size - 1) / w_blk_size; - - c_blk_ = c_block; - w_blk_ = w_blk; - nh_blk_ = nh; - global_work_size_ = cl::NDRange{static_cast(c_blk_), - static_cast(w_blk_), - static_cast(nh_blk_)}; - } } else { kernel_func_names_.push_back("depth_conv2d_3x3"); impl_ = &ConvImageCompute::DepthwiseConv2d3x3; - { - // depthwise spl gws - int c_block = (output_dims[1] + 3) / 4; - int w = output_dims[3]; - int nh = output_dims[0] * output_dims[2]; - - c_blk_ = c_block; - w_blk_ = w; - nh_blk_ = nh; - - global_work_size_ = cl::NDRange{static_cast(c_blk_), - static_cast(w_blk_), - static_cast(nh_blk_)}; - } } kernel_func_paths_.push_back("image/depthwise_conv2d_kernel.cl"); CLImageConverterNWBlock converter; const DDim& filter_image_dims = converter.InitImageDimInfoWith(filter_dims); - tensor_hold_filter_image_->Resize( - {1, filter_image_dims[0], filter_image_dims[1], 4}); + filter_image_h_ = filter_image_dims[1]; + filter_image_w_ = filter_image_dims[0]; + tensor_hold_filter_image_->Resize({1, filter_image_w_, filter_image_h_, 4}); half_t* filter_image_data = tensor_hold_filter_image_->mutable_data(); converter.NCHWToImage(filter_cpu, filter_image_data, filter_dims); filter_gpu_image_->mutable_data( - filter_image_dims[0], filter_image_dims[1], filter_image_data); + filter_image_w_, filter_image_h_, filter_image_data); #endif - } else if (filter_dims[1] == 1 && x_dims[1] == output_dims[1] + } else if (filter_tensor_c_ == 1 && input_tensor_c_ == output_tensor_c_ #ifdef DEPTH_CONV_USE_SPL && - kernel_h != 3 + filter_tensor_h_ != 3 #endif #undef DEPTH_CONV_USE_SPL ) { @@ -216,75 +165,61 @@ void ConvImageCompute::PrepareForRun() { CLImageConverterNWBlock converter; const DDim& filter_image_dims = converter.InitImageDimInfoWith(filter_dims); - tensor_hold_filter_image_->Resize( - {1, filter_image_dims[0], filter_image_dims[1], 4}); + filter_image_h_ = filter_image_dims[1]; + filter_image_w_ = filter_image_dims[0]; + tensor_hold_filter_image_->Resize({1, filter_image_w_, filter_image_h_, 4}); half_t* filter_image_data = tensor_hold_filter_image_->mutable_data(); converter.NCHWToImage(filter_cpu, filter_image_data, filter_dims); filter_gpu_image_->mutable_data( - filter_image_dims[0], filter_image_dims[1], filter_image_data); + filter_image_w_, filter_image_h_, filter_image_data); impl_ = &ConvImageCompute::DepthwiseConv2d; - } else if (kernel_w == 3 && kernel_h == 3) { + } else if (filter_tensor_h_ == 3 && filter_tensor_w_ == 3) { // #define CONV3x3OPT_FALL_BACK #ifndef CONV3x3OPT_FALL_BACK // conv2d_3x3 - kernel_func_names_.push_back(bs > 1 ? "conv2d_3x3_multi_batch" - : "conv2d_3x3_opt"); + kernel_func_names_.push_back(input_tensor_n_ > 1 ? "conv2d_3x3_multi_batch" + : "conv2d_3x3_opt"); kernel_func_paths_.push_back("image/conv2d_3x3_opt_kernel.cl"); CLImageConverterFolder converter; const DDim& filter_image_dims = converter.InitImageDimInfoWith(filter_dims); - tensor_hold_filter_image_->Resize( - {1, filter_image_dims[0], filter_image_dims[1], 4}); + filter_image_h_ = filter_image_dims[1]; + filter_image_w_ = filter_image_dims[0]; + tensor_hold_filter_image_->Resize({1, filter_image_w_, filter_image_h_, 4}); half_t* filter_image_data = tensor_hold_filter_image_->mutable_data(); converter.NCHWToImage(filter_cpu, filter_image_data, filter_dims); filter_gpu_image_->mutable_data( - filter_image_dims[0], filter_image_dims[1], filter_image_data); + filter_image_w_, filter_image_h_, filter_image_data); impl_ = &ConvImageCompute::Conv2d3x3opt; - - { - int w_blk_size = 5; - int w_blk = (default_w_blk_ + w_blk_size - 1) / w_blk_size; - - int h_blk_size = 1; - int h_blk = (default_nh_blk_ + h_blk_size - 1) / h_blk_size; - - c_blk_ = default_c_blk_; - w_blk_ = w_blk; - nh_blk_ = h_blk; - - global_work_size_ = cl::NDRange{static_cast(c_blk_), - static_cast(w_blk_), - static_cast(nh_blk_)}; - } #else kernel_func_names_.push_back("conv2d_3x3"); kernel_func_paths_.push_back("image/conv2d_3x3_kernel.cl"); CLImageConverterFolder converter; const DDim& filter_image_dims = converter.InitImageDimInfoWith(filter_dims); - tensor_hold_filter_image_->Resize( - {1, filter_image_dims[0], filter_image_dims[1], 4}); + filter_image_h_ = filter_image_dims[1]; + filter_image_w_ = filter_image_dims[0]; + tensor_hold_filter_image_->Resize({1, filter_image_w_, filter_image_h_, 4}); half_t* filter_image_data = tensor_hold_filter_image_->mutable_data(); converter.NCHWToImage(filter_cpu, filter_image_data, filter_dims); filter_gpu_image_->mutable_data( - filter_image_dims[0], filter_image_dims[1], filter_image_data); + filter_image_w_, filter_image_h_, filter_image_data); impl_ = &ConvImageCompute::Conv2d3x3; - #endif #undef CONV3x3OPT_FALL_BACK - } else if (kernel_h == 5 && kernel_w == 5) { + } else if (filter_tensor_h_ == 5 && filter_tensor_w_ == 5) { #define CONV_5x5_OPT #ifndef CONV_5x5_OPT // conv2d_5x5 @@ -293,55 +228,42 @@ void ConvImageCompute::PrepareForRun() { CLImageConverterFolder converter; const DDim& filter_image_dims = converter.InitImageDimInfoWith(filter_dims); - tensor_hold_filter_image_->Resize( - {1, filter_image_dims[0], filter_image_dims[1], 4}); + filter_image_h_ = filter_image_dims[1]; + filter_image_w_ = filter_image_dims[0]; + tensor_hold_filter_image_->Resize({1, filter_image_w_, filter_image_h_, 4}); half_t* filter_image_data = tensor_hold_filter_image_->mutable_data(); converter.NCHWToImage(filter_cpu, filter_image_data, filter_dims); filter_gpu_image_->mutable_data( - filter_image_dims[0], filter_image_dims[1], filter_image_data); + filter_image_w_, filter_image_h_, filter_image_data); impl_ = &ConvImageCompute::Conv2d5x5; #else // conv2d_5x5_opt - kernel_func_names_.push_back(bs > 1 ? "conv2d_5x5_multi_batch" - : "conv2d_5x5_opt"); + kernel_func_names_.push_back(input_tensor_n_ > 1 ? "conv2d_5x5_multi_batch" + : "conv2d_5x5_opt"); kernel_func_paths_.push_back("image/conv2d_5x5_opt_kernel.cl"); CLImageConverterFolder converter; const DDim& filter_image_dims = converter.InitImageDimInfoWith(filter_dims); - tensor_hold_filter_image_->Resize( - {1, filter_image_dims[0], filter_image_dims[1], 4}); + filter_image_h_ = filter_image_dims[1]; + filter_image_w_ = filter_image_dims[0]; + tensor_hold_filter_image_->Resize({1, filter_image_w_, filter_image_h_, 4}); half_t* filter_image_data = tensor_hold_filter_image_->mutable_data(); converter.NCHWToImage(filter_cpu, filter_image_data, filter_dims); filter_gpu_image_->mutable_data( - filter_image_dims[0], filter_image_dims[1], filter_image_data); + filter_image_w_, filter_image_h_, filter_image_data); impl_ = &ConvImageCompute::Conv2d5x5opt; - { - int w_blk_size = 5; - int w_blk = (default_w_blk_ + w_blk_size - 1) / w_blk_size; - - int h_blk_size = 1; - int h_blk = (default_nh_blk_ + h_blk_size - 1) / h_blk_size; - - c_blk_ = default_c_blk_; - w_blk_ = w_blk; - nh_blk_ = h_blk; - - global_work_size_ = cl::NDRange{static_cast(c_blk_), - static_cast(w_blk_), - static_cast(nh_blk_)}; - } #endif #undef CONV_5x5_OPT - } else if (kernel_h == 7 && kernel_w == 7) { + } else if (filter_tensor_h_ == 7 && filter_tensor_w_ == 7) { #define CONV_7x7_OPT #ifndef CONV_7x7_OPT // conv2d_7x7 @@ -350,52 +272,39 @@ void ConvImageCompute::PrepareForRun() { CLImageConverterFolder converter; const DDim& filter_image_dims = converter.InitImageDimInfoWith(filter_dims); - tensor_hold_filter_image_->Resize( - {1, filter_image_dims[0], filter_image_dims[1], 4}); + filter_image_h_ = filter_image_dims[1]; + filter_image_w_ = filter_image_dims[0]; + tensor_hold_filter_image_->Resize({1, filter_image_w_, filter_image_h_, 4}); half_t* filter_image_data = tensor_hold_filter_image_->mutable_data(); converter.NCHWToImage(filter_cpu, filter_image_data, filter_dims); filter_gpu_image_->mutable_data( - filter_image_dims[0], filter_image_dims[1], filter_image_data); + filter_image_w_, filter_image_h_, filter_image_data); impl_ = &ConvImageCompute::Conv2d7x7; #else // conv2d_7x7 - kernel_func_names_.push_back(bs > 1 ? "conv2d_7x7_multi_batch" - : "conv2d_7x7_opt"); + kernel_func_names_.push_back(input_tensor_n_ > 1 ? "conv2d_7x7_multi_batch" + : "conv2d_7x7_opt"); kernel_func_paths_.push_back("image/conv2d_7x7_opt_kernel.cl"); CLImageConverterFolder converter; const DDim& filter_image_dims = converter.InitImageDimInfoWith(filter_dims); - tensor_hold_filter_image_->Resize( - {1, filter_image_dims[0], filter_image_dims[1], 4}); + filter_image_h_ = filter_image_dims[1]; + filter_image_w_ = filter_image_dims[0]; + tensor_hold_filter_image_->Resize({1, filter_image_w_, filter_image_h_, 4}); half_t* filter_image_data = tensor_hold_filter_image_->mutable_data(); converter.NCHWToImage(filter_cpu, filter_image_data, filter_dims); filter_gpu_image_->mutable_data( - filter_image_dims[0], filter_image_dims[1], filter_image_data); + filter_image_w_, filter_image_h_, filter_image_data); impl_ = &ConvImageCompute::Conv2d7x7opt; - { - int w_blk_size = 5; - int w_blk = (default_w_blk_ + w_blk_size - 1) / w_blk_size; - - int h_blk_size = 1; - int h_blk = (default_nh_blk_ + h_blk_size - 1) / h_blk_size; - - c_blk_ = default_c_blk_; - w_blk_ = w_blk; - nh_blk_ = h_blk; - - global_work_size_ = cl::NDRange{static_cast(c_blk_), - static_cast(w_blk_), - static_cast(nh_blk_)}; - } #endif #undef CONV_7x7_OPT } else { @@ -407,30 +316,30 @@ void ConvImageCompute::PrepareForRun() { // build options std::string build_options_single(" -DCL_DTYPE_half"); // relu options - VLOG(3) << "relu_fused:" << relu_fused - << " param.activation_param.active_type:" - << static_cast(param.activation_param.active_type) - << " param.activation_param.has_active:" - << param.activation_param.has_active; - if (param.activation_param.has_active) { - if (param.activation_param.active_type == - lite_api::ActivationType::kRelu) { // Note: judge using `relu_fused` + VLOG(3) << "relu_fused_:" << relu_fused_ + << " conv_param_->activation_param.active_type:" + << static_cast(conv_param_->activation_param.active_type) + << " conv_param_->activation_param.has_active:" + << conv_param_->activation_param.has_active; + if (conv_param_->activation_param.has_active) { + if (conv_param_->activation_param.active_type == + lite_api::ActivationType::kRelu) { // Note: judge using `relu_fused_` // also is ok build_options_single += " -DRELU"; - } else if (param.activation_param.active_type == + } else if (conv_param_->activation_param.active_type == lite_api::ActivationType::kRelu6) { build_options_single += " -DRELU6"; } else { LOG(FATAL) << "Unsupported activation type:" - << static_cast(param.activation_param.active_type); + << static_cast(conv_param_->activation_param.active_type); } } + GetGlobalWorkSize(); // bias options - const bool has_bias = param.bias != nullptr; const bool is_element_wise_bias = - has_bias && param.output->dims() == param.bias->dims(); - if (has_bias) { + has_bias_ && conv_param_->output->dims() == conv_param_->bias->dims(); + if (has_bias_) { bias_gpu_image_ = std::unique_ptr(new Tensor); build_options_single += is_element_wise_bias ? " -DBIASE_ELE" : " -DBIASE_CH"; @@ -438,21 +347,36 @@ void ConvImageCompute::PrepareForRun() { // convert cpu buffer bias --> gpu image CLImageConverterFolder bias_converter; const DDim& bias_image_dims = - bias_converter.InitImageDimInfoWith(param.bias->dims()); - + bias_converter.InitImageDimInfoWith(conv_param_->bias->dims()); + bias_image_h_ = bias_image_dims[1]; + bias_image_w_ = bias_image_dims[0]; tensor_hold_bias_image_->Resize( {1, bias_image_dims[0], bias_image_dims[1], 4}); half_t* bias_image_data = tensor_hold_bias_image_->mutable_data(); - float* bias_cpu_data = param.bias->mutable_data(); + float* bias_cpu_data = conv_param_->bias->mutable_data(); bias_converter.NCHWToImage( - bias_cpu_data, bias_image_data, param.bias->dims()); + bias_cpu_data, bias_image_data, conv_param_->bias->dims()); this->bias_gpu_image_->mutable_data( bias_image_dims[0], bias_image_dims[1], bias_image_data); // convert cpu buffer bias --> gpu image --- end ---- + } else { + bias_gpu_image_ = std::unique_ptr(new Tensor); + CLImageConverterFolder bias_converter; + tensor_hold_bias_image_->Resize({1, 1, 1, 4}); + half_t* bias_image_data = tensor_hold_bias_image_->mutable_data(); + this->bias_gpu_image_->mutable_data( + 1, 1, bias_image_data); } + // define image pointer for filter, bias + input_image_p_ = conv_param_->x->data(); + filter_image_p_ = filter_gpu_image_->data(); + bias_image_p_ = bias_gpu_image_->data(); + output_image_p_ = conv_param_->output->mutable_data( + output_image_w_, output_image_h_); + build_options_.push_back(build_options_single); for (size_t i = 0; i < kernel_func_names_.size(); i++) { @@ -478,55 +402,55 @@ void ConvImageCompute::PrepareForRun() { VLOG(4) << "max_work_group_size: " << max_work_group_size; if (max_work_group_size > 0 && use_lws_) { - double min_turn_time = DBL_MAX; + double min_tune_time = DBL_MAX; cl::NDRange best_local_work_size = context.cl_context()->LocalWorkSize( global_work_size_, max_work_group_size); VLOG(3) << "origin :local_work_size_ : " << best_local_work_size[0] << " " << best_local_work_size[1] << " " << best_local_work_size[2]; cl::NDRange last_local_work_size = cl::NDRange{ static_cast(0), static_cast(0), static_cast(0)}; - if (use_turn_) { + if (use_tune_) { for (size_t i = 1; i < 15; i++) { - if (kernel_h == 1 && kernel_w == 1) { + if (filter_tensor_h_ == 1 && filter_tensor_w_ == 1) { // todo use diff logics - local_work_size_ = context.cl_context()->LocalWorkSizeTurn( + local_work_size_ = context.cl_context()->LocalWorkSizeTune( global_work_size_, max_work_group_size, i); } else { - local_work_size_ = context.cl_context()->LocalWorkSizeTurn( + local_work_size_ = context.cl_context()->LocalWorkSizeTune( global_work_size_, max_work_group_size, i); } if (last_local_work_size[0] == local_work_size_[0] && last_local_work_size[1] == local_work_size_[1] && last_local_work_size[2] == local_work_size_[2]) { - // skiped turned lws + // skiped tuneed lws continue; } - auto turn_time = this->Turn(10); - if (min_turn_time > turn_time) { - min_turn_time = turn_time; + auto tune_time = this->Tune(10); + if (min_tune_time > tune_time) { + min_tune_time = tune_time; best_local_work_size = local_work_size_; } last_local_work_size = local_work_size_; } // reverse for (size_t i = 1; i < 15; i++) { - if (kernel_h == 1 && kernel_w == 1) { + if (filter_tensor_h_ == 1 && filter_tensor_w_ == 1) { // todo use diff logics - local_work_size_ = context.cl_context()->LocalWorkSizeTurnReverse( + local_work_size_ = context.cl_context()->LocalWorkSizeTuneReverse( global_work_size_, max_work_group_size, i); } else { - local_work_size_ = context.cl_context()->LocalWorkSizeTurnReverse( + local_work_size_ = context.cl_context()->LocalWorkSizeTuneReverse( global_work_size_, max_work_group_size, i); } if (last_local_work_size[0] == local_work_size_[0] && last_local_work_size[1] == local_work_size_[1] && last_local_work_size[2] == local_work_size_[2]) { - // skiped turned lws + // skiped tuneed lws continue; } - auto turn_time = this->Turn(10); - if (min_turn_time > turn_time) { - min_turn_time = turn_time; + auto tune_time = this->Tune(10); + if (min_tune_time > tune_time) { + min_tune_time = tune_time; best_local_work_size = local_work_size_; } last_local_work_size = local_work_size_; @@ -540,548 +464,316 @@ void ConvImageCompute::PrepareForRun() { } } -void ConvImageCompute::Conv2d1x1opt(bool is_turn) { - auto& context = ctx_->As(); - CHECK(context.cl_context() != nullptr); - const auto& param = *param_.get_mutable(); - auto input_dims = param.x->dims(); - auto paddings = *param.paddings; - auto strides = param.strides; - auto* input_image = param.x->data(); - auto* filter_image = filter_gpu_image_->data(); - auto filter_dims = param.filter->dims(); - auto output_dims = param.output->dims(); - - int input_width = input_dims[3]; - int input_height = input_dims[2]; - int output_width = output_dims[3]; - int output_height = output_dims[2]; - auto out_image_shape = InitImageDimInfoWith(output_dims); - auto* out_image = param.output->mutable_data( - out_image_shape["width"], out_image_shape["height"]); - - const bool has_bias = param.bias != nullptr; - const bool is_element_wise_bias = - has_bias && param.output->dims() == param.bias->dims(); - int offset = static_cast(param.filter->dims()[2]) / 2 - - static_cast(paddings[0]); - - // calc input_c_block - auto input_image_shape = InitImageDimInfoWith(input_dims); - int input_c_block = input_image_shape["width"] / input_dims[3]; - int input_c = input_dims[1]; - auto dilations = *param.dilations; - +void ConvImageCompute::ReInitWhenNeeded() { + conv_param_ = param_.get_mutable(); + auto x_dims = conv_param_->x->dims(); #ifdef LITE_WITH_LOG - // VLOG(4) << "out_image: " << out_image; - VLOG(4) << "global_work_size_[3D]: {" << global_work_size_[0] << "," - << global_work_size_[1] << "," << global_work_size_[2] << "}"; -#endif -#ifdef LITE_WITH_LOG - VLOG(4) << "============ conv2d_1x1 params ============"; - VLOG(4) << "input_image_shape: " << input_image_shape["width"] << "," - << input_image_shape["height"]; - VLOG(4) << "input_c_block: " << input_c_block; - VLOG(4) << "input_c: " << input_c; - // VLOG(4) << "input_image: " << input_image; - VLOG(4) << "filter_dims: " << filter_dims; - // VLOG(4) << "filter_image: " << filter_image; - VLOG(4) << "output_dims: " << output_dims; - VLOG(4) << "out_image_shape: " << out_image_shape["width"] << ", " - << out_image_shape["height"]; - VLOG(4) << "paddings: " << paddings[0] << "," << paddings[1]; - VLOG(4) << "has bias: " << has_bias; - VLOG(4) << "is_element_wise_bias : " << is_element_wise_bias; - VLOG(4) << "strides: " << strides[0] << "," << strides[1]; - VLOG(4) << "offset: " << offset; - VLOG(4) << "dilations.size : " << dilations.size(); - VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1]; -// VLOG(4) << "default work size{c_block, w, nh}: " -// << "{" << c_block << ", " << w << ", " << nh << "" -// << "}"; + LOG(INFO) << "is_first_epoch_for_run_:" << is_first_epoch_for_run_ + << ", last_input_dims_:" << last_input_dims_ + << ", x_dims:" << x_dims; #endif - CHECK_GE(dilations.size(), 2); - CHECK(dilations[0] == dilations[1]); - CHECK_GE(input_dims.size(), 4); - CHECK_GE(paddings.size(), 2); - CHECK(paddings[0] == paddings[1]); - CHECK_GE(strides.size(), 2); - CHECK(strides[0] == strides[1]); - - // handle bias use buffer for channel wise , use image for element wise - const cl::Buffer* bias_buf = nullptr; - const cl::Image2D* bias_image = nullptr; - if (has_bias) { - bias_image = bias_gpu_image_->data(); - } - - auto kernel = kernel_; - cl_int status; - int arg_idx = 0; - status = kernel.setArg(arg_idx, c_blk_); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, w_blk_); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, nh_blk_); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *input_image); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *filter_image); - CL_CHECK_FATAL(status); - if (has_bias) { - status = kernel.setArg(++arg_idx, *bias_image); - CL_CHECK_FATAL(status); - } - status = kernel.setArg(++arg_idx, *out_image); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, strides[0]); - CL_CHECK_FATAL(status); - - status = kernel.setArg(++arg_idx, offset); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_c_block); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_c); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, dilations[0]); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_width); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_height); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, output_width); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, output_height); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, default_w_blk_); - CL_CHECK_FATAL(status); - - status = EnqueueNDRangeKernel(context, - kernel, - cl::NullRange, - global_work_size_, - local_work_size_, - nullptr, - event_); - CL_CHECK_FATAL(status); - if (is_turn) { - CLRuntime::Global()->command_queue().finish(); - } -} -void ConvImageCompute::Conv2d3x3(bool is_turn) { - auto kernel = kernel_; - const auto& param = *param_.get_mutable(); - auto input_dims = param.x->dims(); - auto paddings = *param.paddings; - auto strides = param.strides; - - auto* input_image = param.x->data(); - auto* filter_image = filter_gpu_image_->data(); - auto filter_dims = param.filter->dims(); - auto output_dims = param.output->dims(); - - int input_width = input_dims[3]; - int input_height = input_dims[2]; - int input_channel = input_dims[1]; - int output_width = output_dims[3]; - int output_height = output_dims[2]; - int output_channel = output_dims[1]; - int filter_width = filter_dims[3]; - int filter_height = filter_dims[2]; - int filter_channel = filter_dims[1]; - auto out_image_shape = InitImageDimInfoWith(output_dims); - auto* out_image = param.output->mutable_data( - out_image_shape["width"], out_image_shape["height"]); - - const bool has_bias = param.bias != nullptr; - const bool is_element_wise_bias = - has_bias && param.output->dims() == param.bias->dims(); - int offset = static_cast(param.filter->dims()[2]) / 2 - - static_cast(paddings[0]); - - // calc input_c_block - auto input_image_shape = InitImageDimInfoWith(input_dims); - int input_c_block = input_image_shape["width"] / input_dims[3]; - int input_c = input_dims[1]; - auto dilations = *param.dilations; - - // re-calc group - int new_groups{param.groups}; - if (filter_dims[0] == output_dims[1] && filter_dims[1] == input_dims[1]) { - new_groups = 1; - } else if (!(filter_dims[0] == input_dims[1] && filter_dims[1] == 1)) { - new_groups = input_channel / filter_channel; - } - /* TODO(ysh329): mobile has no case below - else { - LOG(FATAL) << "Not support conv3x3 case with" - << " input_dims:" << input_dims << " output_dims:" << - output_dims - << " filter_dims:" << filter_dims; + if (is_first_epoch_for_run_ || last_input_dims_ != x_dims) { + is_first_epoch_for_run_ = false; + last_input_dims_ = x_dims; + + input_tensor_n_ = x_dims[0]; + input_tensor_c_ = x_dims[1]; + input_tensor_h_ = x_dims[2]; + input_tensor_w_ = x_dims[3]; + auto x_image_shape = InitImageDimInfoWith(x_dims); + input_image_h_ = x_image_shape["height"]; + input_image_w_ = x_image_shape["width"]; + + auto output_dims = conv_param_->output->dims(); + output_tensor_n_ = output_dims[0]; + output_tensor_c_ = output_dims[1]; + output_tensor_h_ = output_dims[2]; + output_tensor_w_ = output_dims[3]; + auto output_image_shape = InitImageDimInfoWith(output_dims); + output_image_h_ = output_image_shape["height"]; + output_image_w_ = output_image_shape["width"]; + + auto& context = ctx_->As(); + CHECK(context.cl_context() != nullptr); + CHECK_GE(conv_param_->x->dims().size(), 4); + CHECK_GE(conv_param_->output->dims().size(), 4); + if (kernel_func_names_.size() > 0 && + kernel_func_names_[0] == "conv2d_3x3") { + groups_ = conv_param_->groups; + if (filter_tensor_n_ == output_tensor_c_ && + filter_tensor_c_ == input_tensor_c_) { + groups_ = 1; + } else if (!(filter_tensor_n_ == input_tensor_c_ && + filter_tensor_c_ == 1)) { + groups_ = input_tensor_c_ / filter_tensor_c_; + } } - */ - - // const std::vector& default_work_size = - // DefaultWorkSize(output_dims, - // DDim(std::vector{ - // static_cast(out_image_shape["width"]), - // static_cast(out_image_shape["height"])})); - - // int c_block = default_work_size[0]; - // int w = default_work_size[1]; - // int nh = default_work_size[2]; - - // VLOG(4) << "============ conv2d params ============"; - // VLOG(4) << "input_image_shape: " << input_image_shape["width"] << "," - // << input_image_shape["height"]; - // VLOG(4) << "input_c_block: " << input_c_block; - // VLOG(4) << "input_c: " << input_c; - // VLOG(4) << "input_image: " << input_image; - // VLOG(4) << "input_dims: " << input_dims; - // VLOG(4) << "filter_dims: " << filter_dims; - // VLOG(4) << "filter_image: " << filter_image; - // VLOG(4) << "output_dims: " << output_dims; - // VLOG(4) << "out_image_shape: " << out_image_shape["width"] << ", " - // << out_image_shape["height"]; - // VLOG(4) << "paddings: " << paddings[0] << "," << paddings[1]; - // VLOG(4) << "has bias: " << has_bias; - // VLOG(4) << "is_element_wise_bias : " << is_element_wise_bias; - // VLOG(4) << "strides: " << strides[0] << "," << strides[1]; - // VLOG(4) << "offset: " << offset; - // VLOG(4) << "dilations.size : " << dilations.size(); - // VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1]; - // VLOG(4) << "param.groups(groups):" << param.groups; - // VLOG(4) << "new_groups:" << new_groups; - // VLOG(4) << "default work size{c_block, w, nh}: " - // << "{" << c_block << ", " << w << ", " << nh << "" - // << "}"; - - CHECK_GE(dilations.size(), 2); - CHECK(dilations[0] == dilations[1]); - CHECK_GE(input_dims.size(), 4); - CHECK_GE(paddings.size(), 2); - CHECK(paddings[0] == paddings[1]); - CHECK_GE(strides.size(), 2); - CHECK(strides[0] == strides[1]); - - const cl::Image2D* bias_image = nullptr; - if (has_bias) { - bias_image = bias_gpu_image_->data(); - } - auto& context = ctx_->As(); - CHECK(context.cl_context() != nullptr); - // STL::stringstream kernel_key; - // kernel_key << kernel_func_names_[0] << build_options_[0]; - // auto kernel = context.cl_context()->GetKernel(kernel_key.str()); - // VLOG(4) << "kernel_key: " << kernel_key.str(); - // VLOG(4) << "kernel ready ... " << kernel_key.str(); - // VLOG(4) << "w: " << w; - - cl_int status; - int arg_idx = 0; - status = kernel.setArg(arg_idx, c_blk_); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, w_blk_); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, nh_blk_); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *input_image); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *filter_image); - CL_CHECK_FATAL(status); - if (has_bias) { - VLOG(4) << "set bias_image: "; - status = kernel.setArg(++arg_idx, *bias_image); - CL_CHECK_FATAL(status); + // define image pointer for input, output + input_image_p_ = conv_param_->x->data(); + output_image_p_ = conv_param_->output->mutable_data( + output_image_w_, output_image_h_); + + GetGlobalWorkSize(); } - status = kernel.setArg(++arg_idx, *out_image); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, strides[0]); - CL_CHECK_FATAL(status); - - status = kernel.setArg(++arg_idx, offset); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_c_block); - CL_CHECK_FATAL(status); - - status = kernel.setArg(++arg_idx, dilations[0]); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_width); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_height); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, output_width); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, output_height); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, output_channel); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, filter_channel); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, filter_width); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, filter_height); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, new_groups); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(input_dims[1])); - CL_CHECK_FATAL(status); - - // auto global_work_size = - // cl::NDRange{static_cast(default_work_size.data()[0]), - // static_cast(default_work_size.data()[1]), - // static_cast(default_work_size.data()[2])}; - - // VLOG(4) << "out_image: " << out_image; - // VLOG(4) << "global_work_size[3D]: {" << global_work_size[0] << "," - // << global_work_size[1] << "," << global_work_size[2] << "}"; - - status = EnqueueNDRangeKernel(context, - kernel, - cl::NullRange, - global_work_size_, - cl::NullRange, - nullptr, - event_); - CL_CHECK_FATAL(status); } -void ConvImageCompute::Conv2d3x3opt(bool is_turn) { - auto& context = ctx_->As(); - CHECK(context.cl_context() != nullptr); - const auto& param = *param_.get_mutable(); - auto input_dims = param.x->dims(); - auto paddings = *param.paddings; - auto strides = param.strides; - auto dilations = *param.dilations; - - auto* input_image = param.x->data(); - auto* filter_image = filter_gpu_image_->data(); - auto filter_dims = param.filter->dims(); - auto output_dims = param.output->dims(); - - int input_width = input_dims[3]; - int input_height = input_dims[2]; - int input_channel = input_dims[1]; - int output_width = output_dims[3]; - int output_height = output_dims[2]; - int output_channel = output_dims[1]; - CHECK_EQ(input_dims[0], output_dims[0]); - int batch = input_dims[0]; - auto out_image_shape = InitImageDimInfoWith(output_dims); - auto* out_image = param.output->mutable_data( - out_image_shape["width"], out_image_shape["height"]); - - const bool has_bias = param.bias != nullptr; - const bool is_element_wise_bias = - has_bias && param.output->dims() == param.bias->dims(); -#ifdef LITE_WITH_LOG - VLOG(4) << "============ conv2d params ============"; - // VLOG(4) << "input_image_shape: " << input_image_shape["width"] << "," - // << input_image_shape["height"]; - // VLOG(4) << "input_image: " << input_image; - VLOG(4) << "input_dims: " << input_dims; - VLOG(4) << "filter_dims: " << filter_dims; - // VLOG(4) << "filter_image: " << filter_image; - VLOG(4) << "output_dims: " << output_dims; - VLOG(4) << "out_image_shape: " << out_image_shape["width"] << ", " - << out_image_shape["height"]; - VLOG(4) << "paddings: " << paddings[0] << "," << paddings[1]; - VLOG(4) << "has bias: " << has_bias; - VLOG(4) << "is_element_wise_bias : " << is_element_wise_bias; - VLOG(4) << "strides: " << strides[0] << "," << strides[1]; - VLOG(4) << "dilations.size : " << dilations.size(); - VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1]; -#endif +void ConvImageCompute::GetGlobalWorkSize() { + if (kernel_func_names_.size() <= 0) return; + // general input_c_block + input_c_block_ = static_cast(input_image_w_ / input_tensor_w_); - CHECK_GE(dilations.size(), 2); - CHECK(dilations[0] == dilations[1]); - CHECK_GE(input_dims.size(), 4); - CHECK_GE(paddings.size(), 2); - CHECK(paddings[0] == paddings[1]); - CHECK_GE(strides.size(), 2); - CHECK(strides[0] == strides[1]); - - const cl::Image2D* bias_image = nullptr; - if (has_bias) { - bias_image = bias_gpu_image_->data(); - } + // general gws + auto output_dims = conv_param_->output->dims(); + const std::vector& default_work_size = + DefaultWorkSize(output_dims, + DDim(std::vector{ + static_cast(output_image_w_), + static_cast(output_image_h_)})); + default_c_blk_ = default_work_size[0]; + default_w_blk_ = default_work_size[1]; + default_nh_blk_ = default_work_size[2]; + c_blk_ = default_c_blk_; + w_blk_ = default_w_blk_; + nh_blk_ = default_nh_blk_; + global_work_size_ = cl::NDRange{static_cast(c_blk_), + static_cast(w_blk_), + static_cast(nh_blk_)}; - auto kernel = kernel_; - - cl_int status; - int arg_idx = 0; - status = kernel.setArg(arg_idx, c_blk_); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, w_blk_); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, nh_blk_); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *input_image); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *filter_image); - CL_CHECK_FATAL(status); - if (has_bias) { -#ifdef LITE_WITH_LOG - VLOG(4) << "set bias_image: "; -#endif - status = kernel.setArg(++arg_idx, *bias_image); - CL_CHECK_FATAL(status); + if (kernel_func_names_[0] == "conv2d_1x1_simple" || + kernel_func_names_[0] == "conv2d_1x1_opt") { + w_blk_ = maptofactor(default_w_blk_, 4); + c_blk_ = default_c_blk_; + nh_blk_ = default_nh_blk_; + global_work_size_ = cl::NDRange{static_cast(c_blk_), + static_cast(w_blk_), + static_cast(nh_blk_)}; + + } else if (kernel_func_names_[0] == "depth_conv2d_3x3s1") { + // depthwise spl gws s1 + int c_block = (output_tensor_c_ + 3) / 4; + int w = output_tensor_w_; + int nh = output_tensor_n_ * output_tensor_h_; + int w_blk_size = 2; + int w_blk = (w + w_blk_size - 1) / w_blk_size; + + c_blk_ = c_block; + w_blk_ = w_blk; + nh_blk_ = nh; + global_work_size_ = cl::NDRange{static_cast(c_blk_), + static_cast(w_blk_), + static_cast(nh_blk_)}; + } else if (kernel_func_names_[0] == "depth_conv2d_3x3") { + // depthwise spl gws + int c_block = (output_tensor_c_ + 3) / 4; + int w = output_tensor_w_; + int nh = output_tensor_n_ * output_tensor_h_; + + c_blk_ = c_block; + w_blk_ = w; + nh_blk_ = nh; + global_work_size_ = cl::NDRange{static_cast(c_blk_), + static_cast(w_blk_), + static_cast(nh_blk_)}; + input_c_block_ = static_cast((input_tensor_c_ + 3) / 4); + } else if (kernel_func_names_[0] == "conv2d_3x3_multi_batch" || + kernel_func_names_[0] == "conv2d_3x3_opt") { + int w_blk_size = 5; + int w_blk = (default_w_blk_ + w_blk_size - 1) / w_blk_size; + + int h_blk_size = 1; + int h_blk = (default_nh_blk_ + h_blk_size - 1) / h_blk_size; + + c_blk_ = default_c_blk_; + w_blk_ = w_blk; + nh_blk_ = h_blk; + + global_work_size_ = cl::NDRange{static_cast(c_blk_), + static_cast(w_blk_), + static_cast(nh_blk_)}; + } else if (kernel_func_names_[0] == "conv2d_5x5_multi_batch" || + kernel_func_names_[0] == "conv2d_5x5_opt") { + int w_blk_size = 5; + int w_blk = (default_w_blk_ + w_blk_size - 1) / w_blk_size; + + int h_blk_size = 1; + int h_blk = (default_nh_blk_ + h_blk_size - 1) / h_blk_size; + + c_blk_ = default_c_blk_; + w_blk_ = w_blk; + nh_blk_ = h_blk; + global_work_size_ = cl::NDRange{static_cast(c_blk_), + static_cast(w_blk_), + static_cast(nh_blk_)}; + } else if (kernel_func_names_[0] == "conv2d_7x7_multi_batch" || + kernel_func_names_[0] == "conv2d_7x7_opt") { + int w_blk_size = 5; + int w_blk = (default_w_blk_ + w_blk_size - 1) / w_blk_size; + + int h_blk_size = 1; + int h_blk = (default_nh_blk_ + h_blk_size - 1) / h_blk_size; + + c_blk_ = default_c_blk_; + w_blk_ = w_blk; + nh_blk_ = h_blk; + global_work_size_ = cl::NDRange{static_cast(c_blk_), + static_cast(w_blk_), + static_cast(nh_blk_)}; } - status = kernel.setArg(++arg_idx, *out_image); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, strides[0]); - CL_CHECK_FATAL(status); - - status = kernel.setArg(++arg_idx, paddings[0]); - CL_CHECK_FATAL(status); - - status = kernel.setArg(++arg_idx, dilations[0]); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, batch); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_channel); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_width); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_height); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, output_width); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, output_height); - CL_CHECK_FATAL(status); +} +void ConvImageCompute::Conv2d1x1opt(bool enable_tune) { #ifdef LITE_WITH_LOG - // VLOG(4) << "out_image: " << out_image; - VLOG(4) << "global_work_size_[3D]: {" << global_work_size_[0] << "," - << global_work_size_[1] << "," << global_work_size_[2] << "}"; + PrintConvInfo(); #endif + auto& context = ctx_->As(); - status = EnqueueNDRangeKernel(context, - kernel, - cl::NullRange, - global_work_size_, - local_work_size_, - nullptr, - event_); - CL_CHECK_FATAL(status); - if (is_turn) { + status_ = kernel_.setArg(0, c_blk_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(1, w_blk_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(2, nh_blk_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(3, *input_image_p_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(4, *filter_image_p_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(5, *bias_image_p_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(6, *output_image_p_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(7, stride_h_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(8, offset_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(9, input_c_block_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(10, input_tensor_c_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(11, dilation_h_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(12, input_tensor_w_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(13, input_tensor_h_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(14, output_tensor_w_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(15, output_tensor_h_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(16, default_w_blk_); + CL_CHECK_FATAL(status_); + + status_ = EnqueueNDRangeKernel(context, + kernel_, + cl::NullRange, + global_work_size_, + local_work_size_, + nullptr, + event_); + CL_CHECK_FATAL(status_); + if (enable_tune) { CLRuntime::Global()->command_queue().finish(); } } -void ConvImageCompute::Conv2d5x5(bool is_turn) { - auto& context = ctx_->As(); - CHECK(context.cl_context() != nullptr); - const auto& param = *param_.get_mutable(); - auto input_dims = param.x->dims(); - auto paddings = *param.paddings; - auto strides = param.strides; - auto* input_image = param.x->data(); - auto* filter_image = filter_gpu_image_->data(); - auto filter_dims = param.filter->dims(); - auto output_dims = param.output->dims(); - - int input_width = input_dims[3]; - int input_height = input_dims[2]; - int output_width = output_dims[3]; - int output_height = output_dims[2]; - int filter_width = filter_dims[3]; - int filter_height = filter_dims[2]; - auto out_image_shape = InitImageDimInfoWith(output_dims); - auto* out_image = param.output->mutable_data( - out_image_shape["width"], out_image_shape["height"]); - - const bool has_bias = param.bias != nullptr; - const bool is_element_wise_bias = - has_bias && param.output->dims() == param.bias->dims(); - int offset = static_cast(param.filter->dims()[2]) / 2 - - static_cast(paddings[0]); - - // calc input_c_block - auto input_image_shape = InitImageDimInfoWith(input_dims); - int input_c_block = input_image_shape["width"] / input_dims[3]; - int input_c = input_dims[1]; - auto dilations = *param.dilations; - +void ConvImageCompute::Conv2d3x3(bool enable_tune) { #ifdef LITE_WITH_LOG - VLOG(4) << "============ conv2d params ============"; - VLOG(4) << "input_image_shape: " << input_image_shape["width"] << "," - << input_image_shape["height"]; - VLOG(4) << "input_c_block: " << input_c_block; - VLOG(4) << "input_c: " << input_c; - // VLOG(4) << "input_image: " << input_image; - VLOG(4) << "input_dims: " << input_dims; - VLOG(4) << "filter_dims: " << filter_dims; - // VLOG(4) << "filter_image: " << filter_image; - VLOG(4) << "output_dims: " << output_dims; - VLOG(4) << "out_image_shape: " << out_image_shape["width"] << ", " - << out_image_shape["height"]; - VLOG(4) << "paddings: " << paddings[0] << "," << paddings[1]; - VLOG(4) << "has bias: " << has_bias; - VLOG(4) << "is_element_wise_bias : " << is_element_wise_bias; - VLOG(4) << "strides: " << strides[0] << "," << strides[1]; - VLOG(4) << "offset: " << offset; - VLOG(4) << "dilations.size : " << dilations.size(); - VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1]; + PrintConvInfo(); #endif + auto& context = ctx_->As(); - CHECK_GE(dilations.size(), 2); - CHECK(dilations[0] == dilations[1]); - CHECK_GE(input_dims.size(), 4); - CHECK_GE(paddings.size(), 2); - CHECK(paddings[0] == paddings[1]); - CHECK_GE(strides.size(), 2); - CHECK(strides[0] == strides[1]); - - const cl::Image2D* bias_image = nullptr; - if (has_bias) { - bias_image = bias_gpu_image_->data(); - } + status_ = kernel_.setArg(0, c_blk_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(1, w_blk_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(2, nh_blk_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(3, *input_image_p_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(4, *filter_image_p_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(5, *bias_image_p_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(6, *output_image_p_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(7, stride_h_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(8, offset_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(9, input_c_block_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(10, dilation_h_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(11, input_tensor_w_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(12, input_tensor_h_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(13, output_tensor_w_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(14, output_tensor_h_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(15, output_tensor_c_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(16, filter_tensor_c_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(17, filter_tensor_w_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(18, filter_tensor_h_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(19, groups_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(20, input_tensor_c_); + CL_CHECK_FATAL(status_); + + status_ = EnqueueNDRangeKernel(context, + kernel_, + cl::NullRange, + global_work_size_, + cl::NullRange, + nullptr, + event_); + CL_CHECK_FATAL(status_); +} - auto kernel = kernel_; - - cl_int status; - int arg_idx = 0; - status = kernel.setArg(arg_idx, c_blk_); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, w_blk_); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, nh_blk_); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *input_image); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *filter_image); - CL_CHECK_FATAL(status); - if (has_bias) { +void ConvImageCompute::Conv2d3x3opt(bool enable_tune) { #ifdef LITE_WITH_LOG - VLOG(4) << "set bias_image: "; + PrintConvInfo(); #endif - status = kernel.setArg(++arg_idx, *bias_image); - CL_CHECK_FATAL(status); - } - status = kernel.setArg(++arg_idx, *out_image); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, strides[0]); - CL_CHECK_FATAL(status); - - status = kernel.setArg(++arg_idx, offset); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_c_block); - CL_CHECK_FATAL(status); - - status = kernel.setArg(++arg_idx, dilations[0]); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_width); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_height); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, output_width); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, output_height); - CL_CHECK_FATAL(status); + auto& context = ctx_->As(); + + status_ = kernel_.setArg(0, c_blk_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(1, w_blk_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(2, nh_blk_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(3, *input_image_p_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(4, *filter_image_p_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(5, *bias_image_p_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(6, *output_image_p_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(7, stride_h_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(8, pad_left_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(9, dilation_h_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(10, input_tensor_n_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(11, input_tensor_c_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(12, input_tensor_w_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(13, input_tensor_h_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(14, output_tensor_w_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(15, output_tensor_h_); + CL_CHECK_FATAL(status_); #ifdef LITE_WITH_LOG // VLOG(4) << "out_image: " << out_image; @@ -1089,697 +781,406 @@ void ConvImageCompute::Conv2d5x5(bool is_turn) { << global_work_size_[1] << "," << global_work_size_[2] << "}"; #endif - status = EnqueueNDRangeKernel(context, - kernel, - cl::NullRange, - global_work_size_, - cl::NullRange, - nullptr, - event_); - CL_CHECK_FATAL(status); - if (is_turn) { + status_ = EnqueueNDRangeKernel(context, + kernel_, + cl::NullRange, + global_work_size_, + local_work_size_, + nullptr, + event_); + CL_CHECK_FATAL(status_); + if (enable_tune) { CLRuntime::Global()->command_queue().finish(); } } -void ConvImageCompute::Conv2d5x5opt(bool is_turn) { - auto& context = ctx_->As(); - CHECK(context.cl_context() != nullptr); - const auto& param = *param_.get_mutable(); - auto input_dims = param.x->dims(); - auto paddings = *param.paddings; - auto strides = param.strides; - auto dilations = *param.dilations; - - auto* input_image = param.x->data(); - auto* filter_image = filter_gpu_image_->data(); - auto filter_dims = param.filter->dims(); - auto output_dims = param.output->dims(); - - int input_width = input_dims[3]; - int input_height = input_dims[2]; - int input_channel = input_dims[1]; - int output_width = output_dims[3]; - int output_height = output_dims[2]; - int output_channel = output_dims[1]; - CHECK_EQ(input_dims[0], output_dims[0]); - int batch = input_dims[0]; - - auto out_image_shape = InitImageDimInfoWith(output_dims); - auto* out_image = param.output->mutable_data( - out_image_shape["width"], out_image_shape["height"]); - - const bool has_bias = param.bias != nullptr; - const bool is_element_wise_bias = - has_bias && param.output->dims() == param.bias->dims(); - -// default_work_size[2] = h_blk; +void ConvImageCompute::Conv2d5x5(bool enable_tune) { #ifdef LITE_WITH_LOG - VLOG(4) << "============ conv2d params ============"; - // VLOG(4) << "input_image_shape: " << input_image_shape["width"] << "," - // << input_image_shape["height"]; - // VLOG(4) << "input_image: " << input_image; - VLOG(4) << "input_dims: " << input_dims; - VLOG(4) << "filter_dims: " << filter_dims; - // VLOG(4) << "filter_image: " << filter_image; - VLOG(4) << "output_dims: " << output_dims; - VLOG(4) << "out_image_shape: " << out_image_shape["width"] << ", " - << out_image_shape["height"]; - VLOG(4) << "paddings: " << paddings[0] << "," << paddings[1]; - VLOG(4) << "has bias: " << has_bias; - VLOG(4) << "is_element_wise_bias : " << is_element_wise_bias; - VLOG(4) << "strides: " << strides[0] << "," << strides[1]; - VLOG(4) << "dilations.size : " << dilations.size(); - VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1]; + PrintConvInfo(); #endif - CHECK_GE(dilations.size(), 2); - CHECK(dilations[0] == dilations[1]); - CHECK_GE(input_dims.size(), 4); - CHECK_GE(paddings.size(), 2); - CHECK(paddings[0] == paddings[1]); - CHECK_GE(strides.size(), 2); - CHECK(strides[0] == strides[1]); - - const cl::Image2D* bias_image = nullptr; - if (has_bias) { - bias_image = bias_gpu_image_->data(); - } - - auto kernel = kernel_; - cl_int status; - int arg_idx = 0; - status = kernel.setArg(arg_idx, c_blk_); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, w_blk_); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, nh_blk_); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *input_image); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *filter_image); - CL_CHECK_FATAL(status); - if (has_bias) { - status = kernel.setArg(++arg_idx, *bias_image); - CL_CHECK_FATAL(status); - } - status = kernel.setArg(++arg_idx, *out_image); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, strides[0]); - CL_CHECK_FATAL(status); - - status = kernel.setArg(++arg_idx, paddings[0]); - CL_CHECK_FATAL(status); - - status = kernel.setArg(++arg_idx, dilations[0]); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, batch); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_channel); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_width); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_height); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, output_width); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, output_height); - CL_CHECK_FATAL(status); - - // VLOG(4) << "out_image: " << out_image; + auto& context = ctx_->As(); - status = EnqueueNDRangeKernel(context, - kernel, - cl::NullRange, - global_work_size_, - local_work_size_, - nullptr, - event_); - CL_CHECK_FATAL(status); - if (is_turn) { + status_ = kernel_.setArg(0, c_blk_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(1, w_blk_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(2, nh_blk_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(3, *input_image_p_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(4, *filter_image_p_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(5, *bias_image_p_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(6, *output_image_p_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(7, stride_h_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(8, offset_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(9, input_c_block_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(10, dilation_h_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(11, input_tensor_w_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(12, input_tensor_h_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(13, output_tensor_w_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(14, output_tensor_h_); + CL_CHECK_FATAL(status_); + + status_ = EnqueueNDRangeKernel(context, + kernel_, + cl::NullRange, + global_work_size_, + cl::NullRange, + nullptr, + event_); + CL_CHECK_FATAL(status_); + if (enable_tune) { CLRuntime::Global()->command_queue().finish(); } } -void ConvImageCompute::Conv2d7x7(bool is_turn) { - auto& context = ctx_->As(); - CHECK(context.cl_context() != nullptr); - const auto& param = *param_.get_mutable(); - auto input_dims = param.x->dims(); - auto paddings = *param.paddings; - auto strides = param.strides; - auto* input_image = param.x->data(); - auto* filter_image = filter_gpu_image_->data(); - auto filter_dims = param.filter->dims(); - auto output_dims = param.output->dims(); - - int input_width = input_dims[3]; - int input_height = input_dims[2]; - int output_width = output_dims[3]; - int output_height = output_dims[2]; - int filter_width = filter_dims[3]; - int filter_height = filter_dims[2]; - auto out_image_shape = InitImageDimInfoWith(output_dims); - auto* out_image = param.output->mutable_data( - out_image_shape["width"], out_image_shape["height"]); - - const bool has_bias = param.bias != nullptr; - const bool is_element_wise_bias = - has_bias && param.output->dims() == param.bias->dims(); - int offset = static_cast(param.filter->dims()[2]) / 2 - - static_cast(paddings[0]); - - // calc input_c_block - auto input_image_shape = InitImageDimInfoWith(input_dims); - int input_c_block = input_image_shape["width"] / input_dims[3]; - int input_c = input_dims[1]; - auto dilations = *param.dilations; - +void ConvImageCompute::Conv2d5x5opt(bool enable_tune) { #ifdef LITE_WITH_LOG - VLOG(4) << "============ conv2d params ============"; - VLOG(4) << "input_image_shape: " << input_image_shape["width"] << "," - << input_image_shape["height"]; - VLOG(4) << "input_c_block: " << input_c_block; - VLOG(4) << "input_c: " << input_c; - // VLOG(4) << "input_image: " << input_image; - VLOG(4) << "input_dims: " << input_dims; - VLOG(4) << "filter_dims: " << filter_dims; - // VLOG(4) << "filter_image: " << filter_image; - VLOG(4) << "output_dims: " << output_dims; - VLOG(4) << "out_image_shape: " << out_image_shape["width"] << ", " - << out_image_shape["height"]; - VLOG(4) << "paddings: " << paddings[0] << "," << paddings[1]; - VLOG(4) << "has bias: " << has_bias; - VLOG(4) << "is_element_wise_bias : " << is_element_wise_bias; - VLOG(4) << "strides: " << strides[0] << "," << strides[1]; - VLOG(4) << "offset: " << offset; - VLOG(4) << "dilations.size : " << dilations.size(); - VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1]; + PrintConvInfo(); #endif + auto& context = ctx_->As(); - CHECK_GE(dilations.size(), 2); - CHECK(dilations[0] == dilations[1]); - CHECK_GE(input_dims.size(), 4); - CHECK_GE(paddings.size(), 2); - CHECK(paddings[0] == paddings[1]); - CHECK_GE(strides.size(), 2); - CHECK(strides[0] == strides[1]); - - const cl::Image2D* bias_image = nullptr; - if (has_bias) { - bias_image = bias_gpu_image_->data(); + status_ = kernel_.setArg(0, c_blk_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(1, w_blk_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(2, nh_blk_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(3, *input_image_p_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(4, *filter_image_p_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(5, *bias_image_p_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(6, *output_image_p_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(7, stride_h_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(8, pad_left_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(9, dilation_h_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(10, input_tensor_n_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(11, input_tensor_c_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(12, input_tensor_w_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(13, input_tensor_h_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(14, output_tensor_w_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(15, output_tensor_h_); + CL_CHECK_FATAL(status_); + + status_ = EnqueueNDRangeKernel(context, + kernel_, + cl::NullRange, + global_work_size_, + local_work_size_, + nullptr, + event_); + CL_CHECK_FATAL(status_); + if (enable_tune) { + CLRuntime::Global()->command_queue().finish(); } +} - auto kernel = kernel_; - - cl_int status; - int arg_idx = 0; - status = kernel.setArg(arg_idx, c_blk_); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, w_blk_); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, nh_blk_); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *input_image); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *filter_image); - CL_CHECK_FATAL(status); - if (has_bias) { +void ConvImageCompute::Conv2d7x7(bool enable_tune) { #ifdef LITE_WITH_LOG - VLOG(4) << "set bias_image: "; + PrintConvInfo(); #endif - status = kernel.setArg(++arg_idx, *bias_image); - CL_CHECK_FATAL(status); + auto& context = ctx_->As(); + + status_ = kernel_.setArg(0, c_blk_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(1, w_blk_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(2, nh_blk_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(3, *input_image_p_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(4, *filter_image_p_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(5, *bias_image_p_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(6, *output_image_p_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(7, stride_h_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(8, offset_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(9, input_c_block_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(9, dilation_h_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(10, input_tensor_w_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(11, input_tensor_h_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(12, output_tensor_w_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(13, output_tensor_h_); + CL_CHECK_FATAL(status_); + + status_ = EnqueueNDRangeKernel(context, + kernel_, + cl::NullRange, + global_work_size_, + cl::NullRange, + nullptr, + event_); + CL_CHECK_FATAL(status_); + if (enable_tune) { + CLRuntime::Global()->command_queue().finish(); } - status = kernel.setArg(++arg_idx, *out_image); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, strides[0]); - CL_CHECK_FATAL(status); - - status = kernel.setArg(++arg_idx, offset); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_c_block); - CL_CHECK_FATAL(status); - - status = kernel.setArg(++arg_idx, dilations[0]); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_width); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_height); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, output_width); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, output_height); - CL_CHECK_FATAL(status); +} +void ConvImageCompute::Conv2d7x7opt(bool enable_tune) { #ifdef LITE_WITH_LOG - // VLOG(4) << "out_image: " << out_image; - VLOG(4) << "global_work_size_[3D]: {" << global_work_size_[0] << "," - << global_work_size_[1] << "," << global_work_size_[2] << "}"; + PrintConvInfo(); #endif + auto& context = ctx_->As(); - status = EnqueueNDRangeKernel(context, - kernel, - cl::NullRange, - global_work_size_, - cl::NullRange, - nullptr, - event_); - CL_CHECK_FATAL(status); - - if (is_turn) { + status_ = kernel_.setArg(0, c_blk_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(1, w_blk_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(2, nh_blk_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(3, *input_image_p_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(4, *filter_image_p_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(5, *bias_image_p_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(6, *output_image_p_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(7, stride_h_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(8, pad_left_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(9, dilation_h_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(10, input_tensor_n_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(11, input_tensor_c_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(12, input_tensor_w_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(13, input_tensor_h_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(14, output_tensor_w_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(15, output_tensor_h_); + CL_CHECK_FATAL(status_); + + status_ = EnqueueNDRangeKernel(context, + kernel_, + cl::NullRange, + global_work_size_, + local_work_size_, + nullptr, + event_); + CL_CHECK_FATAL(status_); + + if (enable_tune) { CLRuntime::Global()->command_queue().finish(); } } -void ConvImageCompute::Conv2d7x7opt(bool is_turn) { - auto& context = ctx_->As(); - CHECK(context.cl_context() != nullptr); - const auto& param = *param_.get_mutable(); - auto input_dims = param.x->dims(); - auto paddings = *param.paddings; - auto strides = param.strides; - auto dilations = *param.dilations; - - auto* input_image = param.x->data(); - auto* filter_image = filter_gpu_image_->data(); - auto filter_dims = param.filter->dims(); - auto output_dims = param.output->dims(); - - int input_width = input_dims[3]; - int input_height = input_dims[2]; - int input_channel = input_dims[1]; - int output_width = output_dims[3]; - int output_height = output_dims[2]; - int output_channel = output_dims[1]; - CHECK_EQ(input_dims[0], output_dims[0]); - int batch = input_dims[0]; - auto out_image_shape = InitImageDimInfoWith(output_dims); - auto* out_image = param.output->mutable_data( - out_image_shape["width"], out_image_shape["height"]); - - const bool has_bias = param.bias != nullptr; - const bool is_element_wise_bias = - has_bias && param.output->dims() == param.bias->dims(); +void ConvImageCompute::DepthwiseConv2d3x3s1(bool enable_tune) { #ifdef LITE_WITH_LOG - VLOG(4) << "============ conv2d 7x7 params ============"; - // VLOG(4) << "input_image_shape: " << input_image_shape["width"] << "," - // << input_image_shape["height"]; - // VLOG(4) << "input_image: " << input_image; - VLOG(4) << "input_dims: " << input_dims; - VLOG(4) << "filter_dims: " << filter_dims; - // VLOG(4) << "filter_image: " << filter_image; - VLOG(4) << "output_dims: " << output_dims; - VLOG(4) << "out_image_shape: " << out_image_shape["width"] << ", " - << out_image_shape["height"]; - VLOG(4) << "paddings: " << paddings[0] << "," << paddings[1]; - VLOG(4) << "has bias: " << has_bias; - VLOG(4) << "is_element_wise_bias : " << is_element_wise_bias; - VLOG(4) << "strides: " << strides[0] << "," << strides[1]; - VLOG(4) << "dilations.size : " << dilations.size(); - VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1]; + PrintConvInfo(); #endif - CHECK_GE(dilations.size(), 2); - CHECK(dilations[0] == dilations[1]); - CHECK_GE(input_dims.size(), 4); - CHECK_GE(paddings.size(), 2); - CHECK(paddings[0] == paddings[1]); - CHECK_GE(strides.size(), 2); - CHECK(strides[0] == strides[1]); - - const cl::Image2D* bias_image = nullptr; - if (has_bias) { - bias_image = bias_gpu_image_->data(); - } + auto& context = ctx_->As(); - auto kernel = kernel_; - - cl_int status; - int arg_idx = 0; - status = kernel.setArg(arg_idx, c_blk_); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, w_blk_); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, nh_blk_); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *input_image); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *filter_image); - CL_CHECK_FATAL(status); - if (has_bias) { - status = kernel.setArg(++arg_idx, *bias_image); - CL_CHECK_FATAL(status); - } - status = kernel.setArg(++arg_idx, *out_image); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, strides[0]); - CL_CHECK_FATAL(status); - - status = kernel.setArg(++arg_idx, paddings[0]); - CL_CHECK_FATAL(status); - - status = kernel.setArg(++arg_idx, dilations[0]); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, batch); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_channel); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_width); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_height); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, output_width); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, output_height); - CL_CHECK_FATAL(status); - - status = EnqueueNDRangeKernel(context, - kernel, - cl::NullRange, - global_work_size_, - local_work_size_, - nullptr, - event_); - CL_CHECK_FATAL(status); - - if (is_turn) { + status_ = kernel_.setArg(0, c_blk_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(1, w_blk_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(2, nh_blk_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(3, *input_image_p_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(4, *filter_image_p_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(5, *bias_image_p_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(6, *output_image_p_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(7, stride_h_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(8, pad_left_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(9, dilation_h_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(10, input_tensor_c_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(11, input_tensor_w_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(12, input_tensor_h_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(13, output_tensor_w_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(14, output_tensor_h_); + CL_CHECK_FATAL(status_); + + status_ = EnqueueNDRangeKernel(context, + kernel_, + cl::NullRange, + global_work_size_, + local_work_size_, + nullptr, + event_); + CL_CHECK_FATAL(status_); + + if (enable_tune) { CLRuntime::Global()->command_queue().finish(); } } -void ConvImageCompute::DepthwiseConv2d3x3s1(bool is_turn) { - auto& context = ctx_->As(); - CHECK(context.cl_context() != nullptr); - const auto& param = *param_.get_mutable(); - auto x_dims = param.x->dims(); - auto filter_dims = param.filter->dims(); - auto output_dims = param.output->dims(); - auto paddings = *param.paddings; - auto strides = param.strides; - auto dilations = *param.dilations; - - auto* input_img = param.x->data(); - auto* filter_img = filter_gpu_image_->data(); - - const cl::Image2D* bias_img = nullptr; - if (param.bias) { - bias_img = bias_gpu_image_->data(); - } - - auto image_shape = InitImageDimInfoWith(output_dims); - - auto* output_img = param.output->mutable_data( - image_shape["width"], image_shape["height"]); - - auto kernel = kernel_; - cl_int status; - int arg_idx = 0; - status = kernel.setArg(arg_idx, c_blk_); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, w_blk_); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, nh_blk_); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *input_img); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *filter_img); - CL_CHECK_FATAL(status); - - const bool has_bias = param.bias != nullptr; - const bool is_element_wise_bias = - has_bias && param.output->dims() == param.bias->dims(); - const cl::Image2D* bias_image = nullptr; - if (has_bias) { - bias_image = bias_gpu_image_->data(); +void ConvImageCompute::DepthwiseConv2d3x3(bool enable_tune) { #ifdef LITE_WITH_LOG - VLOG(4) << "set bias_image: "; + PrintConvInfo(); #endif - status = kernel.setArg(++arg_idx, *bias_image); - CL_CHECK_FATAL(status); - } - status = kernel.setArg(++arg_idx, *output_img); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(strides[0])); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(paddings[0])); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(dilations[0])); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(x_dims[1])); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(x_dims[3])); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(x_dims[2])); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(output_dims[3])); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(output_dims[2])); - CL_CHECK_FATAL(status); - - status = EnqueueNDRangeKernel(context, - kernel, - cl::NullRange, - global_work_size_, - local_work_size_, - nullptr, - event_); - CL_CHECK_FATAL(status); - - if (is_turn) { + auto& context = ctx_->As(); + + status_ = kernel_.setArg(0, c_blk_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(1, w_blk_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(2, nh_blk_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(3, *input_image_p_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(4, *filter_image_p_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(5, *bias_image_p_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(6, *output_image_p_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(7, stride_h_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(8, offset_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(9, dilation_h_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(10, input_c_block_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(11, input_tensor_w_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(12, input_tensor_h_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(13, output_tensor_w_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(14, output_tensor_h_); + CL_CHECK_FATAL(status_); + + status_ = EnqueueNDRangeKernel(context, + kernel_, + cl::NullRange, + global_work_size_, + cl::NullRange, + nullptr, + event_); + CL_CHECK_FATAL(status_); + + if (enable_tune) { CLRuntime::Global()->command_queue().finish(); } } -void ConvImageCompute::DepthwiseConv2d3x3(bool is_turn) { - auto& context = ctx_->As(); - CHECK(context.cl_context() != nullptr); - const auto& param = *param_.get_mutable(); - auto x_dims = param.x->dims(); - auto filter_dims = param.filter->dims(); - auto output_dims = param.output->dims(); - auto paddings = *param.paddings; - auto strides = param.strides; - auto dilations = *param.dilations; - int offset = filter_dims[2] / 2 - paddings[0]; - int input_c_block = (x_dims[1] + 3) / 4; - - auto* input_img = param.x->data(); - auto* filter_img = filter_gpu_image_->data(); - - const cl::Image2D* bias_img = nullptr; - if (param.bias) { - bias_img = bias_gpu_image_->data(); - } - - auto image_shape = InitImageDimInfoWith(output_dims); - - auto* output_img = param.output->mutable_data( - image_shape["width"], image_shape["height"]); - - auto kernel = kernel_; - +void ConvImageCompute::DepthwiseConv2d(bool enable_tune) { #ifdef LITE_WITH_LOG - VLOG(4) << "setArg"; - VLOG(4) << "strides = " << strides[0]; - VLOG(4) << "offset = " << offset; - VLOG(4) << "dilations = " << dilations[0]; - VLOG(4) << "input_c_block = " << input_c_block; - VLOG(4) << "x_dims[3] = " << x_dims[3]; - VLOG(4) << "x_dims[2] = " << x_dims[2]; - VLOG(4) << "output_dims[3] = " << output_dims[3]; - VLOG(4) << "output_dims[2] = " << output_dims[2]; + PrintConvInfo(); #endif + auto& context = ctx_->As(); - cl_int status; - int arg_idx = 0; - status = kernel.setArg(arg_idx, c_blk_); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, w_blk_); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, nh_blk_); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *input_img); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *filter_img); - CL_CHECK_FATAL(status); - const bool has_bias = param.bias != nullptr; - const bool is_element_wise_bias = - has_bias && param.output->dims() == param.bias->dims(); - const cl::Image2D* bias_image = nullptr; - if (has_bias) { - bias_image = bias_gpu_image_->data(); -#ifdef LITE_WITH_LOG - VLOG(4) << "set bias_image: "; -#endif - status = kernel.setArg(++arg_idx, *bias_image); - CL_CHECK_FATAL(status); - } - status = kernel.setArg(++arg_idx, *output_img); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(strides[0])); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(offset)); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(dilations[0])); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(input_c_block)); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(x_dims[3])); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(x_dims[2])); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(output_dims[3])); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(output_dims[2])); - CL_CHECK_FATAL(status); - - status = EnqueueNDRangeKernel(context, - kernel, - cl::NullRange, - global_work_size_, - cl::NullRange, - nullptr, - event_); - CL_CHECK_FATAL(status); - - if (is_turn) { + status_ = kernel_.setArg(0, c_blk_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(1, w_blk_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(2, nh_blk_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(3, *input_image_p_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(4, *filter_image_p_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(5, *bias_image_p_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(6, *output_image_p_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(7, stride_h_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(8, offset_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(9, input_c_block_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(10, dilation_h_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(11, input_tensor_w_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(12, input_tensor_h_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(13, output_tensor_w_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(14, output_tensor_h_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(15, filter_tensor_w_); + CL_CHECK_FATAL(status_); + status_ = kernel_.setArg(16, filter_tensor_h_); + CL_CHECK_FATAL(status_); + + status_ = EnqueueNDRangeKernel(context, + kernel_, + cl::NullRange, + global_work_size_, + cl::NullRange, + nullptr, + event_); + CL_CHECK_FATAL(status_); + + if (enable_tune) { CLRuntime::Global()->command_queue().finish(); } } -void ConvImageCompute::DepthwiseConv2d(bool is_turn) { - auto& context = ctx_->As(); - CHECK(context.cl_context() != nullptr); - const auto& param = *param_.get_mutable(); - auto input_dims = param.x->dims(); - auto paddings = *param.paddings; - auto strides = param.strides; - auto* input_image = param.x->data(); - auto* filter_image = filter_gpu_image_->data(); - auto filter_dims = param.filter->dims(); - auto output_dims = param.output->dims(); - - int input_width = input_dims[3]; - int input_height = input_dims[2]; - int output_width = output_dims[3]; - int output_height = output_dims[2]; - int filter_width = filter_dims[3]; - int filter_height = filter_dims[2]; - auto out_image_shape = InitImageDimInfoWith(output_dims); - auto* out_image = param.output->mutable_data( - out_image_shape["width"], out_image_shape["height"]); - - const bool has_bias = param.bias != nullptr; - const bool is_element_wise_bias = - has_bias && param.output->dims() == param.bias->dims(); - int offset = static_cast(param.filter->dims()[2]) / 2 - - static_cast(paddings[0]); +void ConvImageCompute::Run() { (this->*impl_)(false); } - // calc input_c_block - auto input_image_shape = InitImageDimInfoWith(input_dims); - int input_c_block = input_image_shape["width"] / input_dims[3]; - int input_c = input_dims[1]; - auto dilations = *param.dilations; +void ConvImageCompute::PrintConvInfo() { + const bool is_element_wise_bias = + has_bias_ && conv_param_->output->dims() == conv_param_->bias->dims(); -#ifdef LITE_WITH_LOG - VLOG(4) << "============ depthwise conv2d params ============"; - VLOG(4) << "input_image_shape: " << input_image_shape["width"] << "," - << input_image_shape["height"]; - VLOG(4) << "input_c_block: " << input_c_block; - VLOG(4) << "input_c: " << input_c; - // VLOG(4) << "input_image: " << input_image; - VLOG(4) << "filter_dims: " << filter_dims; + VLOG(4) << "input_image_shape: " << input_image_w_ << "," << input_image_h_; + // VLOG(4) << "input_image: " << input_image_p_; + VLOG(4) << "input_dims: " << conv_param_->x->dims(); + VLOG(4) << "filter_dims: " << conv_param_->filter->dims(); // VLOG(4) << "filter_image: " << filter_image; - VLOG(4) << "output_dims: " << output_dims; - VLOG(4) << "out_image_shape: " << out_image_shape["width"] << ", " - << out_image_shape["height"]; - VLOG(4) << "paddings: " << paddings[0] << "," << paddings[1]; - VLOG(4) << "has bias: " << has_bias; + VLOG(4) << "output_dims: " << conv_param_->output->dims(); + VLOG(4) << "out_image_shape: " << output_image_w_ << ", " << output_image_h_; + VLOG(4) << "paddings: " << pad_left_ << "," << pad_up_; + VLOG(4) << "has bias: " << has_bias_; VLOG(4) << "is_element_wise_bias : " << is_element_wise_bias; - VLOG(4) << "strides: " << strides[0] << "," << strides[1]; - VLOG(4) << "offset: " << offset; - VLOG(4) << "dilations.size : " << dilations.size(); - VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1]; -#endif - - CHECK_GE(dilations.size(), 2); - CHECK(dilations[0] == dilations[1]); - CHECK_GE(input_dims.size(), 4); - CHECK_GE(paddings.size(), 2); - CHECK(paddings[0] == paddings[1]); - CHECK_GE(strides.size(), 2); - CHECK(strides[0] == strides[1]); - - // handle bias use buffer for channel wise , use image for element wise - const cl::Buffer* bias_buf = nullptr; - const cl::Image2D* bias_image = nullptr; - if (has_bias) { - bias_image = bias_gpu_image_->data(); - } - - auto kernel = kernel_; - - cl_int status; - int arg_idx = 0; - status = kernel.setArg(arg_idx, c_blk_); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, w_blk_); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, nh_blk_); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *input_image); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *filter_image); - CL_CHECK_FATAL(status); - if (has_bias) { -#ifdef LITE_WITH_LOG - VLOG(4) << "set bias_image: "; -#endif - status = kernel.setArg(++arg_idx, *bias_image); - CL_CHECK_FATAL(status); - } - status = kernel.setArg(++arg_idx, *out_image); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, strides[0]); - CL_CHECK_FATAL(status); - - status = kernel.setArg(++arg_idx, offset); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_c_block); - CL_CHECK_FATAL(status); - - status = kernel.setArg(++arg_idx, dilations[0]); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_width); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_height); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, output_width); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, output_height); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, filter_width); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, filter_height); - CL_CHECK_FATAL(status); - -#ifdef LITE_WITH_LOG + VLOG(4) << "strides: " << stride_h_ << "," << stride_w_; + VLOG(4) << "offset: "; + VLOG(4) << "dilations.size : " << conv_param_->dilations->size(); + VLOG(4) << "dilations: " << dilation_h_ << ", " << dilation_w_; VLOG(4) << "global_work_size_[3D]: {" << global_work_size_[0] << "," << global_work_size_[1] << "," << global_work_size_[2] << "}"; -#endif - - status = EnqueueNDRangeKernel(context, - kernel, - cl::NullRange, - global_work_size_, - cl::NullRange, - nullptr, - event_); - CL_CHECK_FATAL(status); } -void ConvImageCompute::Run() { (this->*impl_)(false); } - -double ConvImageCompute::Turn(int times) { +double ConvImageCompute::Tune(int times) { auto GetCurrentUS = []() -> double { struct timeval time; gettimeofday(&time, NULL); diff --git a/lite/kernels/opencl/conv_image_compute.h b/lite/kernels/opencl/conv_image_compute.h index 64276a5721cb20718604d91d3cfac31e583ddbf1..4eab7be1f1ac6459250c6df984160f0f6060ea1c 100644 --- a/lite/kernels/opencl/conv_image_compute.h +++ b/lite/kernels/opencl/conv_image_compute.h @@ -33,6 +33,7 @@ namespace paddle { namespace lite { namespace kernels { namespace opencl { + class ConvImageCompute : public KernelLite { @@ -42,8 +43,11 @@ class ConvImageCompute : public KernelLite kernel_func_names_{}; @@ -79,19 +87,72 @@ class ConvImageCompute : public KernelLite tensor_hold_bias_image_{nullptr}; cl::NDRange global_work_size_ = cl::NDRange{ static_cast(1), static_cast(1), static_cast(1)}; + + // opencl kernel args int c_blk_ = 1; int w_blk_ = 1; int nh_blk_ = 1; + const cl::Image2D* input_image_p_{nullptr}; + const cl::Image2D* filter_image_p_{nullptr}; + const cl::Image2D* bias_image_p_{nullptr}; + const cl::Image2D* output_image_p_{nullptr}; + + int stride_h_{-1}; + int stride_w_{-1}; + + int dilation_h_{-1}; + int dilation_w_{-1}; + + int pad_up_{-1}; + int pad_down_{-1}; + int pad_left_{-1}; + int pad_right_{-1}; + + int offset_{-1}; + int groups_{-1}; + bool relu_fused_{false}; + bool has_bias_{false}; + + int input_tensor_n_{-1}; + int input_tensor_c_{-1}; + int input_tensor_h_{-1}; + int input_tensor_w_{-1}; + int input_image_h_{-1}; + int input_image_w_{-1}; + int input_c_block_{-1}; + + int output_tensor_n_{-1}; + int output_tensor_c_{-1}; + int output_tensor_h_{-1}; + int output_tensor_w_{-1}; + int output_image_h_{-1}; + int output_image_w_{-1}; + + int filter_tensor_n_{-1}; + int filter_tensor_c_{-1}; + int filter_tensor_h_{-1}; + int filter_tensor_w_{-1}; + int filter_image_h_{-1}; + int filter_image_w_{-1}; + + int bias_image_h_{-1}; + int bias_image_w_{-1}; + int default_c_blk_ = 1; int default_w_blk_ = 1; int default_nh_blk_ = 1; + // ================= + + DDim last_input_dims_{}; + bool is_first_epoch_for_run_{true}; cl::Kernel kernel_; + cl_int status_; cl::NDRange local_work_size_ = cl::NDRange{ static_cast(1), static_cast(1), static_cast(1)}; bool use_lws_{true}; - bool use_turn_{false}; + bool use_tune_{false}; }; } // namespace opencl diff --git a/lite/kernels/rknpu/subgraph_compute.cc b/lite/kernels/rknpu/subgraph_compute.cc index e0b63205705609b6899918ce8e254ccdf6cbad47..a50505c38c0740f762256cd71e006caf9249838e 100644 --- a/lite/kernels/rknpu/subgraph_compute.cc +++ b/lite/kernels/rknpu/subgraph_compute.cc @@ -28,13 +28,36 @@ namespace lite { namespace kernels { namespace rknpu { -int SubgraphEngine::BuildDeviceProgram() { +bool SubgraphEngine::PrepareWorkspaceForDeviceProgram() { + // Obtain the origin input tensors, and create the origin output + // tensors(Don't try to access them before launch the device program or the + // origin program) + PrepareWorkspaceForOriginProgram(); + // Create the device input and output tensors, but don't initialize them + // with the dimensions + device_itensors_.resize(input_names_.size()); + for (int i = 0; i < input_names_.size(); i++) { + device_itensors_[i].reset(new hiai::AiTensor); + CHECK(device_itensors_[i]); + } + device_otensors_.resize(output_names_.size()); + for (int i = 0; i < output_names_.size(); i++) { + device_otensors_[i].reset(new hiai::AiTensor); + CHECK(device_otensors_[i]); + } + return true; +} + +bool SubgraphEngine::BuildDeviceProgram() { LOG(INFO) << "[RKNPU]:BuildDeviceProgram"; int status = 0; // Convert all of ops and their input vars and weights and added into the NPU // RKNPU IR graph subgraph::rknpu::Graph graph; const auto& bridges = subgraph::Registry::Instance(); + if (origin_program_.empty()) { + BuildOriginProgram(); + } for (auto& inst : origin_program_) { auto op = const_cast(inst.op()); CHECK(op); @@ -42,13 +65,13 @@ int SubgraphEngine::BuildDeviceProgram() { op->InferShape(); std::string op_type = op->op_info()->Type(); if (!bridges.Exists(op_type, TARGET(kRKNPU))) { - return subgraph::FAILED; + return false; } auto kernel = inst.kernel(); status |= bridges.Select(op_type, TARGET(kRKNPU))( reinterpret_cast(&graph), op, const_cast(kernel)); if (subgraph::CHECK_FAILED(status)) { - return subgraph::FAILED; + return false; } } // Collect the valid input and output nodes in the RKNPU IR graph and update @@ -91,7 +114,7 @@ int SubgraphEngine::BuildDeviceProgram() { model_name_, graph.GetHandle(), device_itensors_, device_otensors_); if (device_program_ == nullptr) { LOG(WARNING) << "[RKNPU] Build model failed!"; - return subgraph::FAILED; + return false; } // input @@ -165,10 +188,10 @@ int SubgraphEngine::BuildDeviceProgram() { break; } } - return status; + return true; } -int SubgraphEngine::LaunchDeviceProgram() { +bool SubgraphEngine::LaunchDeviceProgram() { LOG(INFO) << "[RKNPU]:LaunchDeviceProgram"; std::vector inputs; std::vector outputs; @@ -195,7 +218,7 @@ int SubgraphEngine::LaunchDeviceProgram() { device_program_->SetInputs(inputs); device_program_->Run(); device_program_->GetOutputs(outputs); - return 0; + return true; } void SubgraphCompute::PrepareForRun() { @@ -208,13 +231,12 @@ void SubgraphCompute::PrepareForRun() { param.output_data_names, param.scope)); CHECK(engine_); - engine_->Build(); } void SubgraphCompute::Run() { LOG(INFO) << "[RKNPU]:Run"; CHECK(engine_); - engine_->Launch(); + engine_->Run(); } } // namespace rknpu diff --git a/lite/kernels/rknpu/subgraph_compute.h b/lite/kernels/rknpu/subgraph_compute.h index 863e6aef39ad54f0e9d94d4b507c6fca4128ebb8..a4bdadc658a81decd8107072f7b5948613d0c68a 100644 --- a/lite/kernels/rknpu/subgraph_compute.h +++ b/lite/kernels/rknpu/subgraph_compute.h @@ -42,14 +42,15 @@ class SubgraphEngine : public subgraph::Engine { ctx, block_idx, block_desc, input_names, output_names, scope) {} protected: - int BuildDeviceProgram() override; - int LaunchDeviceProgram() override; + bool PrepareWorkspaceForDeviceProgram() override; + bool BuildDeviceProgram() override; + bool LaunchDeviceProgram() override; std::string model_name_; std::vector device_inames_; std::vector device_onames_; - std::vector> device_itensors_; - std::vector> device_otensors_; + std::vector> device_itensors_{}; + std::vector> device_otensors_{}; std::unique_ptr device_program_{nullptr}; }; diff --git a/lite/kernels/x86/slice_compute.h b/lite/kernels/x86/slice_compute.h index ad30215691cde66ab1c7c8c57930fc6d58de7cd5..d32327668bac389e42ff9411be50ce3df42e39ff 100644 --- a/lite/kernels/x86/slice_compute.h +++ b/lite/kernels/x86/slice_compute.h @@ -157,7 +157,7 @@ void slice_compute(const lite::Tensor* in, } } - out->mutable_data(lite::TargetType::kX86); + out->mutable_data(); auto new_out_dims = out->dims(); auto offsets = Eigen::array(); diff --git a/lite/kernels/xpu/CMakeLists.txt b/lite/kernels/xpu/CMakeLists.txt index 7ded008387b7d7c92fb2ce6b18e73e1c1e51f29d..fdb485df02f366f7f4868965b1f20c6861b03d43 100644 --- a/lite/kernels/xpu/CMakeLists.txt +++ b/lite/kernels/xpu/CMakeLists.txt @@ -6,6 +6,7 @@ if(LITE_WITH_XTCL) add_subdirectory(bridges) add_kernel(subgraph_compute_xpu XPU basic SRCS subgraph_compute.cc DEPS ${lite_kernel_deps} device_xpu subgraph_bridge_engine ${xpu_subgraph_bridges}) else() + # basic add_kernel(conv_compute_xpu XPU basic SRCS conv_compute.cc DEPS ${lite_kernel_deps}) add_kernel(io_copy_compute_xpu XPU basic SRCS io_copy_compute.cc DEPS ${lite_kernel_deps} target_wrapper_xpu) add_kernel(batch_norm_compute_xpu XPU basic SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps}) @@ -15,15 +16,32 @@ else() add_kernel(mul_compute_xpu XPU basic SRCS mul_compute.cc DEPS ${lite_kernel_deps}) add_kernel(softmax_compute_xpu XPU basic SRCS softmax_compute.cc DEPS ${lite_kernel_deps}) add_kernel(scale_compute_xpu XPU basic SRCS scale_compute.cc DEPS ${lite_kernel_deps}) - add_kernel(lookup_table_compute_xpu XPU basic SRCS lookup_table_compute.cc DEPS ${lite_kernel_deps}) - add_kernel(layer_norm_compute_xpu XPU basic SRCS layer_norm_compute.cc DEPS ${lite_kernel_deps}) add_kernel(dropout_compute_xpu XPU basic SRCS dropout_compute.cc DEPS ${lite_kernel_deps}) add_kernel(matmul_compute_xpu XPU basic SRCS matmul_compute.cc DEPS ${lite_kernel_deps}) add_kernel(stack_compute_xpu XPU basic SRCS stack_compute.cc DEPS ${lite_kernel_deps}) add_kernel(slice_compute_xpu XPU basic SRCS slice_compute.cc DEPS ${lite_kernel_deps}) add_kernel(cast_compute_xpu XPU basic SRCS cast_compute.cc DEPS ${lite_kernel_deps}) + add_kernel(sequence_topk_avg_pooling_compute_xpu XPU basic SRCS sequence_topk_avg_pooling_compute.cc DEPS ${lite_kernel_deps}) + add_kernel(concat_compute_xpu XPU basic SRCS concat_compute.cc DEPS ${lite_kernel_deps}) + add_kernel(search_fc_compute_xpu XPU basic SRCS search_fc_compute.cc DEPS ${lite_kernel_deps}) + + # extra + add_kernel(lookup_table_compute_xpu XPU extra SRCS lookup_table_compute.cc DEPS ${lite_kernel_deps}) + add_kernel(layer_norm_compute_xpu XPU extra SRCS layer_norm_compute.cc DEPS ${lite_kernel_deps}) + add_kernel(sequence_reverse_compute_xpu XPU extra SRCS sequence_reverse_compute.cc DEPS ${lite_kernel_deps}) + add_kernel(sequence_concat_compute_xpu XPU extra SRCS sequence_concat_compute.cc DEPS ${lite_kernel_deps}) + add_kernel(sequence_arithmetic_compute_xpu XPU extra SRCS sequence_arithmetic_compute.cc DEPS ${lite_kernel_deps}) + add_kernel(sequence_pool_compute_xpu XPU extra SRCS sequence_pool_compute.cc DEPS ${lite_kernel_deps}) + add_kernel(match_matrix_tensor_compute_xpu XPU extra SRCS match_matrix_tensor_compute.cc DEPS ${lite_kernel_deps}) + add_kernel(var_conv_2d_compute_xpu XPU extra SRCS var_conv_2d_compute.cc DEPS ${lite_kernel_deps}) + add_kernel(search_grnn_compute_xpu XPU extra SRCS search_grnn_compute.cc DEPS ${lite_kernel_deps}) + + # extra(fused kernel) add_kernel(__xpu__resnet50_compute_xpu XPU extra SRCS __xpu__resnet50_compute.cc DEPS ${lite_kernel_deps}) + add_kernel(__xpu__resnet_cbam_compute_xpu XPU extra SRCS __xpu__resnet_cbam_compute.cc DEPS ${lite_kernel_deps}) add_kernel(__xpu__multi_encoder_compute_xpu XPU extra SRCS __xpu__multi_encoder_compute.cc DEPS ${lite_kernel_deps}) add_kernel(__xpu__embedding_with_eltwise_add_compute_xpu XPU extra SRCS __xpu__embedding_with_eltwise_add_compute.cc DEPS ${lite_kernel_deps}) add_kernel(__xpu__fc_compute_xpu XPU extra SRCS __xpu__fc_compute.cc DEPS ${lite_kernel_deps}) + add_kernel(__xpu__search_attention_compute_xpu XPU extra SRCS __xpu__search_attention_compute.cc DEPS ${lite_kernel_deps}) + add_kernel(__xpu__mmdnn_compute_xpu XPU extra SRCS __xpu__mmdnn_compute.cc DEPS ${lite_kernel_deps}) endif() diff --git a/lite/kernels/xpu/__xpu__mmdnn_compute.cc b/lite/kernels/xpu/__xpu__mmdnn_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..39ddecb1139073cb1a0bd8e3c7afc89f1d739da8 --- /dev/null +++ b/lite/kernels/xpu/__xpu__mmdnn_compute.cc @@ -0,0 +1,1386 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/backends/xpu/xpu_header_sitter.h" +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +namespace { + +void FillMax(float max, float* xpu_ptr) { + float maxs[4] = {max, 0.0f, 0.0f, 0.0f}; + xpu_memcpy( + xpu_ptr, maxs, 4 * sizeof(float), XPUMemcpyKind::XPU_HOST_TO_DEVICE); +} + +void GrnnLayout(int batch, + const std::vector& offset, + std::vector* new_offset_ptr, + std::vector* idx_sorted_ptr) { + auto& new_offset = *new_offset_ptr; + auto& idx_sorted = *idx_sorted_ptr; + + std::vector width; + width.resize(batch); + new_offset.clear(); + idx_sorted.clear(); + + idx_sorted.resize(batch); + for (int i = 0; i < batch; i++) { + width[i] = offset[i + 1] - offset[i]; + idx_sorted[i] = i; + } + std::sort(idx_sorted.data(), + idx_sorted.data() + batch, + [&width](int a, int b) { return width[a] > width[b]; }); + int max_width = width[idx_sorted[0]]; + new_offset.resize(max_width + 1); + new_offset[0] = 0; + int j = batch - 1; + int last_width = 0; + int sub_row = 0; + int sub_col = 0; + + for (int i = 1; i <= max_width;) { + for (int k = j; k >= 0; --k) { + if (width[idx_sorted[k]] > last_width) { + sub_row = width[idx_sorted[k]] - last_width; + sub_col = k + 1; + for (int s = 0; s < sub_row; s++) { + new_offset[i] = new_offset[i - 1] + sub_col; + i++; + } + // move on + last_width = width[idx_sorted[k]]; + j = k - 1; + break; + } + } + } +} + +} // anonymous namespace + +class MMDNNIdInfo { + XPUScratchPadGuard l3_buffer_guard_; + char* l3_buffer_{nullptr}; + std::unique_ptr cpu_buffer_guard_; + char* cpu_buffer_{nullptr}; + + public: + const int64_t* id0_64{nullptr}; + const int64_t* id1_64{nullptr}; + int64_t* lod_64{nullptr}; + int* lod_32{nullptr}; + int* new_offset_32{nullptr}; + int* idx_sorted_32{nullptr}; + + std::vector lod; + std::vector new_offset; + std::vector idx_sorted; + int batch; + int seqlen_max; + int seqlen_sum; + int seqlen_square_sum; + + void Init(int upper_bound_batch, int upper_bound_seqlen) { + int ub_lod_64_size = (upper_bound_batch + 1) * sizeof(int64_t); + int ub_lod_32_size = (upper_bound_batch + 1) * sizeof(int); + int ub_new_offset_32_size = (upper_bound_seqlen + 1) * sizeof(int); + int ub_idx_sorted_32_size = (upper_bound_batch + 1) * sizeof(int); + int total_size = ub_lod_64_size + ub_lod_32_size + ub_new_offset_32_size + + ub_idx_sorted_32_size; + + // TODO(miaotianxiang): use l3? + l3_buffer_guard_ = TargetWrapperXPU::MallocScratchPad(total_size, false); + l3_buffer_ = reinterpret_cast(l3_buffer_guard_->addr_); + cpu_buffer_guard_.reset(new char[total_size]); + cpu_buffer_ = cpu_buffer_guard_.get(); + } + + void Update(lite::Tensor* id0, lite::Tensor* id1) { + auto& id0_lod = id0->lod()[0]; + lod.clear(); + for (auto e : id0_lod) { + lod.push_back(e); + } + + seqlen_max = 0; + seqlen_sum = 0; + seqlen_square_sum = 0; + batch = lod.size() - 1; + for (int i = 0; i < batch; i++) { + int seqlen = lod[i + 1] - lod[i]; + seqlen_max = std::max(seqlen_max, seqlen); + seqlen_sum = seqlen_sum + seqlen; + seqlen_square_sum = seqlen_square_sum + seqlen * seqlen; + } + GrnnLayout(batch, lod, &new_offset, &idx_sorted); + + id0_64 = id0->data(); + id1_64 = id1->data(); + + int offset = 0; + lod_64 = reinterpret_cast(l3_buffer_ + offset); + memcpy( + cpu_buffer_ + offset, id0_lod.data(), id0_lod.size() * sizeof(int64_t)); + offset += id0_lod.size() * sizeof(int64_t); + lod_32 = reinterpret_cast(l3_buffer_ + offset); + memcpy(cpu_buffer_ + offset, lod.data(), lod.size() * sizeof(int)); + offset += lod.size() * sizeof(int); + new_offset_32 = reinterpret_cast(l3_buffer_ + offset); + memcpy(cpu_buffer_ + offset, + new_offset.data(), + new_offset.size() * sizeof(int)); + offset += new_offset.size() * sizeof(int); + idx_sorted_32 = reinterpret_cast(l3_buffer_ + offset); + memcpy(cpu_buffer_ + offset, + idx_sorted.data(), + idx_sorted.size() * sizeof(int)); + offset += idx_sorted.size() * sizeof(int); + xpu_memcpy( + l3_buffer_, cpu_buffer_, offset, XPUMemcpyKind::XPU_HOST_TO_DEVICE); + } +}; + +class MMDNNFcOp { + const int16_t* weight_{nullptr}; + XPUScratchPadGuard weight_max_guard_; + float* weight_max_{nullptr}; + const float* bias_{nullptr}; + XPUScratchPadGuard in_max_guard_; + float* in_max_{nullptr}; + int n_; + int k_; + xdnn::Activation_t::act_enum act_type_; + XPUScratchPadGuard out_max_guard_; + + public: + float* out_max{nullptr}; + + void Init(const int16_t* weight, + float weight_max, + const float* bias, + int n, + int k, + xdnn::Activation_t::act_enum act_type) { + n_ = n; + k_ = k; + act_type_ = act_type; + + weight_ = weight; + weight_max_guard_ = + TargetWrapperXPU::MallocScratchPad(4 * sizeof(float), false); + weight_max_ = reinterpret_cast(weight_max_guard_->addr_); + FillMax(weight_max, weight_max_); + + bias_ = bias; + + in_max_guard_ = + TargetWrapperXPU::MallocScratchPad(4 * sizeof(float), false); + out_max_guard_ = + TargetWrapperXPU::MallocScratchPad(4 * sizeof(float), false); + in_max_ = reinterpret_cast(in_max_guard_->addr_); + out_max = reinterpret_cast(in_max_guard_->addr_); + } + + void Init(lite::Tensor* weight, + float weight_max, + lite::Tensor* bias, + int n, + int k, + xdnn::Activation_t::act_enum act_type) { + Init(weight->data(), + weight_max, + bias ? bias->data() : nullptr, + n, + k, + act_type); + } + + void Infer(xdnn::Context* ctx, + const float* in, + int m, + float* out, + const float* in_max_by_caller = nullptr) { + if (in_max_by_caller == nullptr) { + xdnn::findmax(ctx, in, m * k_, in_max_); + in_max_by_caller = in_max_; + } + xdnn::gemm_int16_maxptr(ctx, + false, + true, + m, + n_, + k_, + 1.0f, + in, + k_, + weight_, + k_, + 0.0f, + out, + n_, + bias_, + act_type_, + in_max_by_caller, + weight_max_, + out_max); + } +}; + +class MMDNNGrnnOp { + MMDNNFcOp fc_e2h0_; + MMDNNFcOp fc_e2h1_; + MMDNNFcOp fc_e2h2_; + const int16_t* dense_h2h_{nullptr}; + float dense_h2h_max_[3]; + XPUScratchPadGuard input_max_guard_; + float* input_max_{nullptr}; + XPUScratchPadGuard hbm_buffer_guard_; + float* hbm_buffer_{nullptr}; + // require: cap_l * max(cap_e_, cap_h_) * 5 + // seq2batch_out: [cap_l, cap_e_] + // fc_e2h_out: [3, cap_l, cap_h_] + // gru_out: [cap_l, cap_h_] + int cap_e_; + int cap_h_; + int max_cap_l_; + + public: + void Init(lite::Tensor* wh, + const std::vector& wh_maxs, + lite::Tensor* wi, + const std::vector& wi_maxs, + int cap_e, + int cap_h, + int max_cap_l) { + cap_e_ = cap_e; + cap_h_ = cap_h; + max_cap_l_ = max_cap_l; + + // weight + auto* dense_e2h = wi->data(); + fc_e2h0_.Init(dense_e2h, + wi_maxs[0], + nullptr, + cap_h_, + cap_e_, + xdnn::Activation_t::LINEAR); + fc_e2h1_.Init(dense_e2h + cap_e_ * cap_h_, + wi_maxs[1], + nullptr, + cap_h_, + cap_e_, + xdnn::Activation_t::LINEAR); + fc_e2h2_.Init(dense_e2h + cap_e_ * cap_h_ * 2, + wi_maxs[2], + nullptr, + cap_h_, + cap_e_, + xdnn::Activation_t::LINEAR); + + dense_h2h_ = wh->data(); + dense_h2h_max_[0] = wh_maxs[0]; + dense_h2h_max_[1] = wh_maxs[1]; + dense_h2h_max_[2] = wh_maxs[2]; + + input_max_guard_ = + TargetWrapperXPU::MallocScratchPad(4 * sizeof(float), false); + input_max_ = reinterpret_cast(input_max_guard_->addr_); + hbm_buffer_guard_ = TargetWrapperXPU::MallocScratchPad( + 5 * std::max(cap_e_, cap_h_) * max_cap_l_ * sizeof(float), false); + hbm_buffer_ = reinterpret_cast(hbm_buffer_guard_->addr_); + } + + void Infer(xdnn::Context* ctx, + const MMDNNIdInfo& sentense, + const float* in, + float* out, + float* l3_buffer = nullptr, + int l3_size = 0) { + int batch = sentense.batch; + int cap_l = sentense.seqlen_sum; + int max_width = sentense.seqlen_max; + + int slot_size = cap_l * std::max(cap_e_, cap_h_); + float* seq2batch_out = hbm_buffer_; + float* fc_e2h_out = hbm_buffer_ + 1 * slot_size; + float* gru_out = hbm_buffer_ + 4 * slot_size; + if (l3_size > 0 && l3_size >= 5 * slot_size * sizeof(float)) { + seq2batch_out = l3_buffer; + fc_e2h_out = l3_buffer + 1 * slot_size; + gru_out = l3_buffer + 4 * slot_size; + } + + xdnn::search_seq2batch(ctx, + batch, + max_width, + cap_e_, + sentense.idx_sorted_32, + sentense.lod_32, + sentense.new_offset_32, + in, + seq2batch_out); + + xdnn::findmax(ctx, in, cap_l * cap_e_, input_max_); + fc_e2h0_.Infer(ctx, seq2batch_out, cap_l, fc_e2h_out, input_max_); + fc_e2h1_.Infer( + ctx, seq2batch_out, cap_l, fc_e2h_out + cap_l * cap_h_, input_max_); + fc_e2h2_.Infer( + ctx, seq2batch_out, cap_l, fc_e2h_out + cap_l * cap_h_ * 2, input_max_); + xdnn::search_grnn(ctx, + cap_l, + cap_h_, + cap_e_, + max_width, + sentense.new_offset_32, + fc_e2h_out, + dense_h2h_, + gru_out, + dense_h2h_max_[0], + dense_h2h_max_[1], + dense_h2h_max_[2]); + + xdnn::search_batch2seq(ctx, + batch, + max_width, + cap_h_, + sentense.idx_sorted_32, + sentense.lod_32, + sentense.new_offset_32, + gru_out, + out); + } +}; + +class MMDNNAttentionOp { + int dim_; + float alpha0_; + float alpha1_; + MMDNNFcOp seqfc_; + XPUScratchPadGuard hbm_buffer_guard_; + float* hbm_buffer_{nullptr}; + // require: cap_l * dim_ + seqlen_square_sum + // seqfc_out: [cap_l, dim_] + // batchgemm0_out: [seqlen_square_sum] + // seq_softmax_out: [seqlen_square_sum], reuse of batchgemm0_out + // batchgemm1_out: [cap_l, dim_], reuse of seqfc_out + + public: + void Init(lite::Tensor* att_fc_w, + float att_fc_w_max, + lite::Tensor* att_fc_b, + int dim, + int upper_bound_batch, + int upper_bound_seqlen) { + dim_ = dim; + alpha0_ = 0.0883883461356163f; // TODO(miaotianxiang): + alpha1_ = 1.0f; + + seqfc_.Init(att_fc_w, + att_fc_w_max, + att_fc_b, + dim_, + dim_, + xdnn::Activation_t::LINEAR); + hbm_buffer_guard_ = TargetWrapperXPU::MallocScratchPad( + (upper_bound_batch * (upper_bound_seqlen * dim_ + + upper_bound_seqlen * upper_bound_seqlen)) * + sizeof(float), + false); + hbm_buffer_ = reinterpret_cast(hbm_buffer_guard_->addr_); + } + + void Infer(xdnn::Context* ctx, + const MMDNNIdInfo& sentense, + const float* input, + float* pool_out, + float* l3_buffer = nullptr, + int l3_size = 0) { + int batch = sentense.batch; + int cap_l = sentense.seqlen_sum; + int max_width = sentense.seqlen_max; + int* lod_32 = sentense.lod_32; + + float* seqfc_out = hbm_buffer_; + float* batchgemm0_out = hbm_buffer_ + cap_l * dim_; + float* seq_softmax_out = batchgemm0_out; + float* batchgemm1_out = seqfc_out; + if (l3_size > 0 && + l3_size >= + (cap_l * dim_ + sentense.seqlen_square_sum) * sizeof(float)) { + seqfc_out = l3_buffer; + batchgemm0_out = l3_buffer + cap_l * dim_; + seq_softmax_out = batchgemm0_out; + batchgemm1_out = seqfc_out; + } + + seqfc_.Infer(ctx, input, cap_l, seqfc_out); + xdnn::search_noaligned_mat_mul(ctx, + 0, + 1, + batch, + lod_32, + max_width, + dim_, + alpha0_, + input, + seqfc_out, + batchgemm0_out); + xdnn::search_seq_softmax( + ctx, batchgemm0_out, seq_softmax_out, lod_32, batch, max_width); + xdnn::search_noaligned_mat_mul(ctx, + 0, + 0, + batch, + lod_32, + max_width, + dim_, + alpha1_, + seq_softmax_out, + input, + batchgemm1_out); + xdnn::sequence_pooling_forward(ctx, + xdnn::Pooling_t::MAX_WITHOUT_INDEX, + batch, + lod_32, + dim_, + batchgemm1_out, + nullptr, + pool_out); + } +}; + +class MMDNNMatchConvTopk { + std::vector topks_; + int dim_t_; + int dim_in_; + int out_channel_; + + MMDNNFcOp xw_fc_; + const int16_t* conv_weight_{nullptr}; + float conv_weight_max_; + XPUScratchPadGuard hbm_buffer_guard_; + float* hbm_buffer_{nullptr}; + // xw_out: [sum(left_len), dim_t_ * dim_in_] + // xwy_out: [sum(left_len * right_len) * dim_t_] + // conv_out: [sum(left_len * right_len) * out_channel_] + // seq_concat_out: [sum(left_len * right_len) * (dim_t_ + out_channel_)] + + XPUScratchPadGuard left_lod_32_guard_; + int* left_lod_32_{nullptr}; + XPUScratchPadGuard right_lod_32_guard_; + int* right_lod_32_{nullptr}; + XPUScratchPadGuard match_lod_32_guard_; + int* match_lod_32_{nullptr}; + XPUScratchPadGuard conv_lod_32_guard_; + int* conv_lod_32_{nullptr}; + XPUScratchPadGuard topk_offset_32_guard_; + int* topk_offset_32_{nullptr}; + XPUScratchPadGuard topks_xpu_guard_; + int* topks_xpu_{nullptr}; + XPUScratchPadGuard useless_topk_pos_guard_; + int* useless_topk_pos_{nullptr}; + + public: + float* seq_avg_topk_out{nullptr}; + + void Init(lite::Tensor* input_w, + float input_w_max, + lite::Tensor* conv_w, + float conv_w_max, + int dim_t, + int dim_in, + int upper_bound_batch, + int upper_bound_seqlen, + const std::vector& topks) { + dim_t_ = dim_t; + dim_in_ = dim_in; + out_channel_ = 5; // TODO(miaotianxiang): + topks_ = topks; + + xw_fc_.Init(input_w, + input_w_max, + nullptr, + dim_t_ * dim_in_, + dim_in_, + xdnn::Activation_t::LINEAR); + conv_weight_ = conv_w->data(); + conv_weight_max_ = conv_w_max; + + hbm_buffer_guard_ = TargetWrapperXPU::MallocScratchPad( + (upper_bound_batch * upper_bound_seqlen * dim_t_ * dim_in_ + + upper_bound_batch * upper_bound_seqlen * upper_bound_seqlen * + (dim_t_ + out_channel_) * 2) * + sizeof(float), + false); + hbm_buffer_ = reinterpret_cast(hbm_buffer_guard_->addr_); + + left_lod_32_guard_ = TargetWrapperXPU::MallocScratchPad( + (upper_bound_batch + 1) * sizeof(int), false); + left_lod_32_ = reinterpret_cast(left_lod_32_guard_->addr_); + right_lod_32_guard_ = TargetWrapperXPU::MallocScratchPad( + (upper_bound_batch + 1) * sizeof(int), false); + right_lod_32_ = reinterpret_cast(right_lod_32_guard_->addr_); + match_lod_32_guard_ = TargetWrapperXPU::MallocScratchPad( + (upper_bound_batch + 1) * sizeof(int), false); + match_lod_32_ = reinterpret_cast(match_lod_32_guard_->addr_); + conv_lod_32_guard_ = TargetWrapperXPU::MallocScratchPad( + (upper_bound_batch + 1) * sizeof(int), false); + conv_lod_32_ = reinterpret_cast(conv_lod_32_guard_->addr_); + topk_offset_32_guard_ = TargetWrapperXPU::MallocScratchPad( + (upper_bound_batch + 1) * sizeof(int), false); + topk_offset_32_ = reinterpret_cast(topk_offset_32_guard_->addr_); + topks_xpu_guard_ = + TargetWrapperXPU::MallocScratchPad(topks_.size() * sizeof(int), false); + topks_xpu_ = reinterpret_cast(topks_xpu_guard_->addr_); + xpu_memcpy(topks_xpu_, + topks_.data(), + topks_.size() * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE); + useless_topk_pos_guard_ = + TargetWrapperXPU::MallocScratchPad(4 * sizeof(int), false); + useless_topk_pos_ = reinterpret_cast(useless_topk_pos_guard_->addr_); + } + + void Infer(xdnn::Context* ctx, + lite::Tensor* left, + lite::Tensor* right, + lite::Tensor* out, + float* l3_buffer = nullptr, + int l3_size = 0) { + auto left_lod = left->lod()[0]; + auto right_lod = right->lod()[0]; + int batch = left_lod.size() - 1; + + std::vector left_lod_32_cpu; + for (auto e : left_lod) { + left_lod_32_cpu.push_back(e); + } + xpu_memcpy(left_lod_32_, + left_lod_32_cpu.data(), + left_lod_32_cpu.size() * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE); + std::vector right_lod_32_cpu; + for (auto e : right_lod) { + right_lod_32_cpu.push_back(e); + } + xpu_memcpy(right_lod_32_, + right_lod_32_cpu.data(), + right_lod_32_cpu.size() * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE); + + std::vector lod_match = {0}; + std::vector lod_conv = {0}; + std::vector lod_topk = {0}; + int x_mul_y_sum = 0; + int left_seqlen_sum = 0; + int left_seqlen_max = 0; + int right_seqlen_sum = 0; + int right_seqlen_max = 0; + for (int i = 0; i < batch; i++) { + int len_x = left_lod[i + 1] - left_lod[i]; + int len_y = right_lod[i + 1] - right_lod[i]; + int imgsize = len_x * len_y; + x_mul_y_sum = x_mul_y_sum + imgsize; + lod_match.push_back(lod_match.back() + imgsize * dim_t_); + lod_conv.push_back(lod_conv.back() + imgsize * out_channel_); + lod_topk.push_back(lod_topk.back() + imgsize * (dim_t_ + out_channel_)); + + left_seqlen_max = std::max(left_seqlen_max, len_x); + right_seqlen_max = std::max(right_seqlen_max, len_y); + left_seqlen_sum += len_x; + right_seqlen_sum += len_y; + } + xpu_memcpy(match_lod_32_, + lod_match.data(), + lod_match.size() * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE); + xpu_memcpy(conv_lod_32_, + lod_conv.data(), + lod_conv.size() * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE); + xpu_memcpy(topk_offset_32_, + lod_topk.data(), + lod_topk.size() * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE); + + float* xwy_out = hbm_buffer_; + float* conv_out = hbm_buffer_ + x_mul_y_sum * dim_t_; + float* seq_concat_out = hbm_buffer_ + x_mul_y_sum * (dim_t_ + out_channel_); + float* xw_out = hbm_buffer_ + x_mul_y_sum * (dim_t_ + out_channel_) * 2; + int total_len = x_mul_y_sum * (dim_t_ + out_channel_) * 2 + + left_seqlen_sum * dim_t_ * dim_in_; + if (l3_size > 0 && l3_size >= total_len * sizeof(float)) { + xwy_out = l3_buffer; + conv_out = l3_buffer + x_mul_y_sum * dim_t_; + seq_concat_out = l3_buffer + x_mul_y_sum * (dim_t_ + out_channel_); + xw_out = l3_buffer + x_mul_y_sum * (dim_t_ + out_channel_) * 2; + } + seq_avg_topk_out = out->mutable_data(TARGET(kXPU)); + + int max_width = std::max(left_seqlen_max, right_seqlen_max); + xw_fc_.Infer(ctx, left->data(), left_seqlen_sum, xw_out); + xdnn::match_matrix_tensor(ctx, + batch, + xw_out, + right->data(), + left_lod_32_, + right_lod_32_, + dim_t_, + dim_in_, + xwy_out, + xw_fc_.out_max, + xdnn::Activation_t::RELU, + max_width); + xdnn::search_varconv( + ctx, + batch, + dim_t_, + out_channel_, + 5, + 5, + 1, + 1, + xwy_out, + conv_weight_, + right_lod_32_, + left_lod_32_, + conv_out, + conv_weight_max_, + xdnn::Activation_t::RELU); // TODO(miaotianxiang): + xdnn::sequence_concat(ctx, + xwy_out, + match_lod_32_, + conv_out, + conv_lod_32_, + seq_concat_out, + batch); + xdnn::sequence_topk_avg_pooling(ctx, + seq_concat_out, + seq_avg_topk_out, + useless_topk_pos_, + batch, + dim_t_ + out_channel_, + topk_offset_32_, + left_lod_32_, + right_lod_32_, + topks_xpu_, + topks_.size()); + } +}; + +class MMDNNBidEmbGrnnAtt { + const float* table_{nullptr}; + int table_len_; + int emb_dim_; + int cap_h_; + MMDNNGrnnOp bi_fw_; + MMDNNGrnnOp bi_rv_; + MMDNNAttentionOp att_; + XPUScratchPadGuard hbm_buffer_guard_; + float* hbm_buffer_{nullptr}; + // require at least: 4 * cap_l * emb_dim_ + // emb_rv: [cap_l, emb_dim_] + // grnn_fw: [cap_l, emb_dim_] + // grnn_rv: [cap_l, emb_dim_] + // grnn_rv_rv: [cap_l, emb_dim_] + // concat_2in: [cap_l, 2 * emb_dim_] + // L3.bi_fw: 5 * cap_l * emb_dim_ + // L3.bi_rv: 5 * cap_l * emb_dim_ + // L3.att: cap_l * 2 * emb_dim_ + seqlen_square_sum + + // execution-plan: + // 1. bid_emb_ew, alloc(emb_rv) + // 2. bi_rv, alloc(grnn_rv) + // 3. free(emb_rv) + // 4. sequence_reverse, alloc(grnn_rv_rv) + // 5. sequence_pooling(grnn_rv) + // 6. free(grnn_rv) + // 7. bi_fw alloc(grnn_fw) + // 8. sequence_pooling(grnn_fw) + // 9. concat_2 alloc(concat_2in) + // 10. concat_3 + // 11. att + + // alloc-plan: + // [0]: emb_rv, grnn_rv_rv + // [1]: grnn_rv, grnn_fw + // [2, 3]: concat_2in + // [2, 3, 4, 5, 6]: L3.bi_fw, L3.bi_rv + // [4, 5, ..., ?]: L3.att + + public: + float* emb_fw{nullptr}; + float* concat_3in{nullptr}; + float* pool_fw{nullptr}; + float* pool_rv{nullptr}; + float* att_out{nullptr}; + + void Init(lite::Tensor* table, + lite::Tensor* fw_wh, + const std::vector& fw_wh_maxs, + lite::Tensor* fw_wi, + const std::vector& fw_wi_maxs, + lite::Tensor* rv_wh, + const std::vector& rv_wh_maxs, + lite::Tensor* rv_wi, + const std::vector& rv_wi_maxs, + lite::Tensor* att_fc_w, + float att_fc_w_max, + lite::Tensor* att_fc_b, + int upper_bound_batch, + int upper_bound_seqlen) { + table_ = table->data(); + table_len_ = table->dims()[0]; + emb_dim_ = table->dims()[1]; + cap_h_ = emb_dim_; + int max_cap_l = upper_bound_batch * upper_bound_seqlen; + + bi_fw_.Init( + fw_wh, fw_wh_maxs, fw_wi, fw_wi_maxs, emb_dim_, cap_h_, max_cap_l); + bi_rv_.Init( + rv_wh, rv_wh_maxs, rv_wi, rv_wi_maxs, emb_dim_, cap_h_, max_cap_l); + att_.Init(att_fc_w, + att_fc_w_max, + att_fc_b, + 2 * cap_h_, + upper_bound_batch, + upper_bound_seqlen); + + hbm_buffer_guard_ = TargetWrapperXPU::MallocScratchPad( + 4 * max_cap_l * cap_h_ * sizeof(float), false); + hbm_buffer_ = reinterpret_cast(hbm_buffer_guard_->addr_); + } + + void Infer(xdnn::Context* ctx, + int batch, + const MMDNNIdInfo& sentense, + lite::Tensor* grnn_fw_pool_out, + lite::Tensor* grnn_rv_pool_out, + lite::Tensor* att_pool_out, + lite::Tensor* concat_3in1_out, + lite::Tensor* emb_fw_out, + float* l3_buffer = nullptr, + int l3_size = 0) { + int cap_l = sentense.seqlen_sum; + int slot_len = cap_l * cap_h_; + + float* emb_rv = hbm_buffer_; + float* grnn_fw = hbm_buffer_ + slot_len; + float* grnn_rv = hbm_buffer_ + slot_len; + float* grnn_rv_rv = hbm_buffer_; + float* concat_2in = hbm_buffer_ + 2 * slot_len; + if (l3_size > 0 && l3_size >= 4 * slot_len * sizeof(float)) { + emb_rv = l3_buffer; + grnn_fw = l3_buffer + slot_len; + grnn_rv = l3_buffer + slot_len; + grnn_rv_rv = l3_buffer; + } + emb_fw = emb_fw_out->mutable_data(TARGET(kXPU)); + concat_3in = concat_3in1_out->mutable_data(TARGET(kXPU)); + pool_fw = grnn_fw_pool_out->mutable_data(TARGET(kXPU)); + pool_rv = grnn_rv_pool_out->mutable_data(TARGET(kXPU)); + att_out = att_pool_out->mutable_data(TARGET(kXPU)); + + xdnn::search_bid_emb_ew(ctx, + batch, + sentense.lod_64, + sentense.id0_64, + sentense.id1_64, + table_, + table_len_, + emb_dim_, + emb_fw, + emb_rv, + table_len_ - 2, + 1); + bi_rv_.Infer(ctx, + sentense, + emb_rv, + grnn_rv, + l3_buffer + 2 * slot_len, + l3_size - 2 * slot_len * sizeof(float)); + xdnn::sequence_reverse( + ctx, batch, sentense.lod_32, cap_h_, grnn_rv, grnn_rv_rv); + xdnn::sequence_pooling_forward(ctx, + xdnn::Pooling_t::LAST, + batch, + sentense.lod_32, + cap_h_, + grnn_rv, + nullptr, + pool_rv); + + bi_fw_.Infer(ctx, + sentense, + emb_fw, + grnn_fw, + l3_buffer + 2 * slot_len, + l3_size - 2 * slot_len * sizeof(float)); + xdnn::sequence_pooling_forward(ctx, + xdnn::Pooling_t::LAST, + batch, + sentense.lod_32, + cap_h_, + grnn_fw, + nullptr, + pool_fw); + const int concat_widths[] = {cap_h_, cap_h_, cap_h_}; + const float* concat_ptrs[] = {emb_fw, grnn_fw, grnn_rv_rv}; + xdnn::concat( + ctx, cap_l, concat_widths + 1, 2, concat_ptrs + 1, concat_2in); + xdnn::concat(ctx, cap_l, concat_widths, 3, concat_ptrs, concat_3in); + att_.Infer(ctx, + sentense, + concat_2in, + att_out, + l3_buffer + 4 * slot_len, + l3_size - 4 * slot_len * sizeof(float)); + } +}; + +class MMDNNEmbAtt { + const float* table_{nullptr}; + int table_len_; + int emb_dim_; + MMDNNAttentionOp att_; + + public: + float* emb_fw{nullptr}; + float* att_out{nullptr}; + + void Init(lite::Tensor* table, + lite::Tensor* att_fc_w, + float att_fc_w_max, + lite::Tensor* att_fc_b, + int upper_bound_batch, + int upper_bound_seqlen) { + table_ = table->data(); + table_len_ = table->dims()[0]; + emb_dim_ = table->dims()[1]; + att_.Init(att_fc_w, + att_fc_w_max, + att_fc_b, + emb_dim_, + upper_bound_batch, + upper_bound_seqlen); + } + + void Infer(xdnn::Context* ctx, + int batch, + const MMDNNIdInfo& sentense, + lite::Tensor* att_pool_out, + lite::Tensor* emb_fw_out, + float* l3_buffer = nullptr, + int l3_size = 0) { + emb_fw = emb_fw_out->mutable_data(TARGET(kXPU)); + att_out = att_pool_out->mutable_data(TARGET(kXPU)); + + int cap_l = sentense.lod.back(); + const float* emb_tables[] = {table_, table_}; + const int64_t* emb_indices[] = {sentense.id0_64, sentense.id1_64}; + xdnn::embedding_with_ewadd(ctx, + emb_dim_, + cap_l, + 2, + table_len_ - 2, + emb_tables, + emb_indices, + nullptr, + nullptr, + emb_fw); + att_.Infer(ctx, sentense, emb_fw, att_out, l3_buffer, l3_size); + } +}; + +class MMDNNMergeAll { + MMDNNGrnnOp coverage_fw_; + MMDNNGrnnOp coverage_rv_; + int cap_e_; + int cap_h_; + + // TODO(miaotianxiang): + const int fc0_k_ = 1152; + const int fc0_n_ = 512; + const int fc1_k_ = 640; + const int fc1_n_ = 320; + const int fc2_k_ = 320; + const int fc2_n_ = 1; + MMDNNFcOp fc0_; + MMDNNFcOp fc1_; + MMDNNFcOp fc2_; + + XPUScratchPadGuard hbm_buffer_guard_; + float* hbm_buffer_{nullptr}; + // topk_concat_out_fw: [cap_l, cap_e_] <= [cap_l, cap_h_] + // topk_concat_out_rv: [cap_l, cap_e_] <= [cap_l, cap_h_] + // grnn_fw: [cap_l, cap_h_] + // grnn_rv: [cap_l, cap_h_] + // pool_fw: [batch, cap_h_] + // pool_rv: [batch, cap_h_] + // fc0_in: [batch, fc0_k_] + // fc0_out: [batch, fc0_n_] + // fc1_in: [batch, fc1_k_] + // fc1_out: [batch, fc1_n_] + // fc2_out: [batch, fc2_n_] + + public: + void Init(lite::Tensor* grnn_fw_wh, + std::vector grnn_fw_wh_maxs, + lite::Tensor* grnn_fw_wi, + std::vector grnn_fw_wi_maxs, + lite::Tensor* grnn_rv_wh, + std::vector grnn_rv_wh_maxs, + lite::Tensor* grnn_rv_wi, + std::vector grnn_rv_wi_maxs, + lite::Tensor* fc0_w, + float fc0_w_max, + lite::Tensor* fc0_b, + lite::Tensor* fc1_w, + float fc1_w_max, + lite::Tensor* fc1_b, + lite::Tensor* fc2_w, + float fc2_w_max, + lite::Tensor* fc2_b, + int upper_bound_batch, + int upper_bound_seqlen) { + int max_cap_l = upper_bound_batch * upper_bound_seqlen; + cap_e_ = grnn_fw_wi->dims()[2]; + cap_h_ = grnn_fw_wi->dims()[1]; + + coverage_fw_.Init(grnn_fw_wh, + grnn_fw_wh_maxs, + grnn_fw_wi, + grnn_fw_wi_maxs, + cap_e_, + cap_h_, + max_cap_l); + coverage_rv_.Init(grnn_rv_wh, + grnn_rv_wh_maxs, + grnn_rv_wi, + grnn_rv_wi_maxs, + cap_e_, + cap_h_, + max_cap_l); + + fc0_.Init( + fc0_w, fc0_w_max, fc0_b, fc0_n_, fc0_k_, xdnn::Activation_t::RELU); + fc1_.Init( + fc1_w, fc1_w_max, fc1_b, fc1_n_, fc1_k_, xdnn::Activation_t::RELU); + fc2_.Init( + fc2_w, fc2_w_max, fc2_b, fc2_n_, fc2_k_, xdnn::Activation_t::LINEAR); + + int hbm_total_len = max_cap_l * cap_h_ * 4 + + upper_bound_batch * (2 * cap_h_ + fc0_k_ + fc0_n_ + + fc1_k_ + fc1_n_ + fc2_n_); + hbm_buffer_guard_ = TargetWrapperXPU::MallocScratchPad( + hbm_total_len * sizeof(float), false); + hbm_buffer_ = reinterpret_cast(hbm_buffer_guard_->addr_); + } + + void Infer(xdnn::Context* ctx, + const MMDNNIdInfo& sentense, + const std::vector concat_2in1_x, + const std::vector concat_7in1_x, + lite::Tensor* out, + float* l3_buffer = nullptr, + int l3_size = 0) { + int batch = sentense.batch; + int cap_l = sentense.seqlen_sum; + + float* topk_concat_out_fw = hbm_buffer_; + int hbm_total_len = + cap_l * cap_h_ * 4 + + batch * (2 * cap_h_ + fc0_k_ + fc0_n_ + fc1_k_ + fc1_n_ + fc2_n_); + if (l3_size > 0 && l3_size >= hbm_total_len * sizeof(float)) { + topk_concat_out_fw = l3_buffer; + } + float* topk_concat_out_rv = topk_concat_out_fw + cap_l * cap_h_; + float* grnn_fw = topk_concat_out_rv + cap_l * cap_h_; + float* grnn_rv = grnn_fw + cap_l * cap_h_; + float* pool_fw = grnn_rv + cap_l * cap_h_; + float* pool_rv = pool_fw + batch * cap_h_; + float* fc0_in = pool_fw + batch * cap_h_ * 2; + float* fc0_out = fc0_in + batch * fc0_k_; + float* fc1_in = fc0_out + batch * fc0_n_; + float* fc1_out = fc1_in + batch * fc1_k_; + // float* fc2_out = fc1_out + batch * fc1_n_; + float* fc2_out = out->mutable_data(TARGET(kXPU)); + + const int concat_widths[] = {static_cast(concat_2in1_x[0]->dims()[1]), + static_cast(concat_2in1_x[1]->dims()[1])}; + const float* concat_ptrs[] = {concat_2in1_x[0]->data(), + concat_2in1_x[1]->data()}; + xdnn::concat( + ctx, cap_l, concat_widths, 2, concat_ptrs, topk_concat_out_fw); + xdnn::sequence_reverse(ctx, + batch, + sentense.lod_32, + cap_e_, + topk_concat_out_fw, + topk_concat_out_rv); + coverage_fw_.Infer(ctx, + sentense, + topk_concat_out_fw, + grnn_fw, + l3_buffer + hbm_total_len, + l3_size - hbm_total_len * sizeof(float)); + coverage_rv_.Infer(ctx, + sentense, + topk_concat_out_rv, + grnn_rv, + l3_buffer + hbm_total_len, + l3_size - hbm_total_len * sizeof(float)); + xdnn::sequence_pooling_forward(ctx, + xdnn::Pooling_t::LAST, + batch, + sentense.lod_32, + cap_h_, + grnn_fw, + nullptr, + pool_fw); + xdnn::sequence_pooling_forward(ctx, + xdnn::Pooling_t::LAST, + batch, + sentense.lod_32, + cap_h_, + grnn_rv, + nullptr, + pool_rv); + + const int concat_widths_fc0[] = { + static_cast(concat_7in1_x[0]->dims()[1]), + static_cast(concat_7in1_x[1]->dims()[1]), + static_cast(concat_7in1_x[2]->dims()[1]), + static_cast(concat_7in1_x[3]->dims()[1]), + static_cast(concat_7in1_x[4]->dims()[1]), + static_cast(concat_7in1_x[5]->dims()[1]), + static_cast(concat_7in1_x[6]->dims()[1]), + }; + const float* concat_ptrs_fc0[] = { + concat_7in1_x[0]->data(), + concat_7in1_x[1]->data(), + concat_7in1_x[2]->data(), + concat_7in1_x[3]->data(), + concat_7in1_x[4]->data(), + concat_7in1_x[5]->data(), + concat_7in1_x[6]->data(), + }; + const int concat_widths_fc1[] = {cap_h_, cap_h_, fc0_n_}; + const float* concat_ptrs_fc1[] = {pool_fw, pool_rv, fc0_out}; + + xdnn::concat( + ctx, batch, concat_widths_fc0, 7, concat_ptrs_fc0, fc0_in); + fc0_.Infer(ctx, fc0_in, batch, fc0_out); + xdnn::concat( + ctx, batch, concat_widths_fc1, 3, concat_ptrs_fc1, fc1_in); + fc1_.Infer(ctx, fc1_in, batch, fc1_out); + fc2_.Infer(ctx, fc1_out, batch, fc2_out); + } +}; + +class XPUMmdnnBidEmbGrnnAttCompute + : public KernelLite { + public: + using param_t = operators::XPUMmdnnBidEmbGrnnAttParam; + + void PrepareForRun() override; + + void Run() override; + + private: + MMDNNIdInfo id_; + MMDNNBidEmbGrnnAtt compound_; + int upper_bound_batch_ = 40; + int upper_bound_seqlen_ = 512; +}; + +void XPUMmdnnBidEmbGrnnAttCompute::PrepareForRun() { + auto& param = this->Param(); + + id_.Init(upper_bound_batch_, upper_bound_seqlen_); + compound_.Init(param.emb_tbl, + param.grnn_fw_wh, + param.grnn_fw_wh_maxs, + param.grnn_fw_wi, + param.grnn_fw_wi_maxs, + param.grnn_rv_wh, + param.grnn_rv_wh_maxs, + param.grnn_rv_wi, + param.grnn_rv_wi_maxs, + param.att_fc_w, + param.att_fc_w_max, + param.att_fc_b, + upper_bound_batch_, + upper_bound_seqlen_); +} + +void XPUMmdnnBidEmbGrnnAttCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->As(); + + auto* xpu_ctx = ctx.GetRawContext(); + + int batch = param.id0->lod()[0].size() - 1; + id_.Update(param.id0, param.id1); + compound_.Infer(ctx.GetRawContext(), + batch, + id_, + param.grnn_fw_pool_out, + param.grnn_rv_pool_out, + param.att_pool_out, + param.concat_3in1_out, + param.emb_fw_out, + reinterpret_cast( + reinterpret_cast(xpu_ctx->workspace_l3_ptr) + + xpu_ctx->used_l3_size), + xpu_ctx->workspace_l3_size - xpu_ctx->used_l3_size); +} + +class XPUMmdnnBidEmbAttCompute + : public KernelLite { + public: + using param_t = operators::XPUMmdnnBidEmbAttParam; + + void PrepareForRun() override; + + void Run() override; + + private: + MMDNNIdInfo id_; + MMDNNEmbAtt compound_; + int upper_bound_batch_ = 40; + int upper_bound_seqlen_ = 512; +}; + +void XPUMmdnnBidEmbAttCompute::PrepareForRun() { + auto& param = this->Param(); + + id_.Init(upper_bound_batch_, upper_bound_seqlen_); + compound_.Init(param.emb_tbl, + param.att_fc_w, + param.att_fc_w_max, + param.att_fc_b, + upper_bound_batch_, + upper_bound_seqlen_); +} + +void XPUMmdnnBidEmbAttCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->As(); + + auto* xpu_ctx = ctx.GetRawContext(); + + int batch = param.id0->lod()[0].size() - 1; + id_.Update(param.id0, param.id1); + compound_.Infer(ctx.GetRawContext(), + batch, + id_, + param.att_pool_out, + param.emb_fw_out, + reinterpret_cast( + reinterpret_cast(xpu_ctx->workspace_l3_ptr) + + xpu_ctx->used_l3_size), + xpu_ctx->workspace_l3_size - xpu_ctx->used_l3_size); +} + +class XPUMmdnnMatchConvTopkCompute + : public KernelLite { + public: + using param_t = operators::XPUMmdnnMatchConvTopkParam; + + void PrepareForRun() override; + + void Run() override; + + private: + MMDNNMatchConvTopk compound_; + int upper_bound_batch_ = 40; + int upper_bound_seqlen_ = 512; +}; + +void XPUMmdnnMatchConvTopkCompute::PrepareForRun() { + auto& param = this->Param(); + + compound_.Init(param.input_w, + param.input_w_max, + param.conv_w, + param.conv_w_max, + param.dim_t, + param.input_w->dims()[0], + upper_bound_batch_, + upper_bound_seqlen_, + param.topks); +} + +void XPUMmdnnMatchConvTopkCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->As(); + + auto* xpu_ctx = ctx.GetRawContext(); + + compound_.Infer(ctx.GetRawContext(), + param.input_x, + param.input_y, + param.topk_out, + reinterpret_cast( + reinterpret_cast(xpu_ctx->workspace_l3_ptr) + + xpu_ctx->used_l3_size), + xpu_ctx->workspace_l3_size - xpu_ctx->used_l3_size); +} + +class XPUMmdnnMergeAllCompute + : public KernelLite { + public: + using param_t = operators::XPUMmdnnMergeAllParam; + + void PrepareForRun() override; + + void Run() override; + + private: + MMDNNIdInfo id_; + MMDNNMergeAll compound_; + int upper_bound_batch_ = 40; + int upper_bound_seqlen_ = 512; +}; + +void XPUMmdnnMergeAllCompute::PrepareForRun() { + auto& param = this->Param(); + + id_.Init(upper_bound_batch_, upper_bound_seqlen_); + compound_.Init(param.grnn_fw_wh, + param.grnn_fw_wh_maxs, + param.grnn_fw_wi, + param.grnn_fw_wi_maxs, + param.grnn_rv_wh, + param.grnn_rv_wh_maxs, + param.grnn_rv_wi, + param.grnn_rv_wi_maxs, + param.fc0_w, + param.fc0_w_max, + param.fc0_b, + param.fc1_w, + param.fc1_w_max, + param.fc1_b, + param.fc2_w, + param.fc2_w_max, + param.fc2_b, + upper_bound_batch_, + upper_bound_seqlen_); +} + +void XPUMmdnnMergeAllCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->As(); + + auto* xpu_ctx = ctx.GetRawContext(); + + id_.Update(param.concat_2in1_x[0], param.concat_2in1_x[1]); + compound_.Infer(ctx.GetRawContext(), + id_, + param.concat_2in1_x, + param.concat_7in1_x, + param.out, + reinterpret_cast( + reinterpret_cast(xpu_ctx->workspace_l3_ptr) + + xpu_ctx->used_l3_size), + xpu_ctx->workspace_l3_size - xpu_ctx->used_l3_size); +} + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(__xpu__mmdnn_bid_emb_grnn_att, + kXPU, + kFloat, + kNCHW, + paddle::lite::kernels::xpu::XPUMmdnnBidEmbGrnnAttCompute, + def) + .BindInput("id0", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))}) + .BindInput("id1", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))}) + .BindInput("emb_tbl", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("grnn_fw_wh", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("grnn_fw_wi", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("grnn_rv_wh", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("grnn_rv_wi", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("att_fc_w", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("att_fc_b", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("grnn_fw_pool_out", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("grnn_rv_pool_out", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("att_pool_out", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("concat_3in1_out", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("emb_fw_out", {LiteType::GetTensorTy(TARGET(kXPU))}) + .Finalize(); + +REGISTER_LITE_KERNEL(__xpu__mmdnn_bid_emb_att, + kXPU, + kFloat, + kNCHW, + paddle::lite::kernels::xpu::XPUMmdnnBidEmbAttCompute, + def) + .BindInput("id0", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))}) + .BindInput("id1", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))}) + .BindInput("emb_tbl", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("att_fc_w", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("att_fc_b", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("att_pool_out", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("concat_3in1_out", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("emb_fw_out", {LiteType::GetTensorTy(TARGET(kXPU))}) + .Finalize(); + +REGISTER_LITE_KERNEL(__xpu__mmdnn_match_conv_topk, + kXPU, + kFloat, + kNCHW, + paddle::lite::kernels::xpu::XPUMmdnnMatchConvTopkCompute, + def) + .BindInput("input_x", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("input_y", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("input_w", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("conv_w", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("topk_out", {LiteType::GetTensorTy(TARGET(kXPU))}) + .Finalize(); + +REGISTER_LITE_KERNEL(__xpu__mmdnn_merge_all, + kXPU, + kFloat, + kNCHW, + paddle::lite::kernels::xpu::XPUMmdnnMergeAllCompute, + def) + .BindInput("concat_7in1_x", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("concat_2in1_x", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("grnn_fw_wh", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("grnn_fw_wi", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("grnn_rv_wh", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("grnn_rv_wi", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("fc0_w", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("fc0_b", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("fc1_w", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("fc1_b", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("fc2_w", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("fc2_b", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("out", {LiteType::GetTensorTy(TARGET(kXPU))}) + .Finalize(); diff --git a/lite/kernels/xpu/__xpu__resnet_cbam_compute.cc b/lite/kernels/xpu/__xpu__resnet_cbam_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..0d57445cd44953f504e292ad38d44d047daa3a7a --- /dev/null +++ b/lite/kernels/xpu/__xpu__resnet_cbam_compute.cc @@ -0,0 +1,82 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/xpu/__xpu__resnet_cbam_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +void XPUResNetCbamCompute::PrepareForRun() { + auto& param = this->Param(); + + for (auto* filter : param.filter) { + arg_filter_.push_back( + reinterpret_cast(filter->data())); + } + for (auto* bias : param.bias) { + if (bias == nullptr) { + arg_bias_.push_back(nullptr); + } else { + arg_bias_.push_back(bias->data()); + } + } + for (auto* max_filter : param.max_filter) { + arg_max_filter_.push_back(max_filter->data()); + } +} + +void XPUResNetCbamCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->As(); + + auto input_dims = param.input->dims(); + int batch_size = input_dims[0]; + int height = input_dims[2]; + int width = input_dims[3]; + + int r = xdnn::conv2d_int16_resnet_cbam( + ctx.GetRawContext(), /* context */ + batch_size, /* num */ + height, /* height */ + width, /* width */ + param.input->data(), /* bottom */ + &arg_filter_[0], /* weight_list */ + param.output->mutable_data(TARGET(kXPU)), /* top */ + &arg_bias_[0], /* bias_list */ + &arg_max_filter_[0], /* max_filter_list */ + param.pool_p, /* pool_p */ + true, /* midtype_fp16 */ + false /* dynamic_shape */); + CHECK_EQ(r, 0); +} + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(__xpu__resnet_cbam, + kXPU, + kFloat, + kNCHW, + paddle::lite::kernels::xpu::XPUResNetCbamCompute, + def) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("MaxFilter", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kXPU))}) + .Finalize(); diff --git a/lite/kernels/xpu/__xpu__resnet_cbam_compute.h b/lite/kernels/xpu/__xpu__resnet_cbam_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..b952bb088ea88399966c170cbeadebfa698889d8 --- /dev/null +++ b/lite/kernels/xpu/__xpu__resnet_cbam_compute.h @@ -0,0 +1,45 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/backends/xpu/xpu_header_sitter.h" +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +class XPUResNetCbamCompute + : public KernelLite { + public: + using param_t = operators::XPUResNetCbamParam; + + virtual void PrepareForRun(); + + virtual void Run(); + + private: + std::vector arg_filter_; + std::vector arg_max_filter_; + std::vector arg_bias_; +}; + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/xpu/__xpu__search_attention_compute.cc b/lite/kernels/xpu/__xpu__search_attention_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..515be8935637d89d58db830f96f2ea439e7d7e68 --- /dev/null +++ b/lite/kernels/xpu/__xpu__search_attention_compute.cc @@ -0,0 +1,219 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/xpu/__xpu__search_attention_compute.h" +#include "lite/backends/xpu/xpu_header_sitter.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +void XPUMmdnnSearchAttentionCompute::PrepareForRun() { + offset_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); + pad_begin_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); + w_max_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(8 * sizeof(float)); + buffer_at_l3_guard_ = TargetWrapperXPU::MallocScratchPad( + 5 * L3_SLOT_SIZE * sizeof(float), false /* use_l3 */); + buffer_at_gm_guard_ = TargetWrapperXPU::MallocScratchPad( + 5 * GM_SLOT_SIZE * sizeof(float), false /* use_l3 */); + + offset_cpu.reset(new int[64]); + pad_begin_cpu.reset(new int[64]); +} + +void XPUMmdnnSearchAttentionCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->As(); + + auto* X = param.X; + auto* W = param.W; + auto* b = param.b; + float W_max = param.W_max; + float alpha0 = param.alpha0; + float alpha1 = param.alpha1; + float mask = param.mask; + + const int16_t* w_data = W->data(); + const float* b_data = b->data(); + + int batch = X->lod()[0].size() - 1; + int dim0 = X->dims()[0]; + int dim1 = X->dims()[1]; + const auto offset = X->lod()[0]; + int max_seq = 0; + + auto* top = param.Out; + LoD top_lod; + top_lod.push_back(X->lod()[0]); + top->set_lod(top_lod); + top->Resize({dim0, dim1}); + auto* top_data = top->mutable_data(TARGET(kXPU)); + + float maxs_cpu[8] = {0.0f, 0.0f, 0.0f, 0.0f, W_max, 0.0f, 0.0f, 0.0f}; + for (int i = 0; i < batch; ++i) { + offset_cpu[i] = offset[i]; // type of offset is int64, not supported by xpu + pad_begin_cpu[i] = offset[i + 1] - offset[i]; + if (offset[i + 1] - offset[i] > max_seq) { + max_seq = offset[i + 1] - offset[i]; + } + } + offset_cpu[batch] = offset[batch]; + + xpu_memcpy(offset_xpu_guard_->addr_, + offset_cpu.get(), + offset.size() * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE); + xpu_memcpy(pad_begin_xpu_guard_->addr_, + pad_begin_cpu.get(), + batch * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE); + xpu_memcpy(w_max_xpu_guard_->addr_, + maxs_cpu, + 8 * sizeof(float), + XPUMemcpyKind::XPU_HOST_TO_DEVICE); + + int* offset_xpu = reinterpret_cast(offset_xpu_guard_->addr_); + int* pad_begin_xpu = reinterpret_cast(pad_begin_xpu_guard_->addr_); + float* maxs_xpu = reinterpret_cast(w_max_xpu_guard_->addr_); + float* buffer_at_l3 = reinterpret_cast(buffer_at_l3_guard_->addr_); + float* buffer_at_gm = reinterpret_cast(buffer_at_gm_guard_->addr_); + + // when use l3, max_seq <= 128: + // group_padding: batch * max_seq * dim1; at (slot0, slot1) + // seq_fc: batch * max_seq * dim1; at (slot2, slot3) + // batchgemm0: batch * max_seq * max_seq; at slot4 + // attention_padding_mask: batch * max_seq * max_seq; at slot3 + // seq_softmax: batch * max_seq * max_seq; at slot4 + // batchgemm1: batch * max_seq * dim1; at (slot2, slot3) + float* group_padding_output = buffer_at_l3; + float* seq_fc_output = buffer_at_l3 + 2 * L3_SLOT_SIZE; + float* batchgemm0_output = buffer_at_l3 + 4 * L3_SLOT_SIZE; + float* attention_output = buffer_at_l3 + 3 * L3_SLOT_SIZE; + float* seq_softmax_output = buffer_at_l3 + 4 * L3_SLOT_SIZE; + float* batchgemm1_output = buffer_at_l3 + 2 * L3_SLOT_SIZE; + + if (max_seq > 128) { + group_padding_output = buffer_at_gm; + seq_fc_output = buffer_at_gm + 1 * GM_SLOT_SIZE; + batchgemm0_output = buffer_at_gm + 2 * GM_SLOT_SIZE; + attention_output = buffer_at_gm + 1 * GM_SLOT_SIZE; + seq_softmax_output = buffer_at_gm + 3 * GM_SLOT_SIZE; + batchgemm1_output = buffer_at_gm + 4 * GM_SLOT_SIZE; + } + + const auto* bottom_data = X->data(); + xdnn::search_sequence_pad_depad(ctx.GetRawContext(), + const_cast(bottom_data), + group_padding_output, + offset_xpu, + max_seq, + batch, + dim1, + 0); // is_depad = 0 + // do-findmax + xdnn::findmax(ctx.GetRawContext(), + group_padding_output, + batch * max_seq * dim1, + maxs_xpu); + xdnn::gemm_int16_maxptr( + ctx.GetRawContext(), + false, + true, // trans_a, trans_b + batch * max_seq, + dim1, + dim1, // m, n, k + 1.0f, + group_padding_output, + dim1, // alpha, data_a, lda + w_data, + dim1, + 0.0f, // data_b, ldb, beta + seq_fc_output, + dim1, + b_data, // data_c, ldc, bias + xdnn::Activation_t::LINEAR, + maxs_xpu, + maxs_xpu + 4, + nullptr); // max_a, max_b, max_c + xdnn::search_aligned_mat_mul(ctx.GetRawContext(), + 0, + 1, + batch, + max_seq, + max_seq, + dim1, + alpha0, + group_padding_output, + dim1, + seq_fc_output, + dim1, + batchgemm0_output, + max_seq); + xdnn::search_pad_mask(ctx.GetRawContext(), + batchgemm0_output, + attention_output, + pad_begin_xpu, + batch, + max_seq, + max_seq, + batch, + mask); + xdnn::softmax2d_forward(ctx.GetRawContext(), + attention_output, + seq_softmax_output, + batch * max_seq, + max_seq, + true); + xdnn::search_aligned_mat_mul(ctx.GetRawContext(), + 0, + 0, + batch, + max_seq, + dim1, + max_seq, + alpha1, + seq_softmax_output, + max_seq, + group_padding_output, + dim1, + batchgemm1_output, + dim1); + xdnn::search_sequence_pad_depad(ctx.GetRawContext(), + top_data, + batchgemm1_output, + offset_xpu, + max_seq, + batch, + dim1, + 1); // is_depad = 1 +} + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(__xpu__mmdnn_search_attention, + kXPU, + kFloat, + kNCHW, + paddle::lite::kernels::xpu::XPUMmdnnSearchAttentionCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("W", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("b", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))}) + .Finalize(); diff --git a/lite/kernels/xpu/__xpu__search_attention_compute.h b/lite/kernels/xpu/__xpu__search_attention_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..f9670dbab6247927acf6ac7d7b47f98a464a3489 --- /dev/null +++ b/lite/kernels/xpu/__xpu__search_attention_compute.h @@ -0,0 +1,52 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/backends/xpu/target_wrapper.h" // XPUScratchPadGuard +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +class XPUMmdnnSearchAttentionCompute + : public KernelLite { + public: + using param_t = operators::XPUMmdnnSearchAttentionParam; + + void PrepareForRun() override; + + void Run() override; + + private: + XPUScratchPadGuard offset_xpu_guard_; + XPUScratchPadGuard pad_begin_xpu_guard_; + XPUScratchPadGuard w_max_xpu_guard_; + XPUScratchPadGuard buffer_at_l3_guard_; + XPUScratchPadGuard buffer_at_gm_guard_; + + std::unique_ptr offset_cpu; + std::unique_ptr pad_begin_cpu; + + const int L3_SLOT_SIZE = 40 * 128 * 128; + const int GM_SLOT_SIZE = 40 * 512 * 512; +}; + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/xpu/concat_compute.cc b/lite/kernels/xpu/concat_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..f088bb80f0c500c6f900726195bcb5903049d3fb --- /dev/null +++ b/lite/kernels/xpu/concat_compute.cc @@ -0,0 +1,85 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/xpu/concat_compute.h" +#include "lite/backends/xpu/xpu_header_sitter.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +void ConcatCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->As(); + + auto ins = param.x; + auto out = param.output; + int64_t axis = param.axis; + + int n = ins.size(); + int h = 1; + int w_except_axis = 1; + CHECK(n <= 8) << "XPU only surpport at most 8 tensors for now"; + for (int i = 0; i < axis; ++i) { + h *= (ins[0]->dims())[i]; + } + for (int i = axis + 1; i < ins[0]->dims().size(); ++i) { + w_except_axis *= (ins[0]->dims())[i]; + } + CHECK(axis >= 0) << "concat: axis shoud >= 0!"; + CHECK(axis < ins[0]->dims().size()) << "concat: axis shoud < ins[0]->dims()!"; + for (int i = 0; i < n; ++i) { + int hh = 1; + int ww = 1; + for (int j = 0; j < axis; ++j) { + hh *= (ins[i]->dims())[j]; + } + for (int j = axis + 1; j < ins[i]->dims().size(); ++j) { + ww *= (ins[i]->dims())[j]; + } + CHECK(hh == h) << "concat: h should be eual!"; + CHECK(ww == w_except_axis) << "concat: w should be eual except for axis!"; + } + + int in_w_host[n]; // NOLINT + const float* ptrs[n]; // NOLINT + + for (int i = 0; i < n; ++i) { + ptrs[i] = ins[i]->data(); + in_w_host[i] = w_except_axis * (ins[i]->dims())[axis]; + } + + int r = xdnn::concat(ctx.GetRawContext(), /* ctx */ + h, /* height */ + in_w_host, /* width_x */ + n, /* n */ + ptrs, /* lm_ptrs */ + out->mutable_data(TARGET(kXPU)) /*y*/); + CHECK_EQ(r, 0); +} + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL( + concat, kXPU, kFloat, kNCHW, paddle::lite::kernels::xpu::ConcatCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("AxisTensor", + {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt32))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))}) + .Finalize(); diff --git a/lite/kernels/xpu/concat_compute.h b/lite/kernels/xpu/concat_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..f29899a741194270272770d8b781cd9b0b54abc9 --- /dev/null +++ b/lite/kernels/xpu/concat_compute.h @@ -0,0 +1,36 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +class ConcatCompute : public KernelLite { + public: + using param_t = operators::ConcatParam; + + virtual void Run(); + + virtual ~ConcatCompute() = default; +}; + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/xpu/match_matrix_tensor_compute.cc b/lite/kernels/xpu/match_matrix_tensor_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..3c4e896d23add6df99a7b66a830dc526dc808e95 --- /dev/null +++ b/lite/kernels/xpu/match_matrix_tensor_compute.cc @@ -0,0 +1,179 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/xpu/match_matrix_tensor_compute.h" +#include +#include "lite/backends/xpu/xpu_header_sitter.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +void MatchMatrixTensorCompute::PrepareForRun() { + wx_max_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); + offset_l_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); + offset_r_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); + + offset_l_cpu.reset(new int[64]); + offset_r_cpu.reset(new int[64]); +} + +void MatchMatrixTensorCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->As(); + + auto* x = param.x; + auto* y = param.y; + auto* w = param.w; + auto* out = param.out; + auto* tmp = param.tmp; + int dim_t = param.dim_t; + float w_max = param.__xpu__w_max; + bool fuse_relu = param.fuse_relu; + bool float_to_fix = param.__xpu__float_to_fix; + CHECK(float_to_fix) << "W should be fixed point"; + + xdnn::Activation_t act = xdnn::Activation_t::LINEAR; + if (fuse_relu) { + act = xdnn::Activation_t::RELU; + } + + int dim_in = x->dims()[1]; + const auto& offset_l = x->lod()[0]; + const auto& offset_r = y->lod()[0]; + + std::vector top_offset; + int top_size = 0; + top_offset.push_back(top_size); + for (size_t b = 0; b < x->lod()[0].size() - 1; b++) { + int len_l = offset_l[b + 1] - offset_l[b]; + int len_r = offset_r[b + 1] - offset_r[b]; + top_size += dim_t * len_l * len_r; + top_offset.push_back(top_size); + } + auto* bottom_l_data = x->data(); + auto* bottom_r_data = y->data(); + auto* w_data = w->data(); + auto* out_data = out->mutable_data(TARGET(kXPU)); + auto* bottom_l_trans_data = tmp->mutable_data(TARGET(kXPU)); + int batch_size = x->lod()[0].size() - 1; + + float* wx_max = reinterpret_cast(wx_max_xpu_guard_->addr_); + int* offset_l_xpu = reinterpret_cast(offset_l_xpu_guard_->addr_); + int* offset_r_xpu = reinterpret_cast(offset_r_xpu_guard_->addr_); + + int r = xdnn::gemm_int16_tmp_api( + ctx.GetRawContext(), /* ctx */ + false, + false, /* trans_a, trans_b */ + x->dims()[0], + dim_t * dim_in, + dim_in, /* m, n, k */ + 1.0f, + bottom_l_data, + dim_in, /* alpha, data_a, lda */ + w_data, + dim_t * dim_in, + 0.0f, /* data_b, ldb, beta */ + bottom_l_trans_data, + dim_t * dim_in, /* data_c, ldc */ + nullptr, /* bias */ + xdnn::Activation_t::LINEAR, + 0.0f, + w_max, + wx_max /* max_a, max_b, max_c */); + CHECK_EQ(r, 0); + + int max_width = 0; + for (int i = 0; i < offset_l.size(); ++i) { + offset_l_cpu[i] = offset_l[i]; + if (i != 0 && (offset_l_cpu[i] - offset_l_cpu[i - 1] > max_width)) { + max_width = offset_l_cpu[i] - offset_l_cpu[i - 1]; + } + } + for (int i = 0; i < offset_r.size(); ++i) { + offset_r_cpu[i] = offset_r[i]; + if (i != 0 && (offset_r_cpu[i] - offset_r_cpu[i - 1] > max_width)) { + max_width = offset_r_cpu[i] - offset_r_cpu[i - 1]; + } + } + xpu_memcpy(offset_l_xpu, + offset_l_cpu.get(), + offset_l.size() * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE); + xpu_memcpy(offset_r_xpu, + offset_r_cpu.get(), + offset_r.size() * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE); + + r = xdnn::match_matrix_tensor(ctx.GetRawContext(), + batch_size, + bottom_l_trans_data, + bottom_r_data, + offset_l_xpu, + offset_r_xpu, + dim_t, + dim_in, + out_data, + wx_max, + act, + max_width); + CHECK_EQ(r, 0); + + int lod_lv1_size = batch_size * dim_t; + int lod_lv2_size = x->lod()[0].back() * dim_t; + std::vector out_lod0(batch_size + 1, 0); + std::vector out_lod1(lod_lv1_size + 1, 0); + std::vector out_lod2(lod_lv2_size + 1, 0); + for (int i = 0; i < batch_size; i++) { + out_lod0[i + 1] = out_lod0[i] + dim_t; + int len_l = offset_l[i + 1] - offset_l[i]; + + for (int j = 0; j < dim_t; j++) { + out_lod1[i * dim_t + j + 1] = out_lod1[i * dim_t + j] + len_l; + int len_r = offset_r[i + 1] - offset_r[i]; + + for (int k = 0; k < len_l; k++) { + out_lod2[offset_l[i] * dim_t + j * len_l + k + 1] = + out_lod2[offset_l[i] * dim_t + j * len_l + k] + len_r; + } + } + } + + paddle::lite::LoD out_lod; + out_lod.push_back(top_offset); + out_lod.push_back(offset_l); + out_lod.push_back(offset_r); + out->set_lod(out_lod); +} + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(match_matrix_tensor, + kXPU, + kFloat, + kNCHW, + paddle::lite::kernels::xpu::MatchMatrixTensorCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("W", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("Tmp", {LiteType::GetTensorTy(TARGET(kXPU))}) + .Finalize(); diff --git a/lite/kernels/xpu/match_matrix_tensor_compute.h b/lite/kernels/xpu/match_matrix_tensor_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..3bd0b622db1fce178ea66604d89dc50d6477a105 --- /dev/null +++ b/lite/kernels/xpu/match_matrix_tensor_compute.h @@ -0,0 +1,47 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/backends/xpu/target_wrapper.h" // XPUScratchPadGuard +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +class MatchMatrixTensorCompute + : public KernelLite { + public: + using param_t = operators::MatchMatrixTensorParam; + + virtual void PrepareForRun(); + + virtual void Run(); + + private: + XPUScratchPadGuard wx_max_xpu_guard_; + XPUScratchPadGuard offset_l_xpu_guard_; + XPUScratchPadGuard offset_r_xpu_guard_; + + std::unique_ptr offset_l_cpu; + std::unique_ptr offset_r_cpu; +}; + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/xpu/search_fc_compute.cc b/lite/kernels/xpu/search_fc_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..79f4c2d0d809ea9848fb383863d0f9dd2ec5a2ae --- /dev/null +++ b/lite/kernels/xpu/search_fc_compute.cc @@ -0,0 +1,108 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/xpu/search_fc_compute.h" +#include +#include "lite/backends/xpu/xpu_header_sitter.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +void SearchFcCompute::PrepareForRun() { + maxs_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(float)); +} + +void SearchFcCompute::Run() { + auto& param = this->template Param(); + auto& ctx = this->ctx_->template As(); + + auto* bottom = param.X; + auto* w = param.W; + auto* b = param.b; + auto* top = param.Out; + float w_max = param.__xpu__w_max; + int out_size = param.out_size; + bool fuse_relu = param.fuse_relu; + bool float_to_fix = param.__xpu__float_to_fix; + CHECK(float_to_fix) << "W should be fixed point"; + + int batch = bottom->dims()[0]; + int _out = w->dims()[0]; + int _in = w->dims()[1]; + + xdnn::Activation_t act = xdnn::Activation_t::LINEAR; + if (fuse_relu) { + act = xdnn::Activation_t::RELU; + } + + std::vector top_dims{bottom->dims()[0], out_size}; + top->Resize(top_dims); + + const auto* bottom_data = bottom->data(); + const auto* weights = w->data(); + const auto* bias_data = b->data(); + auto* top_data = top->mutable_data(TARGET(kXPU)); + + float* maxs_xpu = reinterpret_cast(maxs_xpu_guard_->addr_); + float maxs_cpu[8] = {0.0f, 0.0f, 0.0f, 0.0f, w_max, 0.0f, 0.0f, 0.0f}; + xpu_memcpy(maxs_xpu, + &maxs_cpu[0], + 8 * sizeof(float), + XPUMemcpyKind::XPU_HOST_TO_DEVICE); + + int r = xdnn::findmax( + ctx.GetRawContext(), bottom_data, batch * _in, maxs_xpu); + CHECK_EQ(r, 0); + r = xdnn::gemm_int16_maxptr( + ctx.GetRawContext(), /* ctx */ + false, + true, /*trans_a, trans_b*/ + batch, + _out, + _in, /*m, n, k*/ + 1.0f, + bottom_data, + _in, /*alpha, data_a, lda*/ + weights, + _in, + 0.0f, /*data_b, ldb, beta*/ + top_data, + _out, + bias_data, /* data_c, ldc, bias*/ + act, + maxs_xpu, + maxs_xpu + 4, + nullptr /*act, max_a, max_b, max_c*/); + CHECK_EQ(r, 0); +} + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(search_fc, + kXPU, + kFloat, + kNCHW, + paddle::lite::kernels::xpu::SearchFcCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("W", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("b", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))}) + .Finalize(); diff --git a/lite/kernels/xpu/search_fc_compute.h b/lite/kernels/xpu/search_fc_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..c7ee06abd957187c18c1306f40a77735f40558e7 --- /dev/null +++ b/lite/kernels/xpu/search_fc_compute.h @@ -0,0 +1,40 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "lite/backends/xpu/target_wrapper.h" // XPUScratchPadGuard +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +class SearchFcCompute : public KernelLite { + public: + using param_t = operators::SearchFcParam; + + void PrepareForRun() override; + + void Run() override; + + private: + XPUScratchPadGuard maxs_xpu_guard_; +}; + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/xpu/search_grnn_compute.cc b/lite/kernels/xpu/search_grnn_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..1c19f58da1b5deaa3d74791561494f13b681cf3a --- /dev/null +++ b/lite/kernels/xpu/search_grnn_compute.cc @@ -0,0 +1,282 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/xpu/search_grnn_compute.h" +#include +#include +#include "lite/backends/xpu/xpu_header_sitter.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +void SearchGrnnCompute::PrepareForRun() { + offset_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); + new_offset_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(256 * sizeof(int)); + maxs_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(16 * sizeof(float)); + + idx_sorted_by_width_data_cpu.reset(new int[64]); + offset_cpu.reset(new int[64]); + new_offset_cpu.reset(new int[256]); +} + +void SearchGrnnCompute::prepare_layout(const operators::SearchGrnnParam& param, + const paddle::lite::Tensor* bottom) { + auto* idx_sorted_by_width = param.idx_sorted_by_width; + auto* layout_input = param.layout_input; + + int dim0 = bottom->dims()[0]; + int dim1 = 1; + if (bottom->dims().size() > 1) { + dim1 = bottom->dims()[1]; + } + int batch = bottom->lod()[0].size() - 1; + auto& offset = bottom->lod()[0]; + + idx_sorted_by_width->Resize({batch}); + std::vector width; + width.resize(batch); + + // sort sequences by width (descending) and find the largest width in the + // batch + for (int i = 0; i < batch; i++) { + width[i] = offset[i + 1] - offset[i]; + idx_sorted_by_width_data_cpu[i] = i; + } + std::sort(idx_sorted_by_width_data_cpu.get(), + idx_sorted_by_width_data_cpu.get() + batch, + [&width](int a, int b) { return width[a] > width[b]; }); + int max_width = width[idx_sorted_by_width_data_cpu[0]]; + + // start of reorganizing the input + std::vector new_offset; + new_offset.resize(max_width + 1); + new_offset[0] = 0; + int j = batch - 1; + int last_width = 0; + int sub_row = 0; + int sub_col = 0; + + for (int i = 1; i <= max_width;) { + for (int k = j; k >= 0; --k) { + if (width[idx_sorted_by_width_data_cpu[k]] > last_width) { + sub_row = width[idx_sorted_by_width_data_cpu[k]] - last_width; + sub_col = k + 1; + for (int s = 0; s < sub_row; s++) { + new_offset[i] = new_offset[i - 1] + sub_col; + i++; + } + // move on + last_width = width[idx_sorted_by_width_data_cpu[k]]; + j = k - 1; + break; + } + } + } + + // copying to the reorganized buffer + if (bottom->dims().size() == 1) { + } else { + LoD new_lod; + new_lod.push_back(new_offset); + layout_input->set_lod(new_lod); + layout_input->Resize({dim0, dim1}); + } + + xpu_memcpy(idx_sorted_by_width->mutable_data(TARGET(kXPU)), + idx_sorted_by_width_data_cpu.get(), + idx_sorted_by_width->numel() * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE); +} + +void SearchGrnnCompute::Run() { + auto& param = this->template Param(); + auto& ctx = this->ctx_->template As(); + + auto* bottom = param.x; + auto* wi = param.wi; + auto* wh = param.wh; + auto* top = param.out; + auto* tmp_buffer = param.tmp_buffer; + auto* idx_sorted_by_width = param.idx_sorted_by_width; + auto* layout_input = param.layout_input; + int cap_h = param.num_hidden; + int cap_e = param.num_input; + int cap_l = bottom->dims()[0]; + auto wi_max = param.__xpu__wi_max; + auto wh_max = param.__xpu__wh_max; + bool float_to_fix = param.__xpu__float_to_fix; + CHECK(float_to_fix) << "W should be fixed point"; + + int dim = 1; + if (bottom->dims().size() > 1) { + dim = bottom->dims()[1]; + } + + const auto& offset = bottom->lod()[0]; + LoD top_lod; + top_lod.push_back(offset); + top->set_lod(top_lod); + std::vector top_dims_vec{cap_l, cap_h}; + top->Resize(top_dims_vec); + auto* top_hidden = top->mutable_data(TARGET(kXPU)); + const auto* dense_e2h = wi->data(); + const auto* dense_h2h = wh->data(); + + // Prepare idx_sorted_by_width + prepare_layout(param, bottom); + int batch = bottom->lod()[0].size() - 1; + int max_width = layout_input->lod()[0].size() - 1; + const auto& new_offset = layout_input->lod()[0]; + auto* new_emb = layout_input->mutable_data(TARGET(kXPU)); + + // Prepare offset and new_offset + int* offset_xpu = reinterpret_cast(offset_xpu_guard_->addr_); + int* new_offset_xpu = reinterpret_cast(new_offset_xpu_guard_->addr_); + float* maxs_xpu = reinterpret_cast(maxs_xpu_guard_->addr_); + CHECK_LE(offset.size(), 64); + CHECK_LE(new_offset.size(), 256); + + for (size_t i = 0; i < offset.size(); ++i) { + offset_cpu[i] = offset[i]; + } + for (size_t i = 0; i < new_offset.size(); ++i) { + new_offset_cpu[i] = new_offset[i]; + } + xpu_memcpy(offset_xpu, + offset_cpu.get(), + offset.size() * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE); + xpu_memcpy(new_offset_xpu, + new_offset_cpu.get(), + new_offset.size() * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE); + + int r = xdnn::search_seq2batch(ctx.GetRawContext(), + batch, + max_width, + dim, + idx_sorted_by_width->data(), + offset_xpu, + new_offset_xpu, + bottom->data(), + new_emb); + CHECK_EQ(r, 0); + + // this buffer is used for book keeping info which will be used in bp + // buffer also needed in bp, so make it larger + tmp_buffer->Resize({20, cap_l, cap_h}); + auto* buffer_data = tmp_buffer->mutable_data(TARGET(kXPU)); + // the internal hidden + auto* hidden = buffer_data + 19 * cap_l * cap_h; + + // do-findmax + float maxs_cpu[16] = {0.0f, + 0.0f, + 0.0f, + 0.0f, + wi_max[0], + 0.0f, + 0.0f, + 0.0f, + wi_max[1], + 0.0f, + 0.0f, + 0.0f, + wi_max[2], + 0.0f, + 0.0f, + 0.0f}; + xpu_memcpy(maxs_xpu, + maxs_cpu, + 16 * sizeof(float), + XPUMemcpyKind::XPU_HOST_TO_DEVICE); + r = xdnn::findmax( + ctx.GetRawContext(), new_emb, cap_l * cap_e, maxs_xpu); + CHECK_EQ(r, 0); + + // precompute embedding to hidden + for (int i = 0; i < 3; ++i) { + const int16_t* data_b = dense_e2h + i * cap_e * cap_h; // e2h, e2hr, e2hz + float* data_c = buffer_data + i * cap_l * cap_h; // w_x_e, wr_x_e, wz_x_e + int r = xdnn::gemm_int16_maxptr( + ctx.GetRawContext(), + false, + true, // trans_a, trans_b + cap_l, + cap_h, + cap_e, // m, n, k + 1.0f, + new_emb, + cap_e, // alpha, data_a, lda + data_b, + cap_e, + 0.0f, // data_b, ldb, beta + data_c, + cap_h, // data_c, ldc + nullptr, + xdnn::Activation_t::LINEAR, // bias, act + maxs_xpu, + maxs_xpu + 4 * (i + 1)); // max_a, max_b + CHECK_EQ(r, 0); + } + + r = xdnn::search_grnn(ctx.GetRawContext(), + cap_l, + cap_h, + cap_e, + max_width, + new_offset_xpu, + buffer_data, + dense_h2h, + hidden, + wh_max[0], + wh_max[1], + wh_max[2]); + CHECK_EQ(r, 0); + + r = xdnn::search_batch2seq(ctx.GetRawContext(), + batch, + max_width, + cap_h, + idx_sorted_by_width->data(), + offset_xpu, + new_offset_xpu, + hidden, + top_hidden); + CHECK_EQ(r, 0); +} + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(search_grnn, + kXPU, + kFloat, + kNCHW, + paddle::lite::kernels::xpu::SearchGrnnCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("Wi", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("Wh", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("tmp_buffer", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("idx_sorted_by_width", + {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt32))}) + .BindOutput("layout_input", {LiteType::GetTensorTy(TARGET(kXPU))}) + .Finalize(); diff --git a/lite/kernels/xpu/search_grnn_compute.h b/lite/kernels/xpu/search_grnn_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..7208e782474d39eabb41b4bc969d27a1d7d5f797 --- /dev/null +++ b/lite/kernels/xpu/search_grnn_compute.h @@ -0,0 +1,49 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/backends/xpu/target_wrapper.h" // XPUScratchPadGuard +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +class SearchGrnnCompute : public KernelLite { + public: + using param_t = operators::SearchGrnnParam; + + void PrepareForRun() override; + + void prepare_layout(const operators::SearchGrnnParam& param, + const paddle::lite::Tensor* bottom); + void Run() override; + + private: + XPUScratchPadGuard offset_xpu_guard_; + XPUScratchPadGuard new_offset_xpu_guard_; + XPUScratchPadGuard maxs_xpu_guard_; + + std::unique_ptr idx_sorted_by_width_data_cpu; + std::unique_ptr offset_cpu; + std::unique_ptr new_offset_cpu; +}; + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/xpu/sequence_arithmetic_compute.cc b/lite/kernels/xpu/sequence_arithmetic_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..226c615dba57ae381ed2457e588c5df32f25e04b --- /dev/null +++ b/lite/kernels/xpu/sequence_arithmetic_compute.cc @@ -0,0 +1,110 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/xpu/sequence_arithmetic_compute.h" +#include "lite/backends/xpu/xpu_header_sitter.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +void SequenceArithmeticCompute::Run() { + auto& param = this->template Param(); + auto& ctx = this->ctx_->template As(); + + auto* bottom0 = param.X; + auto* bottom1 = param.Y; + auto* top = param.Out; + + int op_type = param.op_type; + + auto len1 = bottom0->numel(); + auto len2 = bottom1->numel(); + const auto* bottom_data0 = bottom0->data(); + const auto* bottom_data1 = bottom1->data(); + auto* top_data = top->mutable_data(TARGET(kXPU)); + + switch (op_type) { + case 1: // addition: top[0] = bottom[0] + bottom[1] + if (len1 > len2) { + xdnn::elementwise_add( + ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len2); + xdnn::memcpy_device(ctx.GetRawContext(), + &top_data[len2], + &bottom_data0[len2], + (len1 - len2) * sizeof(float)); + } else { + xdnn::elementwise_add( + ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len1); + } + break; + case 2: // substraction: top[0] = bottom[0] - bottom[1] + if (len1 > len2) { + xdnn::elementwise_sub( + ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len2); + xdnn::memcpy_device(ctx.GetRawContext(), + &top_data[len2], + &bottom_data0[len2], + (len1 - len2) * sizeof(float)); + } else { + xdnn::elementwise_sub( + ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len1); + } + break; + case 3: // multiplication: top[0] = bottom[0] * bottom[1] + if (len1 > len2) { + xdnn::elementwise_mul( + ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len2); + xdnn::memcpy_device(ctx.GetRawContext(), + &top_data[len2], + &bottom_data0[len2], + (len1 - len2) * sizeof(float)); + } else { + xdnn::elementwise_mul( + ctx.GetRawContext(), bottom_data0, bottom_data1, top_data, len1); + } + break; + default: + break; + } +} + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(sequence_arithmetic, + kXPU, + kFloat, + kNCHW, + paddle::lite::kernels::xpu::SequenceArithmeticCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))}) + .Finalize(); + +REGISTER_LITE_KERNEL(search_seq_arithmetic, + kXPU, + kFloat, + kNCHW, + paddle::lite::kernels::xpu::SequenceArithmeticCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))}) + .Finalize(); diff --git a/lite/kernels/arm/activation_grad_compute.h b/lite/kernels/xpu/sequence_arithmetic_compute.h similarity index 75% rename from lite/kernels/arm/activation_grad_compute.h rename to lite/kernels/xpu/sequence_arithmetic_compute.h index ef03f58fa8cd499192aa6edfe3a7c51b49b14f65..9526587ac48cd5025022d646e31c24cac6b59a13 100644 --- a/lite/kernels/arm/activation_grad_compute.h +++ b/lite/kernels/xpu/sequence_arithmetic_compute.h @@ -13,25 +13,24 @@ // limitations under the License. #pragma once -#include + +#include #include "lite/core/kernel.h" -#include "lite/core/op_registry.h" namespace paddle { namespace lite { namespace kernels { -namespace arm { +namespace xpu { -class SquareGradCompute : public KernelLite { +class SequenceArithmeticCompute + : public KernelLite { public: - using param_t = operators::ActivationGradParam; + using param_t = operators::SequenceArithmeticParam; void Run() override; - - virtual ~SquareGradCompute() = default; }; -} // namespace arm +} // namespace xpu } // namespace kernels } // namespace lite } // namespace paddle diff --git a/lite/kernels/xpu/sequence_concat_compute.cc b/lite/kernels/xpu/sequence_concat_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..fd7f5999a6ccb18efbcb0e96b50f2b31884fc21c --- /dev/null +++ b/lite/kernels/xpu/sequence_concat_compute.cc @@ -0,0 +1,141 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/xpu/sequence_concat_compute.h" +#include +#include "lite/backends/xpu/xpu_header_sitter.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +void SequenceConcatCompute::PrepareForRun() { + lod0_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); + lod1_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); + + lod0_cpu.reset(new int[64]); + lod1_cpu.reset(new int[64]); +} + +template +inline LoD ConcatLoD(const std::vector& xs, + std::vector* xs_in_order) { + std::vector result; + result.resize(xs[0]->lod()[0].size()); + + for (size_t i = 1; i < result.size(); ++i) { + size_t sum = 0; + for (size_t j = 0; j < xs.size(); ++j) { + auto& x_lod = xs[j]->lod()[0]; + if (x_lod[i - 1] < x_lod[i]) { + xs_in_order->emplace_back(xs[j]->Slice(x_lod[i - 1], x_lod[i])); + } + sum += x_lod[i]; + } + result[i] = sum; + } + LoD lod; + lod.emplace_back(result); + return lod; +} + +void SequenceConcatCompute::Run() { + auto& param = this->template Param(); + auto& ctx = this->ctx_->template As(); + + auto xs = param.X; + auto out = param.Out; + + size_t lod_size = 0; + for (auto& x : xs) { + if (lod_size == 0) { + lod_size = x->lod()[0].size(); + } else { + CHECK_EQ(lod_size, x->lod()[0].size()) + << "The number of sequence must be same between each input"; + } + } + CHECK_NE(lod_size, 0) << "Each input must have sequence information"; + + // TODO(miaotianxiang): + int64_t dim0 = 0; + int64_t feature_size = 0; + std::vector out_dims; + for (const auto& tensor : param.X) { + const auto x_dims = tensor->dims(); + if (out_dims.empty()) { + out_dims = x_dims.data(); + } + dim0 += x_dims[0]; + if (feature_size == 0) { + feature_size = x_dims.production() / x_dims[0]; + } else { + CHECK_EQ(feature_size, x_dims.production() / x_dims[0]) + << "Inputs of sequence concat must have same feature size"; + } + } + out_dims[0] = dim0; + out->Resize(out_dims); + std::vector x_in_order; + out->set_lod(ConcatLoD(xs, &x_in_order)); + + CHECK(xs.size() == 2) << "XPU only support sequence_pool for 2 tensors"; + + auto lod0 = xs[0]->lod()[0]; + auto lod1 = xs[1]->lod()[0]; + int batch_size = lod0.size() - 1; + + int* lod0_xpu = reinterpret_cast(lod0_xpu_guard_->addr_); + int* lod1_xpu = reinterpret_cast(lod1_xpu_guard_->addr_); + for (int i = 0; i < lod0.size(); ++i) { + lod0_cpu[i] = lod0[i]; + } + for (int i = 0; i < lod1.size(); ++i) { + lod1_cpu[i] = lod1[i]; + } + xpu_memcpy(lod0_xpu, + lod0_cpu.get(), + lod0.size() * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE); + xpu_memcpy(lod1_xpu, + lod1_cpu.get(), + lod1.size() * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE); + + int r = xdnn::sequence_concat(ctx.GetRawContext(), + xs[0]->data(), + lod0_xpu, + xs[1]->data(), + lod1_xpu, + out->mutable_data(TARGET(kXPU)), + batch_size); + CHECK_EQ(r, 0); +} + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(sequence_concat, + kXPU, + kFloat, + kNCHW, + paddle::lite::kernels::xpu::SequenceConcatCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))}) + .Finalize(); diff --git a/lite/kernels/xpu/sequence_concat_compute.h b/lite/kernels/xpu/sequence_concat_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..5726671975d546d1e549ecbe95790c11faafba7b --- /dev/null +++ b/lite/kernels/xpu/sequence_concat_compute.h @@ -0,0 +1,46 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/backends/xpu/target_wrapper.h" // XPUScratchPadGuard +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +class SequenceConcatCompute + : public KernelLite { + public: + using param_t = operators::SequenceConcatParam; + + void PrepareForRun() override; + + void Run() override; + + private: + XPUScratchPadGuard lod0_xpu_guard_; + XPUScratchPadGuard lod1_xpu_guard_; + + std::unique_ptr lod0_cpu; + std::unique_ptr lod1_cpu; +}; + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/xpu/sequence_pool_compute.cc b/lite/kernels/xpu/sequence_pool_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..81d9b5873c3c42afe94acdd8eb5a292326b7a7b6 --- /dev/null +++ b/lite/kernels/xpu/sequence_pool_compute.cc @@ -0,0 +1,89 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/xpu/sequence_pool_compute.h" +#include +#include "lite/backends/xpu/xpu_header_sitter.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +void XPUSequencePoolCompute::PrepareForRun() { + lod_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); + lod_cpu.reset(new int[64]); +} + +void XPUSequencePoolCompute::Run() { + auto& param = this->template Param(); + auto& ctx = this->ctx_->template As(); + + auto* in = param.X; + auto* out = param.Out; + std::string pool_type_str = param.pool_type; + + auto dims = in->dims(); + auto lod = in->lod(); + dims[0] = lod[0].size() - 1; + + xdnn::Pooling_t pool_type = xdnn::Pooling_t::MAX_WITHOUT_INDEX; + if (pool_type_str == "MAX") { + } else if (pool_type_str == "LAST") { + pool_type = xdnn::Pooling_t::LAST; + } else { + CHECK(false); + } + + int num_seq = out->dims()[0]; + int dim = out->numel() / num_seq; + + auto in_lod = in->lod()[0]; + for (size_t i = 0; i < in_lod.size(); ++i) { + lod_cpu[i] = in_lod[i]; + } + int* lod_xpu = reinterpret_cast(lod_xpu_guard_->addr_); + xpu_memcpy(lod_xpu, + lod_cpu.get(), + in_lod.size() * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE); + + int r = + xdnn::sequence_pooling_forward(ctx.GetRawContext(), + pool_type, + num_seq, + lod_xpu, + dim, + in->data(), + nullptr /* index */, + out->mutable_data(TARGET(kXPU))); + CHECK_EQ(r, 0); +} + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(sequence_pool, + kXPU, + kFloat, + kNCHW, + paddle::lite::kernels::xpu::XPUSequencePoolCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("MaxIndex", {LiteType::GetTensorTy(TARGET(kXPU))}) + .Finalize(); diff --git a/lite/kernels/xpu/sequence_pool_compute.h b/lite/kernels/xpu/sequence_pool_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..232634de0e387e764eccdeeda4cb8fd2d5dce598 --- /dev/null +++ b/lite/kernels/xpu/sequence_pool_compute.h @@ -0,0 +1,44 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/backends/xpu/target_wrapper.h" // XPUScratchPadGuard +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +class XPUSequencePoolCompute + : public KernelLite { + public: + using param_t = operators::SequencePoolParam; + + void PrepareForRun() override; + + void Run() override; + + private: + XPUScratchPadGuard lod_xpu_guard_; + + std::unique_ptr lod_cpu; +}; + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/xpu/sequence_reverse_compute.cc b/lite/kernels/xpu/sequence_reverse_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..11e4b80570c19fa90e7846d18a88f966f9a003b7 --- /dev/null +++ b/lite/kernels/xpu/sequence_reverse_compute.cc @@ -0,0 +1,96 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/xpu/sequence_reverse_compute.h" +#include "lite/backends/xpu/xpu_header_sitter.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +template +void SequenceReverseCompute::PrepareForRun() { + lod_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); + lod_cpu.reset(new int[64]); +} + +template +void SequenceReverseCompute::Run() { + auto& param = this->template Param(); + auto& ctx = this->ctx_->template As(); + + auto* x = param.X; + auto* y = param.Out; + + auto lod = x->lod()[0]; + size_t limit = x->numel(); + size_t ele_cnt_in_4_byte = limit / x->dims()[0]; + auto* x_data = x->template data(); + auto* y_data = y->template mutable_data(TARGET(kXPU)); + int batch_size = lod.size() - 1; + + if (std::is_same::value) { + ele_cnt_in_4_byte /= 4; + } else if (std::is_same::value) { + // remain the same + } else if (std::is_same::value) { + ele_cnt_in_4_byte *= 2; + } else if (std::is_same::value) { + // remain the same + } else if (std::is_same::value) { + ele_cnt_in_4_byte *= 2; + } + + for (size_t i = 0; i < lod.size(); ++i) { + lod_cpu[i] = lod[i]; + } + int* lod_xpu = reinterpret_cast(lod_xpu_guard_->addr_); + xpu_memcpy(lod_xpu, + lod_cpu.get(), + lod.size() * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE); + + int r = xdnn::sequence_reverse(ctx.GetRawContext(), + batch_size, + lod_xpu, + ele_cnt_in_4_byte, + reinterpret_cast(x_data), + reinterpret_cast(y_data)); + CHECK_EQ(r, 0); +} + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle + +namespace xpu = paddle::lite::kernels::xpu; +using SequenceReverseFp32 = + xpu::SequenceReverseCompute; +using SequenceReverseInt64 = + xpu::SequenceReverseCompute; + +REGISTER_LITE_KERNEL( + sequence_reverse, kXPU, kFloat, kNCHW, SequenceReverseFp32, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("Y", {LiteType::GetTensorTy(TARGET(kXPU))}) + .Finalize(); + +REGISTER_LITE_KERNEL( + sequence_reverse, kXPU, kInt64, kNCHW, SequenceReverseInt64, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))}) + .BindOutput("Y", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))}) + .Finalize(); diff --git a/lite/kernels/xpu/sequence_reverse_compute.h b/lite/kernels/xpu/sequence_reverse_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..91b285de767c65f93352380df7877e53d61ccd53 --- /dev/null +++ b/lite/kernels/xpu/sequence_reverse_compute.h @@ -0,0 +1,43 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/backends/xpu/target_wrapper.h" // XPUScratchPadGuard +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +template +class SequenceReverseCompute : public KernelLite { + public: + using param_t = operators::SequenceReverseParam; + + void PrepareForRun() override; + + void Run() override; + + private: + XPUScratchPadGuard lod_xpu_guard_; + std::unique_ptr lod_cpu; +}; + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/xpu/sequence_topk_avg_pooling_compute.cc b/lite/kernels/xpu/sequence_topk_avg_pooling_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..54c74211f9738995a8191c77e879a85762d71b3b --- /dev/null +++ b/lite/kernels/xpu/sequence_topk_avg_pooling_compute.cc @@ -0,0 +1,131 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/xpu/sequence_topk_avg_pooling_compute.h" +#include +#include "lite/backends/xpu/xpu_header_sitter.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +void SequenceTopkAvgPoolingCompute::PrepareForRun() { + lod_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(256 * sizeof(int)); + in_lod_cpu.reset(new int[64]); + row_lod_cpu.reset(new int[64]); + col_lod_cpu.reset(new int[64]); +} + +void SequenceTopkAvgPoolingCompute::Run() { + auto& param = this->template Param(); + auto& ctx = this->ctx_->template As(); + + auto* in = param.X; + auto* row = param.ROW; + auto* col = param.COLUMN; + auto* out = param.Out; + auto* pos = param.pos; + + auto channel_num = param.channel_num; + auto topks = param.topks; + auto k_num = topks.size(); + auto max_k = topks[topks.size() - 1]; + auto in_lod = in->lod()[0]; + + auto row_lod = row->lod()[0]; + auto col_lod = col->lod()[0]; + int batch_size = row_lod.size() - 1; + int pos_total_size = row_lod[batch_size] * channel_num * max_k; + std::vector vec_pos_shape; + vec_pos_shape.push_back(pos_total_size); + pos->Resize(vec_pos_shape); + auto pos_data = pos->mutable_data(TARGET(kXPU)); + + int offset = 0; + std::vector vec_out_lod; + vec_out_lod.reserve(batch_size + 1); + for (int i = 0; i <= batch_size; ++i) { + offset = row_lod[i]; + vec_out_lod.push_back(offset); + } + LoD lod_temp; + lod_temp.push_back(vec_out_lod); + out->set_lod(lod_temp); + + auto in_data = in->data(); + auto out_data = out->mutable_data(TARGET(kXPU)); + + int* in_lod_xpu = reinterpret_cast(lod_xpu_guard_->addr_); + int* row_lod_xpu = in_lod_xpu + in_lod.size(); + int* col_lod_xpu = row_lod_xpu + row_lod.size(); + int* topks_xpu = col_lod_xpu + col_lod.size(); + for (int i = 0; i < in_lod.size(); ++i) { + in_lod_cpu[i] = in_lod[i]; + } + for (int i = 0; i < row_lod.size(); ++i) { + row_lod_cpu[i] = row_lod[i]; + } + for (int i = 0; i < col_lod.size(); ++i) { + col_lod_cpu[i] = col_lod[i]; + } + xpu_memcpy(in_lod_xpu, + in_lod_cpu.get(), + in_lod.size() * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE); + xpu_memcpy(row_lod_xpu, + row_lod_cpu.get(), + row_lod.size() * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE); + xpu_memcpy(col_lod_xpu, + col_lod_cpu.get(), + col_lod.size() * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE); + xpu_memcpy(topks_xpu, + topks.data(), + topks.size() * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE); + + int r = xdnn::sequence_topk_avg_pooling(ctx.GetRawContext(), + in_data, + out_data, + pos_data, + batch_size, + channel_num, + in_lod_xpu, + row_lod_xpu, + col_lod_xpu, + topks_xpu, + k_num); + CHECK_EQ(r, 0); +} + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(sequence_topk_avg_pooling, + kXPU, + kFloat, + kNCHW, + paddle::lite::kernels::xpu::SequenceTopkAvgPoolingCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("ROW", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("COLUMN", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("pos", {LiteType::GetTensorTy(TARGET(kXPU))}) + .Finalize(); diff --git a/lite/kernels/xpu/sequence_topk_avg_pooling_compute.h b/lite/kernels/xpu/sequence_topk_avg_pooling_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..7c54ca96225ee9ec37d6d0487a526347c19fdb2d --- /dev/null +++ b/lite/kernels/xpu/sequence_topk_avg_pooling_compute.h @@ -0,0 +1,45 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/backends/xpu/target_wrapper.h" // XPUScratchPadGuard +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +class SequenceTopkAvgPoolingCompute + : public KernelLite { + public: + using param_t = operators::SequenceTopkAvgPoolingParam; + + void PrepareForRun() override; + + void Run() override; + + private: + XPUScratchPadGuard lod_xpu_guard_; + std::unique_ptr in_lod_cpu; + std::unique_ptr row_lod_cpu; + std::unique_ptr col_lod_cpu; +}; + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/xpu/subgraph_compute.cc b/lite/kernels/xpu/subgraph_compute.cc index 9c2191331c85a7f99ffb5a2e9662ed5831cb1dda..981922f8eacab57da4638e1fdcdd3df72465b379 100644 --- a/lite/kernels/xpu/subgraph_compute.cc +++ b/lite/kernels/xpu/subgraph_compute.cc @@ -27,12 +27,35 @@ namespace lite { namespace kernels { namespace xpu { -int SubgraphEngine::BuildDeviceProgram() { +bool SubgraphEngine::PrepareWorkspaceForDeviceProgram() { + // Obtain the origin input tensors, and create the origin output + // tensors(Don't try to access them before launch the device program or the + // origin program) + PrepareWorkspaceForOriginProgram(); + // Create the device input and output tensors, but don't initialize them + // with the dimensions + device_itensors_.resize(input_names_.size()); + for (int i = 0; i < input_names_.size(); i++) { + device_itensors_[i].reset(new hiai::AiTensor); + CHECK(device_itensors_[i]); + } + device_otensors_.resize(output_names_.size()); + for (int i = 0; i < output_names_.size(); i++) { + device_otensors_[i].reset(new hiai::AiTensor); + CHECK(device_otensors_[i]); + } + return true; +} + +bool SubgraphEngine::BuildDeviceProgram() { int status = 0; // Convert all of ops and their input vars and weights and added into the XPU // IR graph subgraph::xpu::Graph graph; const auto& bridges = subgraph::Registry::Instance(); + if (origin_program_.empty()) { + BuildOriginProgram(); + } for (auto& inst : origin_program_) { auto op = const_cast(inst.op()); CHECK(op); @@ -40,13 +63,13 @@ int SubgraphEngine::BuildDeviceProgram() { op->InferShape(); std::string op_type = op->op_info()->Type(); if (!bridges.Exists(op_type, TARGET(kXPU))) { - return subgraph::FAILED; + return false; } auto kernel = inst.kernel(); status |= bridges.Select(op_type, TARGET(kXPU))( reinterpret_cast(&graph), op, const_cast(kernel)); if (subgraph::CHECK_FAILED(status)) { - return subgraph::FAILED; + return false; } } // Obtain the output nodes of the XPU IR graph and build the graph to the XPU @@ -86,7 +109,7 @@ int SubgraphEngine::BuildDeviceProgram() { &graph.builder_, &graph.params_, &device_onodes); if (device_program_ == nullptr) { LOG(WARNING) << "[XPU] Build model failed!"; - return subgraph::FAILED; + return false; } // Query and check the dimensions of input and output tensors @@ -166,10 +189,10 @@ int SubgraphEngine::BuildDeviceProgram() { device_otensors_[i].strides = nullptr; device_otensors_[i].byte_offset = 0; } - return status; + return true; } -int SubgraphEngine::LaunchDeviceProgram() { +bool SubgraphEngine::LaunchDeviceProgram() { for (size_t i = 0; i < device_itensors_.size(); i++) { // Update the data pointer of DLTensor to track the origin input tensors device_itensors_[i].data = @@ -191,7 +214,7 @@ int SubgraphEngine::LaunchDeviceProgram() { const_cast(origin_otensors_[i]->raw_data()); device_program_->CopyOutputTo(i, &device_otensors_[i]); } - return 0; + return true; } void SubgraphCompute::PrepareForRun() { @@ -203,12 +226,11 @@ void SubgraphCompute::PrepareForRun() { param.output_data_names, param.scope)); CHECK(engine_); - engine_->Build(); } void SubgraphCompute::Run() { CHECK(engine_); - engine_->Launch(); + engine_->Run(); } } // namespace xpu diff --git a/lite/kernels/xpu/subgraph_compute.h b/lite/kernels/xpu/subgraph_compute.h index 601c8821bc826e350c233573bf7eff89cdf5c1f5..f09a06a85d5382c72e9efb20cede8bea1922f2da 100644 --- a/lite/kernels/xpu/subgraph_compute.h +++ b/lite/kernels/xpu/subgraph_compute.h @@ -39,13 +39,14 @@ class SubgraphEngine : public subgraph::Engine { ctx, block_idx, block_desc, input_names, output_names, scope) {} protected: - int BuildDeviceProgram() override; - int LaunchDeviceProgram() override; + bool PrepareWorkspaceForDeviceProgram() override; + bool BuildDeviceProgram() override; + bool LaunchDeviceProgram() override; std::vector device_inames_; std::vector device_onames_; - std::vector device_itensors_; - std::vector device_otensors_; + std::vector device_itensors_{}; + std::vector device_otensors_{}; std::unique_ptr device_program_{nullptr}; }; diff --git a/lite/kernels/xpu/var_conv_2d_compute.cc b/lite/kernels/xpu/var_conv_2d_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..b573c810922db98e901c9f9a1953116f3fdfc657 --- /dev/null +++ b/lite/kernels/xpu/var_conv_2d_compute.cc @@ -0,0 +1,139 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/xpu/var_conv_2d_compute.h" +#include +#include "lite/backends/xpu/xpu_header_sitter.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +void VarConv2DCompute::PrepareForRun() { + offset_x_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); + offset_y_xpu_guard_ = TargetWrapperXPU::MallocScratchPad(64 * sizeof(int)); + offset_x_cpu.reset(new int[64]); + offset_y_cpu.reset(new int[64]); +} + +void VarConv2DCompute::Run() { + auto& param = this->template Param(); + auto& ctx = this->ctx_->template As(); + + auto* bottom = param.X; + auto* w = param.W; + auto* top = param.Out; + + int output_channel = param.output_channel; + int input_channel = param.input_channel; + int kernel_h = param.kernel_h; + int kernel_w = param.kernel_w; + int stride_h = param.stride_h; + int stride_w = param.stride_w; + float w_max = param.__xpu__w_max; + bool fuse_relu = param.fuse_relu; + bool float_to_fix = param.__xpu__float_to_fix; + CHECK(float_to_fix) << "W should be fixed point"; + + xdnn::Activation_t act = xdnn::Activation_t::LINEAR; + if (fuse_relu) { + act = xdnn::Activation_t::RELU; + } + + int batch = bottom->lod()[0].size() - 1; + const auto& offset_x = bottom->lod()[2]; + const auto& offset_y = bottom->lod()[1]; + std::vector top_offset; + int top_size = 0; + top_offset.push_back(top_size); + for (int b = 0; b < batch; ++b) { + int width = offset_x[b + 1] - offset_x[b]; + int height = offset_y[b + 1] - offset_y[b]; + int top_im_x = 0; + int top_im_y = 0; + if (width != 0) { + top_im_x = (width - 1) / stride_w + 1; + } + if (height != 0) { + top_im_y = (height - 1) / stride_h + 1; + } + int top_im_size = top_im_y * top_im_x; + top_size += output_channel * top_im_size; + top_offset.push_back(top_size); + } + + LoD top_lod; + top_lod.push_back(top_offset); + top_lod.push_back(bottom->lod()[1]); + top_lod.push_back(bottom->lod()[2]); + top->set_lod(top_lod); + std::vector top_dims_vec{top_size}; + top_dims_vec.push_back(1); + top->Resize(top_dims_vec); + auto* top_data = top->mutable_data(TARGET(kXPU)); + + auto* bottom_data = bottom->data(); + auto* w_data = w->data(); + + int* offset_x_xpu = reinterpret_cast(offset_x_xpu_guard_->addr_); + int* offset_y_xpu = reinterpret_cast(offset_y_xpu_guard_->addr_); + for (int i = 0; i < (batch + 1); ++i) { + offset_x_cpu[i] = offset_x[i]; + offset_y_cpu[i] = offset_y[i]; + } + xpu_memcpy(offset_x_xpu, + offset_x_cpu.get(), + (batch + 1) * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE); + xpu_memcpy(offset_y_xpu, + offset_y_cpu.get(), + (batch + 1) * sizeof(int), + XPUMemcpyKind::XPU_HOST_TO_DEVICE); + + int r = xdnn::search_varconv(ctx.GetRawContext(), + batch, + input_channel, + output_channel, + kernel_h, + kernel_w, + stride_h, + stride_w, + bottom_data, + w_data, + offset_x_xpu, + offset_y_xpu, + top_data, + w_max, + act); + CHECK_EQ(r, 0); +} + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(var_conv_2d, + kXPU, + kFloat, + kNCHW, + paddle::lite::kernels::xpu::VarConv2DCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("W", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("Col", {LiteType::GetTensorTy(TARGET(kXPU))}) + .Finalize(); diff --git a/lite/kernels/xpu/var_conv_2d_compute.h b/lite/kernels/xpu/var_conv_2d_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..4d9f0ca7a9851a0c3071e72519c4ad1f40ea3483 --- /dev/null +++ b/lite/kernels/xpu/var_conv_2d_compute.h @@ -0,0 +1,44 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/backends/xpu/target_wrapper.h" // XPUScratchPadGuard +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +class VarConv2DCompute : public KernelLite { + public: + using param_t = operators::VarConv2DParam; + + void PrepareForRun() override; + + void Run() override; + + private: + XPUScratchPadGuard offset_x_xpu_guard_; + XPUScratchPadGuard offset_y_xpu_guard_; + std::unique_ptr offset_x_cpu; + std::unique_ptr offset_y_cpu; +}; + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/model_parser/base/block_desc.h b/lite/model_parser/base/block_desc.h index 3fd7998aa392034173f7474bc6b4d106f9fbcbd4..b3d2e2452714d474e9d6bc9280cb2c5455fecc98 100644 --- a/lite/model_parser/base/block_desc.h +++ b/lite/model_parser/base/block_desc.h @@ -54,10 +54,16 @@ class BlockDescWriteAPI { virtual void SetForwardBlockIdx(int32_t idx) { NotImplemented(); } template - T* AddVar(); + T* AddVar() { + NotImplemented(); + return nullptr; + } template - T* AddOp(); + T* AddOp() { + NotImplemented(); + return nullptr; + } virtual ~BlockDescWriteAPI() = default; diff --git a/lite/model_parser/base/op_desc.h b/lite/model_parser/base/op_desc.h index 185f5917c46127de1e16e274d0be95073b1a37f6..534ff0feabd2234b4d7a72894383020a5f64d594 100644 --- a/lite/model_parser/base/op_desc.h +++ b/lite/model_parser/base/op_desc.h @@ -73,7 +73,9 @@ class OpDescWriteAPI { } template - void SetAttr(const std::string& name, const T& v); + void SetAttr(const std::string& name, const T& v) { + NotImplemented(); + } virtual ~OpDescWriteAPI() = default; diff --git a/lite/model_parser/base/program_desc.h b/lite/model_parser/base/program_desc.h index c4423f288d8ea90039ffad0db08342b594415fe6..9ca128bd0aa8ba39752247074e8d57c0d23513f3 100644 --- a/lite/model_parser/base/program_desc.h +++ b/lite/model_parser/base/program_desc.h @@ -40,7 +40,10 @@ class ProgramDescWriteAPI { virtual void SetVersion(int64_t version) { NotImplemented(); } template - T* AddBlock(); + T* AddBlock() { + NotImplemented(); + return nullptr; + } virtual ~ProgramDescWriteAPI() = default; diff --git a/lite/model_parser/base/vector_view.h b/lite/model_parser/base/vector_view.h index c6337faa403a2c9a2758b90a4c1f7d092554b0b2..adec1933a2f40face415f610c9ccf2e9f275020c 100644 --- a/lite/model_parser/base/vector_view.h +++ b/lite/model_parser/base/vector_view.h @@ -57,6 +57,7 @@ class VectorView { public: typedef vector_view::VectorTraits Traits; explicit VectorView(typename Traits::vector_type const* cvec) { + CHECK(cvec); cvec_ = cvec; } typename Traits::subscript_return_type operator[](size_t i) const { diff --git a/lite/model_parser/compatible_pb.cc b/lite/model_parser/compatible_pb.cc index 3d66a5234994036397e445744499696909a8ab3e..b8db89230d56e22a361cc4972382d74b8d6f08fd 100644 --- a/lite/model_parser/compatible_pb.cc +++ b/lite/model_parser/compatible_pb.cc @@ -277,7 +277,7 @@ void OpAttrsCppToAny(const cpp::OpDesc &cpp_desc, OpDescType *any_desc) { template <> \ void TransformProgramDescCppToAny(const cpp::T &cpp_desc, \ NT::T *any_desc) { \ - auto desc = cpp_desc; \ + auto &desc = cpp_desc; \ if (desc.HasVersion()) { \ any_desc->SetVersion(desc.Version()); \ } \ diff --git a/lite/model_parser/flatbuffers/CMakeLists.txt b/lite/model_parser/flatbuffers/CMakeLists.txt index 5ca669bfeb512de47f3a15eb7119f12487accc8a..b7ae9514efaa406d6b339c7917ad3dc2ad4a1f4f 100644 --- a/lite/model_parser/flatbuffers/CMakeLists.txt +++ b/lite/model_parser/flatbuffers/CMakeLists.txt @@ -8,9 +8,6 @@ endfunction() lite_fbs_library(fbs_op_desc SRCS op_desc.cc FBS_DEPS framework_fbs_header) lite_fbs_library(fbs_var_desc SRCS var_desc.cc FBS_DEPS framework_fbs_header) lite_fbs_library(fbs_block_desc SRCS block_desc.cc FBS_DEPS framework_fbs_header) -lite_fbs_library(fbs_program_desc SRCS program_desc.cc FBS_DEPS framework_fbs_header) - -lite_cc_test(test_vector_view SRCS vector_view_test.cc) -if (TARGET test_vector_view) - add_dependencies(test_vector_view framework_fbs_header) -endif() +lite_cc_library(fbs_program_desc SRCS program_desc.cc DEPS fbs_op_desc fbs_var_desc fbs_block_desc) +lite_cc_library(fbs_io SRCS io.cc DEPS fbs_program_desc) +lite_cc_test(test_vector_view SRCS vector_view_test.cc DEPS fbs_program_desc) diff --git a/lite/model_parser/flatbuffers/block_desc.cc b/lite/model_parser/flatbuffers/block_desc.cc index fc43af6d6273c845f00e2046ae846f044659fe57..64087bb0707a891cc94a2d1234bb582312c3c10a 100644 --- a/lite/model_parser/flatbuffers/block_desc.cc +++ b/lite/model_parser/flatbuffers/block_desc.cc @@ -19,15 +19,27 @@ namespace lite { namespace fbs { template <> -proto::VarDesc* BlockDesc::GetVar(int32_t idx) { +proto::VarDesc const* BlockDesc::GetVar(int32_t idx) const { CHECK_LT(idx, VarsSize()) << "idx >= vars.size()"; - return const_cast(desc_->vars()->Get(idx)); + return desc_->vars()->Get(idx); } template <> -proto::OpDesc* BlockDesc::GetOp(int32_t idx) { +proto::OpDesc const* BlockDesc::GetOp(int32_t idx) const { CHECK_LT(idx, OpsSize()) << "idx >= ops.size()"; - return const_cast(desc_->ops()->Get(idx)); + return desc_->ops()->Get(idx); +} + +template <> +VarDesc const* BlockDesc::GetVar(int32_t idx) const { + CHECK_LT(idx, VarsSize()) << "idx >= vars.size()"; + return &vars_[idx]; +} + +template <> +OpDesc const* BlockDesc::GetOp(int32_t idx) const { + CHECK_LT(idx, OpsSize()) << "idx >= ops.size()"; + return &ops_[idx]; } } // namespace fbs diff --git a/lite/model_parser/flatbuffers/block_desc.h b/lite/model_parser/flatbuffers/block_desc.h index 0bfef5a452051c37e31f9d2c6ab2504e9addd800..dd99bdaa69020823ad6ca50438f21356eae41459 100644 --- a/lite/model_parser/flatbuffers/block_desc.h +++ b/lite/model_parser/flatbuffers/block_desc.h @@ -14,8 +14,11 @@ #pragma once +#include #include "lite/model_parser/base/block_desc.h" #include "lite/model_parser/flatbuffers/framework_generated.h" +#include "lite/model_parser/flatbuffers/op_desc.h" +#include "lite/model_parser/flatbuffers/var_desc.h" #include "lite/utils/all.h" namespace paddle { @@ -24,7 +27,17 @@ namespace fbs { class BlockDesc : public BlockDescAPI { public: - explicit BlockDesc(proto::BlockDesc* desc) : desc_(desc) { CHECK(desc_); } + explicit BlockDesc(proto::BlockDesc const* desc) : desc_(desc) { + CHECK(desc_); + vars_.reserve(VarsSize()); + ops_.reserve(OpsSize()); + for (size_t idx = 0; idx < VarsSize(); ++idx) { + vars_.push_back(VarDesc(desc_->vars()->Get(idx))); + } + for (size_t idx = 0; idx < OpsSize(); ++idx) { + ops_.push_back(OpDesc(desc_->ops()->Get(idx))); + } + } int32_t Idx() const override { return desc_->idx(); } @@ -33,11 +46,12 @@ class BlockDesc : public BlockDescAPI { size_t VarsSize() const override { return desc_->vars()->size(); } template - T* GetVar(int32_t idx); + T const* GetVar(int32_t idx) const; template - T const* GetVar(int32_t idx) const { - return GetVar(idx); + T* GetVar(int32_t idx) { + NotImplemented(); + return nullptr; } size_t OpsSize() const override { @@ -47,21 +61,32 @@ class BlockDesc : public BlockDescAPI { } template - T* GetOp(int32_t idx); + T const* GetOp(int32_t idx) const; template - T const* GetOp(int32_t idx) const { - return GetOp(idx); + T* GetOp(int32_t idx) { + NotImplemented(); + return nullptr; } + const std::vector& GetVars() const { return vars_; } + int32_t ForwardBlockIdx() const override { return desc_->forward_block_idx(); } - BlockDesc() = delete; + BlockDesc() { NotImplemented(); } private: - proto::BlockDesc* desc_; // not_own + proto::BlockDesc const* desc_; // not_own + std::vector vars_; + std::vector ops_; + + private: + void NotImplemented() const { + LOG(FATAL) << "The additional interfaces of BlockDesc is temporarily " + "unavailable in read-only mode."; + } }; } // namespace fbs diff --git a/lite/model_parser/flatbuffers/io.cc b/lite/model_parser/flatbuffers/io.cc new file mode 100644 index 0000000000000000000000000000000000000000..28fa32398cfe76075c1a429f9f1d348842465dfc --- /dev/null +++ b/lite/model_parser/flatbuffers/io.cc @@ -0,0 +1,37 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/model_parser/flatbuffers/io.h" +#include +#include + +namespace paddle { +namespace lite { +namespace fbs { + +void LoadModel(const std::string& path, ProgramDesc* prog) { + FILE* file = fopen(path.c_str(), "rb"); + fseek(file, 0, SEEK_END); + int64_t size = ftell(file); + rewind(file); + char* data = new char[size]; + size = fread(data, 1, size, file); + fclose(file); + std::unique_ptr buf(data); + prog->Init(std::move(buf)); +} + +} // namespace fbs +} // namespace lite +} // namespace paddle diff --git a/lite/model_parser/flatbuffers/io.h b/lite/model_parser/flatbuffers/io.h new file mode 100644 index 0000000000000000000000000000000000000000..1c81b192bbbcfc026bc4a2e77225c9a4c68208f3 --- /dev/null +++ b/lite/model_parser/flatbuffers/io.h @@ -0,0 +1,28 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/model_parser/flatbuffers/program_desc.h" + +namespace paddle { +namespace lite { +namespace fbs { + +void LoadModel(const std::string& path, ProgramDesc* prog); + +} // namespace fbs +} // namespace lite +} // namespace paddle diff --git a/lite/model_parser/flatbuffers/op_desc.h b/lite/model_parser/flatbuffers/op_desc.h index b2d78ca68af3d2f0595e710d9c0f75d8cceefbb3..e133ffbc27dce1a8c00eed82cc6d4fca76a8564d 100644 --- a/lite/model_parser/flatbuffers/op_desc.h +++ b/lite/model_parser/flatbuffers/op_desc.h @@ -30,7 +30,7 @@ namespace fbs { class OpDesc : public OpDescAPI { public: - explicit OpDesc(proto::OpDesc* desc) : desc_(desc) { CHECK(desc_); } + explicit OpDesc(proto::OpDesc const* desc) : desc_(desc) { CHECK(desc_); } std::string Type() const override { return desc_->type()->str(); } @@ -95,7 +95,7 @@ class OpDesc : public OpDescAPI { OpDescAPI::AttrType GetAttrType(const std::string& name) const override { const auto& attr = desc_->attrs()->LookupByKey(name.c_str()); - CHECK(attr); + CHECK(attr) << "Can not find attr: " << name; return static_cast(attr->type()); } @@ -124,10 +124,8 @@ class OpDesc : public OpDescAPI { template typename lite::OpDataTypeTrait::RT GetAttr(size_t idx) const; - OpDesc() = delete; - private: - proto::OpDesc* desc_; + proto::OpDesc const* desc_; // To reduce overhead, we expect to use namespace aliasing to make cpp::Desc // and flatbuffers::Desc replace each other. However, there is no direct @@ -138,6 +136,7 @@ class OpDesc : public OpDescAPI { // caused by different building options. public: + OpDesc() { NotImplemented(); } bool HasInput(const std::string& param) const { return desc_->inputs()->LookupByKey(param.c_str()) != nullptr; } diff --git a/lite/model_parser/flatbuffers/program_desc.cc b/lite/model_parser/flatbuffers/program_desc.cc index 36429103a72f7b54651aac8d30671f7b3c41956e..f04954a9dc890a0b5866a7e6c3f3c7b18f2783e4 100644 --- a/lite/model_parser/flatbuffers/program_desc.cc +++ b/lite/model_parser/flatbuffers/program_desc.cc @@ -19,9 +19,16 @@ namespace lite { namespace fbs { template <> -proto::BlockDesc* ProgramDesc::GetBlock(int32_t idx) { +proto::BlockDesc const* ProgramDesc::GetBlock( + int32_t idx) const { CHECK_LT(idx, BlocksSize()) << "idx >= blocks.size()"; - return const_cast(desc_->blocks()->Get(idx)); + return desc_->blocks()->Get(idx); +} + +template <> +BlockDesc const* ProgramDesc::GetBlock(int32_t idx) const { + CHECK_LT(idx, BlocksSize()) << "idx >= blocks.size()"; + return &blocks_[idx]; } } // namespace fbs diff --git a/lite/model_parser/flatbuffers/program_desc.h b/lite/model_parser/flatbuffers/program_desc.h index f41fd996b2533321c2494ea6c15d53ed31a3e7c8..c651d9dc0671aced942bb28466e829a40226c2ba 100644 --- a/lite/model_parser/flatbuffers/program_desc.h +++ b/lite/model_parser/flatbuffers/program_desc.h @@ -15,7 +15,10 @@ #pragma once #include +#include +#include #include "lite/model_parser/base/program_desc.h" +#include "lite/model_parser/flatbuffers/block_desc.h" #include "lite/model_parser/flatbuffers/framework_generated.h" #include "lite/utils/all.h" @@ -26,18 +29,40 @@ namespace fbs { class ProgramDesc : public ProgramDescAPI { public: ProgramDesc() = default; - explicit ProgramDesc(proto::ProgramDesc *desc) : desc_(desc) { CHECK(desc); } + explicit ProgramDesc(std::unique_ptr buf) { + Init(std::move(buf)); + } size_t BlocksSize() const override { return desc_->blocks()->size(); } + void Init(std::unique_ptr buf) { + CHECK(buf.get() != nullptr); + buf_ = std::move(buf); + desc_ = proto::GetProgramDesc(buf_.get()); + blocks_.reserve(BlocksSize()); + for (size_t idx = 0; idx < BlocksSize(); ++idx) { + blocks_.push_back(BlockDesc(desc_->blocks()->Get(idx))); + } + } + + void CopyFrom(const ProgramDesc& other) { + size_t length = strlen(static_cast(other.raw_buf())); + std::unique_ptr buf(new char[length]); + memcpy(buf.get(), other.raw_buf(), length); + Init(std::move(buf)); + } + template - T *GetBlock(int32_t idx); + T const* GetBlock(int32_t idx) const; template - T const *GetBlock(int32_t idx) const { - return GetBlock(idx); + T* GetBlock(int32_t idx) { + NotImplemented(); + return nullptr; } + const std::vector& GetBlocks() const { return blocks_; } + bool HasVersion() const override { return desc_->version() != nullptr; } int64_t Version() const override { @@ -45,8 +70,22 @@ class ProgramDesc : public ProgramDescAPI { return desc_->version()->version(); } + proto::ProgramDesc const* raw_desc() const { return desc_; } + + const void* raw_buf() const { return buf_.get(); } + private: - proto::ProgramDesc *desc_; // not_own + proto::ProgramDesc const* desc_; + std::unique_ptr buf_; + std::vector blocks_; + + private: + ProgramDesc& operator=(const ProgramDesc&) = delete; + ProgramDesc(const ProgramDesc&) = delete; + void NotImplemented() const { + LOG(FATAL) << "The additional interfaces of ProgramDesc is temporarily " + "unavailable in read-only mode."; + } }; } // namespace fbs diff --git a/lite/model_parser/flatbuffers/var_desc.h b/lite/model_parser/flatbuffers/var_desc.h index 387e52ec3150e5bc01f365934c310fb1990ce1e4..48d81df30f78ca668bbe9358b4f488fd2f4d3d66 100644 --- a/lite/model_parser/flatbuffers/var_desc.h +++ b/lite/model_parser/flatbuffers/var_desc.h @@ -27,7 +27,7 @@ namespace fbs { class VarDesc : public VarDescAPI { public: - explicit VarDesc(proto::VarDesc* desc) : desc_(desc) {} + explicit VarDesc(proto::VarDesc const* desc) : desc_(desc) {} std::string Name() const override { return desc_->name()->str(); } @@ -48,10 +48,14 @@ class VarDesc : public VarDescAPI { return dims_vec; } - VarDesc() = delete; + VarDescAPI::Type GetDataType() const { + CHECK(GetType() == VarDescAPI::Type::LOD_TENSOR); + return static_cast( + desc_->type()->lod_tensor()->tensor()->data_type()); + } private: - proto::VarDesc* desc_; + proto::VarDesc const* desc_; // To reduce overhead, we expect to use namespace aliasing to make cpp::Desc // and flatbuffers::Desc replace each other. However, there is no direct @@ -62,10 +66,7 @@ class VarDesc : public VarDescAPI { // caused by different building options. public: - VarDescAPI::Type GetDataType() const { - NotImplemented(); - return data_type_; - } + VarDesc() { NotImplemented(); } void SetDataType(Type data_type) { NotImplemented(); } void SetShape(const std::vector& dims) { NotImplemented(); } @@ -74,7 +75,6 @@ class VarDesc : public VarDescAPI { LOG(FATAL) << "The additional interfaces of VarDesc is temporarily " "unavailable in read-only mode."; } - Type data_type_; std::vector shape_; }; diff --git a/lite/model_parser/flatbuffers/vector_view.h b/lite/model_parser/flatbuffers/vector_view.h index ccb700072690c3ecfe55549a1f39d3d574686c7d..1cc890e98d2a85b3113fcf49a68701595e63964e 100644 --- a/lite/model_parser/flatbuffers/vector_view.h +++ b/lite/model_parser/flatbuffers/vector_view.h @@ -104,20 +104,32 @@ class VectorView { explicit VectorView(typename Traits::vector_type const* cvec) { cvec_ = cvec; } - std::string operator[](size_t i) const { return cvec_->operator[](i)->str(); } + std::string operator[](size_t i) const { + CHECK(cvec_); + return cvec_->operator[](i)->str(); + } vector_view::FBSStrIterator begin() const { + CHECK(cvec_); return vector_view::FBSStrIterator(cvec_->begin()); } vector_view::FBSStrIterator end() const { + CHECK(cvec_); return vector_view::FBSStrIterator(cvec_->end()); } - size_t size() const { return cvec_->size(); } + size_t size() const { + if (cvec_ == nullptr) { + return 0; + } + return cvec_->size(); + } operator std::vector() const { VLOG(5) << "Copying elements out of VectorView will damage performance."; std::vector tmp; - tmp.reserve(cvec_->size()); - for (auto val : *cvec_) { - tmp.push_back(val->str()); + tmp.reserve(size()); + if (cvec_ != nullptr) { + for (auto val : *cvec_) { + tmp.push_back(val->str()); + } } return tmp; } diff --git a/lite/model_parser/general/block_desc.cc b/lite/model_parser/general/block_desc.cc index 0766333d66c1299b738098a33a1a2c6433782337..11d2376bc05a6086036b0fd026666b0b16b2de84 100644 --- a/lite/model_parser/general/block_desc.cc +++ b/lite/model_parser/general/block_desc.cc @@ -24,6 +24,12 @@ VarDesc* BlockDesc::GetVar(int32_t idx) { return &vars_[idx]; } +template <> +VarDesc const* BlockDesc::GetVar(int32_t idx) const { + CHECK_LT(idx, VarsSize()) << "idx >= vars.size()"; + return &vars_[idx]; +} + template <> VarDesc* BlockDesc::AddVar() { vars_.emplace_back(); @@ -36,6 +42,12 @@ OpDesc* BlockDesc::GetOp(int32_t idx) { return &ops_[idx]; } +template <> +OpDesc const* BlockDesc::GetOp(int32_t idx) const { + CHECK_LT(idx, OpsSize()) << "idx >= ops.size()"; + return &ops_[idx]; +} + template <> OpDesc* BlockDesc::AddOp() { ops_.emplace_back(); diff --git a/lite/model_parser/general/block_desc.h b/lite/model_parser/general/block_desc.h index 3b1b1ff4e6616c936bd3b09bff563656f6bdbc6a..e618e570c20bfb0915289d2da625865fc5b64676 100644 --- a/lite/model_parser/general/block_desc.h +++ b/lite/model_parser/general/block_desc.h @@ -46,12 +46,10 @@ class BlockDesc : public BlockDescAPI { template T* GetVar(int32_t idx); - std::vector& GetVars() { return vars_; } - template - T const* GetVar(int32_t idx) const { - return GetVar(idx); - } + T const* GetVar(int32_t idx) const; + + std::vector& GetVars() { return vars_; } template T* AddVar(); @@ -64,9 +62,7 @@ class BlockDesc : public BlockDescAPI { T* GetOp(int32_t idx); template - T const* GetOp(int32_t idx) const { - return GetOp(idx); - } + T const* GetOp(int32_t idx) const; template T* AddOp(); diff --git a/lite/model_parser/general/program_desc.cc b/lite/model_parser/general/program_desc.cc index 670c7684312265d5a1f1eb2cbef54ed5fe62b2d2..b767a6f77ca657e8ec02b8e182dd8a8b62b7d6ab 100644 --- a/lite/model_parser/general/program_desc.cc +++ b/lite/model_parser/general/program_desc.cc @@ -24,6 +24,12 @@ BlockDesc* ProgramDesc::GetBlock(int32_t idx) { return &blocks_[idx]; } +template <> +BlockDesc const* ProgramDesc::GetBlock(int32_t idx) const { + CHECK_LT(idx, BlocksSize()) << "idx >= blocks.size()"; + return &blocks_[idx]; +} + template <> BlockDesc* ProgramDesc::AddBlock() { blocks_.emplace_back(); diff --git a/lite/model_parser/general/program_desc.h b/lite/model_parser/general/program_desc.h index 0fbc0742fe149075d3ede2b688fd071727baafc9..bbc045412d2086473375863575e5d16146d84751 100644 --- a/lite/model_parser/general/program_desc.h +++ b/lite/model_parser/general/program_desc.h @@ -30,6 +30,13 @@ class ProgramDesc : public ProgramDescAPI { public: ProgramDesc() = default; + void CopyFrom(const ProgramDesc& other) { + version_ = other.Version(); + blocks_ = other.blocks(); + } + + const std::vector& blocks() const { return blocks_; } + size_t BlocksSize() const override { return blocks_.size(); } void ClearBlocks() override { blocks_.clear(); } @@ -37,12 +44,10 @@ class ProgramDesc : public ProgramDescAPI { template T* GetBlock(int32_t idx); - std::vector& GetBlocks() { return blocks_; } - template - T const* GetBlock(int32_t idx) const { - return GetBlock(idx); - } + T const* GetBlock(int32_t idx) const; + + std::vector& GetBlocks() { return blocks_; } template T* AddBlock(); diff --git a/lite/model_parser/model_parser.cc b/lite/model_parser/model_parser.cc index 640dd044174c831e4570c5e8cc81af02fa50f0c4..cf93e7f2cedc8db5c5a18d26fa2499dd79c456de 100644 --- a/lite/model_parser/model_parser.cc +++ b/lite/model_parser/model_parser.cc @@ -176,7 +176,7 @@ void LoadCombinedParamsPb(const std::string &path, const cpp::ProgramDesc &cpp_prog, bool params_from_memory) { CHECK(scope); - auto prog = cpp_prog; + auto &prog = cpp_prog; auto &main_block_desc = *prog.GetBlock(0); // Get vars @@ -310,7 +310,7 @@ void SaveModelPb(const std::string &model_dir, void SaveCombinedParamsPb(const std::string &path, const lite::Scope &exec_scope, const cpp::ProgramDesc &cpp_prog) { - auto prog = cpp_prog; + auto &prog = cpp_prog; auto &main_block_desc = *prog.GetBlock(0); // Get vars @@ -526,7 +526,7 @@ void SaveCombinedParamsNaive(const std::string &path, naive_buffer::proto::CombinedParamsDesc pt_desc(&table); naive_buffer::CombinedParamsDesc desc(&pt_desc); - auto prog = cpp_prog; + auto &prog = cpp_prog; auto &main_block_desc = *prog.GetBlock(0); // set unique_var_names to avoid saving shared params repeatedly std::set unique_var_names; @@ -681,7 +681,7 @@ void LoadCombinedParamsNaive(const std::string &path, } // Check all params loaded - auto prog = cpp_prog; + auto &prog = cpp_prog; auto &main_block_desc = *prog.GetBlock(0); for (size_t i = 0; i < main_block_desc.VarsSize(); ++i) { auto &var = *main_block_desc.GetVar(i); diff --git a/lite/model_parser/naive_buffer/block_desc.h b/lite/model_parser/naive_buffer/block_desc.h index 61c624d9593244a3e680b5541e32cd4aeee949d5..3f99302c4033f3f732e0c79017fc251c6d0c40b5 100644 --- a/lite/model_parser/naive_buffer/block_desc.h +++ b/lite/model_parser/naive_buffer/block_desc.h @@ -55,11 +55,6 @@ class BlockDesc : public BlockDescAPI { template T* GetVar(int32_t idx); - template - T const* GetVar(int32_t idx) const { - return GetVar(idx); - } - template T* AddVar(); @@ -70,11 +65,6 @@ class BlockDesc : public BlockDescAPI { template T* GetOp(int32_t idx); - template - T const* GetOp(int32_t idx) const { - return GetOp(idx); - } - template T* AddOp(); diff --git a/lite/model_parser/naive_buffer/program_desc.h b/lite/model_parser/naive_buffer/program_desc.h index 1552b6bcdd7ea7f8efd3954e2625712a7684a5f2..6f5277ad32aa2fccf52134a262975cfdbe1b9d6c 100644 --- a/lite/model_parser/naive_buffer/program_desc.h +++ b/lite/model_parser/naive_buffer/program_desc.h @@ -45,11 +45,6 @@ class ProgramDesc : public ProgramDescAPI { template T *GetBlock(int32_t idx); - template - T const *GetBlock(int32_t idx) const { - return GetBlock(idx); - } - template T *AddBlock(); diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index 6eeb5ed7e2ab65f2947d051f10c77dccf9a2eda9..45b49f91ace12da5934471e01afd91c2832f1d6d 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -168,6 +168,9 @@ add_operator(__xpu__resnet50_op extra SRCS __xpu__resnet50_op.cc DEPS ${op_DEPS} add_operator(__xpu__multi_encoder_op extra SRCS __xpu__multi_encoder_op.cc DEPS ${op_DEPS}) add_operator(__xpu__embedding_with_eltwise_add_op extra SRCS __xpu__embedding_with_eltwise_add_op.cc DEPS ${op_DEPS}) add_operator(__xpu__fc_op extra SRCS __xpu__fc_op.cc DEPS ${op_DEPS}) +add_operator(__xpu__resnet_cbam_op extra SRCS __xpu__resnet_cbam_op.cc DEPS ${op_DEPS}) +add_operator(__xpu__search_attention_op extra SRCS __xpu__search_attention_op.cc DEPS ${op_DEPS}) +add_operator(__xpu__mmdnn_op extra SRCS __xpu__mmdnn_op.cc DEPS ${op_DEPS}) if (NOT LITE_WITH_X86) lite_cc_test(test_fc_op SRCS fc_op_test.cc diff --git a/lite/operators/__xpu__mmdnn_op.cc b/lite/operators/__xpu__mmdnn_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..35024da911ba0659c5005a1adc641fa3adc2f282 --- /dev/null +++ b/lite/operators/__xpu__mmdnn_op.cc @@ -0,0 +1,239 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/__xpu__mmdnn_op.h" +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool XPUMmdnnBidEmbGrnnAttOp::CheckShape() const { return true; } + +bool XPUMmdnnBidEmbGrnnAttOp::InferShapeImpl() const { + auto& id_dims = param_.id0->dims(); + auto& id_lod = param_.id0->lod()[0]; + auto& emb_tbl_dims = param_.emb_tbl->dims(); + auto& grnn_wh_dims = param_.grnn_rv_wh->dims(); + + param_.grnn_fw_pool_out->Resize( + {(int64_t)id_lod.size() - 1, grnn_wh_dims[2]}); + param_.grnn_rv_pool_out->Resize( + {(int64_t)id_lod.size() - 1, grnn_wh_dims[2]}); + param_.att_pool_out->Resize( + {(int64_t)id_lod.size() - 1, 2 * grnn_wh_dims[2]}); + param_.concat_3in1_out->Resize({id_dims[0], 3 * grnn_wh_dims[2]}); + param_.concat_3in1_out->set_lod({id_lod}); + param_.emb_fw_out->Resize({id_dims[0], emb_tbl_dims[1]}); + param_.emb_fw_out->set_lod({id_lod}); + return true; +} + +bool XPUMmdnnBidEmbGrnnAttOp::AttachImpl(const cpp::OpDesc& op_desc, + lite::Scope* scope) { + param_.id0 = + scope->FindVar(op_desc.Input("id0").front())->GetMutable(); + param_.id1 = + scope->FindVar(op_desc.Input("id1").front())->GetMutable(); + param_.emb_tbl = scope->FindVar(op_desc.Input("emb_tbl").front()) + ->GetMutable(); + param_.grnn_fw_wh = scope->FindVar(op_desc.Input("grnn_fw_wh").front()) + ->GetMutable(); + param_.grnn_fw_wi = scope->FindVar(op_desc.Input("grnn_fw_wi").front()) + ->GetMutable(); + param_.grnn_rv_wh = scope->FindVar(op_desc.Input("grnn_rv_wh").front()) + ->GetMutable(); + param_.grnn_rv_wi = scope->FindVar(op_desc.Input("grnn_rv_wi").front()) + ->GetMutable(); + param_.att_fc_w = scope->FindVar(op_desc.Input("att_fc_w").front()) + ->GetMutable(); + param_.att_fc_b = scope->FindVar(op_desc.Input("att_fc_b").front()) + ->GetMutable(); + + param_.grnn_fw_pool_out = + scope->FindVar(op_desc.Output("grnn_fw_pool_out").front()) + ->GetMutable(); + param_.grnn_rv_pool_out = + scope->FindVar(op_desc.Output("grnn_rv_pool_out").front()) + ->GetMutable(); + param_.att_pool_out = scope->FindVar(op_desc.Output("att_pool_out").front()) + ->GetMutable(); + param_.concat_3in1_out = + scope->FindVar(op_desc.Output("concat_3in1_out").front()) + ->GetMutable(); + param_.emb_fw_out = scope->FindVar(op_desc.Output("emb_fw_out").front()) + ->GetMutable(); + + param_.grnn_fw_wh_maxs = + op_desc.GetAttr>("grnn_fw_wh_maxs"); + param_.grnn_fw_wi_maxs = + op_desc.GetAttr>("grnn_fw_wi_maxs"); + param_.grnn_rv_wh_maxs = + op_desc.GetAttr>("grnn_rv_wh_maxs"); + param_.grnn_rv_wi_maxs = + op_desc.GetAttr>("grnn_rv_wi_maxs"); + param_.att_fc_w_max = op_desc.GetAttr("att_fc_w_max"); + return true; +} + +bool XPUMmdnnBidEmbAttOp::CheckShape() const { return true; } + +bool XPUMmdnnBidEmbAttOp::InferShapeImpl() const { + auto& id_dims = param_.id0->dims(); + auto& id_lod = param_.id0->lod()[0]; + auto& emb_tbl_dims = param_.emb_tbl->dims(); + + param_.att_pool_out->Resize({(int64_t)id_lod.size() - 1, emb_tbl_dims[1]}); + param_.emb_fw_out->Resize({id_dims[0], emb_tbl_dims[1]}); + param_.emb_fw_out->set_lod({id_lod}); + return true; +} + +bool XPUMmdnnBidEmbAttOp::AttachImpl(const cpp::OpDesc& op_desc, + lite::Scope* scope) { + param_.id0 = + scope->FindVar(op_desc.Input("id0").front())->GetMutable(); + param_.id1 = + scope->FindVar(op_desc.Input("id1").front())->GetMutable(); + param_.emb_tbl = scope->FindVar(op_desc.Input("emb_tbl").front()) + ->GetMutable(); + param_.att_fc_w = scope->FindVar(op_desc.Input("att_fc_w").front()) + ->GetMutable(); + param_.att_fc_b = scope->FindVar(op_desc.Input("att_fc_b").front()) + ->GetMutable(); + + param_.att_pool_out = scope->FindVar(op_desc.Output("att_pool_out").front()) + ->GetMutable(); + param_.emb_fw_out = scope->FindVar(op_desc.Output("emb_fw_out").front()) + ->GetMutable(); + + param_.att_fc_w_max = op_desc.GetAttr("att_fc_w_max"); + return true; +} + +bool XPUMmdnnMatchConvTopkOp::CheckShape() const { return true; } + +bool XPUMmdnnMatchConvTopkOp::InferShapeImpl() const { + int channel_num = param_.channel_num; + std::vector topks = param_.topks; + auto row_dim = param_.input_x->dims(); + auto num_k = topks.size(); + auto row_shape_0 = row_dim[0]; + std::vector vec_out_shape; + vec_out_shape.push_back(row_shape_0); + vec_out_shape.push_back(channel_num * num_k); + + param_.topk_out->Resize(lite::DDim(vec_out_shape)); + param_.topk_out->set_lod(param_.input_x->lod()); + return true; +} + +bool XPUMmdnnMatchConvTopkOp::AttachImpl(const cpp::OpDesc& op_desc, + lite::Scope* scope) { + param_.input_x = scope->FindVar(op_desc.Input("input_x").front()) + ->GetMutable(); + param_.input_y = scope->FindVar(op_desc.Input("input_y").front()) + ->GetMutable(); + param_.input_w = scope->FindVar(op_desc.Input("input_w").front()) + ->GetMutable(); + param_.conv_w = scope->FindVar(op_desc.Input("conv_w").front()) + ->GetMutable(); + + param_.topk_out = scope->FindVar(op_desc.Output("topk_out").front()) + ->GetMutable(); + + param_.input_w_max = op_desc.GetAttr("input_w_max"); + param_.conv_w_max = op_desc.GetAttr("conv_w_max"); + param_.topks = op_desc.GetAttr>("topks"); + param_.channel_num = op_desc.GetAttr("channel_num"); + param_.dim_t = op_desc.GetAttr("dim_t"); + return true; +} + +bool XPUMmdnnMergeAllOp::CheckShape() const { return true; } + +bool XPUMmdnnMergeAllOp::InferShapeImpl() const { + int64_t dim0 = param_.concat_7in1_x[0]->dims()[0]; + int64_t dim1 = param_.fc2_w->dims()[0]; + std::vector vec_out_shape; + vec_out_shape.push_back(dim0); + vec_out_shape.push_back(dim1); + + param_.out->Resize(lite::DDim(vec_out_shape)); + return true; +} + +bool XPUMmdnnMergeAllOp::AttachImpl(const cpp::OpDesc& op_desc, + lite::Scope* scope) { + param_.concat_7in1_x.clear(); + for (auto& name : op_desc.Input("concat_7in1_x")) { + auto t = scope->FindVar(name)->GetMutable(); + param_.concat_7in1_x.push_back(t); + } + param_.concat_2in1_x.clear(); + for (auto& name : op_desc.Input("concat_2in1_x")) { + auto t = scope->FindVar(name)->GetMutable(); + param_.concat_2in1_x.push_back(t); + } + param_.grnn_fw_wh = scope->FindVar(op_desc.Input("grnn_fw_wh").front()) + ->GetMutable(); + param_.grnn_fw_wi = scope->FindVar(op_desc.Input("grnn_fw_wi").front()) + ->GetMutable(); + param_.grnn_rv_wh = scope->FindVar(op_desc.Input("grnn_rv_wh").front()) + ->GetMutable(); + param_.grnn_rv_wi = scope->FindVar(op_desc.Input("grnn_rv_wi").front()) + ->GetMutable(); + param_.fc0_w = scope->FindVar(op_desc.Input("fc0_w").front()) + ->GetMutable(); + param_.fc0_b = scope->FindVar(op_desc.Input("fc0_b").front()) + ->GetMutable(); + param_.fc1_w = scope->FindVar(op_desc.Input("fc1_w").front()) + ->GetMutable(); + param_.fc1_b = scope->FindVar(op_desc.Input("fc1_b").front()) + ->GetMutable(); + param_.fc2_w = scope->FindVar(op_desc.Input("fc2_w").front()) + ->GetMutable(); + param_.fc2_b = scope->FindVar(op_desc.Input("fc2_b").front()) + ->GetMutable(); + + param_.out = + scope->FindVar(op_desc.Output("out").front())->GetMutable(); + + param_.grnn_fw_wh_maxs = + op_desc.GetAttr>("grnn_fw_wh_maxs"); + param_.grnn_fw_wi_maxs = + op_desc.GetAttr>("grnn_fw_wi_maxs"); + param_.grnn_rv_wh_maxs = + op_desc.GetAttr>("grnn_rv_wh_maxs"); + param_.grnn_rv_wi_maxs = + op_desc.GetAttr>("grnn_rv_wi_maxs"); + param_.fc0_w_max = op_desc.GetAttr("fc0_w_max"); + param_.fc1_w_max = op_desc.GetAttr("fc1_w_max"); + param_.fc2_w_max = op_desc.GetAttr("fc2_w_max"); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(__xpu__mmdnn_bid_emb_grnn_att, + paddle::lite::operators::XPUMmdnnBidEmbGrnnAttOp); +REGISTER_LITE_OP(__xpu__mmdnn_bid_emb_att, + paddle::lite::operators::XPUMmdnnBidEmbAttOp); +REGISTER_LITE_OP(__xpu__mmdnn_match_conv_topk, + paddle::lite::operators::XPUMmdnnMatchConvTopkOp); +REGISTER_LITE_OP(__xpu__mmdnn_merge_all, + paddle::lite::operators::XPUMmdnnMergeAllOp); diff --git a/lite/operators/__xpu__mmdnn_op.h b/lite/operators/__xpu__mmdnn_op.h new file mode 100644 index 0000000000000000000000000000000000000000..7038898cad0823746f905e4e60c06885b57a737c --- /dev/null +++ b/lite/operators/__xpu__mmdnn_op.h @@ -0,0 +1,107 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/op_lite.h" + +namespace paddle { +namespace lite { +namespace operators { + +class XPUMmdnnBidEmbGrnnAttOp : public OpLite { + public: + XPUMmdnnBidEmbGrnnAttOp() {} + + explicit XPUMmdnnBidEmbGrnnAttOp(const std::string &op_type) + : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShapeImpl() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { return "XPUMmdnnBidEmbGrnnAttOp"; } + + private: + mutable XPUMmdnnBidEmbGrnnAttParam param_; +}; + +class XPUMmdnnBidEmbAttOp : public OpLite { + public: + XPUMmdnnBidEmbAttOp() {} + + explicit XPUMmdnnBidEmbAttOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShapeImpl() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { return "XPUMmdnnBidEmbAttOp"; } + + private: + mutable XPUMmdnnBidEmbAttParam param_; +}; + +class XPUMmdnnMatchConvTopkOp : public OpLite { + public: + XPUMmdnnMatchConvTopkOp() {} + + explicit XPUMmdnnMatchConvTopkOp(const std::string &op_type) + : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShapeImpl() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { return "XPUMmdnnMatchConvTopkOp"; } + + private: + mutable XPUMmdnnMatchConvTopkParam param_; +}; + +class XPUMmdnnMergeAllOp : public OpLite { + public: + XPUMmdnnMergeAllOp() {} + + explicit XPUMmdnnMergeAllOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShapeImpl() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { return "XPUMmdnnMergeAllOp"; } + + private: + mutable XPUMmdnnMergeAllParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/__xpu__resnet_cbam_op.cc b/lite/operators/__xpu__resnet_cbam_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..6013f4fa90033c51df7a0d3bb670e02f8bf4628d --- /dev/null +++ b/lite/operators/__xpu__resnet_cbam_op.cc @@ -0,0 +1,72 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/__xpu__resnet_cbam_op.h" +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool XPUResNetCbamOp::CheckShape() const { return true; } + +bool XPUResNetCbamOp::InferShapeImpl() const { + auto input_shape = param_.input->dims(); + std::vector output_shape_vec{1, 64}; + paddle::lite::DDim output_shape(output_shape_vec); + output_shape[0] = input_shape[0]; + param_.output->Resize(output_shape); + return true; +} + +bool XPUResNetCbamOp::AttachImpl(const cpp::OpDesc& op_desc, + lite::Scope* scope) { + param_.input = const_cast( + &scope->FindVar(op_desc.Input("Input").front())->Get()); + param_.output = scope->FindVar(op_desc.Output("Output").front()) + ->GetMutable(); + + param_.filter.clear(); + for (auto& name : op_desc.Input("Filter")) { + auto t = + const_cast(&scope->FindVar(name)->Get()); + param_.filter.push_back(t); + } + param_.bias.clear(); + for (auto& name : op_desc.Input("Bias")) { + if (name.substr(0, 11) == "placeholder") { + param_.bias.push_back(nullptr); + } else { + auto t = + const_cast(&scope->FindVar(name)->Get()); + param_.bias.push_back(t); + } + } + param_.max_filter.clear(); + for (auto& name : op_desc.Input("MaxFilter")) { + auto t = + const_cast(&scope->FindVar(name)->Get()); + param_.max_filter.push_back(t); + } + + param_.pool_p = op_desc.GetAttr("pool_p"); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(__xpu__resnet_cbam, paddle::lite::operators::XPUResNetCbamOp); diff --git a/lite/operators/__xpu__resnet_cbam_op.h b/lite/operators/__xpu__resnet_cbam_op.h new file mode 100644 index 0000000000000000000000000000000000000000..26e5bafeae31183e9054e7e77ea46813c95db707 --- /dev/null +++ b/lite/operators/__xpu__resnet_cbam_op.h @@ -0,0 +1,44 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/core/op_lite.h" + +namespace paddle { +namespace lite { +namespace operators { + +class XPUResNetCbamOp : public OpLite { + public: + XPUResNetCbamOp() {} + explicit XPUResNetCbamOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShapeImpl() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "ResNetCbam"; } + + private: + mutable XPUResNetCbamParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/__xpu__search_attention_op.cc b/lite/operators/__xpu__search_attention_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..acd8c817b0d81ef03df1c05417b8bb2f56c00812 --- /dev/null +++ b/lite/operators/__xpu__search_attention_op.cc @@ -0,0 +1,56 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/__xpu__search_attention_op.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool XPUMmdnnSearchAttentionOp::CheckShape() const { return true; } + +bool XPUMmdnnSearchAttentionOp::InferShapeImpl() const { + auto& x_dims = param_.X->dims(); + param_.Out->Resize(x_dims); + param_.Out->set_lod(param_.X->lod()); + return true; +} + +bool XPUMmdnnSearchAttentionOp::AttachImpl(const cpp::OpDesc& op_desc, + lite::Scope* scope) { + auto x = op_desc.Input("X").front(); + auto w = op_desc.Input("W").front(); + auto b = op_desc.Input("b").front(); + auto out = op_desc.Output("Out").front(); + + param_.X = scope->FindVar(x)->GetMutable(); + param_.W = scope->FindVar(w)->GetMutable(); + param_.b = scope->FindVar(b)->GetMutable(); + param_.Out = scope->FindVar(out)->GetMutable(); + + param_.W_max = op_desc.GetAttr("W_max"); + param_.pad_id = op_desc.GetAttr("pad_id"); + param_.alpha0 = op_desc.GetAttr("alpha0"); + param_.alpha1 = op_desc.GetAttr("alpha1"); + param_.mask = op_desc.GetAttr("mask"); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(__xpu__mmdnn_search_attention, + paddle::lite::operators::XPUMmdnnSearchAttentionOp); diff --git a/lite/operators/__xpu__search_attention_op.h b/lite/operators/__xpu__search_attention_op.h new file mode 100644 index 0000000000000000000000000000000000000000..81bd366ee8a51dc8d2d7fb4c9cb03d2199bcb4f2 --- /dev/null +++ b/lite/operators/__xpu__search_attention_op.h @@ -0,0 +1,49 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/core/op_lite.h" + +namespace paddle { +namespace lite { +namespace operators { + +class XPUMmdnnSearchAttentionOp : public OpLite { + public: + XPUMmdnnSearchAttentionOp() {} + + explicit XPUMmdnnSearchAttentionOp(const std::string &op_type) + : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShapeImpl() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { + return "XPUMmdnnSearchAttentionOp"; + } + + private: + mutable XPUMmdnnSearchAttentionParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/activation_grad_ops.cc b/lite/operators/activation_grad_ops.cc index b31163e5dce6d9b77d923ba44ed58952263610a5..a30231be921e2c4445bb4c7a72c9572b14c1c0f5 100644 --- a/lite/operators/activation_grad_ops.cc +++ b/lite/operators/activation_grad_ops.cc @@ -41,15 +41,11 @@ bool ActivationGradOp::AttachImpl(const cpp::OpDesc& opdesc, if (opdesc.HasInput("X")) { auto X_name = opdesc.Input("X").front(); param_.X = GetVar(scope, X_name); - } else { - param_.X = param_.X_grad; } if (opdesc.HasInput("Out")) { auto Out_name = opdesc.Input("Out").front(); param_.Out = GetVar(scope, Out_name); - } else { - param_.Out = param_.Out_grad; } return true; @@ -60,3 +56,5 @@ bool ActivationGradOp::AttachImpl(const cpp::OpDesc& opdesc, } // namespace paddle REGISTER_LITE_OP(square_grad, paddle::lite::operators::ActivationGradOp); +REGISTER_LITE_OP(relu_grad, paddle::lite::operators::ActivationGradOp); +REGISTER_LITE_OP(tanh_grad, paddle::lite::operators::ActivationGradOp); diff --git a/lite/operators/deformable_conv_op.h b/lite/operators/deformable_conv_op.h index aa736fcef6b6f74740253b8607e8bfcd938d0ff8..69b764758699089bdee0a64e33a01d838b011ec0 100644 --- a/lite/operators/deformable_conv_op.h +++ b/lite/operators/deformable_conv_op.h @@ -83,7 +83,7 @@ class DeformableConvOpLite : public OpLite { param_.conv_param.filter = scope->FindVar(Filter)->GetMutable(); param_.conv_param.strides = op_desc.GetAttr>("strides"); - auto paddings = op_desc.GetAttr>("paddings"); + std::vector paddings = op_desc.GetAttr>("paddings"); auto dilations = op_desc.GetAttr>("dilations"); param_.conv_param.groups = op_desc.GetAttr("groups"); param_.conv_param.dilations = std::make_shared>(dilations); diff --git a/lite/operators/elementwise_ops.cc b/lite/operators/elementwise_ops.cc index e1da396697b96001db4c45d81c9f7bb6f4b538b5..5895bb667aa22507d362004627304ecf78e085f1 100644 --- a/lite/operators/elementwise_ops.cc +++ b/lite/operators/elementwise_ops.cc @@ -145,6 +145,7 @@ REGISTER_LITE_OP(elementwise_mul, paddle::lite::operators::ElementwiseOp); REGISTER_LITE_OP(elementwise_max, paddle::lite::operators::ElementwiseOp); REGISTER_LITE_OP(elementwise_div, paddle::lite::operators::ElementwiseOp); REGISTER_LITE_OP(elementwise_mod, paddle::lite::operators::ElementwiseOp); +REGISTER_LITE_OP(elementwise_pow, paddle::lite::operators::ElementwiseOp); // #ifdef LITE_WITH_TRAIN // REGISTER_LITE_OP(elementwise_sub_grad, diff --git a/lite/operators/match_matrix_tensor_op.cc b/lite/operators/match_matrix_tensor_op.cc index 1cc751109f76a96097d363b493322dde182a715d..fd70143131b458c1d985a21a6d9d84c707ba9986 100644 --- a/lite/operators/match_matrix_tensor_op.cc +++ b/lite/operators/match_matrix_tensor_op.cc @@ -94,6 +94,18 @@ bool MatchMatrixTensorOpLite::AttachImpl(const cpp::OpDesc& op_desc, param_.dim_t = op_desc.GetAttr("dim_t"); + if (op_desc.HasAttr("fuse_relu")) { + param_.fuse_relu = op_desc.GetAttr("fuse_relu"); + } +#ifdef LITE_WITH_XPU + if (op_desc.HasAttr("__xpu__float_to_fix")) { + param_.__xpu__float_to_fix = op_desc.GetAttr("__xpu__float_to_fix"); + } + if (op_desc.HasAttr("__xpu__w_max")) { + param_.__xpu__w_max = op_desc.GetAttr("__xpu__w_max"); + } +#endif + return true; } diff --git a/lite/operators/max_pool_with_index_op.h b/lite/operators/max_pool_with_index_op.h index bd82743c279c4728483c72f017a8fa6e94cf3eb4..dfc220907549dc9ce61726b79cb1626c2734b234 100644 --- a/lite/operators/max_pool_with_index_op.h +++ b/lite/operators/max_pool_with_index_op.h @@ -54,7 +54,7 @@ class MaxPoolWithIndexOpLite : public OpLite { param_.ksize = op_desc.GetAttr>("ksize"); param_.global_pooling = op_desc.GetAttr("global_pooling"); param_.strides = op_desc.GetAttr>("strides"); - auto paddings = op_desc.GetAttr>("paddings"); + std::vector paddings = op_desc.GetAttr>("paddings"); if (op_desc.HasAttr("adaptive")) { param_.adaptive = op_desc.GetAttr("adaptive"); } diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index f6fffeff65335f39f48a69459fb6d16f4a697306..f351e8e5344424d80fa79f8d7c83be3bf367441f 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -1129,6 +1129,11 @@ struct VarConv2DParam : ParamBase { int kernel_w; bool fuse_relu{false}; + +#ifdef LITE_WITH_XPU + bool __xpu__float_to_fix{false}; // Is W already converted to int16/int8 + float __xpu__w_max{0.0f}; // Abs max in W +#endif }; /// ----------------------- shape operators ---------------------- @@ -1378,6 +1383,13 @@ struct SearchFcParam : ParamBase { const lite::Tensor* b{}; lite::Tensor* Out{}; int out_size{}; + + bool fuse_relu{false}; + +#ifdef LITE_WITH_XPU + bool __xpu__float_to_fix{false}; // Is W already converted to int16/int8 + float __xpu__w_max{0.0f}; // Abs max in W +#endif }; /// --------------------- match_matrix_tensor operators -------------------- struct MatchMatrixTensorParam : ParamBase { @@ -1388,6 +1400,12 @@ struct MatchMatrixTensorParam : ParamBase { lite::Tensor* tmp{}; int dim_t; + bool fuse_relu{false}; + +#ifdef LITE_WITH_XPU + bool __xpu__float_to_fix{false}; // Is w already converted to int16/int8 + float __xpu__w_max{0.0f}; // Abs max in w +#endif }; /// --------------------- search_seq_depadding operators -------------------- @@ -1409,6 +1427,12 @@ struct SearchGrnnParam : ParamBase { lite::Tensor* tmp_buffer{}; lite::Tensor* idx_sorted_by_width{}; lite::Tensor* layout_input{}; + +#ifdef LITE_WITH_XPU + bool __xpu__float_to_fix{false}; // Is wi/wh already converted to int16/int8 + std::vector __xpu__wi_max; // Abs max in wi + std::vector __xpu__wh_max; // Abs max in wh +#endif }; struct SplitLodTensorParam : ParamBase { @@ -1563,6 +1587,106 @@ struct XPUFcParam : ParamBase { std::string activation_type{""}; }; +struct XPUResNetCbamParam : ParamBase { + lite::Tensor* input{}; + std::vector filter; + std::vector bias; + std::vector max_filter; + lite::Tensor* output{}; + + float pool_p{1.0f}; +}; + +struct XPUMmdnnSearchAttentionParam : ParamBase { + lite::Tensor* X{}; + lite::Tensor* W{}; + lite::Tensor* b{}; + lite::Tensor* Out{}; + + float W_max{0.0f}; + int pad_id{0}; + float alpha0{1.0f}; + float alpha1{1.0f}; + float mask{1.0f}; +}; + +struct XPUMmdnnBidEmbGrnnAttParam : ParamBase { + lite::Tensor* id0{}; + lite::Tensor* id1{}; + lite::Tensor* emb_tbl{}; + lite::Tensor* grnn_fw_wh{}; + lite::Tensor* grnn_fw_wi{}; + lite::Tensor* grnn_rv_wh{}; + lite::Tensor* grnn_rv_wi{}; + lite::Tensor* att_fc_w{}; + lite::Tensor* att_fc_b{}; + + std::vector grnn_fw_wh_maxs; + std::vector grnn_fw_wi_maxs; + std::vector grnn_rv_wh_maxs; + std::vector grnn_rv_wi_maxs; + float att_fc_w_max{0.0f}; + + lite::Tensor* grnn_fw_pool_out{}; // 1 + lite::Tensor* grnn_rv_pool_out{}; // 2 + lite::Tensor* att_pool_out{}; // 3 + lite::Tensor* concat_3in1_out{}; // 4 + lite::Tensor* emb_fw_out{}; // 5 +}; + +struct XPUMmdnnBidEmbAttParam : ParamBase { + lite::Tensor* id0{}; + lite::Tensor* id1{}; + lite::Tensor* emb_tbl{}; + lite::Tensor* att_fc_w{}; + lite::Tensor* att_fc_b{}; + + float att_fc_w_max{0.0f}; + + lite::Tensor* att_pool_out{}; // 1 + lite::Tensor* emb_fw_out{}; // 2 +}; + +struct XPUMmdnnMatchConvTopkParam : ParamBase { + lite::Tensor* input_x{}; + lite::Tensor* input_y{}; + lite::Tensor* input_w{}; + lite::Tensor* conv_w{}; + + float input_w_max{0.0f}; + float conv_w_max{0.0f}; + std::vector topks; + int channel_num{0}; + int dim_t{0}; + + lite::Tensor* topk_out{}; +}; + +struct XPUMmdnnMergeAllParam : ParamBase { + std::vector concat_7in1_x; + std::vector concat_2in1_x; + lite::Tensor* grnn_fw_wh{}; + lite::Tensor* grnn_fw_wi{}; + lite::Tensor* grnn_rv_wh{}; + lite::Tensor* grnn_rv_wi{}; + lite::Tensor* fc0_w{}; + lite::Tensor* fc0_b{}; + lite::Tensor* fc1_w{}; + lite::Tensor* fc1_b{}; + lite::Tensor* fc2_w{}; + lite::Tensor* fc2_b{}; + + std::vector grnn_fw_wh_maxs; + std::vector grnn_fw_wi_maxs; + std::vector grnn_rv_wh_maxs; + std::vector grnn_rv_wi_maxs; + float fc0_w_max{0.0f}; + float fc1_w_max{0.0f}; + float fc2_w_max{0.0f}; + + lite::Tensor* out{}; +}; + // For DeformableConvolution op struct DeformableConvParam : ParamBase { lite::Tensor* x{}; diff --git a/lite/operators/search_fc_op.cc b/lite/operators/search_fc_op.cc index 71e62c2ae729b4e1516a219888b9af3f7d994428..8024c38f9cc4a6d3ba2d47d6c61e716dd57bb362 100644 --- a/lite/operators/search_fc_op.cc +++ b/lite/operators/search_fc_op.cc @@ -70,6 +70,18 @@ bool SearchFcOpLite::AttachImpl(const cpp::OpDesc &op_desc, param_.Out = scope->FindVar(Out)->GetMutable(); param_.out_size = op_desc.GetAttr("out_size"); + if (op_desc.HasAttr("fuse_relu")) { + param_.fuse_relu = op_desc.GetAttr("fuse_relu"); + } +#ifdef LITE_WITH_XPU + if (op_desc.HasAttr("__xpu__float_to_fix")) { + param_.__xpu__float_to_fix = op_desc.GetAttr("__xpu__float_to_fix"); + } + if (op_desc.HasAttr("__xpu__w_max")) { + param_.__xpu__w_max = op_desc.GetAttr("__xpu__w_max"); + } +#endif + return true; } diff --git a/lite/operators/search_grnn_op.cc b/lite/operators/search_grnn_op.cc index 1ced477c109d8cd93485f0193523887759939f17..6f743693bc782e636064ca398539433b497dc645 100644 --- a/lite/operators/search_grnn_op.cc +++ b/lite/operators/search_grnn_op.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "lite/operators/search_grnn_op.h" +#include #include "lite/core/op_lite.h" #include "lite/core/op_registry.h" @@ -84,6 +85,18 @@ bool SearchGrnnOpLite::AttachImpl(const cpp::OpDesc& op_desc, param_.layout_input = scope->FindVar(layout_input)->GetMutable(); +#ifdef LITE_WITH_XPU + if (op_desc.HasAttr("__xpu__float_to_fix")) { + param_.__xpu__float_to_fix = op_desc.GetAttr("__xpu__float_to_fix"); + } + if (op_desc.HasAttr("__xpu__wi_max")) { + param_.__xpu__wi_max = op_desc.GetAttr>("__xpu__wi_max"); + } + if (op_desc.HasAttr("__xpu__wh_max")) { + param_.__xpu__wh_max = op_desc.GetAttr>("__xpu__wh_max"); + } +#endif + return true; } diff --git a/lite/operators/sequence_reverse_op.cc b/lite/operators/sequence_reverse_op.cc index 19a47cac9da666269fc5ef2a172ff0295b71e95d..fa2b0553aa2ac84f27d5d27d31df5ce9584d82c3 100644 --- a/lite/operators/sequence_reverse_op.cc +++ b/lite/operators/sequence_reverse_op.cc @@ -34,6 +34,7 @@ bool SequenceReverseOp::InferShapeImpl() const { const auto *input = param_.X; auto out_dims = input->dims(); param_.Out->Resize(out_dims); + param_.Out->set_lod(param_.X->lod()); return true; } @@ -45,6 +46,7 @@ bool SequenceReverseOp::AttachImpl(const cpp::OpDesc &opdesc, scope->FindVar(opdesc.Output("Y").front())->GetMutable(); CHECK(param_.X); CHECK(param_.Out); + return true; } diff --git a/lite/operators/var_conv_2d_op.cc b/lite/operators/var_conv_2d_op.cc index 8cf11f6465d73646ec9bf846cbe6347bdc4b9f5b..83b6cc6a24ed1537adec8fd7d54a477edf91f873 100644 --- a/lite/operators/var_conv_2d_op.cc +++ b/lite/operators/var_conv_2d_op.cc @@ -52,6 +52,15 @@ bool VarConv2dOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { if (opdesc.HasAttr("fuse_relu")) { param_.fuse_relu = opdesc.GetAttr("fuse_relu"); } +#ifdef LITE_WITH_XPU + if (opdesc.HasAttr("__xpu__float_to_fix")) { + param_.__xpu__float_to_fix = opdesc.GetAttr("__xpu__float_to_fix"); + } + if (opdesc.HasAttr("__xpu__w_max")) { + param_.__xpu__w_max = opdesc.GetAttr("__xpu__w_max"); + } +#endif + return true; } diff --git a/lite/tests/api/CMakeLists.txt b/lite/tests/api/CMakeLists.txt index 7e5ddecb082e17a4a70a41fef0f359c354f2e97e..844c3f2ac7146e05b2d93eac76279df022e06652 100644 --- a/lite/tests/api/CMakeLists.txt +++ b/lite/tests/api/CMakeLists.txt @@ -16,6 +16,15 @@ if(LITE_WITH_XPU) add_dependencies(test_ernie_lite_xpu extern_lite_download_ernie_tar_gz) add_dependencies(test_bert_lite_xpu extern_lite_download_bert_tar_gz) endif() + # TODO(miaotianxiang): enable later + #lite_cc_test(test_fpr_lite_xpu SRCS test_fpr_lite_xpu.cc + #DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils + #${ops} ${host_kernels} ${x86_kernels} ${xpu_kernels} + #ARGS --model_dir=${LITE_MODEL_DIR}/resnet50) + #lite_cc_test(test_mmdnn_lite_xpu SRCS test_mmdnn_lite_xpu.cc + #DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils + #${ops} ${host_kernels} ${x86_kernels} ${xpu_kernels} + #ARGS --model_dir=${LITE_MODEL_DIR}/resnet50) endif() if(LITE_WITH_RKNPU) diff --git a/lite/tests/api/test_fpr_lite_xpu.cc b/lite/tests/api/test_fpr_lite_xpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..026c25690fe2a673be0a5a97b163d7bbe5fdb4f6 --- /dev/null +++ b/lite/tests/api/test_fpr_lite_xpu.cc @@ -0,0 +1,69 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/api/lite_api_test_helper.h" +#include "lite/api/paddle_api.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/api/test_helper.h" +#include "lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { + +TEST(ResnetCbam, test_resnet_cbam_lite_xpu) { + lite_api::CxxConfig config; + // config.set_model_dir(FLAGS_model_dir); + config.set_model_file(FLAGS_model_dir + "/__model__"); + config.set_param_file(FLAGS_model_dir + "/__params__"); + config.set_valid_places({lite_api::Place{TARGET(kXPU), PRECISION(kFloat)}, + lite_api::Place{TARGET(kX86), PRECISION(kFloat)}, + lite_api::Place{TARGET(kHost), PRECISION(kFloat)}}); + config.set_xpu_workspace_l3_size_per_thread(); + auto predictor = lite_api::CreatePaddlePredictor(config); + + auto input_tensor = predictor->GetInput(0); + std::vector input_shape{1, 3, 224, 224}; + input_tensor->Resize(input_shape); + auto* data = input_tensor->mutable_data(); + int input_num = 1; + for (size_t i = 0; i < input_shape.size(); ++i) { + input_num *= input_shape[i]; + } + for (int i = 0; i < input_num; i++) { + data[i] = 1; + } + + for (int i = 0; i < FLAGS_warmup; ++i) { + predictor->Run(); + } + + auto start = GetCurrentUS(); + for (int i = 0; i < FLAGS_repeats; ++i) { + predictor->Run(); + } + + LOG(INFO) << "================== Speed Report ==================="; + LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads + << ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats + << ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 + << " ms in average."; +} + +} // namespace lite +} // namespace paddle diff --git a/lite/tests/api/test_mmdnn_lite_xpu.cc b/lite/tests/api/test_mmdnn_lite_xpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..a2a98821e70cb462b23887f851cfc4bce6b463ca --- /dev/null +++ b/lite/tests/api/test_mmdnn_lite_xpu.cc @@ -0,0 +1,311 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/api/lite_api_test_helper.h" +#include "lite/api/paddle_api.h" +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/api/paddle_use_passes.h" +#include "lite/api/test_helper.h" +#include "lite/utils/cp_logging.h" +#include "lite/utils/string.h" + +DEFINE_bool(perf, false, "perf?"); +DEFINE_string(perf_input, "perf_input", "perf_input"); + +namespace paddle { +namespace lite { + +std::vector input0; +std::vector input0_lod = {0}; +std::vector input1; +std::vector input1_lod = {0}; +std::vector input2; +std::vector input2_lod = {0}; +std::vector input3; +std::vector input3_lod = {0}; +std::vector input4; +std::vector input4_lod = {0}; +std::vector input5; +std::vector input5_lod = {0}; + +void ParseInput() { + std::string raw_input = + "0 1;145 10251 839 3719 428 52;1050488 1050488 911898 3719 760166 " + "760166;3719 428 52 18 1102 10327 252 20 153 2897 1146 70 156 6 145 " + "10251 839 5 1779 1729 1779 1729 18 2707 6 2707 20 4742 4937 432 6 " + "3869;3719 760166 760166 18 1035176 1035176 764393 764393 1259006 767614 " + "767614 1020808 769579 793958 793958 1050488 911898 751332 751332 750336 " + "750799 750336 751575 751575 751544 751735 751397 751365 751512 751512 " + "753011 751562;3719 428 52 18 1102 10327 252 20 153 2897 1146 70 156 6 " + "145 10251 839 2 1211 3 3719 720 1540 145 10251 839 9405 4315 5998 4 2 " + "600 373 41 3719 428 52 44 10251 4302 1319 7 12 2 768 6 918 6 841 870 8 " + "843 8 271;3719 760166 760166 18 1035176 1035176 764393 764393 1259006 " + "767614 767614 1020808 769579 793958 793958 1050488 911898 2 773899 " + "773899 3719 1118420 1118420 1050488 1050488 911898 9405 4315 5998 4 2 " + "785435 785435 41 3719 760166 760166 44 10251 4302 1319 750118 750118 2 " + "750465 750465 750274 750398 750233 751252 751252 753447 752830 753112;\n" + "0 0;145 10251 839 3719 428 52;1050488 1050488 911898 3719 760166 " + "760166;2109 2467 1805 227 3719 428 52 18 1102 10327 252 20 6 242 78 6 " + "532 78;2109 2467 1805 1245431 1245431 760166 760166 18 1035176 1035176 " + "764393 764393 752116 242 750370 750370 752081 751247;2109 2467 1805 227 " + "3719 428 52 18 1102 10327 252 20 2 145 242 1050 252 3582 2212;2109 2467 " + "1805 1245431 1245431 760166 760166 18 1035176 1035176 764393 764393 2 " + "871717 871717 757921 757921 3582 2212;\n" + "0 0;145 10251 839 3719 428 52;1050488 1050488 911898 3719 760166 " + "760166;145 10251 839 76 31 1337 823 7506 567 65 170 8 21293 3719 5 43 " + "394 743 42;1050488 1050488 911898 750016 750016 1337 823 7506 762617 " + "762617 866652 8 21293 3719 5 43 914758 914758 757202;145 10251 839 76 " + "31 1337 823 7506 567 65 170 8 21293 3719 2 17580 30 523324 3 10251 4104 " + "281 3 8511 3719 2217 3 13 226 3083 4 11251 1606 357 9 2 145 10251 839 " + "76 31 1337 823 7506 567 65 170 2 7506 2445 8 145 10251 839 528 839 " + "19670 6538;1050488 1050488 911898 750016 750016 1337 823 7506 762617 " + "762617 866652 8 21293 3719 2 816626 816626 523324 3 1181698 1181698 " + "751656 780821 1063148 3719 2217 3 752498 752498 831323 753602 11251 " + "1606 357 9 2 1050488 1050488 911898 750016 750016 1337 823 7506 762617 " + "762617 866652 2 7506 753045 753045 756756 1050488 911898 528 839 19670 " + "6538;\n" + "0 0;145 10251 839 3719 428 52;1050488 1050488 911898 3719 760166 " + "760166;145 10251 839 99 4 1102 10327 2196 41 3719 428 52 44 99 4 2899 " + "229 10 10 10;1050488 1050488 911898 807966 750273 1035176 1035176 " + "1237875 41 3719 760166 760166 753645 753645 750273 2899 229 750001 " + "750001 750001;145 10251 839 99 4 1102 10327 2196 41 3719 428 52 44 99 4 " + "2899 229 10 10 10 2 1177 8 145 10251 839 99 4 1102 10327 2196 41 3719 " + "428 52 44 99 4 2 101 8 1922 17 2184 2 1154 1922 72 1198 1266 " + "4516;1050488 1050488 911898 807966 750273 1035176 1035176 1237875 41 " + "3719 760166 760166 753645 753645 750273 2899 229 750001 750001 750001 2 " + "750257 750257 756756 1050488 911898 807966 750273 1035176 1035176 " + "1237875 41 3719 760166 760166 753645 753645 750273 2 764513 764513 " + "851213 851213 854628 2 753018 753018 754317 753328 754085 754070;\n" + "0 0;145 10251 839 3719 428 52;1050488 1050488 911898 3719 760166 " + "760166;73 5347 112 8 145 10251 839 262 169 22729 3719 6 743 6 339 1156 " + "78 136 399 693 128 571;776150 776150 112 756756 756756 1050488 911898 " + "791355 791355 22729 3719 6 758277 758277 750137 750234 750241 750178 " + "750055 750216 750212 750049;73 5347 112 8 145 10251 839 262 169 22729 " + "3719 2 588 415 549 415 115 23;776150 776150 112 756756 756756 1050488 " + "911898 791355 791355 22729 3719 2 750221 750221 750262 750277 750277 " + "750261;"; + auto raw_lines = Split(raw_input, "\n"); + for (auto& raw_line : raw_lines) { + auto inputx = Split(raw_line, ";"); + for (size_t i = 1; i < inputx.size(); ++i) { + auto tokens = Split(inputx[i], " "); + static std::vector* const input_array[] = { + &input0, &input0, &input1, &input2, &input3, &input4, &input5}; + static std::vector* const lod_array[] = {&input0_lod, + &input0_lod, + &input1_lod, + &input2_lod, + &input3_lod, + &input4_lod, + &input5_lod}; + for (auto token : tokens) { + input_array[i]->push_back((int64_t)atoi(token.c_str())); + } + lod_array[i]->push_back((uint64_t)tokens.size() + + (*lod_array[i])[lod_array[i]->size() - 1]); + } + } + return; +} + +class MmdnnReader { + std::ifstream ifs; + std::vector StringSplit(const std::string& in, + const std::string& delim) { + std::vector ret; + if (in == "") { + return ret; + } + auto begpos = in.find_first_not_of(delim); + while (begpos != std::string::npos) { + auto endpos = in.find_first_of(delim, begpos); + if (endpos == std::string::npos) { + endpos = in.size(); + } + std::string ssubstr = in.substr(begpos, endpos - begpos); + ret.push_back(ssubstr); + begpos = endpos + 1; + if (endpos >= (in.size() - 1)) { + break; + } + } + return ret; + } + + public: + std::vector data[6]; + std::vector lod[6]; + + void Init(std::string file_name) { ifs.open(file_name); } + + int Read(int maxline) { + for (int i = 0; i < 6; i++) { + data[i].clear(); + } + for (int i = 0; i < 6; i++) { + lod[i].clear(); + lod[i].push_back(0); + } + std::string line; + int cnt = 0; + while (cnt < maxline && getline(ifs, line)) { + std::vector split1 = StringSplit(line, ";"); + for (int i = 1; i < 7; i++) { + std::vector split2 = StringSplit(split1[i], " "); + if (split2.size() == 0) { + split2.push_back("1280000"); + } + for (size_t j = 0; j < split2.size(); j++) { + data[i - 1].push_back(std::stoi(split2[j].c_str(), nullptr, 0)); + } + // if (i % 2 == 1) { + // lod[i / 2].push_back(lod[i / 2].back() + split2.size()); + //} + lod[i - 1].push_back(lod[i - 1].back() + split2.size()); + } + cnt++; + } + return cnt; + } +}; + +TEST(MMDNN, test_mmdnn_lite_xpu) { + lite_api::CxxConfig config; + config.set_model_dir(FLAGS_model_dir); + config.set_valid_places({lite_api::Place{TARGET(kXPU), PRECISION(kFloat)}, + lite_api::Place{TARGET(kXPU), PRECISION(kInt64)}, + lite_api::Place{TARGET(kX86), PRECISION(kFloat)}, + lite_api::Place{TARGET(kX86), PRECISION(kInt64)}, + lite_api::Place{TARGET(kHost), PRECISION(kFloat)}}); + config.set_xpu_workspace_l3_size_per_thread(); + auto predictor = lite_api::CreatePaddlePredictor(config); + + if (FLAGS_perf) { + MmdnnReader reader; + reader.Init(FLAGS_perf_input); + int UB_batch = 40; // upper bound of batch + int iter = 0; + double tsc_sum = 0; + + while (true) { + int batch = reader.Read(UB_batch); + if (batch <= 0) { + break; + } + ++iter; + for (int i = 0; i < 6; ++i) { + auto input_x = predictor->GetInput(i); + input_x->Resize({(int64_t)reader.data[i].size(), 1}); + input_x->SetLoD({reader.lod[i]}); + auto* data_x = input_x->mutable_data(); + memcpy(data_x, + reader.data[i].data(), + reader.data[i].size() * sizeof(int64_t)); + } + + auto start = GetCurrentUS(); + predictor->Run(); + auto end = GetCurrentUS(); + tsc_sum += end - start; + } + LOG(INFO) << "================== Speed Report ==================="; + LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " + << FLAGS_threads << ", warmup: " << FLAGS_warmup + << ", repeats: " << iter << ", spend " << tsc_sum / iter / 1000.0 + << " ms in average."; + + return; + } + + ParseInput(); + + { + std::vector input0_shape{(int64_t)input0.size(), 1}; + auto input_tensor0 = predictor->GetInput(0); + input_tensor0->Resize(input0_shape); + input_tensor0->SetLoD({input0_lod}); + auto* data0 = input_tensor0->mutable_data(); + memcpy(data0, input0.data(), sizeof(int64_t) * input0.size()); + } + { + std::vector input1_shape{(int64_t)input1.size(), 1}; + auto input_tensor1 = predictor->GetInput(1); + input_tensor1->Resize(input1_shape); + input_tensor1->SetLoD({input1_lod}); + auto* data1 = input_tensor1->mutable_data(); + memcpy(data1, input1.data(), sizeof(int64_t) * input1.size()); + } + { + std::vector input2_shape{(int64_t)input2.size(), 1}; + auto input_tensor2 = predictor->GetInput(2); + input_tensor2->Resize(input2_shape); + input_tensor2->SetLoD({input2_lod}); + auto* data2 = input_tensor2->mutable_data(); + memcpy(data2, input2.data(), sizeof(int64_t) * input2.size()); + } + { + std::vector input3_shape{(int64_t)input3.size(), 1}; + auto input_tensor3 = predictor->GetInput(3); + input_tensor3->Resize(input3_shape); + input_tensor3->SetLoD({input3_lod}); + auto* data3 = input_tensor3->mutable_data(); + memcpy(data3, input3.data(), sizeof(int64_t) * input3.size()); + } + { + std::vector input4_shape{(int64_t)input4.size(), 1}; + auto input_tensor4 = predictor->GetInput(4); + input_tensor4->Resize(input4_shape); + input_tensor4->SetLoD({input4_lod}); + auto* data4 = input_tensor4->mutable_data(); + memcpy(data4, input4.data(), sizeof(int64_t) * input4.size()); + } + { + std::vector input5_shape{(int64_t)input5.size(), 1}; + auto input_tensor5 = predictor->GetInput(5); + input_tensor5->Resize(input5_shape); + input_tensor5->SetLoD({input5_lod}); + auto* data5 = input_tensor5->mutable_data(); + memcpy(data5, input5.data(), sizeof(int64_t) * input5.size()); + } + + for (int i = 0; i < FLAGS_warmup; ++i) { + predictor->Run(); + } + + auto start = GetCurrentUS(); + for (int i = 0; i < FLAGS_repeats; ++i) { + predictor->Run(); + } + + auto out = predictor->GetOutput(0); + auto out_shape = out->shape(); + auto out_size = std::accumulate( + out_shape.begin(), out_shape.end(), 1, std::multiplies()); + for (int i = 0; i < out_size; ++i) { + LOG(INFO) << "out[" << i << "] = " << out->data()[i]; + } + + LOG(INFO) << "================== Speed Report ==================="; + LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads + << ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats + << ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 + << " ms in average."; +} + +} // namespace lite +} // namespace paddle diff --git a/lite/tests/kernels/activation_grad_compute_test.cc b/lite/tests/kernels/activation_grad_compute_test.cc index 5d5046b01dee6c84f341159b68300197c20695e6..2ad5b80a910f323b34b039eabda0ceb4b49784c5 100644 --- a/lite/tests/kernels/activation_grad_compute_test.cc +++ b/lite/tests/kernels/activation_grad_compute_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/kernels/arm/activation_grad_compute.h" +#include "lite/kernels/host/activation_grad_compute.h" #include #include "lite/core/op_registry.h" #include "lite/kernels/arm/activation_compute.h" @@ -20,13 +20,11 @@ namespace paddle { namespace lite { namespace kernels { -namespace arm { using param_t = operators::ActivationParam; using grad_param_t = operators::ActivationGradParam; -using kernel_t = SquareCompute; -using grad_kernel_t = SquareGradCompute; +template class ActivationGradTester { public: explicit ActivationGradTester(DDim dims) : dims_(dims) {} @@ -71,22 +69,28 @@ class ActivationGradTester { void run_backward(grad_param_t* param, grad_kernel_t* kernel, const std::vector& in_vec, + const std::vector& out_vec, const std::vector& out_grad_vec, float* in_grad_vec) { Tensor x; + Tensor out; Tensor x_grad; Tensor out_grad; x.Resize(dims_); + out.Resize(dims_); x_grad.Resize(dims_); out_grad.Resize(dims_); auto* x_data = x.mutable_data(); + auto* out_data = out.mutable_data(); auto* out_grad_data = out_grad.mutable_data(); for (int i = 0; i < dims_.production(); i++) { x_data[i] = in_vec[i]; + out_data[i] = out_vec[i]; out_grad_data[i] = out_grad_vec[i]; } param->X = &x; + param->Out = &out; param->X_grad = &x_grad; param->Out_grad = &out_grad; kernel->SetParam(*param); @@ -102,7 +106,9 @@ class ActivationGradTester { std::vector x(dims_.production()); std::vector out(dims_.production()); for (int i = 0; i < dims_.production(); i++) { - x[i] = 1.0 * static_cast(i % 128) * 0.3f - 1.1; + x[i] = static_cast(i % 3 - 2.0) / 2.0 * 0.333 + + static_cast(i % 19 - 10.0) / 10.0 * 0.333 + + static_cast(i % 39 - 20.0) / 20.0 * 0.333 + 0.001213; } this->run_forward(¶m_, &kernel_, x, out.data()); @@ -120,7 +126,8 @@ class ActivationGradTester { for (int i = 0; i < dims_.production(); i++) { out_grad[i] = 1.0; } - this->run_backward(&grad_param_, &grad_kernel_, x, out_grad, x_grad.data()); + this->run_backward( + &grad_param_, &grad_kernel_, x, out, out_grad, x_grad.data()); for (int i = 0; i < dims_.production(); i++) { EXPECT_NEAR(x_grad[i], (out_delta[i] - out[i]) / delta, max_grad_delta); @@ -137,31 +144,58 @@ class ActivationGradTester { grad_param_t grad_param_; }; -void TestNormalCase(DDim dims) { - std::unique_ptr tester(new ActivationGradTester(dims)); +void TestSquareGrad(DDim dims) { + LOG(INFO) << "Test Square grad"; + std::unique_ptr< + ActivationGradTester> + tester( + new ActivationGradTester( + dims)); tester->prepare_kernel(); float delta = 0.001; float max_grad_delta = 0.005; tester->check_grad(delta, max_grad_delta); } -TEST(activation_grad_arm, compute) { - LOG(INFO) << "Test Square grad"; +void TestReluGrad(DDim dims) { + LOG(INFO) << "Test Relu grad"; + std::unique_ptr> + tester(new ActivationGradTester( + dims)); + tester->prepare_kernel(); + float delta = 0.001; + float max_grad_delta = 0.005; + tester->check_grad(delta, max_grad_delta); +} + +void TestTanhGrad(DDim dims) { + LOG(INFO) << "Test Tanh grad"; + std::unique_ptr> + tester(new ActivationGradTester( + dims)); + tester->prepare_kernel(); + float delta = 0.001; + float max_grad_delta = 0.005; + tester->check_grad(delta, max_grad_delta); +} + +TEST(activation_grad_host, compute) { DeviceInfo::Init(); - for (auto n : {2}) { - for (auto c : {2}) { - for (auto h : {2}) { - for (auto w : {2}) { - TestNormalCase(DDim(std::vector({n, c, h, w}))); + for (auto n : {2, 1}) { + for (auto c : {2, 9}) { + for (auto h : {2, 1}) { + for (auto w : {2, 10}) { + TestSquareGrad(DDim(std::vector({n, c, h, w}))); + TestReluGrad(DDim(std::vector({n, c, h, w}))); + TestTanhGrad(DDim(std::vector({n, c, h, w}))); } } } } } -} // namespace arm } // namespace kernels } // namespace lite } // namespace paddle USE_LITE_KERNEL(square, kARM, kFloat, kNCHW, def); -USE_LITE_KERNEL(square_grad, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(square_grad, kHost, kFloat, kNCHW, def); diff --git a/lite/tests/kernels/elementwise_grad_compute_test.cc b/lite/tests/kernels/elementwise_grad_compute_test.cc index 2b5fbbb65d3d7e17bf90afb71f5c8154f0d88488..04e74e49099f13a7e5920b306f8d2e26650a2574 100644 --- a/lite/tests/kernels/elementwise_grad_compute_test.cc +++ b/lite/tests/kernels/elementwise_grad_compute_test.cc @@ -215,18 +215,6 @@ class ElementwiseAddGradTester { fill_data_rand(y.data(), -1.f, 1.f, y_dims_.production()); this->run_forward(¶m_, &kernel_, x, y, out.data()); - for (int i = 0; i < x_dims_.production(); i++) { - LOG(INFO) << "x_" << i << ": " << x[i]; - } - - for (int i = 0; i < y_dims_.production(); i++) { - LOG(INFO) << "y_" << i << ": " << y[i]; - } - - for (int i = 0; i < out_dims_.production(); i++) { - LOG(INFO) << "out_" << i << ": " << out[i]; - } - // backward std::vector out_grad(out_dims_.production()); std::vector x_grad(x_dims_.production()); @@ -242,14 +230,6 @@ class ElementwiseAddGradTester { x_grad.data(), y_grad.data()); - for (int i = 0; i < x_grad.size(); i++) { - LOG(INFO) << "x_grad_" << i << ": " << x_grad[i]; - } - - for (int i = 0; i < y_grad.size(); i++) { - LOG(INFO) << "y_grad_" << i << ": " << y_grad[i]; - } - // get numeric gradient std::vector x_delta(x_dims_.production()); std::vector y_delta(y_dims_.production()); @@ -443,18 +423,6 @@ class ElementwiseSubGradTester { fill_data_rand(y.data(), -1.f, 1.f, y_dims_.production()); this->run_forward(¶m_, &kernel_, x, y, out.data()); - for (int i = 0; i < x_dims_.production(); i++) { - LOG(INFO) << "x_" << i << ": " << x[i]; - } - - for (int i = 0; i < y_dims_.production(); i++) { - LOG(INFO) << "y_" << i << ": " << y[i]; - } - - for (int i = 0; i < out_dims_.production(); i++) { - LOG(INFO) << "out_" << i << ": " << out[i]; - } - // backward std::vector out_grad(out_dims_.production()); std::vector x_grad(x_dims_.production()); @@ -470,14 +438,6 @@ class ElementwiseSubGradTester { x_grad.data(), y_grad.data()); - for (int i = 0; i < x_grad.size(); i++) { - LOG(INFO) << "x_grad_" << i << ": " << x_grad[i]; - } - - for (int i = 0; i < y_grad.size(); i++) { - LOG(INFO) << "y_grad_" << i << ": " << y_grad[i]; - } - // get numeric gradient std::vector x_delta(x_dims_.production()); std::vector y_delta(y_dims_.production()); diff --git a/lite/tests/kernels/sequence_conv_compute_test.cc b/lite/tests/kernels/sequence_conv_compute_test.cc index 84887b2573516d0c82cbb8c9b4cf9336f30ee41d..68afaad04f8e84995e811f81f99a2d4109c845a5 100644 --- a/lite/tests/kernels/sequence_conv_compute_test.cc +++ b/lite/tests/kernels/sequence_conv_compute_test.cc @@ -85,21 +85,31 @@ class SequenceConvComputeTester : public arena::TestCase { auto output_dims = output->dims(); auto output_data = output->mutable_data(); std::vector> res; - if (contextStart_ == -2) { + + if (contextStart_ == -2 && lod_.size() == 1 && + lod_[0] == std::vector({0, 4})) { res = {{-0.08867277f, -0.17257819f, -0.2564836f}, {0.194508f, 0.05720823f, -0.08009153f}, {0.73512584f, 0.5749428f, 0.41475973f}, {0.5635012f, 0.49485126f, 0.42620137f}}; - } else if (contextStart_ == -1) { + } else if (contextStart_ == -1 && lod_.size() == 1 && + lod_[0] == std::vector({0, 4})) { res = {{0.194508f, 0.05720823f, -0.08009153f}, {0.73512584f, 0.5749428f, 0.41475973f}, {0.5635012f, 0.49485126f, 0.42620137f}, {0.2517162f, 0.23646072f, 0.22120519f}}; - } else if (contextStart_ == 0) { + } else if (contextStart_ == 0 && lod_.size() == 1 && + lod_[0] == std::vector({0, 4})) { res = {{0.73512584f, 0.5749428f, 0.41475973f}, {0.5635012f, 0.49485126f, 0.42620137f}, {0.2517162f, 0.23646072f, 0.22120519f}, {0.02574372f, 0.03337148f, 0.04099924f}}; + } else if (contextStart_ == -1 && lod_.size() == 1 && + lod_[0] == std::vector({0, 2, 4})) { + res = {{0.194508, 0.05720823, -0.08009153}, + {0.7093821, 0.57208234, 0.43478262}, + {0.19450802, 0.17925248, 0.16399695}, + {0.2517162, 0.23646072, 0.22120519}}; } else { fprintf(stderr, "not supported contextStart_\n"); exit(-1); @@ -136,12 +146,25 @@ void TestNormalCase(Place place, float abs_error = 2e-5) { } } +void TestBatchCase(Place place, float abs_error = 2e-5) { + std::vector> lod{{0, 2, 4}}; + std::vector dims{4, 5}; + std::vector candidate_pad_idx{-1}; + for (int pad_idx : candidate_pad_idx) { + std::unique_ptr tester(new SequenceConvComputeTester( + place, "def", lod, DDim(dims), pad_idx, 1, 3, 3)); + arena::Arena arena(std::move(tester), place, abs_error); + arena.TestPrecision(); + } +} + TEST(sequence_conv, precision) { #ifdef LITE_WITH_ARM float abs_error = 2e-5; Place place(TARGET(kARM)); TestNormalCase(place, abs_error); + TestBatchCase(place, abs_error); #endif } diff --git a/lite/tools/build_android.sh b/lite/tools/build_android.sh index 90f604d97f212e2966326582eefbb8416cc269ad..5713c4e21bb97d12bb840c99d1adbc7f2d781157 100755 --- a/lite/tools/build_android.sh +++ b/lite/tools/build_android.sh @@ -269,6 +269,7 @@ function main { if [ -z "$1" ]; then # compiling result contains light_api lib only, recommanded. make_tiny_publish_so $ARCH $TOOLCHAIN $ANDROID_STL + exit 0 fi # Parse command line. @@ -358,6 +359,7 @@ function main { done # compiling result contains light_api lib only, recommanded. make_tiny_publish_so + exit 0 } main $@ diff --git a/lite/tools/build_bm.sh b/lite/tools/build_bm.sh index 964da15b0b6fcf888812271b0a2c944d9efa63b8..055f6a35c3ab145e9dfe4bc5d46172a2119ffb25 100755 --- a/lite/tools/build_bm.sh +++ b/lite/tools/build_bm.sh @@ -43,7 +43,7 @@ function prepare_thirdparty { # clone bmlibs if [ ! -d ${workspace}/third-party/bmlibs ]; then git clone https://github.com/AnBaolei1984/bmlibs.git ${workspace}/third-party/bmlibs - fi + fi } # for code gen, a source file is generated after a test, but is dependended by some targets in cmake. @@ -70,6 +70,13 @@ function build_bm { mkdir -p $build_dir cd $build_dir + if [ $TARGET_NAME == "BM1684" ]; then + BM_SDK_ROOT="$workspace/third-party/bmlibs/bm_sc5_libs" + else + BM_SDK_ROOT="$workspace/third-party/bmlibs/bm_sc3_libs" + fi + echo $BM_SDK_ROOT + prepare_workspace cmake .. \ ${CMAKE_COMMON_OPTIONS} \ @@ -95,17 +102,7 @@ function main { case $i in --target_name=*) TARGET_NAME="${i#*=}" - shift - ;; - #--bm_sdk_root=*) - # BM_SDK_ROOT="${i#*=}" - # shift - # ;; - bm) build_bm - shift - ;; - *) # unknown option print_usage exit 1 diff --git a/lite/tools/build_ios.sh b/lite/tools/build_ios.sh index 2c7eeb466f3d82cf491b6a631d79918fa4fd4cd2..3d4337aa8ecc20fd078b8906a950408927ea56c8 100755 --- a/lite/tools/build_ios.sh +++ b/lite/tools/build_ios.sh @@ -152,6 +152,7 @@ function main { esac done make_ios $ARCH + exit 0 } main $@ diff --git a/lite/tools/check_api_approvals.sh b/lite/tools/check_api_approvals.sh index cebeed1a8540676d4fce342a24783da6d6840679..b2a4659c964121b0a95961195340c296710db2de 100755 --- a/lite/tools/check_api_approvals.sh +++ b/lite/tools/check_api_approvals.sh @@ -71,7 +71,7 @@ function CheckLibSizeDiff() { if [ $diff_size -gt 10485 ]; then echo_line="Your PR has increased basic inference lib for $diff_size Byte, exceeding maximum requirement of 10485 Byte (0.01M). You need Superjomn's (Yunchunwei) approval or you can contact DannyIsFunny(HuZhiqiang).\n" echo "****************" - echo -e "${echo_list[@]}" + echo -e "${echo_line[@]}" echo "There is an approved errors." echo "****************" exit 1