未验证 提交 ccf7892e 编写于 作者: H huzhiqiang 提交者: GitHub

Merge branch 'develop' into test_result

......@@ -63,6 +63,16 @@ test/models/
test/images/
*.pyc
# model
*.nb
*.svg
*.dot
# vim intermediate files
*.swp
# Emacs intermediate files
*~
......
......@@ -45,7 +45,7 @@ else()
# we changed the source code to adapt for windows compiling
# git diffs : (1) unsupported/Eigen/CXX11/src/Tensor/TensorBlockV2.h
######################################################################################################
URL https://paddlelite-data.bj.bcebos.com/third_party_libs/eigen-git-mirror-master-9ab917e9db99f5907d086aa73d5f9103.zip
URL http://paddlelite-data.bj.bcebos.com/third_party_libs/eigen-git-mirror-master-9ab917e9db99f5907d086aa73d5f9103.zip
DOWNLOAD_DIR ${EIGEN_SOURCECODE_DIR}
DOWNLOAD_NO_PROGRESS 1
PREFIX ${EIGEN_SOURCE_DIR}
......
......@@ -297,6 +297,7 @@ void Predictor::Build(const std::shared_ptr<cpp::ProgramDesc> &desc,
// `inner_places` is used to optimize passes
std::vector<Place> inner_places = valid_places;
for (auto &valid_place : valid_places) {
if (valid_place.target == TARGET(kOpenCL)) continue;
inner_places.emplace_back(
Place(TARGET(kHost), valid_place.precision, valid_place.layout));
}
......
......@@ -102,6 +102,10 @@ std::vector<std::string> CxxPaddleApiImpl::GetParamNames() {
return raw_predictor_->GetParamNames();
}
std::vector<std::string> CxxPaddleApiImpl::GetParamNames() {
return raw_predictor_.GetParamNames();
}
std::vector<std::string> CxxPaddleApiImpl::GetOutputNames() {
return raw_predictor_->GetOutputNames();
}
......
......@@ -55,7 +55,7 @@ DEFINE_string(model_file, "", "model file path of the combined-param model");
DEFINE_string(param_file, "", "param file path of the combined-param model");
DEFINE_string(
optimize_out_type,
"protobuf",
"naive_buffer",
"store type of the output optimized model. protobuf/naive_buffer");
DEFINE_bool(display_kernels, false, "Display kernel information");
DEFINE_bool(record_tailoring_info,
......@@ -207,7 +207,7 @@ void PrintOpsInfo(std::set<std::string> valid_ops = {}) {
}
std::cout << std::setiosflags(std::ios::internal);
std::cout << std::setw(maximum_optype_length) << "OP_name";
for (int i = 0; i < targets.size(); i++) {
for (size_t i = 0; i < targets.size(); i++) {
std::cout << std::setw(10) << targets[i].substr(1);
}
std::cout << std::endl;
......@@ -215,7 +215,7 @@ void PrintOpsInfo(std::set<std::string> valid_ops = {}) {
for (auto it = supported_ops.begin(); it != supported_ops.end(); it++) {
std::cout << std::setw(maximum_optype_length) << it->first;
auto ops_valid_places = it->second;
for (int i = 0; i < targets.size(); i++) {
for (size_t i = 0; i < targets.size(); i++) {
if (std::find(ops_valid_places.begin(),
ops_valid_places.end(),
targets[i]) != ops_valid_places.end()) {
......@@ -235,7 +235,7 @@ void PrintOpsInfo(std::set<std::string> valid_ops = {}) {
}
// Print OP info.
auto ops_valid_places = supported_ops.at(*op);
for (int i = 0; i < targets.size(); i++) {
for (size_t i = 0; i < targets.size(); i++) {
if (std::find(ops_valid_places.begin(),
ops_valid_places.end(),
targets[i]) != ops_valid_places.end()) {
......@@ -288,11 +288,11 @@ void ParseInputCommand() {
auto valid_places = paddle::lite_api::ParserValidPlaces();
// get valid_targets string
std::vector<TargetType> target_types = {};
for (int i = 0; i < valid_places.size(); i++) {
for (size_t i = 0; i < valid_places.size(); i++) {
target_types.push_back(valid_places[i].target);
}
std::string targets_str = TargetToStr(target_types[0]);
for (int i = 1; i < target_types.size(); i++) {
for (size_t i = 1; i < target_types.size(); i++) {
targets_str = targets_str + TargetToStr(target_types[i]);
}
......@@ -301,7 +301,7 @@ void ParseInputCommand() {
target_types.push_back(TARGET(kUnk));
std::set<std::string> valid_ops;
for (int i = 0; i < target_types.size(); i++) {
for (size_t i = 0; i < target_types.size(); i++) {
auto ops = supported_ops_target[static_cast<int>(target_types[i])];
valid_ops.insert(ops.begin(), ops.end());
}
......@@ -318,7 +318,7 @@ void CheckIfModelSupported() {
auto valid_unktype_ops = supported_ops_target[static_cast<int>(TARGET(kUnk))];
valid_ops.insert(
valid_ops.end(), valid_unktype_ops.begin(), valid_unktype_ops.end());
for (int i = 0; i < valid_places.size(); i++) {
for (size_t i = 0; i < valid_places.size(); i++) {
auto target = valid_places[i].target;
auto ops = supported_ops_target[static_cast<int>(target)];
valid_ops.insert(valid_ops.end(), ops.begin(), ops.end());
......@@ -340,7 +340,7 @@ void CheckIfModelSupported() {
std::set<std::string> unsupported_ops;
std::set<std::string> input_model_ops;
for (int index = 0; index < cpp_prog.BlocksSize(); index++) {
for (size_t index = 0; index < cpp_prog.BlocksSize(); index++) {
auto current_block = cpp_prog.GetBlock<lite::cpp::BlockDesc>(index);
for (size_t i = 0; i < current_block->OpsSize(); ++i) {
auto& op_desc = *current_block->GetOp<lite::cpp::OpDesc>(i);
......@@ -364,13 +364,13 @@ void CheckIfModelSupported() {
unsupported_ops_str = unsupported_ops_str + ", " + *op_str;
}
std::vector<TargetType> targets = {};
for (int i = 0; i < valid_places.size(); i++) {
for (size_t i = 0; i < valid_places.size(); i++) {
targets.push_back(valid_places[i].target);
}
std::sort(targets.begin(), targets.end());
targets.erase(unique(targets.begin(), targets.end()), targets.end());
std::string targets_str = TargetToStr(targets[0]);
for (int i = 1; i < targets.size(); i++) {
for (size_t i = 1; i < targets.size(); i++) {
targets_str = targets_str + "," + TargetToStr(targets[i]);
}
......
......@@ -82,27 +82,56 @@ void OptBase::SetValidPlaces(const std::string& valid_places) {
"command argument 'valid_targets'";
}
void OptBase::SetOptimizeOut(const std::string& optimized_out_path) {
optimize_out_path_ = optimized_out_path;
void OptBase::SetLiteOut(const std::string& lite_out_name) {
lite_out_name_ = lite_out_name;
}
void OptBase::RunOptimize(bool record_strip_info) {
void OptBase::RecordModelInfo(bool record_strip_info) {
record_strip_info_ = record_strip_info;
}
void OptBase::Run() {
CheckIfModelSupported(false);
OpKernelInfoCollector::Global().SetKernel2path(kernel2path_map);
opt_config_.set_valid_places(valid_places_);
if (model_set_dir_ != "") {
RunOptimizeFromModelSet(record_strip_info);
RunOptimizeFromModelSet(record_strip_info_);
} else {
auto opt_predictor = lite_api::CreatePaddlePredictor(opt_config_);
opt_predictor->SaveOptimizedModel(
optimize_out_path_, model_type_, record_strip_info);
lite_out_name_, model_type_, record_strip_info_);
auto resulted_model_name =
record_strip_info ? "information of striped model" : "optimized model";
record_strip_info_ ? "information of striped model" : "optimized model";
std::cout << "Save the " << resulted_model_name
<< " into :" << optimize_out_path_ << "successfully";
<< " into :" << lite_out_name_ << "successfully";
}
}
void OptBase::RunOptimize(const std::string& model_dir_path,
const std::string& model_path,
const std::string& param_path,
const std::string& valid_places,
const std::string& optimized_out_path) {
SetModelDir(model_dir_path);
SetModelFile(model_path);
SetParamFile(param_path);
SetValidPlaces(valid_places);
SetLiteOut(optimized_out_path);
CheckIfModelSupported(false);
OpKernelInfoCollector::Global().SetKernel2path(kernel2path_map);
opt_config_.set_valid_places(valid_places_);
if (model_set_dir_ != "") {
RunOptimizeFromModelSet(record_strip_info_);
} else {
auto opt_predictor = lite_api::CreatePaddlePredictor(opt_config_);
opt_predictor->SaveOptimizedModel(
lite_out_name_, model_type_, record_strip_info_);
auto resulted_model_name =
record_strip_info_ ? "information of striped model" : "optimized model";
std::cout << "Save the " << resulted_model_name
<< " into :" << lite_out_name_ << "successfully";
}
}
// collect ops info of modelset
void CollectModelMetaInfo(const std::string& output_dir,
const std::vector<std::string>& models,
......@@ -125,7 +154,7 @@ void OptBase::SetModelSetDir(const std::string& model_set_path) {
}
void OptBase::RunOptimizeFromModelSet(bool record_strip_info) {
// 1. mkdir of outputed optimized model set.
lite::MkDirRecur(optimize_out_path_);
lite::MkDirRecur(lite_out_name_);
auto model_dirs = lite::ListDir(model_set_dir_, true);
if (model_dirs.size() == 0) {
LOG(FATAL) << "[" << model_set_dir_ << "] does not contain any model";
......@@ -138,7 +167,7 @@ void OptBase::RunOptimizeFromModelSet(bool record_strip_info) {
std::string input_model_dir =
lite::Join<std::string>({model_set_dir_, name}, "/");
std::string output_model_dir =
lite::Join<std::string>({optimize_out_path_, name}, "/");
lite::Join<std::string>({lite_out_name_, name}, "/");
if (opt_config_.model_file() != "" && opt_config_.param_file() != "") {
auto model_file_path =
......@@ -155,7 +184,7 @@ void OptBase::RunOptimizeFromModelSet(bool record_strip_info) {
auto opt_predictor = lite_api::CreatePaddlePredictor(opt_config_);
opt_predictor->SaveOptimizedModel(
optimize_out_path_, model_type_, record_strip_info);
lite_out_name_, model_type_, record_strip_info);
std::cout << "Optimize done. ";
}
......@@ -164,46 +193,60 @@ void OptBase::RunOptimizeFromModelSet(bool record_strip_info) {
if (record_strip_info) {
// Collect all models information
CollectModelMetaInfo(
optimize_out_path_, model_dirs, lite::TAILORD_OPS_SOURCE_LIST_FILENAME);
lite_out_name_, model_dirs, lite::TAILORD_OPS_SOURCE_LIST_FILENAME);
CollectModelMetaInfo(
lite_out_name_, model_dirs, lite::TAILORD_OPS_LIST_NAME);
CollectModelMetaInfo(
optimize_out_path_, model_dirs, lite::TAILORD_OPS_LIST_NAME);
CollectModelMetaInfo(optimize_out_path_,
model_dirs,
lite::TAILORD_KERNELS_SOURCE_LIST_FILENAME);
lite_out_name_, model_dirs, lite::TAILORD_KERNELS_SOURCE_LIST_FILENAME);
CollectModelMetaInfo(
optimize_out_path_, model_dirs, lite::TAILORD_KERNELS_LIST_NAME);
lite_out_name_, model_dirs, lite::TAILORD_KERNELS_LIST_NAME);
std::cout << "Record the information of stripped models into :"
<< optimize_out_path_ << "successfully";
<< lite_out_name_ << "successfully";
}
}
void OptBase::PrintHelpInfo() {
const std::string opt_version = lite::version();
const char help_info[] =
"At least one argument should be inputed. Valid arguments are listed "
"below:\n"
"------------------------------------------------------------------------"
"-----------------------------------------------------------\n"
" Valid arguments of Paddle-Lite opt are listed below:\n"
"------------------------------------------------------------------------"
"-----------------------------------------------------------\n"
" Arguments of help information:\n"
" `help()` Print help infomation\n"
" Arguments of model optimization:\n"
"\n"
" Arguments of model transformation:\n"
" `set_model_dir(model_dir)`\n"
" `set_model_file(model_file_path)`\n"
" `set_param_file(param_file_path)`\n"
" `set_model_type(protobuf|naive_buffer)`\n"
" `set_optimize_out(output_optimize_model_dir)`\n"
" `set_model_type(protobuf|naive_buffer)`: naive_buffer by "
"default\n"
" `set_lite_out(output_optimize_model_dir)`\n"
" `set_valid_places(arm|opencl|x86|npu|xpu|rknpu|apu)`\n"
" `run_optimize(false|true)`\n"
" ` ----fasle&true refer to whether to record ops info for "
"tailoring lib, false by default`\n"
" Arguments of model checking and ops information:\n"
" `record_model_info(false|true)`: refer to whether to record ops "
"info for striping lib, false by default`\n"
" `run() : start model transformation`\n"
" eg. `opt.set_model_dir(\"./mobilenetv1\"); "
"opt.set_lite_out(\"mobilenetv1_opt\"); opt.set_valid_places(\"arm\"); "
"opt.run();`\n"
"\n"
" You can also transform model through a single input argument:\n"
" `run_optimize(model_dir, model_file_path, param_file_path, "
"model_type, valid_places, lite_out_name) `\n"
" eg. `opt.run_optimize(\"./mobilenetv1\", \"\", \"\", "
"\"naive_buffer\", \"arm\", \"mobilenetv1_opt\");`"
"\n"
" Arguments of checking model and printing ops information:\n"
" `print_all_ops()` Display all the valid operators of "
"Paddle-Lite\n"
" `print_supported_ops` Display supported operators of valid "
"places\n"
" `check_if_model_supported()` Check if the input model is "
"supported\n";
std::cout << "opt version:" << opt_version << std::endl
<< help_info << std::endl;
"supported\n"
"------------------------------------------------------------------------"
"-----------------------------------------------------------\n";
std::cout << "opt version:" << opt_version << std::endl << help_info;
}
// 2. Print supported info of inputed ops
void OptBase::PrintOpsInfo(const std::set<std::string>& valid_ops) {
......
......@@ -44,16 +44,21 @@ class LITE_API OptBase {
public:
OptBase() = default;
void SetModelSetDir(const std::string &model_set_path);
void SetModelDir(const std::string &model_path);
void SetModelDir(const std::string &model_dir_path);
void SetModelFile(const std::string &model_path);
void SetParamFile(const std::string &param_path);
void SetValidPlaces(const std::string &valid_places);
void SetOptimizeOut(const std::string &optimized_out_path);
void SetLiteOut(const std::string &lite_out_name);
void RecordModelInfo(bool record_strip_info = true);
// set optimized_model type
void SetModelType(std::string model_type);
// transform and save the optimized model
void RunOptimize(bool record_strip_info = false);
void Run();
void RunOptimize(const std::string &model_dir_path = "",
const std::string &model_path = "",
const std::string &param_path = "",
const std::string &valid_places = "",
const std::string &optimized_out_path = "");
// fuctions of printing info
// 1. help info
void PrintHelpInfo();
......@@ -71,12 +76,12 @@ class LITE_API OptBase {
// valid places for the optimized_model
std::vector<Place> valid_places_;
// filename of the optimized_model
std::string optimize_out_path_;
std::string lite_out_name_;
// type of the optimized_model, kNaiveBuffer default.
LiteModelType model_type_{LiteModelType::kNaiveBuffer};
// Dir path of a set of models, this should be combined with model
std::string model_set_dir_;
bool record_strip_info_{false};
void RunOptimizeFromModelSet(bool record_strip_info = false);
};
......
......@@ -213,6 +213,8 @@ class LITE_API CxxConfig : public ConfigBase {
// current thread.
void set_xpu_workspace_l3_size_per_thread(int l3_size = 0xfffc00);
// XPU only, specify the target device ID for the current thread.
// **DEPRECATED**, use xpu_set_device() at the very beginning of each worker
// thread
void set_xpu_dev_per_thread(int dev_no = 0);
};
......
......@@ -19,7 +19,13 @@
#pragma once
// some platform-independent defintion
#include "lite/utils/macros.h"
#if defined(_WIN32)
#define UNUSED
#define __builtin_expect(EXP, C) (EXP)
#else
#define UNUSED __attribute__((unused))
#endif
#define USE_LITE_OP(op_type__) \
extern int touch_op_##op_type__(); \
......
......@@ -33,6 +33,7 @@ USE_MIR_PASS(lite_transpose_softmax_transpose_fuse_pass);
USE_MIR_PASS(lite_interpolate_fuse_pass);
USE_MIR_PASS(lite_sequence_pool_concat_fuse_pass);
USE_MIR_PASS(identity_scale_eliminate_pass);
USE_MIR_PASS(identity_dropout_eliminate_pass);
USE_MIR_PASS(lite_conv_elementwise_fuse_pass);
USE_MIR_PASS(lite_conv_activation_fuse_pass);
USE_MIR_PASS(lite_var_conv_2d_activation_fuse_pass);
......@@ -53,3 +54,5 @@ USE_MIR_PASS(apu_subgraph_pass);
USE_MIR_PASS(quantized_op_attributes_inference_pass);
USE_MIR_PASS(__xpu__resnet_fuse_pass);
USE_MIR_PASS(__xpu__multi_encoder_fuse_pass);
USE_MIR_PASS(__xpu__embedding_with_eltwise_add_fuse_pass);
USE_MIR_PASS(__xpu__fc_fuse_pass);
......@@ -62,8 +62,10 @@ void BindLiteOpt(py::module *m) {
.def("set_model_file", &OptBase::SetModelFile)
.def("set_param_file", &OptBase::SetParamFile)
.def("set_valid_places", &OptBase::SetValidPlaces)
.def("set_optimize_out", &OptBase::SetOptimizeOut)
.def("set_lite_out", &OptBase::SetLiteOut)
.def("set_model_type", &OptBase::SetModelType)
.def("record_model_info", &OptBase::RecordModelInfo)
.def("run", &OptBase::Run)
.def("run_optimize", &OptBase::RunOptimize)
.def("help", &OptBase::PrintHelpInfo)
.def("print_supported_ops", &OptBase::PrintSupportedOps)
......
......@@ -50,7 +50,7 @@ if '${WITH_MKL}' == 'ON':
# link lite.so to paddlelite.libs
if os.name != 'nt':
COMMAND = "patchelf --set-rpath '$ORIGIN/../libs/' ${PADDLE_BINARY_DIR}\
/inference_lite_lib/python/install/lite/lite.so"
/inference_lite_lib/python/install/lite/lite.so"
if os.system(COMMAND) != 0:
raise Exception("patch third_party libs failed, command: %s" % COMMAND)
......
......@@ -21,6 +21,17 @@ namespace paddle {
namespace lite {
namespace arm {
namespace math {
int AdaptStartIndex(int ph, int input_size, int output_size) {
return static_cast<int>(
floor(static_cast<double>(ph * input_size) / output_size));
}
int AdaptEndIndex(int ph, int input_size, int output_size) {
return static_cast<int>(
ceil(static_cast<double>((ph + 1) * input_size) / output_size));
}
void pooling_basic(const float* din,
float* dout,
int num,
......@@ -88,15 +99,27 @@ void pooling_basic(const float* din,
#pragma omp parallel for
for (int ind_c = 0; ind_c < chin; ++ind_c) {
for (int ind_h = 0; ind_h < hout; ++ind_h) {
int sh = ind_h * stride_h;
int eh = sh + kernel_h;
sh = (sh - pad_h) < 0 ? 0 : sh - pad_h;
eh = (eh - pad_h) > hin ? hin : eh - pad_h;
int sh, eh;
if (adaptive) {
sh = AdaptStartIndex(ind_h, hin, hout);
eh = AdaptEndIndex(ind_h, hin, hout);
} else {
sh = ind_h * stride_h;
eh = sh + kernel_h;
sh = (sh - pad_h) < 0 ? 0 : sh - pad_h;
eh = (eh - pad_h) > hin ? hin : eh - pad_h;
}
for (int ind_w = 0; ind_w < wout; ++ind_w) {
int sw = ind_w * stride_w;
int ew = sw + kernel_w;
sw = (sw - pad_w) < 0 ? 0 : sw - pad_w;
ew = (ew - pad_w) > win ? win : ew - pad_w;
int sw, ew;
if (adaptive) {
sw = AdaptStartIndex(ind_w, win, wout);
ew = AdaptEndIndex(ind_w, win, wout);
} else {
sw = ind_w * stride_w;
ew = sw + kernel_w;
sw = (sw - pad_w) < 0 ? 0 : sw - pad_w;
ew = (ew - pad_w) > win ? win : ew - pad_w;
}
float result = static_cast<float>(0);
int dst_ind = (ind_n * chout + ind_c) * size_channel_out +
ind_h * wout + ind_w;
......
......@@ -19,6 +19,7 @@ namespace lite {
#ifdef LITE_WITH_XPU
thread_local xdnn::Context* Context<TargetType::kXPU>::_tls_raw_ctx{nullptr};
int Context<TargetType::kXPU>::_workspace_l3_size_per_thread{0};
#endif
} // namespace lite
......
......@@ -151,14 +151,23 @@ class Context<TargetType::kXPU> {
if (_tls_raw_ctx == nullptr) {
_tls_raw_ctx = xdnn::create_context();
CHECK(_tls_raw_ctx);
int r = xdnn::set_workspace_l3_size(_tls_raw_ctx,
_workspace_l3_size_per_thread);
if (r != 0) {
LOG(WARNING) << "xdnn::set_workspace_l3_size() failed, r = " << r
<< ", _workspace_l3_size_per_thread = "
<< _workspace_l3_size_per_thread;
}
}
return _tls_raw_ctx;
}
static void SetWorkspaceL3Size(int l3_size = 0xfffc00) {
xdnn::set_workspace_l3_size(GetRawContext(), l3_size);
_workspace_l3_size_per_thread = l3_size;
}
// **DEPRECATED**, use xpu_set_device() at the very beginning of each worker
// thread
static void SetDev(int dev_no = 0) {
const char* dev_env = getenv("LITE_XPU_DEV");
if (dev_env) {
......@@ -173,6 +182,7 @@ class Context<TargetType::kXPU> {
private:
static thread_local xdnn::Context* _tls_raw_ctx;
static int _workspace_l3_size_per_thread;
};
#endif
......
......@@ -23,7 +23,10 @@ lite_cc_library(mir_passes
fusion/sequence_pool_concat_fuse_pass.cc
fusion/__xpu__resnet_fuse_pass.cc
fusion/__xpu__multi_encoder_fuse_pass.cc
fusion/__xpu__embedding_with_eltwise_add_fuse_pass.cc
fusion/__xpu__fc_fuse_pass.cc
elimination/identity_scale_eliminate_pass.cc
elimination/identity_dropout_eliminate_pass.cc
elimination/elementwise_mul_constant_eliminate_pass.cc
static_kernel_pick_pass.cc
variable_place_inference_pass.cc
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/pass.h"
#include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace {
class Eliminator : public FuseBase {
public:
void BuildPattern() override {
// the previous op's output need updat
auto* pre_op = OpNode("preop")->assert_is_not_op_type("conditional_block");
// TODO(Superjomn) check has only one output
auto* x = VarNode("x")->assert_is_op_input("dropout", "X");
auto* dropout_op = OpNode("dropout", "dropout")
->assert_op_attr<int>("is_test", 1)
->assert_op_attr<std::string>(
"dropout_implementation", "upscale_in_train");
auto* out = VarNode("out")->assert_is_op_output("dropout", "Out");
auto* mask = VarNode("mask")->assert_is_op_output("dropout", "Mask");
*pre_op >> *x >> *dropout_op >> *out;
*dropout_op >> *mask;
// The pre_op will be eliminated, and a new output-updated op will insert.
x->AsIntermediate(); // x is pre_op's output, need to update
dropout_op->AsIntermediate();
mask->AsIntermediate();
}
private:
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override {
auto& pre_op = matched.at("preop")->AsStmt();
auto op_info = *pre_op.op_info();
op_info.UpdateAllOutputs(matched.at("x")->AsArg().name,
matched.at("out")->AsArg().name);
pre_op.ResetOp(op_info, graph->valid_places());
IR_NODE_LINK_TO(matched.at("preop"), matched.at("out"));
}
};
} // namespace
class IdentityDropoutEliminatePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override {
Eliminator eliminator;
eliminator(graph.get());
}
};
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(identity_dropout_eliminate_pass,
paddle::lite::mir::IdentityDropoutEliminatePass)
.BindTargets({TARGET(kXPU)});
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <vector>
#include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/xpu_pattern_matcher_high_api.h"
#include "lite/utils/string.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class XPUEmbeddingWithEltwiseAddFuser : public FuseBase {
public:
explicit XPUEmbeddingWithEltwiseAddFuser(int n_embedding)
: n_embedding_(n_embedding) {}
void BuildPattern() override {
auto* ids0 =
VarNode("ids0")->assert_is_op_input("lookup_table", "Ids")->AsInput();
auto* table0 =
VarNode("table0")->assert_is_op_input("lookup_table", "W")->AsInput();
auto* embedding0 = OpNode("embedding0", "lookup_table");
auto* embedding_out0 = VarNode("embedding_out0")
->assert_is_op_output("lookup_table", "Out")
->assert_is_op_input("elementwise_add", "X")
->AsIntermediate();
auto* ids1 =
VarNode("ids1")->assert_is_op_input("lookup_table", "Ids")->AsInput();
auto* table1 =
VarNode("table1")->assert_is_op_input("lookup_table", "W")->AsInput();
auto* embedding1 = OpNode("embedding1", "lookup_table")->AsIntermediate();
auto* embedding_out1 = VarNode("embedding_out1")
->assert_is_op_output("lookup_table", "Out")
->assert_is_op_input("elementwise_add", "Y")
->AsIntermediate();
auto* ewadd01 = OpNode("ewadd01", "elementwise_add")->AsIntermediate();
auto* ewadd01_out = VarNode("ewadd01_out")
->assert_is_op_output("elementwise_add", "Out")
->AsIntermediate();
embedding0->LinksFrom({ids0, table0});
embedding0->LinksTo({embedding_out0});
embedding1->LinksFrom({ids1, table1});
embedding1->LinksTo({embedding_out1});
ewadd01->LinksFrom({embedding_out0, embedding_out1});
ewadd01->LinksTo({ewadd01_out});
auto* last_ewadd_out = ewadd01_out;
for (int i = 2; i < n_embedding_; ++i) {
auto ids_name = paddle::lite::string_format("ids%d", i);
auto table_name = paddle::lite::string_format("table%d", i);
auto embedding_name = paddle::lite::string_format("embedding%d", i);
auto embedding_out_name =
paddle::lite::string_format("embedding_out%d", i);
auto* new_ids = VarNode(ids_name)
->assert_is_op_input("lookup_table", "Ids")
->AsInput();
auto* new_table = VarNode(table_name)
->assert_is_op_input("lookup_table", "W")
->AsInput();
auto* new_embedding =
OpNode(embedding_name, "lookup_table")->AsIntermediate();
auto* new_embedding_out = VarNode(embedding_out_name)
->assert_is_op_output("lookup_table", "Out")
->assert_is_op_input("elementwise_add", "Y")
->AsIntermediate();
new_embedding->LinksFrom({new_ids, new_table});
new_embedding->LinksTo({new_embedding_out});
auto ewadd_name = paddle::lite::string_format("ewadd%d%d", i - 1, i);
auto ewadd_out_name = ewadd_name + "_out";
auto* new_ewadd = OpNode(ewadd_name, "elementwise_add")->AsIntermediate();
auto* new_ewadd_out = VarNode(ewadd_out_name)
->assert_is_op_output("elementwise_add", "Out")
->AsIntermediate();
new_ewadd->LinksFrom({last_ewadd_out, new_embedding_out});
new_ewadd->LinksTo({new_ewadd_out});
last_ewadd_out = new_ewadd_out;
}
last_ewadd_out->AsOutput();
}
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override {
cpp::OpDesc op_desc;
op_desc.SetType("__xpu__embedding_with_eltwise_add");
std::vector<std::string> ids_names;
std::vector<std::string> table_names;
for (int i = 0; i < n_embedding_; ++i) {
auto ids_name = paddle::lite::string_format("ids%d", i);
ids_names.push_back(matched.at(ids_name)->arg()->name);
auto table_name = paddle::lite::string_format("table%d", i);
table_names.push_back(matched.at(table_name)->arg()->name);
}
op_desc.SetInput("Ids", ids_names);
op_desc.SetInput("Tables", table_names);
auto output_name = paddle::lite::string_format(
"ewadd%d%d_out", n_embedding_ - 2, n_embedding_ - 1);
op_desc.SetOutput("Output", {matched.at(output_name)->arg()->name});
op_desc.SetAttr<int>("n_embedding", n_embedding_);
auto* embedding0_op_info = matched.at("embedding0")->stmt()->op_info();
op_desc.SetAttr<int64_t>(
"padding_idx", embedding0_op_info->GetAttr<int64_t>("padding_idx"));
auto* new_stmt = matched.at("embedding0")->stmt();
auto new_op = LiteOpRegistry::Global().Create(op_desc.Type());
new_op->Attach(op_desc, new_stmt->op()->scope());
new_op->SetValidPlaces(new_stmt->op()->valid_places());
auto kernels = new_op->CreateKernels(new_op->valid_places());
new_stmt->SetOp(new_op);
new_stmt->SetKernels(std::move(kernels));
for (int i = 0; i < n_embedding_; ++i) {
auto ids_name = paddle::lite::string_format("ids%d", i);
auto table_name = paddle::lite::string_format("table%d", i);
DirectedLink(matched.at(ids_name), matched.at("embedding0"));
DirectedLink(matched.at(table_name), matched.at("embedding0"));
}
IR_OP_VAR_LINK(matched.at("embedding0"), matched.at(output_name));
}
private:
int n_embedding_;
};
} // namespace fusion
class XPUEmbeddingWithEltwiseAddFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override {
if (GetBoolFromEnv("XPU_ENABLE_XTCL")) return;
for (int n_embedding : {4, 3}) {
fusion::XPUEmbeddingWithEltwiseAddFuser fuser(n_embedding);
fuser(graph.get());
}
}
};
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(__xpu__embedding_with_eltwise_add_fuse_pass,
paddle::lite::mir::XPUEmbeddingWithEltwiseAddFusePass)
.BindTargets({TARGET(kXPU)})
.BindKernel("lookup_table");
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include "lite/backends/xpu/math.h"
#include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class XPUFcFuser : public FuseBase {
public:
explicit XPUFcFuser(bool with_relu) : with_relu_(with_relu) {}
void BuildPattern() override {
// create nodes.
auto* x = VarNode("x")->assert_is_op_input("mul", "X");
auto* W = VarNode("W")->assert_is_op_input("mul", "Y");
auto* b = VarNode("b")->assert_is_persistable_var();
auto* mul = OpNode("mul", "mul");
auto* mul_out = VarNode("mul_out");
auto* add = OpNode("add", "elementwise_add");
auto* Out = VarNode("Out");
// create topology.
std::vector<PMNode*> mul_inputs{W, x};
std::vector<PMNode*> add_inputs{mul_out, b};
mul_inputs >> *mul >> *mul_out;
// Some op specialities.
mul_out->AsIntermediate();
mul->AsIntermediate();
add->AsIntermediate();
if (with_relu_) {
auto* add_out = VarNode("add_out");
auto* relu = OpNode("relu", "relu");
std::vector<PMNode*> relu_inputs{add_out};
add_inputs >> *add >> *add_out;
relu_inputs >> *relu >> *Out;
add_out->AsIntermediate();
relu->AsIntermediate();
} else {
add_inputs >> *add >> *Out;
}
}
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override {
auto mul = matched.at("mul")->stmt()->op();
auto* scope = mul->scope();
// convert W from float to int16, and transpose W
auto weight_name = matched.at("W")->arg()->name;
auto* weight_t = scope->FindMutableTensor(weight_name);
auto weight_dims = weight_t->dims();
int weight_len = weight_t->numel();
float* weight_on_host = weight_t->mutable_data<float>();
float max_f =
paddle::lite::xpu::math::FindMaxAbs(weight_on_host, weight_len);
std::unique_ptr<int16_t[]> weight_int16(new int16_t[weight_len]);
std::unique_ptr<int16_t[]> weight_trans_int16(new int16_t[weight_len]);
paddle::lite::xpu::math::ConvertFP32ToInt16(
weight_on_host, weight_int16.get(), max_f, weight_len);
paddle::lite::xpu::math::Transpose(weight_int16.get(),
weight_trans_int16.get(),
weight_dims[0],
weight_dims[1]);
memcpy(
weight_on_host, weight_trans_int16.get(), weight_len * sizeof(int16_t));
auto op_desc = GenOpDesc(matched, max_f, true);
auto fc_op = LiteOpRegistry::Global().Create("__xpu__fc");
auto& valid_places = mul->valid_places();
fc_op->Attach(op_desc, scope);
auto* new_op_node = graph->GraphCreateInstructNode(fc_op, valid_places);
IR_NODE_LINK_TO(matched.at("W"), new_op_node);
IR_NODE_LINK_TO(matched.at("x"), new_op_node);
IR_NODE_LINK_TO(matched.at("b"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("Out"));
}
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched,
float w_max,
bool transpose_w) {
cpp::OpDesc op_desc = *matched.at("mul")->stmt()->op_info();
op_desc.mutable_inputs()->clear();
op_desc.mutable_outputs()->clear();
op_desc.SetType("__xpu__fc");
op_desc.SetInput("Input", {matched.at("x")->arg()->name});
op_desc.SetInput("W", {matched.at("W")->arg()->name});
op_desc.SetInput("Bias", {matched.at("b")->arg()->name});
op_desc.SetOutput("Out", {matched.at("Out")->arg()->name});
op_desc.SetAttr(
"in_num_col_dims",
matched.at("mul")->stmt()->op_info()->GetAttr<int>("x_num_col_dims"));
op_desc.SetAttr("w_max", w_max);
op_desc.SetAttr("transpose_w", transpose_w);
if (with_relu_) {
op_desc.SetAttr("activation_type", std::string{"relu"});
}
return op_desc;
}
bool with_relu_;
};
} // namespace fusion
class XPUFcFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override {
if (GetBoolFromEnv("XPU_ENABLE_XTCL")) return;
fusion::XPUFcFuser fuser(true /* with_relu */);
fuser(graph.get());
fusion::XPUFcFuser fuser2(false /* with_relu */);
fuser2(graph.get());
}
};
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(__xpu__fc_fuse_pass, paddle::lite::mir::XPUFcFusePass)
.BindTargets({TARGET(kXPU)})
.BindKernel("fc");
......@@ -16,6 +16,7 @@
#include <vector>
#include "lite/backends/xpu/math.h"
#include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/type_precision_cast_pass.h" // For UpdateInputs()
#include "lite/core/mir/xpu_pattern_matcher_high_api.h"
#include "lite/operators/subgraph_op.h"
......@@ -588,8 +589,7 @@ class XPUMultiEncoderFuser {
multi_encoder_stmt->SetOp(multi_encoder_op);
multi_encoder_stmt->SetKernels(std::move(kernels));
// temp remove useless cast
std::unordered_set<const Node*> to_remove2;
// remove dangling/useless cast
Node* stack = nullptr;
for (auto* node : graph->StmtTopologicalOrder()) {
CHECK(node->IsStmt());
......@@ -597,16 +597,39 @@ class XPUMultiEncoderFuser {
stack = node;
}
}
Node* stack_out = stack->outlinks.front();
for (Node* cast : stack_out->outlinks) {
Node* cast_out = cast->outlinks.front();
if (cast_out->outlinks.size() == 0) {
// remove
to_remove2.insert(cast_out);
to_remove2.insert(cast);
if (stack) {
std::unordered_set<const Node*> to_remove2;
Node* stack_out = stack->outlinks.front();
// avoid modification while traversing
auto stack_out_outlinks = stack_out->outlinks;
for (Node* cast : stack_out_outlinks) {
if (cast->stmt()->op_info()->Type() != "cast") {
continue;
}
Node* cast_out = cast->outlinks.front();
if (cast_out->outlinks.size() == 0) {
// dangling cast
to_remove2.insert(cast);
to_remove2.insert(cast_out);
VLOG(3) << "Remove dangling cast [" << cast_out->arg()->name << "]";
} else if (cast_out->outlinks.size() == 1) {
// useless cast
to_remove2.insert(cast);
to_remove2.insert(cast_out);
VLOG(3) << "Remove useless cast [" << cast_out->arg()->name << "]";
auto* multi_encoder = cast_out->outlinks.front();
DirectedLink(stack_out, multi_encoder);
UpdateInputs(multi_encoder->stmt()->op().get(),
cast_out->arg()->name,
stack_out->arg()->name);
auto update_op_info = *multi_encoder->stmt()->op_info();
multi_encoder->stmt()->ResetOp(update_op_info, graph->valid_places());
}
}
GraphSafeRemoveNodes(graph, to_remove2);
}
GraphSafeRemoveNodes(graph, to_remove2);
}
};
......
......@@ -92,6 +92,10 @@ class Optimizer {
#endif
"__xpu__resnet_fuse_pass",
"__xpu__multi_encoder_fuse_pass",
"__xpu__embedding_with_eltwise_add_fuse_pass",
"__xpu__fc_fuse_pass",
"identity_dropout_eliminate_pass", // should be placed after
// xpu fusion
"quantized_op_attributes_inference_pass", // Only for fully
// quantized model, infer
// the output scale and
......
......@@ -53,6 +53,24 @@ static bool write_tensorfile(const Tensor* tensor, const std::string& locate) {
return true;
}
static bool write_precision_summary_tofile(const std::string& string,
const std::string& log_dir = "") {
if (log_dir == "") {
LOG(INFO) << "The `log_dir` of precision summary file is not set. log_dir:"
<< log_dir;
return false;
}
FILE* fp = fopen(log_dir.c_str(), "a");
if (fp == nullptr) {
LOG(INFO) << "Open precision summary file:" << log_dir << "failed.";
return false;
} else {
fprintf(fp, "%s\n", string.c_str());
}
fclose(fp);
return true;
}
class PrecisionProfiler {
public:
// TODO(ysh329): need to remove `explicit PrecisionProfiler`
......@@ -68,7 +86,7 @@ class PrecisionProfiler {
using std::left;
using std::fixed;
STL::stringstream ss;
ss << "========================================= "
ss << "\n\n========================================= "
<< "Detailed Precision Profiler Summary "
<< "=========================================" << std::endl;
ss << setw(45) << left << "operator:(kernel_info)"
......@@ -78,6 +96,13 @@ class PrecisionProfiler {
<< " " << setw(15) << left << "std_deviation"
<< " " << setw(15) << left << "ave_grow_rate*" << std::endl;
// write to file with path: `log_dir`
if (log_dir_ != "") {
FILE* fp = fopen(log_dir_.c_str(), "a");
std::string header_str{ss.str()};
fprintf(fp, "%s\n", header_str.c_str());
fclose(fp);
}
return ss.str();
}
......@@ -195,6 +220,7 @@ class PrecisionProfiler {
}
#ifdef LITE_WITH_OPENCL
} else if (target_type == TARGET(kOpenCL)) {
CLRuntime::Global()->command_queue().finish();
switch (layout_type) {
case DATALAYOUT(kImageDefault): {
paddle::lite::CLImageConverterDefault default_convertor;
......@@ -361,8 +387,12 @@ class PrecisionProfiler {
}
}
}
write_precision_summary_tofile(ss.str(), log_dir_);
return ss.str();
}
private:
std::string log_dir_{"/storage/emulated/0/precision.log"};
};
} // namespace profile
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "lite/kernels/host/compare_compute.h"
#include <math.h>
#include <vector>
namespace paddle {
......
......@@ -111,7 +111,8 @@ lite_cc_test(test_box_coder_image_opencl SRCS box_coder_image_compute_test.cc
#add_kernel(pool_opencl OPENCL basic SRCS pool_buffer_compute.cc DEPS ${cl_kernel_deps})
#add_kernel(concat_opencl OPENCL basic SRCS concat_buffer_compute.cc DEPS ${cl_kernel_deps})
add_kernel(fc_opencl OPENCL basic SRCS fc_buffer_compute.cc DEPS ${cl_kernel_deps})
add_kernel(mul_opencl OPENCL basic SRCS mul_buffer_compute.cc DEPS ${cl_kernel_deps})
# NOTE(ysh329): use fc as `mul`, and mul is not used.
#add_kernel(mul_opencl OPENCL basic SRCS mul_buffer_compute.cc DEPS ${cl_kernel_deps})
#add_kernel(elementwise_add_opencl OPENCL basic SRCS elementwise_add_buffer_compute.cc DEPS ${cl_kernel_deps})
#add_kernel(fusion_elementwise_add_activation_opencl
# OPENCL basic SRCS fusion_elementwise_add_activation_buffer_compute.cc
......@@ -147,8 +148,8 @@ add_kernel(io_copy_opencl OPENCL basic SRCS io_copy_buffer_compute.cc DEPS ${ten
lite_cc_test(test_fc_buffer_opencl SRCS fc_buffer_compute_test.cc
DEPS fc_opencl op_registry program context)
lite_cc_test(test_mul_buffer_opencl SRCS mul_buffer_compute_test.cc
DEPS mul_opencl op_registry program context)
#lite_cc_test(test_mul_buffer_opencl SRCS mul_buffer_compute_test.cc
# DEPS mul_opencl op_registry program context)
#lite_cc_test(test_elementwise_add_buffer_opencl SRCS elementwise_add__buffer_compute_test.cc
# DEPS elementwise_add_opencl op_registry program context)
......
......@@ -176,7 +176,6 @@ TEST(bilinear_interp_image2d, compute) {
input_v.data(), x_image_data.data(), in_dim);
auto* x_image = x.mutable_data<half_t, cl::Image2D>(
x_image_shape[0], x_image_shape[1], x_image_data.data());
// LOG(INFO) << "x_image:" << x_image;
DDim out_image_shape =
default_converter->InitImageDimInfoWith(out_dim);
......@@ -184,9 +183,8 @@ TEST(bilinear_interp_image2d, compute) {
<< out_image_shape[1];
auto* out_image = out.mutable_data<half_t, cl::Image2D>(
out_image_shape[0], out_image_shape[1]);
// LOG(INFO) << "out_image:" << out_image;
kernel->Launch();
kernel->Launch();
CLRuntime::Global()->command_queue().finish();
std::unique_ptr<float[]> out_ref(
......
......@@ -41,9 +41,8 @@ class BoxCoderComputeImage : public KernelLite<TARGET(kOpenCL),
boxcoder_param_->box_normalized == true) {
kernel_func_name_ = "decode_center_size";
} else {
printf("This code_type %s doesn't support \n",
boxcoder_param_->code_type.c_str());
return;
LOG(FATAL) << "This code_type " << boxcoder_param_->code_type
<< " doesn't support";
}
CHECK(context.cl_context() != nullptr);
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
......
......@@ -400,16 +400,28 @@ void ConvImageCompute::PrepareForRun() {
VLOG(1) << "kernel_func_names_[0]:" << kernel_func_names_[0]
<< " kernel_func_paths_[0]:" << kernel_func_paths_[0];
// build options
std::string build_options_single(" -DCL_DTYPE_half");
// relu options
if (relu_fused) {
build_options_single += " -DRELU";
} else if (param.activation_param.active_type ==
lite_api::ActivationType::kRelu6) {
build_options_single += " -DRELU6";
} else {
// do nothing, may add more activation fuse
VLOG(3) << "relu_fused:" << relu_fused
<< " param.activation_param.active_type:"
<< static_cast<int>(param.activation_param.active_type)
<< " param.activation_param.has_active:"
<< param.activation_param.has_active;
if (param.activation_param.has_active) {
if (param.activation_param.active_type ==
lite_api::ActivationType::kRelu) { // Note: judge using `relu_fused`
// also is ok
build_options_single += " -DRELU";
} else if (param.activation_param.active_type ==
lite_api::ActivationType::kRelu6) {
build_options_single += " -DRELU6";
} else {
LOG(FATAL) << "Unsupported activation type:"
<< static_cast<int>(param.activation_param.active_type);
}
}
// bias options
const bool has_bias = param.bias != nullptr;
const bool is_element_wise_bias =
......@@ -648,7 +660,7 @@ void ConvImageCompute::Conv2d3x3(bool is_turn) {
int filter_height = filter_dims[2];
int filter_channel = filter_dims[1];
auto out_image_shape = InitImageDimInfoWith(output_dims);
auto* out_image = param.output->mutable_data<uint16_t, cl::Image2D>(
auto* out_image = param.output->mutable_data<half_t, cl::Image2D>(
out_image_shape["width"], out_image_shape["height"]);
const bool has_bias = param.bias != nullptr;
......@@ -724,7 +736,7 @@ void ConvImageCompute::Conv2d3x3(bool is_turn) {
const cl::Image2D* bias_image = nullptr;
if (has_bias) {
bias_image = bias_gpu_image_->data<uint16_t, cl::Image2D>();
bias_image = bias_gpu_image_->data<half_t, cl::Image2D>();
}
auto& context = ctx_->As<OpenCLContext>();
......
......@@ -197,15 +197,23 @@ TEST(conv2d, compute_image2d_1x1) {
if (bias_flag) {
param.bias = &bias;
}
if (relu_flag == "relu") {
param.fuse_relu = true;
param.fuse_relu = true; // relu only
param.activation_param.has_active = true;
param.activation_param.active_type =
lite_api::ActivationType::kRelu;
} else if (relu_flag == "None") {
param.fuse_relu = false;
param.activation_param.has_active = false;
} else if (relu_flag == "relu6") {
param.activation_param.Relu_clipped_coef = 6.f;
param.activation_param.has_active = true;
param.activation_param.active_type =
lite_api::ActivationType::kRelu6;
} else {
param.fuse_relu = false; // relu only
param.activation_param.has_active = false;
}
std::vector<int> paddings = {pad, pad, pad, pad};
......@@ -337,7 +345,7 @@ TEST(conv2d, compute_image2d_1x1) {
SHADOW_LOG << "(" << i << ")" << Half2Float(x_image_v[i]);
}
// auto* filter_image2d =
// filter.mutable_data<uint16_t, cl::Image2D>(
// filter.mutable_data<half_t, cl::Image2D>(
// filter_image_width,
// filter_image_height,
// filter_image_v.data());
......@@ -563,15 +571,23 @@ const int stride = 2;
if (bias_flag) {
param.bias = &bias;
}
if (relu_flag == "relu") {
param.fuse_relu = true;
param.fuse_relu = true; // relu only
param.activation_param.has_active = true;
param.activation_param.active_type =
lite_api::ActivationType::kRelu;
} else if (relu_flag == "None") {
param.fuse_relu = false;
param.activation_param.has_active = false;
} else if (relu_flag == "relu6") {
param.activation_param.Relu_clipped_coef = 6.f;
param.activation_param.has_active = true;
param.activation_param.active_type =
lite_api::ActivationType::kRelu6;
} else {
param.fuse_relu = false; // relu only
param.activation_param.has_active = false;
}
std::vector<int> paddings = {pad, pad, pad, pad};
......@@ -912,14 +928,21 @@ TEST(conv2d, compute_image2d_5x5) {
param.bias = &bias;
}
if (relu_flag == "relu") {
param.fuse_relu = true;
param.fuse_relu = true; // relu only
param.activation_param.has_active = true;
param.activation_param.active_type =
lite_api::ActivationType::kRelu;
} else if (relu_flag == "None") {
param.fuse_relu = false;
param.activation_param.has_active = false;
} else if (relu_flag == "relu6") {
param.activation_param.Relu_clipped_coef = 6.f;
param.activation_param.has_active = true;
param.activation_param.active_type =
lite_api::ActivationType::kRelu6;
} else {
param.fuse_relu = false; // relu only
param.activation_param.has_active = false;
}
std::vector<int> paddings = {pad, pad, pad, pad};
......@@ -1244,16 +1267,25 @@ TEST(conv2d, compute_image2d_7x7) {
if (bias_flag) {
param.bias = &bias;
}
if (relu_flag == "relu") {
param.fuse_relu = true;
param.fuse_relu = true; // relu only
param.activation_param.has_active = true;
param.activation_param.active_type =
lite_api::ActivationType::kRelu;
} else if (relu_flag == "None") {
param.fuse_relu = false;
param.activation_param.has_active = false;
} else if (relu_flag == "relu6") {
param.activation_param.Relu_clipped_coef = 6.f;
param.activation_param.has_active = true;
param.activation_param.active_type =
lite_api::ActivationType::kRelu6;
} else {
param.fuse_relu = false; // relu only
param.activation_param.has_active = false;
}
std::vector<int> paddings = {pad, pad, pad, pad};
std::vector<int> dilations = {dilation, dilation};
......
......@@ -162,15 +162,27 @@ TEST(fc, compute) {
// run opencl kernel
kernel->Launch();
CLRuntime::Global()->command_queue().finish();
#if 0 // NOTE(ysh329): note event
auto* wait_list = context->As<OpenCLContext>().cl_wait_list();
auto* out_ptr = param.output->data<float, cl::Buffer>();
auto it = wait_list->find(out_ptr);
if (it != wait_list->end()) {
VLOG(4) << "--- Find the sync event for the target cl tensor. ---";
auto& event = *(it->second);
event.wait();
CLRuntime::Global()->command_queue().finish();
#if 0
double start_nanos =
event.getProfilingInfo<CL_PROFILING_COMMAND_START>();
double stop_nanos =
event.getProfilingInfo<CL_PROFILING_COMMAND_END>();
double elapsed_micros = (stop_nanos - start_nanos) / 1000.0;
LOG(INFO) << "Kernel Run Cost Time: " << elapsed_micros << " us.";
} else {
LOG(FATAL)
<< "Could not find the sync event for the target cl tensor.";
}
#endif
std::vector<float> out_data_from_gpu(out_dim.production());
......@@ -201,18 +213,17 @@ TEST(fc, compute) {
out_data_from_gpu.data()[eidx]);
auto relative_diff = COMPUTE_RELATIVE_DIFF(
out_ref_data[eidx], out_data_from_gpu.data()[eidx]);
// EXPECT_EQ((relative_diff <= FP16_MAX_DIFF) ||
// (abs_diff <= FP16_MAX_DIFF),
// true);
EXPECT_EQ(
(relative_diff <= FP16_MAX_DIFF) || (abs_diff <= FP16_MAX_DIFF),
true);
if ((relative_diff > FP16_MAX_DIFF) && (abs_diff > FP16_MAX_DIFF)) {
LOG(ERROR) << "error idx:" << eidx << ", out_ref_data[" << eidx
LOG(FATAL) << "error idx:" << eidx << ", out_ref_data[" << eidx
<< "]:" << out_ref_data[eidx]
<< ", out_data_from_gpu.data()[" << eidx
<< "]:" << out_data_from_gpu.data()[eidx]
<< " abs_diff:" << abs_diff
<< " relative_diff:" << relative_diff
<< " FP16_MAX_DIFF:" << FP16_MAX_DIFF;
return;
}
}
......
......@@ -118,8 +118,11 @@ class LayoutComputeBufferChwToImageDefault
status = kernel.setArg(++arg_idx, static_cast<const int>(Stride2));
CL_CHECK_FATAL(status);
#ifndef LITE_SHUTDOWN_LOG
VLOG(2) << "gws:[3D]" << ((new_dims[1] + 3) / 4) << " " << new_dims[3]
<< " " << (new_dims[0] * new_dims[2]);
#endif
auto global_work_size =
cl::NDRange{static_cast<cl::size_type>((new_dims[1] + 3) / 4),
static_cast<cl::size_type>(new_dims[3]),
......
......@@ -84,7 +84,8 @@ TEST(slice_image2d_fp16, compute) {
}
LOG(INFO) << "prepare input";
CLImageConverterDefault* default_converter = new CLImageConverterDefault();
std::unique_ptr<CLImageConverterDefault> default_converter(
new CLImageConverterDefault());
DDim image_shape = default_converter->InitImageDimInfoWith(in_dim);
LOG(INFO) << "image_shape = " << image_shape[0] << " " << image_shape[1];
std::vector<half_t> x_image_data(image_shape.production() * 4); // 4 : RGBA
......
......@@ -24,4 +24,6 @@ else()
add_kernel(cast_compute_xpu XPU basic SRCS cast_compute.cc DEPS ${lite_kernel_deps})
add_kernel(__xpu__resnet50_compute_xpu XPU extra SRCS __xpu__resnet50_compute.cc DEPS ${lite_kernel_deps})
add_kernel(__xpu__multi_encoder_compute_xpu XPU extra SRCS __xpu__multi_encoder_compute.cc DEPS ${lite_kernel_deps})
add_kernel(__xpu__embedding_with_eltwise_add_compute_xpu XPU extra SRCS __xpu__embedding_with_eltwise_add_compute.cc DEPS ${lite_kernel_deps})
add_kernel(__xpu__fc_compute_xpu XPU extra SRCS __xpu__fc_compute.cc DEPS ${lite_kernel_deps})
endif()
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/xpu/__xpu__embedding_with_eltwise_add_compute.h"
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
void XPUEmbeddingWithEltwiseAddCompute::PrepareForRun() {
auto& param = this->Param<param_t>();
arg_ids_.reserve(param.Ids.size());
arg_tables_.reserve(param.Tables.size());
for (auto* table : param.Tables) {
auto& table_dims = table->dims();
CHECK_EQ(table_dims.size(), 2); /* shape like [table_len, embed_dim] */
table_lens_cpu_.push_back(table_dims[0]);
}
void* lens_ptr = nullptr;
size_t lens_size = table_lens_cpu_.size() * sizeof(int);
xpu_malloc(&lens_ptr, lens_size);
xpu_memcpy(lens_ptr, &table_lens_cpu_[0], lens_size, XPU_HOST_TO_DEVICE);
table_lens_guard_.reset(lens_ptr);
}
void XPUEmbeddingWithEltwiseAddCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<XPUContext>();
for (size_t i = 0; i < param.Ids.size(); ++i) {
arg_ids_[i] = param.Ids[i]->data<int64_t>();
}
for (size_t i = 0; i < param.Tables.size(); ++i) {
arg_tables_[i] = param.Tables[i]->data<float>();
}
auto& id_dims = param.Ids[0]->dims();
auto& table_dims = param.Tables[0]->dims();
int idx_len = id_dims[0] * id_dims[1];
int embed_dim = table_dims[1];
int emb_layer_num = param.Ids.size();
int r = xdnn::embedding_with_ewadd<float, int64_t, false, false>(
ctx.GetRawContext(), /* context */
embed_dim, /* embed_dim */
idx_len, /* idx_len */
emb_layer_num, /* emb_layer_num */
param.padding_idx, /* padding_idx */
&arg_tables_[0], /* tables */
&arg_ids_[0], /* indices */
static_cast<int*>(table_lens_guard_.get()), /* table_lens */
nullptr, /* scale_after_emb */
nullptr, /* scale_after_ewadd */
param.Out->mutable_data<float>(TARGET(kXPU)) /* top */);
CHECK_EQ(r, 0);
}
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
__xpu__embedding_with_eltwise_add,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::XPUEmbeddingWithEltwiseAddCompute,
def)
.BindInput("Ids", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))})
.BindInput("Tables", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
#include "lite/core/kernel.h"
#include "lite/kernels/xpu/utils.h" // XPUFreeDeleter
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
class XPUEmbeddingWithEltwiseAddCompute
: public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::XPUEmbeddingWithEltwiseAddParam;
void PrepareForRun() override;
void Run() override;
private:
std::vector<const int64_t*> arg_ids_;
std::vector<const float*> arg_tables_;
std::unique_ptr<void, XPUFreeDeleter> table_lens_guard_;
std::vector<int> table_lens_cpu_;
};
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/xpu/__xpu__fc_compute.h"
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
void XPUFcCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<XPUContext>();
auto input_dims = param.input->dims();
param.in_mat_dims = input_dims.Flatten2D(param.in_num_col_dims);
int m = param.in_mat_dims[0];
int k = param.in_mat_dims[1];
int n = param.w->dims()[1];
const float* bias = param.bias ? param.bias->data<float>() : nullptr;
xdnn::Activation_t act_type = (param.activation_type == "relu")
? xdnn::Activation_t::RELU
: xdnn::Activation_t::LINEAR;
int r = xdnn::fc_int16(
ctx.GetRawContext(), /* context */
false, /* TransA */
param.transpose_w, /* TransB */
m, /* m */
n, /* n */
k, /* k */
1.0f, /* alpha */
param.input->data<float>(), /* A */
reinterpret_cast<const int16_t*>(param.w->data<float>()), /* B */
param.w_max, /* max_b */
0.0f, /* beta */
param.output->mutable_data<float>(TARGET(kXPU)), /* C */
bias, /* bias */
act_type /* act_type */);
CHECK_EQ(r, 0);
}
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(__xpu__fc,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::XPUFcCompute,
def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("W", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
class XPUFcCompute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::XPUFcParam;
virtual void Run();
virtual ~XPUFcCompute() = default;
};
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "lite/kernels/xpu/stack_compute.h"
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/op_registry.h"
namespace paddle {
......
......@@ -16,18 +16,14 @@
#include <memory>
#include <vector>
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/kernel.h"
#include "lite/kernels/xpu/utils.h" // XPUFreeDeleter
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
struct XPUFreeDeleter {
void operator()(void* p) const { xpu_free(p); }
};
class StackCompute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::StackParam;
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/backends/xpu/xpu_header_sitter.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
struct XPUFreeDeleter {
void operator()(void* p) const { xpu_free(p); }
};
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -154,6 +154,8 @@ add_operator(sgd_op train SRCS sgd_op.cc DEPS ${op_DEPS})
# Only for XPU
add_operator(__xpu__resnet50_op extra SRCS __xpu__resnet50_op.cc DEPS ${op_DEPS})
add_operator(__xpu__multi_encoder_op extra SRCS __xpu__multi_encoder_op.cc DEPS ${op_DEPS})
add_operator(__xpu__embedding_with_eltwise_add_op extra SRCS __xpu__embedding_with_eltwise_add_op.cc DEPS ${op_DEPS})
add_operator(__xpu__fc_op extra SRCS __xpu__fc_op.cc DEPS ${op_DEPS})
if (NOT LITE_WITH_X86)
lite_cc_test(test_fc_op SRCS fc_op_test.cc
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/__xpu__embedding_with_eltwise_add_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool XPUEmbeddingWithEltwiseAddOp::CheckShape() const {
CHECK_OR_FALSE(param_.Ids.size() == param_.Tables.size());
auto& id_dims = param_.Ids[0]->dims();
auto& table_dims = param_.Tables[0]->dims();
int id_rank = id_dims.size();
CHECK_EQ_OR_FALSE(table_dims.size(), 2);
CHECK_EQ_OR_FALSE(id_dims[id_rank - 1], 1);
return true;
}
bool XPUEmbeddingWithEltwiseAddOp::InferShapeImpl() const {
auto& id_dims = param_.Ids[0]->dims();
auto& table_dims = param_.Tables[0]->dims();
auto out_dims = id_dims;
int id_rank = id_dims.size();
out_dims[id_rank - 1] = table_dims[1];
param_.Out->Resize(out_dims);
param_.Out->set_lod(param_.Ids[0]->lod());
return true;
}
bool XPUEmbeddingWithEltwiseAddOp::AttachImpl(const cpp::OpDesc& op_desc,
lite::Scope* scope) {
param_.Out = scope->FindVar(op_desc.Output("Output").front())
->GetMutable<lite::Tensor>();
param_.Ids.clear();
for (auto& name : op_desc.Input("Ids")) {
auto t =
const_cast<lite::Tensor*>(&scope->FindVar(name)->Get<lite::Tensor>());
param_.Ids.push_back(t);
}
param_.Tables.clear();
for (auto& name : op_desc.Input("Tables")) {
auto t =
const_cast<lite::Tensor*>(&scope->FindVar(name)->Get<lite::Tensor>());
param_.Tables.push_back(t);
}
param_.padding_idx = op_desc.GetAttr<int64_t>("padding_idx");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(__xpu__embedding_with_eltwise_add,
paddle::lite::operators::XPUEmbeddingWithEltwiseAddOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "lite/core/op_lite.h"
namespace paddle {
namespace lite {
namespace operators {
class XPUEmbeddingWithEltwiseAddOp : public OpLite {
public:
XPUEmbeddingWithEltwiseAddOp() {}
explicit XPUEmbeddingWithEltwiseAddOp(const std::string &op_type)
: OpLite(op_type) {}
bool CheckShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "EmbeddingWithEltwiseAdd"; }
private:
mutable XPUEmbeddingWithEltwiseAddParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/__xpu__fc_op.h"
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool XPUFcOp::CheckShape() const {
CHECK_OR_FALSE(param_.input);
CHECK_OR_FALSE(param_.output);
CHECK_OR_FALSE(param_.w);
// bias is optional.
const auto input_dims = param_.input->dims();
const auto w_dims = param_.w->dims();
CHECK_EQ_OR_FALSE(w_dims.size(), 2UL);
int64_t w_dims_1 = w_dims[1];
if (param_.bias) {
const auto bias_dims = param_.bias->dims();
if (bias_dims.size() == 2) {
CHECK_EQ_OR_FALSE(bias_dims[0], 1);
CHECK_EQ_OR_FALSE(bias_dims[1], w_dims_1);
} else if (bias_dims.size() == 1) {
CHECK_EQ_OR_FALSE(bias_dims[0], w_dims_1);
}
}
CHECK_GT_OR_FALSE(input_dims.size(),
static_cast<size_t>(param_.in_num_col_dims));
param_.in_mat_dims = input_dims.Flatten2D(param_.in_num_col_dims);
CHECK_EQ_OR_FALSE(param_.in_mat_dims[1], w_dims[0]);
return true;
}
bool XPUFcOp::InferShapeImpl() const {
const auto& input_dims = param_.input->dims();
const auto& w_dims = param_.w->dims();
int in_num_col_dims = param_.in_num_col_dims;
int64_t w_dims_1 = w_dims[1];
// Set output dims
std::vector<DDim::value_type> output_dims(in_num_col_dims + 1);
for (int i = 0; i < in_num_col_dims; ++i) {
output_dims[i] = input_dims[i];
}
output_dims[in_num_col_dims] = w_dims_1;
param_.output->Resize(output_dims);
// share LoD
param_.output->set_lod(param_.input->lod());
return true;
}
bool XPUFcOp::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
auto input = op_desc.Input("Input").front();
auto W = op_desc.Input("W").front();
auto out = op_desc.Output("Out").front();
param_.input = scope->FindVar(input)->GetMutable<lite::Tensor>();
param_.w = scope->FindVar(W)->GetMutable<lite::Tensor>();
std::vector<std::string> input_arg_names = op_desc.InputArgumentNames();
if (std::find(input_arg_names.begin(), input_arg_names.end(), "Bias") !=
input_arg_names.end()) {
auto bias_arguments = op_desc.Input("Bias");
if (bias_arguments.size() > 0) {
auto bias_var = scope->FindVar(bias_arguments.front());
if (bias_var != nullptr) {
param_.bias = bias_var->GetMutable<lite::Tensor>();
}
}
}
CHECK(scope->FindVar(out));
param_.output = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.in_num_col_dims = op_desc.GetAttr<int>("in_num_col_dims");
param_.w_max = op_desc.GetAttr<float>("w_max");
if (op_desc.HasAttr("activation_type")) {
param_.activation_type = op_desc.GetAttr<std::string>("activation_type");
}
if (op_desc.HasAttr("transpose_w")) {
param_.transpose_w = op_desc.GetAttr<bool>("transpose_w");
}
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(__xpu__fc, paddle::lite::operators::XPUFcOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "lite/core/op_lite.h"
namespace paddle {
namespace lite {
namespace operators {
class XPUFcOp : public OpLite {
public:
XPUFcOp() {}
explicit XPUFcOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "XPUFc"; }
private:
mutable XPUFcParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -340,7 +340,7 @@ struct ConcatParam : ParamBase {
struct ActivationParam : ParamBase {
const lite::Tensor* X{};
lite::Tensor* Out{};
lite_api::ActivationType active_type;
lite_api::ActivationType active_type{lite_api::ActivationType::kIndentity};
bool has_active{false};
float Leaky_relu_alpha{0}; // leaky_relu param
float Relu_clipped_coef{6}; // relu_clipped param
......@@ -1491,6 +1491,26 @@ struct XPUMultiEncoderParam : ParamBase {
std::string act_type{};
};
struct XPUEmbeddingWithEltwiseAddParam : ParamBase {
std::vector<lite::Tensor*> Ids;
std::vector<lite::Tensor*> Tables;
lite::Tensor* Out{};
int64_t padding_idx{-1};
};
struct XPUFcParam : ParamBase {
lite::Tensor* input{nullptr};
lite::Tensor* w{nullptr};
lite::Tensor* bias{nullptr};
lite::Tensor* output{nullptr};
int in_num_col_dims{1};
lite::DDim in_mat_dims;
float w_max{0.0f};
bool transpose_w{true};
std::string activation_type{""};
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -32,6 +32,7 @@ APU_DDK_ROOT="$(pwd)/apu_sdk_lib/"
BUILD_RKNPU=OFF
RKNPU_DDK_ROOT="$(pwd)/rknpu/"
LITE_WITH_ARM_LANG=OFF
PYTHON_EXECUTABLE_OPTION=""
readonly THIRDPARTY_TAR=https://paddle-inference-dist.bj.bcebos.com/PaddleLite/third-party-05b862.tar.gz
......@@ -388,7 +389,8 @@ function make_x86 {
-DLITE_WITH_XPU=$BUILD_XPU \
-DLITE_WITH_XTCL=$BUILD_XTCL \
-DXPU_SDK_ROOT=$XPU_SDK_ROOT \
-DCMAKE_BUILD_TYPE=Release
-DCMAKE_BUILD_TYPE=Release \
$PYTHON_EXECUTABLE_OPTION
make publish_inference -j$NUM_PROC
cd -
......@@ -482,7 +484,7 @@ function main {
--build_dir=*)
BUILD_DIR="${i#*=}"
shift
;;
;;
--opt_model_dir=*)
OPTMODEL_DIR="${i#*=}"
shift
......@@ -515,6 +517,10 @@ function main {
XPU_SDK_ROOT="${i#*=}"
shift
;;
--python_executable=*)
PYTHON_EXECUTABLE_OPTION="-DPYTHON_EXECUTABLE=${i#*=}"
shift
;;
--build_apu=*)
BUILD_APU="${i#*=}"
shift
......
#!/bin/bash
set +x
#####################################################################################################
# 1. global variables, you can change them according to your requirements
#####################################################################################################
# armv7 or armv8, default armv8.
ARM_ABI=armv8
# c++_static or c++_shared, default c++_static.
ANDROID_STL=c++_static
# gcc or clang, default gcc.
TOOLCHAIN=gcc
# ON or OFF, default OFF.
WITH_EXTRA=OFF
# ON or OFF, default ON.
WITH_JAVA=ON
# controls whether to compile cv functions into lib, default is OFF.
WITH_CV=OFF
# controls whether to hide log information, default is ON.
SHUTDOWN_LOG=ON
# options of striping lib according to input model.
OPTMODEL_DIR=""
WITH_STRIP=OFF
# options of compiling NPU lib.
WITH_HUAWEI_KIRIN_NPU=OFF
HUAWEI_KIRIN_NPU_SDK_ROOT="$(pwd)/ai_ddk_lib/" # Download HiAI DDK from https://developer.huawei.com/consumer/cn/hiai/
# options of compiling OPENCL lib.
WITH_OPENCL=OFF
# options of adding training ops
WITH_TRAIN=OFF
# num of threads used during compiling..
readonly NUM_PROC=${LITE_BUILD_THREADS:-4}
#####################################################################################################
#####################################################################################################
# 2. local variables, these variables should not be changed.
#####################################################################################################
# url that stores third-party zip file to accelerate third-paty lib installation
readonly THIRDPARTY_TAR=https://paddle-inference-dist.bj.bcebos.com/PaddleLite/third-party-05b862.tar.gz
# absolute path of Paddle-Lite.
readonly workspace=$PWD/$(dirname $0)/../../
# basic options for android compiling.
readonly CMAKE_COMMON_OPTIONS="-DWITH_LITE=ON \
-DLITE_WITH_ARM=ON \
-DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=ON \
-DLITE_WITH_X86=OFF \
-DWITH_TESTING=OFF \
-DARM_TARGET_OS=android"
# on mac environment, we should expand the maximum file num to compile successfully
os_name=`uname -s`
if [ ${os_name} == "Darwin" ]; then
ulimit -n 1024
fi
#####################################################################################################
####################################################################################################
# 3. functions of prepare workspace before compiling
####################################################################################################
# 3.1 generate `__generated_code__.cc`, which is dependended by some targets in cmake.
# here we fake an empty file to make cmake works.
function prepare_workspace {
local root_dir=$1
local build_dir=$2
# 1. Prepare gen_code file
GEN_CODE_PATH_PREFIX=$build_dir/lite/gen_code
mkdir -p ${GEN_CODE_PATH_PREFIX}
touch ${GEN_CODE_PATH_PREFIX}/__generated_code__.cc
# 2.Prepare debug tool
DEBUG_TOOL_PATH_PREFIX=$build_dir/lite/tools/debug
mkdir -p ${DEBUG_TOOL_PATH_PREFIX}
cp $root_dir/lite/tools/debug/analysis_tool.py ${DEBUG_TOOL_PATH_PREFIX}/
}
# 3.2 prepare source code of opencl lib
# here we bundle all cl files into a cc file to bundle all opencl kernels into a single lib
function prepare_opencl_source_code {
local root_dir=$1
local build_dir=$2
# in build directory
# Prepare opencl_kernels_source.cc file
GEN_CODE_PATH_OPENCL=$root_dir/lite/backends/opencl
rm -f GEN_CODE_PATH_OPENCL/opencl_kernels_source.cc
OPENCL_KERNELS_PATH=$root_dir/lite/backends/opencl/cl_kernel
mkdir -p ${GEN_CODE_PATH_OPENCL}
touch $GEN_CODE_PATH_OPENCL/opencl_kernels_source.cc
python $root_dir/lite/tools/cmake_tools/gen_opencl_code.py $OPENCL_KERNELS_PATH $GEN_CODE_PATH_OPENCL/opencl_kernels_source.cc
}
# 3.3 prepare third_party libraries for compiling
# here we store third_party libraries into Paddle-Lite/third-party
function prepare_thirdparty {
if [ ! -d $workspace/third-party -o -f $workspace/third-party-05b862.tar.gz ]; then
rm -rf $workspace/third-party
if [ ! -f $workspace/third-party-05b862.tar.gz ]; then
wget $THIRDPARTY_TAR
fi
tar xzf third-party-05b862.tar.gz
else
git submodule update --init --recursive
fi
}
####################################################################################################
####################################################################################################
# 4. compiling functions
####################################################################################################
# 4.1 function of tiny_publish compiling
# here we only compile light_api lib
function make_tiny_publish_so {
build_dir=$workspace/build.lite.android.$ARM_ABI.$TOOLCHAIN
if [ "${WITH_OPENCL}" == "ON" ]; then
build_dir=${build_dir}.opencl
fi
if [ "${WITH_npu}" == "ON" ]; then
build_dir=${build_dir}.npu
fi
if [ -d $build_dir ]
then
rm -rf $build_dir
fi
mkdir -p $build_dir
cd $build_dir
if [ "${WITH_OPENCL}" == "ON" ]; then
prepare_opencl_source_code $workspace $build_dir
fi
local cmake_mutable_options="
-DLITE_BUILD_EXTRA=$WITH_EXTRA \
-DLITE_SHUTDOWN_LOG=$SHUTDOWN_LOG \
-DLITE_BUILD_TAILOR=$WITH_STRIP \
-DLITE_OPTMODEL_DIR=$OPTMODEL_DIR \
-DLITE_WITH_JAVA=$WITH_JAVA \
-DLITE_WITH_CV=$WITH_CV \
-DLITE_WITH_NPU=$WITH_HUAWEI_KIRIN_NPU \
-DNPU_DDK_ROOT=$HUAWEI_KIRIN_NPU_SDK_ROOT \
-DLITE_WITH_OPENCL=$WITH_OPENCL \
-DARM_TARGET_ARCH_ABI=$ARM_ABI \
-DARM_TARGET_LANG=$TOOLCHAIN \
-DANDROID_STL_TYPE=$ANDROID_STL"
cmake $workspace \
${CMAKE_COMMON_OPTIONS} \
${cmake_mutable_options} \
-DLITE_ON_TINY_PUBLISH=ON
# todo: third_party of opencl should be moved into git submodule and cmake later
if [ "${WITH_OPENCL}" == "ON" ]; then
make opencl_clhpp -j$NUM_PROC
fi
make publish_inference -j$NUM_PROC
cd - > /dev/null
}
# 4.2 function of full_publish compiling
# here we compile both light_api lib and full_api lib
function make_full_publish_so {
prepare_thirdparty
build_directory=$workspace/build.lite.android.$ARM_ABI.$ARM_LANG
if [ -d $build_directory ]
then
rm -rf $build_directory
fi
mkdir -p $build_directory
cd $build_directory
prepare_workspace $workspace $build_directory
if [ "${WITH_OPENCL}" == "ON" ]; then
prepare_opencl_source_code $workspace $build_dir
fi
local cmake_mutable_options="
-DLITE_BUILD_EXTRA=$WITH_EXTRA \
-DLITE_SHUTDOWN_LOG=$SHUTDOWN_LOG \
-DLITE_BUILD_TAILOR=$WITH_STRIP \
-DLITE_OPTMODEL_DIR=$OPTMODEL_DIR \
-DLITE_WITH_JAVA=$WITH_JAVA \
-DLITE_WITH_CV=$WITH_CV \
-DLITE_WITH_NPU=$WITH_HUAWEI_KIRIN_NPU \
-DNPU_DDK_ROOT=$HUAWEI_KIRIN_NPU_SDK_ROOT \
-DLITE_WITH_OPENCL=$WITH_OPENCL \
-DARM_TARGET_ARCH_ABI=$ARM_ABI \
-DARM_TARGET_LANG=$ARM_LANG \
-DLITE_WITH_TRAIN=$WITH_TRAIN \
-DANDROID_STL_TYPE=$ANDROID_STL"
cmake $workspace \
${CMAKE_COMMON_OPTIONS} \
${cmake_mutable_options}
# todo: third_party of opencl should be moved into git submodule and cmake later
if [ "${WITH_OPENCL}" == "ON" ]; then
make opencl_clhpp -j$NUM_PROC
fi
make publish_inference -j$NUM_PROC
cd - > /dev/null
}
# 4.3 function of print help information
function print_usage {
echo "----------------------------------------------------------------------------------------------------------------------------------------"
echo -e "| Methods of compiling Padddle-Lite Android library: |"
echo "----------------------------------------------------------------------------------------------------------------------------------------"
echo -e "| compile android library: (armv8, gcc, c++_static) |"
echo -e "| ./lite/tools/build_android.sh |"
echo -e "| print help information: |"
echo -e "| ./lite/tools/build_android.sh help |"
echo -e "| |"
echo -e "| optional argument: |"
echo -e "| --arm_abi: (armv8|armv7), default is armv8 |"
echo -e "| --toolchain: (gcc|clang), defalut is gcc |"
echo -e "| --android_stl: (c++_static|c++_shared|gnu_static|gnu_shared), default is c++_static |"
echo -e "| --with_java: (OFF|ON); controls whether to publish java api lib, default is ON |"
echo -e "| --with_cv: (OFF|ON); controls whether to compile cv functions into lib, default is OFF |"
echo -e "| --shutdown_log: (OFF|ON); controls whether to hide log information, default is ON |"
echo -e "| --with_extra: (OFF|ON); controls whether to publish extra operators and kernels for (sequence-related model such as OCR or NLP) |"
echo -e "| |"
echo -e "| arguments of striping lib according to input model:(armv8, gcc, c++_static) |"
echo -e "| ./lite/tools/build_android.sh --with_strip=ON --opt_model_dir=YourOptimizedModelDir |"
echo -e "| --with_strip: (OFF|ON); controls whether to strip lib accrding to input model, default is OFF |"
echo -e "| --opt_model_dir: (absolute path to optimized model dir) required when compiling striped library |"
echo -e "| detailed information about striping lib: https://paddle-lite.readthedocs.io/zh/latest/user_guides/library_tailoring.html |"
echo -e "| |"
echo -e "| arguments of npu library compiling:(armv8, gcc, c++_static) |"
echo -e "| ./lite/tools/build_android.sh --with_huawei_kirin_npu=ON --huawei_kirin_npu_sdk_root=YourNpuSdkPath |"
echo -e "| --with_huawei_kirin_npu: (OFF|ON); controls whether to compile lib for huawei_kirin_npu, default is OFF |"
echo -e "| --huawei_kirin_npu_sdk_root: (path to huawei HiAi DDK file) required when compiling npu library |"
echo -e "| you can download huawei HiAi DDK from: https://developer.huawei.com/consumer/cn/hiai/ |"
echo -e "| detailed information about Paddle-Lite NPU: https://paddle-lite.readthedocs.io/zh/latest/demo_guides/npu.html |"
echo -e "| |"
echo -e "| arguments of opencl library compiling:(armv8, gcc, c++_static) |"
echo -e "| ./lite/tools/build_android.sh --with_opencl=ON |"
echo -e "| --with_opencl: (OFF|ON); controls whether to compile lib for opencl, default is OFF |"
echo "----------------------------------------------------------------------------------------------------------------------------------------"
echo
}
####################################################################################################
####################################################################################################
# 5. main functions: choose compiling method according to input argument
####################################################################################################
function main {
if [ -z "$1" ]; then
# compiling result contains light_api lib only, recommanded.
make_tiny_publish_so $ARM_ABI $TOOLCHAIN $ANDROID_STL
fi
# Parse command line.
for i in "$@"; do
case $i in
# armv7 or armv8, default armv8
--arm_abi=*)
ARM_ABI="${i#*=}"
shift
;;
# gcc or clang, default gcc
--toolchain=*)
TOOLCHAIN="${i#*=}"
shift
;;
# c++_static or c++_shared, default c++_static
--android_stl=*)
ANDROID_STL="${i#*=}"
shift
;;
# ON or OFF, default OFF
--with_extra=*)
WITH_EXTRA="${i#*=}"
shift
;;
# ON or OFF, default OFF
--with_cv=*)
WITH_CV="${i#*=}"
shift
;;
# ON or OFF, default ON
--with_java=*)
WITH_JAVA="${i#*=}"
shift
;;
# ON or OFF, default OFF
--with_strip=*)
WITH_STRIP="${i#*=}"
shift
;;
# string, absolute path to optimized model dir
--opt_model_dir=*)
OPTMODEL_DIR="${i#*=}"
shift
;;
# ON or OFF, default ON
--shutdown_log=*)
SHUTDOWN_LOG="${i#*=}"
shift
;;
# compiling lib which can operate on opencl and cpu.
--with_opencl=*)
WITH_OPENCL="${i#*=}"
shift
;;
# compiling lib which can operate on huawei npu.
--with_huawei_kirin_npu=*)
WITH_HUAWEI_KIRIN_NPU="${i#*=}"
shift
;;
--huawei_kirin_npu_sdk_root=*)
HUAWEI_KIRIN_NPU_SDK_ROOT="${i#*=}"
shift
;;
# compiling result contains both light_api and cxx_api lib.
full_publish)
make_full_publish_so
exit 0
;;
# compiling lib with training ops.
--with_train=*)
WITH_TRAIN="${i#*=}"
shift
;;
help)
# print help info
print_usage
exit 0
;;
*)
# unknown option
echo "Error: unsupported argument \"${i#*=}\""
print_usage
exit 1
;;
esac
# compiling result contains light_api lib only, recommanded.
make_tiny_publish_so
done
}
main $@
#!/bin/bash
set +x
#####################################################################################################
# 1. global variables, you can change them according to your requirements
#####################################################################################################
# armv7 or armv8, default armv8.
ARM_ABI=armv8
# ON or OFF, default OFF.
WITH_EXTRA=OFF
# controls whether to compile cv functions into lib, default is OFF.
WITH_CV=OFF
# controls whether to hide log information, default is ON.
SHUTDOWN_LOG=ON
# absolute path of Paddle-Lite.
workspace=$PWD/$(dirname $0)/../../
# options of striping lib according to input model.
OPTMODEL_DIR=""
WITH_STRIP=OFF
# num of threads used during compiling..
readonly NUM_PROC=${LITE_BUILD_THREADS:-4}
#####################################################################################################
#####################################################################################################
# 2. local variables, these variables should not be changed.
#####################################################################################################
# on mac environment, we should expand the maximum file num to compile successfully
os_name=`uname -s`
if [ ${os_name} == "Darwin" ]; then
ulimit -n 1024
fi
#####################################################################################################
####################################################################################################
# 3. compiling functions
####################################################################################################
function make_ios {
local abi=$1
if [ ${abi} == "armv8" ]; then
local os=ios64
elif [ ${abi} == "armv7" ]; then
local os=ios
else
echo -e "Error: unsupported arm_abi: ${abi} \t --arm_abi: armv8|armv7"
exit 1
fi
build_dir=$workspace/build.ios.${os}.${abi}
if [ -d $build_dir ]
then
rm -rf $build_dir
fi
echo "building ios target into $build_dir"
echo "target abi: $abi"
mkdir -p ${build_dir}
cd ${build_dir}
GEN_CODE_PATH_PREFIX=lite/gen_code
mkdir -p ./${GEN_CODE_PATH_PREFIX}
touch ./${GEN_CODE_PATH_PREFIX}/__generated_code__.cc
cmake $workspace \
-DWITH_LITE=ON \
-DLITE_WITH_ARM=ON \
-DLITE_ON_TINY_PUBLISH=ON \
-DLITE_WITH_OPENMP=OFF \
-DWITH_ARM_DOTPROD=OFF \
-DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=ON \
-DLITE_SHUTDOWN_LOG=$SHUTDOWN_LOG \
-DLITE_BUILD_TAILOR=$WITH_STRIP \
-DLITE_OPTMODEL_DIR=$OPTMODEL_DIR \
-DARM_TARGET_ARCH_ABI=$abi \
-DLITE_BUILD_EXTRA=$WITH_EXTRA \
-DLITE_WITH_CV=$WITH_CV \
-DARM_TARGET_OS=$os
make publish_inference -j$NUM_PROC
cd -
}
function print_usage {
echo "----------------------------------------------------------------------------------------------------------------------------------------"
echo -e "| Methods of compiling Padddle-Lite iOS library: |"
echo "----------------------------------------------------------------------------------------------------------------------------------------"
echo -e "| compile iOS armv8 library: |"
echo -e "| ./lite/tools/build_ios.sh |"
echo -e "| compile iOS armv7 library: |"
echo -e "| ./lite/tools/build_ios.sh --arm_abi=armv7 |"
echo -e "| print help information: |"
echo -e "| ./lite/tools/build_ios.sh help |"
echo -e "| |"
echo -e "| optional argument: |"
echo -e "| --arm_abi: (armv8|armv7), default is armv8 |"
echo -e "| --with_cv: (OFF|ON); controls whether to compile cv functions into lib, default is OFF |"
echo -e "| --shutdown_log: (OFF|ON); controls whether to hide log information, default is ON |"
echo -e "| --with_extra: (OFF|ON); controls whether to publish extra operators and kernels for (sequence-related model such as OCR or NLP) |"
echo -e "| |"
echo -e "| arguments of striping lib according to input model:(armv8, gcc, c++_static) |"
echo -e "| ./lite/tools/build_android.sh --with_strip=ON --opt_model_dir=YourOptimizedModelDir |"
echo -e "| --with_strip: (OFF|ON); controls whether to strip lib accrding to input model, default is OFF |"
echo -e "| --opt_model_dir: (absolute path to optimized model dir) required when compiling striped library |"
echo -e "| detailed information about striping lib: https://paddle-lite.readthedocs.io/zh/latest/user_guides/library_tailoring.html |"
echo "----------------------------------------------------------------------------------------------------------------------------------------"
}
function main {
if [ -z "$1" ]; then
make_ios $ARM_ABI
exit -1
fi
# Parse command line.
for i in "$@"; do
case $i in
--arm_abi=*)
ARM_ABI="${i#*=}"
make_ios $ARM_ABI
shift
;;
--with_extra=*)
WITH_EXTRA="${i#*=}"
shift
;;
--with_cv=*)
WITH_CV="${i#*=}"
shift
;;
--opt_model_dir=*)
OPTMODEL_DIR="${i#*=}"
shift
;;
--with_strip=*)
WITH_STRIP="${i#*=}"
shift
;;
--shutdown_log=*)
SHUTDOWN_LOG="${i#*=}"
shift
;;
help)
print_usage
exit 0
;;
*)
# unknown option
print_usage
exit 1
;;
esac
make_ios $ARM_ABI
done
}
main $@
......@@ -18,6 +18,44 @@ limitations under the License. */
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
void softmax_basic_axis_float(const float *din, float *dout,
const int axis_size, const int inner_num,
const int outer_num) {
int compute_size = inner_num * outer_num;
#pragma omp parallel for
for (int i = 0; i < compute_size; ++i) {
int idx_inner = i % inner_num;
int idx_outer = (i / inner_num) * axis_size;
int real_index = idx_outer * inner_num + idx_inner;
float max_data = din[real_index];
// get max
for (int j = 1; j < axis_size; ++j) {
real_index += inner_num;
max_data = din[real_index] > max_data ? din[real_index] : max_data;
}
real_index = idx_outer * inner_num + idx_inner;
// sub, exp and sum
dout[real_index] = expf(din[real_index] - max_data);
float sum_data = dout[real_index];
for (int j = 1; j < axis_size; ++j) {
real_index += inner_num;
dout[real_index] = expf(din[real_index] - max_data);
sum_data += dout[real_index];
}
float sum_inv = 1.f / sum_data;
real_index = idx_outer * inner_num + idx_inner;
// get softmax result
for (int j = 0; j < axis_size; ++j) {
dout[real_index] *= sum_inv;
real_index += inner_num;
}
}
}
template <typename P>
void SoftmaxCompute(const SoftmaxParam<CPU> &param) {
const Tensor *in_x = param.InputX();
......@@ -25,7 +63,29 @@ void SoftmaxCompute(const SoftmaxParam<CPU> &param) {
auto x_dims = in_x->dims();
out->Resize(x_dims);
out->mutable_data<float>();
math::SoftmaxFuntor<CPU, float>()(in_x, out);
if (param.has_axis_) {
int axis = param.axis_;
int axis_size = x_dims[axis];
auto x_rank = x_dims.size();
DLOG << "x_rank :" << x_rank;
if (axis < 0) {
axis += x_rank;
}
DLOG << "axis :" << axis;
int outer_num = framework::product(framework::slice_ddim(x_dims, 0, axis));
DLOG << "outer_num :" << outer_num;
int inner_num =
framework::product(framework::slice_ddim(x_dims, axis + 1, x_rank));
DLOG << "inner_num :" << inner_num;
softmax_basic_axis_float(in_x->data<float>(), out->data<float>(), axis_size,
inner_num, outer_num);
} else {
math::SoftmaxFuntor<CPU, float>()(in_x, out);
}
}
} // namespace operators
} // namespace paddle_mobile
......
......@@ -1180,10 +1180,17 @@ class SoftmaxParam : public OpParam {
: OpParam(inputs, outputs, attrs, scope) {
input_x_ = InputXFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, *scope);
if (HasAttr("axis", attrs)) {
axis_ = GetAttr<int>("axis", attrs);
has_axis_ = true;
}
}
const GType *InputX() const { return input_x_; }
GType *Out() const { return out_; }
int axis_ = -1;
bool has_axis_ = false;
private:
GType *input_x_;
GType *out_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册