// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include <memory> #include <string> #include <utility> #include <vector> #include "paddle/fluid/framework/data_transform.h" #include "paddle/fluid/framework/op_kernel_type.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/framework/type_defs.h" #include "paddle/fluid/imperative/execution_context.h" #include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/type_defs.h" #include "paddle/pten/api/include/core.h" DECLARE_bool(use_mkldnn); namespace paddle { namespace framework { class Tensor; class Variable; } // namespace framework namespace platform { class DeviceContext; } // namespace platform } // namespace paddle namespace paddle { namespace imperative { const framework::Tensor* GetTensorFromVar(const framework::Variable& var); template <typename VarType> static void SetForwardDataTypeOfGradVar(const std::shared_ptr<VarType>& var); template <> void SetForwardDataTypeOfGradVar<VariableWrapper>( const std::shared_ptr<VariableWrapper>& var) { if (var->HasGradVar()) { auto grad_var = var->GetGradVar(); VLOG(6) << "Set grad var (" << grad_var->Name() << ")'s forward dtype to (" << framework::DataTypeToString(var->DataType()) << ")."; grad_var->SetForwardDataType(var->DataType()); } } template <> void SetForwardDataTypeOfGradVar<VarBase>(const std::shared_ptr<VarBase>& var) { if (var->HasGradVar()) { auto& shared_var = var->SharedVar(); SetForwardDataTypeOfGradVar<VariableWrapper>(shared_var); } } extern const std::shared_ptr<VariableWrapper>& GetVariableWrapper( const std::shared_ptr<paddle::imperative::VarBase>& var); extern const std::shared_ptr<VariableWrapper>& GetVariableWrapper( const std::shared_ptr<VariableWrapper>& var); template <typename VarType> std::shared_ptr<NameVarMap<VarType>> PrepareData( const framework::OperatorWithKernel& op, const NameVarMap<VarType>& ins, const framework::OpKernelType& expected_kernel_key) { std::shared_ptr<NameVarMap<VarType>> tmp_ins_ptr = nullptr; for (const auto& name_pair : ins) { for (size_t i = 0; i < name_pair.second.size(); ++i) { auto& var_base = name_pair.second[i]; SetForwardDataTypeOfGradVar(var_base); const auto* tensor = GetTensorFromVar(var_base->Var()); if (tensor && tensor->IsInitialized()) { auto kernel_type_for_var = op.GetKernelTypeForVar( name_pair.first, *tensor, expected_kernel_key); if (!NeedTransform(kernel_type_for_var, expected_kernel_key)) { continue; } else { VLOG(3) << "Transform Variable " << var_base->Name() << " from " << kernel_type_for_var << " to " << expected_kernel_key; if (GetVariableWrapper(var_base)->hasCacheKey(expected_kernel_key)) { VLOG(3) << "Hit variable_wrapper cache: key=" << expected_kernel_key; std::shared_ptr<VariableWrapper> cache_var = GetVariableWrapper(var_base)->getCacheValue( expected_kernel_key); if (tmp_ins_ptr == nullptr) { tmp_ins_ptr = std::make_shared<NameVarMap<VarType>>(ins); } const auto* tensor = GetTensorFromVar(cache_var->Var()); auto tmp_var = std::make_shared<VarType>(var_base->Name()); tmp_var->SetType(var_base->Type()); SetTensorToVariable(cache_var->Var(), *tensor, tmp_var->MutableVar()); (*tmp_ins_ptr)[name_pair.first][i] = tmp_var; } else { framework::Tensor out; TransformData(expected_kernel_key, kernel_type_for_var, *tensor, &out); if (NeedTransformDataType(kernel_type_for_var, expected_kernel_key)) { // To avoid NameVarMap copy construction overhead in general // scenarios, if inplace transformed, return original input // directly if (tmp_ins_ptr == nullptr) { tmp_ins_ptr = std::make_shared<NameVarMap<VarType>>(ins); } auto tmp_var = std::make_shared<VarType>(var_base->Name()); tmp_var->SetType(var_base->Type()); SetTensorToVariable(var_base->Var(), out, tmp_var->MutableVar()); (*tmp_ins_ptr)[name_pair.first][i] = tmp_var; GetVariableWrapper(var_base)->setCacheValue( expected_kernel_key, GetVariableWrapper(tmp_var)); VLOG(3) << "Set cache to variable_wrapper: key=" << expected_kernel_key; } else { // if dtype is same, transform inplace will not change the // original // value, transform inplace to avoid multiple copy SetTensorToVariable(var_base->Var(), out, var_base->MutableVar()); } } } } } } return tmp_ins_ptr; } class PreparedOp { public: PreparedOp(const framework::OperatorBase& op, const framework::RuntimeContext& ctx, const framework::OpKernelType& kernel_type, const framework::OperatorWithKernel::OpKernelFunc& func, platform::DeviceContext* dev_ctx); PreparedOp(const framework::OperatorBase& op, const framework::RuntimeContext& ctx, const framework::OpKernelType& kernel_type, const framework::KernelSignature& kernel_signature, const pten::Kernel& pt_kernel, platform::DeviceContext* dev_ctx); static PreparedOp Prepare(const NameVarMap<VarBase>& ins, const NameVarMap<VarBase>& outs, const framework::OperatorWithKernel& op, const platform::Place& place, const framework::AttributeMap& attrs, const framework::AttributeMap& default_attrs); static PreparedOp Prepare(const NameVarMap<VariableWrapper>& ins, const NameVarMap<VariableWrapper>& outs, const framework::OperatorWithKernel& op, const platform::Place& place, const framework::AttributeMap& attrs, const framework::AttributeMap& default_attrs); void Run(const NameVarMap<VarBase>& in, const NameVarMap<VarBase>& out, const framework::AttributeMap& attrs, const framework::AttributeMap& default_attrs); void Run(const NameVarMap<VariableWrapper>& ins, const NameVarMap<VariableWrapper>& outs, const framework::AttributeMap& attrs, const framework::AttributeMap& default_attrs); const framework::OpKernelType& kernel_type() const { return kernel_type_; } private: const framework::OperatorBase& op_; const framework::RuntimeContext& ctx_; framework::OpKernelType kernel_type_; framework::OperatorWithKernel::OpKernelFunc func_; platform::DeviceContext* dev_ctx_; // NOTE(chenweihang): Similar op members are used to adapt to // new pten kernel, if there is a better design in the future, // we may polish the implementation here bool run_pten_kernel_{false}; framework::KernelSignature pt_kernel_signature_; pten::Kernel pt_kernel_; }; } // namespace imperative } // namespace paddle