提交 8b950a4f 编写于 作者: S superjomn

add mir implementation

上级 f41d73b3
...@@ -52,7 +52,7 @@ class Executor { ...@@ -52,7 +52,7 @@ class Executor {
ops_.emplace_back(LiteOpRegistry::Global().Create(op_type)); ops_.emplace_back(LiteOpRegistry::Global().Create(op_type));
// pick initial kernel // pick initial kernel
ops_.back()->PickKernel(valid_places_); ops_.back()->PickKernel(valid_places_);
ops_.back()->Attach(*op_desc, exec_scope_); ops_.back()->AttachImpl(*op_desc, exec_scope_);
} }
} }
......
...@@ -17,11 +17,6 @@ ...@@ -17,11 +17,6 @@
namespace paddle { namespace paddle {
namespace lite { namespace lite {
bool operator==(const Place &a, const Place &b) {
return a.target == b.target && a.precision == b.precision &&
a.layout == b.layout;
}
bool operator<(const Place &a, const Place &b) { bool operator<(const Place &a, const Place &b) {
if (a.target != b.target) if (a.target != b.target)
return a.target < b.target; return a.target < b.target;
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/lite/core/context.h" #include "paddle/fluid/lite/core/context.h"
#include "paddle/fluid/lite/core/target_wrapper.h" #include "paddle/fluid/lite/core/target_wrapper.h"
#include "paddle/fluid/lite/core/type_system.h"
#include "paddle/fluid/lite/core/types.h" #include "paddle/fluid/lite/core/types.h"
#include "paddle/fluid/lite/operators/op_params.h" #include "paddle/fluid/lite/operators/op_params.h"
#include "paddle/fluid/lite/utils/all.h" #include "paddle/fluid/lite/utils/all.h"
...@@ -51,6 +52,7 @@ class KernelBase { ...@@ -51,6 +52,7 @@ class KernelBase {
virtual TargetType target() const = 0; virtual TargetType target() const = 0;
virtual PrecisionType precision() const = 0; virtual PrecisionType precision() const = 0;
virtual DataLayoutType layout() const = 0;
virtual ~KernelBase() = default; virtual ~KernelBase() = default;
...@@ -66,17 +68,21 @@ class KernelBase { ...@@ -66,17 +68,21 @@ class KernelBase {
* registered in the `TypeSystem`. * registered in the `TypeSystem`.
*/ */
struct ParamType { struct ParamType {
// For unsupported types.
size_t element_type_hash{}; size_t element_type_hash{};
Place tensor_place{}; Place tensor_place{};
const Type* type_;
ParamType() = default; ParamType() = default;
ParamType(size_t element_type_hash) : element_type_hash(element_type_hash) {} ParamType(size_t element_type_hash) : element_type_hash(element_type_hash) {}
ParamType(size_t element_type_hash, const Place& place) ParamType(size_t element_type_hash, const Place& place)
: element_type_hash(element_type_hash), tensor_place(place) {} : element_type_hash(element_type_hash), tensor_place(place) {}
ParamType(const Type* type) : type_(type) {}
}; };
/* /*
* The data types of kernel parameters. * The data types of kernel parameters. It is used to track the type of kernel's
* inputs and outputs.
*/ */
struct ParamTypes { struct ParamTypes {
std::vector<std::vector<ParamType>> inputs; std::vector<std::vector<ParamType>> inputs;
...@@ -115,6 +121,8 @@ struct ParamTypes { ...@@ -115,6 +121,8 @@ struct ParamTypes {
*/ */
class ParamTypeRegistry { class ParamTypeRegistry {
public: public:
enum class IO : int { kInput = 0, kOutput };
template <TargetType target, PrecisionType precision, template <TargetType target, PrecisionType precision,
DataLayoutType layout = DataLayoutType::kNCHW> DataLayoutType layout = DataLayoutType::kNCHW>
/* /*
...@@ -130,7 +138,12 @@ class ParamTypeRegistry { ...@@ -130,7 +138,12 @@ class ParamTypeRegistry {
NewInstance(const std::string& kernel_type) : kernel_type_(kernel_type) {} NewInstance(const std::string& kernel_type) : kernel_type_(kernel_type) {}
NewInstance& BindInput(int offset, const ParamType& ptype) { NewInstance& BindInput(int offset, const ParamType& ptype) {
ParamTypeRegistry::Global().Register( ParamTypeRegistry::Global().Register<IO::kInput>(
kernel_type_, Place{target, precision, layout}, offset, ptype);
return *this;
}
NewInstance& BindOutput(int offset, const ParamType& ptype) {
ParamTypeRegistry::Global().Register<IO::kOutput>(
kernel_type_, Place{target, precision, layout}, offset, ptype); kernel_type_, Place{target, precision, layout}, offset, ptype);
return *this; return *this;
} }
...@@ -141,8 +154,12 @@ class ParamTypeRegistry { ...@@ -141,8 +154,12 @@ class ParamTypeRegistry {
std::string kernel_type_; std::string kernel_type_;
}; };
template <IO io>
void Register(const std::string& kernel_type, const Place& place, int offset, void Register(const std::string& kernel_type, const Place& place, int offset,
ParamType data_type) {} ParamType data_type) {
KernelIdTy key{kernel_type, place, io, offset};
types_[key] = data_type;
}
ParamType Retrive(const Place& place, int offset); ParamType Retrive(const Place& place, int offset);
...@@ -155,16 +172,15 @@ class ParamTypeRegistry { ...@@ -155,16 +172,15 @@ class ParamTypeRegistry {
ParamTypeRegistry() = default; ParamTypeRegistry() = default;
public: public:
enum class IO : int { kInput = 0, kOutput };
// Identification for a Kernel. // Identification for a Kernel.
struct KernelIdT { struct KernelIdTy {
std::string kernel_type; std::string kernel_type;
Place place; Place place;
IO io; IO io;
int offset; int offset;
}; };
using key_t = KernelIdT; using key_t = KernelIdTy;
struct KeyCmp { struct KeyCmp {
bool operator()(const key_t& a, const key_t& b) const; bool operator()(const key_t& a, const key_t& b) const;
}; };
...@@ -188,6 +204,7 @@ class OpKernel : public KernelBase { ...@@ -188,6 +204,7 @@ class OpKernel : public KernelBase {
TargetType target() const override { return Target; } TargetType target() const override { return Target; }
PrecisionType precision() const override { return Precision; } PrecisionType precision() const override { return Precision; }
DataLayoutType layout() const override { return DataLayout; }
void Touch() {} void Touch() {}
......
cc_library(mir_pass SRCS pass.cc)
cc_library(mir_node SRCS node.cc) cc_library(mir_node SRCS node.cc)
cc_library(mir_ssa_graph SRCS ssa_graph.cc) cc_library(mir_ssa_graph SRCS ssa_graph.cc DEPS mir_node)
\ No newline at end of file cc_library(mir_pass SRCS pass.cc DEPS mir_ssa_graph)
cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph)
cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager)
cc_library(mir_demo_pass SRCS demo_pass.cc DEPS mir_pass)
cc_test(test_mir_pass_manager SRCS pass_manager_test.cc DEPS mir_pass_manager mir_demo_pass)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
class DemoPass : public mir::Pass {
public:
void Apply(std::unique_ptr<mir::SSAGraph>& graph) override {}
};
bool RegisterDemoPass() {
return PassManager::Global().AddNewPass("demo", new DemoPass);
}
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/node.h" #include "paddle/fluid/lite/core/mir/node.h"
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <list>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/kernel.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace mir { namespace mir {
// Node in a MIR graph.
class Node { class Node {
public: public:
// Tell is instruction. std::list<Node*> inlinks;
bool IsInstruct() const; std::list<Node*> outlinks;
// Tell is an argument.
bool IsArgument() const; Node() = default;
};
enum class Role {
kUnk = -1,
kArgument,
kInstruct,
kNumRoles /*should be last*/
};
struct Instruct {
std::string op_type;
Place place;
// The kernel instances this Instruct contains.
std::vector<std::unique_ptr<KernelBase>> valid_kernels;
};
struct Argument {
std::string name;
Place place;
};
// Set roles.
Argument& AsArgument() {
if (role_ != Role::kUnk) {
CHECK(role_ == Role::kArgument);
return *argument_;
}
role_ = Role::kArgument;
argument_.reset(new Argument);
return *argument_;
}
Instruct& AsInstruct() {
if (role_ != Role::kUnk) {
CHECK(role_ == Role::kInstruct);
return *instruct_;
}
role_ = Role::kInstruct;
instruct_.reset(new Instruct);
return *instruct_;
}
// Check roles.
bool IsRoleSet() const { return role_ == Role::kUnk; }
bool IsInstruct() const { return role_ == Role::kInstruct; }
bool IsArgument() const { return role_ == Role::kArgument; }
private:
// Either instruct_ or argument_ is used.
std::unique_ptr<Instruct> instruct_;
std::unique_ptr<Argument> argument_;
Role role_{Role::kUnk};
};
} // namespace mir } // namespace mir
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
\ No newline at end of file
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pass.h" #include "paddle/fluid/lite/core/mir/pass.h"
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/core/mir/node.h"
#include "paddle/fluid/lite/core/mir/ssa_graph.h"
namespace paddle {
namespace lite {
namespace mir {
class Pass {
public:
virtual void Apply(std::unique_ptr<mir::SSAGraph>& graph) = 0;
const std::string& name() const { return name_; }
virtual ~Pass() = default;
private:
std::string name_;
};
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
PassManager::PassManager() {}
// Manually register here.
extern bool RegisterDemoPass();
static bool xx __attribute__((unused)) = RegisterDemoPass();
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <list>
#include <map>
#include <string>
#include "paddle/fluid/lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
class PassManager {
public:
static PassManager& Global() {
static PassManager x;
return x;
}
PassManager();
void Run() {
for (auto& pass : passes_) {
LOG(INFO) << "Running MIR pass " << pass->name();
pass->Apply(graph_);
}
}
bool AddNewPass(const std::string& name, Pass* pass) {
passes_.emplace_back(pass);
pass_map_.emplace(name, passes_.back().get());
return true;
}
// Clear all the passes.
void Clear() { passes_.clear(); }
std::list<std::unique_ptr<mir::Pass>>::iterator passes_begin() {
return passes_.begin();
}
std::list<std::unique_ptr<mir::Pass>>::iterator passes_end() {
return passes_.end();
}
std::list<std::unique_ptr<mir::Pass>>::const_iterator passes_const_begin()
const {
return passes_.begin();
}
std::list<std::unique_ptr<mir::Pass>>::const_iterator passes_const_end()
const {
return passes_.end();
}
Pass* LookUp(const std::string& key) {
auto it = pass_map_.find(key);
CHECK(it != pass_map_.end());
return it->second;
}
private:
std::unique_ptr<mir::SSAGraph> graph_;
std::list<std::unique_ptr<mir::Pass>> passes_;
std::map<std::string, mir::Pass*> pass_map_;
};
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include <gtest/gtest.h>
namespace paddle {
namespace lite {
namespace mir {
TEST(PassManager, test) {
auto* pass = PassManager::Global().LookUp("demo");
LOG(INFO) << "pass: " << pass;
ASSERT_TRUE(pass != nullptr);
}
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/mir/pass_manager.h"
namespace paddle {
namespace lite {
namespace mir {
class PassRegistry {
public:
PassRegistry(const std::string& name, mir::Pass* pass) {
LOG(INFO) << "Registry add MIR pass " << name;
PassManager::Global().AddNewPass(name, pass);
}
bool Touch() const { return true; }
};
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/ssa_graph.h" #include "paddle/fluid/lite/core/mir/ssa_graph.h"
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <list>
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/mir/node.h"
#include "paddle/fluid/lite/core/op_lite.h"
namespace paddle {
namespace lite {
namespace mir {
// A program is used to represent a code program, in Paddle, a code program
// contains:
// - main block, which is a list of OpLite
// - scope: which contains all the weights
struct Program {
std::list<std::unique_ptr<OpLite>> ops;
lite::Scope *scope;
};
// An Graph for MIR. It is built from a list of Op and a scope.
class GraphBase {};
class SSAGraph : GraphBase {
public:
// @param program: the op program
// @param valid_places: the valid places user set for the system.
void Build(const Program &program, const std::vector<Place> &valid_places) {
for (auto &op : program.ops) {
node_storage_.emplace_back();
// TODO(Superjomn) remove one valid_places here.
op->SetValidPlaces(valid_places);
auto &new_node = node_storage_.back();
auto &new_kernel = node_storage_.back().AsInstruct();
new_kernel.valid_kernels = op->CreateKernels(valid_places);
CHECK(new_node.inlinks.empty()) << "duplicate Build found";
CHECK(new_node.outlinks.empty()) << "duplicate Build found";
// collect inputs and outputs
for (const std::string &name : op->input_names()) {
new_node.inlinks.push_back(arguments_.at(name));
}
for (const std::string &name : op->output_names()) {
new_node.outlinks.push_back(arguments_.at(name));
}
}
}
std::vector<mir::Node *> TopoloticalOrder() const;
private:
std::list<mir::Node> node_storage_;
std::map<std::string, mir::Node *> arguments_;
};
} // namespace mir
} // namespace lite
} // namespace paddle
...@@ -44,5 +44,15 @@ void OpLite::PickKernel(const std::vector<Place> &valid_places, ...@@ -44,5 +44,15 @@ void OpLite::PickKernel(const std::vector<Place> &valid_places,
} }
} }
bool OpLite::Run() {
CHECK(kernel_);
SyncInputEvents();
kernel_->Run();
RecordOutputEvents();
return true;
}
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -36,6 +36,11 @@ struct Registry { ...@@ -36,6 +36,11 @@ struct Registry {
void Touch() {} void Touch() {}
}; };
namespace mir {
class Node;
class SSAGraph;
}
/** /**
* The base class of an light-weight operators, currently just used in inference * The base class of an light-weight operators, currently just used in inference
* to eliminate overhead of some operations in current framework. * to eliminate overhead of some operations in current framework.
...@@ -71,19 +76,13 @@ class OpLite : public Registry { ...@@ -71,19 +76,13 @@ class OpLite : public Registry {
// Inference the outputs' shape. // Inference the outputs' shape.
virtual bool InferShape() const { return true; } virtual bool InferShape() const { return true; }
// Run this operator. // Run this operator.
virtual bool Run() { virtual bool Run();
CHECK(kernel_);
SyncInputEvents();
kernel_->Run();
RecordOutputEvents(); bool Attach(const framework::OpDesc &opdesc, lite::Scope *scope) {
return true; ExtractInputsAndOutputs(opdesc);
return AttachImpl(opdesc, scope);
} }
// Attach it with the runtime environment.
virtual bool Attach(const framework::OpDesc &opdesc, lite::Scope *scope) = 0;
// Human-readable information. // Human-readable information.
virtual std::string DebugString() const = 0; virtual std::string DebugString() const = 0;
...@@ -92,9 +91,29 @@ class OpLite : public Registry { ...@@ -92,9 +91,29 @@ class OpLite : public Registry {
void PickKernel(const std::vector<Place> &valid_places, void PickKernel(const std::vector<Place> &valid_places,
KernelStrategy kernel_strategy = KernelStrategy::kStatic); KernelStrategy kernel_strategy = KernelStrategy::kStatic);
const std::list<std::string> &input_names() const { return input_names_; }
const std::list<std::string> &output_names() const { return output_names_; }
virtual ~OpLite() = default; virtual ~OpLite() = default;
protected: protected:
// Attach it with the runtime environment.
virtual bool AttachImpl(const framework::OpDesc &opdesc,
lite::Scope *scope) = 0;
void ExtractInputsAndOutputs(const framework::OpDesc &opdesc) {
for (const auto &item : opdesc.Inputs()) {
for (const auto &x : item.second) {
input_names_.push_back(x);
}
}
for (const auto &item : opdesc.Outputs()) {
for (const auto &x : item.second) {
output_names_.push_back(x);
}
}
}
// Specify the kernel to run by default. This will specify the value of // Specify the kernel to run by default. This will specify the value of
// `kernel_place_`. // `kernel_place_`.
virtual void StaticPickKernel(const std::vector<Place> &valid_targets) { virtual void StaticPickKernel(const std::vector<Place> &valid_targets) {
...@@ -113,12 +132,17 @@ class OpLite : public Registry { ...@@ -113,12 +132,17 @@ class OpLite : public Registry {
std::vector<std::unique_ptr<KernelBase>> CreateKernels( std::vector<std::unique_ptr<KernelBase>> CreateKernels(
const std::vector<Place> &places, const std::string &kernel_type = ""); const std::vector<Place> &places, const std::string &kernel_type = "");
friend class mir::Node;
friend class mir::SSAGraph;
protected: protected:
std::unique_ptr<OpContext> op_context_; std::unique_ptr<OpContext> op_context_;
std::unique_ptr<KernelBase> kernel_; std::unique_ptr<KernelBase> kernel_;
std::string op_type_; std::string op_type_;
std::vector<Place> valid_places_; std::vector<Place> valid_places_;
Place kernel_place_{TARGET(kHost), PRECISION(kFloat)}; Place kernel_place_{TARGET(kHost), PRECISION(kFloat)};
std::list<std::string> input_names_;
std::list<std::string> output_names_;
}; };
} // namespace lite } // namespace lite
......
...@@ -13,3 +13,46 @@ ...@@ -13,3 +13,46 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/core/type_system.h" #include "paddle/fluid/lite/core/type_system.h"
namespace paddle {
namespace lite {
// ------------------------- GetType specification ----------------------------
template <>
const Type*
Type::Get<false /*is_unsupported*/, false /*is_tensor*/, TargetType::kHost,
PrecisionType::kFloat, DataLayoutType::kNCHW>() {
static UnsupportedTy x;
return &x;
}
template <>
const Type*
Type::Get<false /*is_unsupported*/, true /*is_tensor*/, TargetType::kX86,
PrecisionType::kFloat, DataLayoutType::kNCHW>() {
static TensorFp32NCHWTy x(TargetType::kX86);
return &x;
}
template <>
const Type* Type::Get<UnsupportedTy>(TargetType target, int device) {
return Get<false, false, TargetType::kHost, PrecisionType::kFloat,
DataLayoutType::kNCHW>();
}
template <>
const Type* Type::Get<TensorFp32NCHWTy>(TargetType target) {
switch (target) {
case TargetType::kX86:
return Get<false, true, TargetType::kX86, PrecisionType::kFloat,
DataLayoutType::kNCHW>();
default:
LOG(FATAL) << "unsupported target " << TargetToStr(target);
return nullptr;
}
}
// ------------------------- end GetType specification ------------------------
} // namespace lite
} // namespace paddle
...@@ -82,7 +82,7 @@ class DataTypeBase { ...@@ -82,7 +82,7 @@ class DataTypeBase {
* Datatype with device info considered. * Datatype with device info considered.
* NOTE A Type with different device is treated as different DeviceDataType. * NOTE A Type with different device is treated as different DeviceDataType.
*/ */
class DeviceDataType : public DataTypeBase { class Type : public DataTypeBase {
public: public:
TargetType target() const { return place_.target; } TargetType target() const { return place_.target; }
PrecisionType precision() const { return place_.precision; } PrecisionType precision() const { return place_.precision; }
...@@ -90,23 +90,31 @@ class DeviceDataType : public DataTypeBase { ...@@ -90,23 +90,31 @@ class DeviceDataType : public DataTypeBase {
const Place& place() const { return place_; } const Place& place() const { return place_; }
const std::string& name() const { return name_; } const std::string& name() const { return name_; }
bool operator==(const DeviceDataType& other) { bool operator==(const Type& other) {
return id_ == other.id() && place_ == other.place(); return id_ == other.id() && place_ == other.place();
} }
// Can cast to another type. This is heavily used in MIR, by determine whether // Can cast to another type. This is heavily used in MIR, by determine whether
// is is possible to add a instruction to transform a type to another. // is is possible to add a instruction to transform a type to another.
virtual bool TypeCastable(const DeviceDataType& type) const { virtual bool TypeCastable(const Type& type) const { return id_ == type.id(); }
return id_ == type.id();
} template <bool is_unknown, bool is_tensor = true,
TargetType target = TargetType::kHost,
PrecisionType precision = PrecisionType::kFloat,
DataLayoutType layout = DataLayoutType::kNCHW>
// Get a type.
static const Type* Get();
template <typename TypeTy>
static const Type* Get(TargetType target = TargetType::kHost);
virtual ~DeviceDataType() = default; virtual ~Type() = default;
protected: protected:
DeviceDataType(ID id, const std::string& name, bool is_tensor, Type(ID id, const std::string& name, bool is_tensor,
TargetType target = TargetType::kHost, TargetType target = TargetType::kHost,
PrecisionType precision = PrecisionType::kFloat, PrecisionType precision = PrecisionType::kFloat,
DataLayoutType layout = DataLayoutType::kNCHW) DataLayoutType layout = DataLayoutType::kNCHW)
: DataTypeBase(id, is_tensor), : DataTypeBase(id, is_tensor),
place_{target, precision, layout}, place_{target, precision, layout},
name_(name) {} name_(name) {}
...@@ -117,30 +125,33 @@ class DeviceDataType : public DataTypeBase { ...@@ -117,30 +125,33 @@ class DeviceDataType : public DataTypeBase {
}; };
// -------------------------------- predefined types --------------------------- // -------------------------------- predefined types ---------------------------
class Void : public DeviceDataType { // TODO(Superjomn) make all the Types' constructs protected to make sure there
// is only one instance across the system.
class VoidTy : public Type {
public:
VoidTy() : Type(ID::Void, "Void", false /*is_tensor*/) {}
};
class UnsupportedTy : public Type {
public: public:
Void() : DeviceDataType(ID::Void, "Void", false /*is_tensor*/) {} UnsupportedTy() : Type(ID::Unsupported, "Unsupported", false /*is_tensor*/) {}
}; };
class TensorFp32NCHW : public DeviceDataType { class TensorFp32NCHWTy : public Type {
public: public:
TensorFp32NCHW(TargetType target) TensorFp32NCHWTy(TargetType target)
: DeviceDataType(ID::Tensor_Fp32_NCHW, "TensorFp32NCHW", : Type(ID::Tensor_Fp32_NCHW, "TensorFp32NCHW", true /*is_tensor*/, target,
true /*is_tensor*/, target, PrecisionType::kFloat, PrecisionType::kFloat, DataLayoutType::kNCHW) {}
DataLayoutType::kNCHW) {}
}; };
class TensorInt8NCHW : public DeviceDataType { class TensorInt8NCHWTy : public Type {
public: public:
TensorInt8NCHW(TargetType target) TensorInt8NCHWTy(TargetType target)
: DeviceDataType(ID::Tensor_Int8_NCHW, "TensorInt8NCHW", : Type(ID::Tensor_Int8_NCHW, "TensorInt8NCHW", true /*is_tensor*/, target,
true /*is_tensor*/, target, PrecisionType::kInt8, PrecisionType::kInt8, DataLayoutType::kNCHW) {}
DataLayoutType::kNCHW) {}
}; };
class TensorInt64NCHW : public DeviceDataType { class TensorInt64NCHWTy : public Type {
public: public:
TensorInt64NCHW(TargetType target) TensorInt64NCHWTy(TargetType target)
: DeviceDataType(ID::Tensor_Int64_NCHW, "TensorInt64NCHW", : Type(ID::Tensor_Int64_NCHW, "TensorInt64NCHW", true /*is_tensor*/,
true /*is_tensor*/, target, PrecisionType::kInt8, target, PrecisionType::kInt8, DataLayoutType::kNCHW) {}
DataLayoutType::kNCHW) {}
}; };
// ------------------------- end predefined types --------------------------- // ------------------------- end predefined types ---------------------------
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/lite/kernels/host/fc_compute.h" #include "paddle/fluid/lite/kernels/host/fc_compute.h"
#include <Eigen/Core> #include <Eigen/Core>
#include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/type_system.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -51,6 +52,8 @@ PrecisionType FcCompute::precision() const { return PRECISION(kFloat); } ...@@ -51,6 +52,8 @@ PrecisionType FcCompute::precision() const { return PRECISION(kFloat); }
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL(fc, kHost, kFloat, paddle::lite::kernels::host::FcCompute) REGISTER_LITE_KERNEL(fc, kHost, kFloat, paddle::lite::kernels::host::FcCompute)
.BindInput(0, {typeid(paddle::lite::Tensor).hash_code(), .BindInput(0, {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
paddle::lite::Place{TARGET(kHost), PRECISION(kFloat)}}) TARGET(kX86))})
.BindOutput(0, {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kX86))})
.Finalize(); .Finalize();
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#pragma once
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/lite/core/kernel.h" #include "paddle/fluid/lite/core/kernel.h"
...@@ -44,7 +46,8 @@ class FcOpLite : public OpLite { ...@@ -44,7 +46,8 @@ class FcOpLite : public OpLite {
*/ */
// TODO(Superjomn) replace framework::OpDesc with a lite one. // TODO(Superjomn) replace framework::OpDesc with a lite one.
bool Attach(const framework::OpDesc &op_desc, lite::Scope *scope) override { bool AttachImpl(const framework::OpDesc &op_desc,
lite::Scope *scope) override {
auto input = op_desc.Input("Input").front(); auto input = op_desc.Input("Input").front();
auto W = op_desc.Input("W").front(); auto W = op_desc.Input("W").front();
auto bias = op_desc.Input("Bias").front(); auto bias = op_desc.Input("Bias").front();
......
...@@ -61,7 +61,7 @@ TEST(fc_op_lite, test) { ...@@ -61,7 +61,7 @@ TEST(fc_op_lite, test) {
fc.SetValidPlaces({Place{TARGET(kHost), PRECISION(kFloat)}}); fc.SetValidPlaces({Place{TARGET(kHost), PRECISION(kFloat)}});
fc.PickKernel({Place{TARGET(kHost), PRECISION(kFloat)}}); fc.PickKernel({Place{TARGET(kHost), PRECISION(kFloat)}});
fc.Attach(desc, &scope); fc.AttachImpl(desc, &scope);
fc.Run(); fc.Run();
for (int i = 0; i < 10 * 20; i++) { for (int i = 0; i < 10 * 20; i++) {
......
...@@ -37,7 +37,8 @@ class MulOpLite : public OpLite { ...@@ -37,7 +37,8 @@ class MulOpLite : public OpLite {
bool InferShape() const override; bool InferShape() const override;
// TODO(Superjomn) replace framework::OpDesc with a lite one. // TODO(Superjomn) replace framework::OpDesc with a lite one.
bool Attach(const framework::OpDesc &op_desc, lite::Scope *scope) override { bool AttachImpl(const framework::OpDesc &op_desc,
lite::Scope *scope) override {
auto input = op_desc.Input("X").front(); auto input = op_desc.Input("X").front();
auto W = op_desc.Input("Y").front(); auto W = op_desc.Input("Y").front();
auto out = op_desc.Output("Out").front(); auto out = op_desc.Output("Out").front();
......
...@@ -31,7 +31,7 @@ bool ReluOp::InferShape() const { ...@@ -31,7 +31,7 @@ bool ReluOp::InferShape() const {
return true; return true;
} }
bool ReluOp::Attach(const framework::OpDesc &opdesc, lite::Scope *scope) { bool ReluOp::AttachImpl(const framework::OpDesc &opdesc, lite::Scope *scope) {
param_.input = const_cast<Tensor *>( param_.input = const_cast<Tensor *>(
&scope->FindVar(opdesc.Input("Input").front())->Get<Tensor>()); &scope->FindVar(opdesc.Input("Input").front())->Get<Tensor>());
param_.output = param_.output =
......
...@@ -32,7 +32,7 @@ class ReluOp : public OpLite { ...@@ -32,7 +32,7 @@ class ReluOp : public OpLite {
bool InferShape() const override; bool InferShape() const override;
bool Attach(const framework::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const framework::OpDesc &opdesc, lite::Scope *scope) override;
std::string DebugString() const override { return "tanh"; } std::string DebugString() const override { return "tanh"; }
......
...@@ -44,7 +44,8 @@ class ScaleOp : public OpLite { ...@@ -44,7 +44,8 @@ class ScaleOp : public OpLite {
} }
// TODO(Superjomn) replace framework::OpDesc with a lite one. // TODO(Superjomn) replace framework::OpDesc with a lite one.
bool Attach(const framework::OpDesc &op_desc, lite::Scope *scope) override { bool AttachImpl(const framework::OpDesc &op_desc,
lite::Scope *scope) override {
auto x = op_desc.Input("X").front(); auto x = op_desc.Input("X").front();
auto out = op_desc.Output("Out").front(); auto out = op_desc.Output("Out").front();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册