提交 621d1522 编写于 作者: S superjomn

make io_copy kernel pick works

上级 1fb93746
......@@ -42,6 +42,7 @@ class OpDesc {
void CopyFrom(const OpDesc &op_desc);
proto::OpDesc *Proto();
const proto::OpDesc &ReadonlyProto() const { return desc_; }
std::string Type() const { return desc_.type(); }
......
cc_library(cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host)
cc_library(cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host )
cc_test(test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite)
if(LITE_WITH_CUDA)
cc_library(cxx_api_lite_cuda SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host target_wrapper_cuda kernels_cuda)
nv_test(test_cxx_api_lite_cuda SRCS cxx_api_test.cc DEPS cxx_api_lite_cuda model_parser_lite)
endif()
......@@ -29,7 +29,7 @@ class Predictor {
public:
Predictor() { scope_ = std::make_shared<Scope>(); }
void Build(const std::string& model_path,
void Build(const std::string& model_path, const Place& prefer_place,
const std::vector<Place>& valid_places) {
framework::proto::ProgramDesc prog;
LoadModel(model_path, scope_.get(), &prog);
......@@ -38,6 +38,7 @@ class Predictor {
Program program(prog_desc, scope_, valid_places);
Optimizer optimizer;
optimizer.KernelPickPreferPlace(prefer_place);
core::KernelPickFactor factor;
factor.ConsiderTarget();
optimizer.Run(std::move(program), valid_places, factor);
......
......@@ -23,8 +23,21 @@ namespace lite {
TEST(CXXApi, test) {
lite::Predictor predictor;
#ifndef LITE_WITH_CUDA
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)}});
#else
std::vector<Place> valid_places({
Place{TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)},
Place{TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW)},
Place{TARGET(kCUDA), PRECISION(kAny), DATALAYOUT(kNCHW)},
Place{TARGET(kHost), PRECISION(kAny), DATALAYOUT(kNCHW)},
Place{TARGET(kCUDA), PRECISION(kAny), DATALAYOUT(kAny)},
Place{TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)},
});
#endif
predictor.Build("/home/chunwei/project2/models/model2",
{Place{TARGET(kHost), PRECISION(kFloat)}});
Place{TARGET(kCUDA), PRECISION(kFloat)}, valid_places);
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize({100, 100});
......@@ -54,8 +67,15 @@ USE_LITE_OP(fc);
USE_LITE_OP(scale);
USE_LITE_OP(feed);
USE_LITE_OP(fetch);
USE_LITE_KERNEL(fc, kHost, kFloat, def);
USE_LITE_KERNEL(mul, kHost, kFloat, def);
USE_LITE_KERNEL(scale, kHost, kFloat, def);
USE_LITE_KERNEL(feed, kHost, kFloat, def);
USE_LITE_KERNEL(fetch, kHost, kFloat, def);
USE_LITE_OP(io_copy);
USE_LITE_KERNEL(fc, kHost, kFloat, kNCHW, def);
USE_LITE_KERNEL(mul, kHost, kFloat, kNCHW, def);
USE_LITE_KERNEL(scale, kHost, kFloat, kNCHW, def);
USE_LITE_KERNEL(feed, kHost, kAny, kAny, def);
USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def);
#ifdef LITE_WITH_CUDA
USE_LITE_KERNEL(mul, kCUDA, kFloat, kNCHW, def);
USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, host_to_device);
USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, device_to_host);
#endif
......@@ -27,5 +27,6 @@ cc_test(test_tensor_lite SRCS tensor_test.cc)
cc_test(test_op_executor_lite SRCS op_executor_test.cc DEPS op_executor_lite ops_lite host_kernels)
cc_test(test_type_system SRCS type_system_test.cc DEPS type_system)
cc_test(test_optimizer_lite SRCS optimizer_test.cc DEPS mir_pass_manager program_fake_utils mir_passes)
cc_test(test_types_lite SRCS types_test.cc DEPS types_lite)
add_subdirectory(mir)
......@@ -17,6 +17,13 @@
namespace paddle {
namespace lite {
std::string KernelBase::summary() const {
std::stringstream ss;
ss << op_type() << ":" << TargetToStr(target()) << "/"
<< PrecisionToStr(precision()) << "/" << DataLayoutToStr(layout());
return ss.str();
}
bool ParamTypeRegistry::KeyCmp::operator()(
const ParamTypeRegistry::key_t &a,
const ParamTypeRegistry::key_t &b) const {
......
......@@ -17,6 +17,7 @@
#include <map>
#include <memory>
#include <set>
#include <sstream>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_desc.h"
......@@ -34,49 +35,100 @@ namespace lite {
// different targets.
class KernelBase {
public:
// type_infer_handler is used to inference a output type by considering the
// input types in the type system.
using type_infer_handler_t = std::function<const Type*(
const std::map<std::string, const Type*>& input_types,
const std::string& out_arg)>;
virtual void Run() = 0;
void SetContext(std::unique_ptr<KernelContext>&& ctx) {
context_ = std::move(ctx);
}
template <typename T>
void SetParam(T param) {
param_.set<T>(param);
}
template <typename P>
P& Param() const {
return param_.get<P>();
}
// This is used in the kernels that takes 'kAny' places and inference the
// output place. For `ScaleCompute` and `IoCopyCompute`, their input types are
// declared as 'kAny' in some Place field, and the output is also `kAny`, but
// when in real execution, when takes some non-kAny type as input, the
// output's kAny-fields can be determained. For example, when the
// `ScaleCompute` takes `TensorFp32NCHWTy` as input, its output should be also
// `TensorFp32NCHWTy`. This type inference rule is different for each kernel,
// so we make it a virtual method.
// One can custom this handler to make a specific type inference rule for a
// kernel, or leave the default to force the kernel use the system's
// type-inference rules.
virtual std::unique_ptr<type_infer_handler_t> GetTypeInferHandler() {
return nullptr;
}
void set_op_type(const std::string& type) { op_type_ = type; }
const std::string& op_type() const { return op_type_; }
void Torch() {}
// Get input declaration type.
const Type* GetInputDeclType(const std::string& arg_name) {
CHECK(!op_type_.empty()) << "op_type should be set first";
const auto* type = ParamTypeRegistry::Global().RetrieveInArgument(
place(), GenParamTypeKey(), arg_name);
CHECK(type) << "no type registered for kernel [" << op_type_
<< "] input argument [" << arg_name << "]"
<< " with key " << GenParamTypeKey();
return type->type;
}
// Get output declaration type.
const Type* GetOutputDeclType(const std::string& arg_name) {
CHECK(!op_type_.empty()) << "op_type should be set first";
const auto* type = ParamTypeRegistry::Global().RetrieveOutArgument(
place(), GenParamTypeKey(), arg_name);
CHECK(type) << "no type registered for kernel [" << op_type_
<< "] output argument [" << arg_name << "]";
return type->type;
}
void set_alias(const std::string& x) {
alias_ = x;
LOG(INFO) << "kernel " << op_type() << " setting alias " << alias();
}
const std::string& alias() const { return alias_; }
virtual Place place() const = 0;
virtual TargetType target() const = 0;
virtual PrecisionType precision() const = 0;
virtual DataLayoutType layout() const = 0;
const KernelContext* context() const { return context_.get(); }
virtual std::string name() const = 0;
virtual ~KernelBase() = default;
// Short human-readable document.
std::string summary() const;
// Long human-readable document.
virtual std::string doc() const { return ""; }
std::string DebugString() const {
std::string GenParamTypeKey() const {
std::stringstream ss;
ss << op_type() << ":" << TargetToStr(target()) << "/"
<< PrecisionToStr(precision()) << "/" << DataLayoutToStr(layout());
LOG(INFO) << "alias : " << alias_;
ss << op_type() << "/" << alias_;
return ss.str();
}
virtual ~KernelBase() = default;
protected:
std::unique_ptr<KernelContext> context_;
mutable operators::param_t param_;
// The corresponding op type.
std::string op_type_;
std::string op_type_{};
std::string alias_{};
};
// Light-weight kernel implementation.
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "memory.h"
#include "paddle/fluid/lite/core/memory.h"
namespace paddle {
namespace framework {
......
......@@ -14,7 +14,7 @@
#pragma once
#include <glog/logging.h>
#include "target_wrapper.h"
#include "paddle/fluid/lite/core/target_wrapper.h"
namespace paddle {
namespace lite {
......@@ -26,9 +26,12 @@ static void* TargetMalloc(TargetType target, size_t size) {
case TargetType::kX86:
data = TargetWrapper<TARGET(kHost)>::Malloc(size);
break;
#ifdef LITE_WITH_CUDA
case TargetType::kCUDA:
data = TargetWrapper<TARGET(kCUDA)>::Malloc(size);
data =
TargetWrapper<TARGET(kCUDA), cudaStream_t, cudaEvent_t>::Malloc(size);
break;
#endif // LITE_WITH_CUDA
default:
LOG(FATAL) << "Unknown supported target " << TargetToStr(target);
}
......
......@@ -7,9 +7,12 @@ cc_library(mir_passes
SRCS static_kernel_pick_pass.cc
variable_place_inference_pass.cc
io_complement_pass.cc
io_copy_kernel_pick_pass.cc
graph_visualize_pass.cc
generate_program_pass.cc
argument_type_display_pass.cc
demo_pass.cc
runtime_context_assign_pass.cc
DEPS mir_pass types_lite)
cc_test(test_mir_pass_manager SRCS pass_manager_test.cc DEPS mir_pass_manager mir_passes)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
class ArgumentTypeDisplayPass : public DebugPass {
public:
void Apply(std::unique_ptr<mir::SSAGraph>& graph) override {
LOG(INFO) << "== Argument types ==";
for (auto& node : graph->mutable_nodes()) {
if (!node.IsArgument()) continue;
auto* type = node.AsArgument().type;
if (type) {
LOG(INFO) << "* ARG " << node.AsArgument().name << " type: " << *type;
} else {
LOG(INFO) << "* ARG " << node.AsArgument().name << " type: UNK";
}
}
LOG(INFO) << "---------------------";
}
};
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(argument_type_display_pass,
paddle::lite::mir::ArgumentTypeDisplayPass);
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/lite/core/mir/generate_program_pass.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
......@@ -20,9 +21,11 @@ namespace lite {
namespace mir {
void GenerateProgramPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
LOG(INFO) << "final program \n" << Visualize(graph.get());
for (auto& item : graph->InstructTopologicalOrder()) {
if (item->IsInstruct()) {
auto& instruct = item->AsInstruct();
LOG(INFO) << instruct;
insts_.emplace_back(instruct.op,
std::move(instruct.valid_kernels.front()));
}
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/lite/core/mir/io_complement_pass.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
......@@ -21,28 +22,161 @@ namespace mir {
void IoComplementPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
// Start from inputs of the graph, those should have place set.
std::list<Node*> nodes;
for (auto& node : graph->mutable_nodes()) {
if (!node.IsInstruct()) continue;
auto& inst = node.AsInstruct();
nodes.push_back(&node);
}
CHECK(!valid_places_.empty());
for (auto& node : nodes) {
if (!node->IsInstruct()) continue;
auto inlinks = node->inlinks;
for (auto* in : inlinks) {
ComplementInputs(graph.get(), node, in);
}
}
// inputs
for (auto* in : node.inlinks) {
// PickIoCopyKernel(graph.get());
LOG(INFO) << "\n" << Visualize(graph.get());
}
void IoComplementPass::ComplementInputs(SSAGraph* graph, Node* inst_node,
Node* in) {
// If this input is out of date.
if (inst_node->inlinks.end() ==
std::find(inst_node->inlinks.begin(), inst_node->inlinks.end(), in))
return;
CHECK(inst_node->IsInstruct());
auto& inst = inst_node->AsInstruct();
CHECK(in->IsRoleSet());
CHECK(in->IsArgument());
auto name = in->AsArgument().name;
auto in_arg_name = in->AsArgument().name;
std::string tmp;
CHECK(inst.op_info->GetInputArgname(name, &tmp));
auto type =
ParamTypeRegistry::Global().Retrieve<ParamTypeRegistry::IO::kInput>(
inst.place, inst.op_type, tmp);
CHECK(type) << "no param type found for " << inst.op_type << ":" << name
<< " " << inst.place;
CHECK(type->type);
CHECK(inst.op_info()->GetInputArgname(in_arg_name, &tmp));
auto decl_arg_type = inst.picked_kernel().GetInputDeclType(tmp);
CHECK(in->AsArgument().type);
if (!TypeCompatible(*type->type, *in->AsArgument().type)) {
LOG(INFO) << "found IO unmatched tensor: " << in->AsArgument().name;
if (!TypeCompatibleTo(*in->AsArgument().type, *decl_arg_type)) {
LOG(INFO) << "found IO unmatched tensor: " << in->AsArgument().name
<< " for kernel " << inst.op->DebugString() << " "
<< *in->AsArgument().type << " -> " << *decl_arg_type;
// Add an IoCopy instruction to make the input compatible with other dist.
AddIoCopyInst(*in->AsArgument().type, *decl_arg_type, in->AsArgument().name,
graph, inst_node, valid_places_);
}
}
void UpdateOpdescInputName(framework::OpDesc* desc,
const std::string& old_arg_name,
const std::string& new_arg_name) {
for (auto& item : *desc->Proto()->mutable_inputs()) {
for (int i = 0; i < item.mutable_arguments()->size(); i++) {
auto* arg = item.mutable_arguments(i);
if (*arg == old_arg_name) {
*arg = new_arg_name;
}
}
}
}
void IoComplementPass::AddIoCopyInst(const Type& from, const Type& to,
const std::string& var, SSAGraph* graph,
Node* inst_node,
const std::vector<Place>& valid_places) {
CHECK(!valid_places.empty()) << "valid_place should be set";
// var -> new_transform_op -> new_var -> inst
// So there will be a new Argument node and a new IoCopy Instruct Node.
auto node_id = [&] { return graph->nodes().size(); };
auto io_copy_output_name = var + "/trans/" + std::to_string(node_id());
auto* io_copy_output_arg = graph->NewArgumentNode(io_copy_output_name);
auto* io_copy_inst = graph->NewInstructNode();
// create Op and kernels.
auto io_copy_op = LiteOpRegistry::Global().Create("io_copy");
// CHECK(io_copy_op);
// Create the new var manually.
inst_node->AsInstruct().op->scope()->Var(io_copy_output_name);
// Create IoCopy Instruction.
framework::OpDesc op_desc;
op_desc.SetType("io_copy");
op_desc.SetInput("Input", {var});
op_desc.SetOutput("Out", {io_copy_output_name});
op_desc.Flush();
io_copy_op->Attach(op_desc, inst_node->AsInstruct().op->scope());
auto kernels = io_copy_op->CreateKernels(valid_places);
io_copy_inst->AsInstruct("io_copy", std::move(kernels), io_copy_op);
// Remove the old link
RemoveDirectedLink(graph->Argument(var), inst_node);
// Update the original instruction OpDesc.
// Update its input to the io_copy_output_name
auto& inst = inst_node->AsInstruct();
auto inst_program_desc = inst.op_info()->desc();
// Add new link, var -> new_inst, new_inst->newarg, newarg->inst
DirectedLink(graph->Argument(var), io_copy_inst);
DirectedLink(io_copy_inst, io_copy_output_arg);
DirectedLink(io_copy_output_arg, inst_node);
// reset opdesc and update kernel information
auto desc_dummy = inst_node->AsInstruct().op->op_info()->desc();
UpdateInputTo(&desc_dummy, var, io_copy_output_name);
framework::OpDesc desc_fake(desc_dummy, nullptr);
inst_node->AsInstruct().op->Attach(desc_fake,
inst_node->AsInstruct().op->scope());
std::string tmp;
if (inst_node->AsInstruct().op_info()->GetInputArgname("a", &tmp)) {
CHECK(false) << "get old a " << tmp;
}
for (auto& kernel : inst_node->AsInstruct().valid_kernels) {
inst_node->AsInstruct().op->AttachKernel(kernel.get());
}
graph->CheckValid();
}
void IoComplementPass::PickIoCopyKernel(SSAGraph* graph) {
for (auto& node : graph->mutable_nodes()) {
if (node.IsInstruct() && node.AsInstruct().op_type == "io_copy") {
auto& kernels = node.AsInstruct().valid_kernels;
CHECK(!kernels.empty()) << "No valid kernels found for IoCopy Op";
for (auto& kernel : kernels) {
CHECK_EQ(node.inlinks.size(), 1UL);
CHECK_EQ(node.outlinks.size(), 1UL);
auto* inty = node.inlinks.front()->AsArgument().type;
auto* outy = node.outlinks.front()->AsArgument().type;
const Type* in_arg_ty = kernel->GetInputDeclType("Input");
if (TypeCompatibleTo(*inty, *in_arg_ty)) {
const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
// Both the input and output type matches, remove other kernels
// directly.
if (out_arg_ty->target() == outy->target()) {
LOG(INFO) << "get a IOCopy kernel";
auto x = std::move(kernel);
kernels.clear();
kernels.emplace_back(std::move(x));
break;
}
}
}
}
}
// Check the compatiblity.
}
void IoComplementPass::SetValidPlaces(const std::vector<Place>& valid_places) {
CHECK(!valid_places.empty());
valid_places_ = valid_places;
}
} // namespace mir
......
......@@ -15,18 +15,47 @@
#pragma once
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace mir {
static void UpdateInputTo(framework::proto::OpDesc* desc,
const std::string& from, const std::string& to) {
for (auto& item : *desc->mutable_inputs()) {
for (auto& input : *item.mutable_arguments()) {
if (input == from) {
LOG(INFO) << "** update input argument from " << from << " to " << to;
input = to;
}
}
}
}
/*
* IoComplementPass complement the necessary instruction to make data
* transferring or transformation between different places.
*/
class IoComplementPass : public ProgramPass {
public:
void Apply(std::unique_ptr<mir::SSAGraph> &graph) override;
void Apply(std::unique_ptr<mir::SSAGraph>& graph) override;
void ComplementInputs(SSAGraph* graph, Node* inst_node, Node* in);
void AddIoCopyInst(const Type& from, const Type& to, const std::string& var,
SSAGraph* graph, Node* inst_node,
const std::vector<Place>& valid_places);
void SetValidPlaces(const std::vector<Place>& valid_places);
// Pick the right kernel of IoCopy considering the input and output Type.
void PickIoCopyKernel(SSAGraph* graph);
const std::vector<Place>& valid_places() const { return valid_places_; };
private:
std::vector<Place> valid_places_;
};
} // namespace mir
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
class IoCopyKernelPickPass : public InstructionPass {
public:
void Apply(std::unique_ptr<mir::SSAGraph>& graph) override {
for (auto& node : graph->mutable_nodes()) {
if (!node.IsInstruct()) continue;
auto& inst = node.AsInstruct();
if (inst.op_type != "io_copy") continue;
LOG(INFO) << "....> picking a IO COPY kernel";
auto& kernels = node.AsInstruct().valid_kernels;
CHECK(!kernels.empty()) << "No valid kernels found for IoCopy Op";
const auto* inty = node.inlinks.front()->AsArgument().type;
const auto* outy = node.outlinks.front()->AsArgument().type;
LOG(INFO) << "input type " << *inty;
LOG(INFO) << "output type " << *outy;
bool is_found = false;
LOG(INFO) << "kernels size " << kernels.size();
for (auto& kernel : kernels) {
CHECK_EQ(node.inlinks.size(), 1UL);
CHECK_EQ(node.outlinks.size(), 1UL);
const Type* in_arg_ty = kernel->GetInputDeclType("Input");
const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
LOG(INFO) << "checking kernel candidate " << *in_arg_ty << "->"
<< *out_arg_ty;
if (inty->target() == in_arg_ty->target()) {
// Both the input and output type matches, remove other kernels
// directly.
if (out_arg_ty->target() == outy->target()) {
LOG(INFO) << "get a IOCopy kernel";
auto x = std::move(kernel);
kernels.clear();
kernels.emplace_back(std::move(x));
is_found = true;
break;
}
}
}
CHECK(is_found) << "Can't find a IoCopy kernel for IO: " << *inty << "->"
<< *outy;
}
}
};
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(io_copy_kernel_pick_pass,
paddle::lite::mir::IoCopyKernelPickPass);
......@@ -34,30 +34,43 @@ class Node {
Node() = default;
enum class Role {
kUnk = -1,
kArgument,
kArgument = 0,
kInstruct,
kNumRoles /*should be last*/
kNumRoles, /*should be last*/
kUnk,
};
struct Instruct {
std::string op_type;
Place place;
// The kernel instances this Instruct contains.
std::vector<std::unique_ptr<KernelBase>> valid_kernels;
std::shared_ptr<OpInfo> op_info;
// TODO(Superjomn) make this a shared_ptr for resource safety.
std::shared_ptr<OpLite> op; // we hold op to run InferShape
const OpInfo* op_info() {
CHECK(op);
return op->op_info();
}
Place place() const {
CHECK(!valid_kernels.empty());
return valid_kernels.front()->place();
}
KernelBase& picked_kernel() {
CHECK(!valid_kernels.empty());
return *valid_kernels.front();
}
friend std::ostream& operator<<(std::ostream& os, const Instruct& other) {
os << "Instruct " << other.op_type << " " << other.place();
return os;
}
};
struct Argument {
std::string name;
const Type* type;
const Type* type{};
// Weight is a special kind of argument, it is marked as weight explicitly
// so that some weight related optimization can take place.
bool is_weight{false};
......@@ -71,13 +84,11 @@ class Node {
Instruct& AsInstruct(const std::string& op_type,
std::vector<std::unique_ptr<KernelBase>>&& kernels,
const std::shared_ptr<OpLite>& op,
const std::shared_ptr<lite::OpInfo>& op_info) {
const std::shared_ptr<OpLite>& op) {
auto& x = AsInstruct();
x.op_type = op_type;
x.op = op;
x.valid_kernels = std::move(kernels);
x.op_info = op_info;
return x;
}
......@@ -100,8 +111,25 @@ class Node {
instruct_.reset(new Instruct);
return *instruct_;
}
friend std::ostream& operator<<(std::ostream& os, Node& other) {
os << static_cast<int>(other.role_) << " ";
if (!other.IsRoleSet()) {
os << "Unk role node";
}
if (other.IsArgument()) {
auto& arg = other.AsArgument();
os << "Argument " << arg.name;
}
if (other.IsInstruct()) {
auto& arg = other.AsInstruct();
os << "Instruct " << arg.op_type;
}
return os;
}
// Check roles.
bool IsRoleSet() const { return role_ == Role::kUnk; }
bool IsRoleSet() const { return role_ != Role::kUnk; }
bool IsInstruct() const { return role_ == Role::kInstruct; }
bool IsArgument() const { return role_ == Role::kArgument; }
......
......@@ -26,3 +26,5 @@ USE_MIR_PASS(static_kernel_pick_pass);
USE_MIR_PASS(variable_place_inference_pass);
USE_MIR_PASS(io_complement_pass);
USE_MIR_PASS(generate_program_pass);
USE_MIR_PASS(io_copy_kernel_pick_pass);
USE_MIR_PASS(argument_type_display_pass);
......@@ -89,6 +89,144 @@ std::vector<mir::Node *> SSAGraph::InstructTopologicalOrder() {
return res;
}
void SSAGraph::GraphCreateTmpVarNodes(const Program &program) {
for (const auto &name : program.tmp_vars) {
LOG(INFO) << "create arg node " << name;
node_storage_.emplace_back();
auto &new_node = node_storage_.back();
new_node.AsArgument(name);
arguments_[name] = &new_node;
}
}
void SSAGraph::GraphCreateWeightVarNodes(const Program &program) {
// create weight nodes.
for (const auto &name : program.weights) {
LOG(INFO) << "create arg node " << name;
node_storage_.emplace_back();
auto &new_node = node_storage_.back();
new_node.AsArgument(name);
arguments_[name] = &new_node;
}
}
Node *SSAGraph::GraphCreateInstructNode(
const Program &program, const std::shared_ptr<OpLite> &op,
const std::vector<Place> &valid_places) {
node_storage_.emplace_back();
// TODO(Superjomn) remove one valid_places here.
op->SetValidPlaces(valid_places);
auto &new_node = node_storage_.back();
auto kernels = op->CreateKernels(valid_places);
node_storage_.back().AsInstruct(op->op_type_, std::move(kernels), op);
CHECK(new_node.inlinks.empty()) << "duplicate Build found";
CHECK(new_node.outlinks.empty()) << "duplicate Build found";
return &node_storage_.back();
}
void SSAGraph::Build(const Program &program,
const std::vector<Place> &valid_places) {
CHECK(node_storage_.empty());
GraphCreateTmpVarNodes(program);
GraphCreateWeightVarNodes(program);
CHECK(CheckNodesRoleSet());
for (auto &op : program.ops) {
auto *op_node = GraphCreateInstructNode(program, op, valid_places);
LOG(INFO) << "checking op " << op->op_type_;
for (const std::string &name : op->op_info()->input_names()) {
auto *arg = Argument(name);
LOG(INFO) << "input " << name;
CHECK(arg->IsRoleSet());
DirectedLink(arg, op_node);
}
for (const std::string &name : op->op_info()->output_names()) {
if (!arguments_.count(name)) {
NewArgumentNode(name);
}
LOG(INFO) << "output " << name;
auto *arg = arguments_.at(name);
CHECK(arg->IsRoleSet());
DirectedLink(op_node, arg);
}
CHECK(CheckLinksRoleSet());
}
MarkArgumentWeights(program);
CheckValid();
}
mir::Node *SSAGraph::Argument(const std::string &name) {
auto it = arguments_.find(name);
CHECK(it != arguments_.end()) << "no argument called " << name;
return it->second;
}
std::vector<mir::Node *> SSAGraph::inputs() {
std::vector<mir::Node *> res;
for (auto &node : node_storage_) {
if (node.inlinks.empty()) {
res.push_back(&node);
}
}
return res;
}
std::vector<mir::Node *> SSAGraph::outputs() {
std::vector<mir::Node *> res;
for (auto &node : node_storage_) {
if (node.outlinks.empty()) {
res.push_back(&node);
}
}
return res;
}
mir::Node *SSAGraph::RetrieveArgument(const std::string &arg) {
auto it = arguments_.find(arg);
if (it != arguments_.end()) {
return it->second;
}
return nullptr;
}
bool SSAGraph::CheckNodesRoleSet() {
for (auto &node : mutable_nodes()) {
CHECK_OR_FALSE(node.IsRoleSet());
}
return true;
}
bool SSAGraph::CheckLinksRoleSet() {
for (auto &node : mutable_nodes()) {
CHECK_OR_FALSE(node.IsRoleSet());
if (!node.IsInstruct()) continue;
for (auto *x : node.inlinks) {
CHECK_OR_FALSE(x->IsRoleSet());
CHECK_OR_FALSE(x->IsArgument());
}
for (auto *x : node.outlinks) {
CHECK_OR_FALSE(x->IsRoleSet());
CHECK_OR_FALSE(x->IsArgument());
}
}
return true;
}
Node *SSAGraph::NewArgumentNode(const std::string &name) {
node_storage_.emplace_back();
CHECK(!arguments_.count(name)) << "duplicate argument called " << name;
arguments_[name] = &node_storage_.back();
node_storage_.back().AsArgument(name);
return &node_storage_.back();
}
Node *SSAGraph::NewInstructNode() {
node_storage_.emplace_back();
return &node_storage_.back();
}
} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -35,104 +35,44 @@ class SSAGraph : GraphBase {
public:
// @param program: the op program
// @param valid_places: the valid places user set for the system.
void Build(const Program &program, const std::vector<Place> &valid_places) {
// create temporary nodes.
for (const auto &name : program.tmp_vars) {
node_storage_.emplace_back();
auto &new_node = node_storage_.back();
auto &arg = new_node.AsArgument();
arg.name = name;
arguments_[name] = &new_node;
}
void Build(const Program &program, const std::vector<Place> &valid_places);
// create weight nodes.
for (const auto &name : program.weights) {
node_storage_.emplace_back();
auto &new_node = node_storage_.back();
auto &arg = new_node.AsArgument();
arg.name = name;
arguments_[name] = &new_node;
}
for (auto &op : program.ops) {
node_storage_.emplace_back();
// TODO(Superjomn) remove one valid_places here.
op->SetValidPlaces(valid_places);
auto &new_node = node_storage_.back();
auto kernels = op->CreateKernels(valid_places);
node_storage_.back().AsInstruct(op->op_type_, std::move(kernels), op,
op->op_info());
CHECK(new_node.inlinks.empty()) << "duplicate Build found";
CHECK(new_node.outlinks.empty()) << "duplicate Build found";
// collect inputs and outputs
for (const std::string &name : op->op_info()->input_names()) {
auto *arg = Argument(name);
new_node.inlinks.push_back(arg);
arg->outlinks.push_back(&new_node);
}
for (const std::string &name : op->op_info()->output_names()) {
if (!arguments_.count(name)) {
node_storage_.emplace_back();
auto &new_node = node_storage_.back();
auto &arg = new_node.AsArgument(name);
arg.name = name;
arguments_.emplace(name, &new_node);
}
auto *arg = arguments_.at(name);
new_node.outlinks.push_back(arg);
arg->inlinks.push_back(&new_node);
}
}
MarkArgumentWeights(program);
}
mir::Node *Argument(const std::string &name) {
auto it = arguments_.find(name);
CHECK(it != arguments_.end()) << "no argument called " << name;
return it->second;
}
mir::Node *Argument(const std::string &name);
std::vector<mir::Node *> InstructTopologicalOrder();
// The inputs of the graph.
std::vector<mir::Node *> inputs() {
std::vector<mir::Node *> res;
for (auto &node : node_storage_) {
if (node.inlinks.empty()) {
res.push_back(&node);
}
}
return res;
}
std::vector<mir::Node *> inputs();
// The outputs of the graph.
std::vector<mir::Node *> outputs() {
std::vector<mir::Node *> res;
for (auto &node : node_storage_) {
if (node.outlinks.empty()) {
res.push_back(&node);
}
}
return res;
}
std::vector<mir::Node *> outputs();
const std::list<mir::Node> &nodes() const { return node_storage_; }
std::list<mir::Node> &mutable_nodes() { return node_storage_; }
mir::Node *RetrieveArgument(const std::string &arg) {
auto it = arguments_.find(arg);
if (it != arguments_.end()) {
return it->second;
}
return nullptr;
mir::Node *RetrieveArgument(const std::string &arg);
Node *NewArgumentNode(const std::string &name);
Node *NewInstructNode();
void CheckValid() {
CHECK(CheckBidirectionalConnection());
CHECK(CheckNodesRoleSet());
CHECK(CheckLinksRoleSet());
}
private:
void GraphCreateTmpVarNodes(const Program &program);
void GraphCreateWeightVarNodes(const Program &program);
Node *GraphCreateInstructNode(const Program &program,
const std::shared_ptr<OpLite> &op,
const std::vector<Place> &valid_places);
// Check the bidirectional connection.
bool CheckBidirectionalConnection();
bool CheckNodesRoleSet();
// Check all the items's role in inlinks and outlinks is set.
bool CheckLinksRoleSet();
void MarkArgumentWeights(const Program &program) {
for (const auto &name : program.weights) {
......@@ -152,6 +92,48 @@ class SSAGraph : GraphBase {
std::map<std::string, mir::Node *> arguments_;
};
// Remove the link between a -> b.
static void RemoveDirectedLink(Node *a, Node *b) {
auto it = std::find(b->inlinks.begin(), b->inlinks.end(), a);
if (it != b->inlinks.end()) {
b->inlinks.erase(it);
}
auto it1 = std::find(a->outlinks.begin(), a->outlinks.end(), b);
if (it1 != a->outlinks.end()) {
a->outlinks.erase((it1));
}
}
// Link a -> b.
static void DirectedLink(Node *a, Node *b) {
// Eagerly remove first, to avoid duplicate link.
RemoveDirectedLink(a, b);
a->outlinks.push_back(b);
b->inlinks.push_back(a);
}
static void LocalInferenceType(Node *a, Node *b, const std::string &arg_name) {
// instr -> output argument
if (a->IsInstruct() && b->IsArgument()) {
auto &inst = a->AsInstruct();
auto &output = b->AsArgument();
if (!output.type) {
output.type = inst.picked_kernel().GetOutputDeclType(arg_name);
}
}
// input argument -> instr
if (a->IsArgument() && b->IsInstruct()) {
auto &input = a->AsArgument();
auto &inst = b->AsInstruct();
if (!input.type) {
input.type = inst.picked_kernel().GetInputDeclType(arg_name);
}
}
}
} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -37,7 +37,9 @@ void StaticKernelPickPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
auto& instruct = node.AsInstruct();
std::vector<std::pair<size_t, std::unique_ptr<KernelBase>>> scored;
for (auto&& kernel : instruct.valid_kernels) {
scored.emplace_back(KernelGrade(*kernel), std::move(kernel));
size_t score = KernelGrade(*kernel);
LOG(INFO) << "kernel " << kernel->summary() << " " << score;
scored.emplace_back(score, std::move(kernel));
}
std::sort(scored.begin(), scored.end(), KernelScoreCmp);
......@@ -47,7 +49,6 @@ void StaticKernelPickPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
// TODO(Superjomn) reconsider this.
instruct.valid_kernels.clear();
instruct.valid_kernels.emplace_back(std::move(scored.front().second));
instruct.place = instruct.valid_kernels.front()->place();
LOG(INFO) << "pick " << instruct.valid_kernels.front()->name();
}
}
......
......@@ -37,6 +37,7 @@ class StaticKernelPickPass : public mir::InstructionPass {
public:
void Apply(std::unique_ptr<mir::SSAGraph>& graph) override;
void SetPreferPlace(const Place& place) { place_ = place; }
const Place& place() const { return place_; }
const core::KernelPickFactor& kernel_pick_factors() const {
return kernel_pick_factors_;
......@@ -51,16 +52,32 @@ class StaticKernelPickPass : public mir::InstructionPass {
size_t score{};
const int kMax =
std::numeric_limits<core::KernelPickFactor::value_type>::max();
// The more important factor comes first
if (kernel_pick_factors_.IsTargetConsidered() &&
place().target == kernel.target()) {
(place().target == kernel.target() || kernel.target() == TARGET(kAny) ||
place().target == TARGET(kAny))) {
score +=
kMax / static_cast<int>(core::KernelPickFactor::Factor::TargetFirst);
}
if (kernel_pick_factors_.IsPrecisionConsidered() &&
place().precision == kernel.precision()) {
(place().precision == kernel.precision() ||
kernel.precision() == PRECISION(kAny) ||
place().precision == PRECISION(kAny))) {
score += kMax /
static_cast<int>(core::KernelPickFactor::Factor::PrecisionFirst);
}
if (kernel_pick_factors_.IsDataLayoutConsidered() &&
(place().layout == kernel.layout() ||
kernel.layout() == DATALAYOUT(kAny) ||
place().layout == DATALAYOUT(kAny))) {
score += kMax / static_cast<int>(
core::KernelPickFactor::Factor::DataLayoutFirst);
}
LOG(INFO) << "picker tactic " << kernel_pick_factors_;
LOG(INFO) << "kernel place " << kernel.place();
LOG(INFO) << "picker place " << place();
LOG(INFO) << "score " << score;
// The data layout is not considered, for the input and output arguments
// might have different data layout.
......
......@@ -22,8 +22,8 @@ namespace mir {
void VariablePlaceInferencePass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
MarkInputPlace(graph.get());
InferenceArgumentPlace(graph.get());
CheckAllArgumentTypeDetermined(graph.get());
}
} // namespace mir
......
......@@ -31,6 +31,7 @@ class VariablePlaceInferencePass : public DebugPass {
private:
// Mark the place of input arguments.
void MarkInputPlace(SSAGraph* graph) {
CHECK(!graph->inputs().empty()) << "graph's inputs should be set";
for (const auto& v : graph->inputs()) {
// the feed op might in the inputs
if (v->IsInstruct()) {
......@@ -39,54 +40,60 @@ class VariablePlaceInferencePass : public DebugPass {
}
// auto& arg = v->AsArgument();
// arg.place.target = argument_default_target_;
// LOG(INFO) << "get graph input " << arg.name << " " << *arg.type;
// arg.type.target = argument_default_target_;
// the other place description can't be determined yet, until their first
// usage by some kernel.
}
}
void CheckAllArgumentTypeDetermined(SSAGraph* graph) {
for (auto& node : graph->mutable_nodes()) {
if (node.IsArgument()) {
CHECK(node.AsArgument().type) << "node " << node.AsArgument().name
<< " type not determined";
}
}
}
void InferenceArgumentPlace(SSAGraph* graph) {
LOG(INFO) << "param-type-registry:\n" << ParamTypeRegistry::Global();
for (auto& x : graph->InstructTopologicalOrder()) {
auto& inst = x->AsInstruct();
CHECK(inst.place.is_valid())
<< "kernel's place should be set when loaded";
// The IoCopyOp is a tool operator, it won't support the type inference.
if (inst.op_type == "io_copy") continue;
// LOG(INFO) << "- inferencing type " <<
// deal with inputs
for (auto& arg_name : inst.op_info->input_argnames()) {
auto type =
ParamTypeRegistry::Global().Retrieve<ParamTypeRegistry::IO::kInput>(
inst.place, inst.op_type, arg_name);
CHECK(type) << "no param-type found for " << inst.op_type << ":"
<< arg_name << " " << inst.place.DebugString();
auto arg_names = inst.op_info->input_argument().at(arg_name);
for (auto& arg_name : inst.op_info()->input_argnames()) {
LOG(INFO) << "-- input arg_name " << arg_name;
// check if inputs's place is set, if not set, update them with the
// kernel's declaration.
auto type = inst.picked_kernel().GetInputDeclType(arg_name);
auto arg_names = inst.op_info()->input_argument().at(arg_name);
for (auto& arg_name : arg_names) {
LOG(INFO) << "--- var " << arg_name;
auto* node = graph->RetrieveArgument(arg_name);
CHECK(node) << "argument " << arg_name << " not exists in the graph";
auto& arg_node = node->AsArgument();
if (arg_node.type) continue;
arg_node.type = type->type;
arg_node.type = type;
}
}
for (auto& arg_name : inst.op_info->output_argnames()) {
auto type = ParamTypeRegistry::Global()
.Retrieve<ParamTypeRegistry::IO::kOutput>(
inst.place, inst.op_type, arg_name);
CHECK(type) << "no param-type found for " << inst.op_type << ":"
<< arg_name << " " << inst.place.DebugString();
auto arg_names = inst.op_info->output_argument().at(arg_name);
for (auto& arg_name : inst.op_info()->output_argnames()) {
LOG(INFO) << "-- output arg_name " << arg_name;
auto type = inst.picked_kernel().GetOutputDeclType(arg_name);
auto arg_names = inst.op_info()->output_argument().at(arg_name);
// check if outputs's place is set, if not set, update them with the
// kernel's declaration.
for (auto& arg_name : arg_names) {
LOG(INFO) << "--- var " << arg_name;
auto* node = graph->RetrieveArgument(arg_name);
CHECK(node) << "argument " << arg_name << " not exists in the graph";
auto& arg_node = node->AsArgument();
if (arg_node.type) continue;
node->AsArgument().type = type->type;
node->AsArgument().type = type;
}
}
}
......
......@@ -27,13 +27,15 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels(
for (auto place : places) {
auto ks = KernelRegistry::Global().Create(
(kernel_type.empty() ? op_type_ : kernel_type), place.target,
place.precision);
place.precision, place.layout);
for (auto &&it : ks) {
AttachKernel(it.get());
kernels.emplace_back(std::move(it));
}
}
CHECK(!kernels.empty()) << "No kernel found for Op " << op_type_;
LOG(INFO) << "op " << op_type_ << " get " << kernels.size() << " kernels";
return kernels;
}
......@@ -59,9 +61,10 @@ bool OpLite::Run() {
}
bool OpLite::Attach(const framework::OpDesc &opdesc, lite::Scope *scope) {
CHECK(!op_info_) << "op_info duplicate build found";
op_info_ = std::make_shared<OpInfo>();
op_info_->Build(opdesc);
CHECK(scope);
scope_ = scope;
op_info_.reset(new OpInfo); // Force clean the out-of-date infomation.
op_info_->Build(opdesc.ReadonlyProto());
return AttachImpl(opdesc, scope);
}
......@@ -79,7 +82,8 @@ Tensor *OpLite::GetMutableTensor(lite::Scope *scope,
return var->GetMutable<lite::Tensor>();
}
bool OpInfo::GetInputArgname(const std::string &value_name, std::string *out) {
bool OpInfo::GetInputArgname(const std::string &value_name,
std::string *out) const {
for (auto &item : input_argument_) {
auto it = std::find(item.second.begin(), item.second.end(), value_name);
if (it != item.second.end()) {
......@@ -89,7 +93,8 @@ bool OpInfo::GetInputArgname(const std::string &value_name, std::string *out) {
}
return false;
}
bool OpInfo::GetOutputArgname(const std::string &value_name, std::string *out) {
bool OpInfo::GetOutputArgname(const std::string &value_name,
std::string *out) const {
for (auto &item : output_argument_) {
auto it = std::find(item.second.begin(), item.second.end(), value_name);
if (it != item.second.end()) {
......
......@@ -81,19 +81,30 @@ class OpLite : public Registry {
// Run this operator.
virtual bool Run();
// Link the external execution environ to internal context.
bool Attach(const framework::OpDesc &opdesc, lite::Scope *scope);
const std::shared_ptr<OpInfo> &op_info() const { return op_info_; }
std::shared_ptr<OpInfo> &mutable_op_info() { return op_info_; }
const OpInfo *op_info() const { return op_info_.get(); }
OpInfo *mutable_op_info() { return op_info_.get(); }
// Human-readable information.
virtual std::string DebugString() const = 0;
const Place &kernel_place() const { return kernel_place_; }
// NOTE This might be discarded.
void PickKernel(const std::vector<Place> &valid_places,
KernelStrategy kernel_strategy = KernelStrategy::kStatic);
// Create all the kernels for the valid targets.
std::vector<std::unique_ptr<KernelBase>> CreateKernels(
const std::vector<Place> &places, const std::string &kernel_type = "");
lite::Scope *scope() { return scope_; }
// Assign op param to kernel.
virtual void AttachKernel(KernelBase *kernel) = 0;
virtual ~OpLite() = default;
protected:
......@@ -101,9 +112,6 @@ class OpLite : public Registry {
virtual bool AttachImpl(const framework::OpDesc &opdesc,
lite::Scope *scope) = 0;
// Assign op param to kernel.
virtual void AttachKernel(KernelBase *kernel) = 0;
// Specify the kernel to run by default. This will specify the value of
// `kernel_place_`.
virtual void StaticPickKernel(const std::vector<Place> &valid_targets) {
......@@ -118,10 +126,6 @@ class OpLite : public Registry {
// some inputs are ready.
void RecordOutputEvents() {}
// Create all the kernels for the valid targets.
std::vector<std::unique_ptr<KernelBase>> CreateKernels(
const std::vector<Place> &places, const std::string &kernel_type = "");
const Tensor *GetTensor(lite::Scope *scope, const std::string &name) const;
Tensor *GetMutableTensor(lite::Scope *scope, const std::string &name) const;
......@@ -129,11 +133,12 @@ class OpLite : public Registry {
friend class mir::SSAGraph;
protected:
lite::Scope *scope_{};
std::unique_ptr<KernelBase> kernel_;
std::string op_type_;
std::vector<Place> valid_places_;
Place kernel_place_{TARGET(kHost), PRECISION(kFloat)};
std::shared_ptr<OpInfo> op_info_;
std::unique_ptr<OpInfo> op_info_;
};
/*
......@@ -142,22 +147,30 @@ class OpLite : public Registry {
*/
class OpInfo {
public:
void Build(const framework::OpDesc &desc) {
// To avoid the bugs from legancy framework::OpDesc, we use the ProtoBuf
// message instead.
void Build(const framework::proto::OpDesc &desc) {
ExtractInputsAndOutputs(desc);
CollectInputAndOutputArgnames(desc);
CollectArguments(desc);
desc_.reset(new framework::proto::OpDesc(desc));
}
const framework::proto::OpDesc &desc() const {
CHECK(desc_) << "desc has't set";
return *desc_;
}
framework::proto::OpDesc *mutable_desc() { return desc_.get(); }
const std::list<std::string> &input_names() const { return input_names_; }
const std::list<std::string> &output_names() const { return output_names_; }
const std::map<std::string, std::list<std::string>> &input_argument() {
const std::map<std::string, std::list<std::string>> &input_argument() const {
return input_argument_;
}
const std::map<std::string, std::list<std::string>> &output_argument() {
const std::map<std::string, std::list<std::string>> &output_argument() const {
return output_argument_;
}
bool GetInputArgname(const std::string &value_name, std::string *out);
bool GetOutputArgname(const std::string &value_name, std::string *out);
bool GetInputArgname(const std::string &value_name, std::string *out) const;
bool GetOutputArgname(const std::string &value_name, std::string *out) const;
const std::list<std::string> &input_argnames() const {
return input_argnames_;
......@@ -167,37 +180,37 @@ class OpInfo {
}
private:
void ExtractInputsAndOutputs(const framework::OpDesc &opdesc) {
for (const auto &item : opdesc.Inputs()) {
for (const auto &x : item.second) {
void ExtractInputsAndOutputs(const framework::proto::OpDesc &opdesc) {
for (const auto &item : opdesc.inputs()) {
for (const auto &x : item.arguments()) {
input_names_.push_back(x);
}
}
for (const auto &item : opdesc.Outputs()) {
for (const auto &x : item.second) {
for (const auto &item : opdesc.outputs()) {
for (const auto &x : item.arguments()) {
output_names_.push_back(x);
}
}
}
void CollectInputAndOutputArgnames(const framework::OpDesc &opdesc) {
for (const auto &item : opdesc.InputNames()) {
input_argnames_.push_back(item);
void CollectInputAndOutputArgnames(const framework::proto::OpDesc &opdesc) {
for (const auto &item : opdesc.inputs()) {
input_argnames_.push_back(item.parameter());
}
for (const auto &item : opdesc.OutputNames()) {
output_argnames_.push_back(item);
for (const auto &item : opdesc.outputs()) {
output_argnames_.push_back(item.parameter());
}
}
void CollectArguments(const framework::OpDesc &opdesc) {
for (const auto &item : opdesc.Inputs()) {
for (auto &x : item.second) {
input_argument_[item.first].push_back(x);
void CollectArguments(const framework::proto::OpDesc &opdesc) {
for (const auto &item : opdesc.inputs()) {
for (auto &x : item.arguments()) {
input_argument_[item.parameter()].push_back(x);
}
}
for (const auto &item : opdesc.Outputs()) {
for (auto &x : item.second) {
output_argument_[item.first].push_back(x);
for (const auto &item : opdesc.outputs()) {
for (auto &x : item.arguments()) {
output_argument_[item.parameter()].push_back(x);
}
}
}
......@@ -209,6 +222,8 @@ class OpInfo {
std::list<std::string> output_argnames_;
std::map<std::string, std::list<std::string>> input_argument_;
std::map<std::string, std::list<std::string>> output_argument_;
// NOTE too heavy.
std::unique_ptr<framework::proto::OpDesc> desc_;
};
} // namespace lite
......
......@@ -18,13 +18,33 @@ namespace paddle {
namespace lite {
std::list<std::unique_ptr<KernelBase>> KernelRegistry::Create(
const std::string &op_type, TargetType target, PrecisionType precision) {
const std::string &op_type, TargetType target, PrecisionType precision,
DataLayoutType layout) {
Place place{target, precision, layout};
LOG(INFO) << "creating " << op_type << " kernel for " << place;
#define CREATE_KERNEL1(target__, precision__) \
switch (layout) { \
case DATALAYOUT(kNCHW): \
return Create<TARGET(target__), PRECISION(precision__), \
DATALAYOUT(kNCHW)>(op_type); \
case DATALAYOUT(kAny): \
return Create<TARGET(target__), PRECISION(precision__), \
DATALAYOUT(kAny)>(op_type); \
default: \
LOG(FATAL) << "unsupported kernel layout " << DataLayoutToStr(layout); \
}
#define CREATE_KERNEL(target__) \
switch (precision) { \
case PRECISION(kFloat): \
return Create<TARGET(target__), PRECISION(kFloat)>(op_type); \
CREATE_KERNEL1(target__, kFloat); \
case PRECISION(kInt8): \
CREATE_KERNEL1(target__, kInt8); \
case PRECISION(kAny): \
CREATE_KERNEL1(target__, kAny); \
default: \
CHECK(false) << "not supported kernel place yet"; \
CHECK(false) << "not supported kernel precision " \
<< PrecisionToStr(precision); \
}
switch (target) {
......@@ -38,7 +58,7 @@ std::list<std::unique_ptr<KernelBase>> KernelRegistry::Create(
CREATE_KERNEL(kCUDA);
} break;
default:
CHECK(false) << "not supported kernel place";
CHECK(false) << "not supported kernel target " << TargetToStr(target);
}
#undef CREATE_KERNEL
......@@ -46,14 +66,21 @@ std::list<std::unique_ptr<KernelBase>> KernelRegistry::Create(
}
KernelRegistry::KernelRegistry() {
#define INIT_FOR(target__, precision__) \
#define INIT_FOR(target__, precision__, layout__) \
registries_[KernelRegistry::GetKernelOffset<TARGET(target__), \
PRECISION(precision__)>()] \
.set<KernelRegistryForTarget<TARGET(target__), PRECISION(precision__)> \
*>(&KernelRegistryForTarget<TARGET(target__), \
PRECISION(precision__)>::Global());
PRECISION(precision__), \
DATALAYOUT(layout__)>()] \
.set<KernelRegistryForTarget<TARGET(target__), PRECISION(precision__), \
DATALAYOUT(layout__)> *>( \
&KernelRegistryForTarget<TARGET(target__), PRECISION(precision__), \
DATALAYOUT(layout__)>::Global());
// Currently, just register 2 kernel targets.
INIT_FOR(kHost, kFloat);
INIT_FOR(kCUDA, kFloat, kNCHW);
INIT_FOR(kCUDA, kAny, kNCHW);
INIT_FOR(kHost, kFloat, kNCHW);
INIT_FOR(kHost, kAny, kNCHW);
INIT_FOR(kHost, kAny, kAny);
INIT_FOR(kCUDA, kAny, kAny);
#undef INIT_FOR
}
......
......@@ -50,80 +50,108 @@ class OpLiteRegistor : public Registor<OpClass> {
}) {}
};
template <TargetType Target, PrecisionType Precision>
template <TargetType Target, PrecisionType Precision, DataLayoutType Layout>
using KernelRegistryForTarget =
Factory<OpKernel<Target, Precision>, std::unique_ptr<KernelBase>>;
Factory<OpKernel<Target, Precision, Layout>, std::unique_ptr<KernelBase>>;
class KernelRegistry final {
public:
using any_kernel_registor_t =
variant<KernelRegistryForTarget<TARGET(kCUDA), PRECISION(kFloat)> *, //
KernelRegistryForTarget<TARGET(kCUDA), PRECISION(kInt8)> *, //
KernelRegistryForTarget<TARGET(kX86), PRECISION(kFloat)> *, //
KernelRegistryForTarget<TARGET(kX86), PRECISION(kInt8)> *, //
KernelRegistryForTarget<TARGET(kHost), PRECISION(kFloat)> * //
variant<KernelRegistryForTarget<TARGET(kCUDA), PRECISION(kFloat),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kCUDA), PRECISION(kInt8),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kX86), PRECISION(kFloat),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kX86), PRECISION(kInt8),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kHost), PRECISION(kFloat),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kHost), PRECISION(kAny),
DATALAYOUT(kAny)> *, //
KernelRegistryForTarget<TARGET(kCUDA), PRECISION(kAny),
DATALAYOUT(kAny)> * //
>;
KernelRegistry();
static KernelRegistry &Global();
template <TargetType Target, PrecisionType Precision>
template <TargetType Target, PrecisionType Precision, DataLayoutType Layout>
void Register(const std::string &name,
typename KernelRegistryForTarget<Target, Precision>::creator_t
&&creator) {
using kernel_registor_t = KernelRegistryForTarget<Target, Precision>;
registries_[GetKernelOffset<Target, Precision>()]
.template get<kernel_registor_t *>()
->Register(name, std::move(creator));
typename KernelRegistryForTarget<Target, Precision,
Layout>::creator_t &&creator) {
LOG(INFO) << "register for " << TargetToStr(Target) << ":"
<< PrecisionToStr(Precision) << "//"
<< GetKernelOffset<Target, Precision, Layout>();
using kernel_registor_t =
KernelRegistryForTarget<Target, Precision, Layout>;
auto &varient = registries_[GetKernelOffset<Target, Precision, Layout>()];
varient.template get<kernel_registor_t *>()->Register(name,
std::move(creator));
}
template <TargetType Target, PrecisionType Precision>
template <TargetType Target, PrecisionType Precision, DataLayoutType Layout>
std::list<std::unique_ptr<KernelBase>> Create(const std::string &op_type) {
using kernel_registor_t = KernelRegistryForTarget<Target, Precision>;
return registries_[GetKernelOffset<Target, Precision>()]
using kernel_registor_t =
KernelRegistryForTarget<Target, Precision, Layout>;
return registries_[GetKernelOffset<Target, Precision, Layout>()]
.template get<kernel_registor_t *>()
->Creates(op_type);
}
std::list<std::unique_ptr<KernelBase>> Create(const std::string &op_type,
TargetType target,
PrecisionType precision);
PrecisionType precision,
DataLayoutType layout);
// Get a kernel registry offset in all the registries.
template <TargetType Target, PrecisionType Precision>
static constexpr int GetKernelOffset() {
return kNumTargets * static_cast<int>(Target) + static_cast<int>(Precision);
template <TargetType Target, PrecisionType Precision, DataLayoutType Layout>
static int GetKernelOffset() {
CHECK_LT(static_cast<int>(Target), static_cast<int>(TARGET(NUM)));
CHECK_LT(static_cast<int>(Precision), static_cast<int>(PRECISION(NUM)));
CHECK_LT(static_cast<int>(Layout), static_cast<int>(DATALAYOUT(NUM)));
return static_cast<int>(Target) * static_cast<int>(PRECISION(NUM)) *
static_cast<int>(DATALAYOUT(NUM)) + //
static_cast<int>(Precision) * static_cast<int>(DATALAYOUT(NUM)) + //
static_cast<int>(Layout);
}
std::string DebugString() const {
std::stringstream ss;
ss << "KernelCreator<host, float>:" << std::endl;
ss << registries_[GetKernelOffset<TARGET(kHost), PRECISION(kFloat)>()]
.get<
KernelRegistryForTarget<TARGET(kHost), PRECISION(kFloat)> *>()
ss << registries_[GetKernelOffset<TARGET(kHost), PRECISION(kFloat),
DATALAYOUT(kAny)>()]
.get<KernelRegistryForTarget<TARGET(kHost), PRECISION(kFloat),
DATALAYOUT(kNCHW)> *>()
->DebugString();
ss << std::endl;
return ss.str();
}
private:
mutable std::array<any_kernel_registor_t, kNumTargets * kNumPrecisions>
mutable std::array<any_kernel_registor_t,
static_cast<int>(TARGET(NUM)) *
static_cast<int>(PRECISION(NUM)) *
static_cast<int>(DATALAYOUT(NUM))>
registries_;
};
template <TargetType target, PrecisionType precision, typename KernelType>
template <TargetType target, PrecisionType precision, DataLayoutType layout,
typename KernelType>
class KernelRegistor : public lite::Registor<KernelType> {
public:
KernelRegistor(const std::string op_type)
: Registor<KernelType>([&] {
KernelRegistor(const std::string &op_type, const std::string &alias)
: Registor<KernelType>([=] {
LOG(INFO) << "Register kernel " << op_type << " for "
<< TargetToStr(target) << " " << PrecisionToStr(precision);
KernelRegistry::Global().Register<target, precision>(
op_type, [&, op_type]() -> std::unique_ptr<KernelType> {
<< TargetToStr(target) << " " << PrecisionToStr(precision)
<< " " << DataLayoutToStr(layout) << " alias " << alias;
KernelRegistry::Global().Register<target, precision, layout>(
op_type, [=]() -> std::unique_ptr<KernelType> {
std::unique_ptr<KernelType> x(new KernelType);
x->set_op_type(op_type);
x->set_alias(alias);
return x;
});
}) {}
......@@ -151,35 +179,40 @@ class KernelRegistor : public lite::Registor<KernelType> {
#define LITE_KERNEL_REGISTER(op_type__, target__, precision__) \
op_type__##__##target__##__##precision__##__registor__
#define LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, \
alias__) \
layout__, alias__) \
op_type__##__##target__##__##precision__##__registor__instance__##alias__
#define LITE_KERNEL_REGISTER_FAKE(op_type__, target__, precision__, alias__) \
LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, alias__)
#define REGISTER_LITE_KERNEL(op_type__, target__, precision__, KernelClass, \
alias__) \
#define REGISTER_LITE_KERNEL(op_type__, target__, precision__, layout__, \
KernelClass, alias__) \
static paddle::lite::KernelRegistor<TARGET(target__), \
PRECISION(precision__), KernelClass> \
PRECISION(precision__), \
DATALAYOUT(layout__), KernelClass> \
LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, \
alias__)(#op_type__); \
layout__, alias__)(#op_type__, #alias__); \
static KernelClass LITE_KERNEL_INSTANCE(op_type__, target__, precision__, \
alias__); \
int touch_##op_type__##target__##precision__##alias__() { \
LITE_KERNEL_INSTANCE(op_type__, target__, precision__, alias__).Touch(); \
layout__, alias__); \
int touch_##op_type__##target__##precision__##layout__##alias__() { \
LITE_KERNEL_INSTANCE(op_type__, target__, precision__, layout__, alias__) \
.Touch(); \
return 0; \
} \
static bool LITE_KERNEL_PARAM_INSTANCE(op_type__, target__, precision__, \
alias__) __attribute__((unused)) = \
paddle::lite::ParamTypeRegistry::NewInstance<TARGET(target__), \
PRECISION(precision__)>( \
#op_type__)
#define USE_LITE_KERNEL(op_type__, target__, precision__, alias__) \
extern int touch_##op_type__##target__##precision__##alias__(); \
int op_type__##target__##precision__##alias__ __attribute__((unused)) = \
touch_##op_type__##target__##precision__##alias__();
#define LITE_KERNEL_INSTANCE(op_type__, target__, precision__, alias__) \
op_type__##target__##precision__##alias__
#define LITE_KERNEL_PARAM_INSTANCE(op_type__, target__, precision__, alias__) \
op_type__##target__##precision__##alias__##param_register
layout__, alias__) \
__attribute__((unused)) = paddle::lite::ParamTypeRegistry::NewInstance< \
TARGET(target__), PRECISION(precision__), DATALAYOUT(layout__)>( \
#op_type__ "/" #alias__)
#define USE_LITE_KERNEL(op_type__, target__, precision__, layout__, alias__) \
extern int touch_##op_type__##target__##precision__##layout__##alias__(); \
int op_type__##target__##precision__##layout__##alias__ \
__attribute__((unused)) = \
touch_##op_type__##target__##precision__##layout__##alias__();
#define LITE_KERNEL_INSTANCE(op_type__, target__, precision__, layout__, \
alias__) \
op_type__##target__##precision__##layout__##alias__
#define LITE_KERNEL_PARAM_INSTANCE(op_type__, target__, precision__, layout__, \
alias__) \
op_type__##target__##precision__##layout__##alias__##param_register
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/lite/core/optimizer.h"
#include "paddle/fluid/lite/core/mir/io_complement_pass.h"
#include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h"
namespace paddle {
......@@ -25,5 +26,33 @@ void Optimizer::SpecifyKernelPickTactic(core::KernelPickFactor factor) {
*pass->mutable_kernel_pick_factors() = factor;
}
void Optimizer::RunPasses() {
std::vector<std::string> passes({
"static_kernel_pick_pass", //
"variable_place_inference_pass", //
"argument_type_display_pass", //
"io_complement_pass", //
"argument_type_display_pass", //
"variable_place_inference_pass", //
"argument_type_display_pass", //
"io_copy_kernel_pick_pass", //
"variable_place_inference_pass", //
});
for (auto& pass_type : passes) {
LOG(INFO) << ".. running pass " << pass_type;
auto* pass = mir::PassManager::Global().LookUp(pass_type);
CHECK(pass);
if (pass->name() == "io_complement_pass") {
auto* _pass = dynamic_cast<mir::IoComplementPass*>(pass);
_pass->SetValidPlaces(valid_places_);
CHECK(!_pass->valid_places().empty());
_pass->Apply(graph_);
} else {
pass->Apply(graph_);
}
}
// mir::PassManager::Global().Run(graph_);
}
} // namespace lite
} // namespace paddle
......@@ -16,8 +16,10 @@
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/mir/generate_program_pass.h"
#include "paddle/fluid/lite/core/mir/io_complement_pass.h"
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include "paddle/fluid/lite/core/mir/ssa_graph.h"
#include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h"
#include "paddle/fluid/lite/core/program.h"
#include "paddle/fluid/lite/core/types.h"
......@@ -33,25 +35,46 @@ class Optimizer {
void Run(Program&& program, const std::vector<Place>& valid_places,
core::KernelPickFactor kernel_pick_factor,
const std::vector<std::string>& passes = {}) {
valid_places_ = valid_places;
CHECK(!valid_places.empty()) << "At least one valid_place should be set";
CHECK(!graph_) << "duplicate optimize found";
graph_.reset(new mir::SSAGraph);
graph_->Build(program, valid_places);
SpecifyKernelPickTactic(kernel_pick_factor);
// InitIoComplement();
RunPasses();
exec_scope_ = program.exec_scope;
}
void KernelPickPreferPlace(const Place& place) {
auto* pass = mir::PassManager::Global().LookUp<mir::StaticKernelPickPass>(
"static_kernel_pick_pass");
CHECK(pass);
pass->SetPreferPlace(place);
}
// Generate a new program based on the mir graph.
std::unique_ptr<RuntimeProgram> GenRuntimeProgram() {
LOG(INFO) << "generate program";
std::unique_ptr<Program> res;
auto pass = mir::PassManager::Global().LookUp<mir::GenerateProgramPass>(
"generate_program_pass");
pass->Apply(graph_);
auto program = pass->GenProgram();
CHECK(exec_scope_);
program->set_exec_scope(exec_scope_);
return program;
}
void InitIoComplement() {
auto* pass = mir::PassManager::Global().LookUp<mir::IoComplementPass>(
"io_complement_pass");
CHECK(pass);
CHECK(!valid_places_.empty());
LOG(INFO) << "valid_places.size " << valid_places_.size();
pass->SetValidPlaces(valid_places_);
}
// Generate C++ code which combines the inference program, model and weights.
void GenCode(const std::string& code_dir);
......@@ -64,13 +87,14 @@ class Optimizer {
void SpecifyKernelPickTactic(core::KernelPickFactor factor);
// Run the default passes registered in the PassManager.
void RunPasses() { mir::PassManager::Global().Run(graph_); }
void RunPasses();
// Specify the passes and run them.
void RunPasses(std::vector<std::string>& passes);
private:
std::unique_ptr<mir::SSAGraph> graph_;
std::vector<Place> valid_places_;
lite::Scope* exec_scope_{};
};
......
......@@ -84,13 +84,10 @@ struct Program {
tmp_vars.push_back("fetch");
for (auto var_desc : program.Block(0).AllVars()) {
if (!var_desc->Persistable()) {
LOG(INFO) << "get tmp var " << var_desc->Name();
tmp_vars.push_back(var_desc->Name());
auto* var = exec_scope->Var(var_desc->Name());
LOG(INFO) << "create tmp var " << var_desc->Name() << " " << var;
exec_scope->Var(var_desc->Name());
} else {
if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") continue;
LOG(INFO) << "get weight var " << var_desc->Name();
weights.push_back(var_desc->Name());
}
}
......@@ -105,15 +102,19 @@ struct Instruction {
void Run() {
CHECK(op_);
CHECK(kernel_);
LOG(INFO) << "running kernel> " << kernel_->DebugString();
if (UNLIKELY(first_epoch_)) {
first_epoch_ = false;
op_->CheckShape();
CHECK(op_->CheckShape());
}
op_->InferShape();
kernel_->Run();
}
friend std::ostream& operator<<(std::ostream& os, const Instruction& other) {
os << other.kernel_->summary() << "\t(" << other.kernel_->doc() << ")";
return os;
}
private:
std::shared_ptr<OpLite> op_;
std::unique_ptr<KernelBase> kernel_;
......@@ -125,11 +126,16 @@ struct Instruction {
*/
class RuntimeProgram {
public:
explicit RuntimeProgram(std::vector<Instruction>&& instruction)
: instructions_(std::move(instruction)) {}
explicit RuntimeProgram(std::vector<Instruction>&& insts)
: instructions_(std::move(insts)) {
if (insts.empty()) {
LOG(ERROR) << "no instructions";
}
}
void Run() {
for (auto& inst : instructions_) {
LOG(INFO) << ">> Running kernel: " << inst;
inst.Run();
}
}
......
......@@ -16,6 +16,10 @@
#include <glog/logging.h>
#include <iostream>
#include <sstream>
#ifdef LITE_WITH_CUDA
#include <cuda.h>
#include <cuda_runtime.h>
#endif
namespace paddle {
namespace lite {
......@@ -26,20 +30,20 @@ enum class TargetType : int {
kX86,
kCUDA,
kAny, // any target
kLastAsPlaceHolder,
NUM, // number of fields.
};
enum class PrecisionType : int {
kUnk = 0,
kFloat,
kInt8,
kAny, // any precision
kLastAsPlaceHolder,
NUM, // number of fields.
};
enum class DataLayoutType : int {
kUnk = 0,
kNCHW,
kAny, // any data layout
kLastAsPlaceHolder,
NUM, // number of fields.
};
// Some helper macro to get a specific TargetType.
......@@ -50,25 +54,29 @@ enum class DataLayoutType : int {
#define PRECISION_VAL(item__) static_cast<int>(PRECISION(item__))
#define DATALAYOUT(item__) paddle::lite::DataLayoutType::item__
constexpr const int kNumPrecisions =
PRECISION_VAL(kLastAsPlaceHolder) - PRECISION_VAL(kFloat);
constexpr const int kNumTargets =
TARGET_VAL(kLastAsPlaceHolder) - TARGET_VAL(kHost);
constexpr const int kNumPrecisions = PRECISION_VAL(NUM);
constexpr const int kNumTargets = TARGET_VAL(NUM);
static const std::string target2string[] = {"unk", "host", "x86", "cuda",
"any"};
static const std::string& TargetToStr(TargetType target) {
return target2string[static_cast<int>(target)];
auto x = static_cast<int>(target);
CHECK_LT(x, static_cast<int>(TARGET(NUM)));
return target2string[x];
}
static const std::string precision2string[] = {"unk", "float", "int8", "any"};
static const std::string& PrecisionToStr(PrecisionType precision) {
return precision2string[static_cast<int>(precision)];
auto x = static_cast<int>(precision);
CHECK_LT(x, static_cast<int>(PRECISION(NUM)));
return precision2string[x];
}
static const std::string datalayout2string[] = {"unk", "NCHW", "any"};
static const std::string& DataLayoutToStr(DataLayoutType x) {
return datalayout2string[static_cast<int>(x)];
static const std::string& DataLayoutToStr(DataLayoutType layout) {
auto x = static_cast<int>(layout);
CHECK_LT(x, static_cast<int>(DATALAYOUT(NUM)));
return datalayout2string[x];
}
/*
......@@ -187,5 +195,37 @@ class TargetWrapper<TARGET(kHost)> {
}
};
#ifdef LITE_WITH_CUDA
// This interface should be specified by each kind of target.
template <>
class TargetWrapper<TARGET(kCUDA), cudaStream_t, cudaEvent_t> {
public:
using stream_t = cudaStream_t;
using event_t = cudaEvent_t;
static size_t num_devices() { return 0; }
static size_t maximum_stream() { return 0; }
static void CreateStream(stream_t* stream) {}
static void DestroyStream(const stream_t& stream) {}
static void CreateEvent(event_t* event) {}
static void DestroyEvent(const event_t& event) {}
static void RecordEvent(const event_t& event) {}
static void SyncEvent(const event_t& event) {}
static void StreamSync(const stream_t& stream) {}
static void* Malloc(size_t size);
static void Free(void* ptr);
static void MemcpySync(void* dst, const void* src, size_t size,
IoDirection dir);
static void MemcpyAsync(void* dst, const void* src, size_t size,
IoDirection dir, const stream_t& stream);
};
#endif // LITE_WITH_CUDA
} // namespace lite
} // namespace paddle
......@@ -87,20 +87,41 @@ const Type* Type::Get<TensorAnyTy>(TargetType target) {
}
}
template <TargetType Target>
const Type* GetTensorFp32NCHWTy() {
static TensorFp32NCHWTy x(Target);
return &x;
}
template <>
const Type* Type::Get<TensorFp32NCHWTy>(TargetType target) {
switch (target) {
case TargetType::kX86:
return Get<false, true, TargetType::kX86, PrecisionType::kFloat,
DataLayoutType::kNCHW>();
case TargetType::kHost:
return Get<false, true, TargetType::kHost, PrecisionType::kFloat,
DataLayoutType::kNCHW>();
case TARGET(kHost):
return GetTensorFp32NCHWTy<TARGET(kHost)>();
case TARGET(kCUDA):
return GetTensorFp32NCHWTy<TARGET(kCUDA)>();
case TARGET(kX86):
return GetTensorFp32NCHWTy<TARGET(kX86)>();
default:
LOG(FATAL) << "unsupported target " << TargetToStr(target);
LOG(FATAL) << "unsupported target Type " << TargetToStr(target);
}
return nullptr;
}
const Type* LookupType(DataTypeBase::ID type_id, bool is_unknown,
bool is_tensor, Place place) {
using id_t = DataTypeBase::ID;
switch (type_id) {
case id_t::Tensor_Any:
return Type::Get<TensorAnyTy>(place.target);
case id_t::Tensor_Fp32_NCHW:
return Type::Get<TensorFp32NCHWTy>(place.target);
case id_t::TensorList_Any:
return Type::Get<TensorListAnyTy>(place.target);
default:
LOG(FATAL) << "unsupported type";
}
return nullptr;
}
// ------------------------- end GetType specification ------------------------
......
......@@ -131,6 +131,23 @@ class Type : public DataTypeBase {
bool operator==(const Type& other) {
return id_ == other.id() && place_ == other.place();
}
friend std::ostream& operator<<(std::ostream& os, const Type& other) {
if (other.IsUnsupported()) {
os << "<Unsupported>";
return os;
}
if (other.IsVoid()) {
os << "<Void>";
return os;
}
if (other.is_tensor_) {
os << "<Tensor:";
}
os << TargetToStr(other.target()) << "/"
<< PrecisionToStr(other.precision()) << "/"
<< DataLayoutToStr(other.layout()) << ">";
return os;
}
// Can cast to another type. This is heavily used in MIR, by determine whether
// is is possible to add a instruction to transform a type to another.
......@@ -163,29 +180,33 @@ class Type : public DataTypeBase {
};
// -------------------------------- compatible check ---------------------------
static bool TargetCompatible(const Type& a, const Type& b) {
return (a.IsVoid() || b.IsVoid()) || //
a.target() == b.target();
static bool TargetCompatibleTo(const Type& a, const Type& b) {
return a.IsVoid() || //
(a.IsTensor() && b.IsTensor() && (a.target() == b.target() || //
b.target() == TARGET(kAny)));
}
static bool DataLayoutCompatible(const Type& a, const Type& b) {
return (a.IsVoid() || b.IsVoid()) || //
(a.IsTensor() && b.IsTensor() && a.layout() == b.layout());
static bool DataLayoutCompatibleTo(const Type& a, const Type& b) {
return a.IsVoid() || //
(a.IsTensor() && b.IsTensor() && (a.layout() == b.layout() || //
b.layout() == DATALAYOUT(kAny)));
}
static bool PrecisionCompatible(const Type& a, const Type& b) {
return (a.IsVoid() || b.IsVoid()) || //
(a.precision() == b.precision());
static bool PrecisionCompatibleTo(const Type& a, const Type& b) {
return a.IsVoid() || //
(a.IsTensor() && b.IsTensor() && (a.precision() == b.precision() || //
b.precision() == PRECISION(kAny)));
}
static bool DeviceCompatible(const Type& a, const Type& b) {
return (a.IsVoid() || b.IsVoid()) || //
(a.device() == b.device());
static bool DeviceCompatibleTo(const Type& a, const Type& b) {
return a.IsVoid() || //
(a.IsTensor() && b.IsTensor() && (a.device() == b.device()));
}
static bool TypeCompatible(const Type& a, const Type& b) {
return TargetCompatible(a, b) && DataLayoutCompatible(a, b) &&
PrecisionCompatible(a, b) && DeviceCompatible(a, b);
// Can type 'a' be passed to 'b' directly.
static bool TypeCompatibleTo(const Type& a, const Type& b) {
return TargetCompatibleTo(a, b) && DataLayoutCompatibleTo(a, b) &&
PrecisionCompatibleTo(a, b) && DeviceCompatibleTo(a, b);
}
// -------------------------------- predefined types ---------------------------
......@@ -230,6 +251,9 @@ class TensorInt64NCHWTy : public Type {
: Type(ID::Tensor_Int64_NCHW, "TensorInt64NCHW", true /*is_tensor*/,
target, PrecisionType::kInt8, DataLayoutType::kNCHW) {}
};
const Type* LookupType(DataTypeBase::ID type_id, bool is_unknown,
bool is_tensor, Place place);
// ------------------------- end predefined types ---------------------------
// NOTE TypeSystem has some overhead, and better to be used in analysis phase.
......@@ -381,13 +405,15 @@ class ParamTypeRegistry {
CHECK(types_.count(key));
}
template <IO io>
const ParamType* Retrieve(const Place& place, const std::string& op_type,
const ParamType* RetrieveInArgument(const Place& place,
const std::string& op_type,
const std::string& arg_name) {
KernelIdTy key{op_type, place, io, arg_name};
auto it = types_.find(key);
if (it == types_.end()) return nullptr;
return &it->second;
return Retrieve<IO::kInput>(place, op_type, arg_name);
}
const ParamType* RetrieveOutArgument(const Place& place,
const std::string& op_type,
const std::string& arg_name) {
return Retrieve<IO::kOutput>(place, op_type, arg_name);
}
static ParamTypeRegistry& Global() {
......@@ -403,6 +429,16 @@ class ParamTypeRegistry {
return os;
}
protected:
template <IO io>
const ParamType* Retrieve(const Place& place, const std::string& op_type,
const std::string& arg_name) {
KernelIdTy key{op_type, place, io, arg_name};
auto it = types_.find(key);
if (it == types_.end()) return nullptr;
return &it->second;
}
private:
ParamTypeRegistry() = default;
......
......@@ -43,6 +43,9 @@ bool KernelPickFactor::IsTargetConsidered() const {
bool KernelPickFactor::IsDataLayoutConsidered() const {
return data_ & static_cast<int>(Factor::DataLayoutFirst);
}
bool KernelPickFactor::IsDeviceConsidered() const {
return data_ & static_cast<int>(Factor::DeviceFirst);
}
} // namespace core
} // namespace lite
......
......@@ -14,6 +14,7 @@
#pragma once
#include <stack>
#include "paddle/fluid/lite/core/context.h"
#include "paddle/fluid/lite/utils/all.h"
......@@ -38,6 +39,7 @@ class KernelPickFactor {
bool AnyFactorConsidered() const { return data_; }
KernelPickFactor& ConsiderTarget();
// Perfer a specific target, e.g. prefer CUDA kernels.
KernelPickFactor& ConsiderPrecision();
KernelPickFactor& ConsiderDataLayout();
KernelPickFactor& ConsiderDevice();
......@@ -45,12 +47,29 @@ class KernelPickFactor {
bool IsTargetConsidered() const;
bool IsPrecisionConsidered() const;
bool IsDataLayoutConsidered() const;
bool IsDeviceConsidered() const {
return data_ & static_cast<int>(Factor::DeviceFirst);
bool IsDeviceConsidered() const;
friend std::ostream& operator<<(std::ostream& os, const KernelPickFactor& k) {
std::stack<bool> bits;
auto data = k.data_;
while (data) {
bits.push(data % 2);
data /= 2;
}
int nbits = bits.size();
for (size_t i = 0; i < sizeof(data) * 8 - nbits; i++) {
os << 0;
}
while (!bits.empty()) {
os << bits.top();
bits.pop();
}
return os;
}
private:
unsigned char data_{};
TargetType target_{TARGET(kUnk)};
};
struct dim2 {
......
......@@ -12,10 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Created by chunwei on 19-2-23.
//
#include "paddle/fluid/lite/cuda/target_wrapper.h"
#include <glog/logging.h>
......@@ -24,19 +20,14 @@ namespace lite {
using TargetW = TargetWrapper<TARGET(kCUDA), cudaStream_t, cudaEvent_t>;
template <>
void* TargetW::Malloc(size_t size) {
void* ptr{};
CHECK_EQ(cudaSuccess, cudaMalloc(&ptr, size));
return ptr;
}
template <>
void TargetW::Free(void* ptr) {
CHECK_EQ(cudaSuccess, cudaFree(ptr));
}
void TargetW::Free(void* ptr) { CHECK_EQ(cudaSuccess, cudaFree(ptr)); }
template <>
void TargetW::MemcpySync(void* dst, const void* src, size_t size,
IoDirection dir) {
switch (dir) {
......@@ -55,7 +46,6 @@ void TargetW::MemcpySync(void* dst, const void* src, size_t size,
}
}
template <>
void TargetW::MemcpyAsync(void* dst, const void* src, size_t size,
IoDirection dir, const stream_t& stream) {
switch (dir) {
......
cc_library(mul_compute_cuda SRCS mul_compute.cc DEPS tensor_lite)
nv_library(mul_compute_cuda SRCS mul_compute.cc DEPS tensor_lite)
cc_library(io_copy_compute_cuda SRCS io_copy_compute.cc DEPS tensor_lite)
nv_library(kernels_cuda DEPS mul_compute_cuda io_copy_compute_cuda)
......@@ -21,7 +21,7 @@ namespace lite {
namespace kernels {
namespace cuda {
using TargetW = TargetWrapper<TARGET(kHost), cudaStream_t, cudaEvent_t>;
using TargetW = TargetWrapper<TARGET(kCUDA), cudaStream_t, cudaEvent_t>;
// Host to CUDA memory.
void CopyFromHostSync(void* target, const void* source, size_t size) {
......@@ -51,6 +51,25 @@ class IoCopyHostToCudaCompute
auto* data = param.y->mutable_data(target(), param.x->memory_size());
CopyFromHostSync(data, param.x->data<void>(), param.x->memory_size());
}
std::unique_ptr<type_infer_handler_t> GetTypeInferHandler() override {
std::unique_ptr<type_infer_handler_t> res(new type_infer_handler_t);
*res = [](const std::map<std::string, const Type*>& inputs,
const std::string& out) -> const Type* {
CHECK(!inputs.empty());
auto* type = inputs.at("Input");
CHECK(type->target() == TARGET(kHost));
auto out_place = type->place();
out_place.target = TARGET(kCUDA);
auto* out_type = LookupType(type->id(), type->IsUnsupported(),
type->IsUnsupported(), out_place);
return out_type;
};
return res;
}
std::string doc() const override { return "Copy IO from HOST to CUDA"; }
};
/*
......@@ -65,6 +84,8 @@ class IoCopyCudaToHostCompute
auto* data = param.y->mutable_data(TARGET(kHost), param.x->memory_size());
CopyToHostSync(data, param.x->data<void>(), param.x->memory_size());
}
std::string doc() const override { return "Copy IO from CUDA to HOST"; }
};
} // namespace cuda
......@@ -72,7 +93,7 @@ class IoCopyCudaToHostCompute
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(io_copy, kCUDA, kAny,
REGISTER_LITE_KERNEL(io_copy, kCUDA, kAny, kAny,
paddle::lite::kernels::cuda::IoCopyHostToCudaCompute,
host_to_device)
.BindInput("Input", {paddle::lite::Type::Get<paddle::lite::TensorAnyTy>(
......@@ -81,7 +102,7 @@ REGISTER_LITE_KERNEL(io_copy, kCUDA, kAny,
TARGET(kCUDA))})
.Finalize();
REGISTER_LITE_KERNEL(io_copy, kCUDA, kAny,
REGISTER_LITE_KERNEL(io_copy, kCUDA, kAny, kAny,
paddle::lite::kernels::cuda::IoCopyCudaToHostCompute,
device_to_host)
.BindInput("Input", {paddle::lite::Type::Get<paddle::lite::TensorAnyTy>(
......
......@@ -13,3 +13,22 @@
// limitations under the License.
#include "paddle/fluid/lite/kernels/cuda/mul_compute.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(mul, kCUDA, kFloat, kNCHW,
paddle::lite::kernels::cuda::MulCompute, def)
.BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kCUDA))})
.BindInput("Y", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kCUDA))})
.BindOutput("Out", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kCUDA))})
.Finalize();
......@@ -16,6 +16,7 @@
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/types.h"
#include "paddle/fluid/lite/cuda/blas.h"
#include "paddle/fluid/lite/operators/op_params.h"
namespace paddle {
namespace lite {
......@@ -29,11 +30,29 @@ void mul_compute(const lite::cuda::Blas<float>& blas, const T* x, int x_h,
nullptr, out, 0);
}
class MulCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
class MulCompute : public OpKernel<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::MulParam;
void Run() override {}
void Run() override {
CHECK(context_) << "running context should be set first";
auto& context = context_->AsCudaContext();
CHECK(context.blas_fp32) << "blas should init first";
auto& blas = *context.blas_fp32;
const auto& param = Param<operators::MulParam>();
CHECK(param.x->target() == TARGET(kCUDA));
auto* x = param.x->data<float>();
int x_h = param.x->dims()[0];
int x_w = param.x->dims()[1];
auto* y = param.y->data<float>();
int y_h = param.y->dims()[0];
int y_w = param.y->dims()[1];
auto* out = param.output->mutable_data<float>(TARGET(kCUDA));
mul_compute<float>(blas, x, x_h, x_w, y, y_h, y_w, out);
}
virtual ~MulCompute() = default;
};
......
......@@ -51,8 +51,8 @@ void FcCompute::Run() {
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(fc, kHost, kFloat, paddle::lite::kernels::host::FcCompute,
def)
REGISTER_LITE_KERNEL(fc, kHost, kFloat, kNCHW,
paddle::lite::kernels::host::FcCompute, def)
.BindInput("Input",
{paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kHost))})
......
......@@ -20,7 +20,8 @@ namespace lite {
namespace kernels {
namespace host {
class FeedCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
class FeedCompute
: public OpKernel<TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)> {
public:
using param_t = operators::FeedParam;
......@@ -38,7 +39,7 @@ class FeedCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(feed, kHost, kFloat,
REGISTER_LITE_KERNEL(feed, kHost, kAny, kAny,
paddle::lite::kernels::host::FeedCompute, def)
.BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorAnyTy>(
TARGET(kHost))})
......
......@@ -20,7 +20,8 @@ namespace lite {
namespace kernels {
namespace host {
class FetchCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
class FetchCompute
: public OpKernel<TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)> {
public:
using param_t = operators::FeedParam;
......@@ -41,7 +42,7 @@ class FetchCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(fetch, kHost, kFloat,
REGISTER_LITE_KERNEL(fetch, kHost, kAny, kAny,
paddle::lite::kernels::host::FetchCompute, def)
.BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorAnyTy>(
TARGET(kHost))})
......
......@@ -67,7 +67,7 @@ class MulCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(mul, kHost, kFloat,
REGISTER_LITE_KERNEL(mul, kHost, kFloat, kNCHW,
paddle::lite::kernels::host::MulCompute, def)
.BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kHost))})
......
......@@ -42,6 +42,6 @@ class ReluCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(relu, kHost, kFloat,
REGISTER_LITE_KERNEL(relu, kHost, kFloat, kNCHW,
paddle::lite::kernels::host::ReluCompute, def)
.Finalize();
......@@ -50,7 +50,7 @@ class ScaleCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(scale, kHost, kFloat,
REGISTER_LITE_KERNEL(scale, kHost, kFloat, kNCHW,
paddle::lite::kernels::host::ScaleCompute, def)
.BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kHost))})
......
......@@ -24,7 +24,10 @@ bool IoCopyOp::CheckShape() const {
CHECK_OR_FALSE(param_.y);
return true;
}
bool IoCopyOp::InferShape() const { return true; }
bool IoCopyOp::InferShape() const {
param_.y->Resize(param_.x->dims());
return true;
}
bool IoCopyOp::Run() { return OpLite::Run(); }
bool IoCopyOp::AttachImpl(const paddle::framework::OpDesc &opdesc,
paddle::lite::Scope *scope) {
......
......@@ -51,9 +51,6 @@ class MulOpLite : public OpLite {
param_.x_num_col_dims = boost::get<int>(op_desc.GetAttr("x_num_col_dims"));
param_.y_num_col_dims = boost::get<int>(op_desc.GetAttr("y_num_col_dims"));
CHECK(kernel_);
kernel_->SetParam(param_);
return true;
}
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include <iostream>
#include <list>
#include <memory>
......@@ -48,8 +49,6 @@ class Factory {
}
void Register(const std::string& op_type, creator_t&& creator) {
CHECK(!creators_.count(op_type)) << "The op " << op_type
<< " has already registered";
creators_[op_type].emplace_back(std::move(creator));
}
......@@ -58,9 +57,9 @@ class Factory {
}
std::list<item_ptr_t> Creates(const std::string& op_type) const {
auto it = creators_.find(op_type);
CHECK(it != creators_.end()) << "no item called " << op_type;
std::list<item_ptr_t> res;
auto it = creators_.find(op_type);
if (it == creators_.end()) return res;
for (auto& c : it->second) {
res.emplace_back(c());
}
......
......@@ -99,7 +99,7 @@ struct variant {
size_t type() { return type_id; }
void valid() { return (type_id != invalid_type()); }
bool valid() { return (type_id != invalid_type()); }
template <typename T, typename... Args>
void set(Args&&... args) {
......
......@@ -24,6 +24,8 @@ namespace utils {
TEST(varient, test) {
variant<int, float> a;
// The initial state should be invalid.
ASSERT_FALSE(a.valid());
a.set<int>(1);
ASSERT_EQ(a.get<int>(), 1);
a.set<int>(20);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册