// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/jit/layer.h" #include "paddle/fluid/framework/variable.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/errors.h" #include "paddle/fluid/jit/compilation_unit.h" #include "paddle/fluid/jit/engine/base_engine.h" #include "paddle/fluid/jit/function.h" #include "paddle/fluid/jit/function_schema.h" namespace paddle { namespace jit { Layer::Layer(const Name2VariableMap& params_map, const Name2VariableMap& attrs_map, const Name2FunctionInfoMap& info_map, const phi::Place& place) : params_map_(params_map), attrs_map_(attrs_map), info_map_(info_map) { unit_.reset(new CompilationUnit()); } jit::Function Layer::Function(const std::string& name) const { return jit::Function(unit_->GetEngine(name).get()); } std::vector Layer::forward(const std::vector& inputs) { auto func = this->Function("forward"); return func(inputs); } std::vector Layer::forward( const std::vector& inputs) { auto func = this->Function("forward"); return func(inputs); } void Layer::to(const phi::Place& place) {} void Layer::SetEngine(const std::string& name, const std::shared_ptr& engine) { unit_->SetEngine(name, engine); } const Name2EngineMap& Layer::EngineMap() const { return unit_->EngineMap(); } const std::shared_ptr& Layer::FunctionInfo( const std::string& name) const { PADDLE_ENFORCE_EQ( info_map_.count(name), 1, phi::errors::InvalidArgument( "FuncitonInfo named %s is not exist in info_map_.", name)); return info_map_.at(name); } #define PD_SPECIALZE_ATTRIBUTE_TYPE(T) \ template <> \ T Layer::Attribute(const std::string& name) const { \ if (attrs_map_.find(name) == attrs_map_.end()) { \ PADDLE_THROW(phi::errors::NotFound( \ "Attribute can not found %s, please check if it exists.")); \ return T(); \ } \ auto var = attrs_map_.at(name); \ T ret = var->Get(); \ return ret; \ } PD_SPECIALZE_ATTRIBUTE_TYPE(int) PD_SPECIALZE_ATTRIBUTE_TYPE(float) PD_SPECIALZE_ATTRIBUTE_TYPE(std::string) PD_SPECIALZE_ATTRIBUTE_TYPE(std::vector) PD_SPECIALZE_ATTRIBUTE_TYPE(std::vector) PD_SPECIALZE_ATTRIBUTE_TYPE(std::vector) } // namespace jit } // namespace paddle