提交 a867dbbf 编写于 作者: C chonwhite

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle-Lite into fpga_pr

......@@ -32,10 +32,9 @@ list(APPEND CUDNN_CHECK_LIBRARY_DIRS
$ENV{CUDNN_ROOT}/lib64
$ENV{CUDNN_ROOT}/lib
/usr/lib
${CUDA_TOOLKIT_ROOT_DIR}
${CUDA_TOOLKIT_ROOT_DIR}/lib/x64
${CUDA_TOOLKIT_ROOT_DIR}/lib64
)
${CUDA_TOOLKIT_ROOT_DIR}
${CUDA_TOOLKIT_ROOT_DIR}/lib/x64
${CUDA_TOOLKIT_ROOT_DIR}/lib64)
if((${CUDA_VERSION} GREATER 10.0) OR (${CUDA_VERSION} EQUAL 10.0))
find_library(CUBLAS_LIBRARY NAMES libcublas.so PATHS ${CUDNN_CHECK_LIBRARY_DIRS} NO_DEFAULT_PATH)
......
......@@ -46,7 +46,6 @@ void OutputOptModel(const std::string& load_model_dir,
config.set_model_dir(load_model_dir);
std::vector<Place> vaild_places = {
Place{TARGET(kARM), PRECISION(kFloat)},
Place{TARGET(kX86), PRECISION(kFloat)},
};
if (FLAGS_is_quantized_model) {
vaild_places.insert(vaild_places.begin(),
......
......@@ -47,7 +47,6 @@ void OutputOptModel(const std::string& load_model_dir,
lite_api::CxxConfig config;
config.set_model_dir(load_model_dir);
config.set_valid_places({
Place{TARGET(kX86), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kFloat)},
});
auto predictor = lite_api::CreatePaddlePredictor(config);
......
......@@ -153,7 +153,7 @@ class LITE_API CxxConfig : public ConfigBase {
std::string param_file() const { return param_file_; }
bool model_from_memory() const { return model_from_memory_; }
void set_cpu_math_library_math_threads(int threads) {
void set_cpu_math_library_num_threads(int threads) {
cpu_math_library_math_threads_ = threads;
}
int cpu_math_library_num_threads() const {
......
......@@ -31,9 +31,11 @@ USE_MIR_PASS(lite_fc_fuse_pass);
USE_MIR_PASS(lite_shuffle_channel_fuse_pass);
USE_MIR_PASS(lite_transpose_softmax_transpose_fuse_pass);
USE_MIR_PASS(lite_interpolate_fuse_pass);
USE_MIR_PASS(lite_sequence_pool_concat_fuse_pass);
USE_MIR_PASS(identity_scale_eliminate_pass);
USE_MIR_PASS(lite_conv_elementwise_fuse_pass);
USE_MIR_PASS(lite_conv_activation_fuse_pass);
USE_MIR_PASS(lite_var_conv_2d_activation_fuse_pass);
USE_MIR_PASS(lite_elementwise_add_activation_fuse_pass);
USE_MIR_PASS(lite_quant_dequant_fuse_pass);
USE_MIR_PASS(type_precision_cast_pass);
......
......@@ -30,7 +30,7 @@ TEST(Step_rnn, test_step_rnn_lite_x86) {
std::string model_dir = FLAGS_model_dir;
lite_api::CxxConfig config;
config.set_model_dir(model_dir);
config.set_cpu_math_library_math_threads(10);
config.set_cpu_math_library_num_threads(1);
config.set_valid_places({lite_api::Place{TARGET(kX86), PRECISION(kInt64)},
lite_api::Place{TARGET(kX86), PRECISION(kFloat)},
lite_api::Place{TARGET(kHost), PRECISION(kFloat)}});
......
......@@ -836,7 +836,6 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
threads * prein_size + win_round /*tmp zero*/ + ow_round /*tmp writer*/;
ctx->ExtendWorkspace(sizeof(float) * workspace_size);
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr;
/// get workspace
......
......@@ -2151,6 +2151,210 @@ inline void act_switch_c8_fp32(const float* din_ptr,
}
}
#ifdef __aarch64__
#define LOAD_DATA \
"1: \n" \
"ld1 {v0.4s}, [%[din_ptr]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
"ld1 {v1.4s}, [%[din_ptr]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
"ld1 {v2.4s}, [%[din_ptr]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
"ld1 {v3.4s}, [%[din_ptr]], #16 \n" /*vld1q_f32(din_ptr0)*/
#define DO_RELU \
"fmax v0.4s, v0.4s, %[vzero].4s \n" /* vmaxq_f32() */ \
"fmax v1.4s, v1.4s, %[vzero].4s \n" /* vmaxq_f32() */ \
"fmax v2.4s, v2.4s, %[vzero].4s \n" /* vmaxq_f32() */ \
"fmax v3.4s, v3.4s, %[vzero].4s \n" /* vmaxq_f32() */
#define DO_RELU6 \
"fmin v0.4s, v0.4s, %[vsix].4s \n" /* vmaxq_f32() */ \
"fmin v1.4s, v1.4s, %[vsix].4s \n" /* vmaxq_f32() */ \
"fmin v2.4s, v2.4s, %[vsix].4s \n" /* vmaxq_f32() */ \
"fmin v3.4s, v3.4s, %[vsix].4s \n" /* vmaxq_f32() */
#define DO_LEAKY_RELU \
"cmhs v4.4s, v0.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v5.4s, v0.4s, %[vscale].4s \n" /* vmulq_f32 */ \
"cmhs v6.4s, v1.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v7.4s, v1.4s, %[vscale].4s \n" /* vmulq_f32 */ \
"cmhs v8.4s, v2.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v9.4s, v2.4s, %[vscale].4s \n" /* vmulq_f32 */ \
"cmhs v10.4s, v3.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v11.4s, v3.4s, %[vscale].4s \n" /* vmulq_f32 */ \
"bif v0.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v1.16b, v7.16b, v6.16b \n" /* choose*/ \
"bif v2.16b, v9.16b, v8.16b \n" /* choose*/ \
"bif v3.16b, v11.16b, v10.16b \n" /* choose*/
#define DO_STORE \
"subs %w[cnt], %w[cnt], #1 \n" \
"st1 {v0.4s}, [%[dout_ptr]], #16 \n" /* vst1q_f32() */ \
"st1 {v1.4s}, [%[dout_ptr]], #16 \n" /* vst1q_f32() */ \
"st1 {v2.4s}, [%[dout_ptr]], #16 \n" /* vst1q_f32() */ \
"st1 {v3.4s}, [%[dout_ptr]], #16 \n" /* vst1q_f32() */ \
"bne 1b \n"
#else
#define LOAD_DATA \
"1: \n" \
"vld1.32 {d6-d7}, [%[din_ptr]]! @ vld1q_f32(din_ptr) \n" \
"vld1.32 {d8-d9}, [%[din_ptr]]! @ vld1q_f32(din_ptr) \n" \
"vld1.32 {d10-d11}, [%[din_ptr]]! @ vld1q_f32(din_ptr) \n" \
"vld1.32 {d12-d13}, [%[din_ptr]]! @ vld1q_f32(din_ptr) \n"
#define DO_RELU \
"vmax.f32 q3, q3, %q[vzero] @ vmaxq_f32() \n" \
"vmax.f32 q4, q4, %q[vzero] @ vmaxq_f32() \n" \
"vmax.f32 q5, q5, %q[vzero] @ vmaxq_f32() \n" \
"vmax.f32 q6, q6, %q[vzero] @ vmaxq_f32() \n"
#define DO_RELU6 \
"vmin.f32 q3, q3, %q[vsix] @ vminq_f32() \n" \
"vmin.f32 q4, q4, %q[vsix] @ vmaxq_f32() \n" \
"vmin.f32 q5, q5, %q[vsix] @ vmaxq_f32() \n" \
"vmin.f32 q6, q6, %q[vsix] @ vmaxq_f32() \n"
#define DO_LEAKY_RELU \
"vcge.f32 q7, q3, %q[vzero] @ vcgeq_u32 \n" \
"vmul.f32 q8, q3, %q[vscale] @ vmulq_f32 \n" \
"vcge.f32 q9, q4, %q[vzero] @ vcgeq_u32 \n" \
"vmul.f32 q10, q4, %q[vscale] @ vmulq_f32 \n" \
"vcge.f32 q11, q5, %q[vzero] @ vcgeq_u32 \n" \
"vmul.f32 q12, q5, %q[vscale] @ vmulq_f32 \n" \
"vcge.f32 q13, q6, %q[vzero] @ vcgeq_u32 \n" \
"vmul.f32 q14, q6, %q[vscale] @ vmulq_f32 \n" \
"vbif q3, q8, q7 @ choose \n" \
"vbif q4, q10, q9 @ choose \n" \
"vbif q5, q12, q11 @ choose \n" \
"vbif q6, q13, q13 @ choose \n"
#define DO_STORE \
"subs %[cnt], #1 \n" \
"vst1.32 {d6-d7}, [%[dout_ptr]]! @ vst1q_f32() \n" \
"vst1.32 {d8-d9}, [%[dout_ptr]]! @ vst1q_f32() \n" \
"vst1.32 {d10-d11}, [%[dout_ptr]]! @ vst1q_f32() \n" \
"vst1.32 {d12-d13}, [%[dout_ptr]]! @ vst1q_f32() \n" \
"bne 1b \n"
#endif
/*
* Data do activation process
* Now support relu relu6 leakyrelu act
*/
inline void act_switch_process(float* src,
float* dst,
int size,
const operators::ActivationParam* act_param) {
int cnt = size >> 4;
int remain = size % 16;
float32x4_t vzero = vdupq_n_f32(0.f);
if (act_param != nullptr && act_param->has_active) {
float32x4_t vsix = vdupq_n_f32(act_param->Relu_clipped_coef);
float32x4_t vscale = vdupq_n_f32(act_param->Leaky_relu_alpha);
if (cnt > 0) {
switch (act_param->active_type) {
case lite_api::ActivationType::kRelu:
#ifdef __aarch64__
asm volatile(
LOAD_DATA DO_RELU DO_STORE
: [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt)
: [vzero] "w"(vzero)
: "memory", "cc", "v0", "v1", "v2", "v3");
#else
asm volatile(
LOAD_DATA DO_RELU DO_STORE
: [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt)
: [vzero] "w"(vzero)
: "memory", "cc", "q3", "q4", "q5", "q6");
#endif
break;
case lite_api::ActivationType::kRelu6:
#ifdef __aarch64__
asm volatile(
LOAD_DATA DO_RELU DO_RELU6 DO_STORE
: [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt)
: [vzero] "w"(vzero), [vsix] "w"(vsix)
: "memory", "cc", "v0", "v1", "v2", "v3");
#else
asm volatile(
LOAD_DATA DO_RELU DO_RELU6 DO_STORE
: [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt)
: [vzero] "w"(vzero), [vsix] "w"(vsix)
: "memory", "cc", "q3", "q4", "q5", "q6");
#endif
break;
case lite_api::ActivationType::kLeakyRelu:
#ifdef __aarch64__
asm volatile(
LOAD_DATA DO_LEAKY_RELU DO_STORE
: [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt)
: [vzero] "w"(vzero), [vscale] "w"(vscale)
: "memory",
"cc",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11");
#else
asm volatile(
LOAD_DATA DO_LEAKY_RELU DO_STORE
: [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt)
: [vzero] "w"(vzero), [vscale] "w"(vscale)
: "memory",
"cc",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14");
#endif
break;
default:
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param->active_type)
<< " fuse not support";
}
}
// remain
switch (act_param->active_type) {
case lite_api::ActivationType::kRelu:
for (int i = 0; i < remain; i++) {
*dst = *src >= 0.f ? *src : 0.f;
src++;
dst++;
}
case lite_api::ActivationType::kRelu6:
for (int i = 0; i < remain; i++) {
float tmp = *src >= 0.f ? *src : 0.f;
*dst = tmp <= act_param->Relu_clipped_coef
? tmp
: act_param->Relu_clipped_coef;
src++;
dst++;
}
case lite_api::ActivationType::kLeakyRelu:
for (int i = 0; i < remain; i++) {
if (*src >= 0.f) {
*dst = *src;
} else {
*dst = *src * act_param->Leaky_relu_alpha;
}
src++;
dst++;
}
break;
default:
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param->active_type)
<< " fuse not support";
}
}
}
/*wirte result in outputs
* input din: [n, c / 8, h, w * 8], output dout: [n, c, h, w]
*/
......
......@@ -52,6 +52,7 @@ void conv_3x3s2_depthwise_fp32(const float* i_data,
const float* weights,
const float* bias,
const operators::ConvParam& param,
const operators::ActivationParam act_param,
ARMContext* ctx);
void conv_depthwise_3x3s1_fp32(const float* din,
......@@ -67,7 +68,6 @@ void conv_depthwise_3x3s1_fp32(const float* din,
const float* bias,
int pad,
bool flag_bias,
bool flag_relu,
const operators::ActivationParam act_param,
ARMContext* ctx);
......@@ -84,7 +84,7 @@ void conv_depthwise_3x3s2_fp32(const float* din,
const float* bias,
int pad,
bool flag_bias,
bool flag_relu,
const operators::ActivationParam act_param,
ARMContext* ctx);
template <typename Dtype>
......
......@@ -584,7 +584,6 @@ void conv_depthwise_3x3_fp32(const void* din,
const int pad_w = paddings[2];
int stride = param.strides[1];
int pad = pad_w;
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr;
bool pads_equal =
((paddings[0] == paddings[1]) && (paddings[2] == paddings[3]));
......@@ -603,7 +602,6 @@ void conv_depthwise_3x3_fp32(const void* din,
bias,
pad,
flag_bias,
flag_relu,
act_param,
ctx);
} else {
......@@ -638,7 +636,7 @@ void conv_depthwise_3x3_fp32(const void* din,
bias,
pad,
flag_bias,
flag_relu,
act_param,
ctx);
} else {
conv_3x3s2_depthwise_fp32(reinterpret_cast<const float*>(din),
......@@ -653,6 +651,7 @@ void conv_depthwise_3x3_fp32(const void* din,
reinterpret_cast<const float*>(weights),
bias,
param,
act_param,
ctx);
}
} else {
......
......@@ -1404,8 +1404,8 @@ void sgemm_prepack_c4_small(int M,
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32 \n"
"bne 1b \n"
"fadd v8.4s, v8.4s, v9.4s \n"
"2:\n"
"fadd v8.4s, v8.4s, v9.4s \n"
"st1 {v8.4s}, [%[c]], #16 \n"
: [a] "+r" (a_ptr),
[b] "+r" (b_ptr),
......@@ -1660,8 +1660,8 @@ void sgemm_prepack_c4_small(int M,
/* load a0, a1 */
"vld1.32 {d2-d5}, [%[a]]! \n"
"bne 1b \n"
"vadd.f32 q5, q5, q6 \n"
"2:\n"
"vadd.f32 q5, q5, q6 \n"
"vst1.32 {d10-d11}, [%[c]]!\n"
: [a] "+r" (a_ptr),
[b] "+r" (b_ptr),
......
......@@ -89,9 +89,15 @@ bool CudnnConv2D<PRECISION(kFloat)>::create(const operators::ConvParam& param,
this->act_desc_, CUDNN_ACTIVATION_RELU, CUDNN_NOT_PROPAGATE_NAN, 0.0));
}
#if CUDNN_VERSION_MIN(7, 0, 0)
cudnnMathType_t math_type =
use_tensor_core_ ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH;
CUDNN_CHECK(cudnnSetConvolutionMathType(this->conv_desc_, math_type));
#endif
if (ic == param.groups && ic == oc && ic != 1) {
this->fwd_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
} else if (1) {
} else if (!param.var_length) {
const auto* i_data = param.x->data<float>();
const auto* w_data = param.filter->data<float>();
auto* o_data = param.output->mutable_data<float>(TARGET(kCUDA));
......
......@@ -55,6 +55,8 @@ class Gemm {
PtypeOut* c,
Context<TARGET(kCUDA)>* ctx);
cublasHandle_t get_handle() const { return cu_handle_; }
private:
cudaStream_t exe_stream_;
cublasHandle_t cu_handle_;
......
......@@ -30,7 +30,12 @@ std::unique_ptr<xtcl::network::xRuntimeInstance> Device::Build(
// The XPU compiler build the graph and fill all of the constant params, only
// one output is supported now.
xtcl::xNetwork network = builder->FinalizeNetwork(*((*outputs)[0]));
xtcl::Array<xtcl::xExpr> all_outs;
for (size_t i = 0; i < outputs->size(); i++) {
all_outs.push_back(*outputs->at(i));
}
xtcl::xNetwork network =
builder->FinalizeNetwork(xtcl::relay::TupleNode::make(all_outs));
auto target = xtcl::Target::Create(device_name_);
auto compiler = xtcl::network::xTensorCompiler(network, target);
compiler.SetParams(*params); // Set the data of constant tensors
......
......@@ -35,12 +35,12 @@ void TestCase::CreateInstruction() {
op_desc_.reset(new cpp::OpDesc());
op_desc_->SetType("subgraph");
op_desc_->SetAttr<int32_t>("sub_block", sub_block_idx);
op_desc_->SetInput("Inputs", op_desc_->input_vars());
op_desc_->SetOutput("Outputs", op_desc_->output_vars());
op_desc_->SetAttr<std::vector<std::string>>(
"input_data_names", sub_block_op_desc->input_vars());
op_desc_->SetAttr<std::vector<std::string>>(
"output_data_names", sub_block_op_desc->output_vars());
auto in_names = sub_block_op_desc->input_vars();
auto out_names = sub_block_op_desc->output_vars();
op_desc_->SetInput("Inputs", in_names);
op_desc_->SetOutput("Outputs", out_names);
op_desc_->SetAttr<std::vector<std::string>>("input_data_names", in_names);
op_desc_->SetAttr<std::vector<std::string>>("output_data_names", out_names);
op = LiteOpRegistry::Global().Create(op_desc().Type());
static_cast<operators::SubgraphOp*>(op.get())->SetSubBlock(sub_block_desc);
} else {
......
......@@ -188,13 +188,17 @@ class Arena {
tester_->Prepare();
}
bool TestPrecision() {
bool TestPrecision(const std::vector<std::string>& exclude_outs = {}) {
tester_->RunBaseline(tester_->baseline_scope());
tester_->RunInstruction();
bool success = true;
for (auto& out : tester_->op_desc().OutputArgumentNames()) {
for (auto& var : tester_->op_desc().Output(out)) {
if (std::find(exclude_outs.begin(), exclude_outs.end(), var) !=
exclude_outs.end()) {
continue;
}
success = success && CompareTensor(out, var);
}
}
......@@ -209,7 +213,17 @@ class Arena {
}
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::high_resolution_clock::now() - timer);
LOG(INFO) << "average duration: " << duration.count() << " ms";
timer = std::chrono::high_resolution_clock::now();
for (int i = 0; i < times; i++) {
tester_->RunBaseline(tester_->baseline_scope());
}
auto duration_basic = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::high_resolution_clock::now() - timer);
LOG(INFO) << "average lite duration: " << duration.count() << " ms";
LOG(INFO) << "average basic duration: " << duration_basic.count() << " ms";
LOG(INFO) << "speed up ratio: lite_speed / basic_speed: "
<< static_cast<float>(duration_basic.count()) / duration.count();
}
private:
......
......@@ -16,9 +16,11 @@ lite_cc_library(mir_passes
fusion/interpolate_fuse_pass.cc
fusion/conv_elementwise_fuse_pass.cc
fusion/conv_activation_fuse_pass.cc
fusion/var_conv_2d_activation_fuse_pass.cc
fusion/conv_bn_fuse_pass.cc
fusion/elementwise_add_activation_fuse_pass.cc
fusion/quant_dequant_fuse_pass.cc
fusion/sequence_pool_concat_fuse_pass.cc
elimination/identity_scale_eliminate_pass.cc
elimination/elementwise_mul_constant_eliminate_pass.cc
static_kernel_pick_pass.cc
......
......@@ -10,6 +10,9 @@ lite_cc_library(fuse_conv_elementwise
lite_cc_library(fuse_conv_activation
SRCS conv_activation_fuser.cc
DEPS pattern_matcher_high_api)
lite_cc_library(fuse_var_conv_activation
SRCS var_conv_2d_activation_fuser.cc
DEPS pattern_matcher_high_api)
lite_cc_library(fuse_conv_bn
SRCS conv_bn_fuser.cc
DEPS pattern_matcher_high_api)
......@@ -25,17 +28,22 @@ lite_cc_library(fuse_transpose_softmax_transpose
lite_cc_library(fuse_interpolate
SRCS interpolate_fuser.cc
DEPS pattern_matcher_high_api)
lite_cc_library(fuse_sequence_pool_concat
SRCS sequence_pool_concat_fuser.cc
DEPS pattern_matcher_high_api)
set(mir_fusers
fuse_fc
fuse_shuffle_channel
fuse_conv_elementwise
fuse_conv_activation
fuse_var_conv_activation
fuse_conv_bn
fuse_quant_dequant
fuse_elementwise_add_activation
fuse_transpose_softmax_transpose
fuse_interpolate
fuse_sequence_pool_concat
CACHE INTERNAL "fusers")
if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/sequence_pool_concat_fuse_pass.h"
#include <memory>
#include <vector>
#include "lite/core/mir/fusion/sequence_pool_concat_fuser.h"
#include "lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void SequencePoolConcatFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
fusion::SequencePoolConcatFuser fuser;
fuser(graph.get());
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_sequence_pool_concat_fuse_pass,
paddle::lite::mir::SequencePoolConcatFusePass)
.BindTargets({TARGET(kCUDA)});
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
class SequencePoolConcatFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/sequence_pool_concat_fuser.h"
#include <memory>
#include <vector>
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
// """
// merge {sequence_pool x 7, concat} => merge_sequence_pool_and_concat
// src1 src2 src7 src1 src2 src7
// | | | | |
// v v | | ... |
// sequence_pool sequence_pool ...(sequence_pool) | | |
// | | | => -------------------
// --------------------------------- |
// | |
// v v
// concat sequence_pool_concat
// """
void SequencePoolConcatFuser::BuildPattern() {
// create nodes.
auto* concat = OpNode("concat", "concat")->AsIntermediate();
#define STR1(R) #R
#define STR2(R) STR1(R)
#define POOL_CONCAT_PATTERN(num) \
auto* x_##num = VarNode(STR2(sequence_pool_x_##num)) \
->assert_is_op_input("sequence_pool", "X") \
->AsInput(); \
auto* sequence_pool_##num = \
OpNode(STR2(sequence_pool_##num), "sequence_pool")->AsIntermediate(); \
auto* sequence_pool_##num##_out = \
VarNode(STR2(sequence_pool_##num##_out)) \
->assert_is_op_output("sequence_pool", "Out") \
->assert_is_op_nth_input("concat", "X", num - 1) \
->AsIntermediate(); \
auto* sequence_pool_##num##_idx = \
VarNode(STR2(sequence_pool_##num##_idx)) \
->assert_is_op_output("sequence_pool", "MaxIndex") \
->AsIntermediate(); \
*sequence_pool_##num >> *sequence_pool_##num##_idx; \
*x_##num >> *sequence_pool_##num >> *sequence_pool_##num##_out >> *concat;
auto* concat_out =
VarNode("concat_out")->assert_is_op_output("concat", "Out");
*concat >> *concat_out;
POOL_CONCAT_PATTERN(1);
POOL_CONCAT_PATTERN(2);
POOL_CONCAT_PATTERN(3);
POOL_CONCAT_PATTERN(4);
POOL_CONCAT_PATTERN(5);
POOL_CONCAT_PATTERN(6);
POOL_CONCAT_PATTERN(7);
#undef POOL_CONCAT_PATTERN
#undef STR1
#undef STR2
}
void SequencePoolConcatFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto sequence_pool_concat_op =
LiteOpRegistry::Global().Create("sequence_pool_concat");
auto concat = matched.at("concat")->stmt()->op();
auto* scope = concat->scope();
auto& valid_places = concat->valid_places();
sequence_pool_concat_op->Attach(op_desc, scope);
auto* new_op_node =
graph->GraphCreateInstructNode(sequence_pool_concat_op, valid_places);
IR_NODE_LINK_TO(matched.at("sequence_pool_x_1"), new_op_node);
IR_NODE_LINK_TO(matched.at("sequence_pool_x_2"), new_op_node);
IR_NODE_LINK_TO(matched.at("sequence_pool_x_3"), new_op_node);
IR_NODE_LINK_TO(matched.at("sequence_pool_x_4"), new_op_node);
IR_NODE_LINK_TO(matched.at("sequence_pool_x_5"), new_op_node);
IR_NODE_LINK_TO(matched.at("sequence_pool_x_6"), new_op_node);
IR_NODE_LINK_TO(matched.at("sequence_pool_x_7"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("concat_out"));
}
cpp::OpDesc SequencePoolConcatFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc = *matched.at("concat")->stmt()->op_info();
op_desc.SetType("sequence_pool_concat");
op_desc.SetInput("X",
{matched.at("sequence_pool_x_1")->arg()->name,
matched.at("sequence_pool_x_2")->arg()->name,
matched.at("sequence_pool_x_3")->arg()->name,
matched.at("sequence_pool_x_4")->arg()->name,
matched.at("sequence_pool_x_5")->arg()->name,
matched.at("sequence_pool_x_6")->arg()->name,
matched.at("sequence_pool_x_7")->arg()->name});
std::vector<std::string> pooltypes;
pooltypes.push_back(matched.at("sequence_pool_1")
->stmt()
->op_info()
->GetAttr<std::string>("pooltype"));
pooltypes.push_back(matched.at("sequence_pool_2")
->stmt()
->op_info()
->GetAttr<std::string>("pooltype"));
pooltypes.push_back(matched.at("sequence_pool_3")
->stmt()
->op_info()
->GetAttr<std::string>("pooltype"));
pooltypes.push_back(matched.at("sequence_pool_4")
->stmt()
->op_info()
->GetAttr<std::string>("pooltype"));
pooltypes.push_back(matched.at("sequence_pool_5")
->stmt()
->op_info()
->GetAttr<std::string>("pooltype"));
pooltypes.push_back(matched.at("sequence_pool_6")
->stmt()
->op_info()
->GetAttr<std::string>("pooltype"));
pooltypes.push_back(matched.at("sequence_pool_7")
->stmt()
->op_info()
->GetAttr<std::string>("pooltype"));
op_desc.SetAttr("pooltype", pooltypes);
op_desc.SetOutput("Out", {matched.at("concat_out")->arg()->name});
return op_desc;
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class SequencePoolConcatFuser : public FuseBase {
public:
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
};
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/var_conv_2d_activation_fuse_pass.h"
#include <memory>
#include <vector>
#include "lite/core/mir/fusion/var_conv_2d_activation_fuser.h"
#include "lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void VarConv2dActivationFusePass::Apply(
const std::unique_ptr<SSAGraph>& graph) {
std::vector<std::string> act_types{"relu"};
for (auto act_type : act_types) {
fusion::VarConvActivationFuser fuser(act_type, "var_conv_2d");
fuser(graph.get());
}
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_var_conv_2d_activation_fuse_pass,
paddle::lite::mir::VarConv2dActivationFusePass)
.BindTargets({TARGET(kCUDA)});
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
class VarConv2dActivationFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/var_conv_2d_activation_fuser.h"
#include <memory>
#include <vector>
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
void VarConvActivationFuser::BuildPattern() {
// create nodes.
auto* input = VarNode("X")->assert_is_op_input(conv_type_, "X")->AsInput();
auto* filter = VarNode("W")->assert_is_op_input(conv_type_, "W")->AsInput();
auto* conv2d = OpNode("var_conv_2d", conv_type_)->AsIntermediate();
auto* act = OpNode("act", act_type_)->AsIntermediate();
auto* conv2d_out = VarNode("conv2d_out")
->assert_is_op_output(conv_type_, "Out")
->assert_is_op_input(act_type_, "X")
->AsIntermediate();
auto* conv2d_out_1 = VarNode("conv2d_out_1")
->assert_is_op_output(conv_type_, "Col")
->AsIntermediate();
auto* out =
VarNode("output")->assert_is_op_output(act_type_, "Out")->AsOutput();
// create topology.
std::vector<PMNode*> conv2d_inputs{filter, input};
conv2d_inputs >> *conv2d >> *conv2d_out >> *act >> *out;
*conv2d >> *conv2d_out_1;
}
void VarConvActivationFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto conv_op = LiteOpRegistry::Global().Create(conv_type_);
auto conv_old = matched.at("var_conv_2d")->stmt()->op();
auto* scope = conv_old->scope();
auto& valid_places = conv_old->valid_places();
conv_op->Attach(op_desc, scope);
auto* new_op_node = graph->GraphCreateInstructNode(conv_op, valid_places);
IR_NODE_LINK_TO(matched.at("X"), new_op_node);
IR_NODE_LINK_TO(matched.at("W"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("output"));
}
cpp::OpDesc VarConvActivationFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc = *matched.at("var_conv_2d")->stmt()->op_info();
op_desc.SetOutput("Out", {matched.at("output")->arg()->name});
cpp::OpDesc act_op_desc = *matched.at("act")->stmt()->op_info();
if (act_type_ == "relu") {
op_desc.SetAttr("fuse_relu", true);
}
return op_desc;
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class VarConvActivationFuser : public FuseBase {
public:
explicit VarConvActivationFuser(const std::string& act_type,
const std::string& conv_type)
: act_type_(act_type), conv_type_(conv_type) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
std::string act_type_;
std::string conv_type_;
};
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -62,12 +62,14 @@ class Optimizer {
// TODO(Superjomn) Refine the fusion related design to select fusion
// kernels for devices automatically.
"lite_conv_activation_fuse_pass", //
"lite_var_conv_2d_activation_fuse_pass", //
"lite_fc_fuse_pass", //
"lite_shuffle_channel_fuse_pass", //
"lite_transpose_softmax_transpose_fuse_pass", //
"lite_interpolate_fuse_pass", //
"identity_scale_eliminate_pass", //
"elementwise_mul_constant_eliminate_pass", //
"lite_sequence_pool_concat_fuse_pass", //
#if (defined LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) || (defined LITE_WITH_CUDA) || \
(defined LITE_WITH_ARM)
"lite_elementwise_add_activation_fuse_pass", //
......
......@@ -262,14 +262,10 @@ void Instruction::Run() {
if (op_->run_once() && has_run_) {
return;
}
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "kernel launch";
#endif
// VLOG(4) << "kernel launch";
op_->InferShape();
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << ">> Running kernel: " << op_->op_info()->Repr() << " on Target "
<< TargetToStr(kernel_->target());
#endif
// VLOG(4) << ">> Running kernel: " << op_->op_info()->Repr() << " on Target "
// << TargetToStr(kernel_->target());
kernel_->Launch();
has_run_ = true;
}
......
......@@ -49,6 +49,7 @@ add_kernel(range_compute_arm ARM basic SRCS range_compute.cc DEPS ${lite_kernel_
add_kernel(dropout_compute_arm ARM basic SRCS dropout_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(layout_compute_arm ARM basic SRCS layout_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(instance_norm_compute_arm ARM basic SRCS instance_norm_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(grid_sampler_compute_arm ARM basic SRCS grid_sampler_compute.cc DEPS ${lite_kernel_deps} math_arm)
## 2.other basic kernels: basic kernels that not used in basic models
add_kernel(negative_compute_arm ARM extra SRCS negative_compute.cc DEPS ${lite_kernel_deps} math_arm)
......
......@@ -65,20 +65,20 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
no_dilation && flag_dw) {
/// dw conv impl
impl_ = new DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>;
VLOG(3) << "invoking dw conv";
// VLOG(3) << "invoking dw conv";
} else if (param.groups == 1 && kw == 3 && stride == 1 && kps_equal &&
no_dilation && pads_all_equal) {
/// winograd conv impl
impl_ = new WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>;
VLOG(3) << "invoking winograd conv";
// VLOG(3) << "invoking winograd conv";
} else if (param.groups == 1 && kw == 3 && stride == 2 &&
chin * chout < 4 * hin * win && kps_equal && no_dilation) {
/// direct conv impl
impl_ = new DirectConv<PRECISION(kFloat), PRECISION(kFloat)>;
VLOG(3) << "invoking direct conv";
// VLOG(3) << "invoking direct conv";
} else {
impl_ = new GemmLikeConv<PRECISION(kFloat), PRECISION(kFloat)>;
VLOG(3) << "invoking gemm like conv";
// VLOG(3) << "invoking gemm like conv";
}
impl_->SetContext(std::move(this->ctx_));
impl_->SetParam(param);
......@@ -117,14 +117,14 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
if (param.groups == ic && ic == oc && kps_equal && pads_equal &&
no_dilation && flag_dw) {
impl_ = new DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>;
VLOG(3) << "Run DepthwiseConv Int8";
// VLOG(3) << "Run DepthwiseConv Int8";
} else if (param.groups == 1 && kw == 3 && (sw == 1 || sw == 2) &&
kps_equal && no_dilation) {
impl_ = new DirectConv<PRECISION(kInt8), PRECISION(kFloat)>;
VLOG(3) << "Run DirectConv Int8";
// VLOG(3) << "Run DirectConv Int8";
} else {
impl_ = new GemmLikeConv<PRECISION(kInt8), PRECISION(kFloat)>;
VLOG(3) << "Run GemmLikeConvInt8";
// VLOG(3) << "Run GemmLikeConvInt8";
}
impl_->SetContext(std::move(this->ctx_));
impl_->SetParam(param);
......@@ -163,14 +163,14 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
if (param.groups == ic && ic == oc && kps_equal && pads_equal &&
no_dilation && flag_dw) {
impl_ = new DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>;
VLOG(3) << "Run DepthwiseConv Int8";
// VLOG(3) << "Run DepthwiseConv Int8";
} else if (param.groups == 1 && kw == 3 && (sw == 1 || sw == 2) &&
kps_equal && no_dilation) {
impl_ = new DirectConv<PRECISION(kInt8), PRECISION(kInt8)>;
VLOG(3) << "Run DirectConv Int8";
// VLOG(3) << "Run DirectConv Int8";
} else {
impl_ = new GemmLikeConv<PRECISION(kInt8), PRECISION(kInt8)>;
VLOG(3) << "Run GemmLikeConvInt8";
// VLOG(3) << "Run GemmLikeConvInt8";
}
impl_->SetContext(std::move(this->ctx_));
impl_->SetParam(param);
......
......@@ -30,7 +30,7 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
auto kw = w_dims[3];
// select dw conv kernel
if (kw == 3) {
VLOG(5) << "invoke 3x3 dw conv fp32";
// VLOG(5) << "invoke 3x3 dw conv fp32";
auto paddings = *param.paddings;
bool pads_equal =
((paddings[0] == paddings[1]) && (paddings[2] == paddings[3]));
......@@ -54,7 +54,7 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
flag_trans_weights_ = true;
}
} else if (kw == 5) {
VLOG(5) << "invoke 5x5 dw conv fp32";
// VLOG(5) << "invoke 5x5 dw conv fp32";
impl_ = lite::arm::math::conv_depthwise_5x5_fp32;
} else {
LOG(FATAL) << "this type dw conv not impl";
......@@ -86,7 +86,7 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
/// select dw conv kernel
if (kw == 3) {
// trans weights
VLOG(5) << "invoke 3x3 dw conv int8 kernel fp32 out";
// VLOG(5) << "invoke 3x3 dw conv int8 kernel fp32 out";
impl_ = lite::arm::math::conv_depthwise_3x3_int8_fp32;
int cround = ROUNDUP(w_dims[0], 8);
weights_.Resize({cround / 8, 1, kh * kw, 8});
......@@ -96,7 +96,7 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
flag_trans_weights_ = true;
} else if (kw == 5) {
// trans weights
VLOG(5) << "invoke 5x5 dw conv int8 kernel fp32 out";
// VLOG(5) << "invoke 5x5 dw conv int8 kernel fp32 out";
impl_ = lite::arm::math::conv_depthwise_5x5_int8_fp32;
int cround = ROUNDUP(w_dims[0], 8);
weights_.Resize({cround / 8, 1, kh * kw, 8});
......@@ -145,7 +145,7 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
/// select dw conv kernel
if (kw == 3) {
// trans weights
VLOG(5) << "invoke 3x3 dw conv int8 kernel int8 out";
// VLOG(5) << "invoke 3x3 dw conv int8 kernel int8 out";
impl_ = lite::arm::math::conv_depthwise_3x3_int8_int8;
int cround = ROUNDUP(w_dims[0], 8);
weights_.Resize({cround / 8, 1, kh * kw, 8});
......@@ -155,7 +155,7 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
flag_trans_weights_ = true;
} else if (kw == 5) {
// trans weights
VLOG(5) << "invoke 5x5 dw conv int8 kernel int8 out";
// VLOG(5) << "invoke 5x5 dw conv int8 kernel int8 out";
impl_ = lite::arm::math::conv_depthwise_5x5_int8_int8;
int cround = ROUNDUP(w_dims[0], 8);
weights_.Resize({cround / 8, 1, kh * kw, 8});
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/grid_sampler_compute.h"
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void GridSamplerCompute::PrepareForRun() {}
void GridSamplerCompute::Run() {
auto& param = this->Param<param_t>();
auto n = param.x->dims()[0];
auto c = param.x->dims()[1];
auto h = param.x->dims()[2];
auto w = param.x->dims()[3];
const float* in = param.x->data<float>();
const float* grid = param.grid->data<float>();
float* out = param.out->mutable_data<float>();
auto& ctx = this->ctx_->template As<ARMContext>();
const size_t coor_size = n * h * w;
const size_t workspace_size = coor_size * 12 * sizeof(float);
ctx.ExtendWorkspace(workspace_size);
int32_t* coor_p = ctx.workspace_data<int>();
float* dis_p = reinterpret_cast<float*>(coor_p) + coor_size * 4;
uint32_t* bound_p = reinterpret_cast<uint32_t*>(dis_p) + coor_size * 4;
float x_max = static_cast<float>(w - 1);
float y_max = static_cast<float>(h - 1);
float32x4_t vxmax = vdupq_n_f32(x_max);
float32x4_t vymax = vdupq_n_f32(y_max);
float32x4_t vone = vdupq_n_f32(1.f);
float32x4_t vzero = vdupq_n_f32(0.f);
// compute coor, dis, bound
int i = coor_size;
for (; i > 3; i -= 4) {
float32x4x2_t xy = vld2q_f32(grid);
float32x4_t grid_x = vmulq_n_f32(vaddq_f32(xy.val[0], vone), 0.5 * x_max);
float32x4_t grid_y = vmulq_n_f32(vaddq_f32(xy.val[1], vone), 0.5 * y_max);
grid += 8;
// compute xw, we, yn, ys
int32x4x4_t vcoor;
vcoor.val[0] = vcvtq_s32_f32(grid_x);
vcoor.val[2] = vcvtq_s32_f32(grid_y);
float32x4_t vxwf = vcvtq_f32_s32(vcoor.val[0]);
float32x4_t vynf = vcvtq_f32_s32(vcoor.val[2]);
float32x4_t vxef = vaddq_f32(vxwf, vone);
float32x4_t vysf = vaddq_f32(vynf, vone);
vcoor.val[1] = vcvtq_s32_f32(vxef);
vcoor.val[3] = vcvtq_s32_f32(vysf);
vst4q_s32(coor_p, vcoor);
coor_p += 16;
// compute dw, dn ,de, ds
float32x4x4_t vdis;
vdis.val[0] = vsubq_f32(grid_x, vxwf);
vdis.val[2] = vsubq_f32(grid_y, vynf);
vdis.val[1] = vsubq_f32(vxef, grid_x);
vdis.val[3] = vsubq_f32(vysf, grid_y);
vst4q_f32(dis_p, vdis);
dis_p += 16;
// compute bound
uint32x4x4_t vbound;
uint32x4_t logic_xw =
vorrq_u32(vcltq_f32(vxwf, vzero), vcgtq_f32(vxwf, vxmax));
uint32x4_t logic_xe =
vorrq_u32(vcltq_f32(vxef, vzero), vcgtq_f32(vxef, vxmax));
uint32x4_t logic_yn =
vorrq_u32(vcltq_f32(vynf, vzero), vcgtq_f32(vynf, vymax));
uint32x4_t logic_ys =
vorrq_u32(vcltq_f32(vysf, vzero), vcgtq_f32(vysf, vymax));
vbound.val[0] = vmvnq_u32(vorrq_u32(logic_xw, logic_yn));
vbound.val[1] = vmvnq_u32(vorrq_u32(logic_xe, logic_yn));
vbound.val[2] = vmvnq_u32(vorrq_u32(logic_xw, logic_ys));
vbound.val[3] = vmvnq_u32(vorrq_u32(logic_xe, logic_ys));
vst4q_u32(bound_p, vbound);
bound_p += 16;
}
for (; i > 0; i--) {
float x = grid[0];
float y = grid[1];
float grid_x = (x + 1) * 0.5 * x_max;
float grid_y = (y + 1) * 0.5 * y_max;
grid += 2;
// compute xw, xe, yn, ys
int32_t xw = static_cast<int32_t>(floor(grid_x));
int32_t xe = xw + 1;
int32_t yn = static_cast<int32_t>(floor(grid_y));
int32_t ys = yn + 1;
*coor_p++ = xw;
*coor_p++ = xe;
*coor_p++ = yn;
*coor_p++ = ys;
// compute dw, de, dn, ds
float dw = grid_x - xw;
float de = xe - grid_x;
float dn = grid_y - yn;
float ds = ys - grid_y;
*dis_p++ = dw;
*dis_p++ = de;
*dis_p++ = dn;
*dis_p++ = ds;
// compute bound
bool logic_xw = (xw < 0.f || xw > x_max);
bool logic_xe = (xe < 0.f || xe > x_max);
bool logic_yn = (yn < 0.f || yn > y_max);
bool logic_ys = (ys < 0.f || ys > y_max);
*bound_p++ = ((logic_xw || logic_yn) ? 0 : 0xffffffff);
*bound_p++ = ((logic_xe || logic_yn) ? 0 : 0xffffffff);
*bound_p++ = ((logic_xw || logic_ys) ? 0 : 0xffffffff);
*bound_p++ = ((logic_xe || logic_ys) ? 0 : 0xffffffff);
}
size_t cube_size = c * h * w;
size_t spatial_size = h * w;
// compute output
for (int i = 0; i < n; ++i) {
const float* in_n = in + i * cube_size;
float* out_n = out + i * cube_size;
int32_t* coor_n = ctx.workspace_data<int>() + i * spatial_size * 4;
float* dis_n = reinterpret_cast<float*>(coor_n) + coor_size * 4;
uint32_t* bound_n = reinterpret_cast<uint32_t*>(dis_n) + coor_size * 4;
#pragma omp parallel for
for (int j = 0; j < c; ++j) {
int32_t* coor_ptr = coor_n;
float* dis_ptr = dis_n;
uint32_t* bound_ptr = bound_n;
const float* in_c = in_n + j * spatial_size;
float* out_c = out_n + j * spatial_size;
for (int k = 0; k < spatial_size; k++) {
int32x4_t vcoor = vld1q_s32(coor_ptr);
float32x4_t vdis = vld1q_f32(dis_ptr);
int32_t xw = vgetq_lane_s32(vcoor, 0);
int32_t xe = vgetq_lane_s32(vcoor, 1);
int32_t yn = vgetq_lane_s32(vcoor, 2);
int32_t ys = vgetq_lane_s32(vcoor, 3);
uint32x4_t vbound = vld1q_u32(bound_ptr);
float dw = vgetq_lane_f32(vdis, 0);
float de = vgetq_lane_f32(vdis, 1);
float dn = vgetq_lane_f32(vdis, 2);
float ds = vgetq_lane_f32(vdis, 3);
uint32_t wnbound = vgetq_lane_u32(vbound, 0);
uint32_t enbound = vgetq_lane_u32(vbound, 1);
uint32_t wsbound = vgetq_lane_u32(vbound, 2);
uint32_t esbound = vgetq_lane_u32(vbound, 3);
float in_wn = wnbound ? in_c[yn * w + xw] : 0.f;
float in_en = enbound ? in_c[yn * w + xe] : 0.f;
float in_ws = wsbound ? in_c[ys * w + xw] : 0.f;
float in_es = esbound ? in_c[ys * w + xe] : 0.f;
coor_ptr += 4;
dis_ptr += 4;
bound_ptr += 4;
*out_c++ =
ds * (in_wn * de + in_en * dw) + dn * (in_ws * de + in_es * dw);
}
}
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(grid_sampler,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::GridSamplerCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Grid", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class GridSamplerCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::GridSamplerParam;
void PrepareForRun() override;
void Run() override;
virtual ~GridSamplerCompute() = default;
private:
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -41,18 +41,20 @@ void PoolCompute::Run() {
std::vector<int>& paddings = *param.paddings;
std::string& pooling_type = param.pooling_type;
bool global_pooling = param.global_pooling;
bool exclusive = param.exclusive;
bool adaptive = param.adaptive;
bool ceil_mode = param.ceil_mode;
bool use_quantizer = param.use_quantizer;
std::string& data_format = param.data_format;
bool pads_equal =
(paddings[0] == paddings[1]) && (paddings[2] == paddings[3]);
bool kps_equal = (ksize[0] == ksize[1]) && (strides[0] == strides[1]) &&
(paddings[0] == paddings[2]);
bool pads_equal = (paddings[0] == paddings[1]) &&
(paddings[2] == paddings[3]) &&
(paddings[0] == paddings[2]);
bool kps_equal =
(ksize[0] == ksize[1]) && (strides[0] == strides[1]) && pads_equal;
bool global_pooling = (paddings[0] == 0) && (ksize[0] == in_dims[2]) &&
(ksize[1] == in_dims[3]) && pads_equal;
global_pooling = param.global_pooling || global_pooling;
if (global_pooling) {
for (size_t i = 0; i < ksize.size(); ++i) {
paddings[2 * i] = 0;
......@@ -83,8 +85,7 @@ void PoolCompute::Run() {
return;
}
} else {
if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 0 && pads_equal &&
kps_equal) {
if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 0 && kps_equal) {
if (pooling_type == "max") {
lite::arm::math::pooling2x2s2_max(din,
dout,
......@@ -110,7 +111,7 @@ void PoolCompute::Run() {
return;
}
} else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 1 &&
pads_equal && kps_equal) {
kps_equal) {
if (pooling_type == "max") {
lite::arm::math::pooling3x3s1p1_max(din,
dout,
......@@ -136,7 +137,7 @@ void PoolCompute::Run() {
return;
}
} else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 0 &&
pads_equal && kps_equal) {
kps_equal) {
if (pooling_type == "max") {
lite::arm::math::pooling3x3s1p0_max(din,
dout,
......@@ -162,7 +163,7 @@ void PoolCompute::Run() {
return;
}
} else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 0 &&
pads_equal && kps_equal) {
kps_equal) {
if (pooling_type == "max") {
lite::arm::math::pooling3x3s2p0_max(din,
dout,
......@@ -188,7 +189,7 @@ void PoolCompute::Run() {
return;
}
} else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 1 &&
pads_equal && kps_equal) {
kps_equal) {
if (pooling_type == "max") {
lite::arm::math::pooling3x3s2p1_max(din,
dout,
......
......@@ -54,7 +54,7 @@ void SplitLodTensorCompute::Run() {
}
lod->clear();
for (size_t i = 0; i < static_cast<size_t>(mask_dim[0]); i++) {
VLOG(4) << "mask: " << mask_data[i];
// VLOG(4) << "mask: " << mask_data[i];
if (static_cast<size_t>(mask_data[i]) == t) {
size_t start_idx = i;
auto lod_and_offset = lite::arm::math::GetSubLoDAndAbsoluteOffset(
......
......@@ -36,7 +36,7 @@ class StepExecutor {
auto &op_desc = *block->template GetOp<cpp::OpDesc>(i);
auto op_type = op_desc.Type();
auto op_handler = lite::LiteOpRegistry::Global().Create(op_desc.Type());
VLOG(4) << "while: creating Op [" << op_type << "]";
// VLOG(4) << "while: creating Op [" << op_type << "]";
op_handler->Attach(op_desc, scope);
auto hostplace = place_;
......@@ -51,9 +51,9 @@ class StepExecutor {
void Run() {
for (auto &op_handler : ops_of_block_) {
VLOG(4) << op_handler->op_info()->Repr();
// VLOG(4) << op_handler->op_info()->Repr();
op_handler->InferShape();
VLOG(4) << "while: infered shape";
// VLOG(4) << "while: infered shape";
op_handler->Run();
}
}
......
......@@ -11,6 +11,7 @@ add_kernel(leaky_relu_compute_cuda CUDA basic SRCS leaky_relu_compute.cu DEPS ${
add_kernel(relu_compute_cuda CUDA basic SRCS relu_compute.cu DEPS ${lite_kernel_deps})
add_kernel(yolo_box_compute_cuda CUDA basic SRCS yolo_box_compute.cu DEPS ${lite_kernel_deps})
add_kernel(sequence_pool_compute_cuda CUDA extra SRCS sequence_pool_compute.cu DEPS ${lite_kernel_deps})
add_kernel(sequence_pool_concat_compute_cuda CUDA extra SRCS sequence_pool_concat_compute.cu DEPS ${lite_kernel_deps})
add_kernel(transpose_compute_cuda CUDA basic SRCS transpose_compute.cu DEPS ${lite_kernel_deps} ${math_cuda} cuda_transpose)
add_kernel(nearest_interp_compute_cuda CUDA basic SRCS nearest_interp_compute.cu DEPS ${lite_kernel_deps})
add_kernel(conv2d_cuda CUDA basic SRCS conv_compute.cc DEPS ${lite_kernel_deps} ${math_cuda})
......
......@@ -10,6 +10,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/match_matrix_tensor_compute.h"
......@@ -20,6 +21,54 @@ namespace kernels {
namespace cuda {
using Tensor = lite::Tensor;
template <typename dtype>
void gpu_transpose(
cublasHandle_t handle, const dtype* src, int M, int N, dtype* dst);
template <>
void gpu_transpose<float>(
cublasHandle_t handle, const float* src, int M, int N, float* dst) {
float alpha = 1.0;
float beta = 0.0;
CUBLAS_CHECK(cublasSgeam(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
M,
N,
&alpha,
src,
N,
&beta,
dst,
M,
dst,
M));
}
template <typename dtype>
__global__ void padding_out(const dtype* src,
const int* offset,
const int seq_num_r,
const int max_len_r,
const int tl,
const int count,
dtype* dst) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
int thread_num = blockDim.x * gridDim.x;
for (tid = threadIdx.x + blockIdx.x * blockDim.x; tid < count;
tid += thread_num) {
int seq_id = tid / (tl * max_len_r);
int tl_id = (tid / (max_len_r)) % tl;
int r_id = tid % max_len_r;
int cur_len = offset[seq_id + 1] - offset[seq_id];
if (r_id < cur_len) {
dst[tid] = src[(offset[seq_id] + r_id) * tl + tl_id];
} else {
dst[tid] = 0.f;
}
}
}
void MatchMatrixTensorCompute::PrepareForRun() {
gemm_impl_.reset(new lite::cuda::math::Gemm<float, float>);
}
......@@ -28,6 +77,7 @@ void MatchMatrixTensorCompute::Run() {
CHECK(ctx_) << "running context should be set first";
auto& param = this->Param<param_t>();
auto& context = this->ctx_->template As<CUDAContext>();
auto stream = context.exec_stream();
auto* x = param.x;
auto* w = param.w;
......@@ -39,76 +89,74 @@ void MatchMatrixTensorCompute::Run() {
const auto& offset_l = x->lod()[0];
const auto& offset_r = y->lod()[0];
std::vector<size_t> top_offset;
int top_size = 0;
top_offset.push_back(top_size);
for (size_t b = 0; b < x->lod()[0].size() - 1; b++) {
int len_l = offset_l[b + 1] - offset_l[b];
int len_r = offset_r[b + 1] - offset_r[b];
top_size += dim_t * len_l * len_r;
top_offset.push_back(top_size);
std::vector<int> offset_r_int(offset_r.size());
std::transform(offset_r.begin(),
offset_r.end(),
offset_r_int.begin(),
[](int64_t x) -> int { return static_cast<int>(x); });
int batch = offset_r.size() - 1;
int len_l = offset_l[1] - offset_l[0];
for (int i = 1; i < offset_l.size() - 1; i++) {
int cur_len = offset_l[i + 1] - offset_l[i];
CHECK_EQ(cur_len, len_l)
<< "each sequence of left matrix is the same length";
}
auto* bottom_l_data = x->data<float>();
auto* bottom_r_data = y->data<float>();
auto* t_data = w->data<float>();
auto* out_data = out->mutable_data<float>(TARGET(kCUDA));
auto* bottom_l_trans_data = tmp->mutable_data<float>(TARGET(kCUDA));
gemm_impl_->init(
false, false, x->dims()[0], dim_t * dim_in, dim_in, &context);
gemm_impl_->run(
1.0f, 0.0f, bottom_l_data, t_data, bottom_l_trans_data, &context);
for (size_t b = 0; b < x->lod()[0].size() - 1; b++) {
for (int t = 0; t < dim_t; t++) {
int len_l = offset_l[b + 1] - offset_l[b];
int len_r = offset_r[b + 1] - offset_r[b];
auto* top_data = out_data + top_offset[b] + t * len_l * len_r;
const auto* l_t_data =
bottom_l_trans_data + offset_l[b] * dim_t * dim_in + t * dim_in;
const auto* r_data = bottom_r_data + offset_r[b] * dim_in;
gemm_impl_->init(false,
true,
len_l,
len_r,
dim_in,
dim_t * dim_in,
dim_in,
len_r,
&context);
gemm_impl_->run(1.0f, 0.0f, l_t_data, r_data, top_data, &context);
}
int max_len_r = 0;
for (int i = 0; i < offset_r.size() - 1; ++i) {
int cur_len = offset_r[i + 1] - offset_r[i];
max_len_r = cur_len > max_len_r ? cur_len : max_len_r;
}
int batch_size = x->lod()[0].size() - 1;
int lod_lv1_size = batch_size * dim_t;
int lod_lv2_size = x->lod()[0].back() * dim_t;
std::vector<size_t> out_lod0(batch_size + 1, 0);
std::vector<size_t> out_lod1(lod_lv1_size + 1, 0);
std::vector<size_t> out_lod2(lod_lv2_size + 1, 0);
for (int i = 0; i < batch_size; i++) {
out_lod0[i + 1] = out_lod0[i] + dim_t;
int len_l = offset_l[i + 1] - offset_l[i];
for (int j = 0; j < dim_t; j++) {
out_lod1[i * dim_t + j + 1] = out_lod1[i * dim_t + j] + len_l;
int len_r = offset_r[i + 1] - offset_r[i];
for (int k = 0; k < len_l; k++) {
out_lod2[offset_l[i] * dim_t + j * len_l + k + 1] =
out_lod2[offset_l[i] * dim_t + j * len_l + k] + len_r;
}
}
_input_l_transform.Resize({batch, dim_t, dim_in, len_l});
_input_l_transform_reorganize.Resize({batch, dim_t, len_l, dim_in});
_output_tmp.Resize({batch, max_len_r, dim_t, len_l});
out->Resize({batch, dim_t, len_l, max_len_r});
_offset_r.Resize({static_cast<int64_t>(offset_r.size())});
TargetWrapperCuda::MemcpyAsync(_offset_r.mutable_data<int>(TARGET(kCUDA)),
&offset_r_int[0],
sizeof(int) * offset_r.size(),
IoDirection::HtoD,
stream);
int len_r = offset_r[offset_r.size() - 1];
const float* input_l = x->data<float>();
const float* input_r = y->data<float>();
const float* weight_data = w->data<float>();
float* input_l_transform =
_input_l_transform.mutable_data<float>(TARGET(kCUDA));
float* input_l_transform_reorganize =
_input_l_transform_reorganize.mutable_data<float>(TARGET(kCUDA));
float* output_tmp = _output_tmp.mutable_data<float>(TARGET(kCUDA));
float* out_data = out->mutable_data<float>(TARGET(kCUDA));
gemm_impl_->init(true, true, dim_t * dim_in, len_l, dim_in, &context);
gemm_impl_->run(
1.0f, 0.0f, weight_data, input_l, input_l_transform, &context);
for (int i = 0; i < dim_t; ++i) {
int offset = i * dim_in * len_l;
gpu_transpose(gemm_impl_->get_handle(),
input_l_transform + offset,
dim_in,
len_l,
input_l_transform_reorganize + offset);
}
LoD out_lod;
out_lod.push_back(top_offset);
out_lod.push_back(offset_l);
out_lod.push_back(offset_r);
out->set_lod(out_lod);
gemm_impl_->init(false, true, len_r, dim_t * len_l, dim_in, &context);
gemm_impl_->run(
1.0f, 0.0f, input_r, input_l_transform_reorganize, output_tmp, &context);
int seq_num = offset_r.size() - 1;
int count = seq_num * max_len_r * dim_t * len_l;
const int blocks = 512;
const int grids = (count + blocks - 1) / blocks;
padding_out<float><<<grids, blocks, 0, stream>>>(_output_tmp.data<float>(),
_offset_r.data<int>(),
seq_num,
max_len_r,
dim_t * len_l,
count,
out_data);
out->set_lod(y->lod());
}
} // namespace cuda
......
......@@ -34,6 +34,10 @@ class MatchMatrixTensorCompute
private:
std::unique_ptr<lite::cuda::math::Gemm<float, float>> gemm_impl_;
lite::Tensor _input_l_transform;
lite::Tensor _input_l_transform_reorganize;
lite::Tensor _output_tmp;
lite::Tensor _offset_r;
};
} // namespace cuda
......
......@@ -16,92 +16,6 @@ namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T>
static void anakin_NV_gemv(cublasHandle_t handle,
const bool TransA,
const int M,
const int N,
const T alpha,
const T* A,
const T* x,
const T beta,
T* y);
template <>
void anakin_NV_gemv<float>(cublasHandle_t handle,
const bool TransA,
const int M,
const int N,
const float alpha,
const float* A,
const float* x,
const float beta,
float* y) {
cublasOperation_t cuTransA = (TransA == false) ? CUBLAS_OP_T : CUBLAS_OP_N;
CUBLAS_CHECK(
cublasSgemv(handle, cuTransA, N, M, &alpha, A, N, x, 1, &beta, y, 1));
}
template <typename T>
static void anakin_NV_gemm(cublasHandle_t handle,
const bool TransA,
const bool TransB,
const int M,
const int N,
const int K,
const T alpha,
const T* A,
const T* B,
const T beta,
T* C);
template <>
void anakin_NV_gemm<float>(cublasHandle_t handle,
const bool TransA,
const bool TransB,
const int M,
const int N,
const int K,
const float alpha,
const float* A,
const float* B,
const float beta,
float* C) {
// Note that cublas follows fortran order.
int lda = (!TransA /* == CblasNoTrans*/) ? K : M;
int ldb = (!TransB /* == CblasNoTrans*/) ? N : K;
cublasOperation_t cuTransA =
(!TransA /* == CblasNoTrans*/) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(!TransB /* == CblasNoTrans*/) ? CUBLAS_OP_N : CUBLAS_OP_T;
CUBLAS_CHECK(cublasSgemm(handle,
cuTransB,
cuTransA,
N,
M,
K,
&alpha,
B,
ldb,
A,
lda,
&beta,
C,
N));
}
template <>
void anakin_NV_gemm<char>(cublasHandle_t handle,
const bool TransA,
const bool TransB,
const int M,
const int N,
const int K,
const char alpha,
const char* A,
const char* B,
const char beta,
char* C) {
LOG(FATAL) << "int8 gemm is not implemented";
}
template <typename T>
static __global__ void add_bias(int n,
......@@ -115,6 +29,11 @@ static __global__ void add_bias(int n,
}
}
template <typename T>
void SearchFcCompute<T>::PrepareForRun() {
gemm_impl_.reset(new lite::cuda::math::Gemm<float, float>);
}
template <typename T>
void SearchFcCompute<T>::Run() {
auto& param = this->Param<param_t>();
......@@ -132,22 +51,10 @@ void SearchFcCompute<T>::Run() {
const T* weight = w_tensor->data<T>();
const Tensor* b_tensor = param.b;
const T* bias = b_tensor->data<T>();
cublasCreate(&_handle);
if (_M == 1 && _K > 50000) {
anakin_NV_gemv<T>(_handle, false, _N, _K, (T)1, weight, din, (T)0, dout);
} else {
anakin_NV_gemm<T>(_handle,
false,
!_flag_trans_weights,
_M,
_N,
_K,
(T)1,
din,
weight,
(T)0,
dout);
}
CHECK(gemm_impl_->init(false, true, _M, _N, _K, &ctx));
gemm_impl_->run(1.0f, 0.0f, din, weight, dout, &ctx);
int total_size = _M * _N;
add_bias<T><<<CUDA_GET_BLOCKS(total_size), CUDA_NUM_THREADS, 0, stream>>>(
total_size, _N, bias, dout);
......
......@@ -14,7 +14,9 @@
#pragma once
#include <cudnn.h>
#include <memory>
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/backends/cuda/math/gemm.h"
#include "lite/core/kernel.h"
namespace paddle {
......@@ -34,16 +36,15 @@ template <typename T>
class SearchFcCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::SearchFcParam;
void PrepareForRun() override;
void Run() override;
virtual ~SearchFcCompute() = default;
private:
bool _flag_trans_weights{false};
std::unique_ptr<lite::cuda::math::Gemm<float, float>> gemm_impl_{nullptr};
int _M;
int _K;
int _N;
cublasHandle_t _handle;
bool _is_continue_buf{true};
};
} // namespace cuda
......
......@@ -22,43 +22,44 @@ namespace lite {
namespace kernels {
namespace cuda {
const int CUDA_NUM_THREADS = 512;
template <typename T>
inline LoD ConcatLoD(const std::vector<lite::Tensor*>& xs) {
std::vector<size_t> result;
result.resize(xs[0]->lod()[0].size());
for (size_t i = 1; i < result.size(); ++i) {
size_t sum = 0;
for (size_t j = 0; j < xs.size(); ++j) {
auto& x_lod = xs[j]->lod()[0];
sum += x_lod[i];
}
result[i] = sum;
template <typename dtype>
__global__ void concat_impl_cuda(const int nthreads,
const dtype* in_data,
const int num_concats,
const int concat_size,
const int top_concat_axis,
const int bottom_concat_axis,
const int offset_concat_axis,
dtype* out_data) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
index += blockDim.x * gridDim.x) {
const int total_concat_size = concat_size * bottom_concat_axis;
const int concat_num = index / total_concat_size;
const int concat_index = index % total_concat_size;
const int top_index =
concat_index +
(concat_num * top_concat_axis + offset_concat_axis) * concat_size;
out_data[top_index] = in_data[index];
}
LoD lod;
lod.emplace_back(result);
return lod;
}
template <typename Dtype>
__global__ void ker_sequence_concat(Dtype* out_data,
const uint64_t* in_locate_data,
const int* o2i_map,
const int* o2i_w_map,
const int seq_num,
const int emb_size,
const int count) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
for (int tid = idx; tid < count; tid += blockDim.x * gridDim.x) {
int emb_id = tid % emb_size;
int word_id = tid / emb_size;
int input_id = o2i_map[word_id];
int cur_work_id = o2i_w_map[word_id];
const Dtype* in_data = reinterpret_cast<const Dtype*>(
reinterpret_cast<uintptr_t>(in_locate_data[input_id]));
out_data[tid] = in_data[cur_work_id * emb_size + emb_id];
template <typename dtype>
__global__ void concat_impl_2d_impl(const int inner_size,
const int num_concats,
const dtype* in_data,
const int concat_size,
const int out_concat_axis,
const int offset_concat_axis,
dtype* out_data) {
int idx_inner = threadIdx.x + blockIdx.x * blockDim.x;
int idx_outer = threadIdx.y + blockIdx.y * blockDim.y;
if (idx_inner < inner_size && idx_outer < num_concats) {
int idx_input = idx_outer * inner_size + idx_inner;
int idx_output =
(idx_outer * out_concat_axis + offset_concat_axis) * concat_size +
idx_inner;
out_data[idx_output] = in_data[idx_input];
}
}
......@@ -66,73 +67,75 @@ void SequenceConcatCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
float* out_data = param.Out->mutable_data<float>(TARGET(kCUDA));
int seq_num = param.X[0]->lod()[0].size() - 1;
const int emb_size = param.X[0]->numel() / param.X[0]->dims()[0];
std::vector<uint64_t> in_locate_vec;
for (size_t i = 0; i < param.X.size(); ++i) {
in_locate_vec.push_back(
reinterpret_cast<uintptr_t>(param.X[i]->data<float>()));
}
in_locate_tensor.Resize({static_cast<int64_t>(in_locate_vec.size())});
const int BLOCK_SIZE = 32;
const int axis = 1;
int num_concats = param.X[0]->dims().count(0, axis);
int concat_input_size =
param.X[0]->dims().count(axis + 1, param.X[0]->dims().size());
std::vector<int> out2in_map;
std::vector<int> out2in_word_map;
for (int i = 0; i < seq_num; ++i) {
for (int j = 0; j < param.X.size(); ++j) {
auto offset = param.X[j]->lod()[0];
int cur_len = offset[i + 1] - offset[i];
for (int k = 0; k < cur_len; ++k) {
out2in_map.push_back(j);
out2in_word_map.push_back(offset[i] + k);
int input_size = param.X.size();
std::vector<std::vector<int64_t>> shapes_in(input_size);
for (int i = 0; i < input_size; ++i) {
shapes_in[i] = param.X[i]->dims().Vectorize();
}
std::vector<int64_t> shape_out = shapes_in[0];
// compute output shape
for (int i = 1; i < input_size; ++i) {
for (int j = 0; j < shapes_in[i].size(); ++j) {
if (j == axis) {
continue;
} else if (shapes_in[i][j] != -1) {
CHECK_EQ(shape_out[j], shapes_in[i][j])
<< "All inputs must have the same shape, except at concat_axis.";
}
}
shape_out[axis] += shapes_in[i][axis];
}
int word_num = out2in_map.size();
out2in_map_tensor.Resize({word_num});
out2in_word_map_tensor.Resize({word_num});
int* gpu_o2i_map_data = out2in_map_tensor.mutable_data<int>(TARGET(kCUDA));
int* gpu_o2i_w_map_data =
out2in_word_map_tensor.mutable_data<int>(TARGET(kCUDA));
uint64_t* gpu_in_locate_data =
in_locate_tensor.mutable_data<uint64_t>(TARGET(kCUDA));
TargetWrapperCuda::MemcpyAsync(gpu_o2i_map_data,
out2in_map.data(),
sizeof(int) * out2in_map.size(),
IoDirection::HtoD,
stream);
TargetWrapperCuda::MemcpyAsync(gpu_o2i_w_map_data,
out2in_word_map.data(),
sizeof(int) * out2in_word_map.size(),
IoDirection::HtoD,
stream);
TargetWrapperCuda::MemcpyAsync(gpu_in_locate_data,
in_locate_vec.data(),
sizeof(uint64_t) * in_locate_vec.size(),
IoDirection::HtoD,
stream);
param.Out->set_lod(ConcatLoD<float>(param.X));
int count = param.X[0]->numel();
for (int i = 1; i < param.X.size(); ++i) {
count += param.X[i]->numel();
param.Out->Resize(shape_out);
float* out_data = param.Out->mutable_data<float>(TARGET(kCUDA));
int offset_concat_axis = 0;
const int out_concat_axis = shape_out[axis];
for (int i = 0; i < input_size; ++i) {
std::vector<int64_t> in_shape = param.X[i]->dims().Vectorize();
const auto* in_data = param.X[i]->data<float>();
const int in_concat_axis = in_shape[axis];
const int in_concat_size = in_concat_axis * concat_input_size;
const int nthreads = in_concat_size * num_concats;
float ratio = static_cast<float>(in_concat_size) / num_concats;
bool is_balance = (ratio > 0.1 && ratio < 10);
if (is_balance) {
int block_x = BLOCK_SIZE;
int block_y = BLOCK_SIZE;
int grid_x = (in_concat_size + block_x - 1) / block_x;
int grid_y = (num_concats + block_y - 1) / block_y;
dim3 block(block_x, block_y);
dim3 grid(grid_x, grid_y);
concat_impl_2d_impl<float><<<grid, block, 0, stream>>>(in_concat_size,
num_concats,
in_data,
concat_input_size,
out_concat_axis,
offset_concat_axis,
out_data);
} else {
int grid = (nthreads + BLOCK_SIZE - 1) / BLOCK_SIZE;
concat_impl_cuda<float><<<grid, BLOCK_SIZE, 0, stream>>>(
nthreads,
in_data,
num_concats,
concat_input_size,
out_concat_axis,
in_concat_axis,
offset_concat_axis,
out_data);
}
offset_concat_axis += in_concat_axis;
}
int blocks = (count + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
ker_sequence_concat<float><<<blocks, CUDA_NUM_THREADS, 0, stream>>>(
out_data,
gpu_in_locate_data,
gpu_o2i_map_data,
gpu_o2i_w_map_data,
seq_num,
emb_size,
count);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
param.Out->set_lod(param.X[0]->lod());
}
} // namespace cuda
......
......@@ -27,11 +27,6 @@ class SequenceConcatCompute
void Run() override;
virtual ~SequenceConcatCompute() = default;
private:
lite::Tensor out2in_map_tensor;
lite::Tensor out2in_word_map_tensor;
lite::Tensor in_locate_tensor;
};
} // namespace cuda
......
此差异已折叠。
此差异已折叠。
......@@ -21,6 +21,8 @@ namespace kernels {
namespace cuda {
using Tensor = lite::Tensor;
const int CUDA_NUM_THREADS = 512;
extern __shared__ char tile[];
template <typename dtype>
__global__ void sharemem_softmax_kernel(int total_size,
......@@ -149,6 +151,15 @@ __global__ void softmax_divid_output_kernel(int total_size,
}
}
void SoftmaxCompute::PrepareForRun() {
int device_id;
cudaGetDevice(&device_id);
cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, device_id);
sharedmem_size = deviceProp.sharedMemPerBlock;
max_dimsize = sharedmem_size / sizeof(float) / CUDA_NUM_THREADS;
}
void SoftmaxCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
......@@ -165,18 +176,10 @@ void SoftmaxCompute::Run() {
int total_threads = inner_num * outer_num;
int axis_size = x_dims[axis];
int device_id;
const int threads = 512;
const int threads = CUDA_NUM_THREADS;
const int blocks = (total_threads + threads - 1) / threads;
cudaGetDevice(&device_id);
cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, device_id);
size_t sharedmem_size = deviceProp.sharedMemPerBlock;
int max_dimsize = sharedmem_size / sizeof(float) / threads;
auto input_data = param.x->data<float>();
auto output_data = param.output->mutable_data<float>(TARGET(kCUDA));
TargetWrapperCuda::MemsetSync(
output_data, 0, param.output->numel() * sizeof(float));
if (axis_size <= max_dimsize) {
int use_sharemem_size = axis_size * threads * sizeof(float);
sharemem_softmax_kernel<<<blocks, threads, use_sharemem_size, stream>>>(
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册