// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/imperative/prepared_operator.h" #include #include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/imperative/infer_shape_context.h" #include "paddle/fluid/imperative/infer_var_type_context.h" namespace paddle { namespace imperative { const framework::Tensor* GetTensorFromVar(const framework::Variable& var) { if (var.IsType()) { return &(var.Get()); } else if (var.IsType()) { return &(var.Get().value()); } else { return nullptr; } } template static void HandleComplexGradToRealGrad(const NameVarMap& outs) { for (auto& pair : outs) { for (auto& var : pair.second) { if (var == nullptr) { continue; } if (var->ForwardDataType() == static_cast(-1)) { VLOG(6) << "Var (" << var->Name() << ")'s forward data type is not set."; continue; } if (!framework::IsComplexType(var->DataType()) || framework::IsComplexType(var->ForwardDataType())) { continue; } const auto* tensor = GetTensorFromVar(var->Var()); if (tensor && tensor->IsInitialized()) { VLOG(6) << "Transform " << framework::DataTypeToString(var->DataType()) << " var `" << var->Name() << "` to " << framework::DataTypeToString(var->ForwardDataType()) << " real var in dynamic graph."; framework::Tensor out; framework::TransComplexToReal(var->ForwardDataType(), var->DataType(), *tensor, &out); SetTensorToVariable(var->Var(), out, var->MutableVar()); } } } } PreparedOp::PreparedOp(const framework::OperatorBase& op, const framework::RuntimeContext& ctx, const framework::OpKernelType& kernel_type, const framework::OperatorWithKernel::OpKernelFunc& func, platform::DeviceContext* dev_ctx) : op_(op), ctx_(ctx), kernel_type_(kernel_type), func_(func), dev_ctx_(dev_ctx) {} PreparedOp PreparedOp::Prepare( const framework::OperatorWithKernel& op, const framework::OpKernelType& expected_kernel_key) { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(expected_kernel_key.place_); // check if op[type] has kernel registered. auto& all_op_kernels = op.AllOpKernels(); auto kernels_iter = all_op_kernels.find(op.Type()); PADDLE_ENFORCE_NE( kernels_iter, all_op_kernels.end(), platform::errors::NotFound( "There are no kernels which are registered in the %s operator.", op.Type())); auto& kernels = kernels_iter->second; framework::RuntimeContext ctx({}, {}); auto kernel_iter = kernels.find(expected_kernel_key); // TODO(jiabin): Add operator.cc's line 1000 part back when we need that case PADDLE_ENFORCE_NE(kernel_iter, kernels.end(), platform::errors::NotFound( "Operator %s does not have kernel for %s.", op.Type(), KernelTypeToString(expected_kernel_key))); return PreparedOp(op, ctx, expected_kernel_key, kernel_iter->second, dev_ctx); } template static void PreparedOpRunImpl( const framework::OperatorBase& op, const framework::RuntimeContext& ctx, const framework::OpKernelType& kernel_type, const framework::OperatorWithKernel::OpKernelFunc& func, platform::DeviceContext* dev_ctx, const NameVarMap& ins, const NameVarMap& outs, const framework::AttributeMap& attrs) { // TODO(zjl): remove scope in dygraph framework::Scope scope; DygraphInferShapeContext infer_shape_ctx(&ins, &outs, &attrs, op.Type()); static_cast(op).InferShape( &infer_shape_ctx); func(DygraphExecutionContext(op, scope, *dev_ctx, ctx, ins, outs, attrs)); /** * [ Why need handle complex gradient to real gradient? ] * * After the introduction of complex number calculations, Ops that support * complex number calculations generally support type promotion, such as * x(float32) + y(complex64) = out(complex64), then the type of the grad * tensor should be dout(complex64), dx(float32), dy (complex64). * * But because the dout is complex64, the dx is also complex64 after * grad op kernel executed, we need to recognize this situation and * convert dx to float32 type. HandleComplexGradToRealGrad does this thing. */ if (framework::IsComplexType(kernel_type.data_type_)) { HandleComplexGradToRealGrad(outs); } } void PreparedOp::Run(const NameVarMap& ins, const NameVarMap& outs, const framework::AttributeMap& attrs) { PreparedOpRunImpl(op_, ctx_, kernel_type_, func_, dev_ctx_, ins, outs, attrs); } void PreparedOp::Run(const NameVarMap& ins, const NameVarMap& outs, const framework::AttributeMap& attrs) { PreparedOpRunImpl(op_, ctx_, kernel_type_, func_, dev_ctx_, ins, outs, attrs); } } // namespace imperative } // namespace paddle