提交 621d1522 编写于 作者: S superjomn

make io_copy kernel pick works

上级 1fb93746
...@@ -42,6 +42,7 @@ class OpDesc { ...@@ -42,6 +42,7 @@ class OpDesc {
void CopyFrom(const OpDesc &op_desc); void CopyFrom(const OpDesc &op_desc);
proto::OpDesc *Proto(); proto::OpDesc *Proto();
const proto::OpDesc &ReadonlyProto() const { return desc_; }
std::string Type() const { return desc_.type(); } std::string Type() const { return desc_.type(); }
......
cc_library(cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host) cc_library(cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host )
cc_test(test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite) cc_test(test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite)
if(LITE_WITH_CUDA)
cc_library(cxx_api_lite_cuda SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host target_wrapper_cuda kernels_cuda)
nv_test(test_cxx_api_lite_cuda SRCS cxx_api_test.cc DEPS cxx_api_lite_cuda model_parser_lite)
endif()
...@@ -29,7 +29,7 @@ class Predictor { ...@@ -29,7 +29,7 @@ class Predictor {
public: public:
Predictor() { scope_ = std::make_shared<Scope>(); } Predictor() { scope_ = std::make_shared<Scope>(); }
void Build(const std::string& model_path, void Build(const std::string& model_path, const Place& prefer_place,
const std::vector<Place>& valid_places) { const std::vector<Place>& valid_places) {
framework::proto::ProgramDesc prog; framework::proto::ProgramDesc prog;
LoadModel(model_path, scope_.get(), &prog); LoadModel(model_path, scope_.get(), &prog);
...@@ -38,6 +38,7 @@ class Predictor { ...@@ -38,6 +38,7 @@ class Predictor {
Program program(prog_desc, scope_, valid_places); Program program(prog_desc, scope_, valid_places);
Optimizer optimizer; Optimizer optimizer;
optimizer.KernelPickPreferPlace(prefer_place);
core::KernelPickFactor factor; core::KernelPickFactor factor;
factor.ConsiderTarget(); factor.ConsiderTarget();
optimizer.Run(std::move(program), valid_places, factor); optimizer.Run(std::move(program), valid_places, factor);
......
...@@ -23,8 +23,21 @@ namespace lite { ...@@ -23,8 +23,21 @@ namespace lite {
TEST(CXXApi, test) { TEST(CXXApi, test) {
lite::Predictor predictor; lite::Predictor predictor;
#ifndef LITE_WITH_CUDA
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)}});
#else
std::vector<Place> valid_places({
Place{TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)},
Place{TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW)},
Place{TARGET(kCUDA), PRECISION(kAny), DATALAYOUT(kNCHW)},
Place{TARGET(kHost), PRECISION(kAny), DATALAYOUT(kNCHW)},
Place{TARGET(kCUDA), PRECISION(kAny), DATALAYOUT(kAny)},
Place{TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)},
});
#endif
predictor.Build("/home/chunwei/project2/models/model2", predictor.Build("/home/chunwei/project2/models/model2",
{Place{TARGET(kHost), PRECISION(kFloat)}}); Place{TARGET(kCUDA), PRECISION(kFloat)}, valid_places);
auto* input_tensor = predictor.GetInput(0); auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize({100, 100}); input_tensor->Resize({100, 100});
...@@ -54,8 +67,15 @@ USE_LITE_OP(fc); ...@@ -54,8 +67,15 @@ USE_LITE_OP(fc);
USE_LITE_OP(scale); USE_LITE_OP(scale);
USE_LITE_OP(feed); USE_LITE_OP(feed);
USE_LITE_OP(fetch); USE_LITE_OP(fetch);
USE_LITE_KERNEL(fc, kHost, kFloat, def); USE_LITE_OP(io_copy);
USE_LITE_KERNEL(mul, kHost, kFloat, def); USE_LITE_KERNEL(fc, kHost, kFloat, kNCHW, def);
USE_LITE_KERNEL(scale, kHost, kFloat, def); USE_LITE_KERNEL(mul, kHost, kFloat, kNCHW, def);
USE_LITE_KERNEL(feed, kHost, kFloat, def); USE_LITE_KERNEL(scale, kHost, kFloat, kNCHW, def);
USE_LITE_KERNEL(fetch, kHost, kFloat, def); USE_LITE_KERNEL(feed, kHost, kAny, kAny, def);
USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def);
#ifdef LITE_WITH_CUDA
USE_LITE_KERNEL(mul, kCUDA, kFloat, kNCHW, def);
USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, host_to_device);
USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, device_to_host);
#endif
...@@ -27,5 +27,6 @@ cc_test(test_tensor_lite SRCS tensor_test.cc) ...@@ -27,5 +27,6 @@ cc_test(test_tensor_lite SRCS tensor_test.cc)
cc_test(test_op_executor_lite SRCS op_executor_test.cc DEPS op_executor_lite ops_lite host_kernels) cc_test(test_op_executor_lite SRCS op_executor_test.cc DEPS op_executor_lite ops_lite host_kernels)
cc_test(test_type_system SRCS type_system_test.cc DEPS type_system) cc_test(test_type_system SRCS type_system_test.cc DEPS type_system)
cc_test(test_optimizer_lite SRCS optimizer_test.cc DEPS mir_pass_manager program_fake_utils mir_passes) cc_test(test_optimizer_lite SRCS optimizer_test.cc DEPS mir_pass_manager program_fake_utils mir_passes)
cc_test(test_types_lite SRCS types_test.cc DEPS types_lite)
add_subdirectory(mir) add_subdirectory(mir)
...@@ -17,6 +17,13 @@ ...@@ -17,6 +17,13 @@
namespace paddle { namespace paddle {
namespace lite { namespace lite {
std::string KernelBase::summary() const {
std::stringstream ss;
ss << op_type() << ":" << TargetToStr(target()) << "/"
<< PrecisionToStr(precision()) << "/" << DataLayoutToStr(layout());
return ss.str();
}
bool ParamTypeRegistry::KeyCmp::operator()( bool ParamTypeRegistry::KeyCmp::operator()(
const ParamTypeRegistry::key_t &a, const ParamTypeRegistry::key_t &a,
const ParamTypeRegistry::key_t &b) const { const ParamTypeRegistry::key_t &b) const {
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <map> #include <map>
#include <memory> #include <memory>
#include <set> #include <set>
#include <sstream>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
...@@ -34,49 +35,100 @@ namespace lite { ...@@ -34,49 +35,100 @@ namespace lite {
// different targets. // different targets.
class KernelBase { class KernelBase {
public: public:
// type_infer_handler is used to inference a output type by considering the
// input types in the type system.
using type_infer_handler_t = std::function<const Type*(
const std::map<std::string, const Type*>& input_types,
const std::string& out_arg)>;
virtual void Run() = 0; virtual void Run() = 0;
void SetContext(std::unique_ptr<KernelContext>&& ctx) { void SetContext(std::unique_ptr<KernelContext>&& ctx) {
context_ = std::move(ctx); context_ = std::move(ctx);
} }
template <typename T> template <typename T>
void SetParam(T param) { void SetParam(T param) {
param_.set<T>(param); param_.set<T>(param);
} }
template <typename P> template <typename P>
P& Param() const { P& Param() const {
return param_.get<P>(); return param_.get<P>();
} }
// This is used in the kernels that takes 'kAny' places and inference the
// output place. For `ScaleCompute` and `IoCopyCompute`, their input types are
// declared as 'kAny' in some Place field, and the output is also `kAny`, but
// when in real execution, when takes some non-kAny type as input, the
// output's kAny-fields can be determained. For example, when the
// `ScaleCompute` takes `TensorFp32NCHWTy` as input, its output should be also
// `TensorFp32NCHWTy`. This type inference rule is different for each kernel,
// so we make it a virtual method.
// One can custom this handler to make a specific type inference rule for a
// kernel, or leave the default to force the kernel use the system's
// type-inference rules.
virtual std::unique_ptr<type_infer_handler_t> GetTypeInferHandler() {
return nullptr;
}
void set_op_type(const std::string& type) { op_type_ = type; } void set_op_type(const std::string& type) { op_type_ = type; }
const std::string& op_type() const { return op_type_; } const std::string& op_type() const { return op_type_; }
void Torch() {} void Torch() {}
// Get input declaration type.
const Type* GetInputDeclType(const std::string& arg_name) {
CHECK(!op_type_.empty()) << "op_type should be set first";
const auto* type = ParamTypeRegistry::Global().RetrieveInArgument(
place(), GenParamTypeKey(), arg_name);
CHECK(type) << "no type registered for kernel [" << op_type_
<< "] input argument [" << arg_name << "]"
<< " with key " << GenParamTypeKey();
return type->type;
}
// Get output declaration type.
const Type* GetOutputDeclType(const std::string& arg_name) {
CHECK(!op_type_.empty()) << "op_type should be set first";
const auto* type = ParamTypeRegistry::Global().RetrieveOutArgument(
place(), GenParamTypeKey(), arg_name);
CHECK(type) << "no type registered for kernel [" << op_type_
<< "] output argument [" << arg_name << "]";
return type->type;
}
void set_alias(const std::string& x) {
alias_ = x;
LOG(INFO) << "kernel " << op_type() << " setting alias " << alias();
}
const std::string& alias() const { return alias_; }
virtual Place place() const = 0; virtual Place place() const = 0;
virtual TargetType target() const = 0; virtual TargetType target() const = 0;
virtual PrecisionType precision() const = 0; virtual PrecisionType precision() const = 0;
virtual DataLayoutType layout() const = 0; virtual DataLayoutType layout() const = 0;
const KernelContext* context() const { return context_.get(); } const KernelContext* context() const { return context_.get(); }
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual ~KernelBase() = default; // Short human-readable document.
std::string summary() const;
// Long human-readable document.
virtual std::string doc() const { return ""; }
std::string DebugString() const { std::string GenParamTypeKey() const {
std::stringstream ss; std::stringstream ss;
ss << op_type() << ":" << TargetToStr(target()) << "/" LOG(INFO) << "alias : " << alias_;
<< PrecisionToStr(precision()) << "/" << DataLayoutToStr(layout()); ss << op_type() << "/" << alias_;
return ss.str(); return ss.str();
} }
virtual ~KernelBase() = default;
protected: protected:
std::unique_ptr<KernelContext> context_; std::unique_ptr<KernelContext> context_;
mutable operators::param_t param_; mutable operators::param_t param_;
// The corresponding op type. // The corresponding op type.
std::string op_type_; std::string op_type_{};
std::string alias_{};
}; };
// Light-weight kernel implementation. // Light-weight kernel implementation.
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "memory.h" #include "paddle/fluid/lite/core/memory.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#pragma once #pragma once
#include <glog/logging.h> #include <glog/logging.h>
#include "target_wrapper.h" #include "paddle/fluid/lite/core/target_wrapper.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -26,9 +26,12 @@ static void* TargetMalloc(TargetType target, size_t size) { ...@@ -26,9 +26,12 @@ static void* TargetMalloc(TargetType target, size_t size) {
case TargetType::kX86: case TargetType::kX86:
data = TargetWrapper<TARGET(kHost)>::Malloc(size); data = TargetWrapper<TARGET(kHost)>::Malloc(size);
break; break;
#ifdef LITE_WITH_CUDA
case TargetType::kCUDA: case TargetType::kCUDA:
data = TargetWrapper<TARGET(kCUDA)>::Malloc(size); data =
TargetWrapper<TARGET(kCUDA), cudaStream_t, cudaEvent_t>::Malloc(size);
break; break;
#endif // LITE_WITH_CUDA
default: default:
LOG(FATAL) << "Unknown supported target " << TargetToStr(target); LOG(FATAL) << "Unknown supported target " << TargetToStr(target);
} }
......
...@@ -7,9 +7,12 @@ cc_library(mir_passes ...@@ -7,9 +7,12 @@ cc_library(mir_passes
SRCS static_kernel_pick_pass.cc SRCS static_kernel_pick_pass.cc
variable_place_inference_pass.cc variable_place_inference_pass.cc
io_complement_pass.cc io_complement_pass.cc
io_copy_kernel_pick_pass.cc
graph_visualize_pass.cc graph_visualize_pass.cc
generate_program_pass.cc generate_program_pass.cc
argument_type_display_pass.cc
demo_pass.cc demo_pass.cc
runtime_context_assign_pass.cc
DEPS mir_pass types_lite) DEPS mir_pass types_lite)
cc_test(test_mir_pass_manager SRCS pass_manager_test.cc DEPS mir_pass_manager mir_passes) cc_test(test_mir_pass_manager SRCS pass_manager_test.cc DEPS mir_pass_manager mir_passes)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
class ArgumentTypeDisplayPass : public DebugPass {
public:
void Apply(std::unique_ptr<mir::SSAGraph>& graph) override {
LOG(INFO) << "== Argument types ==";
for (auto& node : graph->mutable_nodes()) {
if (!node.IsArgument()) continue;
auto* type = node.AsArgument().type;
if (type) {
LOG(INFO) << "* ARG " << node.AsArgument().name << " type: " << *type;
} else {
LOG(INFO) << "* ARG " << node.AsArgument().name << " type: UNK";
}
}
LOG(INFO) << "---------------------";
}
};
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(argument_type_display_pass,
paddle::lite::mir::ArgumentTypeDisplayPass);
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/core/mir/generate_program_pass.h" #include "paddle/fluid/lite/core/mir/generate_program_pass.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h" #include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle { namespace paddle {
...@@ -20,9 +21,11 @@ namespace lite { ...@@ -20,9 +21,11 @@ namespace lite {
namespace mir { namespace mir {
void GenerateProgramPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) { void GenerateProgramPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
LOG(INFO) << "final program \n" << Visualize(graph.get());
for (auto& item : graph->InstructTopologicalOrder()) { for (auto& item : graph->InstructTopologicalOrder()) {
if (item->IsInstruct()) { if (item->IsInstruct()) {
auto& instruct = item->AsInstruct(); auto& instruct = item->AsInstruct();
LOG(INFO) << instruct;
insts_.emplace_back(instruct.op, insts_.emplace_back(instruct.op,
std::move(instruct.valid_kernels.front())); std::move(instruct.valid_kernels.front()));
} }
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/core/mir/io_complement_pass.h" #include "paddle/fluid/lite/core/mir/io_complement_pass.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h" #include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle { namespace paddle {
...@@ -21,28 +22,161 @@ namespace mir { ...@@ -21,28 +22,161 @@ namespace mir {
void IoComplementPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) { void IoComplementPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
// Start from inputs of the graph, those should have place set. // Start from inputs of the graph, those should have place set.
std::list<Node*> nodes;
for (auto& node : graph->mutable_nodes()) { for (auto& node : graph->mutable_nodes()) {
if (!node.IsInstruct()) continue; nodes.push_back(&node);
auto& inst = node.AsInstruct(); }
CHECK(!valid_places_.empty());
for (auto& node : nodes) {
if (!node->IsInstruct()) continue;
auto inlinks = node->inlinks;
for (auto* in : inlinks) {
ComplementInputs(graph.get(), node, in);
}
}
// inputs // PickIoCopyKernel(graph.get());
for (auto* in : node.inlinks) {
LOG(INFO) << "\n" << Visualize(graph.get());
}
void IoComplementPass::ComplementInputs(SSAGraph* graph, Node* inst_node,
Node* in) {
// If this input is out of date.
if (inst_node->inlinks.end() ==
std::find(inst_node->inlinks.begin(), inst_node->inlinks.end(), in))
return;
CHECK(inst_node->IsInstruct());
auto& inst = inst_node->AsInstruct();
CHECK(in->IsRoleSet());
CHECK(in->IsArgument()); CHECK(in->IsArgument());
auto name = in->AsArgument().name; auto in_arg_name = in->AsArgument().name;
std::string tmp; std::string tmp;
CHECK(inst.op_info->GetInputArgname(name, &tmp)); CHECK(inst.op_info()->GetInputArgname(in_arg_name, &tmp));
auto type = auto decl_arg_type = inst.picked_kernel().GetInputDeclType(tmp);
ParamTypeRegistry::Global().Retrieve<ParamTypeRegistry::IO::kInput>(
inst.place, inst.op_type, tmp);
CHECK(type) << "no param type found for " << inst.op_type << ":" << name
<< " " << inst.place;
CHECK(type->type);
CHECK(in->AsArgument().type); CHECK(in->AsArgument().type);
if (!TypeCompatible(*type->type, *in->AsArgument().type)) { if (!TypeCompatibleTo(*in->AsArgument().type, *decl_arg_type)) {
LOG(INFO) << "found IO unmatched tensor: " << in->AsArgument().name; LOG(INFO) << "found IO unmatched tensor: " << in->AsArgument().name
<< " for kernel " << inst.op->DebugString() << " "
<< *in->AsArgument().type << " -> " << *decl_arg_type;
// Add an IoCopy instruction to make the input compatible with other dist.
AddIoCopyInst(*in->AsArgument().type, *decl_arg_type, in->AsArgument().name,
graph, inst_node, valid_places_);
}
}
void UpdateOpdescInputName(framework::OpDesc* desc,
const std::string& old_arg_name,
const std::string& new_arg_name) {
for (auto& item : *desc->Proto()->mutable_inputs()) {
for (int i = 0; i < item.mutable_arguments()->size(); i++) {
auto* arg = item.mutable_arguments(i);
if (*arg == old_arg_name) {
*arg = new_arg_name;
}
}
}
}
void IoComplementPass::AddIoCopyInst(const Type& from, const Type& to,
const std::string& var, SSAGraph* graph,
Node* inst_node,
const std::vector<Place>& valid_places) {
CHECK(!valid_places.empty()) << "valid_place should be set";
// var -> new_transform_op -> new_var -> inst
// So there will be a new Argument node and a new IoCopy Instruct Node.
auto node_id = [&] { return graph->nodes().size(); };
auto io_copy_output_name = var + "/trans/" + std::to_string(node_id());
auto* io_copy_output_arg = graph->NewArgumentNode(io_copy_output_name);
auto* io_copy_inst = graph->NewInstructNode();
// create Op and kernels.
auto io_copy_op = LiteOpRegistry::Global().Create("io_copy");
// CHECK(io_copy_op);
// Create the new var manually.
inst_node->AsInstruct().op->scope()->Var(io_copy_output_name);
// Create IoCopy Instruction.
framework::OpDesc op_desc;
op_desc.SetType("io_copy");
op_desc.SetInput("Input", {var});
op_desc.SetOutput("Out", {io_copy_output_name});
op_desc.Flush();
io_copy_op->Attach(op_desc, inst_node->AsInstruct().op->scope());
auto kernels = io_copy_op->CreateKernels(valid_places);
io_copy_inst->AsInstruct("io_copy", std::move(kernels), io_copy_op);
// Remove the old link
RemoveDirectedLink(graph->Argument(var), inst_node);
// Update the original instruction OpDesc.
// Update its input to the io_copy_output_name
auto& inst = inst_node->AsInstruct();
auto inst_program_desc = inst.op_info()->desc();
// Add new link, var -> new_inst, new_inst->newarg, newarg->inst
DirectedLink(graph->Argument(var), io_copy_inst);
DirectedLink(io_copy_inst, io_copy_output_arg);
DirectedLink(io_copy_output_arg, inst_node);
// reset opdesc and update kernel information
auto desc_dummy = inst_node->AsInstruct().op->op_info()->desc();
UpdateInputTo(&desc_dummy, var, io_copy_output_name);
framework::OpDesc desc_fake(desc_dummy, nullptr);
inst_node->AsInstruct().op->Attach(desc_fake,
inst_node->AsInstruct().op->scope());
std::string tmp;
if (inst_node->AsInstruct().op_info()->GetInputArgname("a", &tmp)) {
CHECK(false) << "get old a " << tmp;
}
for (auto& kernel : inst_node->AsInstruct().valid_kernels) {
inst_node->AsInstruct().op->AttachKernel(kernel.get());
}
graph->CheckValid();
}
void IoComplementPass::PickIoCopyKernel(SSAGraph* graph) {
for (auto& node : graph->mutable_nodes()) {
if (node.IsInstruct() && node.AsInstruct().op_type == "io_copy") {
auto& kernels = node.AsInstruct().valid_kernels;
CHECK(!kernels.empty()) << "No valid kernels found for IoCopy Op";
for (auto& kernel : kernels) {
CHECK_EQ(node.inlinks.size(), 1UL);
CHECK_EQ(node.outlinks.size(), 1UL);
auto* inty = node.inlinks.front()->AsArgument().type;
auto* outy = node.outlinks.front()->AsArgument().type;
const Type* in_arg_ty = kernel->GetInputDeclType("Input");
if (TypeCompatibleTo(*inty, *in_arg_ty)) {
const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
// Both the input and output type matches, remove other kernels
// directly.
if (out_arg_ty->target() == outy->target()) {
LOG(INFO) << "get a IOCopy kernel";
auto x = std::move(kernel);
kernels.clear();
kernels.emplace_back(std::move(x));
break;
} }
} }
} }
}
}
// Check the compatiblity.
}
void IoComplementPass::SetValidPlaces(const std::vector<Place>& valid_places) {
CHECK(!valid_places.empty());
valid_places_ = valid_places;
} }
} // namespace mir } // namespace mir
......
...@@ -15,18 +15,47 @@ ...@@ -15,18 +15,47 @@
#pragma once #pragma once
#include "paddle/fluid/lite/core/mir/pass.h" #include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace mir { namespace mir {
static void UpdateInputTo(framework::proto::OpDesc* desc,
const std::string& from, const std::string& to) {
for (auto& item : *desc->mutable_inputs()) {
for (auto& input : *item.mutable_arguments()) {
if (input == from) {
LOG(INFO) << "** update input argument from " << from << " to " << to;
input = to;
}
}
}
}
/* /*
* IoComplementPass complement the necessary instruction to make data * IoComplementPass complement the necessary instruction to make data
* transferring or transformation between different places. * transferring or transformation between different places.
*/ */
class IoComplementPass : public ProgramPass { class IoComplementPass : public ProgramPass {
public: public:
void Apply(std::unique_ptr<mir::SSAGraph> &graph) override; void Apply(std::unique_ptr<mir::SSAGraph>& graph) override;
void ComplementInputs(SSAGraph* graph, Node* inst_node, Node* in);
void AddIoCopyInst(const Type& from, const Type& to, const std::string& var,
SSAGraph* graph, Node* inst_node,
const std::vector<Place>& valid_places);
void SetValidPlaces(const std::vector<Place>& valid_places);
// Pick the right kernel of IoCopy considering the input and output Type.
void PickIoCopyKernel(SSAGraph* graph);
const std::vector<Place>& valid_places() const { return valid_places_; };
private:
std::vector<Place> valid_places_;
}; };
} // namespace mir } // namespace mir
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
class IoCopyKernelPickPass : public InstructionPass {
public:
void Apply(std::unique_ptr<mir::SSAGraph>& graph) override {
for (auto& node : graph->mutable_nodes()) {
if (!node.IsInstruct()) continue;
auto& inst = node.AsInstruct();
if (inst.op_type != "io_copy") continue;
LOG(INFO) << "....> picking a IO COPY kernel";
auto& kernels = node.AsInstruct().valid_kernels;
CHECK(!kernels.empty()) << "No valid kernels found for IoCopy Op";
const auto* inty = node.inlinks.front()->AsArgument().type;
const auto* outy = node.outlinks.front()->AsArgument().type;
LOG(INFO) << "input type " << *inty;
LOG(INFO) << "output type " << *outy;
bool is_found = false;
LOG(INFO) << "kernels size " << kernels.size();
for (auto& kernel : kernels) {
CHECK_EQ(node.inlinks.size(), 1UL);
CHECK_EQ(node.outlinks.size(), 1UL);
const Type* in_arg_ty = kernel->GetInputDeclType("Input");
const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
LOG(INFO) << "checking kernel candidate " << *in_arg_ty << "->"
<< *out_arg_ty;
if (inty->target() == in_arg_ty->target()) {
// Both the input and output type matches, remove other kernels
// directly.
if (out_arg_ty->target() == outy->target()) {
LOG(INFO) << "get a IOCopy kernel";
auto x = std::move(kernel);
kernels.clear();
kernels.emplace_back(std::move(x));
is_found = true;
break;
}
}
}
CHECK(is_found) << "Can't find a IoCopy kernel for IO: " << *inty << "->"
<< *outy;
}
}
};
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(io_copy_kernel_pick_pass,
paddle::lite::mir::IoCopyKernelPickPass);
...@@ -34,30 +34,43 @@ class Node { ...@@ -34,30 +34,43 @@ class Node {
Node() = default; Node() = default;
enum class Role { enum class Role {
kUnk = -1, kArgument = 0,
kArgument,
kInstruct, kInstruct,
kNumRoles /*should be last*/ kNumRoles, /*should be last*/
kUnk,
}; };
struct Instruct { struct Instruct {
std::string op_type; std::string op_type;
Place place;
// The kernel instances this Instruct contains. // The kernel instances this Instruct contains.
std::vector<std::unique_ptr<KernelBase>> valid_kernels; std::vector<std::unique_ptr<KernelBase>> valid_kernels;
std::shared_ptr<OpInfo> op_info;
// TODO(Superjomn) make this a shared_ptr for resource safety. // TODO(Superjomn) make this a shared_ptr for resource safety.
std::shared_ptr<OpLite> op; // we hold op to run InferShape std::shared_ptr<OpLite> op; // we hold op to run InferShape
const OpInfo* op_info() {
CHECK(op);
return op->op_info();
}
Place place() const {
CHECK(!valid_kernels.empty());
return valid_kernels.front()->place();
}
KernelBase& picked_kernel() { KernelBase& picked_kernel() {
CHECK(!valid_kernels.empty()); CHECK(!valid_kernels.empty());
return *valid_kernels.front(); return *valid_kernels.front();
} }
friend std::ostream& operator<<(std::ostream& os, const Instruct& other) {
os << "Instruct " << other.op_type << " " << other.place();
return os;
}
}; };
struct Argument { struct Argument {
std::string name; std::string name;
const Type* type; const Type* type{};
// Weight is a special kind of argument, it is marked as weight explicitly // Weight is a special kind of argument, it is marked as weight explicitly
// so that some weight related optimization can take place. // so that some weight related optimization can take place.
bool is_weight{false}; bool is_weight{false};
...@@ -71,13 +84,11 @@ class Node { ...@@ -71,13 +84,11 @@ class Node {
Instruct& AsInstruct(const std::string& op_type, Instruct& AsInstruct(const std::string& op_type,
std::vector<std::unique_ptr<KernelBase>>&& kernels, std::vector<std::unique_ptr<KernelBase>>&& kernels,
const std::shared_ptr<OpLite>& op, const std::shared_ptr<OpLite>& op) {
const std::shared_ptr<lite::OpInfo>& op_info) {
auto& x = AsInstruct(); auto& x = AsInstruct();
x.op_type = op_type; x.op_type = op_type;
x.op = op; x.op = op;
x.valid_kernels = std::move(kernels); x.valid_kernels = std::move(kernels);
x.op_info = op_info;
return x; return x;
} }
...@@ -100,8 +111,25 @@ class Node { ...@@ -100,8 +111,25 @@ class Node {
instruct_.reset(new Instruct); instruct_.reset(new Instruct);
return *instruct_; return *instruct_;
} }
friend std::ostream& operator<<(std::ostream& os, Node& other) {
os << static_cast<int>(other.role_) << " ";
if (!other.IsRoleSet()) {
os << "Unk role node";
}
if (other.IsArgument()) {
auto& arg = other.AsArgument();
os << "Argument " << arg.name;
}
if (other.IsInstruct()) {
auto& arg = other.AsInstruct();
os << "Instruct " << arg.op_type;
}
return os;
}
// Check roles. // Check roles.
bool IsRoleSet() const { return role_ == Role::kUnk; } bool IsRoleSet() const { return role_ != Role::kUnk; }
bool IsInstruct() const { return role_ == Role::kInstruct; } bool IsInstruct() const { return role_ == Role::kInstruct; }
bool IsArgument() const { return role_ == Role::kArgument; } bool IsArgument() const { return role_ == Role::kArgument; }
......
...@@ -26,3 +26,5 @@ USE_MIR_PASS(static_kernel_pick_pass); ...@@ -26,3 +26,5 @@ USE_MIR_PASS(static_kernel_pick_pass);
USE_MIR_PASS(variable_place_inference_pass); USE_MIR_PASS(variable_place_inference_pass);
USE_MIR_PASS(io_complement_pass); USE_MIR_PASS(io_complement_pass);
USE_MIR_PASS(generate_program_pass); USE_MIR_PASS(generate_program_pass);
USE_MIR_PASS(io_copy_kernel_pick_pass);
USE_MIR_PASS(argument_type_display_pass);
...@@ -89,6 +89,144 @@ std::vector<mir::Node *> SSAGraph::InstructTopologicalOrder() { ...@@ -89,6 +89,144 @@ std::vector<mir::Node *> SSAGraph::InstructTopologicalOrder() {
return res; return res;
} }
void SSAGraph::GraphCreateTmpVarNodes(const Program &program) {
for (const auto &name : program.tmp_vars) {
LOG(INFO) << "create arg node " << name;
node_storage_.emplace_back();
auto &new_node = node_storage_.back();
new_node.AsArgument(name);
arguments_[name] = &new_node;
}
}
void SSAGraph::GraphCreateWeightVarNodes(const Program &program) {
// create weight nodes.
for (const auto &name : program.weights) {
LOG(INFO) << "create arg node " << name;
node_storage_.emplace_back();
auto &new_node = node_storage_.back();
new_node.AsArgument(name);
arguments_[name] = &new_node;
}
}
Node *SSAGraph::GraphCreateInstructNode(
const Program &program, const std::shared_ptr<OpLite> &op,
const std::vector<Place> &valid_places) {
node_storage_.emplace_back();
// TODO(Superjomn) remove one valid_places here.
op->SetValidPlaces(valid_places);
auto &new_node = node_storage_.back();
auto kernels = op->CreateKernels(valid_places);
node_storage_.back().AsInstruct(op->op_type_, std::move(kernels), op);
CHECK(new_node.inlinks.empty()) << "duplicate Build found";
CHECK(new_node.outlinks.empty()) << "duplicate Build found";
return &node_storage_.back();
}
void SSAGraph::Build(const Program &program,
const std::vector<Place> &valid_places) {
CHECK(node_storage_.empty());
GraphCreateTmpVarNodes(program);
GraphCreateWeightVarNodes(program);
CHECK(CheckNodesRoleSet());
for (auto &op : program.ops) {
auto *op_node = GraphCreateInstructNode(program, op, valid_places);
LOG(INFO) << "checking op " << op->op_type_;
for (const std::string &name : op->op_info()->input_names()) {
auto *arg = Argument(name);
LOG(INFO) << "input " << name;
CHECK(arg->IsRoleSet());
DirectedLink(arg, op_node);
}
for (const std::string &name : op->op_info()->output_names()) {
if (!arguments_.count(name)) {
NewArgumentNode(name);
}
LOG(INFO) << "output " << name;
auto *arg = arguments_.at(name);
CHECK(arg->IsRoleSet());
DirectedLink(op_node, arg);
}
CHECK(CheckLinksRoleSet());
}
MarkArgumentWeights(program);
CheckValid();
}
mir::Node *SSAGraph::Argument(const std::string &name) {
auto it = arguments_.find(name);
CHECK(it != arguments_.end()) << "no argument called " << name;
return it->second;
}
std::vector<mir::Node *> SSAGraph::inputs() {
std::vector<mir::Node *> res;
for (auto &node : node_storage_) {
if (node.inlinks.empty()) {
res.push_back(&node);
}
}
return res;
}
std::vector<mir::Node *> SSAGraph::outputs() {
std::vector<mir::Node *> res;
for (auto &node : node_storage_) {
if (node.outlinks.empty()) {
res.push_back(&node);
}
}
return res;
}
mir::Node *SSAGraph::RetrieveArgument(const std::string &arg) {
auto it = arguments_.find(arg);
if (it != arguments_.end()) {
return it->second;
}
return nullptr;
}
bool SSAGraph::CheckNodesRoleSet() {
for (auto &node : mutable_nodes()) {
CHECK_OR_FALSE(node.IsRoleSet());
}
return true;
}
bool SSAGraph::CheckLinksRoleSet() {
for (auto &node : mutable_nodes()) {
CHECK_OR_FALSE(node.IsRoleSet());
if (!node.IsInstruct()) continue;
for (auto *x : node.inlinks) {
CHECK_OR_FALSE(x->IsRoleSet());
CHECK_OR_FALSE(x->IsArgument());
}
for (auto *x : node.outlinks) {
CHECK_OR_FALSE(x->IsRoleSet());
CHECK_OR_FALSE(x->IsArgument());
}
}
return true;
}
Node *SSAGraph::NewArgumentNode(const std::string &name) {
node_storage_.emplace_back();
CHECK(!arguments_.count(name)) << "duplicate argument called " << name;
arguments_[name] = &node_storage_.back();
node_storage_.back().AsArgument(name);
return &node_storage_.back();
}
Node *SSAGraph::NewInstructNode() {
node_storage_.emplace_back();
return &node_storage_.back();
}
} // namespace mir } // namespace mir
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -35,104 +35,44 @@ class SSAGraph : GraphBase { ...@@ -35,104 +35,44 @@ class SSAGraph : GraphBase {
public: public:
// @param program: the op program // @param program: the op program
// @param valid_places: the valid places user set for the system. // @param valid_places: the valid places user set for the system.
void Build(const Program &program, const std::vector<Place> &valid_places) { void Build(const Program &program, const std::vector<Place> &valid_places);
// create temporary nodes.
for (const auto &name : program.tmp_vars) {
node_storage_.emplace_back();
auto &new_node = node_storage_.back();
auto &arg = new_node.AsArgument();
arg.name = name;
arguments_[name] = &new_node;
}
// create weight nodes. mir::Node *Argument(const std::string &name);
for (const auto &name : program.weights) {
node_storage_.emplace_back();
auto &new_node = node_storage_.back();
auto &arg = new_node.AsArgument();
arg.name = name;
arguments_[name] = &new_node;
}
for (auto &op : program.ops) {
node_storage_.emplace_back();
// TODO(Superjomn) remove one valid_places here.
op->SetValidPlaces(valid_places);
auto &new_node = node_storage_.back();
auto kernels = op->CreateKernels(valid_places);
node_storage_.back().AsInstruct(op->op_type_, std::move(kernels), op,
op->op_info());
CHECK(new_node.inlinks.empty()) << "duplicate Build found";
CHECK(new_node.outlinks.empty()) << "duplicate Build found";
// collect inputs and outputs
for (const std::string &name : op->op_info()->input_names()) {
auto *arg = Argument(name);
new_node.inlinks.push_back(arg);
arg->outlinks.push_back(&new_node);
}
for (const std::string &name : op->op_info()->output_names()) {
if (!arguments_.count(name)) {
node_storage_.emplace_back();
auto &new_node = node_storage_.back();
auto &arg = new_node.AsArgument(name);
arg.name = name;
arguments_.emplace(name, &new_node);
}
auto *arg = arguments_.at(name);
new_node.outlinks.push_back(arg);
arg->inlinks.push_back(&new_node);
}
}
MarkArgumentWeights(program);
}
mir::Node *Argument(const std::string &name) {
auto it = arguments_.find(name);
CHECK(it != arguments_.end()) << "no argument called " << name;
return it->second;
}
std::vector<mir::Node *> InstructTopologicalOrder(); std::vector<mir::Node *> InstructTopologicalOrder();
// The inputs of the graph. // The inputs of the graph.
std::vector<mir::Node *> inputs() { std::vector<mir::Node *> inputs();
std::vector<mir::Node *> res;
for (auto &node : node_storage_) {
if (node.inlinks.empty()) {
res.push_back(&node);
}
}
return res;
}
// The outputs of the graph. // The outputs of the graph.
std::vector<mir::Node *> outputs() { std::vector<mir::Node *> outputs();
std::vector<mir::Node *> res;
for (auto &node : node_storage_) {
if (node.outlinks.empty()) {
res.push_back(&node);
}
}
return res;
}
const std::list<mir::Node> &nodes() const { return node_storage_; } const std::list<mir::Node> &nodes() const { return node_storage_; }
std::list<mir::Node> &mutable_nodes() { return node_storage_; } std::list<mir::Node> &mutable_nodes() { return node_storage_; }
mir::Node *RetrieveArgument(const std::string &arg) { mir::Node *RetrieveArgument(const std::string &arg);
auto it = arguments_.find(arg);
if (it != arguments_.end()) { Node *NewArgumentNode(const std::string &name);
return it->second; Node *NewInstructNode();
}
return nullptr; void CheckValid() {
CHECK(CheckBidirectionalConnection());
CHECK(CheckNodesRoleSet());
CHECK(CheckLinksRoleSet());
} }
private: private:
void GraphCreateTmpVarNodes(const Program &program);
void GraphCreateWeightVarNodes(const Program &program);
Node *GraphCreateInstructNode(const Program &program,
const std::shared_ptr<OpLite> &op,
const std::vector<Place> &valid_places);
// Check the bidirectional connection. // Check the bidirectional connection.
bool CheckBidirectionalConnection(); bool CheckBidirectionalConnection();
bool CheckNodesRoleSet();
// Check all the items's role in inlinks and outlinks is set.
bool CheckLinksRoleSet();
void MarkArgumentWeights(const Program &program) { void MarkArgumentWeights(const Program &program) {
for (const auto &name : program.weights) { for (const auto &name : program.weights) {
...@@ -152,6 +92,48 @@ class SSAGraph : GraphBase { ...@@ -152,6 +92,48 @@ class SSAGraph : GraphBase {
std::map<std::string, mir::Node *> arguments_; std::map<std::string, mir::Node *> arguments_;
}; };
// Remove the link between a -> b.
static void RemoveDirectedLink(Node *a, Node *b) {
auto it = std::find(b->inlinks.begin(), b->inlinks.end(), a);
if (it != b->inlinks.end()) {
b->inlinks.erase(it);
}
auto it1 = std::find(a->outlinks.begin(), a->outlinks.end(), b);
if (it1 != a->outlinks.end()) {
a->outlinks.erase((it1));
}
}
// Link a -> b.
static void DirectedLink(Node *a, Node *b) {
// Eagerly remove first, to avoid duplicate link.
RemoveDirectedLink(a, b);
a->outlinks.push_back(b);
b->inlinks.push_back(a);
}
static void LocalInferenceType(Node *a, Node *b, const std::string &arg_name) {
// instr -> output argument
if (a->IsInstruct() && b->IsArgument()) {
auto &inst = a->AsInstruct();
auto &output = b->AsArgument();
if (!output.type) {
output.type = inst.picked_kernel().GetOutputDeclType(arg_name);
}
}
// input argument -> instr
if (a->IsArgument() && b->IsInstruct()) {
auto &input = a->AsArgument();
auto &inst = b->AsInstruct();
if (!input.type) {
input.type = inst.picked_kernel().GetInputDeclType(arg_name);
}
}
}
} // namespace mir } // namespace mir
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -37,7 +37,9 @@ void StaticKernelPickPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) { ...@@ -37,7 +37,9 @@ void StaticKernelPickPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
auto& instruct = node.AsInstruct(); auto& instruct = node.AsInstruct();
std::vector<std::pair<size_t, std::unique_ptr<KernelBase>>> scored; std::vector<std::pair<size_t, std::unique_ptr<KernelBase>>> scored;
for (auto&& kernel : instruct.valid_kernels) { for (auto&& kernel : instruct.valid_kernels) {
scored.emplace_back(KernelGrade(*kernel), std::move(kernel)); size_t score = KernelGrade(*kernel);
LOG(INFO) << "kernel " << kernel->summary() << " " << score;
scored.emplace_back(score, std::move(kernel));
} }
std::sort(scored.begin(), scored.end(), KernelScoreCmp); std::sort(scored.begin(), scored.end(), KernelScoreCmp);
...@@ -47,7 +49,6 @@ void StaticKernelPickPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) { ...@@ -47,7 +49,6 @@ void StaticKernelPickPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
// TODO(Superjomn) reconsider this. // TODO(Superjomn) reconsider this.
instruct.valid_kernels.clear(); instruct.valid_kernels.clear();
instruct.valid_kernels.emplace_back(std::move(scored.front().second)); instruct.valid_kernels.emplace_back(std::move(scored.front().second));
instruct.place = instruct.valid_kernels.front()->place();
LOG(INFO) << "pick " << instruct.valid_kernels.front()->name(); LOG(INFO) << "pick " << instruct.valid_kernels.front()->name();
} }
} }
......
...@@ -37,6 +37,7 @@ class StaticKernelPickPass : public mir::InstructionPass { ...@@ -37,6 +37,7 @@ class StaticKernelPickPass : public mir::InstructionPass {
public: public:
void Apply(std::unique_ptr<mir::SSAGraph>& graph) override; void Apply(std::unique_ptr<mir::SSAGraph>& graph) override;
void SetPreferPlace(const Place& place) { place_ = place; }
const Place& place() const { return place_; } const Place& place() const { return place_; }
const core::KernelPickFactor& kernel_pick_factors() const { const core::KernelPickFactor& kernel_pick_factors() const {
return kernel_pick_factors_; return kernel_pick_factors_;
...@@ -51,16 +52,32 @@ class StaticKernelPickPass : public mir::InstructionPass { ...@@ -51,16 +52,32 @@ class StaticKernelPickPass : public mir::InstructionPass {
size_t score{}; size_t score{};
const int kMax = const int kMax =
std::numeric_limits<core::KernelPickFactor::value_type>::max(); std::numeric_limits<core::KernelPickFactor::value_type>::max();
// The more important factor comes first
if (kernel_pick_factors_.IsTargetConsidered() && if (kernel_pick_factors_.IsTargetConsidered() &&
place().target == kernel.target()) { (place().target == kernel.target() || kernel.target() == TARGET(kAny) ||
place().target == TARGET(kAny))) {
score += score +=
kMax / static_cast<int>(core::KernelPickFactor::Factor::TargetFirst); kMax / static_cast<int>(core::KernelPickFactor::Factor::TargetFirst);
} }
if (kernel_pick_factors_.IsPrecisionConsidered() && if (kernel_pick_factors_.IsPrecisionConsidered() &&
place().precision == kernel.precision()) { (place().precision == kernel.precision() ||
kernel.precision() == PRECISION(kAny) ||
place().precision == PRECISION(kAny))) {
score += kMax / score += kMax /
static_cast<int>(core::KernelPickFactor::Factor::PrecisionFirst); static_cast<int>(core::KernelPickFactor::Factor::PrecisionFirst);
} }
if (kernel_pick_factors_.IsDataLayoutConsidered() &&
(place().layout == kernel.layout() ||
kernel.layout() == DATALAYOUT(kAny) ||
place().layout == DATALAYOUT(kAny))) {
score += kMax / static_cast<int>(
core::KernelPickFactor::Factor::DataLayoutFirst);
}
LOG(INFO) << "picker tactic " << kernel_pick_factors_;
LOG(INFO) << "kernel place " << kernel.place();
LOG(INFO) << "picker place " << place();
LOG(INFO) << "score " << score;
// The data layout is not considered, for the input and output arguments // The data layout is not considered, for the input and output arguments
// might have different data layout. // might have different data layout.
......
...@@ -22,8 +22,8 @@ namespace mir { ...@@ -22,8 +22,8 @@ namespace mir {
void VariablePlaceInferencePass::Apply(std::unique_ptr<mir::SSAGraph>& graph) { void VariablePlaceInferencePass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
MarkInputPlace(graph.get()); MarkInputPlace(graph.get());
InferenceArgumentPlace(graph.get()); InferenceArgumentPlace(graph.get());
CheckAllArgumentTypeDetermined(graph.get());
} }
} // namespace mir } // namespace mir
......
...@@ -31,6 +31,7 @@ class VariablePlaceInferencePass : public DebugPass { ...@@ -31,6 +31,7 @@ class VariablePlaceInferencePass : public DebugPass {
private: private:
// Mark the place of input arguments. // Mark the place of input arguments.
void MarkInputPlace(SSAGraph* graph) { void MarkInputPlace(SSAGraph* graph) {
CHECK(!graph->inputs().empty()) << "graph's inputs should be set";
for (const auto& v : graph->inputs()) { for (const auto& v : graph->inputs()) {
// the feed op might in the inputs // the feed op might in the inputs
if (v->IsInstruct()) { if (v->IsInstruct()) {
...@@ -39,54 +40,60 @@ class VariablePlaceInferencePass : public DebugPass { ...@@ -39,54 +40,60 @@ class VariablePlaceInferencePass : public DebugPass {
} }
// auto& arg = v->AsArgument(); // auto& arg = v->AsArgument();
// arg.place.target = argument_default_target_; // LOG(INFO) << "get graph input " << arg.name << " " << *arg.type;
// arg.type.target = argument_default_target_;
// the other place description can't be determined yet, until their first // the other place description can't be determined yet, until their first
// usage by some kernel. // usage by some kernel.
} }
} }
void CheckAllArgumentTypeDetermined(SSAGraph* graph) {
for (auto& node : graph->mutable_nodes()) {
if (node.IsArgument()) {
CHECK(node.AsArgument().type) << "node " << node.AsArgument().name
<< " type not determined";
}
}
}
void InferenceArgumentPlace(SSAGraph* graph) { void InferenceArgumentPlace(SSAGraph* graph) {
LOG(INFO) << "param-type-registry:\n" << ParamTypeRegistry::Global(); LOG(INFO) << "param-type-registry:\n" << ParamTypeRegistry::Global();
for (auto& x : graph->InstructTopologicalOrder()) { for (auto& x : graph->InstructTopologicalOrder()) {
auto& inst = x->AsInstruct(); auto& inst = x->AsInstruct();
CHECK(inst.place.is_valid()) // The IoCopyOp is a tool operator, it won't support the type inference.
<< "kernel's place should be set when loaded"; if (inst.op_type == "io_copy") continue;
// LOG(INFO) << "- inferencing type " <<
// deal with inputs // deal with inputs
for (auto& arg_name : inst.op_info->input_argnames()) { for (auto& arg_name : inst.op_info()->input_argnames()) {
auto type = LOG(INFO) << "-- input arg_name " << arg_name;
ParamTypeRegistry::Global().Retrieve<ParamTypeRegistry::IO::kInput>(
inst.place, inst.op_type, arg_name);
CHECK(type) << "no param-type found for " << inst.op_type << ":"
<< arg_name << " " << inst.place.DebugString();
auto arg_names = inst.op_info->input_argument().at(arg_name);
// check if inputs's place is set, if not set, update them with the // check if inputs's place is set, if not set, update them with the
// kernel's declaration. // kernel's declaration.
auto type = inst.picked_kernel().GetInputDeclType(arg_name);
auto arg_names = inst.op_info()->input_argument().at(arg_name);
for (auto& arg_name : arg_names) { for (auto& arg_name : arg_names) {
LOG(INFO) << "--- var " << arg_name;
auto* node = graph->RetrieveArgument(arg_name); auto* node = graph->RetrieveArgument(arg_name);
CHECK(node) << "argument " << arg_name << " not exists in the graph"; CHECK(node) << "argument " << arg_name << " not exists in the graph";
auto& arg_node = node->AsArgument(); auto& arg_node = node->AsArgument();
if (arg_node.type) continue; if (arg_node.type) continue;
arg_node.type = type->type; arg_node.type = type;
} }
} }
for (auto& arg_name : inst.op_info->output_argnames()) { for (auto& arg_name : inst.op_info()->output_argnames()) {
auto type = ParamTypeRegistry::Global() LOG(INFO) << "-- output arg_name " << arg_name;
.Retrieve<ParamTypeRegistry::IO::kOutput>( auto type = inst.picked_kernel().GetOutputDeclType(arg_name);
inst.place, inst.op_type, arg_name); auto arg_names = inst.op_info()->output_argument().at(arg_name);
CHECK(type) << "no param-type found for " << inst.op_type << ":"
<< arg_name << " " << inst.place.DebugString();
auto arg_names = inst.op_info->output_argument().at(arg_name);
// check if outputs's place is set, if not set, update them with the // check if outputs's place is set, if not set, update them with the
// kernel's declaration. // kernel's declaration.
for (auto& arg_name : arg_names) { for (auto& arg_name : arg_names) {
LOG(INFO) << "--- var " << arg_name;
auto* node = graph->RetrieveArgument(arg_name); auto* node = graph->RetrieveArgument(arg_name);
CHECK(node) << "argument " << arg_name << " not exists in the graph"; CHECK(node) << "argument " << arg_name << " not exists in the graph";
auto& arg_node = node->AsArgument(); auto& arg_node = node->AsArgument();
if (arg_node.type) continue; if (arg_node.type) continue;
node->AsArgument().type = type->type; node->AsArgument().type = type;
} }
} }
} }
......
...@@ -27,13 +27,15 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels( ...@@ -27,13 +27,15 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels(
for (auto place : places) { for (auto place : places) {
auto ks = KernelRegistry::Global().Create( auto ks = KernelRegistry::Global().Create(
(kernel_type.empty() ? op_type_ : kernel_type), place.target, (kernel_type.empty() ? op_type_ : kernel_type), place.target,
place.precision); place.precision, place.layout);
for (auto &&it : ks) { for (auto &&it : ks) {
AttachKernel(it.get()); AttachKernel(it.get());
kernels.emplace_back(std::move(it)); kernels.emplace_back(std::move(it));
} }
} }
CHECK(!kernels.empty()) << "No kernel found for Op " << op_type_;
LOG(INFO) << "op " << op_type_ << " get " << kernels.size() << " kernels";
return kernels; return kernels;
} }
...@@ -59,9 +61,10 @@ bool OpLite::Run() { ...@@ -59,9 +61,10 @@ bool OpLite::Run() {
} }
bool OpLite::Attach(const framework::OpDesc &opdesc, lite::Scope *scope) { bool OpLite::Attach(const framework::OpDesc &opdesc, lite::Scope *scope) {
CHECK(!op_info_) << "op_info duplicate build found"; CHECK(scope);
op_info_ = std::make_shared<OpInfo>(); scope_ = scope;
op_info_->Build(opdesc); op_info_.reset(new OpInfo); // Force clean the out-of-date infomation.
op_info_->Build(opdesc.ReadonlyProto());
return AttachImpl(opdesc, scope); return AttachImpl(opdesc, scope);
} }
...@@ -79,7 +82,8 @@ Tensor *OpLite::GetMutableTensor(lite::Scope *scope, ...@@ -79,7 +82,8 @@ Tensor *OpLite::GetMutableTensor(lite::Scope *scope,
return var->GetMutable<lite::Tensor>(); return var->GetMutable<lite::Tensor>();
} }
bool OpInfo::GetInputArgname(const std::string &value_name, std::string *out) { bool OpInfo::GetInputArgname(const std::string &value_name,
std::string *out) const {
for (auto &item : input_argument_) { for (auto &item : input_argument_) {
auto it = std::find(item.second.begin(), item.second.end(), value_name); auto it = std::find(item.second.begin(), item.second.end(), value_name);
if (it != item.second.end()) { if (it != item.second.end()) {
...@@ -89,7 +93,8 @@ bool OpInfo::GetInputArgname(const std::string &value_name, std::string *out) { ...@@ -89,7 +93,8 @@ bool OpInfo::GetInputArgname(const std::string &value_name, std::string *out) {
} }
return false; return false;
} }
bool OpInfo::GetOutputArgname(const std::string &value_name, std::string *out) { bool OpInfo::GetOutputArgname(const std::string &value_name,
std::string *out) const {
for (auto &item : output_argument_) { for (auto &item : output_argument_) {
auto it = std::find(item.second.begin(), item.second.end(), value_name); auto it = std::find(item.second.begin(), item.second.end(), value_name);
if (it != item.second.end()) { if (it != item.second.end()) {
......
...@@ -81,19 +81,30 @@ class OpLite : public Registry { ...@@ -81,19 +81,30 @@ class OpLite : public Registry {
// Run this operator. // Run this operator.
virtual bool Run(); virtual bool Run();
// Link the external execution environ to internal context.
bool Attach(const framework::OpDesc &opdesc, lite::Scope *scope); bool Attach(const framework::OpDesc &opdesc, lite::Scope *scope);
const std::shared_ptr<OpInfo> &op_info() const { return op_info_; } const OpInfo *op_info() const { return op_info_.get(); }
std::shared_ptr<OpInfo> &mutable_op_info() { return op_info_; } OpInfo *mutable_op_info() { return op_info_.get(); }
// Human-readable information. // Human-readable information.
virtual std::string DebugString() const = 0; virtual std::string DebugString() const = 0;
const Place &kernel_place() const { return kernel_place_; } const Place &kernel_place() const { return kernel_place_; }
// NOTE This might be discarded.
void PickKernel(const std::vector<Place> &valid_places, void PickKernel(const std::vector<Place> &valid_places,
KernelStrategy kernel_strategy = KernelStrategy::kStatic); KernelStrategy kernel_strategy = KernelStrategy::kStatic);
// Create all the kernels for the valid targets.
std::vector<std::unique_ptr<KernelBase>> CreateKernels(
const std::vector<Place> &places, const std::string &kernel_type = "");
lite::Scope *scope() { return scope_; }
// Assign op param to kernel.
virtual void AttachKernel(KernelBase *kernel) = 0;
virtual ~OpLite() = default; virtual ~OpLite() = default;
protected: protected:
...@@ -101,9 +112,6 @@ class OpLite : public Registry { ...@@ -101,9 +112,6 @@ class OpLite : public Registry {
virtual bool AttachImpl(const framework::OpDesc &opdesc, virtual bool AttachImpl(const framework::OpDesc &opdesc,
lite::Scope *scope) = 0; lite::Scope *scope) = 0;
// Assign op param to kernel.
virtual void AttachKernel(KernelBase *kernel) = 0;
// Specify the kernel to run by default. This will specify the value of // Specify the kernel to run by default. This will specify the value of
// `kernel_place_`. // `kernel_place_`.
virtual void StaticPickKernel(const std::vector<Place> &valid_targets) { virtual void StaticPickKernel(const std::vector<Place> &valid_targets) {
...@@ -118,10 +126,6 @@ class OpLite : public Registry { ...@@ -118,10 +126,6 @@ class OpLite : public Registry {
// some inputs are ready. // some inputs are ready.
void RecordOutputEvents() {} void RecordOutputEvents() {}
// Create all the kernels for the valid targets.
std::vector<std::unique_ptr<KernelBase>> CreateKernels(
const std::vector<Place> &places, const std::string &kernel_type = "");
const Tensor *GetTensor(lite::Scope *scope, const std::string &name) const; const Tensor *GetTensor(lite::Scope *scope, const std::string &name) const;
Tensor *GetMutableTensor(lite::Scope *scope, const std::string &name) const; Tensor *GetMutableTensor(lite::Scope *scope, const std::string &name) const;
...@@ -129,11 +133,12 @@ class OpLite : public Registry { ...@@ -129,11 +133,12 @@ class OpLite : public Registry {
friend class mir::SSAGraph; friend class mir::SSAGraph;
protected: protected:
lite::Scope *scope_{};
std::unique_ptr<KernelBase> kernel_; std::unique_ptr<KernelBase> kernel_;
std::string op_type_; std::string op_type_;
std::vector<Place> valid_places_; std::vector<Place> valid_places_;
Place kernel_place_{TARGET(kHost), PRECISION(kFloat)}; Place kernel_place_{TARGET(kHost), PRECISION(kFloat)};
std::shared_ptr<OpInfo> op_info_; std::unique_ptr<OpInfo> op_info_;
}; };
/* /*
...@@ -142,22 +147,30 @@ class OpLite : public Registry { ...@@ -142,22 +147,30 @@ class OpLite : public Registry {
*/ */
class OpInfo { class OpInfo {
public: public:
void Build(const framework::OpDesc &desc) { // To avoid the bugs from legancy framework::OpDesc, we use the ProtoBuf
// message instead.
void Build(const framework::proto::OpDesc &desc) {
ExtractInputsAndOutputs(desc); ExtractInputsAndOutputs(desc);
CollectInputAndOutputArgnames(desc); CollectInputAndOutputArgnames(desc);
CollectArguments(desc); CollectArguments(desc);
desc_.reset(new framework::proto::OpDesc(desc));
} }
const framework::proto::OpDesc &desc() const {
CHECK(desc_) << "desc has't set";
return *desc_;
}
framework::proto::OpDesc *mutable_desc() { return desc_.get(); }
const std::list<std::string> &input_names() const { return input_names_; } const std::list<std::string> &input_names() const { return input_names_; }
const std::list<std::string> &output_names() const { return output_names_; } const std::list<std::string> &output_names() const { return output_names_; }
const std::map<std::string, std::list<std::string>> &input_argument() { const std::map<std::string, std::list<std::string>> &input_argument() const {
return input_argument_; return input_argument_;
} }
const std::map<std::string, std::list<std::string>> &output_argument() { const std::map<std::string, std::list<std::string>> &output_argument() const {
return output_argument_; return output_argument_;
} }
bool GetInputArgname(const std::string &value_name, std::string *out); bool GetInputArgname(const std::string &value_name, std::string *out) const;
bool GetOutputArgname(const std::string &value_name, std::string *out); bool GetOutputArgname(const std::string &value_name, std::string *out) const;
const std::list<std::string> &input_argnames() const { const std::list<std::string> &input_argnames() const {
return input_argnames_; return input_argnames_;
...@@ -167,37 +180,37 @@ class OpInfo { ...@@ -167,37 +180,37 @@ class OpInfo {
} }
private: private:
void ExtractInputsAndOutputs(const framework::OpDesc &opdesc) { void ExtractInputsAndOutputs(const framework::proto::OpDesc &opdesc) {
for (const auto &item : opdesc.Inputs()) { for (const auto &item : opdesc.inputs()) {
for (const auto &x : item.second) { for (const auto &x : item.arguments()) {
input_names_.push_back(x); input_names_.push_back(x);
} }
} }
for (const auto &item : opdesc.Outputs()) { for (const auto &item : opdesc.outputs()) {
for (const auto &x : item.second) { for (const auto &x : item.arguments()) {
output_names_.push_back(x); output_names_.push_back(x);
} }
} }
} }
void CollectInputAndOutputArgnames(const framework::OpDesc &opdesc) { void CollectInputAndOutputArgnames(const framework::proto::OpDesc &opdesc) {
for (const auto &item : opdesc.InputNames()) { for (const auto &item : opdesc.inputs()) {
input_argnames_.push_back(item); input_argnames_.push_back(item.parameter());
} }
for (const auto &item : opdesc.OutputNames()) { for (const auto &item : opdesc.outputs()) {
output_argnames_.push_back(item); output_argnames_.push_back(item.parameter());
} }
} }
void CollectArguments(const framework::OpDesc &opdesc) { void CollectArguments(const framework::proto::OpDesc &opdesc) {
for (const auto &item : opdesc.Inputs()) { for (const auto &item : opdesc.inputs()) {
for (auto &x : item.second) { for (auto &x : item.arguments()) {
input_argument_[item.first].push_back(x); input_argument_[item.parameter()].push_back(x);
} }
} }
for (const auto &item : opdesc.Outputs()) { for (const auto &item : opdesc.outputs()) {
for (auto &x : item.second) { for (auto &x : item.arguments()) {
output_argument_[item.first].push_back(x); output_argument_[item.parameter()].push_back(x);
} }
} }
} }
...@@ -209,6 +222,8 @@ class OpInfo { ...@@ -209,6 +222,8 @@ class OpInfo {
std::list<std::string> output_argnames_; std::list<std::string> output_argnames_;
std::map<std::string, std::list<std::string>> input_argument_; std::map<std::string, std::list<std::string>> input_argument_;
std::map<std::string, std::list<std::string>> output_argument_; std::map<std::string, std::list<std::string>> output_argument_;
// NOTE too heavy.
std::unique_ptr<framework::proto::OpDesc> desc_;
}; };
} // namespace lite } // namespace lite
......
...@@ -18,13 +18,33 @@ namespace paddle { ...@@ -18,13 +18,33 @@ namespace paddle {
namespace lite { namespace lite {
std::list<std::unique_ptr<KernelBase>> KernelRegistry::Create( std::list<std::unique_ptr<KernelBase>> KernelRegistry::Create(
const std::string &op_type, TargetType target, PrecisionType precision) { const std::string &op_type, TargetType target, PrecisionType precision,
DataLayoutType layout) {
Place place{target, precision, layout};
LOG(INFO) << "creating " << op_type << " kernel for " << place;
#define CREATE_KERNEL1(target__, precision__) \
switch (layout) { \
case DATALAYOUT(kNCHW): \
return Create<TARGET(target__), PRECISION(precision__), \
DATALAYOUT(kNCHW)>(op_type); \
case DATALAYOUT(kAny): \
return Create<TARGET(target__), PRECISION(precision__), \
DATALAYOUT(kAny)>(op_type); \
default: \
LOG(FATAL) << "unsupported kernel layout " << DataLayoutToStr(layout); \
}
#define CREATE_KERNEL(target__) \ #define CREATE_KERNEL(target__) \
switch (precision) { \ switch (precision) { \
case PRECISION(kFloat): \ case PRECISION(kFloat): \
return Create<TARGET(target__), PRECISION(kFloat)>(op_type); \ CREATE_KERNEL1(target__, kFloat); \
case PRECISION(kInt8): \
CREATE_KERNEL1(target__, kInt8); \
case PRECISION(kAny): \
CREATE_KERNEL1(target__, kAny); \
default: \ default: \
CHECK(false) << "not supported kernel place yet"; \ CHECK(false) << "not supported kernel precision " \
<< PrecisionToStr(precision); \
} }
switch (target) { switch (target) {
...@@ -38,7 +58,7 @@ std::list<std::unique_ptr<KernelBase>> KernelRegistry::Create( ...@@ -38,7 +58,7 @@ std::list<std::unique_ptr<KernelBase>> KernelRegistry::Create(
CREATE_KERNEL(kCUDA); CREATE_KERNEL(kCUDA);
} break; } break;
default: default:
CHECK(false) << "not supported kernel place"; CHECK(false) << "not supported kernel target " << TargetToStr(target);
} }
#undef CREATE_KERNEL #undef CREATE_KERNEL
...@@ -46,14 +66,21 @@ std::list<std::unique_ptr<KernelBase>> KernelRegistry::Create( ...@@ -46,14 +66,21 @@ std::list<std::unique_ptr<KernelBase>> KernelRegistry::Create(
} }
KernelRegistry::KernelRegistry() { KernelRegistry::KernelRegistry() {
#define INIT_FOR(target__, precision__) \ #define INIT_FOR(target__, precision__, layout__) \
registries_[KernelRegistry::GetKernelOffset<TARGET(target__), \ registries_[KernelRegistry::GetKernelOffset<TARGET(target__), \
PRECISION(precision__)>()] \ PRECISION(precision__), \
.set<KernelRegistryForTarget<TARGET(target__), PRECISION(precision__)> \ DATALAYOUT(layout__)>()] \
*>(&KernelRegistryForTarget<TARGET(target__), \ .set<KernelRegistryForTarget<TARGET(target__), PRECISION(precision__), \
PRECISION(precision__)>::Global()); DATALAYOUT(layout__)> *>( \
&KernelRegistryForTarget<TARGET(target__), PRECISION(precision__), \
DATALAYOUT(layout__)>::Global());
// Currently, just register 2 kernel targets. // Currently, just register 2 kernel targets.
INIT_FOR(kHost, kFloat); INIT_FOR(kCUDA, kFloat, kNCHW);
INIT_FOR(kCUDA, kAny, kNCHW);
INIT_FOR(kHost, kFloat, kNCHW);
INIT_FOR(kHost, kAny, kNCHW);
INIT_FOR(kHost, kAny, kAny);
INIT_FOR(kCUDA, kAny, kAny);
#undef INIT_FOR #undef INIT_FOR
} }
......
...@@ -50,80 +50,108 @@ class OpLiteRegistor : public Registor<OpClass> { ...@@ -50,80 +50,108 @@ class OpLiteRegistor : public Registor<OpClass> {
}) {} }) {}
}; };
template <TargetType Target, PrecisionType Precision> template <TargetType Target, PrecisionType Precision, DataLayoutType Layout>
using KernelRegistryForTarget = using KernelRegistryForTarget =
Factory<OpKernel<Target, Precision>, std::unique_ptr<KernelBase>>; Factory<OpKernel<Target, Precision, Layout>, std::unique_ptr<KernelBase>>;
class KernelRegistry final { class KernelRegistry final {
public: public:
using any_kernel_registor_t = using any_kernel_registor_t =
variant<KernelRegistryForTarget<TARGET(kCUDA), PRECISION(kFloat)> *, // variant<KernelRegistryForTarget<TARGET(kCUDA), PRECISION(kFloat),
KernelRegistryForTarget<TARGET(kCUDA), PRECISION(kInt8)> *, // DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kX86), PRECISION(kFloat)> *, // KernelRegistryForTarget<TARGET(kCUDA), PRECISION(kInt8),
KernelRegistryForTarget<TARGET(kX86), PRECISION(kInt8)> *, // DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kHost), PRECISION(kFloat)> * // KernelRegistryForTarget<TARGET(kX86), PRECISION(kFloat),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kX86), PRECISION(kInt8),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kHost), PRECISION(kFloat),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kHost), PRECISION(kAny),
DATALAYOUT(kAny)> *, //
KernelRegistryForTarget<TARGET(kCUDA), PRECISION(kAny),
DATALAYOUT(kAny)> * //
>; >;
KernelRegistry(); KernelRegistry();
static KernelRegistry &Global(); static KernelRegistry &Global();
template <TargetType Target, PrecisionType Precision> template <TargetType Target, PrecisionType Precision, DataLayoutType Layout>
void Register(const std::string &name, void Register(const std::string &name,
typename KernelRegistryForTarget<Target, Precision>::creator_t typename KernelRegistryForTarget<Target, Precision,
&&creator) { Layout>::creator_t &&creator) {
using kernel_registor_t = KernelRegistryForTarget<Target, Precision>; LOG(INFO) << "register for " << TargetToStr(Target) << ":"
registries_[GetKernelOffset<Target, Precision>()] << PrecisionToStr(Precision) << "//"
.template get<kernel_registor_t *>() << GetKernelOffset<Target, Precision, Layout>();
->Register(name, std::move(creator)); using kernel_registor_t =
KernelRegistryForTarget<Target, Precision, Layout>;
auto &varient = registries_[GetKernelOffset<Target, Precision, Layout>()];
varient.template get<kernel_registor_t *>()->Register(name,
std::move(creator));
} }
template <TargetType Target, PrecisionType Precision> template <TargetType Target, PrecisionType Precision, DataLayoutType Layout>
std::list<std::unique_ptr<KernelBase>> Create(const std::string &op_type) { std::list<std::unique_ptr<KernelBase>> Create(const std::string &op_type) {
using kernel_registor_t = KernelRegistryForTarget<Target, Precision>; using kernel_registor_t =
return registries_[GetKernelOffset<Target, Precision>()] KernelRegistryForTarget<Target, Precision, Layout>;
return registries_[GetKernelOffset<Target, Precision, Layout>()]
.template get<kernel_registor_t *>() .template get<kernel_registor_t *>()
->Creates(op_type); ->Creates(op_type);
} }
std::list<std::unique_ptr<KernelBase>> Create(const std::string &op_type, std::list<std::unique_ptr<KernelBase>> Create(const std::string &op_type,
TargetType target, TargetType target,
PrecisionType precision); PrecisionType precision,
DataLayoutType layout);
// Get a kernel registry offset in all the registries. // Get a kernel registry offset in all the registries.
template <TargetType Target, PrecisionType Precision> template <TargetType Target, PrecisionType Precision, DataLayoutType Layout>
static constexpr int GetKernelOffset() { static int GetKernelOffset() {
return kNumTargets * static_cast<int>(Target) + static_cast<int>(Precision); CHECK_LT(static_cast<int>(Target), static_cast<int>(TARGET(NUM)));
CHECK_LT(static_cast<int>(Precision), static_cast<int>(PRECISION(NUM)));
CHECK_LT(static_cast<int>(Layout), static_cast<int>(DATALAYOUT(NUM)));
return static_cast<int>(Target) * static_cast<int>(PRECISION(NUM)) *
static_cast<int>(DATALAYOUT(NUM)) + //
static_cast<int>(Precision) * static_cast<int>(DATALAYOUT(NUM)) + //
static_cast<int>(Layout);
} }
std::string DebugString() const { std::string DebugString() const {
std::stringstream ss; std::stringstream ss;
ss << "KernelCreator<host, float>:" << std::endl; ss << "KernelCreator<host, float>:" << std::endl;
ss << registries_[GetKernelOffset<TARGET(kHost), PRECISION(kFloat)>()] ss << registries_[GetKernelOffset<TARGET(kHost), PRECISION(kFloat),
.get< DATALAYOUT(kAny)>()]
KernelRegistryForTarget<TARGET(kHost), PRECISION(kFloat)> *>() .get<KernelRegistryForTarget<TARGET(kHost), PRECISION(kFloat),
DATALAYOUT(kNCHW)> *>()
->DebugString(); ->DebugString();
ss << std::endl; ss << std::endl;
return ss.str(); return ss.str();
} }
private: private:
mutable std::array<any_kernel_registor_t, kNumTargets * kNumPrecisions> mutable std::array<any_kernel_registor_t,
static_cast<int>(TARGET(NUM)) *
static_cast<int>(PRECISION(NUM)) *
static_cast<int>(DATALAYOUT(NUM))>
registries_; registries_;
}; };
template <TargetType target, PrecisionType precision, typename KernelType> template <TargetType target, PrecisionType precision, DataLayoutType layout,
typename KernelType>
class KernelRegistor : public lite::Registor<KernelType> { class KernelRegistor : public lite::Registor<KernelType> {
public: public:
KernelRegistor(const std::string op_type) KernelRegistor(const std::string &op_type, const std::string &alias)
: Registor<KernelType>([&] { : Registor<KernelType>([=] {
LOG(INFO) << "Register kernel " << op_type << " for " LOG(INFO) << "Register kernel " << op_type << " for "
<< TargetToStr(target) << " " << PrecisionToStr(precision); << TargetToStr(target) << " " << PrecisionToStr(precision)
KernelRegistry::Global().Register<target, precision>( << " " << DataLayoutToStr(layout) << " alias " << alias;
op_type, [&, op_type]() -> std::unique_ptr<KernelType> { KernelRegistry::Global().Register<target, precision, layout>(
op_type, [=]() -> std::unique_ptr<KernelType> {
std::unique_ptr<KernelType> x(new KernelType); std::unique_ptr<KernelType> x(new KernelType);
x->set_op_type(op_type); x->set_op_type(op_type);
x->set_alias(alias);
return x; return x;
}); });
}) {} }) {}
...@@ -151,35 +179,40 @@ class KernelRegistor : public lite::Registor<KernelType> { ...@@ -151,35 +179,40 @@ class KernelRegistor : public lite::Registor<KernelType> {
#define LITE_KERNEL_REGISTER(op_type__, target__, precision__) \ #define LITE_KERNEL_REGISTER(op_type__, target__, precision__) \
op_type__##__##target__##__##precision__##__registor__ op_type__##__##target__##__##precision__##__registor__
#define LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, \ #define LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, \
alias__) \ layout__, alias__) \
op_type__##__##target__##__##precision__##__registor__instance__##alias__ op_type__##__##target__##__##precision__##__registor__instance__##alias__
#define LITE_KERNEL_REGISTER_FAKE(op_type__, target__, precision__, alias__) \ #define LITE_KERNEL_REGISTER_FAKE(op_type__, target__, precision__, alias__) \
LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, alias__) LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, alias__)
#define REGISTER_LITE_KERNEL(op_type__, target__, precision__, KernelClass, \ #define REGISTER_LITE_KERNEL(op_type__, target__, precision__, layout__, \
alias__) \ KernelClass, alias__) \
static paddle::lite::KernelRegistor<TARGET(target__), \ static paddle::lite::KernelRegistor<TARGET(target__), \
PRECISION(precision__), KernelClass> \ PRECISION(precision__), \
DATALAYOUT(layout__), KernelClass> \
LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, \ LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, \
alias__)(#op_type__); \ layout__, alias__)(#op_type__, #alias__); \
static KernelClass LITE_KERNEL_INSTANCE(op_type__, target__, precision__, \ static KernelClass LITE_KERNEL_INSTANCE(op_type__, target__, precision__, \
alias__); \ layout__, alias__); \
int touch_##op_type__##target__##precision__##alias__() { \ int touch_##op_type__##target__##precision__##layout__##alias__() { \
LITE_KERNEL_INSTANCE(op_type__, target__, precision__, alias__).Touch(); \ LITE_KERNEL_INSTANCE(op_type__, target__, precision__, layout__, alias__) \
.Touch(); \
return 0; \ return 0; \
} \ } \
static bool LITE_KERNEL_PARAM_INSTANCE(op_type__, target__, precision__, \ static bool LITE_KERNEL_PARAM_INSTANCE(op_type__, target__, precision__, \
alias__) __attribute__((unused)) = \ layout__, alias__) \
paddle::lite::ParamTypeRegistry::NewInstance<TARGET(target__), \ __attribute__((unused)) = paddle::lite::ParamTypeRegistry::NewInstance< \
PRECISION(precision__)>( \ TARGET(target__), PRECISION(precision__), DATALAYOUT(layout__)>( \
#op_type__) #op_type__ "/" #alias__)
#define USE_LITE_KERNEL(op_type__, target__, precision__, alias__) \ #define USE_LITE_KERNEL(op_type__, target__, precision__, layout__, alias__) \
extern int touch_##op_type__##target__##precision__##alias__(); \ extern int touch_##op_type__##target__##precision__##layout__##alias__(); \
int op_type__##target__##precision__##alias__ __attribute__((unused)) = \ int op_type__##target__##precision__##layout__##alias__ \
touch_##op_type__##target__##precision__##alias__(); __attribute__((unused)) = \
touch_##op_type__##target__##precision__##layout__##alias__();
#define LITE_KERNEL_INSTANCE(op_type__, target__, precision__, alias__) \
op_type__##target__##precision__##alias__ #define LITE_KERNEL_INSTANCE(op_type__, target__, precision__, layout__, \
#define LITE_KERNEL_PARAM_INSTANCE(op_type__, target__, precision__, alias__) \ alias__) \
op_type__##target__##precision__##alias__##param_register op_type__##target__##precision__##layout__##alias__
#define LITE_KERNEL_PARAM_INSTANCE(op_type__, target__, precision__, layout__, \
alias__) \
op_type__##target__##precision__##layout__##alias__##param_register
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/core/optimizer.h" #include "paddle/fluid/lite/core/optimizer.h"
#include "paddle/fluid/lite/core/mir/io_complement_pass.h"
#include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h" #include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h"
namespace paddle { namespace paddle {
...@@ -25,5 +26,33 @@ void Optimizer::SpecifyKernelPickTactic(core::KernelPickFactor factor) { ...@@ -25,5 +26,33 @@ void Optimizer::SpecifyKernelPickTactic(core::KernelPickFactor factor) {
*pass->mutable_kernel_pick_factors() = factor; *pass->mutable_kernel_pick_factors() = factor;
} }
void Optimizer::RunPasses() {
std::vector<std::string> passes({
"static_kernel_pick_pass", //
"variable_place_inference_pass", //
"argument_type_display_pass", //
"io_complement_pass", //
"argument_type_display_pass", //
"variable_place_inference_pass", //
"argument_type_display_pass", //
"io_copy_kernel_pick_pass", //
"variable_place_inference_pass", //
});
for (auto& pass_type : passes) {
LOG(INFO) << ".. running pass " << pass_type;
auto* pass = mir::PassManager::Global().LookUp(pass_type);
CHECK(pass);
if (pass->name() == "io_complement_pass") {
auto* _pass = dynamic_cast<mir::IoComplementPass*>(pass);
_pass->SetValidPlaces(valid_places_);
CHECK(!_pass->valid_places().empty());
_pass->Apply(graph_);
} else {
pass->Apply(graph_);
}
}
// mir::PassManager::Global().Run(graph_);
}
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -16,8 +16,10 @@ ...@@ -16,8 +16,10 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/lite/core/mir/generate_program_pass.h" #include "paddle/fluid/lite/core/mir/generate_program_pass.h"
#include "paddle/fluid/lite/core/mir/io_complement_pass.h"
#include "paddle/fluid/lite/core/mir/pass_manager.h" #include "paddle/fluid/lite/core/mir/pass_manager.h"
#include "paddle/fluid/lite/core/mir/ssa_graph.h" #include "paddle/fluid/lite/core/mir/ssa_graph.h"
#include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h"
#include "paddle/fluid/lite/core/program.h" #include "paddle/fluid/lite/core/program.h"
#include "paddle/fluid/lite/core/types.h" #include "paddle/fluid/lite/core/types.h"
...@@ -33,25 +35,46 @@ class Optimizer { ...@@ -33,25 +35,46 @@ class Optimizer {
void Run(Program&& program, const std::vector<Place>& valid_places, void Run(Program&& program, const std::vector<Place>& valid_places,
core::KernelPickFactor kernel_pick_factor, core::KernelPickFactor kernel_pick_factor,
const std::vector<std::string>& passes = {}) { const std::vector<std::string>& passes = {}) {
valid_places_ = valid_places;
CHECK(!valid_places.empty()) << "At least one valid_place should be set";
CHECK(!graph_) << "duplicate optimize found"; CHECK(!graph_) << "duplicate optimize found";
graph_.reset(new mir::SSAGraph); graph_.reset(new mir::SSAGraph);
graph_->Build(program, valid_places); graph_->Build(program, valid_places);
SpecifyKernelPickTactic(kernel_pick_factor); SpecifyKernelPickTactic(kernel_pick_factor);
// InitIoComplement();
RunPasses(); RunPasses();
exec_scope_ = program.exec_scope; exec_scope_ = program.exec_scope;
} }
void KernelPickPreferPlace(const Place& place) {
auto* pass = mir::PassManager::Global().LookUp<mir::StaticKernelPickPass>(
"static_kernel_pick_pass");
CHECK(pass);
pass->SetPreferPlace(place);
}
// Generate a new program based on the mir graph. // Generate a new program based on the mir graph.
std::unique_ptr<RuntimeProgram> GenRuntimeProgram() { std::unique_ptr<RuntimeProgram> GenRuntimeProgram() {
LOG(INFO) << "generate program";
std::unique_ptr<Program> res; std::unique_ptr<Program> res;
auto pass = mir::PassManager::Global().LookUp<mir::GenerateProgramPass>( auto pass = mir::PassManager::Global().LookUp<mir::GenerateProgramPass>(
"generate_program_pass"); "generate_program_pass");
pass->Apply(graph_);
auto program = pass->GenProgram(); auto program = pass->GenProgram();
CHECK(exec_scope_); CHECK(exec_scope_);
program->set_exec_scope(exec_scope_); program->set_exec_scope(exec_scope_);
return program; return program;
} }
void InitIoComplement() {
auto* pass = mir::PassManager::Global().LookUp<mir::IoComplementPass>(
"io_complement_pass");
CHECK(pass);
CHECK(!valid_places_.empty());
LOG(INFO) << "valid_places.size " << valid_places_.size();
pass->SetValidPlaces(valid_places_);
}
// Generate C++ code which combines the inference program, model and weights. // Generate C++ code which combines the inference program, model and weights.
void GenCode(const std::string& code_dir); void GenCode(const std::string& code_dir);
...@@ -64,13 +87,14 @@ class Optimizer { ...@@ -64,13 +87,14 @@ class Optimizer {
void SpecifyKernelPickTactic(core::KernelPickFactor factor); void SpecifyKernelPickTactic(core::KernelPickFactor factor);
// Run the default passes registered in the PassManager. // Run the default passes registered in the PassManager.
void RunPasses() { mir::PassManager::Global().Run(graph_); } void RunPasses();
// Specify the passes and run them. // Specify the passes and run them.
void RunPasses(std::vector<std::string>& passes); void RunPasses(std::vector<std::string>& passes);
private: private:
std::unique_ptr<mir::SSAGraph> graph_; std::unique_ptr<mir::SSAGraph> graph_;
std::vector<Place> valid_places_;
lite::Scope* exec_scope_{}; lite::Scope* exec_scope_{};
}; };
......
...@@ -84,13 +84,10 @@ struct Program { ...@@ -84,13 +84,10 @@ struct Program {
tmp_vars.push_back("fetch"); tmp_vars.push_back("fetch");
for (auto var_desc : program.Block(0).AllVars()) { for (auto var_desc : program.Block(0).AllVars()) {
if (!var_desc->Persistable()) { if (!var_desc->Persistable()) {
LOG(INFO) << "get tmp var " << var_desc->Name();
tmp_vars.push_back(var_desc->Name()); tmp_vars.push_back(var_desc->Name());
auto* var = exec_scope->Var(var_desc->Name()); exec_scope->Var(var_desc->Name());
LOG(INFO) << "create tmp var " << var_desc->Name() << " " << var;
} else { } else {
if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") continue; if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") continue;
LOG(INFO) << "get weight var " << var_desc->Name();
weights.push_back(var_desc->Name()); weights.push_back(var_desc->Name());
} }
} }
...@@ -105,15 +102,19 @@ struct Instruction { ...@@ -105,15 +102,19 @@ struct Instruction {
void Run() { void Run() {
CHECK(op_); CHECK(op_);
CHECK(kernel_); CHECK(kernel_);
LOG(INFO) << "running kernel> " << kernel_->DebugString();
if (UNLIKELY(first_epoch_)) { if (UNLIKELY(first_epoch_)) {
first_epoch_ = false; first_epoch_ = false;
op_->CheckShape(); CHECK(op_->CheckShape());
} }
op_->InferShape(); op_->InferShape();
kernel_->Run(); kernel_->Run();
} }
friend std::ostream& operator<<(std::ostream& os, const Instruction& other) {
os << other.kernel_->summary() << "\t(" << other.kernel_->doc() << ")";
return os;
}
private: private:
std::shared_ptr<OpLite> op_; std::shared_ptr<OpLite> op_;
std::unique_ptr<KernelBase> kernel_; std::unique_ptr<KernelBase> kernel_;
...@@ -125,11 +126,16 @@ struct Instruction { ...@@ -125,11 +126,16 @@ struct Instruction {
*/ */
class RuntimeProgram { class RuntimeProgram {
public: public:
explicit RuntimeProgram(std::vector<Instruction>&& instruction) explicit RuntimeProgram(std::vector<Instruction>&& insts)
: instructions_(std::move(instruction)) {} : instructions_(std::move(insts)) {
if (insts.empty()) {
LOG(ERROR) << "no instructions";
}
}
void Run() { void Run() {
for (auto& inst : instructions_) { for (auto& inst : instructions_) {
LOG(INFO) << ">> Running kernel: " << inst;
inst.Run(); inst.Run();
} }
} }
......
...@@ -16,6 +16,10 @@ ...@@ -16,6 +16,10 @@
#include <glog/logging.h> #include <glog/logging.h>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#ifdef LITE_WITH_CUDA
#include <cuda.h>
#include <cuda_runtime.h>
#endif
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -26,20 +30,20 @@ enum class TargetType : int { ...@@ -26,20 +30,20 @@ enum class TargetType : int {
kX86, kX86,
kCUDA, kCUDA,
kAny, // any target kAny, // any target
kLastAsPlaceHolder, NUM, // number of fields.
}; };
enum class PrecisionType : int { enum class PrecisionType : int {
kUnk = 0, kUnk = 0,
kFloat, kFloat,
kInt8, kInt8,
kAny, // any precision kAny, // any precision
kLastAsPlaceHolder, NUM, // number of fields.
}; };
enum class DataLayoutType : int { enum class DataLayoutType : int {
kUnk = 0, kUnk = 0,
kNCHW, kNCHW,
kAny, // any data layout kAny, // any data layout
kLastAsPlaceHolder, NUM, // number of fields.
}; };
// Some helper macro to get a specific TargetType. // Some helper macro to get a specific TargetType.
...@@ -50,25 +54,29 @@ enum class DataLayoutType : int { ...@@ -50,25 +54,29 @@ enum class DataLayoutType : int {
#define PRECISION_VAL(item__) static_cast<int>(PRECISION(item__)) #define PRECISION_VAL(item__) static_cast<int>(PRECISION(item__))
#define DATALAYOUT(item__) paddle::lite::DataLayoutType::item__ #define DATALAYOUT(item__) paddle::lite::DataLayoutType::item__
constexpr const int kNumPrecisions = constexpr const int kNumPrecisions = PRECISION_VAL(NUM);
PRECISION_VAL(kLastAsPlaceHolder) - PRECISION_VAL(kFloat); constexpr const int kNumTargets = TARGET_VAL(NUM);
constexpr const int kNumTargets =
TARGET_VAL(kLastAsPlaceHolder) - TARGET_VAL(kHost);
static const std::string target2string[] = {"unk", "host", "x86", "cuda", static const std::string target2string[] = {"unk", "host", "x86", "cuda",
"any"}; "any"};
static const std::string& TargetToStr(TargetType target) { static const std::string& TargetToStr(TargetType target) {
return target2string[static_cast<int>(target)]; auto x = static_cast<int>(target);
CHECK_LT(x, static_cast<int>(TARGET(NUM)));
return target2string[x];
} }
static const std::string precision2string[] = {"unk", "float", "int8", "any"}; static const std::string precision2string[] = {"unk", "float", "int8", "any"};
static const std::string& PrecisionToStr(PrecisionType precision) { static const std::string& PrecisionToStr(PrecisionType precision) {
return precision2string[static_cast<int>(precision)]; auto x = static_cast<int>(precision);
CHECK_LT(x, static_cast<int>(PRECISION(NUM)));
return precision2string[x];
} }
static const std::string datalayout2string[] = {"unk", "NCHW", "any"}; static const std::string datalayout2string[] = {"unk", "NCHW", "any"};
static const std::string& DataLayoutToStr(DataLayoutType x) { static const std::string& DataLayoutToStr(DataLayoutType layout) {
return datalayout2string[static_cast<int>(x)]; auto x = static_cast<int>(layout);
CHECK_LT(x, static_cast<int>(DATALAYOUT(NUM)));
return datalayout2string[x];
} }
/* /*
...@@ -187,5 +195,37 @@ class TargetWrapper<TARGET(kHost)> { ...@@ -187,5 +195,37 @@ class TargetWrapper<TARGET(kHost)> {
} }
}; };
#ifdef LITE_WITH_CUDA
// This interface should be specified by each kind of target.
template <>
class TargetWrapper<TARGET(kCUDA), cudaStream_t, cudaEvent_t> {
public:
using stream_t = cudaStream_t;
using event_t = cudaEvent_t;
static size_t num_devices() { return 0; }
static size_t maximum_stream() { return 0; }
static void CreateStream(stream_t* stream) {}
static void DestroyStream(const stream_t& stream) {}
static void CreateEvent(event_t* event) {}
static void DestroyEvent(const event_t& event) {}
static void RecordEvent(const event_t& event) {}
static void SyncEvent(const event_t& event) {}
static void StreamSync(const stream_t& stream) {}
static void* Malloc(size_t size);
static void Free(void* ptr);
static void MemcpySync(void* dst, const void* src, size_t size,
IoDirection dir);
static void MemcpyAsync(void* dst, const void* src, size_t size,
IoDirection dir, const stream_t& stream);
};
#endif // LITE_WITH_CUDA
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -87,20 +87,41 @@ const Type* Type::Get<TensorAnyTy>(TargetType target) { ...@@ -87,20 +87,41 @@ const Type* Type::Get<TensorAnyTy>(TargetType target) {
} }
} }
template <TargetType Target>
const Type* GetTensorFp32NCHWTy() {
static TensorFp32NCHWTy x(Target);
return &x;
}
template <> template <>
const Type* Type::Get<TensorFp32NCHWTy>(TargetType target) { const Type* Type::Get<TensorFp32NCHWTy>(TargetType target) {
switch (target) { switch (target) {
case TargetType::kX86: case TARGET(kHost):
return Get<false, true, TargetType::kX86, PrecisionType::kFloat, return GetTensorFp32NCHWTy<TARGET(kHost)>();
DataLayoutType::kNCHW>(); case TARGET(kCUDA):
case TargetType::kHost: return GetTensorFp32NCHWTy<TARGET(kCUDA)>();
return Get<false, true, TargetType::kHost, PrecisionType::kFloat, case TARGET(kX86):
DataLayoutType::kNCHW>(); return GetTensorFp32NCHWTy<TARGET(kX86)>();
default: default:
LOG(FATAL) << "unsupported target " << TargetToStr(target); LOG(FATAL) << "unsupported target Type " << TargetToStr(target);
}
return nullptr; return nullptr;
}
const Type* LookupType(DataTypeBase::ID type_id, bool is_unknown,
bool is_tensor, Place place) {
using id_t = DataTypeBase::ID;
switch (type_id) {
case id_t::Tensor_Any:
return Type::Get<TensorAnyTy>(place.target);
case id_t::Tensor_Fp32_NCHW:
return Type::Get<TensorFp32NCHWTy>(place.target);
case id_t::TensorList_Any:
return Type::Get<TensorListAnyTy>(place.target);
default:
LOG(FATAL) << "unsupported type";
} }
return nullptr;
} }
// ------------------------- end GetType specification ------------------------ // ------------------------- end GetType specification ------------------------
......
...@@ -131,6 +131,23 @@ class Type : public DataTypeBase { ...@@ -131,6 +131,23 @@ class Type : public DataTypeBase {
bool operator==(const Type& other) { bool operator==(const Type& other) {
return id_ == other.id() && place_ == other.place(); return id_ == other.id() && place_ == other.place();
} }
friend std::ostream& operator<<(std::ostream& os, const Type& other) {
if (other.IsUnsupported()) {
os << "<Unsupported>";
return os;
}
if (other.IsVoid()) {
os << "<Void>";
return os;
}
if (other.is_tensor_) {
os << "<Tensor:";
}
os << TargetToStr(other.target()) << "/"
<< PrecisionToStr(other.precision()) << "/"
<< DataLayoutToStr(other.layout()) << ">";
return os;
}
// Can cast to another type. This is heavily used in MIR, by determine whether // Can cast to another type. This is heavily used in MIR, by determine whether
// is is possible to add a instruction to transform a type to another. // is is possible to add a instruction to transform a type to another.
...@@ -163,29 +180,33 @@ class Type : public DataTypeBase { ...@@ -163,29 +180,33 @@ class Type : public DataTypeBase {
}; };
// -------------------------------- compatible check --------------------------- // -------------------------------- compatible check ---------------------------
static bool TargetCompatible(const Type& a, const Type& b) { static bool TargetCompatibleTo(const Type& a, const Type& b) {
return (a.IsVoid() || b.IsVoid()) || // return a.IsVoid() || //
a.target() == b.target(); (a.IsTensor() && b.IsTensor() && (a.target() == b.target() || //
b.target() == TARGET(kAny)));
} }
static bool DataLayoutCompatible(const Type& a, const Type& b) { static bool DataLayoutCompatibleTo(const Type& a, const Type& b) {
return (a.IsVoid() || b.IsVoid()) || // return a.IsVoid() || //
(a.IsTensor() && b.IsTensor() && a.layout() == b.layout()); (a.IsTensor() && b.IsTensor() && (a.layout() == b.layout() || //
b.layout() == DATALAYOUT(kAny)));
} }
static bool PrecisionCompatible(const Type& a, const Type& b) { static bool PrecisionCompatibleTo(const Type& a, const Type& b) {
return (a.IsVoid() || b.IsVoid()) || // return a.IsVoid() || //
(a.precision() == b.precision()); (a.IsTensor() && b.IsTensor() && (a.precision() == b.precision() || //
b.precision() == PRECISION(kAny)));
} }
static bool DeviceCompatible(const Type& a, const Type& b) { static bool DeviceCompatibleTo(const Type& a, const Type& b) {
return (a.IsVoid() || b.IsVoid()) || // return a.IsVoid() || //
(a.device() == b.device()); (a.IsTensor() && b.IsTensor() && (a.device() == b.device()));
} }
static bool TypeCompatible(const Type& a, const Type& b) { // Can type 'a' be passed to 'b' directly.
return TargetCompatible(a, b) && DataLayoutCompatible(a, b) && static bool TypeCompatibleTo(const Type& a, const Type& b) {
PrecisionCompatible(a, b) && DeviceCompatible(a, b); return TargetCompatibleTo(a, b) && DataLayoutCompatibleTo(a, b) &&
PrecisionCompatibleTo(a, b) && DeviceCompatibleTo(a, b);
} }
// -------------------------------- predefined types --------------------------- // -------------------------------- predefined types ---------------------------
...@@ -230,6 +251,9 @@ class TensorInt64NCHWTy : public Type { ...@@ -230,6 +251,9 @@ class TensorInt64NCHWTy : public Type {
: Type(ID::Tensor_Int64_NCHW, "TensorInt64NCHW", true /*is_tensor*/, : Type(ID::Tensor_Int64_NCHW, "TensorInt64NCHW", true /*is_tensor*/,
target, PrecisionType::kInt8, DataLayoutType::kNCHW) {} target, PrecisionType::kInt8, DataLayoutType::kNCHW) {}
}; };
const Type* LookupType(DataTypeBase::ID type_id, bool is_unknown,
bool is_tensor, Place place);
// ------------------------- end predefined types --------------------------- // ------------------------- end predefined types ---------------------------
// NOTE TypeSystem has some overhead, and better to be used in analysis phase. // NOTE TypeSystem has some overhead, and better to be used in analysis phase.
...@@ -381,13 +405,15 @@ class ParamTypeRegistry { ...@@ -381,13 +405,15 @@ class ParamTypeRegistry {
CHECK(types_.count(key)); CHECK(types_.count(key));
} }
template <IO io> const ParamType* RetrieveInArgument(const Place& place,
const ParamType* Retrieve(const Place& place, const std::string& op_type, const std::string& op_type,
const std::string& arg_name) { const std::string& arg_name) {
KernelIdTy key{op_type, place, io, arg_name}; return Retrieve<IO::kInput>(place, op_type, arg_name);
auto it = types_.find(key); }
if (it == types_.end()) return nullptr; const ParamType* RetrieveOutArgument(const Place& place,
return &it->second; const std::string& op_type,
const std::string& arg_name) {
return Retrieve<IO::kOutput>(place, op_type, arg_name);
} }
static ParamTypeRegistry& Global() { static ParamTypeRegistry& Global() {
...@@ -403,6 +429,16 @@ class ParamTypeRegistry { ...@@ -403,6 +429,16 @@ class ParamTypeRegistry {
return os; return os;
} }
protected:
template <IO io>
const ParamType* Retrieve(const Place& place, const std::string& op_type,
const std::string& arg_name) {
KernelIdTy key{op_type, place, io, arg_name};
auto it = types_.find(key);
if (it == types_.end()) return nullptr;
return &it->second;
}
private: private:
ParamTypeRegistry() = default; ParamTypeRegistry() = default;
......
...@@ -43,6 +43,9 @@ bool KernelPickFactor::IsTargetConsidered() const { ...@@ -43,6 +43,9 @@ bool KernelPickFactor::IsTargetConsidered() const {
bool KernelPickFactor::IsDataLayoutConsidered() const { bool KernelPickFactor::IsDataLayoutConsidered() const {
return data_ & static_cast<int>(Factor::DataLayoutFirst); return data_ & static_cast<int>(Factor::DataLayoutFirst);
} }
bool KernelPickFactor::IsDeviceConsidered() const {
return data_ & static_cast<int>(Factor::DeviceFirst);
}
} // namespace core } // namespace core
} // namespace lite } // namespace lite
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <stack>
#include "paddle/fluid/lite/core/context.h" #include "paddle/fluid/lite/core/context.h"
#include "paddle/fluid/lite/utils/all.h" #include "paddle/fluid/lite/utils/all.h"
...@@ -38,6 +39,7 @@ class KernelPickFactor { ...@@ -38,6 +39,7 @@ class KernelPickFactor {
bool AnyFactorConsidered() const { return data_; } bool AnyFactorConsidered() const { return data_; }
KernelPickFactor& ConsiderTarget(); KernelPickFactor& ConsiderTarget();
// Perfer a specific target, e.g. prefer CUDA kernels.
KernelPickFactor& ConsiderPrecision(); KernelPickFactor& ConsiderPrecision();
KernelPickFactor& ConsiderDataLayout(); KernelPickFactor& ConsiderDataLayout();
KernelPickFactor& ConsiderDevice(); KernelPickFactor& ConsiderDevice();
...@@ -45,12 +47,29 @@ class KernelPickFactor { ...@@ -45,12 +47,29 @@ class KernelPickFactor {
bool IsTargetConsidered() const; bool IsTargetConsidered() const;
bool IsPrecisionConsidered() const; bool IsPrecisionConsidered() const;
bool IsDataLayoutConsidered() const; bool IsDataLayoutConsidered() const;
bool IsDeviceConsidered() const { bool IsDeviceConsidered() const;
return data_ & static_cast<int>(Factor::DeviceFirst);
friend std::ostream& operator<<(std::ostream& os, const KernelPickFactor& k) {
std::stack<bool> bits;
auto data = k.data_;
while (data) {
bits.push(data % 2);
data /= 2;
}
int nbits = bits.size();
for (size_t i = 0; i < sizeof(data) * 8 - nbits; i++) {
os << 0;
}
while (!bits.empty()) {
os << bits.top();
bits.pop();
}
return os;
} }
private: private:
unsigned char data_{}; unsigned char data_{};
TargetType target_{TARGET(kUnk)};
}; };
struct dim2 { struct dim2 {
......
...@@ -12,10 +12,6 @@ ...@@ -12,10 +12,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
//
// Created by chunwei on 19-2-23.
//
#include "paddle/fluid/lite/cuda/target_wrapper.h" #include "paddle/fluid/lite/cuda/target_wrapper.h"
#include <glog/logging.h> #include <glog/logging.h>
...@@ -24,19 +20,14 @@ namespace lite { ...@@ -24,19 +20,14 @@ namespace lite {
using TargetW = TargetWrapper<TARGET(kCUDA), cudaStream_t, cudaEvent_t>; using TargetW = TargetWrapper<TARGET(kCUDA), cudaStream_t, cudaEvent_t>;
template <>
void* TargetW::Malloc(size_t size) { void* TargetW::Malloc(size_t size) {
void* ptr{}; void* ptr{};
CHECK_EQ(cudaSuccess, cudaMalloc(&ptr, size)); CHECK_EQ(cudaSuccess, cudaMalloc(&ptr, size));
return ptr; return ptr;
} }
template <> void TargetW::Free(void* ptr) { CHECK_EQ(cudaSuccess, cudaFree(ptr)); }
void TargetW::Free(void* ptr) {
CHECK_EQ(cudaSuccess, cudaFree(ptr));
}
template <>
void TargetW::MemcpySync(void* dst, const void* src, size_t size, void TargetW::MemcpySync(void* dst, const void* src, size_t size,
IoDirection dir) { IoDirection dir) {
switch (dir) { switch (dir) {
...@@ -55,7 +46,6 @@ void TargetW::MemcpySync(void* dst, const void* src, size_t size, ...@@ -55,7 +46,6 @@ void TargetW::MemcpySync(void* dst, const void* src, size_t size,
} }
} }
template <>
void TargetW::MemcpyAsync(void* dst, const void* src, size_t size, void TargetW::MemcpyAsync(void* dst, const void* src, size_t size,
IoDirection dir, const stream_t& stream) { IoDirection dir, const stream_t& stream) {
switch (dir) { switch (dir) {
......
cc_library(mul_compute_cuda SRCS mul_compute.cc DEPS tensor_lite) nv_library(mul_compute_cuda SRCS mul_compute.cc DEPS tensor_lite)
cc_library(io_copy_compute_cuda SRCS io_copy_compute.cc DEPS tensor_lite) cc_library(io_copy_compute_cuda SRCS io_copy_compute.cc DEPS tensor_lite)
nv_library(kernels_cuda DEPS mul_compute_cuda io_copy_compute_cuda)
...@@ -21,7 +21,7 @@ namespace lite { ...@@ -21,7 +21,7 @@ namespace lite {
namespace kernels { namespace kernels {
namespace cuda { namespace cuda {
using TargetW = TargetWrapper<TARGET(kHost), cudaStream_t, cudaEvent_t>; using TargetW = TargetWrapper<TARGET(kCUDA), cudaStream_t, cudaEvent_t>;
// Host to CUDA memory. // Host to CUDA memory.
void CopyFromHostSync(void* target, const void* source, size_t size) { void CopyFromHostSync(void* target, const void* source, size_t size) {
...@@ -51,6 +51,25 @@ class IoCopyHostToCudaCompute ...@@ -51,6 +51,25 @@ class IoCopyHostToCudaCompute
auto* data = param.y->mutable_data(target(), param.x->memory_size()); auto* data = param.y->mutable_data(target(), param.x->memory_size());
CopyFromHostSync(data, param.x->data<void>(), param.x->memory_size()); CopyFromHostSync(data, param.x->data<void>(), param.x->memory_size());
} }
std::unique_ptr<type_infer_handler_t> GetTypeInferHandler() override {
std::unique_ptr<type_infer_handler_t> res(new type_infer_handler_t);
*res = [](const std::map<std::string, const Type*>& inputs,
const std::string& out) -> const Type* {
CHECK(!inputs.empty());
auto* type = inputs.at("Input");
CHECK(type->target() == TARGET(kHost));
auto out_place = type->place();
out_place.target = TARGET(kCUDA);
auto* out_type = LookupType(type->id(), type->IsUnsupported(),
type->IsUnsupported(), out_place);
return out_type;
};
return res;
}
std::string doc() const override { return "Copy IO from HOST to CUDA"; }
}; };
/* /*
...@@ -65,6 +84,8 @@ class IoCopyCudaToHostCompute ...@@ -65,6 +84,8 @@ class IoCopyCudaToHostCompute
auto* data = param.y->mutable_data(TARGET(kHost), param.x->memory_size()); auto* data = param.y->mutable_data(TARGET(kHost), param.x->memory_size());
CopyToHostSync(data, param.x->data<void>(), param.x->memory_size()); CopyToHostSync(data, param.x->data<void>(), param.x->memory_size());
} }
std::string doc() const override { return "Copy IO from CUDA to HOST"; }
}; };
} // namespace cuda } // namespace cuda
...@@ -72,7 +93,7 @@ class IoCopyCudaToHostCompute ...@@ -72,7 +93,7 @@ class IoCopyCudaToHostCompute
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL(io_copy, kCUDA, kAny, REGISTER_LITE_KERNEL(io_copy, kCUDA, kAny, kAny,
paddle::lite::kernels::cuda::IoCopyHostToCudaCompute, paddle::lite::kernels::cuda::IoCopyHostToCudaCompute,
host_to_device) host_to_device)
.BindInput("Input", {paddle::lite::Type::Get<paddle::lite::TensorAnyTy>( .BindInput("Input", {paddle::lite::Type::Get<paddle::lite::TensorAnyTy>(
...@@ -81,7 +102,7 @@ REGISTER_LITE_KERNEL(io_copy, kCUDA, kAny, ...@@ -81,7 +102,7 @@ REGISTER_LITE_KERNEL(io_copy, kCUDA, kAny,
TARGET(kCUDA))}) TARGET(kCUDA))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(io_copy, kCUDA, kAny, REGISTER_LITE_KERNEL(io_copy, kCUDA, kAny, kAny,
paddle::lite::kernels::cuda::IoCopyCudaToHostCompute, paddle::lite::kernels::cuda::IoCopyCudaToHostCompute,
device_to_host) device_to_host)
.BindInput("Input", {paddle::lite::Type::Get<paddle::lite::TensorAnyTy>( .BindInput("Input", {paddle::lite::Type::Get<paddle::lite::TensorAnyTy>(
......
...@@ -13,3 +13,22 @@ ...@@ -13,3 +13,22 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/kernels/cuda/mul_compute.h" #include "paddle/fluid/lite/kernels/cuda/mul_compute.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(mul, kCUDA, kFloat, kNCHW,
paddle::lite::kernels::cuda::MulCompute, def)
.BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kCUDA))})
.BindInput("Y", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kCUDA))})
.BindOutput("Out", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kCUDA))})
.Finalize();
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "paddle/fluid/lite/core/kernel.h" #include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/types.h" #include "paddle/fluid/lite/core/types.h"
#include "paddle/fluid/lite/cuda/blas.h" #include "paddle/fluid/lite/cuda/blas.h"
#include "paddle/fluid/lite/operators/op_params.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -29,11 +30,29 @@ void mul_compute(const lite::cuda::Blas<float>& blas, const T* x, int x_h, ...@@ -29,11 +30,29 @@ void mul_compute(const lite::cuda::Blas<float>& blas, const T* x, int x_h,
nullptr, out, 0); nullptr, out, 0);
} }
class MulCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> { class MulCompute : public OpKernel<TARGET(kCUDA), PRECISION(kFloat)> {
public: public:
using param_t = operators::MulParam; using param_t = operators::MulParam;
void Run() override {} void Run() override {
CHECK(context_) << "running context should be set first";
auto& context = context_->AsCudaContext();
CHECK(context.blas_fp32) << "blas should init first";
auto& blas = *context.blas_fp32;
const auto& param = Param<operators::MulParam>();
CHECK(param.x->target() == TARGET(kCUDA));
auto* x = param.x->data<float>();
int x_h = param.x->dims()[0];
int x_w = param.x->dims()[1];
auto* y = param.y->data<float>();
int y_h = param.y->dims()[0];
int y_w = param.y->dims()[1];
auto* out = param.output->mutable_data<float>(TARGET(kCUDA));
mul_compute<float>(blas, x, x_h, x_w, y, y_h, y_w, out);
}
virtual ~MulCompute() = default; virtual ~MulCompute() = default;
}; };
......
...@@ -51,8 +51,8 @@ void FcCompute::Run() { ...@@ -51,8 +51,8 @@ void FcCompute::Run() {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL(fc, kHost, kFloat, paddle::lite::kernels::host::FcCompute, REGISTER_LITE_KERNEL(fc, kHost, kFloat, kNCHW,
def) paddle::lite::kernels::host::FcCompute, def)
.BindInput("Input", .BindInput("Input",
{paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>( {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kHost))}) TARGET(kHost))})
......
...@@ -20,7 +20,8 @@ namespace lite { ...@@ -20,7 +20,8 @@ namespace lite {
namespace kernels { namespace kernels {
namespace host { namespace host {
class FeedCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> { class FeedCompute
: public OpKernel<TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)> {
public: public:
using param_t = operators::FeedParam; using param_t = operators::FeedParam;
...@@ -38,7 +39,7 @@ class FeedCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> { ...@@ -38,7 +39,7 @@ class FeedCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL(feed, kHost, kFloat, REGISTER_LITE_KERNEL(feed, kHost, kAny, kAny,
paddle::lite::kernels::host::FeedCompute, def) paddle::lite::kernels::host::FeedCompute, def)
.BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorAnyTy>( .BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorAnyTy>(
TARGET(kHost))}) TARGET(kHost))})
......
...@@ -20,7 +20,8 @@ namespace lite { ...@@ -20,7 +20,8 @@ namespace lite {
namespace kernels { namespace kernels {
namespace host { namespace host {
class FetchCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> { class FetchCompute
: public OpKernel<TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)> {
public: public:
using param_t = operators::FeedParam; using param_t = operators::FeedParam;
...@@ -41,7 +42,7 @@ class FetchCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> { ...@@ -41,7 +42,7 @@ class FetchCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL(fetch, kHost, kFloat, REGISTER_LITE_KERNEL(fetch, kHost, kAny, kAny,
paddle::lite::kernels::host::FetchCompute, def) paddle::lite::kernels::host::FetchCompute, def)
.BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorAnyTy>( .BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorAnyTy>(
TARGET(kHost))}) TARGET(kHost))})
......
...@@ -67,7 +67,7 @@ class MulCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> { ...@@ -67,7 +67,7 @@ class MulCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL(mul, kHost, kFloat, REGISTER_LITE_KERNEL(mul, kHost, kFloat, kNCHW,
paddle::lite::kernels::host::MulCompute, def) paddle::lite::kernels::host::MulCompute, def)
.BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>( .BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kHost))}) TARGET(kHost))})
......
...@@ -42,6 +42,6 @@ class ReluCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> { ...@@ -42,6 +42,6 @@ class ReluCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL(relu, kHost, kFloat, REGISTER_LITE_KERNEL(relu, kHost, kFloat, kNCHW,
paddle::lite::kernels::host::ReluCompute, def) paddle::lite::kernels::host::ReluCompute, def)
.Finalize(); .Finalize();
...@@ -50,7 +50,7 @@ class ScaleCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> { ...@@ -50,7 +50,7 @@ class ScaleCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL(scale, kHost, kFloat, REGISTER_LITE_KERNEL(scale, kHost, kFloat, kNCHW,
paddle::lite::kernels::host::ScaleCompute, def) paddle::lite::kernels::host::ScaleCompute, def)
.BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>( .BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kHost))}) TARGET(kHost))})
......
...@@ -24,7 +24,10 @@ bool IoCopyOp::CheckShape() const { ...@@ -24,7 +24,10 @@ bool IoCopyOp::CheckShape() const {
CHECK_OR_FALSE(param_.y); CHECK_OR_FALSE(param_.y);
return true; return true;
} }
bool IoCopyOp::InferShape() const { return true; } bool IoCopyOp::InferShape() const {
param_.y->Resize(param_.x->dims());
return true;
}
bool IoCopyOp::Run() { return OpLite::Run(); } bool IoCopyOp::Run() { return OpLite::Run(); }
bool IoCopyOp::AttachImpl(const paddle::framework::OpDesc &opdesc, bool IoCopyOp::AttachImpl(const paddle::framework::OpDesc &opdesc,
paddle::lite::Scope *scope) { paddle::lite::Scope *scope) {
......
...@@ -51,9 +51,6 @@ class MulOpLite : public OpLite { ...@@ -51,9 +51,6 @@ class MulOpLite : public OpLite {
param_.x_num_col_dims = boost::get<int>(op_desc.GetAttr("x_num_col_dims")); param_.x_num_col_dims = boost::get<int>(op_desc.GetAttr("x_num_col_dims"));
param_.y_num_col_dims = boost::get<int>(op_desc.GetAttr("y_num_col_dims")); param_.y_num_col_dims = boost::get<int>(op_desc.GetAttr("y_num_col_dims"));
CHECK(kernel_);
kernel_->SetParam(param_);
return true; return true;
} }
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <glog/logging.h>
#include <iostream> #include <iostream>
#include <list> #include <list>
#include <memory> #include <memory>
...@@ -48,8 +49,6 @@ class Factory { ...@@ -48,8 +49,6 @@ class Factory {
} }
void Register(const std::string& op_type, creator_t&& creator) { void Register(const std::string& op_type, creator_t&& creator) {
CHECK(!creators_.count(op_type)) << "The op " << op_type
<< " has already registered";
creators_[op_type].emplace_back(std::move(creator)); creators_[op_type].emplace_back(std::move(creator));
} }
...@@ -58,9 +57,9 @@ class Factory { ...@@ -58,9 +57,9 @@ class Factory {
} }
std::list<item_ptr_t> Creates(const std::string& op_type) const { std::list<item_ptr_t> Creates(const std::string& op_type) const {
auto it = creators_.find(op_type);
CHECK(it != creators_.end()) << "no item called " << op_type;
std::list<item_ptr_t> res; std::list<item_ptr_t> res;
auto it = creators_.find(op_type);
if (it == creators_.end()) return res;
for (auto& c : it->second) { for (auto& c : it->second) {
res.emplace_back(c()); res.emplace_back(c());
} }
......
...@@ -99,7 +99,7 @@ struct variant { ...@@ -99,7 +99,7 @@ struct variant {
size_t type() { return type_id; } size_t type() { return type_id; }
void valid() { return (type_id != invalid_type()); } bool valid() { return (type_id != invalid_type()); }
template <typename T, typename... Args> template <typename T, typename... Args>
void set(Args&&... args) { void set(Args&&... args) {
......
...@@ -24,6 +24,8 @@ namespace utils { ...@@ -24,6 +24,8 @@ namespace utils {
TEST(varient, test) { TEST(varient, test) {
variant<int, float> a; variant<int, float> a;
// The initial state should be invalid.
ASSERT_FALSE(a.valid());
a.set<int>(1); a.set<int>(1);
ASSERT_EQ(a.get<int>(), 1); ASSERT_EQ(a.get<int>(), 1);
a.set<int>(20); a.set<int>(20);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册