提交 7126524d 编写于 作者: D DannyIsFunny

Merge remote-tracking branch 'origin' into test_result

......@@ -48,7 +48,7 @@ cuda的编译结果位于 `build_cuda/inference_lite_lib`
4、 `demo` 文件夹:c++ demo.
如果编译打开了python选项,则会在 `build_cuda/inference_lite_lib/python/lib/` 目录下生成 `lite_core.so`
如果编译打开了python选项,则会在 `build_cuda/inference_lite_lib/python/lib/` 目录下生成 `lite.so`
## 运行
......@@ -66,7 +66,7 @@ wget https://paddle-inference-dist.cdn.bcebos.com/PaddleLite/kite.jpg
二: 运行
**NOTE:**此处示例使用的是python接口。
**NOTE:** 此处示例使用的是python接口。
``` python
#-*- coding: utf-8 -*-
......@@ -75,7 +75,7 @@ import sys
import numpy as np
import cv2
sys.path.append('build_cuda/inference_lite_lib/python/lib')
from lite_core import *
from lite import *
def read_img(im_path, resize_h, resize_w):
im = cv2.imread(im_path).astype('float32')
......
......@@ -181,7 +181,7 @@ class LITE_API CxxConfig : public ConfigBase {
#endif
#ifdef LITE_WITH_CUDA
void set_multi_stream(bool multi_stream) { multi_stream_ = multi_stream; }
int multi_stream() const { return multi_stream_; }
bool multi_stream() const { return multi_stream_; }
#endif
#ifdef LITE_WITH_MLU
......
......@@ -52,6 +52,7 @@ USE_MIR_PASS(mlu_postprocess_pass);
USE_MIR_PASS(weight_quantization_preprocess_pass);
USE_MIR_PASS(apu_subgraph_pass);
USE_MIR_PASS(quantized_op_attributes_inference_pass);
USE_MIR_PASS(lite_scale_activation_fuse_pass);
USE_MIR_PASS(__xpu__resnet_fuse_pass);
USE_MIR_PASS(__xpu__multi_encoder_fuse_pass);
USE_MIR_PASS(__xpu__embedding_with_eltwise_add_fuse_pass);
......
此差异已折叠。
......@@ -40,6 +40,15 @@ void scale_compute_basic(const operators::ScaleParam& param) {
template <typename T>
void scale(const T* din, T* dout, int num, T scale, T bias);
template <typename T>
void scale_relu(const T* din, T* dout, int num, T scale, T bias);
template <typename T>
void scale_relu6(const T* din, T* dout, int num, T scale, T bias, T alpha);
template <typename T>
void scale_leaky_relu(const T* din, T* dout, int num, T scale, T bias, T alpha);
template <typename T>
void scale(const T* din,
T* dout,
......
......@@ -21,6 +21,7 @@ lite_cc_library(mir_passes
fusion/elementwise_add_activation_fuse_pass.cc
fusion/quant_dequant_fuse_pass.cc
fusion/sequence_pool_concat_fuse_pass.cc
fusion/scale_activation_fuse_pass.cc
fusion/__xpu__resnet_fuse_pass.cc
fusion/__xpu__multi_encoder_fuse_pass.cc
fusion/__xpu__embedding_with_eltwise_add_fuse_pass.cc
......
......@@ -31,6 +31,9 @@ lite_cc_library(fuse_interpolate
lite_cc_library(fuse_sequence_pool_concat
SRCS sequence_pool_concat_fuser.cc
DEPS pattern_matcher_high_api)
lite_cc_library(fuse_scale_activation
SRCS scale_activation_fuser.cc
DEPS pattern_matcher_high_api)
set(mir_fusers
fuse_fc
......@@ -44,6 +47,7 @@ set(mir_fusers
fuse_transpose_softmax_transpose
fuse_interpolate
fuse_sequence_pool_concat
fuse_scale_activation
CACHE INTERNAL "fusers")
if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/scale_activation_fuse_pass.h"
#include <memory>
#include <vector>
#include "lite/core/mir/fusion/scale_activation_fuser.h"
#include "lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void ScaleActivationFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
for (auto act_type : {"relu", "relu6", "leaky_relu"}) {
fusion::ScaleActivationFuser fuser(act_type);
fuser(graph.get());
}
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_scale_activation_fuse_pass,
paddle::lite::mir::ScaleActivationFusePass)
.BindTargets({TARGET(kARM)})
.BindKernel("scale");
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
class ScaleActivationFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/scale_activation_fuser.h"
#include <memory>
#include <vector>
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
void ScaleActivationFuser::BuildPattern() {
// create input nodes.
auto* x = VarNode("x")->assert_is_op_input("scale", "X")->AsInput();
// create op nodes
auto* scale =
OpNode("scale", "scale")->assert_is_op("scale")->AsIntermediate();
auto* act =
OpNode("act", act_type_)->assert_is_op(act_type_)->AsIntermediate();
// create intermediate nodes
auto* scale_out = VarNode("scale_out")
->assert_is_op_output("scale", "Out")
->assert_is_op_input(act_type_, "X")
->AsIntermediate();
// create output node
auto* out =
VarNode("output")->assert_is_op_output(act_type_, "Out")->AsOutput();
// create topology.
*x >> *scale >> *scale_out;
*scale_out >> *act >> *out;
}
void ScaleActivationFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto scale_op = LiteOpRegistry::Global().Create("scale");
auto scale = matched.at("scale")->stmt()->op();
auto* scope = scale->scope();
auto& valid_places = scale->valid_places();
scale_op->Attach(op_desc, scope);
auto* new_op_node = graph->GraphCreateInstructNode(scale_op, valid_places);
IR_NODE_LINK_TO(matched.at("x"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("output"));
}
cpp::OpDesc ScaleActivationFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc = *matched.at("scale")->stmt()->op_info();
op_desc.SetOutput("Out", {matched.at("output")->arg()->name});
cpp::OpDesc act_op_desc = *matched.at("act")->stmt()->op_info();
op_desc.SetAttr("activation_type", act_type_);
if (act_type_ == "relu") {
op_desc.SetAttr("fuse_relu", true);
} else if (act_type_ == "relu6") {
float alpha = act_op_desc.GetAttr<float>("threshold");
op_desc.SetAttr("alpha", alpha);
} else if (act_type_ == "leaky_relu") {
float alpha = act_op_desc.GetAttr<float>("alpha");
op_desc.SetAttr("alpha", alpha);
}
return op_desc;
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class ScaleActivationFuser : public FuseBase {
public:
explicit ScaleActivationFuser(const std::string& act_type) {
act_type_ = act_type;
}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
std::string act_type_;
};
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -86,6 +86,7 @@ class Optimizer {
"identity_scale_eliminate_pass", //
"elementwise_mul_constant_eliminate_pass", //
"lite_sequence_pool_concat_fuse_pass", //
"lite_scale_activation_fuse_pass", //
#if (defined LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) || (defined LITE_WITH_CUDA) || \
(defined LITE_WITH_ARM)
"lite_elementwise_add_activation_fuse_pass", //
......
project(demo CXX C)
cmake_minimum_required(VERSION 2.8)
project(demo CXX C)
add_definitions(-DLITE_WITH_CUDA)
set(TARGET demo)
set(CMAKE_CXX_FLAGS "-std=c++11 -O3")
set(LITE_LIB "${PROJECT_SOURCE_DIR}/../../cxx")
set(PROTOBUF_LIB "${PROJECT_SOURCE_DIR}/../../third_party/protobuf")
set(LITE_ROOT "${PROJECT_SOURCE_DIR}/../../cxx")
set(PROTOBUF_ROOT "${PROJECT_SOURCE_DIR}/../../third_party/protobuf")
include_directories("${LITE_LIB}/include")
link_directories("${LITE_LIB}/lib")
link_directories("${PROTOBUF_LIB}/lib")
include_directories("${LITE_ROOT}/include")
link_directories("${LITE_ROOT}/lib")
link_directories("${PROTOBUF_ROOT}/lib")
# cuda lib
link_directories("/usr/local/cuda/lib64/")
add_executable(${TARGET} ${TARGET}.cc)
set(DEPS ${LITE_LIB}/lib/libpaddle_full_api_shared.so)
set(DEPS ${LITE_ROOT}/lib/libpaddle_full_api_shared.so)
set(DEPS ${DEPS} protobuf-lite)
set(DEPS ${DEPS} "-lrt -lpthread -ldl")
set(DEPS ${DEPS} "-lrt -lpthread -ldl -lcudart")
target_link_libraries(${TARGET} ${DEPS})
......@@ -31,7 +31,18 @@ void ScaleCompute<T, PType>::Run() {
if (!param.bias_after_scale) {
bias *= scale;
}
lite::arm::math::scale<T>(x_data, output_data, num, scale, bias);
T alpha = param.alpha;
if (param.activation_type == "") { // no act
lite::arm::math::scale<T>(x_data, output_data, num, scale, bias);
} else if (param.activation_type == "relu") { // do relu
lite::arm::math::scale_relu<T>(x_data, output_data, num, scale, bias);
} else if (param.activation_type == "relu6") { // do relu6
lite::arm::math::scale_relu6<T>(
x_data, output_data, num, scale, bias, alpha);
} else if (param.activation_type == "leaky_relu") { // do leaky_relu
lite::arm::math::scale_leaky_relu<T>(
x_data, output_data, num, scale, bias, alpha);
}
if (!param.x->lod().empty()) {
param.output->set_lod(param.x->lod());
}
......
......@@ -244,6 +244,9 @@ struct ScaleParam : ParamBase {
float scale{1.};
float bias{};
bool bias_after_scale{true};
std::string activation_type{""};
bool fuse_relu{false};
float alpha{6.};
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() override {
......
......@@ -38,6 +38,20 @@ bool ScaleOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
param_.scale = op_desc.GetAttr<float>("scale");
param_.bias = op_desc.GetAttr<float>("bias");
param_.bias_after_scale = op_desc.GetAttr<bool>("bias_after_scale");
if (op_desc.HasAttr("activation_type")) {
auto act_type = op_desc.GetAttr<std::string>("activation_type");
param_.activation_type = act_type;
if (act_type == "relu") {
param_.fuse_relu = true;
} else if (act_type == "relu6") {
param_.alpha = op_desc.GetAttr<float>("alpha"); // 6.f
} else if (act_type == "leaky_relu") {
param_.alpha = op_desc.GetAttr<float>("alpha");
} else {
CHECK(false)
<< "The fused conv only supports fuse with relu and leaky relu";
}
}
CHECK(param_.x);
CHECK(param_.output);
return true;
......
......@@ -350,6 +350,7 @@ function make_cuda {
-DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=OFF \
-DWITH_TESTING=OFF \
-DLITE_WITH_ARM=OFF \
-DLITE_WITH_STATIC_CUDA=OFF \
-DLITE_WITH_PYTHON=${BUILD_PYTHON} \
-DLITE_BUILD_EXTRA=ON \
-DLITE_WITH_XPU=$BUILD_XPU \
......
......@@ -13,70 +13,75 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void bilinear_interp(__read_only image2d_t input, __write_only image2d_t output,
__private const float scale_h, __private const float scale_w,
__private const int in_dims_h, __private const int out_dims_h,
__private const int in_dims_w, __private const int out_dims_w,
__private const float align_delta) {
const int c = get_global_id(0);
const int w = get_global_id(1);
const int nh = get_global_id(2);
__kernel void bilinear_interp(
__read_only image2d_t input, __write_only image2d_t output,
__private const float scale_h, __private const float scale_w,
__private const int in_dims_h, __private const int out_dims_h,
__private const int in_dims_w, __private const int out_dims_w,
__private const float align_delta) {
const int c = get_global_id(0);
const int w = get_global_id(1);
const int nh = get_global_id(2);
int2 output_pos;
output_pos.x = c * out_dims_w + w;
output_pos.y = nh;
int2 output_pos;
output_pos.x = c * out_dims_w + w;
output_pos.y = nh;
// calculate center pixel's pos
int out_n = nh / out_dims_h;
int out_h = nh % out_dims_h;
float center_w = (w + align_delta) * scale_w - align_delta;
float center_h = (out_h + align_delta) * scale_h - align_delta;
// calculate center pixel's pos
int out_n = nh / out_dims_h;
int out_h = nh % out_dims_h;
float center_w = (w + align_delta) * scale_w - align_delta;
float center_h = (out_h + align_delta) * scale_h - align_delta;
int floor_w = (int)center_w;
int floor_h = (int)center_h;
int ceil_w = floor_w + 1;
int ceil_h = floor_h + 1;
int floor_w = (int)center_w;
int floor_h = (int)center_h;
int ceil_w = floor_w + 1;
int ceil_h = floor_h + 1;
if (ceil_w > in_dims_w) {
ceil_w = floor_w;
}
if (ceil_h > in_dims_h) {
ceil_h = floor_h;
}
float wight0_w = center_w - floor_w;
float wight0_h = center_h - floor_h;
float wight1_w = 1.0 - wight0_w;
float wight1_h = 1.0 - wight0_h;
if (ceil_w > in_dims_w) {
ceil_w = floor_w;
}
if (ceil_h > in_dims_h) {
ceil_h = floor_h;
}
float wight0_w = center_w - floor_w;
float wight0_h = center_h - floor_h;
float wight1_w = 1.0f - wight0_w;
float wight1_h = 1.0f - wight0_h;
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
// get left up pixel data
int2 left_up;
left_up.x = c * in_dims_w + floor_w;
left_up.y = out_n * in_dims_h + ceil_h;
half4 left_up_data = read_imageh(input, sampler, left_up);
// get left up pixel data
int2 left_up;
left_up.x = c * in_dims_w + floor_w;
left_up.y = out_n * in_dims_h + ceil_h;
half4 left_up_data = read_imageh(input, sampler, left_up);
// get left down pixel data
int2 left_down;
left_down.x = c * in_dims_w + floor_w;
left_down.y = out_n * in_dims_h + floor_h;
half4 left_down_data = read_imageh(input, sampler, left_down);
// get left down pixel data
int2 left_down;
left_down.x = c * in_dims_w + floor_w;
left_down.y = out_n * in_dims_h + floor_h;
half4 left_down_data = read_imageh(input, sampler, left_down);
// get right up pixel data
int2 right_up;
right_up.x = c * in_dims_w + ceil_w;
right_up.y = out_n * in_dims_h + ceil_h;
half4 right_up_data = read_imageh(input, sampler, right_up);
// get right up pixel data
int2 right_up;
right_up.x = c * in_dims_w + ceil_w;
right_up.y = out_n * in_dims_h + ceil_h;
half4 right_up_data = read_imageh(input, sampler, right_up);
// get right down pixel's data
int2 right_down;
right_down.x = c * in_dims_w + ceil_w;
right_down.y = out_n * in_dims_h + floor_h;
half4 right_down_data = read_imageh(input, sampler, right_down);
// get right down pixel's data
int2 right_down;
right_down.x = c * in_dims_w + ceil_w;
right_down.y = out_n * in_dims_h + floor_h;
half4 right_down_data = read_imageh(input, sampler, right_down);
// calculate output data
half4 data = (left_down_data * wight1_w + right_down_data * wight0_w) * wight1_h
+ (left_up_data * wight1_w + right_up_data * wight0_w) * wight0_h;
// calculate output data
half4 data =
(left_down_data * (half)wight1_w + right_down_data * (half)wight0_w) *
(half)wight1_h +
(left_up_data * (half)wight1_w + right_up_data * (half)wight0_w) *
(half)wight0_h;
write_imageh(output, output_pos, data);
write_imageh(output, output_pos, data);
}
\ No newline at end of file
......@@ -30,8 +30,6 @@ bool InstanceNormKernel<GPU_CL, float>::Init(InstanceNormParam<GPU_CL> *param) {
build_options = "-DLOCAL_MEM_128";
} else if (h == 64) {
build_options = "-DLOCAL_MEM_64";
} else if (h > 256) {
PADDLE_MOBILE_THROW_EXCEPTION("instance norm unsupported input height");
}
this->cl_helper_.AddKernel("instancenorm", "instancenorm_kernel.cl",
build_options);
......
......@@ -26,13 +26,11 @@ bool InstanceNormReluKernel<GPU_CL, float>::Init(
FusionInstanceNormReluParam<GPU_CL> *param) {
auto &dims = param->Out()->dims();
const int h = dims[2];
std::string build_options = "-DRELU";
std::string build_options = " -DRELU";
if (h == 128) {
build_options += " -DLOCAL_MEM_128";
} else if (h == 64) {
build_options += " -DLOCAL_MEM_64";
} else if (h > 256) {
PADDLE_MOBILE_THROW_EXCEPTION("instance norm unsupported input height");
}
this->cl_helper_.AddKernel("instancenorm", "instancenorm_kernel.cl",
build_options);
......
......@@ -442,9 +442,9 @@ endif()
if (FILL_CONSTANT_OP)
add_definitions(-DFILL_CONSTANT_OP)
endif()
if (FUSION_CONVADD_OP)
add_definitions(-DFUSION_CONVADD_OP)
endif()
# if (FUSION_CONVADD_OP)
# add_definitions(-DFUSION_CONVADD_OP)
# endif()
if (FUSION_CONVADDRELU_OP)
add_definitions(-DFUSION_CONVADDRELU_OP)
endif()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册