提交 6fbeafe0 编写于 作者: C Chunwei

add high level API

上级 13b39df2
......@@ -6,6 +6,8 @@ if(LITE_WITH_CUDA)
nv_test(test_cxx_api_lite_cuda SRCS cxx_api_test.cc DEPS cxx_api_lite_cuda)
endif()
lite_cc_library(place_lite SRCS place.cc DEPS glog)
lite_cc_library(lite_api_test_helper SRCS lite_api_test_helper.cc
DEPS scope_lite optimizer_lite target_wrapper_host model_parser_lite program_lite
${ops_lite} ${host_kernels}
......@@ -24,7 +26,13 @@ message(STATUS "get ops ${ops_lite}")
message(STATUS "get Host kernels ${host_kernels}")
message(STATUS "get ARM kernels ${arm_kernels}")
lite_cc_library(cxx_api_lite SRCS cxx_api.cc DEPS ${cxx_api_lite_deps} ${ops_lite} ${host_kernels} program_lite)
lite_cc_library(cxx_api_lite
SRCS cxx_api.cc
DEPS ${cxx_api_lite_deps} ${ops_lite} ${host_kernels} program_lite
X86_DEPS ${x86_kernels} operator
ARM_DEPS ${arm_kernels}
CL_DEPS ${opencl_kenrels}
)
lite_cc_library(light_api_lite SRCS light_api.cc
DEPS scope_lite target_wrapper_host model_parser_lite
......@@ -32,6 +40,7 @@ lite_cc_library(light_api_lite SRCS light_api.cc
CUDA_DEPS target_wrapper_cuda
X86_DEPS ${x86_kernels} operator
ARM_DEPS ${arm_kernels}
CL_DEPS ${opencl_kenrels}
)
include(ExternalProject)
......@@ -91,6 +100,16 @@ lite_cc_test(test_apis_lite SRCS apis_test.cc
ARGS --model_dir=${LITE_MODEL_DIR}/lite_naive_model
--optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL)
lite_cc_library(cxx_api_impl_lite SRCS cxx_api_impl.cc DEPS cxx_api_lite)
lite_cc_library(light_api_impl_lite SRCS light_api_impl.cc DEPS light_api_lite)
lite_cc_library(paddle_api_lite SRCS paddle_api.cc DEPS cxx_api_impl_lite light_api_impl_lite)
lite_cc_test(test_paddle_api_lite SRCS paddle_api_test.cc DEPS cxx_api_lite light_api_lite paddle_api_lite
ARGS --model_dir=${LITE_MODEL_DIR}/lite_naive_model SERIAL)
if (WITH_TESTING)
add_dependencies(test_paddle_api_lite test_apis_lite)
endif()
#lite_cc_binary(cxx_api_lite_bin SRCS cxx_api_bin.cc
#X86_DEPS operator
#DEPS light_api_lite model_parser_lite target_wrapper_host mir_passes
......
......@@ -38,7 +38,7 @@ lite::Tensor *Predictor::GetInput(size_t offset) {
return &feed_list->at(offset);
}
const lite::Tensor *Predictor::GetOutput(size_t offset) {
const lite::Tensor *Predictor::GetOutput(size_t offset) const {
auto *_fetch_list = program_->exec_scope()->FindVar("fetch");
CHECK(_fetch_list) << "no fatch variable in exec_scope";
auto &fetch_list = *_fetch_list->GetMutable<std::vector<lite::Tensor>>();
......
......@@ -17,6 +17,7 @@
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/lite/api/paddle_api.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/optimizer.h"
#include "paddle/fluid/lite/core/program.h"
......@@ -53,7 +54,7 @@ class Predictor {
lite::Tensor* GetInput(size_t offset);
// Get offset-th col of fetch results.
const lite::Tensor* GetOutput(size_t offset);
const lite::Tensor* GetOutput(size_t offset) const;
const framework::proto::ProgramDesc& program_desc() const;
const lite::Tensor* GetTensor(const std::string& name) const;
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/api/cxx_api.h"
#include "paddle/fluid/lite/api/paddle_api.h"
namespace paddle {
namespace lite {
class CxxPaddleApiImpl : public lite_api::PaddlePredictor {
public:
CxxPaddleApiImpl();
/// Create a new predictor from a config.
void Init(const lite_api::CxxConfig &config);
std::unique_ptr<lite_api::Tensor> GetInput(int i) override;
std::unique_ptr<const lite_api::Tensor> GetOutput(int i) const override;
void Run() override;
std::unique_ptr<const lite_api::Tensor> GetTensor(
const std::string &name) const override;
void SaveOptimizedModel(const std::string &model_dir) override;
private:
Predictor raw_predictor_;
};
CxxPaddleApiImpl::CxxPaddleApiImpl() {}
void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) {
auto places = config.valid_places();
places.emplace_back(TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny));
raw_predictor_.Build(config.model_dir(), config.preferred_place(), places);
}
std::unique_ptr<lite_api::Tensor> CxxPaddleApiImpl::GetInput(int i) {
auto *x = raw_predictor_.GetInput(i);
return std::unique_ptr<lite_api::Tensor>(new lite_api::Tensor(x));
}
std::unique_ptr<const lite_api::Tensor> CxxPaddleApiImpl::GetOutput(
int i) const {
const auto *x = raw_predictor_.GetOutput(i);
return std::unique_ptr<lite_api::Tensor>(new lite_api::Tensor(x));
}
void CxxPaddleApiImpl::Run() { raw_predictor_.Run(); }
std::unique_ptr<const lite_api::Tensor> CxxPaddleApiImpl::GetTensor(
const std::string &name) const {
auto *x = raw_predictor_.GetTensor(name);
return std::unique_ptr<const lite_api::Tensor>(new lite_api::Tensor(x));
}
void CxxPaddleApiImpl::SaveOptimizedModel(const std::string &model_dir) {
raw_predictor_.SaveModel(model_dir);
}
} // namespace lite
namespace lite_api {
template <>
std::shared_ptr<PaddlePredictor> CreatePaddlePredictor(
const CxxConfig &config) {
auto x = std::make_shared<lite::CxxPaddleApiImpl>();
x->Init(config);
return x;
}
} // namespace lite_api
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/api/light_api.h"
#include "paddle/fluid/lite/api/paddle_api.h"
namespace paddle {
namespace lite_api {
class LightPredictorImpl : public PaddlePredictor {
public:
LightPredictorImpl() = default;
std::unique_ptr<Tensor> GetInput(int i) override;
std::unique_ptr<const Tensor> GetOutput(int i) const override;
void Run() override;
std::unique_ptr<const Tensor> GetTensor(
const std::string& name) const override;
void Init(const MobileConfig& config);
private:
std::unique_ptr<lite::LightPredictor> raw_predictor_;
};
void LightPredictorImpl::Init(const MobileConfig& config) {
raw_predictor_.reset(new lite::LightPredictor(config.model_dir()));
}
std::unique_ptr<Tensor> LightPredictorImpl::GetInput(int i) {
return std::unique_ptr<Tensor>(new Tensor(raw_predictor_->GetInput(i)));
}
std::unique_ptr<const Tensor> LightPredictorImpl::GetOutput(int i) const {
return std::unique_ptr<Tensor>(new Tensor(raw_predictor_->GetOutput(i)));
}
void LightPredictorImpl::Run() { raw_predictor_->Run(); }
std::unique_ptr<const Tensor> LightPredictorImpl::GetTensor(
const std::string& name) const {
return std::unique_ptr<const Tensor>(
new Tensor(raw_predictor_->GetTensor(name)));
}
template <>
std::shared_ptr<PaddlePredictor> CreatePaddlePredictor(
const MobileConfig& config) {
auto x = std::make_shared<LightPredictorImpl>();
x->Init(config);
return x;
}
} // namespace lite_api
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/api/paddle_api.h"
#include "paddle/fluid/lite/api/cxx_api.h"
#include "paddle/fluid/lite/api/light_api.h"
namespace paddle {
namespace lite_api {
Tensor::Tensor(void *raw) : raw_tensor_(raw) {}
// TODO(Superjomn) refine this by using another `const void* const_raw`;
Tensor::Tensor(const void *raw) { raw_tensor_ = const_cast<void *>(raw); }
lite::Tensor *tensor(void *x) { return static_cast<lite::Tensor *>(x); }
const lite::Tensor *ctensor(void *x) {
return static_cast<const lite::Tensor *>(x);
}
void Tensor::Resize(const shape_t &shape) {
tensor(raw_tensor_)->Resize(shape);
}
template <>
const float *Tensor::data() const {
return ctensor(raw_tensor_)->data<float>();
}
template <>
const int8_t *Tensor::data() const {
return ctensor(raw_tensor_)->data<int8_t>();
}
template <>
float *Tensor::mutable_data() const {
return tensor(raw_tensor_)->mutable_data<float>();
}
template <>
int8_t *Tensor::mutable_data() const {
return tensor(raw_tensor_)->mutable_data<int8_t>();
}
shape_t Tensor::shape() const {
return ctensor(raw_tensor_)->dims().Vectorize();
}
void PaddlePredictor::SaveOptimizedModel(const std::string &model_dir) {
LOG(ERROR)
<< "The SaveOptimizedModel API is only supported by CxxConfig predictor.";
}
template <typename ConfigT>
std::shared_ptr<PaddlePredictor> CreatePaddlePredictor(const ConfigT &) {
return std::shared_ptr<PaddlePredictor>();
}
} // namespace lite_api
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* This file defines PaddlePredictor, the api for lite. It supports multiple
* hardware including ARM, X86, OpenCL, CUDA and so on.
*/
#ifndef PADDLE_LITE_API_H_ // NOLINT
#define PADDLE_LITE_API_H_
#include <memory>
#include <string>
#include <vector>
#include "place.h" // NOLINT
namespace paddle {
namespace lite_api {
using shape_t = std::vector<int64_t>;
struct Tensor {
explicit Tensor(void* raw);
explicit Tensor(const void* raw);
void Resize(const shape_t& shape);
/// Readonly data.
template <typename T>
const T* data() const;
template <typename T>
T* mutable_data() const;
/// Shape of the tensor.
shape_t shape() const;
private:
void* raw_tensor_;
};
/// The PaddlePredictor defines the basic interfaces for different kinds of
/// predictors.
class PaddlePredictor {
public:
PaddlePredictor() = default;
/// Get i-th input.
virtual std::unique_ptr<Tensor> GetInput(int i) = 0;
/// Get i-th output.
virtual std::unique_ptr<const Tensor> GetOutput(int i) const = 0;
virtual void Run() = 0;
/// Get a readonly tensor, return null if no one called `name` exists.
virtual std::unique_ptr<const Tensor> GetTensor(
const std::string& name) const = 0;
/// Persist the optimized model to disk. This API is only supported by
/// CxxConfig, and the persisted model can be reused for MobileConfig.
virtual void SaveOptimizedModel(const std::string& model_dir);
virtual ~PaddlePredictor() = default;
};
/// Base class for all the configs.
class ConfigBase {
std::string model_dir_;
public:
void set_model_dir(const std::string& x) { model_dir_ = x; }
const std::string& model_dir() const { return model_dir_; }
};
/// CxxConfig is the config for the Full feature predictor.
class CxxConfig : public ConfigBase {
Place preferred_place_;
std::vector<Place> valid_places_;
public:
void set_preferred_place(const Place& x) { preferred_place_ = x; }
void set_valid_places(const std::vector<Place>& x) { valid_places_ = x; }
const Place& preferred_place() const { return preferred_place_; }
const std::vector<Place>& valid_places() const { return valid_places_; }
};
/// MobileConfig is the config for the light weight predictor, it will skip
/// IR optimization or other unnecessary stages.
class MobileConfig : public ConfigBase {};
template <typename ConfigT>
std::shared_ptr<PaddlePredictor> CreatePaddlePredictor(const ConfigT&);
} // namespace lite_api
} // namespace paddle
#endif // NOLINT
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/api/paddle_api.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include "paddle/fluid/lite/core/mir/use_passes.h"
#include "paddle/fluid/lite/kernels/use_kernels.h"
#include "paddle/fluid/lite/operators/use_ops.h"
DEFINE_string(model_dir, "", "");
namespace paddle {
namespace lite_api {
TEST(CxxApi, run) {
lite_api::CxxConfig config;
config.set_model_dir(FLAGS_model_dir);
config.set_preferred_place(Place{TARGET(kX86), PRECISION(kFloat)});
config.set_valid_places({Place{TARGET(kX86), PRECISION(kFloat)}});
auto predictor = lite_api::CreatePaddlePredictor(config);
auto input_tensor = predictor->GetInput(0);
input_tensor->Resize(std::vector<int64_t>({100, 100}));
auto* data = input_tensor->mutable_data<float>();
for (int i = 0; i < 100 * 100; i++) {
data[i] = i;
}
predictor->Run();
auto output = predictor->GetOutput(0);
auto* out = output->data<float>();
LOG(INFO) << out[0];
LOG(INFO) << out[1];
EXPECT_NEAR(out[0], 50.2132, 1e-3);
EXPECT_NEAR(out[1], -28.8729, 1e-3);
predictor->SaveOptimizedModel(FLAGS_model_dir + ".opt2");
}
TEST(LightApi, run) {
lite_api::MobileConfig config;
config.set_model_dir(FLAGS_model_dir + ".opt2");
auto predictor = lite_api::CreatePaddlePredictor(config);
auto input_tensor = predictor->GetInput(0);
input_tensor->Resize(std::vector<int64_t>({100, 100}));
auto* data = input_tensor->mutable_data<float>();
for (int i = 0; i < 100 * 100; i++) {
data[i] = i;
}
predictor->Run();
auto output = predictor->GetOutput(0);
auto* out = output->data<float>();
LOG(INFO) << out[0];
LOG(INFO) << out[1];
EXPECT_NEAR(out[0], 50.2132, 1e-3);
EXPECT_NEAR(out[1], -28.8729, 1e-3);
}
} // namespace lite_api
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/api/place.h"
#include <glog/logging.h>
#include "paddle/fluid/lite/utils/hash.h"
namespace paddle {
namespace lite_api {
size_t Place::hash() const {
std::hash<int> h;
size_t hash = h(static_cast<int>(target));
hash = lite::hash_combine(hash, static_cast<int>(precision));
hash = lite::hash_combine(hash, static_cast<int>(layout));
hash = lite::hash_combine(hash, static_cast<int>(device));
return hash;
}
bool operator<(const Place& a, const Place& b) {
if (a.target != b.target) return a.target < b.target;
if (a.precision != b.precision) return a.precision < b.precision;
if (a.layout != b.layout) return a.layout < b.layout;
if (a.device != b.device) return a.device < b.device;
return false;
}
std::string Place::DebugString() const {
std::stringstream os;
os << TargetToStr(target) << "/" << PrecisionToStr(precision) << "/"
<< DataLayoutToStr(layout);
return os.str();
}
const std::string& TargetToStr(TargetType target) {
static const std::string target2string[] = {"unk", "host", "x86", "cuda",
"arm", "opencl", "any"};
auto x = static_cast<int>(target);
CHECK_LT(x, static_cast<int>(TARGET(NUM)));
return target2string[x];
}
const std::string& PrecisionToStr(PrecisionType precision) {
static const std::string precision2string[] = {"unk", "float", "int8_t",
"any"};
auto x = static_cast<int>(precision);
CHECK_LT(x, static_cast<int>(PRECISION(NUM)));
return precision2string[x];
}
const std::string& DataLayoutToStr(DataLayoutType layout) {
static const std::string datalayout2string[] = {"unk", "NCHW", "any"};
auto x = static_cast<int>(layout);
CHECK_LT(x, static_cast<int>(DATALAYOUT(NUM)));
return datalayout2string[x];
}
const std::string& TargetRepr(TargetType target) {
static const std::string target2string[] = {
"kUnk", "kHost", "kX86", "kCUDA", "kARM", "kOpenCL", "kAny"};
auto x = static_cast<int>(target);
CHECK_LT(x, static_cast<int>(TARGET(NUM)));
return target2string[x];
}
const std::string& PrecisionRepr(PrecisionType precision) {
static const std::string precision2string[] = {"kUnk", "kFloat", "kInt8",
"kInt32", "kAny"};
auto x = static_cast<int>(precision);
CHECK_LT(x, static_cast<int>(PRECISION(NUM)));
return precision2string[x];
}
const std::string& DataLayoutRepr(DataLayoutType layout) {
static const std::string datalayout2string[] = {"kUnk", "kNCHW", "kAny"};
auto x = static_cast<int>(layout);
CHECK_LT(x, static_cast<int>(DATALAYOUT(NUM)));
return datalayout2string[x];
}
} // namespace lite_api
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
namespace paddle {
namespace lite_api {
enum class TargetType : int {
kUnk = 0,
kHost,
kX86,
kCUDA,
kARM,
kOpenCL,
kAny, // any target
NUM, // number of fields.
};
enum class PrecisionType : int {
kUnk = 0,
kFloat,
kInt8,
kInt32,
kAny, // any precision
NUM, // number of fields.
};
enum class DataLayoutType : int {
kUnk = 0,
kNCHW,
kAny, // any data layout
NUM, // number of fields.
};
static size_t PrecisionTypeLength(PrecisionType type) {
switch (type) {
case PrecisionType::kFloat:
return 4;
case PrecisionType::kInt8:
return 1;
case PrecisionType::kInt32:
return 4;
default:
return 4;
}
}
#define TARGET(item__) paddle::lite_api::TargetType::item__
#define PRECISION(item__) paddle::lite_api::PrecisionType::item__
#define DATALAYOUT(item__) paddle::lite_api::DataLayoutType::item__
const std::string& TargetToStr(TargetType target);
const std::string& PrecisionToStr(PrecisionType precision);
const std::string& DataLayoutToStr(DataLayoutType layout);
const std::string& TargetRepr(TargetType target);
const std::string& PrecisionRepr(PrecisionType precision);
const std::string& DataLayoutRepr(DataLayoutType layout);
/*
* Place specifies the execution context of a Kernel or input/output for a
* kernel. It is used to make the analysis of the MIR more clear and accurate.
*/
struct Place {
TargetType target{TARGET(kUnk)};
PrecisionType precision{PRECISION(kUnk)};
DataLayoutType layout{DATALAYOUT(kUnk)};
int16_t device{0}; // device ID
Place() = default;
Place(TargetType target, PrecisionType precision,
DataLayoutType layout = DATALAYOUT(kNCHW), int16_t device = 0)
: target(target), precision(precision), layout(layout), device(device) {}
bool is_valid() const {
return target != TARGET(kUnk) && precision != PRECISION(kUnk) &&
layout != DATALAYOUT(kUnk);
}
size_t hash() const;
bool operator==(const Place& other) const {
return target == other.target && precision == other.precision &&
layout == other.layout && device == other.device;
}
bool operator!=(const Place& other) const { return !(*this == other); }
friend bool operator<(const Place& a, const Place& b);
friend std::ostream& operator<<(std::ostream& os, const Place& other) {
os << other.DebugString();
return os;
}
std::string DebugString() const;
};
} // namespace lite_api
} // namespace paddle
......@@ -2,7 +2,7 @@ if (WITH_TESTING)
cc_library(lite_gtest_main SRCS lite_gtest_main.cc DEPS gtest gflags)
endif()
lite_cc_library(target_wrapper_lite SRCS target_wrapper.cc
DEPS target_wrapper_host
DEPS target_wrapper_host place_lite
X86_DEPS target_wrapper_x86
CUDA_DEPS target_wrapper_cuda)
lite_cc_library(memory_lite SRCS memory.cc DEPS target_wrapper_lite)
......
......@@ -17,31 +17,5 @@
#include "paddle/fluid/lite/utils/all.h"
namespace paddle {
namespace lite {
size_t Place::hash() const {
std::hash<int> h;
size_t hash = h(static_cast<int>(target));
hash = hash_combine(hash, static_cast<int>(precision));
hash = hash_combine(hash, static_cast<int>(layout));
hash = hash_combine(hash, static_cast<int>(device));
return hash;
}
bool operator<(const Place &a, const Place &b) {
if (a.target != b.target) return a.target < b.target;
if (a.precision != b.precision) return a.precision < b.precision;
if (a.layout != b.layout) return a.layout < b.layout;
if (a.device != b.device) return a.device < b.device;
return true;
}
std::string Place::DebugString() const {
std::stringstream os;
os << TargetToStr(target) << "/" << PrecisionToStr(precision) << "/"
<< DataLayoutToStr(layout);
return os.str();
}
} // namespace lite
namespace lite {} // namespace lite
} // namespace paddle
......@@ -16,7 +16,9 @@
#include <iostream>
#include <sstream>
#include <string>
#include "paddle/fluid/lite/api/place.h"
#include "paddle/fluid/lite/utils/cp_logging.h"
#ifdef LITE_WITH_CUDA
#include <cuda.h>
#include <cuda_runtime.h>
......@@ -25,134 +27,17 @@
namespace paddle {
namespace lite {
enum class TargetType : int {
kUnk = 0,
kHost,
kX86,
kCUDA,
kARM,
kOpenCL,
kAny, // any target
NUM, // number of fields.
};
enum class PrecisionType : int {
kUnk = 0,
kFloat,
kInt8,
kInt32,
kAny, // any precision
NUM, // number of fields.
};
enum class DataLayoutType : int {
kUnk = 0,
kNCHW,
kAny, // any data layout
NUM, // number of fields.
};
static size_t PrecisionTypeLength(PrecisionType type) {
switch (type) {
case PrecisionType::kFloat:
return 4;
case PrecisionType::kInt8:
return 1;
case PrecisionType::kInt32:
return 4;
default:
return 4;
}
}
// Some helper macro to get a specific TargetType.
#define TARGET(item__) paddle::lite::TargetType::item__
// Some helper macro to get a specific PrecisionType.
#define PRECISION(item__) paddle::lite::PrecisionType::item__
#define DATALAYOUT(item__) paddle::lite::DataLayoutType::item__
static const std::string& TargetToStr(TargetType target) {
static const std::string target2string[] = {"unk", "host", "x86", "cuda",
"arm", "opencl", "any"};
auto x = static_cast<int>(target);
CHECK_LT(x, static_cast<int>(TARGET(NUM)));
return target2string[x];
}
static const std::string& PrecisionToStr(PrecisionType precision) {
static const std::string precision2string[] = {"unk", "float", "int8_t",
"any"};
auto x = static_cast<int>(precision);
CHECK_LT(x, static_cast<int>(PRECISION(NUM)));
return precision2string[x];
}
static const std::string& DataLayoutToStr(DataLayoutType layout) {
static const std::string datalayout2string[] = {"unk", "NCHW", "any"};
auto x = static_cast<int>(layout);
CHECK_LT(x, static_cast<int>(DATALAYOUT(NUM)));
return datalayout2string[x];
}
static const std::string& TargetRepr(TargetType target) {
static const std::string target2string[] = {
"kUnk", "kHost", "kX86", "kCUDA", "kARM", "kOpenCL", "kAny"};
auto x = static_cast<int>(target);
CHECK_LT(x, static_cast<int>(TARGET(NUM)));
return target2string[x];
}
static const std::string& PrecisionRepr(PrecisionType precision) {
static const std::string precision2string[] = {"kUnk", "kFloat", "kInt8",
"kInt32", "kAny"};
auto x = static_cast<int>(precision);
CHECK_LT(x, static_cast<int>(PRECISION(NUM)));
return precision2string[x];
}
static const std::string& DataLayoutRepr(DataLayoutType layout) {
static const std::string datalayout2string[] = {"kUnk", "kNCHW", "kAny"};
auto x = static_cast<int>(layout);
CHECK_LT(x, static_cast<int>(DATALAYOUT(NUM)));
return datalayout2string[x];
}
/*
* Place specifies the execution context of a Kernel or input/output for a
* kernel. It is used to make the analysis of the MIR more clear and accurate.
*/
struct Place {
TargetType target{TARGET(kUnk)};
PrecisionType precision{PRECISION(kUnk)};
DataLayoutType layout{DATALAYOUT(kUnk)};
int16_t device{0}; // device ID
Place() = default;
Place(TargetType target, PrecisionType precision,
DataLayoutType layout = DATALAYOUT(kNCHW), int16_t device = 0)
: target(target), precision(precision), layout(layout), device(device) {}
bool is_valid() const {
return target != TARGET(kUnk) && precision != PRECISION(kUnk) &&
layout != DATALAYOUT(kUnk);
}
size_t hash() const;
bool operator==(const Place& other) const {
return target == other.target && precision == other.precision &&
layout == other.layout && device == other.device;
}
bool operator!=(const Place& other) const { return !(*this == other); }
friend bool operator<(const Place& a, const Place& b);
friend std::ostream& operator<<(std::ostream& os, const Place& other) {
os << other.DebugString();
return os;
}
std::string DebugString() const;
};
using lite_api::TargetType;
using lite_api::PrecisionType;
using lite_api::DataLayoutType;
using lite_api::PrecisionTypeLength;
using lite_api::TargetToStr;
using lite_api::Place;
using lite_api::PrecisionToStr;
using lite_api::DataLayoutToStr;
using lite_api::TargetRepr;
using lite_api::PrecisionRepr;
using lite_api::DataLayoutRepr;
// Memory copy directions.
enum class IoDirection {
......
......@@ -123,7 +123,7 @@ function test_arm_android {
echo "test name: ${test_name}"
adb_work_dir="/data/local/tmp"
skip_list=("test_model_parser_lite" "test_mobilenetv1_lite" "test_mobilenetv2_lite" "test_resnet50_lite" "test_inceptionv4_lite" "test_light_api_lite" "test_apis_lite")
skip_list=("test_model_parser_lite" "test_mobilenetv1_lite" "test_mobilenetv2_lite" "test_resnet50_lite" "test_inceptionv4_lite" "test_light_api_lite" "test_apis_lite" "test_paddle_api_lite")
for skip_name in ${skip_list[@]} ; do
[[ $skip_name =~ (^|[[:space:]])$test_name($|[[:space:]]) ]] && echo "skip $test_name" && return
done
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册