// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include "paddle/fluid/framework/data_transform.h" #include "paddle/fluid/framework/op_kernel_type.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/framework/type_defs.h" #include "paddle/fluid/imperative/execution_context.h" #include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/type_defs.h" #include "paddle/pten/api/include/core.h" DECLARE_bool(use_mkldnn); namespace paddle { namespace framework { class Tensor; class Variable; } // namespace framework namespace platform { class DeviceContext; } // namespace platform } // namespace paddle namespace paddle { namespace imperative { const framework::Tensor* GetTensorFromVar(const framework::Variable& var); template static void SetForwardDataTypeOfGradVar(const std::shared_ptr& var); template <> void SetForwardDataTypeOfGradVar( const std::shared_ptr& var) { if (var->HasGradVar()) { auto grad_var = var->GetGradVar(); VLOG(6) << "Set grad var (" << grad_var->Name() << ")'s forward dtype to (" << framework::DataTypeToString(var->DataType()) << ")."; grad_var->SetForwardDataType(var->DataType()); } } template <> void SetForwardDataTypeOfGradVar(const std::shared_ptr& var) { if (var->HasGradVar()) { auto& shared_var = var->SharedVar(); SetForwardDataTypeOfGradVar(shared_var); } } extern const std::shared_ptr& GetVariableWrapper( const std::shared_ptr& var); extern const std::shared_ptr& GetVariableWrapper( const std::shared_ptr& var); template std::shared_ptr> PrepareData( const framework::OperatorWithKernel& op, const NameVarMap& ins, const framework::OpKernelType& expected_kernel_key) { std::shared_ptr> tmp_ins_ptr = nullptr; for (const auto& name_pair : ins) { for (size_t i = 0; i < name_pair.second.size(); ++i) { auto& var_base = name_pair.second[i]; SetForwardDataTypeOfGradVar(var_base); const auto* tensor = GetTensorFromVar(var_base->Var()); if (tensor && tensor->IsInitialized()) { auto kernel_type_for_var = op.GetKernelTypeForVar( name_pair.first, *tensor, expected_kernel_key); if (!NeedTransform(kernel_type_for_var, expected_kernel_key)) { continue; } else { VLOG(3) << "Transform Variable " << var_base->Name() << " from " << kernel_type_for_var << " to " << expected_kernel_key; if (GetVariableWrapper(var_base)->hasCacheKey(expected_kernel_key)) { VLOG(3) << "Hit variable_wrapper cache: key=" << expected_kernel_key; std::shared_ptr cache_var = GetVariableWrapper(var_base)->getCacheValue( expected_kernel_key); if (tmp_ins_ptr == nullptr) { tmp_ins_ptr = std::make_shared>(ins); } const auto* tensor = GetTensorFromVar(cache_var->Var()); auto tmp_var = std::make_shared(var_base->Name()); tmp_var->SetType(var_base->Type()); SetTensorToVariable(cache_var->Var(), *tensor, tmp_var->MutableVar()); (*tmp_ins_ptr)[name_pair.first][i] = tmp_var; } else { framework::Tensor out; TransformData(expected_kernel_key, kernel_type_for_var, *tensor, &out); if (NeedTransformDataType(kernel_type_for_var, expected_kernel_key)) { // To avoid NameVarMap copy construction overhead in general // scenarios, if inplace transformed, return original input // directly if (tmp_ins_ptr == nullptr) { tmp_ins_ptr = std::make_shared>(ins); } auto tmp_var = std::make_shared(var_base->Name()); tmp_var->SetType(var_base->Type()); SetTensorToVariable(var_base->Var(), out, tmp_var->MutableVar()); (*tmp_ins_ptr)[name_pair.first][i] = tmp_var; GetVariableWrapper(var_base)->setCacheValue( expected_kernel_key, GetVariableWrapper(tmp_var)); VLOG(3) << "Set cache to variable_wrapper: key=" << expected_kernel_key; } else { // if dtype is same, transform inplace will not change the // original // value, transform inplace to avoid multiple copy SetTensorToVariable(var_base->Var(), out, var_base->MutableVar()); } } } } } } return tmp_ins_ptr; } class PreparedOp { public: PreparedOp(const framework::OperatorBase& op, const framework::RuntimeContext& ctx, const framework::OpKernelType& kernel_type, const framework::OperatorWithKernel::OpKernelFunc& func, platform::DeviceContext* dev_ctx); PreparedOp(const framework::OperatorBase& op, const framework::RuntimeContext& ctx, const framework::OpKernelType& kernel_type, const framework::KernelSignature& kernel_signature, const pten::Kernel& pt_kernel, platform::DeviceContext* dev_ctx); static PreparedOp Prepare(const NameVarMap& ins, const NameVarMap& outs, const framework::OperatorWithKernel& op, const platform::Place& place, const framework::AttributeMap& attrs, const framework::AttributeMap& default_attrs); static PreparedOp Prepare(const NameVarMap& ins, const NameVarMap& outs, const framework::OperatorWithKernel& op, const platform::Place& place, const framework::AttributeMap& attrs, const framework::AttributeMap& default_attrs); void Run(const NameVarMap& in, const NameVarMap& out, const framework::AttributeMap& attrs, const framework::AttributeMap& default_attrs); void Run(const NameVarMap& ins, const NameVarMap& outs, const framework::AttributeMap& attrs, const framework::AttributeMap& default_attrs); const framework::OpKernelType& kernel_type() const { return kernel_type_; } private: const framework::OperatorBase& op_; const framework::RuntimeContext& ctx_; framework::OpKernelType kernel_type_; framework::OperatorWithKernel::OpKernelFunc func_; platform::DeviceContext* dev_ctx_; // NOTE(chenweihang): Similar op members are used to adapt to // new pten kernel, if there is a better design in the future, // we may polish the implementation here bool run_pten_kernel_{false}; framework::KernelSignature pt_kernel_signature_; pten::Kernel pt_kernel_; }; } // namespace imperative } // namespace paddle