提交 18959c09 编写于 作者: S superjomn

make op pointer shared_ptr to support kernel infershape

上级 367a2814
...@@ -25,7 +25,7 @@ struct Config {}; ...@@ -25,7 +25,7 @@ struct Config {};
class Predictor { class Predictor {
public: public:
void Build(const std::string& model_path, void Build(const std::string& model_path,
const std::vector<OpLite::Place>& valid_places) { const std::vector<Place>& valid_places) {
CHECK(!executor_.get()) << "duplicate build found"; CHECK(!executor_.get()) << "duplicate build found";
framework::proto::ProgramDesc prog; framework::proto::ProgramDesc prog;
LoadModel(model_path, &scope_, &prog); LoadModel(model_path, &scope_, &prog);
......
...@@ -46,6 +46,8 @@ class Node { ...@@ -46,6 +46,8 @@ class Node {
// The kernel instances this Instruct contains. // The kernel instances this Instruct contains.
std::vector<std::unique_ptr<KernelBase>> valid_kernels; std::vector<std::unique_ptr<KernelBase>> valid_kernels;
std::shared_ptr<OpInfo> op_info; std::shared_ptr<OpInfo> op_info;
// TODO(Superjomn) make this a shared_ptr for resource safety.
std::shared_ptr<OpLite> op; // we hold op to run InferShape
}; };
struct Argument { struct Argument {
...@@ -64,9 +66,11 @@ class Node { ...@@ -64,9 +66,11 @@ class Node {
Instruct& AsInstruct(const std::string& op_type, Instruct& AsInstruct(const std::string& op_type,
std::vector<std::unique_ptr<KernelBase>>&& kernels, std::vector<std::unique_ptr<KernelBase>>&& kernels,
const std::shared_ptr<OpLite>& op,
const std::shared_ptr<lite::OpInfo>& op_info) { const std::shared_ptr<lite::OpInfo>& op_info) {
auto& x = AsInstruct(); auto& x = AsInstruct();
x.op_type = op_type; x.op_type = op_type;
x.op = op;
x.valid_kernels = std::move(kernels); x.valid_kernels = std::move(kernels);
x.op_info = op_info; x.op_info = op_info;
return x; return x;
......
...@@ -34,13 +34,7 @@ namespace mir { ...@@ -34,13 +34,7 @@ namespace mir {
struct Program { struct Program {
std::list<std::string> tmp_vars; std::list<std::string> tmp_vars;
std::list<std::string> weights; std::list<std::string> weights;
std::list<std::unique_ptr<OpLite>> ops; std::list<std::shared_ptr<OpLite>> ops;
lite::Scope *scope{};
};
// Program of kernel.
struct KernelProgram {
std::list<std::unique_ptr<KernelBase>> instructions;
lite::Scope *scope{}; lite::Scope *scope{};
}; };
...@@ -67,7 +61,7 @@ class SSAGraph : GraphBase { ...@@ -67,7 +61,7 @@ class SSAGraph : GraphBase {
op->SetValidPlaces(valid_places); op->SetValidPlaces(valid_places);
auto &new_node = node_storage_.back(); auto &new_node = node_storage_.back();
node_storage_.back().AsInstruct( node_storage_.back().AsInstruct(
op->op_type_, op->CreateKernels(valid_places), op->op_info()); op->op_type_, op->CreateKernels(valid_places), op, op->op_info());
CHECK(new_node.inlinks.empty()) << "duplicate Build found"; CHECK(new_node.inlinks.empty()) << "duplicate Build found";
CHECK(new_node.outlinks.empty()) << "duplicate Build found"; CHECK(new_node.outlinks.empty()) << "duplicate Build found";
...@@ -122,7 +116,7 @@ class SSAGraph : GraphBase { ...@@ -122,7 +116,7 @@ class SSAGraph : GraphBase {
const std::list<mir::Node> &nodes() const { return node_storage_; } const std::list<mir::Node> &nodes() const { return node_storage_; }
std::list<mir::Node> &mutable_nodes() { return node_storage_; } std::list<mir::Node> &mutable_nodes() { return node_storage_; }
mir::Node *RetriveArgument(const std::string &arg) { mir::Node *RetrieveArgument(const std::string &arg) {
auto it = arguments_.find(arg); auto it = arguments_.find(arg);
if (it != arguments_.end()) { if (it != arguments_.end()) {
return it->second; return it->second;
......
...@@ -63,7 +63,7 @@ class VariablePlaceInferencePass : public DebugPass { ...@@ -63,7 +63,7 @@ class VariablePlaceInferencePass : public DebugPass {
// kernel's declaration. // kernel's declaration.
for (auto& arg_name : arg_names) { for (auto& arg_name : arg_names) {
auto* node = graph->RetriveArgument(arg_name); auto* node = graph->RetrieveArgument(arg_name);
CHECK(node) << "argument " << arg_name << " not exists in the graph"; CHECK(node) << "argument " << arg_name << " not exists in the graph";
auto& arg_node = node->AsArgument(); auto& arg_node = node->AsArgument();
if (arg_node.place.is_valid()) continue; if (arg_node.place.is_valid()) continue;
...@@ -82,7 +82,7 @@ class VariablePlaceInferencePass : public DebugPass { ...@@ -82,7 +82,7 @@ class VariablePlaceInferencePass : public DebugPass {
// kernel's declaration. // kernel's declaration.
for (auto& arg_name : arg_names) { for (auto& arg_name : arg_names) {
auto* node = graph->RetriveArgument(arg_name); auto* node = graph->RetrieveArgument(arg_name);
CHECK(node) << "argument " << arg_name << " not exists in the graph"; CHECK(node) << "argument " << arg_name << " not exists in the graph";
auto& arg_node = node->AsArgument(); auto& arg_node = node->AsArgument();
if (arg_node.place.is_valid()) continue; if (arg_node.place.is_valid()) continue;
......
...@@ -27,7 +27,7 @@ namespace lite { ...@@ -27,7 +27,7 @@ namespace lite {
using KernelFunc = std::function<void()>; using KernelFunc = std::function<void()>;
using KernelFuncCreator = std::function<std::unique_ptr<KernelFunc>()>; using KernelFuncCreator = std::function<std::unique_ptr<KernelFunc>()>;
class LiteOpRegistry final : public Factory<OpLite> { class LiteOpRegistry final : public Factory<OpLite, std::shared_ptr<OpLite>> {
public: public:
static LiteOpRegistry &Global() { static LiteOpRegistry &Global() {
static auto *x = new LiteOpRegistry; static auto *x = new LiteOpRegistry;
...@@ -51,7 +51,9 @@ class OpLiteRegistor : public Registor<OpClass> { ...@@ -51,7 +51,9 @@ class OpLiteRegistor : public Registor<OpClass> {
}; };
template <TargetType Target, PrecisionType Precision> template <TargetType Target, PrecisionType Precision>
using KernelRegistryForTarget = Factory<OpKernel<Target, Precision>>; using KernelRegistryForTarget =
Factory<OpKernel<Target, Precision>,
std::unique_ptr<OpKernel<Target, Precision>>>;
class KernelRegistry final { class KernelRegistry final {
public: public:
......
set(lite_kernel_deps type_system kernel_lite op_registry_lite)
add_subdirectory(host) add_subdirectory(host)
add_subdirectory(arm) add_subdirectory(arm)
add_subdirectory(cuda) add_subdirectory(cuda)
cc_library(fc_compute_host SRCS fc_compute.cc DEPS tensor_lite) cc_library(fc_compute_host SRCS fc_compute.cc DEPS ${lite_kernel_deps})
cc_library(relu_compute_host SRCS relu_compute.cc DEPS tensor_lite) cc_library(relu_compute_host SRCS relu_compute.cc DEPS ${lite_kernel_deps})
cc_library(mul_compute_host SRCS mul_compute.cc DEPS tensor_lite) cc_library(mul_compute_host SRCS mul_compute.cc DEPS ${lite_kernel_deps})
cc_library(scale_compute_host SRCS scale_compute.cc DEPS tensor_lite) cc_library(scale_compute_host SRCS scale_compute.cc DEPS ${lite_kernel_deps})
cc_library(feed_compute_host SRCS feed_compute.cc DEPS tensor_lite) cc_library(feed_compute_host SRCS feed_compute.cc DEPS ${lite_kernel_deps})
cc_library(host_kernels DEPS cc_library(host_kernels DEPS
fc_compute_host fc_compute_host
...@@ -10,7 +10,7 @@ cc_library(host_kernels DEPS ...@@ -10,7 +10,7 @@ cc_library(host_kernels DEPS
mul_compute_host mul_compute_host
scale_compute_host scale_compute_host
feed_compute_host feed_compute_host
DEPS kernel_lite DEPS ${lite_kernel_deps}
) )
cc_test(test_fc_compute SRCS fc_compute_test.cc DEPS fc_compute_host fc_op_lite) cc_test(test_fc_compute SRCS fc_compute_test.cc DEPS fc_compute_host fc_op_lite)
...@@ -68,4 +68,10 @@ class MulCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> { ...@@ -68,4 +68,10 @@ class MulCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
REGISTER_LITE_KERNEL(mul, kHost, kFloat, REGISTER_LITE_KERNEL(mul, kHost, kFloat,
paddle::lite::kernels::host::MulCompute) paddle::lite::kernels::host::MulCompute)
.BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kHost))})
.BindInput("Y", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kHost))})
.BindOutput("Out", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kHost))})
.Finalize(); .Finalize();
...@@ -33,12 +33,12 @@ namespace lite { ...@@ -33,12 +33,12 @@ namespace lite {
* // Retrive a creator. * // Retrive a creator.
* auto some_type_instance = Factory<SomeType>::Global().Create("some_key"); * auto some_type_instance = Factory<SomeType>::Global().Create("some_key");
*/ */
template <typename ItemType> template <typename ItemType, typename ItemTypePtr>
class Factory { class Factory {
public: public:
using item_t = ItemType; using item_t = ItemType;
using self_t = Factory<item_t>; using self_t = Factory<item_t, ItemTypePtr>;
using item_ptr_t = std::unique_ptr<item_t>; using item_ptr_t = ItemTypePtr;
using creator_t = std::function<item_ptr_t()>; using creator_t = std::function<item_ptr_t()>;
static Factory& Global() { static Factory& Global() {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册