From 621d15227fb557751cc29bc804510bedc0fbb5d0 Mon Sep 17 00:00:00 2001 From: superjomn Date: Sat, 27 Apr 2019 15:49:35 +0800 Subject: [PATCH] make io_copy kernel pick works --- paddle/fluid/framework/op_desc.h | 1 + paddle/fluid/lite/api/CMakeLists.txt | 8 +- paddle/fluid/lite/api/cxx_api.h | 3 +- paddle/fluid/lite/api/cxx_api_test.cc | 32 +++- paddle/fluid/lite/core/CMakeLists.txt | 1 + paddle/fluid/lite/core/kernel.cc | 7 + paddle/fluid/lite/core/kernel.h | 68 ++++++- paddle/fluid/lite/core/memory.cc | 2 +- paddle/fluid/lite/core/memory.h | 7 +- paddle/fluid/lite/core/mir/CMakeLists.txt | 3 + .../core/mir/argument_type_display_pass.cc | 45 +++++ .../lite/core/mir/generate_program_pass.cc | 3 + .../fluid/lite/core/mir/io_complement_pass.cc | 170 ++++++++++++++++-- .../fluid/lite/core/mir/io_complement_pass.h | 31 +++- .../lite/core/mir/io_copy_kernel_pick_pass.cc | 74 ++++++++ paddle/fluid/lite/core/mir/node.h | 48 +++-- paddle/fluid/lite/core/mir/passes.h | 2 + paddle/fluid/lite/core/mir/ssa_graph.cc | 138 ++++++++++++++ paddle/fluid/lite/core/mir/ssa_graph.h | 146 +++++++-------- .../lite/core/mir/static_kernel_pick_pass.cc | 5 +- .../lite/core/mir/static_kernel_pick_pass.h | 21 ++- .../core/mir/variable_place_inference_pass.cc | 2 +- .../core/mir/variable_place_inference_pass.h | 47 ++--- paddle/fluid/lite/core/op_lite.cc | 17 +- paddle/fluid/lite/core/op_lite.h | 79 ++++---- paddle/fluid/lite/core/op_registry.cc | 55 ++++-- paddle/fluid/lite/core/op_registry.h | 151 ++++++++++------ paddle/fluid/lite/core/optimizer.cc | 29 +++ paddle/fluid/lite/core/optimizer.h | 26 ++- paddle/fluid/lite/core/program.h | 22 ++- paddle/fluid/lite/core/target_wrapper.h | 62 +++++-- paddle/fluid/lite/core/type_system.cc | 37 +++- paddle/fluid/lite/core/type_system.h | 80 ++++++--- paddle/fluid/lite/core/types.cc | 3 + paddle/fluid/lite/core/types.h | 23 ++- paddle/fluid/lite/cuda/target_wrapper.cc | 12 +- paddle/fluid/lite/kernels/cuda/CMakeLists.txt | 4 +- .../lite/kernels/cuda/io_copy_compute.cc | 27 ++- paddle/fluid/lite/kernels/cuda/mul_compute.cc | 19 ++ paddle/fluid/lite/kernels/cuda/mul_compute.h | 23 ++- paddle/fluid/lite/kernels/host/fc_compute.cc | 4 +- .../fluid/lite/kernels/host/feed_compute.cc | 5 +- .../fluid/lite/kernels/host/fetch_compute.cc | 5 +- paddle/fluid/lite/kernels/host/mul_compute.cc | 2 +- paddle/fluid/lite/kernels/host/relu_compute.h | 2 +- .../fluid/lite/kernels/host/scale_compute.cc | 2 +- paddle/fluid/lite/operators/io_copy_op.cc | 5 +- paddle/fluid/lite/operators/mul_op.h | 3 - paddle/fluid/lite/utils/factory.h | 7 +- paddle/fluid/lite/utils/varient.h | 2 +- paddle/fluid/lite/utils/varient_test.cc | 2 + 51 files changed, 1218 insertions(+), 354 deletions(-) create mode 100644 paddle/fluid/lite/core/mir/argument_type_display_pass.cc create mode 100644 paddle/fluid/lite/core/mir/io_copy_kernel_pick_pass.cc diff --git a/paddle/fluid/framework/op_desc.h b/paddle/fluid/framework/op_desc.h index dedaf243647..03ebc9ac0ac 100644 --- a/paddle/fluid/framework/op_desc.h +++ b/paddle/fluid/framework/op_desc.h @@ -42,6 +42,7 @@ class OpDesc { void CopyFrom(const OpDesc &op_desc); proto::OpDesc *Proto(); + const proto::OpDesc &ReadonlyProto() const { return desc_; } std::string Type() const { return desc_.type(); } diff --git a/paddle/fluid/lite/api/CMakeLists.txt b/paddle/fluid/lite/api/CMakeLists.txt index 73989dda91b..ec0aab9063f 100644 --- a/paddle/fluid/lite/api/CMakeLists.txt +++ b/paddle/fluid/lite/api/CMakeLists.txt @@ -1,3 +1,7 @@ -cc_library(cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host) - +cc_library(cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host ) cc_test(test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite) + +if(LITE_WITH_CUDA) + cc_library(cxx_api_lite_cuda SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host target_wrapper_cuda kernels_cuda) + nv_test(test_cxx_api_lite_cuda SRCS cxx_api_test.cc DEPS cxx_api_lite_cuda model_parser_lite) +endif() diff --git a/paddle/fluid/lite/api/cxx_api.h b/paddle/fluid/lite/api/cxx_api.h index d6b5f09dcbc..ea577b7211e 100644 --- a/paddle/fluid/lite/api/cxx_api.h +++ b/paddle/fluid/lite/api/cxx_api.h @@ -29,7 +29,7 @@ class Predictor { public: Predictor() { scope_ = std::make_shared(); } - void Build(const std::string& model_path, + void Build(const std::string& model_path, const Place& prefer_place, const std::vector& valid_places) { framework::proto::ProgramDesc prog; LoadModel(model_path, scope_.get(), &prog); @@ -38,6 +38,7 @@ class Predictor { Program program(prog_desc, scope_, valid_places); Optimizer optimizer; + optimizer.KernelPickPreferPlace(prefer_place); core::KernelPickFactor factor; factor.ConsiderTarget(); optimizer.Run(std::move(program), valid_places, factor); diff --git a/paddle/fluid/lite/api/cxx_api_test.cc b/paddle/fluid/lite/api/cxx_api_test.cc index 43bbc69a920..7397e837ab3 100644 --- a/paddle/fluid/lite/api/cxx_api_test.cc +++ b/paddle/fluid/lite/api/cxx_api_test.cc @@ -23,8 +23,21 @@ namespace lite { TEST(CXXApi, test) { lite::Predictor predictor; +#ifndef LITE_WITH_CUDA + std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}}); +#else + std::vector valid_places({ + Place{TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)}, + Place{TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW)}, + Place{TARGET(kCUDA), PRECISION(kAny), DATALAYOUT(kNCHW)}, + Place{TARGET(kHost), PRECISION(kAny), DATALAYOUT(kNCHW)}, + Place{TARGET(kCUDA), PRECISION(kAny), DATALAYOUT(kAny)}, + Place{TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)}, + }); +#endif + predictor.Build("/home/chunwei/project2/models/model2", - {Place{TARGET(kHost), PRECISION(kFloat)}}); + Place{TARGET(kCUDA), PRECISION(kFloat)}, valid_places); auto* input_tensor = predictor.GetInput(0); input_tensor->Resize({100, 100}); @@ -54,8 +67,15 @@ USE_LITE_OP(fc); USE_LITE_OP(scale); USE_LITE_OP(feed); USE_LITE_OP(fetch); -USE_LITE_KERNEL(fc, kHost, kFloat, def); -USE_LITE_KERNEL(mul, kHost, kFloat, def); -USE_LITE_KERNEL(scale, kHost, kFloat, def); -USE_LITE_KERNEL(feed, kHost, kFloat, def); -USE_LITE_KERNEL(fetch, kHost, kFloat, def); +USE_LITE_OP(io_copy); +USE_LITE_KERNEL(fc, kHost, kFloat, kNCHW, def); +USE_LITE_KERNEL(mul, kHost, kFloat, kNCHW, def); +USE_LITE_KERNEL(scale, kHost, kFloat, kNCHW, def); +USE_LITE_KERNEL(feed, kHost, kAny, kAny, def); +USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def); + +#ifdef LITE_WITH_CUDA +USE_LITE_KERNEL(mul, kCUDA, kFloat, kNCHW, def); +USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, host_to_device); +USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, device_to_host); +#endif diff --git a/paddle/fluid/lite/core/CMakeLists.txt b/paddle/fluid/lite/core/CMakeLists.txt index c7aa67b478a..2755baf1048 100644 --- a/paddle/fluid/lite/core/CMakeLists.txt +++ b/paddle/fluid/lite/core/CMakeLists.txt @@ -27,5 +27,6 @@ cc_test(test_tensor_lite SRCS tensor_test.cc) cc_test(test_op_executor_lite SRCS op_executor_test.cc DEPS op_executor_lite ops_lite host_kernels) cc_test(test_type_system SRCS type_system_test.cc DEPS type_system) cc_test(test_optimizer_lite SRCS optimizer_test.cc DEPS mir_pass_manager program_fake_utils mir_passes) +cc_test(test_types_lite SRCS types_test.cc DEPS types_lite) add_subdirectory(mir) diff --git a/paddle/fluid/lite/core/kernel.cc b/paddle/fluid/lite/core/kernel.cc index 7cea5986dd7..d297f0f1488 100644 --- a/paddle/fluid/lite/core/kernel.cc +++ b/paddle/fluid/lite/core/kernel.cc @@ -17,6 +17,13 @@ namespace paddle { namespace lite { +std::string KernelBase::summary() const { + std::stringstream ss; + ss << op_type() << ":" << TargetToStr(target()) << "/" + << PrecisionToStr(precision()) << "/" << DataLayoutToStr(layout()); + return ss.str(); +} + bool ParamTypeRegistry::KeyCmp::operator()( const ParamTypeRegistry::key_t &a, const ParamTypeRegistry::key_t &b) const { diff --git a/paddle/fluid/lite/core/kernel.h b/paddle/fluid/lite/core/kernel.h index 9db4874d1bd..d8999cef994 100644 --- a/paddle/fluid/lite/core/kernel.h +++ b/paddle/fluid/lite/core/kernel.h @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include "paddle/fluid/framework/op_desc.h" @@ -34,49 +35,100 @@ namespace lite { // different targets. class KernelBase { public: + // type_infer_handler is used to inference a output type by considering the + // input types in the type system. + using type_infer_handler_t = std::function& input_types, + const std::string& out_arg)>; + virtual void Run() = 0; void SetContext(std::unique_ptr&& ctx) { context_ = std::move(ctx); } - template void SetParam(T param) { param_.set(param); } - template P& Param() const { return param_.get

(); } + // This is used in the kernels that takes 'kAny' places and inference the + // output place. For `ScaleCompute` and `IoCopyCompute`, their input types are + // declared as 'kAny' in some Place field, and the output is also `kAny`, but + // when in real execution, when takes some non-kAny type as input, the + // output's kAny-fields can be determained. For example, when the + // `ScaleCompute` takes `TensorFp32NCHWTy` as input, its output should be also + // `TensorFp32NCHWTy`. This type inference rule is different for each kernel, + // so we make it a virtual method. + // One can custom this handler to make a specific type inference rule for a + // kernel, or leave the default to force the kernel use the system's + // type-inference rules. + virtual std::unique_ptr GetTypeInferHandler() { + return nullptr; + } + void set_op_type(const std::string& type) { op_type_ = type; } const std::string& op_type() const { return op_type_; } void Torch() {} + // Get input declaration type. + const Type* GetInputDeclType(const std::string& arg_name) { + CHECK(!op_type_.empty()) << "op_type should be set first"; + const auto* type = ParamTypeRegistry::Global().RetrieveInArgument( + place(), GenParamTypeKey(), arg_name); + CHECK(type) << "no type registered for kernel [" << op_type_ + << "] input argument [" << arg_name << "]" + << " with key " << GenParamTypeKey(); + return type->type; + } + + // Get output declaration type. + const Type* GetOutputDeclType(const std::string& arg_name) { + CHECK(!op_type_.empty()) << "op_type should be set first"; + const auto* type = ParamTypeRegistry::Global().RetrieveOutArgument( + place(), GenParamTypeKey(), arg_name); + CHECK(type) << "no type registered for kernel [" << op_type_ + << "] output argument [" << arg_name << "]"; + return type->type; + } + + void set_alias(const std::string& x) { + alias_ = x; + LOG(INFO) << "kernel " << op_type() << " setting alias " << alias(); + } + const std::string& alias() const { return alias_; } + virtual Place place() const = 0; virtual TargetType target() const = 0; virtual PrecisionType precision() const = 0; virtual DataLayoutType layout() const = 0; const KernelContext* context() const { return context_.get(); } - virtual std::string name() const = 0; - virtual ~KernelBase() = default; + // Short human-readable document. + std::string summary() const; + // Long human-readable document. + virtual std::string doc() const { return ""; } - std::string DebugString() const { + std::string GenParamTypeKey() const { std::stringstream ss; - ss << op_type() << ":" << TargetToStr(target()) << "/" - << PrecisionToStr(precision()) << "/" << DataLayoutToStr(layout()); + LOG(INFO) << "alias : " << alias_; + ss << op_type() << "/" << alias_; return ss.str(); } + virtual ~KernelBase() = default; + protected: std::unique_ptr context_; mutable operators::param_t param_; // The corresponding op type. - std::string op_type_; + std::string op_type_{}; + std::string alias_{}; }; // Light-weight kernel implementation. diff --git a/paddle/fluid/lite/core/memory.cc b/paddle/fluid/lite/core/memory.cc index f84eeec2c09..205452f0398 100644 --- a/paddle/fluid/lite/core/memory.cc +++ b/paddle/fluid/lite/core/memory.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "memory.h" +#include "paddle/fluid/lite/core/memory.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/lite/core/memory.h b/paddle/fluid/lite/core/memory.h index 6cb321c637b..363156b596f 100644 --- a/paddle/fluid/lite/core/memory.h +++ b/paddle/fluid/lite/core/memory.h @@ -14,7 +14,7 @@ #pragma once #include -#include "target_wrapper.h" +#include "paddle/fluid/lite/core/target_wrapper.h" namespace paddle { namespace lite { @@ -26,9 +26,12 @@ static void* TargetMalloc(TargetType target, size_t size) { case TargetType::kX86: data = TargetWrapper::Malloc(size); break; +#ifdef LITE_WITH_CUDA case TargetType::kCUDA: - data = TargetWrapper::Malloc(size); + data = + TargetWrapper::Malloc(size); break; +#endif // LITE_WITH_CUDA default: LOG(FATAL) << "Unknown supported target " << TargetToStr(target); } diff --git a/paddle/fluid/lite/core/mir/CMakeLists.txt b/paddle/fluid/lite/core/mir/CMakeLists.txt index 1b6e980927c..1fd04914732 100644 --- a/paddle/fluid/lite/core/mir/CMakeLists.txt +++ b/paddle/fluid/lite/core/mir/CMakeLists.txt @@ -7,9 +7,12 @@ cc_library(mir_passes SRCS static_kernel_pick_pass.cc variable_place_inference_pass.cc io_complement_pass.cc + io_copy_kernel_pick_pass.cc graph_visualize_pass.cc generate_program_pass.cc + argument_type_display_pass.cc demo_pass.cc + runtime_context_assign_pass.cc DEPS mir_pass types_lite) cc_test(test_mir_pass_manager SRCS pass_manager_test.cc DEPS mir_pass_manager mir_passes) diff --git a/paddle/fluid/lite/core/mir/argument_type_display_pass.cc b/paddle/fluid/lite/core/mir/argument_type_display_pass.cc new file mode 100644 index 00000000000..77c3b65ac2b --- /dev/null +++ b/paddle/fluid/lite/core/mir/argument_type_display_pass.cc @@ -0,0 +1,45 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/pass.h" +#include "paddle/fluid/lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +class ArgumentTypeDisplayPass : public DebugPass { + public: + void Apply(std::unique_ptr& graph) override { + LOG(INFO) << "== Argument types =="; + for (auto& node : graph->mutable_nodes()) { + if (!node.IsArgument()) continue; + + auto* type = node.AsArgument().type; + if (type) { + LOG(INFO) << "* ARG " << node.AsArgument().name << " type: " << *type; + } else { + LOG(INFO) << "* ARG " << node.AsArgument().name << " type: UNK"; + } + } + LOG(INFO) << "---------------------"; + } +}; + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(argument_type_display_pass, + paddle::lite::mir::ArgumentTypeDisplayPass); diff --git a/paddle/fluid/lite/core/mir/generate_program_pass.cc b/paddle/fluid/lite/core/mir/generate_program_pass.cc index 0e47d6b02ca..0b2d806264e 100644 --- a/paddle/fluid/lite/core/mir/generate_program_pass.cc +++ b/paddle/fluid/lite/core/mir/generate_program_pass.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/lite/core/mir/generate_program_pass.h" +#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h" #include "paddle/fluid/lite/core/mir/pass_registry.h" namespace paddle { @@ -20,9 +21,11 @@ namespace lite { namespace mir { void GenerateProgramPass::Apply(std::unique_ptr& graph) { + LOG(INFO) << "final program \n" << Visualize(graph.get()); for (auto& item : graph->InstructTopologicalOrder()) { if (item->IsInstruct()) { auto& instruct = item->AsInstruct(); + LOG(INFO) << instruct; insts_.emplace_back(instruct.op, std::move(instruct.valid_kernels.front())); } diff --git a/paddle/fluid/lite/core/mir/io_complement_pass.cc b/paddle/fluid/lite/core/mir/io_complement_pass.cc index f122bfab689..eac6208a7f1 100644 --- a/paddle/fluid/lite/core/mir/io_complement_pass.cc +++ b/paddle/fluid/lite/core/mir/io_complement_pass.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/lite/core/mir/io_complement_pass.h" +#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h" #include "paddle/fluid/lite/core/mir/pass_registry.h" namespace paddle { @@ -21,28 +22,161 @@ namespace mir { void IoComplementPass::Apply(std::unique_ptr& graph) { // Start from inputs of the graph, those should have place set. + std::list nodes; for (auto& node : graph->mutable_nodes()) { - if (!node.IsInstruct()) continue; - auto& inst = node.AsInstruct(); - - // inputs - for (auto* in : node.inlinks) { - CHECK(in->IsArgument()); - auto name = in->AsArgument().name; - std::string tmp; - CHECK(inst.op_info->GetInputArgname(name, &tmp)); - auto type = - ParamTypeRegistry::Global().Retrieve( - inst.place, inst.op_type, tmp); - CHECK(type) << "no param type found for " << inst.op_type << ":" << name - << " " << inst.place; - CHECK(type->type); - CHECK(in->AsArgument().type); - if (!TypeCompatible(*type->type, *in->AsArgument().type)) { - LOG(INFO) << "found IO unmatched tensor: " << in->AsArgument().name; + nodes.push_back(&node); + } + + CHECK(!valid_places_.empty()); + + for (auto& node : nodes) { + if (!node->IsInstruct()) continue; + auto inlinks = node->inlinks; + for (auto* in : inlinks) { + ComplementInputs(graph.get(), node, in); + } + } + + // PickIoCopyKernel(graph.get()); + + LOG(INFO) << "\n" << Visualize(graph.get()); +} + +void IoComplementPass::ComplementInputs(SSAGraph* graph, Node* inst_node, + Node* in) { + // If this input is out of date. + if (inst_node->inlinks.end() == + std::find(inst_node->inlinks.begin(), inst_node->inlinks.end(), in)) + return; + + CHECK(inst_node->IsInstruct()); + auto& inst = inst_node->AsInstruct(); + CHECK(in->IsRoleSet()); + CHECK(in->IsArgument()); + auto in_arg_name = in->AsArgument().name; + std::string tmp; + CHECK(inst.op_info()->GetInputArgname(in_arg_name, &tmp)); + auto decl_arg_type = inst.picked_kernel().GetInputDeclType(tmp); + CHECK(in->AsArgument().type); + if (!TypeCompatibleTo(*in->AsArgument().type, *decl_arg_type)) { + LOG(INFO) << "found IO unmatched tensor: " << in->AsArgument().name + << " for kernel " << inst.op->DebugString() << " " + << *in->AsArgument().type << " -> " << *decl_arg_type; + // Add an IoCopy instruction to make the input compatible with other dist. + AddIoCopyInst(*in->AsArgument().type, *decl_arg_type, in->AsArgument().name, + graph, inst_node, valid_places_); + } +} + +void UpdateOpdescInputName(framework::OpDesc* desc, + const std::string& old_arg_name, + const std::string& new_arg_name) { + for (auto& item : *desc->Proto()->mutable_inputs()) { + for (int i = 0; i < item.mutable_arguments()->size(); i++) { + auto* arg = item.mutable_arguments(i); + if (*arg == old_arg_name) { + *arg = new_arg_name; + } + } + } +} + +void IoComplementPass::AddIoCopyInst(const Type& from, const Type& to, + const std::string& var, SSAGraph* graph, + Node* inst_node, + const std::vector& valid_places) { + CHECK(!valid_places.empty()) << "valid_place should be set"; + // var -> new_transform_op -> new_var -> inst + // So there will be a new Argument node and a new IoCopy Instruct Node. + + auto node_id = [&] { return graph->nodes().size(); }; + auto io_copy_output_name = var + "/trans/" + std::to_string(node_id()); + auto* io_copy_output_arg = graph->NewArgumentNode(io_copy_output_name); + auto* io_copy_inst = graph->NewInstructNode(); + + // create Op and kernels. + auto io_copy_op = LiteOpRegistry::Global().Create("io_copy"); + // CHECK(io_copy_op); + // Create the new var manually. + inst_node->AsInstruct().op->scope()->Var(io_copy_output_name); + + // Create IoCopy Instruction. + framework::OpDesc op_desc; + op_desc.SetType("io_copy"); + op_desc.SetInput("Input", {var}); + op_desc.SetOutput("Out", {io_copy_output_name}); + op_desc.Flush(); + + io_copy_op->Attach(op_desc, inst_node->AsInstruct().op->scope()); + auto kernels = io_copy_op->CreateKernels(valid_places); + io_copy_inst->AsInstruct("io_copy", std::move(kernels), io_copy_op); + + // Remove the old link + RemoveDirectedLink(graph->Argument(var), inst_node); + + // Update the original instruction OpDesc. + // Update its input to the io_copy_output_name + auto& inst = inst_node->AsInstruct(); + auto inst_program_desc = inst.op_info()->desc(); + + // Add new link, var -> new_inst, new_inst->newarg, newarg->inst + DirectedLink(graph->Argument(var), io_copy_inst); + DirectedLink(io_copy_inst, io_copy_output_arg); + DirectedLink(io_copy_output_arg, inst_node); + + // reset opdesc and update kernel information + auto desc_dummy = inst_node->AsInstruct().op->op_info()->desc(); + UpdateInputTo(&desc_dummy, var, io_copy_output_name); + + framework::OpDesc desc_fake(desc_dummy, nullptr); + inst_node->AsInstruct().op->Attach(desc_fake, + inst_node->AsInstruct().op->scope()); + + std::string tmp; + if (inst_node->AsInstruct().op_info()->GetInputArgname("a", &tmp)) { + CHECK(false) << "get old a " << tmp; + } + + for (auto& kernel : inst_node->AsInstruct().valid_kernels) { + inst_node->AsInstruct().op->AttachKernel(kernel.get()); + } + + graph->CheckValid(); +} + +void IoComplementPass::PickIoCopyKernel(SSAGraph* graph) { + for (auto& node : graph->mutable_nodes()) { + if (node.IsInstruct() && node.AsInstruct().op_type == "io_copy") { + auto& kernels = node.AsInstruct().valid_kernels; + CHECK(!kernels.empty()) << "No valid kernels found for IoCopy Op"; + for (auto& kernel : kernels) { + CHECK_EQ(node.inlinks.size(), 1UL); + CHECK_EQ(node.outlinks.size(), 1UL); + auto* inty = node.inlinks.front()->AsArgument().type; + auto* outy = node.outlinks.front()->AsArgument().type; + const Type* in_arg_ty = kernel->GetInputDeclType("Input"); + if (TypeCompatibleTo(*inty, *in_arg_ty)) { + const Type* out_arg_ty = kernel->GetOutputDeclType("Out"); + // Both the input and output type matches, remove other kernels + // directly. + if (out_arg_ty->target() == outy->target()) { + LOG(INFO) << "get a IOCopy kernel"; + auto x = std::move(kernel); + kernels.clear(); + kernels.emplace_back(std::move(x)); + break; + } + } } } } + + // Check the compatiblity. +} + +void IoComplementPass::SetValidPlaces(const std::vector& valid_places) { + CHECK(!valid_places.empty()); + valid_places_ = valid_places; } } // namespace mir diff --git a/paddle/fluid/lite/core/mir/io_complement_pass.h b/paddle/fluid/lite/core/mir/io_complement_pass.h index b44fde7b450..feca16debf8 100644 --- a/paddle/fluid/lite/core/mir/io_complement_pass.h +++ b/paddle/fluid/lite/core/mir/io_complement_pass.h @@ -15,18 +15,47 @@ #pragma once #include "paddle/fluid/lite/core/mir/pass.h" +#include "paddle/fluid/lite/core/op_registry.h" namespace paddle { namespace lite { namespace mir { +static void UpdateInputTo(framework::proto::OpDesc* desc, + const std::string& from, const std::string& to) { + for (auto& item : *desc->mutable_inputs()) { + for (auto& input : *item.mutable_arguments()) { + if (input == from) { + LOG(INFO) << "** update input argument from " << from << " to " << to; + input = to; + } + } + } +} + /* * IoComplementPass complement the necessary instruction to make data * transferring or transformation between different places. */ class IoComplementPass : public ProgramPass { public: - void Apply(std::unique_ptr &graph) override; + void Apply(std::unique_ptr& graph) override; + + void ComplementInputs(SSAGraph* graph, Node* inst_node, Node* in); + + void AddIoCopyInst(const Type& from, const Type& to, const std::string& var, + SSAGraph* graph, Node* inst_node, + const std::vector& valid_places); + + void SetValidPlaces(const std::vector& valid_places); + + // Pick the right kernel of IoCopy considering the input and output Type. + void PickIoCopyKernel(SSAGraph* graph); + + const std::vector& valid_places() const { return valid_places_; }; + + private: + std::vector valid_places_; }; } // namespace mir diff --git a/paddle/fluid/lite/core/mir/io_copy_kernel_pick_pass.cc b/paddle/fluid/lite/core/mir/io_copy_kernel_pick_pass.cc new file mode 100644 index 00000000000..78f0a9b02b9 --- /dev/null +++ b/paddle/fluid/lite/core/mir/io_copy_kernel_pick_pass.cc @@ -0,0 +1,74 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/pass.h" +#include "paddle/fluid/lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +class IoCopyKernelPickPass : public InstructionPass { + public: + void Apply(std::unique_ptr& graph) override { + for (auto& node : graph->mutable_nodes()) { + if (!node.IsInstruct()) continue; + auto& inst = node.AsInstruct(); + if (inst.op_type != "io_copy") continue; + + LOG(INFO) << "....> picking a IO COPY kernel"; + + auto& kernels = node.AsInstruct().valid_kernels; + CHECK(!kernels.empty()) << "No valid kernels found for IoCopy Op"; + const auto* inty = node.inlinks.front()->AsArgument().type; + const auto* outy = node.outlinks.front()->AsArgument().type; + LOG(INFO) << "input type " << *inty; + LOG(INFO) << "output type " << *outy; + + bool is_found = false; + LOG(INFO) << "kernels size " << kernels.size(); + for (auto& kernel : kernels) { + CHECK_EQ(node.inlinks.size(), 1UL); + CHECK_EQ(node.outlinks.size(), 1UL); + + const Type* in_arg_ty = kernel->GetInputDeclType("Input"); + const Type* out_arg_ty = kernel->GetOutputDeclType("Out"); + LOG(INFO) << "checking kernel candidate " << *in_arg_ty << "->" + << *out_arg_ty; + if (inty->target() == in_arg_ty->target()) { + // Both the input and output type matches, remove other kernels + // directly. + if (out_arg_ty->target() == outy->target()) { + LOG(INFO) << "get a IOCopy kernel"; + auto x = std::move(kernel); + kernels.clear(); + kernels.emplace_back(std::move(x)); + is_found = true; + break; + } + } + } + + CHECK(is_found) << "Can't find a IoCopy kernel for IO: " << *inty << "->" + << *outy; + } + } +}; + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(io_copy_kernel_pick_pass, + paddle::lite::mir::IoCopyKernelPickPass); diff --git a/paddle/fluid/lite/core/mir/node.h b/paddle/fluid/lite/core/mir/node.h index e16e6ffb7b3..daa6f5555a8 100644 --- a/paddle/fluid/lite/core/mir/node.h +++ b/paddle/fluid/lite/core/mir/node.h @@ -34,30 +34,43 @@ class Node { Node() = default; enum class Role { - kUnk = -1, - kArgument, + kArgument = 0, kInstruct, - kNumRoles /*should be last*/ + kNumRoles, /*should be last*/ + kUnk, }; struct Instruct { std::string op_type; - Place place; // The kernel instances this Instruct contains. std::vector> valid_kernels; - std::shared_ptr op_info; // TODO(Superjomn) make this a shared_ptr for resource safety. std::shared_ptr op; // we hold op to run InferShape + const OpInfo* op_info() { + CHECK(op); + return op->op_info(); + } + + Place place() const { + CHECK(!valid_kernels.empty()); + return valid_kernels.front()->place(); + } + KernelBase& picked_kernel() { CHECK(!valid_kernels.empty()); return *valid_kernels.front(); } + + friend std::ostream& operator<<(std::ostream& os, const Instruct& other) { + os << "Instruct " << other.op_type << " " << other.place(); + return os; + } }; struct Argument { std::string name; - const Type* type; + const Type* type{}; // Weight is a special kind of argument, it is marked as weight explicitly // so that some weight related optimization can take place. bool is_weight{false}; @@ -71,13 +84,11 @@ class Node { Instruct& AsInstruct(const std::string& op_type, std::vector>&& kernels, - const std::shared_ptr& op, - const std::shared_ptr& op_info) { + const std::shared_ptr& op) { auto& x = AsInstruct(); x.op_type = op_type; x.op = op; x.valid_kernels = std::move(kernels); - x.op_info = op_info; return x; } @@ -100,8 +111,25 @@ class Node { instruct_.reset(new Instruct); return *instruct_; } + + friend std::ostream& operator<<(std::ostream& os, Node& other) { + os << static_cast(other.role_) << " "; + if (!other.IsRoleSet()) { + os << "Unk role node"; + } + if (other.IsArgument()) { + auto& arg = other.AsArgument(); + os << "Argument " << arg.name; + } + if (other.IsInstruct()) { + auto& arg = other.AsInstruct(); + os << "Instruct " << arg.op_type; + } + return os; + } + // Check roles. - bool IsRoleSet() const { return role_ == Role::kUnk; } + bool IsRoleSet() const { return role_ != Role::kUnk; } bool IsInstruct() const { return role_ == Role::kInstruct; } bool IsArgument() const { return role_ == Role::kArgument; } diff --git a/paddle/fluid/lite/core/mir/passes.h b/paddle/fluid/lite/core/mir/passes.h index 50ecd267887..35eaeeef29a 100644 --- a/paddle/fluid/lite/core/mir/passes.h +++ b/paddle/fluid/lite/core/mir/passes.h @@ -26,3 +26,5 @@ USE_MIR_PASS(static_kernel_pick_pass); USE_MIR_PASS(variable_place_inference_pass); USE_MIR_PASS(io_complement_pass); USE_MIR_PASS(generate_program_pass); +USE_MIR_PASS(io_copy_kernel_pick_pass); +USE_MIR_PASS(argument_type_display_pass); diff --git a/paddle/fluid/lite/core/mir/ssa_graph.cc b/paddle/fluid/lite/core/mir/ssa_graph.cc index 0ab203af4d8..51f9362e5c1 100644 --- a/paddle/fluid/lite/core/mir/ssa_graph.cc +++ b/paddle/fluid/lite/core/mir/ssa_graph.cc @@ -89,6 +89,144 @@ std::vector SSAGraph::InstructTopologicalOrder() { return res; } +void SSAGraph::GraphCreateTmpVarNodes(const Program &program) { + for (const auto &name : program.tmp_vars) { + LOG(INFO) << "create arg node " << name; + node_storage_.emplace_back(); + auto &new_node = node_storage_.back(); + new_node.AsArgument(name); + arguments_[name] = &new_node; + } +} + +void SSAGraph::GraphCreateWeightVarNodes(const Program &program) { + // create weight nodes. + for (const auto &name : program.weights) { + LOG(INFO) << "create arg node " << name; + node_storage_.emplace_back(); + auto &new_node = node_storage_.back(); + new_node.AsArgument(name); + arguments_[name] = &new_node; + } +} + +Node *SSAGraph::GraphCreateInstructNode( + const Program &program, const std::shared_ptr &op, + const std::vector &valid_places) { + node_storage_.emplace_back(); + // TODO(Superjomn) remove one valid_places here. + op->SetValidPlaces(valid_places); + auto &new_node = node_storage_.back(); + auto kernels = op->CreateKernels(valid_places); + node_storage_.back().AsInstruct(op->op_type_, std::move(kernels), op); + + CHECK(new_node.inlinks.empty()) << "duplicate Build found"; + CHECK(new_node.outlinks.empty()) << "duplicate Build found"; + return &node_storage_.back(); +} + +void SSAGraph::Build(const Program &program, + const std::vector &valid_places) { + CHECK(node_storage_.empty()); + GraphCreateTmpVarNodes(program); + GraphCreateWeightVarNodes(program); + CHECK(CheckNodesRoleSet()); + + for (auto &op : program.ops) { + auto *op_node = GraphCreateInstructNode(program, op, valid_places); + LOG(INFO) << "checking op " << op->op_type_; + for (const std::string &name : op->op_info()->input_names()) { + auto *arg = Argument(name); + LOG(INFO) << "input " << name; + CHECK(arg->IsRoleSet()); + DirectedLink(arg, op_node); + } + for (const std::string &name : op->op_info()->output_names()) { + if (!arguments_.count(name)) { + NewArgumentNode(name); + } + LOG(INFO) << "output " << name; + auto *arg = arguments_.at(name); + CHECK(arg->IsRoleSet()); + DirectedLink(op_node, arg); + } + CHECK(CheckLinksRoleSet()); + } + + MarkArgumentWeights(program); + CheckValid(); +} + +mir::Node *SSAGraph::Argument(const std::string &name) { + auto it = arguments_.find(name); + CHECK(it != arguments_.end()) << "no argument called " << name; + return it->second; +} + +std::vector SSAGraph::inputs() { + std::vector res; + for (auto &node : node_storage_) { + if (node.inlinks.empty()) { + res.push_back(&node); + } + } + return res; +} + +std::vector SSAGraph::outputs() { + std::vector res; + for (auto &node : node_storage_) { + if (node.outlinks.empty()) { + res.push_back(&node); + } + } + return res; +} + +mir::Node *SSAGraph::RetrieveArgument(const std::string &arg) { + auto it = arguments_.find(arg); + if (it != arguments_.end()) { + return it->second; + } + return nullptr; +} + +bool SSAGraph::CheckNodesRoleSet() { + for (auto &node : mutable_nodes()) { + CHECK_OR_FALSE(node.IsRoleSet()); + } + return true; +} + +bool SSAGraph::CheckLinksRoleSet() { + for (auto &node : mutable_nodes()) { + CHECK_OR_FALSE(node.IsRoleSet()); + if (!node.IsInstruct()) continue; + for (auto *x : node.inlinks) { + CHECK_OR_FALSE(x->IsRoleSet()); + CHECK_OR_FALSE(x->IsArgument()); + } + for (auto *x : node.outlinks) { + CHECK_OR_FALSE(x->IsRoleSet()); + CHECK_OR_FALSE(x->IsArgument()); + } + } + return true; +} + +Node *SSAGraph::NewArgumentNode(const std::string &name) { + node_storage_.emplace_back(); + CHECK(!arguments_.count(name)) << "duplicate argument called " << name; + arguments_[name] = &node_storage_.back(); + node_storage_.back().AsArgument(name); + return &node_storage_.back(); +} + +Node *SSAGraph::NewInstructNode() { + node_storage_.emplace_back(); + return &node_storage_.back(); +} + } // namespace mir } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/core/mir/ssa_graph.h b/paddle/fluid/lite/core/mir/ssa_graph.h index 070747f4e57..011aed2cd91 100644 --- a/paddle/fluid/lite/core/mir/ssa_graph.h +++ b/paddle/fluid/lite/core/mir/ssa_graph.h @@ -35,104 +35,44 @@ class SSAGraph : GraphBase { public: // @param program: the op program // @param valid_places: the valid places user set for the system. - void Build(const Program &program, const std::vector &valid_places) { - // create temporary nodes. - for (const auto &name : program.tmp_vars) { - node_storage_.emplace_back(); - auto &new_node = node_storage_.back(); - auto &arg = new_node.AsArgument(); - arg.name = name; - arguments_[name] = &new_node; - } - - // create weight nodes. - for (const auto &name : program.weights) { - node_storage_.emplace_back(); - auto &new_node = node_storage_.back(); - auto &arg = new_node.AsArgument(); - arg.name = name; - arguments_[name] = &new_node; - } + void Build(const Program &program, const std::vector &valid_places); - for (auto &op : program.ops) { - node_storage_.emplace_back(); - // TODO(Superjomn) remove one valid_places here. - op->SetValidPlaces(valid_places); - auto &new_node = node_storage_.back(); - auto kernels = op->CreateKernels(valid_places); - node_storage_.back().AsInstruct(op->op_type_, std::move(kernels), op, - op->op_info()); - - CHECK(new_node.inlinks.empty()) << "duplicate Build found"; - CHECK(new_node.outlinks.empty()) << "duplicate Build found"; - - // collect inputs and outputs - for (const std::string &name : op->op_info()->input_names()) { - auto *arg = Argument(name); - new_node.inlinks.push_back(arg); - arg->outlinks.push_back(&new_node); - } - for (const std::string &name : op->op_info()->output_names()) { - if (!arguments_.count(name)) { - node_storage_.emplace_back(); - auto &new_node = node_storage_.back(); - auto &arg = new_node.AsArgument(name); - arg.name = name; - arguments_.emplace(name, &new_node); - } - auto *arg = arguments_.at(name); - new_node.outlinks.push_back(arg); - arg->inlinks.push_back(&new_node); - } - } - - MarkArgumentWeights(program); - } - - mir::Node *Argument(const std::string &name) { - auto it = arguments_.find(name); - CHECK(it != arguments_.end()) << "no argument called " << name; - return it->second; - } + mir::Node *Argument(const std::string &name); std::vector InstructTopologicalOrder(); // The inputs of the graph. - std::vector inputs() { - std::vector res; - for (auto &node : node_storage_) { - if (node.inlinks.empty()) { - res.push_back(&node); - } - } - return res; - } + std::vector inputs(); // The outputs of the graph. - std::vector outputs() { - std::vector res; - for (auto &node : node_storage_) { - if (node.outlinks.empty()) { - res.push_back(&node); - } - } - return res; - } + std::vector outputs(); const std::list &nodes() const { return node_storage_; } std::list &mutable_nodes() { return node_storage_; } - mir::Node *RetrieveArgument(const std::string &arg) { - auto it = arguments_.find(arg); - if (it != arguments_.end()) { - return it->second; - } - return nullptr; + mir::Node *RetrieveArgument(const std::string &arg); + + Node *NewArgumentNode(const std::string &name); + Node *NewInstructNode(); + + void CheckValid() { + CHECK(CheckBidirectionalConnection()); + CHECK(CheckNodesRoleSet()); + CHECK(CheckLinksRoleSet()); } private: + void GraphCreateTmpVarNodes(const Program &program); + void GraphCreateWeightVarNodes(const Program &program); + Node *GraphCreateInstructNode(const Program &program, + const std::shared_ptr &op, + const std::vector &valid_places); + // Check the bidirectional connection. bool CheckBidirectionalConnection(); + bool CheckNodesRoleSet(); + // Check all the items's role in inlinks and outlinks is set. + bool CheckLinksRoleSet(); void MarkArgumentWeights(const Program &program) { for (const auto &name : program.weights) { @@ -152,6 +92,48 @@ class SSAGraph : GraphBase { std::map arguments_; }; +// Remove the link between a -> b. +static void RemoveDirectedLink(Node *a, Node *b) { + auto it = std::find(b->inlinks.begin(), b->inlinks.end(), a); + if (it != b->inlinks.end()) { + b->inlinks.erase(it); + } + + auto it1 = std::find(a->outlinks.begin(), a->outlinks.end(), b); + if (it1 != a->outlinks.end()) { + a->outlinks.erase((it1)); + } +} + +// Link a -> b. +static void DirectedLink(Node *a, Node *b) { + // Eagerly remove first, to avoid duplicate link. + RemoveDirectedLink(a, b); + a->outlinks.push_back(b); + b->inlinks.push_back(a); +} + +static void LocalInferenceType(Node *a, Node *b, const std::string &arg_name) { + // instr -> output argument + if (a->IsInstruct() && b->IsArgument()) { + auto &inst = a->AsInstruct(); + auto &output = b->AsArgument(); + + if (!output.type) { + output.type = inst.picked_kernel().GetOutputDeclType(arg_name); + } + } + + // input argument -> instr + if (a->IsArgument() && b->IsInstruct()) { + auto &input = a->AsArgument(); + auto &inst = b->AsInstruct(); + if (!input.type) { + input.type = inst.picked_kernel().GetInputDeclType(arg_name); + } + } +} + } // namespace mir } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc b/paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc index a86470f3fce..a85628f027d 100644 --- a/paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc +++ b/paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc @@ -37,7 +37,9 @@ void StaticKernelPickPass::Apply(std::unique_ptr& graph) { auto& instruct = node.AsInstruct(); std::vector>> scored; for (auto&& kernel : instruct.valid_kernels) { - scored.emplace_back(KernelGrade(*kernel), std::move(kernel)); + size_t score = KernelGrade(*kernel); + LOG(INFO) << "kernel " << kernel->summary() << " " << score; + scored.emplace_back(score, std::move(kernel)); } std::sort(scored.begin(), scored.end(), KernelScoreCmp); @@ -47,7 +49,6 @@ void StaticKernelPickPass::Apply(std::unique_ptr& graph) { // TODO(Superjomn) reconsider this. instruct.valid_kernels.clear(); instruct.valid_kernels.emplace_back(std::move(scored.front().second)); - instruct.place = instruct.valid_kernels.front()->place(); LOG(INFO) << "pick " << instruct.valid_kernels.front()->name(); } } diff --git a/paddle/fluid/lite/core/mir/static_kernel_pick_pass.h b/paddle/fluid/lite/core/mir/static_kernel_pick_pass.h index 9cf3ee2d439..8e0aed7daab 100644 --- a/paddle/fluid/lite/core/mir/static_kernel_pick_pass.h +++ b/paddle/fluid/lite/core/mir/static_kernel_pick_pass.h @@ -37,6 +37,7 @@ class StaticKernelPickPass : public mir::InstructionPass { public: void Apply(std::unique_ptr& graph) override; + void SetPreferPlace(const Place& place) { place_ = place; } const Place& place() const { return place_; } const core::KernelPickFactor& kernel_pick_factors() const { return kernel_pick_factors_; @@ -51,16 +52,32 @@ class StaticKernelPickPass : public mir::InstructionPass { size_t score{}; const int kMax = std::numeric_limits::max(); + + // The more important factor comes first if (kernel_pick_factors_.IsTargetConsidered() && - place().target == kernel.target()) { + (place().target == kernel.target() || kernel.target() == TARGET(kAny) || + place().target == TARGET(kAny))) { score += kMax / static_cast(core::KernelPickFactor::Factor::TargetFirst); } if (kernel_pick_factors_.IsPrecisionConsidered() && - place().precision == kernel.precision()) { + (place().precision == kernel.precision() || + kernel.precision() == PRECISION(kAny) || + place().precision == PRECISION(kAny))) { score += kMax / static_cast(core::KernelPickFactor::Factor::PrecisionFirst); } + if (kernel_pick_factors_.IsDataLayoutConsidered() && + (place().layout == kernel.layout() || + kernel.layout() == DATALAYOUT(kAny) || + place().layout == DATALAYOUT(kAny))) { + score += kMax / static_cast( + core::KernelPickFactor::Factor::DataLayoutFirst); + } + LOG(INFO) << "picker tactic " << kernel_pick_factors_; + LOG(INFO) << "kernel place " << kernel.place(); + LOG(INFO) << "picker place " << place(); + LOG(INFO) << "score " << score; // The data layout is not considered, for the input and output arguments // might have different data layout. diff --git a/paddle/fluid/lite/core/mir/variable_place_inference_pass.cc b/paddle/fluid/lite/core/mir/variable_place_inference_pass.cc index 50665066365..4998a800582 100644 --- a/paddle/fluid/lite/core/mir/variable_place_inference_pass.cc +++ b/paddle/fluid/lite/core/mir/variable_place_inference_pass.cc @@ -22,8 +22,8 @@ namespace mir { void VariablePlaceInferencePass::Apply(std::unique_ptr& graph) { MarkInputPlace(graph.get()); - InferenceArgumentPlace(graph.get()); + CheckAllArgumentTypeDetermined(graph.get()); } } // namespace mir diff --git a/paddle/fluid/lite/core/mir/variable_place_inference_pass.h b/paddle/fluid/lite/core/mir/variable_place_inference_pass.h index 3a05828d53c..2a24ac6e67f 100644 --- a/paddle/fluid/lite/core/mir/variable_place_inference_pass.h +++ b/paddle/fluid/lite/core/mir/variable_place_inference_pass.h @@ -31,6 +31,7 @@ class VariablePlaceInferencePass : public DebugPass { private: // Mark the place of input arguments. void MarkInputPlace(SSAGraph* graph) { + CHECK(!graph->inputs().empty()) << "graph's inputs should be set"; for (const auto& v : graph->inputs()) { // the feed op might in the inputs if (v->IsInstruct()) { @@ -39,54 +40,60 @@ class VariablePlaceInferencePass : public DebugPass { } // auto& arg = v->AsArgument(); - // arg.place.target = argument_default_target_; + // LOG(INFO) << "get graph input " << arg.name << " " << *arg.type; + // arg.type.target = argument_default_target_; // the other place description can't be determined yet, until their first // usage by some kernel. } } + void CheckAllArgumentTypeDetermined(SSAGraph* graph) { + for (auto& node : graph->mutable_nodes()) { + if (node.IsArgument()) { + CHECK(node.AsArgument().type) << "node " << node.AsArgument().name + << " type not determined"; + } + } + } + void InferenceArgumentPlace(SSAGraph* graph) { LOG(INFO) << "param-type-registry:\n" << ParamTypeRegistry::Global(); for (auto& x : graph->InstructTopologicalOrder()) { auto& inst = x->AsInstruct(); - CHECK(inst.place.is_valid()) - << "kernel's place should be set when loaded"; + // The IoCopyOp is a tool operator, it won't support the type inference. + if (inst.op_type == "io_copy") continue; + // LOG(INFO) << "- inferencing type " << // deal with inputs - for (auto& arg_name : inst.op_info->input_argnames()) { - auto type = - ParamTypeRegistry::Global().Retrieve( - inst.place, inst.op_type, arg_name); - CHECK(type) << "no param-type found for " << inst.op_type << ":" - << arg_name << " " << inst.place.DebugString(); - auto arg_names = inst.op_info->input_argument().at(arg_name); + for (auto& arg_name : inst.op_info()->input_argnames()) { + LOG(INFO) << "-- input arg_name " << arg_name; // check if inputs's place is set, if not set, update them with the // kernel's declaration. + auto type = inst.picked_kernel().GetInputDeclType(arg_name); + auto arg_names = inst.op_info()->input_argument().at(arg_name); for (auto& arg_name : arg_names) { + LOG(INFO) << "--- var " << arg_name; auto* node = graph->RetrieveArgument(arg_name); CHECK(node) << "argument " << arg_name << " not exists in the graph"; auto& arg_node = node->AsArgument(); if (arg_node.type) continue; - arg_node.type = type->type; + arg_node.type = type; } } - for (auto& arg_name : inst.op_info->output_argnames()) { - auto type = ParamTypeRegistry::Global() - .Retrieve( - inst.place, inst.op_type, arg_name); - CHECK(type) << "no param-type found for " << inst.op_type << ":" - << arg_name << " " << inst.place.DebugString(); - auto arg_names = inst.op_info->output_argument().at(arg_name); + for (auto& arg_name : inst.op_info()->output_argnames()) { + LOG(INFO) << "-- output arg_name " << arg_name; + auto type = inst.picked_kernel().GetOutputDeclType(arg_name); + auto arg_names = inst.op_info()->output_argument().at(arg_name); // check if outputs's place is set, if not set, update them with the // kernel's declaration. - for (auto& arg_name : arg_names) { + LOG(INFO) << "--- var " << arg_name; auto* node = graph->RetrieveArgument(arg_name); CHECK(node) << "argument " << arg_name << " not exists in the graph"; auto& arg_node = node->AsArgument(); if (arg_node.type) continue; - node->AsArgument().type = type->type; + node->AsArgument().type = type; } } } diff --git a/paddle/fluid/lite/core/op_lite.cc b/paddle/fluid/lite/core/op_lite.cc index 86557ce0abc..ba6917c6210 100644 --- a/paddle/fluid/lite/core/op_lite.cc +++ b/paddle/fluid/lite/core/op_lite.cc @@ -27,13 +27,15 @@ std::vector> OpLite::CreateKernels( for (auto place : places) { auto ks = KernelRegistry::Global().Create( (kernel_type.empty() ? op_type_ : kernel_type), place.target, - place.precision); + place.precision, place.layout); for (auto &&it : ks) { AttachKernel(it.get()); kernels.emplace_back(std::move(it)); } } + CHECK(!kernels.empty()) << "No kernel found for Op " << op_type_; + LOG(INFO) << "op " << op_type_ << " get " << kernels.size() << " kernels"; return kernels; } @@ -59,9 +61,10 @@ bool OpLite::Run() { } bool OpLite::Attach(const framework::OpDesc &opdesc, lite::Scope *scope) { - CHECK(!op_info_) << "op_info duplicate build found"; - op_info_ = std::make_shared(); - op_info_->Build(opdesc); + CHECK(scope); + scope_ = scope; + op_info_.reset(new OpInfo); // Force clean the out-of-date infomation. + op_info_->Build(opdesc.ReadonlyProto()); return AttachImpl(opdesc, scope); } @@ -79,7 +82,8 @@ Tensor *OpLite::GetMutableTensor(lite::Scope *scope, return var->GetMutable(); } -bool OpInfo::GetInputArgname(const std::string &value_name, std::string *out) { +bool OpInfo::GetInputArgname(const std::string &value_name, + std::string *out) const { for (auto &item : input_argument_) { auto it = std::find(item.second.begin(), item.second.end(), value_name); if (it != item.second.end()) { @@ -89,7 +93,8 @@ bool OpInfo::GetInputArgname(const std::string &value_name, std::string *out) { } return false; } -bool OpInfo::GetOutputArgname(const std::string &value_name, std::string *out) { +bool OpInfo::GetOutputArgname(const std::string &value_name, + std::string *out) const { for (auto &item : output_argument_) { auto it = std::find(item.second.begin(), item.second.end(), value_name); if (it != item.second.end()) { diff --git a/paddle/fluid/lite/core/op_lite.h b/paddle/fluid/lite/core/op_lite.h index f754780ad52..2092703d33a 100644 --- a/paddle/fluid/lite/core/op_lite.h +++ b/paddle/fluid/lite/core/op_lite.h @@ -81,19 +81,30 @@ class OpLite : public Registry { // Run this operator. virtual bool Run(); + // Link the external execution environ to internal context. bool Attach(const framework::OpDesc &opdesc, lite::Scope *scope); - const std::shared_ptr &op_info() const { return op_info_; } - std::shared_ptr &mutable_op_info() { return op_info_; } + const OpInfo *op_info() const { return op_info_.get(); } + OpInfo *mutable_op_info() { return op_info_.get(); } // Human-readable information. virtual std::string DebugString() const = 0; const Place &kernel_place() const { return kernel_place_; } + // NOTE This might be discarded. void PickKernel(const std::vector &valid_places, KernelStrategy kernel_strategy = KernelStrategy::kStatic); + // Create all the kernels for the valid targets. + std::vector> CreateKernels( + const std::vector &places, const std::string &kernel_type = ""); + + lite::Scope *scope() { return scope_; } + + // Assign op param to kernel. + virtual void AttachKernel(KernelBase *kernel) = 0; + virtual ~OpLite() = default; protected: @@ -101,9 +112,6 @@ class OpLite : public Registry { virtual bool AttachImpl(const framework::OpDesc &opdesc, lite::Scope *scope) = 0; - // Assign op param to kernel. - virtual void AttachKernel(KernelBase *kernel) = 0; - // Specify the kernel to run by default. This will specify the value of // `kernel_place_`. virtual void StaticPickKernel(const std::vector &valid_targets) { @@ -118,10 +126,6 @@ class OpLite : public Registry { // some inputs are ready. void RecordOutputEvents() {} - // Create all the kernels for the valid targets. - std::vector> CreateKernels( - const std::vector &places, const std::string &kernel_type = ""); - const Tensor *GetTensor(lite::Scope *scope, const std::string &name) const; Tensor *GetMutableTensor(lite::Scope *scope, const std::string &name) const; @@ -129,11 +133,12 @@ class OpLite : public Registry { friend class mir::SSAGraph; protected: + lite::Scope *scope_{}; std::unique_ptr kernel_; std::string op_type_; std::vector valid_places_; Place kernel_place_{TARGET(kHost), PRECISION(kFloat)}; - std::shared_ptr op_info_; + std::unique_ptr op_info_; }; /* @@ -142,22 +147,30 @@ class OpLite : public Registry { */ class OpInfo { public: - void Build(const framework::OpDesc &desc) { + // To avoid the bugs from legancy framework::OpDesc, we use the ProtoBuf + // message instead. + void Build(const framework::proto::OpDesc &desc) { ExtractInputsAndOutputs(desc); CollectInputAndOutputArgnames(desc); CollectArguments(desc); + desc_.reset(new framework::proto::OpDesc(desc)); } + const framework::proto::OpDesc &desc() const { + CHECK(desc_) << "desc has't set"; + return *desc_; + } + framework::proto::OpDesc *mutable_desc() { return desc_.get(); } const std::list &input_names() const { return input_names_; } const std::list &output_names() const { return output_names_; } - const std::map> &input_argument() { + const std::map> &input_argument() const { return input_argument_; } - const std::map> &output_argument() { + const std::map> &output_argument() const { return output_argument_; } - bool GetInputArgname(const std::string &value_name, std::string *out); - bool GetOutputArgname(const std::string &value_name, std::string *out); + bool GetInputArgname(const std::string &value_name, std::string *out) const; + bool GetOutputArgname(const std::string &value_name, std::string *out) const; const std::list &input_argnames() const { return input_argnames_; @@ -167,37 +180,37 @@ class OpInfo { } private: - void ExtractInputsAndOutputs(const framework::OpDesc &opdesc) { - for (const auto &item : opdesc.Inputs()) { - for (const auto &x : item.second) { + void ExtractInputsAndOutputs(const framework::proto::OpDesc &opdesc) { + for (const auto &item : opdesc.inputs()) { + for (const auto &x : item.arguments()) { input_names_.push_back(x); } } - for (const auto &item : opdesc.Outputs()) { - for (const auto &x : item.second) { + for (const auto &item : opdesc.outputs()) { + for (const auto &x : item.arguments()) { output_names_.push_back(x); } } } - void CollectInputAndOutputArgnames(const framework::OpDesc &opdesc) { - for (const auto &item : opdesc.InputNames()) { - input_argnames_.push_back(item); + void CollectInputAndOutputArgnames(const framework::proto::OpDesc &opdesc) { + for (const auto &item : opdesc.inputs()) { + input_argnames_.push_back(item.parameter()); } - for (const auto &item : opdesc.OutputNames()) { - output_argnames_.push_back(item); + for (const auto &item : opdesc.outputs()) { + output_argnames_.push_back(item.parameter()); } } - void CollectArguments(const framework::OpDesc &opdesc) { - for (const auto &item : opdesc.Inputs()) { - for (auto &x : item.second) { - input_argument_[item.first].push_back(x); + void CollectArguments(const framework::proto::OpDesc &opdesc) { + for (const auto &item : opdesc.inputs()) { + for (auto &x : item.arguments()) { + input_argument_[item.parameter()].push_back(x); } } - for (const auto &item : opdesc.Outputs()) { - for (auto &x : item.second) { - output_argument_[item.first].push_back(x); + for (const auto &item : opdesc.outputs()) { + for (auto &x : item.arguments()) { + output_argument_[item.parameter()].push_back(x); } } } @@ -209,6 +222,8 @@ class OpInfo { std::list output_argnames_; std::map> input_argument_; std::map> output_argument_; + // NOTE too heavy. + std::unique_ptr desc_; }; } // namespace lite diff --git a/paddle/fluid/lite/core/op_registry.cc b/paddle/fluid/lite/core/op_registry.cc index 676cbe2dfcf..b012ca3e2fd 100644 --- a/paddle/fluid/lite/core/op_registry.cc +++ b/paddle/fluid/lite/core/op_registry.cc @@ -18,13 +18,33 @@ namespace paddle { namespace lite { std::list> KernelRegistry::Create( - const std::string &op_type, TargetType target, PrecisionType precision) { -#define CREATE_KERNEL(target__) \ - switch (precision) { \ - case PRECISION(kFloat): \ - return Create(op_type); \ - default: \ - CHECK(false) << "not supported kernel place yet"; \ + const std::string &op_type, TargetType target, PrecisionType precision, + DataLayoutType layout) { + Place place{target, precision, layout}; + LOG(INFO) << "creating " << op_type << " kernel for " << place; +#define CREATE_KERNEL1(target__, precision__) \ + switch (layout) { \ + case DATALAYOUT(kNCHW): \ + return Create(op_type); \ + case DATALAYOUT(kAny): \ + return Create(op_type); \ + default: \ + LOG(FATAL) << "unsupported kernel layout " << DataLayoutToStr(layout); \ + } + +#define CREATE_KERNEL(target__) \ + switch (precision) { \ + case PRECISION(kFloat): \ + CREATE_KERNEL1(target__, kFloat); \ + case PRECISION(kInt8): \ + CREATE_KERNEL1(target__, kInt8); \ + case PRECISION(kAny): \ + CREATE_KERNEL1(target__, kAny); \ + default: \ + CHECK(false) << "not supported kernel precision " \ + << PrecisionToStr(precision); \ } switch (target) { @@ -38,7 +58,7 @@ std::list> KernelRegistry::Create( CREATE_KERNEL(kCUDA); } break; default: - CHECK(false) << "not supported kernel place"; + CHECK(false) << "not supported kernel target " << TargetToStr(target); } #undef CREATE_KERNEL @@ -46,14 +66,21 @@ std::list> KernelRegistry::Create( } KernelRegistry::KernelRegistry() { -#define INIT_FOR(target__, precision__) \ +#define INIT_FOR(target__, precision__, layout__) \ registries_[KernelRegistry::GetKernelOffset()] \ - .set \ - *>(&KernelRegistryForTarget::Global()); + PRECISION(precision__), \ + DATALAYOUT(layout__)>()] \ + .set *>( \ + &KernelRegistryForTarget::Global()); // Currently, just register 2 kernel targets. - INIT_FOR(kHost, kFloat); + INIT_FOR(kCUDA, kFloat, kNCHW); + INIT_FOR(kCUDA, kAny, kNCHW); + INIT_FOR(kHost, kFloat, kNCHW); + INIT_FOR(kHost, kAny, kNCHW); + INIT_FOR(kHost, kAny, kAny); + INIT_FOR(kCUDA, kAny, kAny); #undef INIT_FOR } diff --git a/paddle/fluid/lite/core/op_registry.h b/paddle/fluid/lite/core/op_registry.h index 749f786a71c..590ba3caca8 100644 --- a/paddle/fluid/lite/core/op_registry.h +++ b/paddle/fluid/lite/core/op_registry.h @@ -50,80 +50,108 @@ class OpLiteRegistor : public Registor { }) {} }; -template +template using KernelRegistryForTarget = - Factory, std::unique_ptr>; + Factory, std::unique_ptr>; class KernelRegistry final { public: using any_kernel_registor_t = - variant *, // - KernelRegistryForTarget *, // - KernelRegistryForTarget *, // - KernelRegistryForTarget *, // - KernelRegistryForTarget * // + variant *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget * // >; KernelRegistry(); static KernelRegistry &Global(); - template + template void Register(const std::string &name, - typename KernelRegistryForTarget::creator_t - &&creator) { - using kernel_registor_t = KernelRegistryForTarget; - registries_[GetKernelOffset()] - .template get() - ->Register(name, std::move(creator)); + typename KernelRegistryForTarget::creator_t &&creator) { + LOG(INFO) << "register for " << TargetToStr(Target) << ":" + << PrecisionToStr(Precision) << "//" + << GetKernelOffset(); + using kernel_registor_t = + KernelRegistryForTarget; + auto &varient = registries_[GetKernelOffset()]; + varient.template get()->Register(name, + std::move(creator)); } - template + template std::list> Create(const std::string &op_type) { - using kernel_registor_t = KernelRegistryForTarget; - return registries_[GetKernelOffset()] + using kernel_registor_t = + KernelRegistryForTarget; + return registries_[GetKernelOffset()] .template get() ->Creates(op_type); } std::list> Create(const std::string &op_type, TargetType target, - PrecisionType precision); + PrecisionType precision, + DataLayoutType layout); // Get a kernel registry offset in all the registries. - template - static constexpr int GetKernelOffset() { - return kNumTargets * static_cast(Target) + static_cast(Precision); + template + static int GetKernelOffset() { + CHECK_LT(static_cast(Target), static_cast(TARGET(NUM))); + CHECK_LT(static_cast(Precision), static_cast(PRECISION(NUM))); + CHECK_LT(static_cast(Layout), static_cast(DATALAYOUT(NUM))); + return static_cast(Target) * static_cast(PRECISION(NUM)) * + static_cast(DATALAYOUT(NUM)) + // + static_cast(Precision) * static_cast(DATALAYOUT(NUM)) + // + static_cast(Layout); } std::string DebugString() const { std::stringstream ss; ss << "KernelCreator:" << std::endl; - ss << registries_[GetKernelOffset()] - .get< - KernelRegistryForTarget *>() + ss << registries_[GetKernelOffset()] + .get *>() ->DebugString(); ss << std::endl; return ss.str(); } private: - mutable std::array + mutable std::array(TARGET(NUM)) * + static_cast(PRECISION(NUM)) * + static_cast(DATALAYOUT(NUM))> registries_; }; -template +template class KernelRegistor : public lite::Registor { public: - KernelRegistor(const std::string op_type) - : Registor([&] { + KernelRegistor(const std::string &op_type, const std::string &alias) + : Registor([=] { LOG(INFO) << "Register kernel " << op_type << " for " - << TargetToStr(target) << " " << PrecisionToStr(precision); - KernelRegistry::Global().Register( - op_type, [&, op_type]() -> std::unique_ptr { + << TargetToStr(target) << " " << PrecisionToStr(precision) + << " " << DataLayoutToStr(layout) << " alias " << alias; + KernelRegistry::Global().Register( + op_type, [=]() -> std::unique_ptr { std::unique_ptr x(new KernelType); x->set_op_type(op_type); + x->set_alias(alias); return x; }); }) {} @@ -151,35 +179,40 @@ class KernelRegistor : public lite::Registor { #define LITE_KERNEL_REGISTER(op_type__, target__, precision__) \ op_type__##__##target__##__##precision__##__registor__ #define LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, \ - alias__) \ + layout__, alias__) \ op_type__##__##target__##__##precision__##__registor__instance__##alias__ #define LITE_KERNEL_REGISTER_FAKE(op_type__, target__, precision__, alias__) \ LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, alias__) -#define REGISTER_LITE_KERNEL(op_type__, target__, precision__, KernelClass, \ - alias__) \ - static paddle::lite::KernelRegistor \ - LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, \ - alias__)(#op_type__); \ - static KernelClass LITE_KERNEL_INSTANCE(op_type__, target__, precision__, \ - alias__); \ - int touch_##op_type__##target__##precision__##alias__() { \ - LITE_KERNEL_INSTANCE(op_type__, target__, precision__, alias__).Touch(); \ - return 0; \ - } \ - static bool LITE_KERNEL_PARAM_INSTANCE(op_type__, target__, precision__, \ - alias__) __attribute__((unused)) = \ - paddle::lite::ParamTypeRegistry::NewInstance( \ - #op_type__) - -#define USE_LITE_KERNEL(op_type__, target__, precision__, alias__) \ - extern int touch_##op_type__##target__##precision__##alias__(); \ - int op_type__##target__##precision__##alias__ __attribute__((unused)) = \ - touch_##op_type__##target__##precision__##alias__(); - -#define LITE_KERNEL_INSTANCE(op_type__, target__, precision__, alias__) \ - op_type__##target__##precision__##alias__ -#define LITE_KERNEL_PARAM_INSTANCE(op_type__, target__, precision__, alias__) \ - op_type__##target__##precision__##alias__##param_register +#define REGISTER_LITE_KERNEL(op_type__, target__, precision__, layout__, \ + KernelClass, alias__) \ + static paddle::lite::KernelRegistor \ + LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, \ + layout__, alias__)(#op_type__, #alias__); \ + static KernelClass LITE_KERNEL_INSTANCE(op_type__, target__, precision__, \ + layout__, alias__); \ + int touch_##op_type__##target__##precision__##layout__##alias__() { \ + LITE_KERNEL_INSTANCE(op_type__, target__, precision__, layout__, alias__) \ + .Touch(); \ + return 0; \ + } \ + static bool LITE_KERNEL_PARAM_INSTANCE(op_type__, target__, precision__, \ + layout__, alias__) \ + __attribute__((unused)) = paddle::lite::ParamTypeRegistry::NewInstance< \ + TARGET(target__), PRECISION(precision__), DATALAYOUT(layout__)>( \ + #op_type__ "/" #alias__) + +#define USE_LITE_KERNEL(op_type__, target__, precision__, layout__, alias__) \ + extern int touch_##op_type__##target__##precision__##layout__##alias__(); \ + int op_type__##target__##precision__##layout__##alias__ \ + __attribute__((unused)) = \ + touch_##op_type__##target__##precision__##layout__##alias__(); + +#define LITE_KERNEL_INSTANCE(op_type__, target__, precision__, layout__, \ + alias__) \ + op_type__##target__##precision__##layout__##alias__ +#define LITE_KERNEL_PARAM_INSTANCE(op_type__, target__, precision__, layout__, \ + alias__) \ + op_type__##target__##precision__##layout__##alias__##param_register diff --git a/paddle/fluid/lite/core/optimizer.cc b/paddle/fluid/lite/core/optimizer.cc index cd80f786793..c3be12d22f5 100644 --- a/paddle/fluid/lite/core/optimizer.cc +++ b/paddle/fluid/lite/core/optimizer.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/lite/core/optimizer.h" +#include "paddle/fluid/lite/core/mir/io_complement_pass.h" #include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h" namespace paddle { @@ -25,5 +26,33 @@ void Optimizer::SpecifyKernelPickTactic(core::KernelPickFactor factor) { *pass->mutable_kernel_pick_factors() = factor; } + +void Optimizer::RunPasses() { + std::vector passes({ + "static_kernel_pick_pass", // + "variable_place_inference_pass", // + "argument_type_display_pass", // + "io_complement_pass", // + "argument_type_display_pass", // + "variable_place_inference_pass", // + "argument_type_display_pass", // + "io_copy_kernel_pick_pass", // + "variable_place_inference_pass", // + }); + for (auto& pass_type : passes) { + LOG(INFO) << ".. running pass " << pass_type; + auto* pass = mir::PassManager::Global().LookUp(pass_type); + CHECK(pass); + if (pass->name() == "io_complement_pass") { + auto* _pass = dynamic_cast(pass); + _pass->SetValidPlaces(valid_places_); + CHECK(!_pass->valid_places().empty()); + _pass->Apply(graph_); + } else { + pass->Apply(graph_); + } + } + // mir::PassManager::Global().Run(graph_); +} } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/core/optimizer.h b/paddle/fluid/lite/core/optimizer.h index 7bd6a2476bd..977d1ab7efa 100644 --- a/paddle/fluid/lite/core/optimizer.h +++ b/paddle/fluid/lite/core/optimizer.h @@ -16,8 +16,10 @@ #include #include #include "paddle/fluid/lite/core/mir/generate_program_pass.h" +#include "paddle/fluid/lite/core/mir/io_complement_pass.h" #include "paddle/fluid/lite/core/mir/pass_manager.h" #include "paddle/fluid/lite/core/mir/ssa_graph.h" +#include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h" #include "paddle/fluid/lite/core/program.h" #include "paddle/fluid/lite/core/types.h" @@ -33,25 +35,46 @@ class Optimizer { void Run(Program&& program, const std::vector& valid_places, core::KernelPickFactor kernel_pick_factor, const std::vector& passes = {}) { + valid_places_ = valid_places; + CHECK(!valid_places.empty()) << "At least one valid_place should be set"; CHECK(!graph_) << "duplicate optimize found"; graph_.reset(new mir::SSAGraph); graph_->Build(program, valid_places); SpecifyKernelPickTactic(kernel_pick_factor); + // InitIoComplement(); RunPasses(); exec_scope_ = program.exec_scope; } + void KernelPickPreferPlace(const Place& place) { + auto* pass = mir::PassManager::Global().LookUp( + "static_kernel_pick_pass"); + CHECK(pass); + pass->SetPreferPlace(place); + } + // Generate a new program based on the mir graph. std::unique_ptr GenRuntimeProgram() { + LOG(INFO) << "generate program"; std::unique_ptr res; auto pass = mir::PassManager::Global().LookUp( "generate_program_pass"); + pass->Apply(graph_); auto program = pass->GenProgram(); CHECK(exec_scope_); program->set_exec_scope(exec_scope_); return program; } + void InitIoComplement() { + auto* pass = mir::PassManager::Global().LookUp( + "io_complement_pass"); + CHECK(pass); + CHECK(!valid_places_.empty()); + LOG(INFO) << "valid_places.size " << valid_places_.size(); + pass->SetValidPlaces(valid_places_); + } + // Generate C++ code which combines the inference program, model and weights. void GenCode(const std::string& code_dir); @@ -64,13 +87,14 @@ class Optimizer { void SpecifyKernelPickTactic(core::KernelPickFactor factor); // Run the default passes registered in the PassManager. - void RunPasses() { mir::PassManager::Global().Run(graph_); } + void RunPasses(); // Specify the passes and run them. void RunPasses(std::vector& passes); private: std::unique_ptr graph_; + std::vector valid_places_; lite::Scope* exec_scope_{}; }; diff --git a/paddle/fluid/lite/core/program.h b/paddle/fluid/lite/core/program.h index a31a0eb0d5f..6f945c06124 100644 --- a/paddle/fluid/lite/core/program.h +++ b/paddle/fluid/lite/core/program.h @@ -84,13 +84,10 @@ struct Program { tmp_vars.push_back("fetch"); for (auto var_desc : program.Block(0).AllVars()) { if (!var_desc->Persistable()) { - LOG(INFO) << "get tmp var " << var_desc->Name(); tmp_vars.push_back(var_desc->Name()); - auto* var = exec_scope->Var(var_desc->Name()); - LOG(INFO) << "create tmp var " << var_desc->Name() << " " << var; + exec_scope->Var(var_desc->Name()); } else { if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") continue; - LOG(INFO) << "get weight var " << var_desc->Name(); weights.push_back(var_desc->Name()); } } @@ -105,15 +102,19 @@ struct Instruction { void Run() { CHECK(op_); CHECK(kernel_); - LOG(INFO) << "running kernel> " << kernel_->DebugString(); if (UNLIKELY(first_epoch_)) { first_epoch_ = false; - op_->CheckShape(); + CHECK(op_->CheckShape()); } op_->InferShape(); kernel_->Run(); } + friend std::ostream& operator<<(std::ostream& os, const Instruction& other) { + os << other.kernel_->summary() << "\t(" << other.kernel_->doc() << ")"; + return os; + } + private: std::shared_ptr op_; std::unique_ptr kernel_; @@ -125,11 +126,16 @@ struct Instruction { */ class RuntimeProgram { public: - explicit RuntimeProgram(std::vector&& instruction) - : instructions_(std::move(instruction)) {} + explicit RuntimeProgram(std::vector&& insts) + : instructions_(std::move(insts)) { + if (insts.empty()) { + LOG(ERROR) << "no instructions"; + } + } void Run() { for (auto& inst : instructions_) { + LOG(INFO) << ">> Running kernel: " << inst; inst.Run(); } } diff --git a/paddle/fluid/lite/core/target_wrapper.h b/paddle/fluid/lite/core/target_wrapper.h index 2852f582572..fd754a09ec0 100644 --- a/paddle/fluid/lite/core/target_wrapper.h +++ b/paddle/fluid/lite/core/target_wrapper.h @@ -16,6 +16,10 @@ #include #include #include +#ifdef LITE_WITH_CUDA +#include +#include +#endif namespace paddle { namespace lite { @@ -26,20 +30,20 @@ enum class TargetType : int { kX86, kCUDA, kAny, // any target - kLastAsPlaceHolder, + NUM, // number of fields. }; enum class PrecisionType : int { kUnk = 0, kFloat, kInt8, kAny, // any precision - kLastAsPlaceHolder, + NUM, // number of fields. }; enum class DataLayoutType : int { kUnk = 0, kNCHW, kAny, // any data layout - kLastAsPlaceHolder, + NUM, // number of fields. }; // Some helper macro to get a specific TargetType. @@ -50,25 +54,29 @@ enum class DataLayoutType : int { #define PRECISION_VAL(item__) static_cast(PRECISION(item__)) #define DATALAYOUT(item__) paddle::lite::DataLayoutType::item__ -constexpr const int kNumPrecisions = - PRECISION_VAL(kLastAsPlaceHolder) - PRECISION_VAL(kFloat); -constexpr const int kNumTargets = - TARGET_VAL(kLastAsPlaceHolder) - TARGET_VAL(kHost); +constexpr const int kNumPrecisions = PRECISION_VAL(NUM); +constexpr const int kNumTargets = TARGET_VAL(NUM); static const std::string target2string[] = {"unk", "host", "x86", "cuda", "any"}; static const std::string& TargetToStr(TargetType target) { - return target2string[static_cast(target)]; + auto x = static_cast(target); + CHECK_LT(x, static_cast(TARGET(NUM))); + return target2string[x]; } static const std::string precision2string[] = {"unk", "float", "int8", "any"}; static const std::string& PrecisionToStr(PrecisionType precision) { - return precision2string[static_cast(precision)]; + auto x = static_cast(precision); + CHECK_LT(x, static_cast(PRECISION(NUM))); + return precision2string[x]; } static const std::string datalayout2string[] = {"unk", "NCHW", "any"}; -static const std::string& DataLayoutToStr(DataLayoutType x) { - return datalayout2string[static_cast(x)]; +static const std::string& DataLayoutToStr(DataLayoutType layout) { + auto x = static_cast(layout); + CHECK_LT(x, static_cast(DATALAYOUT(NUM))); + return datalayout2string[x]; } /* @@ -187,5 +195,37 @@ class TargetWrapper { } }; +#ifdef LITE_WITH_CUDA +// This interface should be specified by each kind of target. +template <> +class TargetWrapper { + public: + using stream_t = cudaStream_t; + using event_t = cudaEvent_t; + + static size_t num_devices() { return 0; } + static size_t maximum_stream() { return 0; } + + static void CreateStream(stream_t* stream) {} + static void DestroyStream(const stream_t& stream) {} + + static void CreateEvent(event_t* event) {} + static void DestroyEvent(const event_t& event) {} + + static void RecordEvent(const event_t& event) {} + static void SyncEvent(const event_t& event) {} + + static void StreamSync(const stream_t& stream) {} + + static void* Malloc(size_t size); + static void Free(void* ptr); + + static void MemcpySync(void* dst, const void* src, size_t size, + IoDirection dir); + static void MemcpyAsync(void* dst, const void* src, size_t size, + IoDirection dir, const stream_t& stream); +}; +#endif // LITE_WITH_CUDA + } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/core/type_system.cc b/paddle/fluid/lite/core/type_system.cc index b396a4e147c..a51bdb42891 100644 --- a/paddle/fluid/lite/core/type_system.cc +++ b/paddle/fluid/lite/core/type_system.cc @@ -87,20 +87,41 @@ const Type* Type::Get(TargetType target) { } } +template +const Type* GetTensorFp32NCHWTy() { + static TensorFp32NCHWTy x(Target); + return &x; +} + template <> const Type* Type::Get(TargetType target) { switch (target) { - case TargetType::kX86: - return Get(); - case TargetType::kHost: - return Get(); + case TARGET(kHost): + return GetTensorFp32NCHWTy(); + case TARGET(kCUDA): + return GetTensorFp32NCHWTy(); + case TARGET(kX86): + return GetTensorFp32NCHWTy(); + default: + LOG(FATAL) << "unsupported target Type " << TargetToStr(target); + } + return nullptr; +} +const Type* LookupType(DataTypeBase::ID type_id, bool is_unknown, + bool is_tensor, Place place) { + using id_t = DataTypeBase::ID; + switch (type_id) { + case id_t::Tensor_Any: + return Type::Get(place.target); + case id_t::Tensor_Fp32_NCHW: + return Type::Get(place.target); + case id_t::TensorList_Any: + return Type::Get(place.target); default: - LOG(FATAL) << "unsupported target " << TargetToStr(target); - return nullptr; + LOG(FATAL) << "unsupported type"; } + return nullptr; } // ------------------------- end GetType specification ------------------------ diff --git a/paddle/fluid/lite/core/type_system.h b/paddle/fluid/lite/core/type_system.h index c8c0c0f5c12..33181eeace2 100644 --- a/paddle/fluid/lite/core/type_system.h +++ b/paddle/fluid/lite/core/type_system.h @@ -131,6 +131,23 @@ class Type : public DataTypeBase { bool operator==(const Type& other) { return id_ == other.id() && place_ == other.place(); } + friend std::ostream& operator<<(std::ostream& os, const Type& other) { + if (other.IsUnsupported()) { + os << ""; + return os; + } + if (other.IsVoid()) { + os << ""; + return os; + } + if (other.is_tensor_) { + os << ""; + return os; + } // Can cast to another type. This is heavily used in MIR, by determine whether // is is possible to add a instruction to transform a type to another. @@ -163,29 +180,33 @@ class Type : public DataTypeBase { }; // -------------------------------- compatible check --------------------------- -static bool TargetCompatible(const Type& a, const Type& b) { - return (a.IsVoid() || b.IsVoid()) || // - a.target() == b.target(); +static bool TargetCompatibleTo(const Type& a, const Type& b) { + return a.IsVoid() || // + (a.IsTensor() && b.IsTensor() && (a.target() == b.target() || // + b.target() == TARGET(kAny))); } -static bool DataLayoutCompatible(const Type& a, const Type& b) { - return (a.IsVoid() || b.IsVoid()) || // - (a.IsTensor() && b.IsTensor() && a.layout() == b.layout()); +static bool DataLayoutCompatibleTo(const Type& a, const Type& b) { + return a.IsVoid() || // + (a.IsTensor() && b.IsTensor() && (a.layout() == b.layout() || // + b.layout() == DATALAYOUT(kAny))); } -static bool PrecisionCompatible(const Type& a, const Type& b) { - return (a.IsVoid() || b.IsVoid()) || // - (a.precision() == b.precision()); +static bool PrecisionCompatibleTo(const Type& a, const Type& b) { + return a.IsVoid() || // + (a.IsTensor() && b.IsTensor() && (a.precision() == b.precision() || // + b.precision() == PRECISION(kAny))); } -static bool DeviceCompatible(const Type& a, const Type& b) { - return (a.IsVoid() || b.IsVoid()) || // - (a.device() == b.device()); +static bool DeviceCompatibleTo(const Type& a, const Type& b) { + return a.IsVoid() || // + (a.IsTensor() && b.IsTensor() && (a.device() == b.device())); } -static bool TypeCompatible(const Type& a, const Type& b) { - return TargetCompatible(a, b) && DataLayoutCompatible(a, b) && - PrecisionCompatible(a, b) && DeviceCompatible(a, b); +// Can type 'a' be passed to 'b' directly. +static bool TypeCompatibleTo(const Type& a, const Type& b) { + return TargetCompatibleTo(a, b) && DataLayoutCompatibleTo(a, b) && + PrecisionCompatibleTo(a, b) && DeviceCompatibleTo(a, b); } // -------------------------------- predefined types --------------------------- @@ -230,6 +251,9 @@ class TensorInt64NCHWTy : public Type { : Type(ID::Tensor_Int64_NCHW, "TensorInt64NCHW", true /*is_tensor*/, target, PrecisionType::kInt8, DataLayoutType::kNCHW) {} }; + +const Type* LookupType(DataTypeBase::ID type_id, bool is_unknown, + bool is_tensor, Place place); // ------------------------- end predefined types --------------------------- // NOTE TypeSystem has some overhead, and better to be used in analysis phase. @@ -381,13 +405,15 @@ class ParamTypeRegistry { CHECK(types_.count(key)); } - template - const ParamType* Retrieve(const Place& place, const std::string& op_type, - const std::string& arg_name) { - KernelIdTy key{op_type, place, io, arg_name}; - auto it = types_.find(key); - if (it == types_.end()) return nullptr; - return &it->second; + const ParamType* RetrieveInArgument(const Place& place, + const std::string& op_type, + const std::string& arg_name) { + return Retrieve(place, op_type, arg_name); + } + const ParamType* RetrieveOutArgument(const Place& place, + const std::string& op_type, + const std::string& arg_name) { + return Retrieve(place, op_type, arg_name); } static ParamTypeRegistry& Global() { @@ -403,6 +429,16 @@ class ParamTypeRegistry { return os; } + protected: + template + const ParamType* Retrieve(const Place& place, const std::string& op_type, + const std::string& arg_name) { + KernelIdTy key{op_type, place, io, arg_name}; + auto it = types_.find(key); + if (it == types_.end()) return nullptr; + return &it->second; + } + private: ParamTypeRegistry() = default; diff --git a/paddle/fluid/lite/core/types.cc b/paddle/fluid/lite/core/types.cc index f616a7d7f49..b94ac8c5c96 100644 --- a/paddle/fluid/lite/core/types.cc +++ b/paddle/fluid/lite/core/types.cc @@ -43,6 +43,9 @@ bool KernelPickFactor::IsTargetConsidered() const { bool KernelPickFactor::IsDataLayoutConsidered() const { return data_ & static_cast(Factor::DataLayoutFirst); } +bool KernelPickFactor::IsDeviceConsidered() const { + return data_ & static_cast(Factor::DeviceFirst); +} } // namespace core } // namespace lite diff --git a/paddle/fluid/lite/core/types.h b/paddle/fluid/lite/core/types.h index 7562a911c8d..4b542814cdb 100644 --- a/paddle/fluid/lite/core/types.h +++ b/paddle/fluid/lite/core/types.h @@ -14,6 +14,7 @@ #pragma once +#include #include "paddle/fluid/lite/core/context.h" #include "paddle/fluid/lite/utils/all.h" @@ -38,6 +39,7 @@ class KernelPickFactor { bool AnyFactorConsidered() const { return data_; } KernelPickFactor& ConsiderTarget(); + // Perfer a specific target, e.g. prefer CUDA kernels. KernelPickFactor& ConsiderPrecision(); KernelPickFactor& ConsiderDataLayout(); KernelPickFactor& ConsiderDevice(); @@ -45,12 +47,29 @@ class KernelPickFactor { bool IsTargetConsidered() const; bool IsPrecisionConsidered() const; bool IsDataLayoutConsidered() const; - bool IsDeviceConsidered() const { - return data_ & static_cast(Factor::DeviceFirst); + bool IsDeviceConsidered() const; + + friend std::ostream& operator<<(std::ostream& os, const KernelPickFactor& k) { + std::stack bits; + auto data = k.data_; + while (data) { + bits.push(data % 2); + data /= 2; + } + int nbits = bits.size(); + for (size_t i = 0; i < sizeof(data) * 8 - nbits; i++) { + os << 0; + } + while (!bits.empty()) { + os << bits.top(); + bits.pop(); + } + return os; } private: unsigned char data_{}; + TargetType target_{TARGET(kUnk)}; }; struct dim2 { diff --git a/paddle/fluid/lite/cuda/target_wrapper.cc b/paddle/fluid/lite/cuda/target_wrapper.cc index cca9a95d8e0..7e6936d613c 100644 --- a/paddle/fluid/lite/cuda/target_wrapper.cc +++ b/paddle/fluid/lite/cuda/target_wrapper.cc @@ -12,10 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// -// Created by chunwei on 19-2-23. -// - #include "paddle/fluid/lite/cuda/target_wrapper.h" #include @@ -24,19 +20,14 @@ namespace lite { using TargetW = TargetWrapper; -template <> void* TargetW::Malloc(size_t size) { void* ptr{}; CHECK_EQ(cudaSuccess, cudaMalloc(&ptr, size)); return ptr; } -template <> -void TargetW::Free(void* ptr) { - CHECK_EQ(cudaSuccess, cudaFree(ptr)); -} +void TargetW::Free(void* ptr) { CHECK_EQ(cudaSuccess, cudaFree(ptr)); } -template <> void TargetW::MemcpySync(void* dst, const void* src, size_t size, IoDirection dir) { switch (dir) { @@ -55,7 +46,6 @@ void TargetW::MemcpySync(void* dst, const void* src, size_t size, } } -template <> void TargetW::MemcpyAsync(void* dst, const void* src, size_t size, IoDirection dir, const stream_t& stream) { switch (dir) { diff --git a/paddle/fluid/lite/kernels/cuda/CMakeLists.txt b/paddle/fluid/lite/kernels/cuda/CMakeLists.txt index 9a435e45e7f..64ea90b0afe 100644 --- a/paddle/fluid/lite/kernels/cuda/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/cuda/CMakeLists.txt @@ -1,2 +1,4 @@ -cc_library(mul_compute_cuda SRCS mul_compute.cc DEPS tensor_lite) +nv_library(mul_compute_cuda SRCS mul_compute.cc DEPS tensor_lite) cc_library(io_copy_compute_cuda SRCS io_copy_compute.cc DEPS tensor_lite) + +nv_library(kernels_cuda DEPS mul_compute_cuda io_copy_compute_cuda) diff --git a/paddle/fluid/lite/kernels/cuda/io_copy_compute.cc b/paddle/fluid/lite/kernels/cuda/io_copy_compute.cc index 9503d41f796..5705feae922 100644 --- a/paddle/fluid/lite/kernels/cuda/io_copy_compute.cc +++ b/paddle/fluid/lite/kernels/cuda/io_copy_compute.cc @@ -21,7 +21,7 @@ namespace lite { namespace kernels { namespace cuda { -using TargetW = TargetWrapper; +using TargetW = TargetWrapper; // Host to CUDA memory. void CopyFromHostSync(void* target, const void* source, size_t size) { @@ -51,6 +51,25 @@ class IoCopyHostToCudaCompute auto* data = param.y->mutable_data(target(), param.x->memory_size()); CopyFromHostSync(data, param.x->data(), param.x->memory_size()); } + + std::unique_ptr GetTypeInferHandler() override { + std::unique_ptr res(new type_infer_handler_t); + *res = [](const std::map& inputs, + const std::string& out) -> const Type* { + CHECK(!inputs.empty()); + auto* type = inputs.at("Input"); + CHECK(type->target() == TARGET(kHost)); + + auto out_place = type->place(); + out_place.target = TARGET(kCUDA); + auto* out_type = LookupType(type->id(), type->IsUnsupported(), + type->IsUnsupported(), out_place); + return out_type; + }; + return res; + } + + std::string doc() const override { return "Copy IO from HOST to CUDA"; } }; /* @@ -65,6 +84,8 @@ class IoCopyCudaToHostCompute auto* data = param.y->mutable_data(TARGET(kHost), param.x->memory_size()); CopyToHostSync(data, param.x->data(), param.x->memory_size()); } + + std::string doc() const override { return "Copy IO from CUDA to HOST"; } }; } // namespace cuda @@ -72,7 +93,7 @@ class IoCopyCudaToHostCompute } // namespace lite } // namespace paddle -REGISTER_LITE_KERNEL(io_copy, kCUDA, kAny, +REGISTER_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, paddle::lite::kernels::cuda::IoCopyHostToCudaCompute, host_to_device) .BindInput("Input", {paddle::lite::Type::Get( @@ -81,7 +102,7 @@ REGISTER_LITE_KERNEL(io_copy, kCUDA, kAny, TARGET(kCUDA))}) .Finalize(); -REGISTER_LITE_KERNEL(io_copy, kCUDA, kAny, +REGISTER_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, paddle::lite::kernels::cuda::IoCopyCudaToHostCompute, device_to_host) .BindInput("Input", {paddle::lite::Type::Get( diff --git a/paddle/fluid/lite/kernels/cuda/mul_compute.cc b/paddle/fluid/lite/kernels/cuda/mul_compute.cc index c80851bf649..f5081d2baa9 100644 --- a/paddle/fluid/lite/kernels/cuda/mul_compute.cc +++ b/paddle/fluid/lite/kernels/cuda/mul_compute.cc @@ -13,3 +13,22 @@ // limitations under the License. #include "paddle/fluid/lite/kernels/cuda/mul_compute.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda {} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(mul, kCUDA, kFloat, kNCHW, + paddle::lite::kernels::cuda::MulCompute, def) + .BindInput("X", {paddle::lite::Type::Get( + TARGET(kCUDA))}) + .BindInput("Y", {paddle::lite::Type::Get( + TARGET(kCUDA))}) + .BindOutput("Out", {paddle::lite::Type::Get( + TARGET(kCUDA))}) + .Finalize(); diff --git a/paddle/fluid/lite/kernels/cuda/mul_compute.h b/paddle/fluid/lite/kernels/cuda/mul_compute.h index e8d65a93daa..ad39e2eae8a 100644 --- a/paddle/fluid/lite/kernels/cuda/mul_compute.h +++ b/paddle/fluid/lite/kernels/cuda/mul_compute.h @@ -16,6 +16,7 @@ #include "paddle/fluid/lite/core/kernel.h" #include "paddle/fluid/lite/core/types.h" #include "paddle/fluid/lite/cuda/blas.h" +#include "paddle/fluid/lite/operators/op_params.h" namespace paddle { namespace lite { @@ -29,11 +30,29 @@ void mul_compute(const lite::cuda::Blas& blas, const T* x, int x_h, nullptr, out, 0); } -class MulCompute : public OpKernel { +class MulCompute : public OpKernel { public: using param_t = operators::MulParam; - void Run() override {} + void Run() override { + CHECK(context_) << "running context should be set first"; + auto& context = context_->AsCudaContext(); + CHECK(context.blas_fp32) << "blas should init first"; + auto& blas = *context.blas_fp32; + const auto& param = Param(); + CHECK(param.x->target() == TARGET(kCUDA)); + auto* x = param.x->data(); + int x_h = param.x->dims()[0]; + int x_w = param.x->dims()[1]; + + auto* y = param.y->data(); + int y_h = param.y->dims()[0]; + int y_w = param.y->dims()[1]; + + auto* out = param.output->mutable_data(TARGET(kCUDA)); + + mul_compute(blas, x, x_h, x_w, y, y_h, y_w, out); + } virtual ~MulCompute() = default; }; diff --git a/paddle/fluid/lite/kernels/host/fc_compute.cc b/paddle/fluid/lite/kernels/host/fc_compute.cc index ac9a4ccc0a7..f63ee3958c9 100644 --- a/paddle/fluid/lite/kernels/host/fc_compute.cc +++ b/paddle/fluid/lite/kernels/host/fc_compute.cc @@ -51,8 +51,8 @@ void FcCompute::Run() { } // namespace lite } // namespace paddle -REGISTER_LITE_KERNEL(fc, kHost, kFloat, paddle::lite::kernels::host::FcCompute, - def) +REGISTER_LITE_KERNEL(fc, kHost, kFloat, kNCHW, + paddle::lite::kernels::host::FcCompute, def) .BindInput("Input", {paddle::lite::Type::Get( TARGET(kHost))}) diff --git a/paddle/fluid/lite/kernels/host/feed_compute.cc b/paddle/fluid/lite/kernels/host/feed_compute.cc index d727f0c22ac..b1cef73276b 100644 --- a/paddle/fluid/lite/kernels/host/feed_compute.cc +++ b/paddle/fluid/lite/kernels/host/feed_compute.cc @@ -20,7 +20,8 @@ namespace lite { namespace kernels { namespace host { -class FeedCompute : public OpKernel { +class FeedCompute + : public OpKernel { public: using param_t = operators::FeedParam; @@ -38,7 +39,7 @@ class FeedCompute : public OpKernel { } // namespace lite } // namespace paddle -REGISTER_LITE_KERNEL(feed, kHost, kFloat, +REGISTER_LITE_KERNEL(feed, kHost, kAny, kAny, paddle::lite::kernels::host::FeedCompute, def) .BindInput("X", {paddle::lite::Type::Get( TARGET(kHost))}) diff --git a/paddle/fluid/lite/kernels/host/fetch_compute.cc b/paddle/fluid/lite/kernels/host/fetch_compute.cc index b3193a01942..4d640feff37 100644 --- a/paddle/fluid/lite/kernels/host/fetch_compute.cc +++ b/paddle/fluid/lite/kernels/host/fetch_compute.cc @@ -20,7 +20,8 @@ namespace lite { namespace kernels { namespace host { -class FetchCompute : public OpKernel { +class FetchCompute + : public OpKernel { public: using param_t = operators::FeedParam; @@ -41,7 +42,7 @@ class FetchCompute : public OpKernel { } // namespace lite } // namespace paddle -REGISTER_LITE_KERNEL(fetch, kHost, kFloat, +REGISTER_LITE_KERNEL(fetch, kHost, kAny, kAny, paddle::lite::kernels::host::FetchCompute, def) .BindInput("X", {paddle::lite::Type::Get( TARGET(kHost))}) diff --git a/paddle/fluid/lite/kernels/host/mul_compute.cc b/paddle/fluid/lite/kernels/host/mul_compute.cc index ee7b168503a..2f6a95c83ba 100644 --- a/paddle/fluid/lite/kernels/host/mul_compute.cc +++ b/paddle/fluid/lite/kernels/host/mul_compute.cc @@ -67,7 +67,7 @@ class MulCompute : public OpKernel { } // namespace lite } // namespace paddle -REGISTER_LITE_KERNEL(mul, kHost, kFloat, +REGISTER_LITE_KERNEL(mul, kHost, kFloat, kNCHW, paddle::lite::kernels::host::MulCompute, def) .BindInput("X", {paddle::lite::Type::Get( TARGET(kHost))}) diff --git a/paddle/fluid/lite/kernels/host/relu_compute.h b/paddle/fluid/lite/kernels/host/relu_compute.h index b8176377dcc..8afe1819184 100644 --- a/paddle/fluid/lite/kernels/host/relu_compute.h +++ b/paddle/fluid/lite/kernels/host/relu_compute.h @@ -42,6 +42,6 @@ class ReluCompute : public OpKernel { } // namespace lite } // namespace paddle -REGISTER_LITE_KERNEL(relu, kHost, kFloat, +REGISTER_LITE_KERNEL(relu, kHost, kFloat, kNCHW, paddle::lite::kernels::host::ReluCompute, def) .Finalize(); diff --git a/paddle/fluid/lite/kernels/host/scale_compute.cc b/paddle/fluid/lite/kernels/host/scale_compute.cc index 7ad6ab28818..4b260aecf87 100644 --- a/paddle/fluid/lite/kernels/host/scale_compute.cc +++ b/paddle/fluid/lite/kernels/host/scale_compute.cc @@ -50,7 +50,7 @@ class ScaleCompute : public OpKernel { } // namespace lite } // namespace paddle -REGISTER_LITE_KERNEL(scale, kHost, kFloat, +REGISTER_LITE_KERNEL(scale, kHost, kFloat, kNCHW, paddle::lite::kernels::host::ScaleCompute, def) .BindInput("X", {paddle::lite::Type::Get( TARGET(kHost))}) diff --git a/paddle/fluid/lite/operators/io_copy_op.cc b/paddle/fluid/lite/operators/io_copy_op.cc index 0169d85ba14..c9a71160731 100644 --- a/paddle/fluid/lite/operators/io_copy_op.cc +++ b/paddle/fluid/lite/operators/io_copy_op.cc @@ -24,7 +24,10 @@ bool IoCopyOp::CheckShape() const { CHECK_OR_FALSE(param_.y); return true; } -bool IoCopyOp::InferShape() const { return true; } +bool IoCopyOp::InferShape() const { + param_.y->Resize(param_.x->dims()); + return true; +} bool IoCopyOp::Run() { return OpLite::Run(); } bool IoCopyOp::AttachImpl(const paddle::framework::OpDesc &opdesc, paddle::lite::Scope *scope) { diff --git a/paddle/fluid/lite/operators/mul_op.h b/paddle/fluid/lite/operators/mul_op.h index 36770cb4d56..a0e91ba9865 100644 --- a/paddle/fluid/lite/operators/mul_op.h +++ b/paddle/fluid/lite/operators/mul_op.h @@ -51,9 +51,6 @@ class MulOpLite : public OpLite { param_.x_num_col_dims = boost::get(op_desc.GetAttr("x_num_col_dims")); param_.y_num_col_dims = boost::get(op_desc.GetAttr("y_num_col_dims")); - CHECK(kernel_); - kernel_->SetParam(param_); - return true; } diff --git a/paddle/fluid/lite/utils/factory.h b/paddle/fluid/lite/utils/factory.h index cc00e42651b..f8386c08de5 100644 --- a/paddle/fluid/lite/utils/factory.h +++ b/paddle/fluid/lite/utils/factory.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#include #include #include #include @@ -48,8 +49,6 @@ class Factory { } void Register(const std::string& op_type, creator_t&& creator) { - CHECK(!creators_.count(op_type)) << "The op " << op_type - << " has already registered"; creators_[op_type].emplace_back(std::move(creator)); } @@ -58,9 +57,9 @@ class Factory { } std::list Creates(const std::string& op_type) const { - auto it = creators_.find(op_type); - CHECK(it != creators_.end()) << "no item called " << op_type; std::list res; + auto it = creators_.find(op_type); + if (it == creators_.end()) return res; for (auto& c : it->second) { res.emplace_back(c()); } diff --git a/paddle/fluid/lite/utils/varient.h b/paddle/fluid/lite/utils/varient.h index ab13a647ec3..6a4b33a4fc6 100644 --- a/paddle/fluid/lite/utils/varient.h +++ b/paddle/fluid/lite/utils/varient.h @@ -99,7 +99,7 @@ struct variant { size_t type() { return type_id; } - void valid() { return (type_id != invalid_type()); } + bool valid() { return (type_id != invalid_type()); } template void set(Args&&... args) { diff --git a/paddle/fluid/lite/utils/varient_test.cc b/paddle/fluid/lite/utils/varient_test.cc index d9a98448dec..50eb632c5cb 100644 --- a/paddle/fluid/lite/utils/varient_test.cc +++ b/paddle/fluid/lite/utils/varient_test.cc @@ -24,6 +24,8 @@ namespace utils { TEST(varient, test) { variant a; + // The initial state should be invalid. + ASSERT_FALSE(a.valid()); a.set(1); ASSERT_EQ(a.get(), 1); a.set(20); -- GitLab