提交 18959c09 编写于 作者: S superjomn

make op pointer shared_ptr to support kernel infershape

上级 367a2814
......@@ -25,7 +25,7 @@ struct Config {};
class Predictor {
public:
void Build(const std::string& model_path,
const std::vector<OpLite::Place>& valid_places) {
const std::vector<Place>& valid_places) {
CHECK(!executor_.get()) << "duplicate build found";
framework::proto::ProgramDesc prog;
LoadModel(model_path, &scope_, &prog);
......
......@@ -46,6 +46,8 @@ class Node {
// The kernel instances this Instruct contains.
std::vector<std::unique_ptr<KernelBase>> valid_kernels;
std::shared_ptr<OpInfo> op_info;
// TODO(Superjomn) make this a shared_ptr for resource safety.
std::shared_ptr<OpLite> op; // we hold op to run InferShape
};
struct Argument {
......@@ -64,9 +66,11 @@ class Node {
Instruct& AsInstruct(const std::string& op_type,
std::vector<std::unique_ptr<KernelBase>>&& kernels,
const std::shared_ptr<OpLite>& op,
const std::shared_ptr<lite::OpInfo>& op_info) {
auto& x = AsInstruct();
x.op_type = op_type;
x.op = op;
x.valid_kernels = std::move(kernels);
x.op_info = op_info;
return x;
......
......@@ -34,13 +34,7 @@ namespace mir {
struct Program {
std::list<std::string> tmp_vars;
std::list<std::string> weights;
std::list<std::unique_ptr<OpLite>> ops;
lite::Scope *scope{};
};
// Program of kernel.
struct KernelProgram {
std::list<std::unique_ptr<KernelBase>> instructions;
std::list<std::shared_ptr<OpLite>> ops;
lite::Scope *scope{};
};
......@@ -67,7 +61,7 @@ class SSAGraph : GraphBase {
op->SetValidPlaces(valid_places);
auto &new_node = node_storage_.back();
node_storage_.back().AsInstruct(
op->op_type_, op->CreateKernels(valid_places), op->op_info());
op->op_type_, op->CreateKernels(valid_places), op, op->op_info());
CHECK(new_node.inlinks.empty()) << "duplicate Build found";
CHECK(new_node.outlinks.empty()) << "duplicate Build found";
......@@ -122,7 +116,7 @@ class SSAGraph : GraphBase {
const std::list<mir::Node> &nodes() const { return node_storage_; }
std::list<mir::Node> &mutable_nodes() { return node_storage_; }
mir::Node *RetriveArgument(const std::string &arg) {
mir::Node *RetrieveArgument(const std::string &arg) {
auto it = arguments_.find(arg);
if (it != arguments_.end()) {
return it->second;
......
......@@ -63,7 +63,7 @@ class VariablePlaceInferencePass : public DebugPass {
// kernel's declaration.
for (auto& arg_name : arg_names) {
auto* node = graph->RetriveArgument(arg_name);
auto* node = graph->RetrieveArgument(arg_name);
CHECK(node) << "argument " << arg_name << " not exists in the graph";
auto& arg_node = node->AsArgument();
if (arg_node.place.is_valid()) continue;
......@@ -82,7 +82,7 @@ class VariablePlaceInferencePass : public DebugPass {
// kernel's declaration.
for (auto& arg_name : arg_names) {
auto* node = graph->RetriveArgument(arg_name);
auto* node = graph->RetrieveArgument(arg_name);
CHECK(node) << "argument " << arg_name << " not exists in the graph";
auto& arg_node = node->AsArgument();
if (arg_node.place.is_valid()) continue;
......
......@@ -27,7 +27,7 @@ namespace lite {
using KernelFunc = std::function<void()>;
using KernelFuncCreator = std::function<std::unique_ptr<KernelFunc>()>;
class LiteOpRegistry final : public Factory<OpLite> {
class LiteOpRegistry final : public Factory<OpLite, std::shared_ptr<OpLite>> {
public:
static LiteOpRegistry &Global() {
static auto *x = new LiteOpRegistry;
......@@ -51,7 +51,9 @@ class OpLiteRegistor : public Registor<OpClass> {
};
template <TargetType Target, PrecisionType Precision>
using KernelRegistryForTarget = Factory<OpKernel<Target, Precision>>;
using KernelRegistryForTarget =
Factory<OpKernel<Target, Precision>,
std::unique_ptr<OpKernel<Target, Precision>>>;
class KernelRegistry final {
public:
......
set(lite_kernel_deps type_system kernel_lite op_registry_lite)
add_subdirectory(host)
add_subdirectory(arm)
add_subdirectory(cuda)
cc_library(fc_compute_host SRCS fc_compute.cc DEPS tensor_lite)
cc_library(relu_compute_host SRCS relu_compute.cc DEPS tensor_lite)
cc_library(mul_compute_host SRCS mul_compute.cc DEPS tensor_lite)
cc_library(scale_compute_host SRCS scale_compute.cc DEPS tensor_lite)
cc_library(feed_compute_host SRCS feed_compute.cc DEPS tensor_lite)
cc_library(fc_compute_host SRCS fc_compute.cc DEPS ${lite_kernel_deps})
cc_library(relu_compute_host SRCS relu_compute.cc DEPS ${lite_kernel_deps})
cc_library(mul_compute_host SRCS mul_compute.cc DEPS ${lite_kernel_deps})
cc_library(scale_compute_host SRCS scale_compute.cc DEPS ${lite_kernel_deps})
cc_library(feed_compute_host SRCS feed_compute.cc DEPS ${lite_kernel_deps})
cc_library(host_kernels DEPS
fc_compute_host
......@@ -10,7 +10,7 @@ cc_library(host_kernels DEPS
mul_compute_host
scale_compute_host
feed_compute_host
DEPS kernel_lite
DEPS ${lite_kernel_deps}
)
cc_test(test_fc_compute SRCS fc_compute_test.cc DEPS fc_compute_host fc_op_lite)
......@@ -68,4 +68,10 @@ class MulCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
REGISTER_LITE_KERNEL(mul, kHost, kFloat,
paddle::lite::kernels::host::MulCompute)
.BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kHost))})
.BindInput("Y", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kHost))})
.BindOutput("Out", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kHost))})
.Finalize();
......@@ -33,12 +33,12 @@ namespace lite {
* // Retrive a creator.
* auto some_type_instance = Factory<SomeType>::Global().Create("some_key");
*/
template <typename ItemType>
template <typename ItemType, typename ItemTypePtr>
class Factory {
public:
using item_t = ItemType;
using self_t = Factory<item_t>;
using item_ptr_t = std::unique_ptr<item_t>;
using self_t = Factory<item_t, ItemTypePtr>;
using item_ptr_t = ItemTypePtr;
using creator_t = std::function<item_ptr_t()>;
static Factory& Global() {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册