// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/lite/core/op_lite.h" #include #include #include #include #include "paddle/fluid/lite/core/op_registry.h" namespace paddle { namespace lite { std::vector> OpLite::CreateKernels( const std::vector &places, const std::string &kernel_type) { std::vector> kernels; CHECK(!op_type_.empty()) << "op_type_ should be set first"; auto pick_kernel = [&](const Place &place) { auto ks = KernelRegistry::Global().Create(op_type_, place.target, place.precision, place.layout); for (auto &&it : ks) { AttachKernel(it.get()); kernels.emplace_back(std::move(it)); } }; if (!kernel_type.empty()) { Place place; std::string op_type, alias; KernelBase::ParseKernelType(kernel_type, &op_type, &alias, &place); pick_kernel(place); CHECK(!kernels.empty()) << "no kernel for kernel type " << kernel_type; return kernels; } std::set place_set; for (auto place : places) { place_set.insert(place); // Pick kernels those support any Precision and any DataLayout place.precision = PRECISION(kAny); place_set.insert(place); place.layout = DATALAYOUT(kAny); place_set.insert(place); } std::set targets; for (auto place : place_set) { pick_kernel(place); targets.insert(place.target); } VLOG(2) << "op " << op_type_ << " get " << kernels.size() << " kernels"; return kernels; } bool OpLite::Run() { CHECK(kernel_); SyncInputEvents(); kernel_->Launch(); RecordOutputEvents(); return true; } bool OpLite::Attach(const cpp::OpDesc &opdesc, lite::Scope *scope) { // valid_places_.clear(); CHECK(scope != nullptr); // CHECK(!op_info_.get()); scope_ = scope; op_info_.reset( new OpInfo(opdesc)); // Force clean the out-of-date infomation. return AttachImpl(*op_info(), scope); } const Tensor *OpLite::GetTensor(lite::Scope *scope, const std::string &name) const { auto *var = scope->FindVar(name); CHECK(var) << "no variable called " << name << " found"; return &var->Get(); } Tensor *OpLite::GetMutableTensor(lite::Scope *scope, const std::string &name) const { auto *var = scope->FindVar(name); CHECK(var) << "no variable called " << name << " found"; return var->GetMutable(); } } // namespace lite } // namespace paddle