提交 8b950a4f 编写于 作者: S superjomn

add mir implementation

上级 f41d73b3
......@@ -52,7 +52,7 @@ class Executor {
ops_.emplace_back(LiteOpRegistry::Global().Create(op_type));
// pick initial kernel
ops_.back()->PickKernel(valid_places_);
ops_.back()->Attach(*op_desc, exec_scope_);
ops_.back()->AttachImpl(*op_desc, exec_scope_);
}
}
......
......@@ -17,11 +17,6 @@
namespace paddle {
namespace lite {
bool operator==(const Place &a, const Place &b) {
return a.target == b.target && a.precision == b.precision &&
a.layout == b.layout;
}
bool operator<(const Place &a, const Place &b) {
if (a.target != b.target)
return a.target < b.target;
......
......@@ -19,6 +19,7 @@
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/lite/core/context.h"
#include "paddle/fluid/lite/core/target_wrapper.h"
#include "paddle/fluid/lite/core/type_system.h"
#include "paddle/fluid/lite/core/types.h"
#include "paddle/fluid/lite/operators/op_params.h"
#include "paddle/fluid/lite/utils/all.h"
......@@ -51,6 +52,7 @@ class KernelBase {
virtual TargetType target() const = 0;
virtual PrecisionType precision() const = 0;
virtual DataLayoutType layout() const = 0;
virtual ~KernelBase() = default;
......@@ -66,17 +68,21 @@ class KernelBase {
* registered in the `TypeSystem`.
*/
struct ParamType {
// For unsupported types.
size_t element_type_hash{};
Place tensor_place{};
const Type* type_;
ParamType() = default;
ParamType(size_t element_type_hash) : element_type_hash(element_type_hash) {}
ParamType(size_t element_type_hash, const Place& place)
: element_type_hash(element_type_hash), tensor_place(place) {}
ParamType(const Type* type) : type_(type) {}
};
/*
* The data types of kernel parameters.
* The data types of kernel parameters. It is used to track the type of kernel's
* inputs and outputs.
*/
struct ParamTypes {
std::vector<std::vector<ParamType>> inputs;
......@@ -115,6 +121,8 @@ struct ParamTypes {
*/
class ParamTypeRegistry {
public:
enum class IO : int { kInput = 0, kOutput };
template <TargetType target, PrecisionType precision,
DataLayoutType layout = DataLayoutType::kNCHW>
/*
......@@ -130,7 +138,12 @@ class ParamTypeRegistry {
NewInstance(const std::string& kernel_type) : kernel_type_(kernel_type) {}
NewInstance& BindInput(int offset, const ParamType& ptype) {
ParamTypeRegistry::Global().Register(
ParamTypeRegistry::Global().Register<IO::kInput>(
kernel_type_, Place{target, precision, layout}, offset, ptype);
return *this;
}
NewInstance& BindOutput(int offset, const ParamType& ptype) {
ParamTypeRegistry::Global().Register<IO::kOutput>(
kernel_type_, Place{target, precision, layout}, offset, ptype);
return *this;
}
......@@ -141,8 +154,12 @@ class ParamTypeRegistry {
std::string kernel_type_;
};
template <IO io>
void Register(const std::string& kernel_type, const Place& place, int offset,
ParamType data_type) {}
ParamType data_type) {
KernelIdTy key{kernel_type, place, io, offset};
types_[key] = data_type;
}
ParamType Retrive(const Place& place, int offset);
......@@ -155,16 +172,15 @@ class ParamTypeRegistry {
ParamTypeRegistry() = default;
public:
enum class IO : int { kInput = 0, kOutput };
// Identification for a Kernel.
struct KernelIdT {
struct KernelIdTy {
std::string kernel_type;
Place place;
IO io;
int offset;
};
using key_t = KernelIdT;
using key_t = KernelIdTy;
struct KeyCmp {
bool operator()(const key_t& a, const key_t& b) const;
};
......@@ -188,6 +204,7 @@ class OpKernel : public KernelBase {
TargetType target() const override { return Target; }
PrecisionType precision() const override { return Precision; }
DataLayoutType layout() const override { return DataLayout; }
void Touch() {}
......
cc_library(mir_pass SRCS pass.cc)
cc_library(mir_node SRCS node.cc)
cc_library(mir_ssa_graph SRCS ssa_graph.cc)
\ No newline at end of file
cc_library(mir_ssa_graph SRCS ssa_graph.cc DEPS mir_node)
cc_library(mir_pass SRCS pass.cc DEPS mir_ssa_graph)
cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph)
cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager)
cc_library(mir_demo_pass SRCS demo_pass.cc DEPS mir_pass)
cc_test(test_mir_pass_manager SRCS pass_manager_test.cc DEPS mir_pass_manager mir_demo_pass)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
class DemoPass : public mir::Pass {
public:
void Apply(std::unique_ptr<mir::SSAGraph>& graph) override {}
};
bool RegisterDemoPass() {
return PassManager::Global().AddNewPass("demo", new DemoPass);
}
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/node.h"
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <list>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace mir {
// Node in a MIR graph.
class Node {
public:
// Tell is instruction.
bool IsInstruct() const;
// Tell is an argument.
bool IsArgument() const;
};
std::list<Node*> inlinks;
std::list<Node*> outlinks;
Node() = default;
enum class Role {
kUnk = -1,
kArgument,
kInstruct,
kNumRoles /*should be last*/
};
struct Instruct {
std::string op_type;
Place place;
// The kernel instances this Instruct contains.
std::vector<std::unique_ptr<KernelBase>> valid_kernels;
};
struct Argument {
std::string name;
Place place;
};
// Set roles.
Argument& AsArgument() {
if (role_ != Role::kUnk) {
CHECK(role_ == Role::kArgument);
return *argument_;
}
role_ = Role::kArgument;
argument_.reset(new Argument);
return *argument_;
}
Instruct& AsInstruct() {
if (role_ != Role::kUnk) {
CHECK(role_ == Role::kInstruct);
return *instruct_;
}
role_ = Role::kInstruct;
instruct_.reset(new Instruct);
return *instruct_;
}
// Check roles.
bool IsRoleSet() const { return role_ == Role::kUnk; }
bool IsInstruct() const { return role_ == Role::kInstruct; }
bool IsArgument() const { return role_ == Role::kArgument; }
private:
// Either instruct_ or argument_ is used.
std::unique_ptr<Instruct> instruct_;
std::unique_ptr<Argument> argument_;
Role role_{Role::kUnk};
};
} // namespace mir
} // namespace lite
} // namespace paddle
\ No newline at end of file
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pass.h"
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/core/mir/node.h"
#include "paddle/fluid/lite/core/mir/ssa_graph.h"
namespace paddle {
namespace lite {
namespace mir {
class Pass {
public:
virtual void Apply(std::unique_ptr<mir::SSAGraph>& graph) = 0;
const std::string& name() const { return name_; }
virtual ~Pass() = default;
private:
std::string name_;
};
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
PassManager::PassManager() {}
// Manually register here.
extern bool RegisterDemoPass();
static bool xx __attribute__((unused)) = RegisterDemoPass();
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <list>
#include <map>
#include <string>
#include "paddle/fluid/lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
class PassManager {
public:
static PassManager& Global() {
static PassManager x;
return x;
}
PassManager();
void Run() {
for (auto& pass : passes_) {
LOG(INFO) << "Running MIR pass " << pass->name();
pass->Apply(graph_);
}
}
bool AddNewPass(const std::string& name, Pass* pass) {
passes_.emplace_back(pass);
pass_map_.emplace(name, passes_.back().get());
return true;
}
// Clear all the passes.
void Clear() { passes_.clear(); }
std::list<std::unique_ptr<mir::Pass>>::iterator passes_begin() {
return passes_.begin();
}
std::list<std::unique_ptr<mir::Pass>>::iterator passes_end() {
return passes_.end();
}
std::list<std::unique_ptr<mir::Pass>>::const_iterator passes_const_begin()
const {
return passes_.begin();
}
std::list<std::unique_ptr<mir::Pass>>::const_iterator passes_const_end()
const {
return passes_.end();
}
Pass* LookUp(const std::string& key) {
auto it = pass_map_.find(key);
CHECK(it != pass_map_.end());
return it->second;
}
private:
std::unique_ptr<mir::SSAGraph> graph_;
std::list<std::unique_ptr<mir::Pass>> passes_;
std::map<std::string, mir::Pass*> pass_map_;
};
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include <gtest/gtest.h>
namespace paddle {
namespace lite {
namespace mir {
TEST(PassManager, test) {
auto* pass = PassManager::Global().LookUp("demo");
LOG(INFO) << "pass: " << pass;
ASSERT_TRUE(pass != nullptr);
}
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/mir/pass_manager.h"
namespace paddle {
namespace lite {
namespace mir {
class PassRegistry {
public:
PassRegistry(const std::string& name, mir::Pass* pass) {
LOG(INFO) << "Registry add MIR pass " << name;
PassManager::Global().AddNewPass(name, pass);
}
bool Touch() const { return true; }
};
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/ssa_graph.h"
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <list>
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/mir/node.h"
#include "paddle/fluid/lite/core/op_lite.h"
namespace paddle {
namespace lite {
namespace mir {
// A program is used to represent a code program, in Paddle, a code program
// contains:
// - main block, which is a list of OpLite
// - scope: which contains all the weights
struct Program {
std::list<std::unique_ptr<OpLite>> ops;
lite::Scope *scope;
};
// An Graph for MIR. It is built from a list of Op and a scope.
class GraphBase {};
class SSAGraph : GraphBase {
public:
// @param program: the op program
// @param valid_places: the valid places user set for the system.
void Build(const Program &program, const std::vector<Place> &valid_places) {
for (auto &op : program.ops) {
node_storage_.emplace_back();
// TODO(Superjomn) remove one valid_places here.
op->SetValidPlaces(valid_places);
auto &new_node = node_storage_.back();
auto &new_kernel = node_storage_.back().AsInstruct();
new_kernel.valid_kernels = op->CreateKernels(valid_places);
CHECK(new_node.inlinks.empty()) << "duplicate Build found";
CHECK(new_node.outlinks.empty()) << "duplicate Build found";
// collect inputs and outputs
for (const std::string &name : op->input_names()) {
new_node.inlinks.push_back(arguments_.at(name));
}
for (const std::string &name : op->output_names()) {
new_node.outlinks.push_back(arguments_.at(name));
}
}
}
std::vector<mir::Node *> TopoloticalOrder() const;
private:
std::list<mir::Node> node_storage_;
std::map<std::string, mir::Node *> arguments_;
};
} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -44,5 +44,15 @@ void OpLite::PickKernel(const std::vector<Place> &valid_places,
}
}
bool OpLite::Run() {
CHECK(kernel_);
SyncInputEvents();
kernel_->Run();
RecordOutputEvents();
return true;
}
} // namespace lite
} // namespace paddle
......@@ -36,6 +36,11 @@ struct Registry {
void Touch() {}
};
namespace mir {
class Node;
class SSAGraph;
}
/**
* The base class of an light-weight operators, currently just used in inference
* to eliminate overhead of some operations in current framework.
......@@ -71,19 +76,13 @@ class OpLite : public Registry {
// Inference the outputs' shape.
virtual bool InferShape() const { return true; }
// Run this operator.
virtual bool Run() {
CHECK(kernel_);
SyncInputEvents();
kernel_->Run();
virtual bool Run();
RecordOutputEvents();
return true;
bool Attach(const framework::OpDesc &opdesc, lite::Scope *scope) {
ExtractInputsAndOutputs(opdesc);
return AttachImpl(opdesc, scope);
}
// Attach it with the runtime environment.
virtual bool Attach(const framework::OpDesc &opdesc, lite::Scope *scope) = 0;
// Human-readable information.
virtual std::string DebugString() const = 0;
......@@ -92,9 +91,29 @@ class OpLite : public Registry {
void PickKernel(const std::vector<Place> &valid_places,
KernelStrategy kernel_strategy = KernelStrategy::kStatic);
const std::list<std::string> &input_names() const { return input_names_; }
const std::list<std::string> &output_names() const { return output_names_; }
virtual ~OpLite() = default;
protected:
// Attach it with the runtime environment.
virtual bool AttachImpl(const framework::OpDesc &opdesc,
lite::Scope *scope) = 0;
void ExtractInputsAndOutputs(const framework::OpDesc &opdesc) {
for (const auto &item : opdesc.Inputs()) {
for (const auto &x : item.second) {
input_names_.push_back(x);
}
}
for (const auto &item : opdesc.Outputs()) {
for (const auto &x : item.second) {
output_names_.push_back(x);
}
}
}
// Specify the kernel to run by default. This will specify the value of
// `kernel_place_`.
virtual void StaticPickKernel(const std::vector<Place> &valid_targets) {
......@@ -113,12 +132,17 @@ class OpLite : public Registry {
std::vector<std::unique_ptr<KernelBase>> CreateKernels(
const std::vector<Place> &places, const std::string &kernel_type = "");
friend class mir::Node;
friend class mir::SSAGraph;
protected:
std::unique_ptr<OpContext> op_context_;
std::unique_ptr<KernelBase> kernel_;
std::string op_type_;
std::vector<Place> valid_places_;
Place kernel_place_{TARGET(kHost), PRECISION(kFloat)};
std::list<std::string> input_names_;
std::list<std::string> output_names_;
};
} // namespace lite
......
......@@ -13,3 +13,46 @@
// limitations under the License.
#include "paddle/fluid/lite/core/type_system.h"
namespace paddle {
namespace lite {
// ------------------------- GetType specification ----------------------------
template <>
const Type*
Type::Get<false /*is_unsupported*/, false /*is_tensor*/, TargetType::kHost,
PrecisionType::kFloat, DataLayoutType::kNCHW>() {
static UnsupportedTy x;
return &x;
}
template <>
const Type*
Type::Get<false /*is_unsupported*/, true /*is_tensor*/, TargetType::kX86,
PrecisionType::kFloat, DataLayoutType::kNCHW>() {
static TensorFp32NCHWTy x(TargetType::kX86);
return &x;
}
template <>
const Type* Type::Get<UnsupportedTy>(TargetType target, int device) {
return Get<false, false, TargetType::kHost, PrecisionType::kFloat,
DataLayoutType::kNCHW>();
}
template <>
const Type* Type::Get<TensorFp32NCHWTy>(TargetType target) {
switch (target) {
case TargetType::kX86:
return Get<false, true, TargetType::kX86, PrecisionType::kFloat,
DataLayoutType::kNCHW>();
default:
LOG(FATAL) << "unsupported target " << TargetToStr(target);
return nullptr;
}
}
// ------------------------- end GetType specification ------------------------
} // namespace lite
} // namespace paddle
......@@ -82,7 +82,7 @@ class DataTypeBase {
* Datatype with device info considered.
* NOTE A Type with different device is treated as different DeviceDataType.
*/
class DeviceDataType : public DataTypeBase {
class Type : public DataTypeBase {
public:
TargetType target() const { return place_.target; }
PrecisionType precision() const { return place_.precision; }
......@@ -90,23 +90,31 @@ class DeviceDataType : public DataTypeBase {
const Place& place() const { return place_; }
const std::string& name() const { return name_; }
bool operator==(const DeviceDataType& other) {
bool operator==(const Type& other) {
return id_ == other.id() && place_ == other.place();
}
// Can cast to another type. This is heavily used in MIR, by determine whether
// is is possible to add a instruction to transform a type to another.
virtual bool TypeCastable(const DeviceDataType& type) const {
return id_ == type.id();
}
virtual bool TypeCastable(const Type& type) const { return id_ == type.id(); }
template <bool is_unknown, bool is_tensor = true,
TargetType target = TargetType::kHost,
PrecisionType precision = PrecisionType::kFloat,
DataLayoutType layout = DataLayoutType::kNCHW>
// Get a type.
static const Type* Get();
template <typename TypeTy>
static const Type* Get(TargetType target = TargetType::kHost);
virtual ~DeviceDataType() = default;
virtual ~Type() = default;
protected:
DeviceDataType(ID id, const std::string& name, bool is_tensor,
TargetType target = TargetType::kHost,
PrecisionType precision = PrecisionType::kFloat,
DataLayoutType layout = DataLayoutType::kNCHW)
Type(ID id, const std::string& name, bool is_tensor,
TargetType target = TargetType::kHost,
PrecisionType precision = PrecisionType::kFloat,
DataLayoutType layout = DataLayoutType::kNCHW)
: DataTypeBase(id, is_tensor),
place_{target, precision, layout},
name_(name) {}
......@@ -117,30 +125,33 @@ class DeviceDataType : public DataTypeBase {
};
// -------------------------------- predefined types ---------------------------
class Void : public DeviceDataType {
// TODO(Superjomn) make all the Types' constructs protected to make sure there
// is only one instance across the system.
class VoidTy : public Type {
public:
VoidTy() : Type(ID::Void, "Void", false /*is_tensor*/) {}
};
class UnsupportedTy : public Type {
public:
Void() : DeviceDataType(ID::Void, "Void", false /*is_tensor*/) {}
UnsupportedTy() : Type(ID::Unsupported, "Unsupported", false /*is_tensor*/) {}
};
class TensorFp32NCHW : public DeviceDataType {
class TensorFp32NCHWTy : public Type {
public:
TensorFp32NCHW(TargetType target)
: DeviceDataType(ID::Tensor_Fp32_NCHW, "TensorFp32NCHW",
true /*is_tensor*/, target, PrecisionType::kFloat,
DataLayoutType::kNCHW) {}
TensorFp32NCHWTy(TargetType target)
: Type(ID::Tensor_Fp32_NCHW, "TensorFp32NCHW", true /*is_tensor*/, target,
PrecisionType::kFloat, DataLayoutType::kNCHW) {}
};
class TensorInt8NCHW : public DeviceDataType {
class TensorInt8NCHWTy : public Type {
public:
TensorInt8NCHW(TargetType target)
: DeviceDataType(ID::Tensor_Int8_NCHW, "TensorInt8NCHW",
true /*is_tensor*/, target, PrecisionType::kInt8,
DataLayoutType::kNCHW) {}
TensorInt8NCHWTy(TargetType target)
: Type(ID::Tensor_Int8_NCHW, "TensorInt8NCHW", true /*is_tensor*/, target,
PrecisionType::kInt8, DataLayoutType::kNCHW) {}
};
class TensorInt64NCHW : public DeviceDataType {
class TensorInt64NCHWTy : public Type {
public:
TensorInt64NCHW(TargetType target)
: DeviceDataType(ID::Tensor_Int64_NCHW, "TensorInt64NCHW",
true /*is_tensor*/, target, PrecisionType::kInt8,
DataLayoutType::kNCHW) {}
TensorInt64NCHWTy(TargetType target)
: Type(ID::Tensor_Int64_NCHW, "TensorInt64NCHW", true /*is_tensor*/,
target, PrecisionType::kInt8, DataLayoutType::kNCHW) {}
};
// ------------------------- end predefined types ---------------------------
......
......@@ -15,6 +15,7 @@
#include "paddle/fluid/lite/kernels/host/fc_compute.h"
#include <Eigen/Core>
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/type_system.h"
namespace paddle {
namespace lite {
......@@ -51,6 +52,8 @@ PrecisionType FcCompute::precision() const { return PRECISION(kFloat); }
} // namespace paddle
REGISTER_LITE_KERNEL(fc, kHost, kFloat, paddle::lite::kernels::host::FcCompute)
.BindInput(0, {typeid(paddle::lite::Tensor).hash_code(),
paddle::lite::Place{TARGET(kHost), PRECISION(kFloat)}})
.BindInput(0, {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kX86))})
.BindOutput(0, {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kX86))})
.Finalize();
......@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/kernel.h"
......@@ -44,7 +46,8 @@ class FcOpLite : public OpLite {
*/
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool Attach(const framework::OpDesc &op_desc, lite::Scope *scope) override {
bool AttachImpl(const framework::OpDesc &op_desc,
lite::Scope *scope) override {
auto input = op_desc.Input("Input").front();
auto W = op_desc.Input("W").front();
auto bias = op_desc.Input("Bias").front();
......
......@@ -61,7 +61,7 @@ TEST(fc_op_lite, test) {
fc.SetValidPlaces({Place{TARGET(kHost), PRECISION(kFloat)}});
fc.PickKernel({Place{TARGET(kHost), PRECISION(kFloat)}});
fc.Attach(desc, &scope);
fc.AttachImpl(desc, &scope);
fc.Run();
for (int i = 0; i < 10 * 20; i++) {
......
......@@ -37,7 +37,8 @@ class MulOpLite : public OpLite {
bool InferShape() const override;
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool Attach(const framework::OpDesc &op_desc, lite::Scope *scope) override {
bool AttachImpl(const framework::OpDesc &op_desc,
lite::Scope *scope) override {
auto input = op_desc.Input("X").front();
auto W = op_desc.Input("Y").front();
auto out = op_desc.Output("Out").front();
......
......@@ -31,7 +31,7 @@ bool ReluOp::InferShape() const {
return true;
}
bool ReluOp::Attach(const framework::OpDesc &opdesc, lite::Scope *scope) {
bool ReluOp::AttachImpl(const framework::OpDesc &opdesc, lite::Scope *scope) {
param_.input = const_cast<Tensor *>(
&scope->FindVar(opdesc.Input("Input").front())->Get<Tensor>());
param_.output =
......
......@@ -32,7 +32,7 @@ class ReluOp : public OpLite {
bool InferShape() const override;
bool Attach(const framework::OpDesc &opdesc, lite::Scope *scope) override;
bool AttachImpl(const framework::OpDesc &opdesc, lite::Scope *scope) override;
std::string DebugString() const override { return "tanh"; }
......
......@@ -44,7 +44,8 @@ class ScaleOp : public OpLite {
}
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool Attach(const framework::OpDesc &op_desc, lite::Scope *scope) override {
bool AttachImpl(const framework::OpDesc &op_desc,
lite::Scope *scope) override {
auto x = op_desc.Input("X").front();
auto out = op_desc.Output("Out").front();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册