未验证 提交 34cce62f 编写于 作者: J Jiabin Yang 提交者: GitHub

Merge legacy to fluid (#39318)

上级 01f606b4
set(eager_deps pten pten_api hook_utils tensor_utils utils global_utils backward pten_tensor legacy autograd_meta grad_node_info grad_tensor_holder accumulation_node) set(eager_deps pten pten_api hook_utils tensor_utils utils global_utils backward pten_tensor tracer layer autograd_meta grad_node_info grad_tensor_holder accumulation_node)
set(fluid_deps tracer layer proto_desc operator op_registry variable_helper memcpy) set(fluid_deps tracer layer proto_desc operator op_registry variable_helper memcpy)
set(generated_deps dygraph_function dygraph_node) set(generated_deps dygraph_function dygraph_node)
...@@ -9,14 +9,12 @@ endif() ...@@ -9,14 +9,12 @@ endif()
add_subdirectory(api) add_subdirectory(api)
add_subdirectory(accumulation) add_subdirectory(accumulation)
add_subdirectory(legacy)
cc_library(grad_node_info SRCS grad_node_info.cc DEPS pten pten_api) cc_library(grad_node_info SRCS grad_node_info.cc DEPS pten pten_api)
cc_library(grad_tensor_holder SRCS grad_tensor_holder.cc DEPS grad_node_info gradient_accumulator) cc_library(grad_tensor_holder SRCS grad_tensor_holder.cc DEPS grad_node_info gradient_accumulator)
cc_library(autograd_meta SRCS autograd_meta.cc DEPS pten pten_api) cc_library(autograd_meta SRCS autograd_meta.cc DEPS pten pten_api)
cc_library(utils SRCS utils.cc DEPS pten pten_api global_utils layer proto_desc operator op_registry variable_helper memcpy scale_op autograd_meta hook_utils) cc_library(utils SRCS utils.cc DEPS pten pten_api global_utils layer proto_desc operator op_registry variable_helper memcpy scale_op autograd_meta hook_utils)
cc_library(legacy SRCS ${DYGRAPH_LEGACY} DEPS global_utils proto_desc operator pten pten_api op_registry variable_helper memcpy)
cc_library(backward SRCS backward.cc DEPS grad_tensor_holder utils autograd_meta grad_node_info) cc_library(backward SRCS backward.cc DEPS grad_tensor_holder utils autograd_meta grad_node_info)
add_subdirectory(tests) add_subdirectory(tests)
fluid_generated/** fluid_generated/**
eager_generated/**
\ No newline at end of file
...@@ -1220,7 +1220,8 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1220,7 +1220,8 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
// According to op_proto->attrs() // According to op_proto->attrs()
egr::legacy::RunOp("op_type", ins, outs, attr_map, Controller.Instance().GetCurrentTracer()->TraceOp("op_type", ins, outs,
attr_map,
Controller.Instance().GetExpectedPlace(), {}); Controller.Instance().GetExpectedPlace(), {});
// According to fwd_outputs_names // According to fwd_outputs_names
...@@ -1401,7 +1402,8 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1401,7 +1402,8 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
const char* FWD_TRACE_OP_TEMPLATE = const char* FWD_TRACE_OP_TEMPLATE =
" paddle::framework::AttributeMap attrs = attr_map;\n" " paddle::framework::AttributeMap attrs = attr_map;\n"
" paddle::framework::AttributeMap default_attrs;\n" " paddle::framework::AttributeMap default_attrs;\n"
" egr::legacy::RunOp(\"%s\", ins, outs, attrs, \n" " egr::Controller::Instance().GetCurrentTracer()->TraceOp(\"%s\", ins, "
"outs, attrs, \n"
" egr::Controller::Instance().GetExpectedPlace(),\n" " egr::Controller::Instance().GetExpectedPlace(),\n"
" &default_attrs, true, {});\n"; " &default_attrs, true, {});\n";
std::string trace_op_str = std::string trace_op_str =
...@@ -1712,7 +1714,8 @@ static std::string GenerateSingleOpBase( ...@@ -1712,7 +1714,8 @@ static std::string GenerateSingleOpBase(
" // Pass the entire attribute map to TraceOp\n" " // Pass the entire attribute map to TraceOp\n"
" // The underlying kernel will pickup whatever attribute they need " " // The underlying kernel will pickup whatever attribute they need "
"at runtime\n" "at runtime\n"
" egr::legacy::RunOp(\"%s\", %s, %s, %s,\n" " egr::Controller::Instance().GetCurrentTracer()->TraceOp(\"%s\", %s, "
"%s, %s,\n"
" egr::Controller::Instance().GetExpectedPlace(),\n" " egr::Controller::Instance().GetExpectedPlace(),\n"
" &this->default_attr_map_, false, {});\n"; " &this->default_attr_map_, false, {});\n";
std::string trace_opbase_str = paddle::string::Sprintf( std::string trace_opbase_str = paddle::string::Sprintf(
...@@ -1822,7 +1825,8 @@ static std::string GenerateGradNodeCCContents( ...@@ -1822,7 +1825,8 @@ static std::string GenerateGradNodeCCContents(
// Visit each OpBase // Visit each OpBase
for(auto iter = "grad_node->begin()"; iter < "grad_node->end()"; iter++) { for(auto iter = "grad_node->begin()"; iter < "grad_node->end()"; iter++) {
// Simply pass entire attribute map to kernels // Simply pass entire attribute map to kernels
egr::legacy::RunOp("iter->Type()", ins, outs, this->attr_map_, Controller.Instance().GetCurrentTracer()->TraceOp("iter->Type()", ins,
outs, this->attr_map_,
egr::Controller::Instance().ExpectedPlace(), false, {}); egr::Controller::Instance().ExpectedPlace(), false, {});
} }
...@@ -2054,6 +2058,7 @@ static std::string GenerateDygraphHFileIncludes() { ...@@ -2054,6 +2058,7 @@ static std::string GenerateDygraphHFileIncludes() {
"#include \"paddle/fluid/eager/autograd_meta.h\"\n" "#include \"paddle/fluid/eager/autograd_meta.h\"\n"
"#include \"paddle/pten/api/all.h\"\n" "#include \"paddle/pten/api/all.h\"\n"
"#include \"paddle/fluid/eager/utils.h\"\n" "#include \"paddle/fluid/eager/utils.h\"\n"
"#include \"paddle/fluid/imperative/tracer.h\"\n"
"#include \"paddle/fluid/framework/op_registry.h\"\n\n"; "#include \"paddle/fluid/framework/op_registry.h\"\n\n";
dygraph_forward_api_includes_str += dygraph_forward_api_includes_str +=
...@@ -2084,8 +2089,7 @@ static void GenerateForwardDygraphFile(const std::string& forward_cc_path, ...@@ -2084,8 +2089,7 @@ static void GenerateForwardDygraphFile(const std::string& forward_cc_path,
"dygraph_forward_api.h\"\n" "dygraph_forward_api.h\"\n"
"#include " "#include "
"\"paddle/fluid/eager/api/generated/fluid_generated/nodes/nodes.h\"\n\n" "\"paddle/fluid/eager/api/generated/fluid_generated/nodes/nodes.h\"\n\n"
"#include \"paddle/fluid/eager/api/utils/global_utils.h\"\n" "#include \"paddle/fluid/eager/api/utils/global_utils.h\"\n";
"#include \"paddle/fluid/eager/legacy/op_runner.h\"\n";
std::string forward_cc_include_str = std::string forward_cc_include_str =
paddle::string::Sprintf(FORWARD_INCLUDE_TEMPLATE); paddle::string::Sprintf(FORWARD_INCLUDE_TEMPLATE);
std::ofstream forward_cc_stream(forward_cc_path, std::ios::out); std::ofstream forward_cc_stream(forward_cc_path, std::ios::out);
...@@ -2099,7 +2103,7 @@ static void GenerateNodeHFile(const std::string& node_h_path, ...@@ -2099,7 +2103,7 @@ static void GenerateNodeHFile(const std::string& node_h_path,
std::string node_h_include_str = std::string node_h_include_str =
"#pragma once\n" "#pragma once\n"
"#include \"paddle/fluid/eager/tensor_wrapper.h\"\n" "#include \"paddle/fluid/eager/tensor_wrapper.h\"\n"
"#include \"paddle/fluid/eager/legacy/op_runner.h\"\n" "#include \"paddle/fluid/imperative/tracer.h\"\n"
"#include \"paddle/fluid/eager/grad_node_info.h\"\n\n"; "#include \"paddle/fluid/eager/grad_node_info.h\"\n\n";
std::ofstream node_h_stream(node_h_path, std::ios::out); std::ofstream node_h_stream(node_h_path, std::ios::out);
node_h_stream << node_h_include_str; node_h_stream << node_h_include_str;
......
...@@ -845,7 +845,6 @@ def GenerateNodeHFile(filepath, node_declaration_str): ...@@ -845,7 +845,6 @@ def GenerateNodeHFile(filepath, node_declaration_str):
file_contents = """ file_contents = """
#pragma once #pragma once
#include "paddle/fluid/eager/tensor_wrapper.h" #include "paddle/fluid/eager/tensor_wrapper.h"
#include "paddle/fluid/eager/legacy/op_runner.h"
#include "paddle/fluid/eager/grad_node_info.h" #include "paddle/fluid/eager/grad_node_info.h"
""" """
...@@ -860,7 +859,6 @@ def GenerateForwardCCFile(filepath, forward_definition_str): ...@@ -860,7 +859,6 @@ def GenerateForwardCCFile(filepath, forward_definition_str):
#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h" #include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h"
#include "paddle/fluid/eager/api/utils/global_utils.h" #include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/eager/legacy/op_runner.h"
""" """
file_contents += forward_definition_str file_contents += forward_definition_str
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
// pten deps // pten deps
#include "paddle/pten/api/all.h" #include "paddle/pten/api/include/tensor.h"
#include "paddle/pten/api/lib/api_declare.h" #include "paddle/pten/api/lib/api_declare.h"
#include "paddle/pten/api/lib/utils/tensor_utils.h" #include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/core/compat/convert_utils.h" #include "paddle/pten/core/compat/convert_utils.h"
......
file(GLOB DYGRAPH_LEGACY "*.cpp" "*.cc")
set(DYGRAPH_LEGACY ${DYGRAPH_LEGACY} PARENT_SCOPE)
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/eager/legacy/amp_auto_cast.h"
#include <memory>
#include <string>
#include "paddle/fluid/eager/legacy/op_runner.h"
#include "paddle/fluid/eager/legacy/tensor_helper.h"
#include "paddle/fluid/framework/operator.h"
namespace egr {
namespace legacy {
AmpOperators::AmpOperators()
: allow_ops_(new std::unordered_set<std::string>()),
block_ops_(new std::unordered_set<std::string>()),
unsupported_fp16_ops_(new std::unordered_set<std::string>()) {
auto& all_kernels = paddle::framework::OperatorWithKernel::AllOpKernels();
auto fp16_dtype = paddle::framework::proto::VarType::FP16;
for (auto it = all_kernels.begin(); it != all_kernels.end(); it++) {
bool supported = false;
for (auto& kernel_type : it->second) {
if ((paddle::platform::is_gpu_place(kernel_type.first.place_) ||
paddle::platform::is_xpu_place(kernel_type.first.place_)) &&
kernel_type.first.data_type_ == fp16_dtype) {
supported = true;
}
}
if (!supported) {
unsupported_fp16_ops_->insert(it->first);
}
}
}
AmpOperators::~AmpOperators() {}
AmpOperators& AmpOperators::Instance() {
static AmpOperators instance;
return instance;
}
std::shared_ptr<std::unordered_set<std::string>>
AmpOperators::GetMutableAllowOps() {
return allow_ops_;
}
std::shared_ptr<std::unordered_set<std::string>>
AmpOperators::GetMutableBlockOps() {
return block_ops_;
}
std::shared_ptr<std::unordered_set<std::string>>
AmpOperators::GetMutableUnsupportedFp16Ops() {
return unsupported_fp16_ops_;
}
std::ostream& operator<<(std::ostream& os, AmpOperators& ops) {
os << "allow ops: ";
auto allow_ops = ops.GetMutableAllowOps();
std::copy((*allow_ops).begin(), (*allow_ops).end(),
std::ostream_iterator<std::string>(os, " "));
os << "\n";
os << "block ops: ";
auto block_ops = ops.GetMutableBlockOps();
std::copy((*block_ops).begin(), (*block_ops).end(),
std::ostream_iterator<std::string>(os, " "));
os << "\n";
os << "unsupported fp16 ops: ";
auto unsupported_fp16_ops = ops.GetMutableUnsupportedFp16Ops();
std::copy((*unsupported_fp16_ops).begin(), (*unsupported_fp16_ops).end(),
std::ostream_iterator<std::string>(os, " "));
return os;
}
inline std::string GetDtypeStr(
const std::shared_ptr<egr::EagerTensor>& tensor) {
return paddle::framework::DataTypeToString(
egr::legacy::GetDtypeFromVar(tensor->Var()));
}
inline bool NeedCast(const std::shared_ptr<egr::EagerTensor>& tensor) {
auto place = egr::legacy::GetPlaceFromVar(tensor->Var());
auto data_type = egr::legacy::GetDtypeFromVar(tensor->Var());
if (paddle::platform::is_gpu_place(place) ||
paddle::platform::is_cuda_pinned_place(place) ||
paddle::platform::is_xpu_place(place)) {
// CudaPinndePlace is added for varbase created by dataloader
if (data_type == paddle::framework::proto::VarType::FP32 ||
data_type == paddle::framework::proto::VarType::FP16) {
return true;
}
}
return false;
}
// NOTE: Trace a cast op, so if a var is casted from fp32 to fp16, then the grad
// var will be cast back from fp16 to fp32 during backward phase.
static inline std::shared_ptr<egr::EagerTensor> CastToType(
const std::shared_ptr<egr::EagerTensor>& tensor,
const paddle::framework::proto::VarType::Type dst_type) {
NameTensorMap ins = {{"X", {tensor}}};
auto in_data_type = egr::legacy::GetDtypeFromVar(tensor->Var());
paddle::framework::AttributeMap attrs = {{"in_dtype", in_data_type},
{"out_dtype", dst_type}};
auto out = std::shared_ptr<egr::EagerTensor>(new egr::EagerTensor());
NameTensorMap outs = {{"Out", {out}}};
{
AutoCastGuard guard(paddle::imperative::AmpLevel::O0);
paddle::framework::AttributeMap default_attrs;
RunOp("cast", ins, outs, std::move(attrs), {}, &default_attrs, true);
}
return out;
}
static inline std::shared_ptr<egr::EagerTensor> CastToFP16(
const std::shared_ptr<egr::EagerTensor>& tensor) {
auto dst_type = paddle::framework::proto::VarType::FP16;
if (NeedCast(tensor) &&
(egr::legacy::GetDtypeFromVar(tensor->Var()) != dst_type)) {
return CastToType(tensor, dst_type);
}
return tensor;
}
static inline std::shared_ptr<egr::EagerTensor> CastToFP32(
const std::shared_ptr<egr::EagerTensor>& tensor) {
auto dst_type = paddle::framework::proto::VarType::FP32;
if (NeedCast(tensor) &&
(egr::legacy::GetDtypeFromVar(tensor->Var()) != dst_type)) {
return CastToType(tensor, dst_type);
}
return tensor;
}
static inline paddle::framework::proto::VarType::Type GetPromoteType(
const std::string& op_type, const NameTensorMap& ins) {
auto dst_type = paddle::framework::proto::VarType::FP16;
for (const auto& pair : ins) {
for (const auto& tensor : pair.second) {
if (egr::legacy::GetDtypeFromVar(tensor->Var()) ==
paddle::framework::proto::VarType::FP32) {
dst_type = egr::legacy::GetDtypeFromVar(tensor->Var());
break;
}
}
}
// NOTE(juncai): moving_average_abs_max_scale only consider the
// dtype of input(X)
if (op_type == "moving_average_abs_max_scale") {
for (const auto& pair : ins) {
if (pair.first == "X" &&
egr::legacy::GetDtypeFromVar(pair.second.front()->Var()) ==
paddle::framework::proto::VarType::FP16) {
dst_type = paddle::framework::proto::VarType::FP16;
}
}
}
return dst_type;
}
NameTensorMap AutoCastInputs(const std::string& op_type,
const NameTensorMap& ins) {
NameTensorMap new_ins(ins);
if (AmpOperators::Instance().GetMutableAllowOps()->count(op_type)) {
for (auto& pair : new_ins) {
// NOTE(zhiqiu): batch_norm and layer_norm support only input x is fp16.
if ((op_type == "batch_norm" || op_type == "layer_norm" ||
op_type == "sync_batch_norm") &&
pair.first != "X") {
continue;
}
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to float16";
for (auto& var : pair.second) {
var = CastToFP16(var);
}
}
return new_ins;
} else if (AmpOperators::Instance().GetMutableBlockOps()->count(op_type)) {
for (auto& pair : new_ins) {
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to float";
for (auto& var : pair.second) {
var = CastToFP32(var);
}
}
return new_ins;
} else {
auto dst_type = GetPromoteType(op_type, ins);
// NOTE(zhiqiu): if the op has op fp16 kernel, fall back to fp32.
if (dst_type == paddle::framework::proto::VarType::FP16 &&
AmpOperators::Instance().GetMutableUnsupportedFp16Ops()->count(
op_type)) {
dst_type = paddle::framework::proto::VarType::FP32;
}
for (auto& pair : new_ins) {
// NOTE(zhiqiu): batch_norm and layer_norm support only input x is fp16.
if ((op_type == "batch_norm" || op_type == "layer_norm" ||
op_type == "sync_batch_norm") &&
pair.first == "X" &&
dst_type == paddle::framework::proto::VarType::FP32) {
continue;
}
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to "
<< paddle::framework::DataTypeToString(dst_type);
for (auto& var : pair.second) {
var = (dst_type == paddle::framework::proto::VarType::FP32
? CastToFP32(var)
: CastToFP16(var));
}
}
return new_ins;
}
return new_ins;
}
NameTensorMap CastPureFp16Inputs(const std::string& op_type,
const NameTensorMap& ins) {
NameTensorMap new_ins(ins);
auto dst_type = paddle::framework::proto::VarType::FP16;
if (AmpOperators::Instance().GetMutableUnsupportedFp16Ops()->count(op_type) ||
AmpOperators::Instance().GetMutableBlockOps()->count(op_type)) {
dst_type = paddle::framework::proto::VarType::FP32;
}
for (auto& pair : new_ins) {
if ((op_type == "batch_norm" || op_type == "layer_norm" ||
op_type == "sync_batch_norm") &&
pair.first != "X") {
continue;
}
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to "
<< paddle::framework::DataTypeToString(dst_type);
for (auto& var : pair.second) {
var = (dst_type == paddle::framework::proto::VarType::FP32
? CastToFP32(var)
: CastToFP16(var));
}
}
return new_ins;
}
} // namespace legacy
} // namespace egr
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <set>
#include <string>
#include <tuple>
#include <unordered_set>
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/eager/eager_tensor.h"
#include "paddle/fluid/eager/legacy/type_def.h"
#include "paddle/fluid/imperative/amp_auto_cast.h"
namespace egr {
namespace legacy {
class AmpOperators {
public:
~AmpOperators();
AmpOperators(const AmpOperators& o) = delete;
const AmpOperators& operator=(const AmpOperators& o) = delete;
static AmpOperators& Instance();
std::shared_ptr<std::unordered_set<std::string>> GetMutableAllowOps();
std::shared_ptr<std::unordered_set<std::string>> GetMutableBlockOps();
std::shared_ptr<std::unordered_set<std::string>>
GetMutableUnsupportedFp16Ops();
private:
AmpOperators(); // forbid calling default constructor
// The set of ops that support fp16 calculation and are considered numerically
// safe and performance critical. These ops are always converted to fp16.
std::shared_ptr<std::unordered_set<std::string>> allow_ops_;
// The set of ops that support fp16 calculation and are considered numerically
// dangerous and whose effects may also be observed in downstream ops.
std::shared_ptr<std::unordered_set<std::string>> block_ops_;
// The set of ops that has no fp16 CUDA kennel.
std::shared_ptr<std::unordered_set<std::string>> unsupported_fp16_ops_;
};
std::ostream& operator<<(std::ostream& os, AmpOperators& ops);
// NOTE(zhiqiu): AutoCastGuard is used for RAII.
class AutoCastGuard {
public:
explicit AutoCastGuard(paddle::imperative::AmpLevel guard_level) {
pre_amp_level_ = Controller::Instance().GetAMPLevel();
if (pre_amp_level_ != guard_level) {
Controller::Instance().SetAMPLevel(guard_level);
}
}
~AutoCastGuard() { Controller::Instance().SetAMPLevel(pre_amp_level_); }
// forbid copy and operator=
AutoCastGuard(const AutoCastGuard& guard) = delete;
AutoCastGuard& operator=(const AutoCastGuard& guard) = delete;
private:
paddle::imperative::AmpLevel pre_amp_level_;
};
NameTensorMap AutoCastInputs(const std::string& op_type,
const NameTensorMap& ins);
NameTensorMap CastPureFp16Inputs(const std::string& op_type,
const NameTensorMap& ins);
} // namespace legacy
} // namespace egr
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/eager/eager_tensor.h"
#include "paddle/fluid/eager/legacy/type_def.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/framework/variable.h"
namespace egr {
namespace legacy {
class EagerExecutionContext : public paddle::framework::ExecutionContext {
using Variable = paddle::framework::Variable;
public:
EagerExecutionContext(const paddle::framework::OperatorBase& op,
const paddle::framework::Scope& scope,
const paddle::platform::DeviceContext& device_context,
const paddle::framework::RuntimeContext& ctx,
const NameTensorMap& tensor_map_in,
const NameTensorMap& tensor_map_out,
const paddle::framework::AttributeMap& attrs,
const paddle::framework::AttributeMap& default_attrs)
: ExecutionContext(op, scope, device_context, ctx),
tensor_map_in_(tensor_map_in),
tensor_map_out_(tensor_map_out),
attrs_(attrs),
default_attrs_(default_attrs) {}
std::string InputName(const std::string& name) const override {
auto it = tensor_map_in_.find(name);
PADDLE_ENFORCE_NE(it, tensor_map_in_.end(),
paddle::platform::errors::PreconditionNotMet(
"Can not find [%s] in Input", name));
// TODO(jiabin): This is used for egr::EagerTensor temporally,
// once we have name, remove it.
return it->second[0] ? it->second[0]->name()
: paddle::framework::kEmptyVarName;
}
std::vector<std::string> InputNames(const std::string& name) const override {
auto it = tensor_map_in_.find(name);
PADDLE_ENFORCE_NE(
it, tensor_map_in_.end(),
paddle::platform::errors::NotFound("Can not find [%s] in Input", name));
std::vector<std::string> vec_res;
vec_res.reserve(it->second.size());
for (size_t i = 0; i < it->second.size(); ++i) {
if (it->second[i]) {
// TODO(jiabin): This is used for egr::EagerTensor
// temporally, once we have name, remove it.
vec_res.push_back(it->second[i]->name());
} else {
vec_res.push_back(paddle::framework::kEmptyVarName);
}
}
return vec_res;
}
std::string OutputName(const std::string& name) const override {
auto it = tensor_map_out_.find(name);
PADDLE_ENFORCE_NE(it, tensor_map_out_.end(),
paddle::platform::errors::NotFound(
"Can not find [%s] in Output", name));
return it->second[0] ? it->second[0]->name()
: paddle::framework::kEmptyVarName;
}
std::vector<std::string> OutputNames(const std::string& name) const override {
auto it = tensor_map_out_.find(name);
PADDLE_ENFORCE_NE(it, tensor_map_out_.end(),
paddle::platform::errors::NotFound(
"Can not find [%s] in Output", name));
std::vector<std::string> vec_res;
vec_res.reserve(it->second.size());
for (size_t i = 0; i < it->second.size(); ++i) {
if (it->second[i]) {
vec_res.push_back(it->second[i]->name());
} else {
vec_res.push_back(paddle::framework::kEmptyVarName);
}
}
return vec_res;
}
bool HasAttr(const std::string& name) const override {
return attrs_.count(name) != 0 || default_attrs_.count(name) != 0;
}
const paddle::framework::AttributeMap& Attrs() const override {
return attrs_;
}
const paddle::framework::Attribute& GetAttr(
const std::string& name) const override {
auto it = attrs_.find(name);
if (it == attrs_.end()) {
it = default_attrs_.find(name);
if (it == default_attrs_.end()) {
PADDLE_THROW(paddle::platform::errors::NotFound(
"Can not find [%s] in attributes of op %s.", name,
this->GetOp().Type()));
}
}
return it->second;
}
std::vector<std::string> InNameList() const override {
std::vector<std::string> vec_temp;
vec_temp.reserve(tensor_map_in_.size());
for (auto& v : tensor_map_in_) {
vec_temp.push_back(v.first);
}
return vec_temp;
}
bool HasInput(const std::string& name) const override {
auto it = tensor_map_in_.find(name);
return (it != tensor_map_in_.end() && it->second.size() > 0);
}
bool HasOutput(const std::string& name) const override {
auto it = tensor_map_out_.find(name);
return (it != tensor_map_out_.end() && it->second.size() > 0);
}
size_t InputSize(const std::string& name) const override {
return InputNames(name).size();
}
size_t OutputSize(const std::string& name) const override {
return OutputNames(name).size();
}
const Variable* InputVar(const std::string& name) const override {
auto it = tensor_map_in_.find(name);
if (it == tensor_map_in_.end()) {
return nullptr;
}
return it->second.empty() || it->second[0] == nullptr
? nullptr
: it->second[0]->MutableVar();
}
Variable* OutputVar(const std::string& name) const override {
auto it = tensor_map_out_.find(name);
if (it == tensor_map_out_.end()) {
return nullptr;
}
return it->second.empty() || it->second[0] == nullptr
? nullptr
: it->second[0]->MutableVar();
}
const std::vector<Variable*> MultiInputVar(
const std::string& name) const override {
auto it = tensor_map_in_.find(name);
if (it == tensor_map_in_.end()) {
return {};
}
std::vector<Variable*> vec_res;
vec_res.reserve(it->second.size());
for (size_t i = 0; i < it->second.size(); ++i) {
vec_res.push_back(it->second[i] ? it->second[i]->MutableVar() : nullptr);
}
return vec_res;
}
std::vector<Variable*> MultiOutputVar(
const std::string& name) const override {
auto it = tensor_map_out_.find(name);
if (it == tensor_map_out_.end()) {
return {};
}
std::vector<Variable*> vec_res;
vec_res.reserve(it->second.size());
for (size_t i = 0; i < it->second.size(); ++i) {
vec_res.push_back(it->second[i] ? it->second[i]->MutableVar() : nullptr);
}
return vec_res;
}
private:
const NameTensorMap& tensor_map_in_;
const NameTensorMap& tensor_map_out_;
const paddle::framework::AttributeMap& attrs_;
const paddle::framework::AttributeMap& default_attrs_;
};
} // namespace legacy
} // namespace egr
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/eager/eager_tensor.h"
#include "paddle/fluid/eager/legacy/type_def.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/framework/var_type.h"
namespace egr {
namespace legacy {
class EagerInferShapeContext : public paddle::framework::InferShapeContext {
using DDim = paddle::framework::DDim;
public:
EagerInferShapeContext(
const NameTensorMap* in, const NameTensorMap* out,
const paddle::framework::AttributeMap* attr,
const paddle::framework::AttributeMap* default_attr,
const std::string op_type,
const paddle::framework::OpKernelType* op_kernel_type = nullptr)
: tensor_in_(in),
tensor_out_(out),
attrs_(attr),
default_attrs_(default_attr),
op_type_(op_type),
op_kernel_type_(op_kernel_type) {}
bool HasInput(const std::string& name) const override {
// has only one input
auto it = tensor_in_->find(name);
if (it == tensor_in_->end()) {
return false;
}
const auto& in = it->second;
if (in.size() == 0) return false;
PADDLE_ENFORCE_EQ(
in.size(), 1UL,
paddle::platform::errors::PreconditionNotMet(
"Input %s should not have more than one inputs", name));
return in[0] != nullptr;
}
bool HasOutput(const std::string& name) const override {
// has only one output
auto it = tensor_out_->find(name);
if (it == tensor_out_->end()) {
return false;
}
const auto& out = it->second;
if (out.size() == 0) {
return false;
}
PADDLE_ENFORCE_EQ(
out.size(), 1UL,
paddle::platform::errors::PreconditionNotMet(
"Output %s should not have more than one outputs", name));
return out[0] != nullptr;
}
bool HasInputs(const std::string& name) const override {
auto it = tensor_in_->find(name);
if (it == tensor_in_->end() || it->second.empty()) {
return false;
}
for (auto& input : it->second) {
if (input == nullptr) {
return false;
}
}
return true;
}
bool HasOutputs(const std::string& name) const override {
auto it = tensor_out_->find(name);
if (it == tensor_out_->end() || it->second.empty()) {
return false;
}
for (auto& output : it->second) {
if (output == nullptr) {
return false;
}
}
return true;
}
paddle::framework::AttrReader Attrs() const override {
return paddle::framework::AttrReader(*attrs_, *default_attrs_);
}
std::vector<std::string> Inputs(const std::string& name) const override {
std::vector<std::string> vec_res;
auto it = tensor_in_->find(name);
PADDLE_ENFORCE_NE(
it, tensor_in_->end(),
paddle::platform::errors::NotFound("can not find [%s] in input", name));
vec_res.reserve(it->second.size());
for (auto& var : it->second) {
if (var) {
vec_res.push_back(var->name());
} else {
vec_res.push_back(paddle::framework::kEmptyVarName);
}
}
return vec_res;
}
std::vector<std::string> Outputs(const std::string& name) const override {
std::vector<std::string> vec_res;
auto it = tensor_out_->find(name);
PADDLE_ENFORCE_NE(it, tensor_out_->end(),
paddle::platform::errors::NotFound(
"can not find [%s] in output", name));
vec_res.reserve(it->second.size());
for (auto& var : it->second) {
if (var) {
vec_res.push_back(var->name());
} else {
vec_res.push_back(paddle::framework::kEmptyVarName);
}
}
return vec_res;
}
std::string GetInputNameByIdx(size_t idx) const override {
auto& op_proto =
paddle::framework::OpInfoMap::Instance().Get(op_type_).proto_;
PADDLE_ENFORCE_LT(idx, op_proto->inputs().size(),
paddle::platform::errors::OutOfRange(
"The index should be less than the size of inputs of "
"operator %s, but got index is %d and size is %d",
op_type_, idx, op_proto->inputs().size()));
return op_proto->inputs()[idx].name();
}
std::string GetOutputNameByIdx(size_t idx) const override {
auto& op_proto =
paddle::framework::OpInfoMap::Instance().Get(op_type_).proto_;
PADDLE_ENFORCE_LT(
idx, op_proto->outputs().size(),
paddle::platform::errors::OutOfRange(
"The index should be less than the size of outputs of "
"operator %s, but got index is %d and size is %d",
op_type_, idx, op_proto->outputs().size()));
return op_proto->outputs()[idx].name();
}
void ShareDim(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) override {
auto in_it = tensor_in_->find(in);
auto out_it = tensor_out_->find(out);
PADDLE_ENFORCE_NE(
in_it, tensor_in_->end(),
paddle::platform::errors::NotFound("can not found [%s] in input", in));
PADDLE_ENFORCE_GT(in_it->second.size(), i,
paddle::platform::errors::PreconditionNotMet(
"Inputs %s should have %llu argument", in, i));
PADDLE_ENFORCE_NE(
out_it, tensor_out_->end(),
paddle::platform::errors::NotFound("can not found [%s] in input", in));
PADDLE_ENFORCE_GT(out_it->second.size(), j,
paddle::platform::errors::PreconditionNotMet(
"Outputs %s should have %llu argument", out, j));
paddle::framework::Variable* in_var = in_it->second[i]->MutableVar();
paddle::framework::Variable* out_var = out_it->second[j]->MutableVar();
PADDLE_ENFORCE_EQ(in_var->Type(), out_var->Type(),
paddle::platform::errors::PreconditionNotMet(
"The type of %s and %s is not the same.", in, out));
if (in_var->IsType<paddle::framework::LoDTensor>()) {
auto& in_lod_tensor = in_var->Get<paddle::framework::LoDTensor>();
auto* out_lod_tensor =
out_var->GetMutable<paddle::framework::LoDTensor>();
out_lod_tensor->Resize(in_lod_tensor.dims());
} else {
auto& in_sele_rows = in_var->Get<pten::SelectedRows>();
auto out_sele_rows = out_var->GetMutable<pten::SelectedRows>();
out_sele_rows->mutable_value()->Resize(in_sele_rows.value().dims());
out_sele_rows->set_rows(in_sele_rows.rows());
out_sele_rows->set_height(in_sele_rows.height());
}
}
void ShareAllLoD(const std::string& in,
const std::string& out) const override {
// do nothing
}
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) const override {
// do nothing
}
bool IsRuntime() const override { return true; }
bool IsRunMKLDNNKernel() const override {
return (op_kernel_type_ && (op_kernel_type_->data_layout_ ==
paddle::framework::DataLayout::kMKLDNN));
}
std::vector<paddle::framework::InferShapeVarPtr> GetInputVarPtrs(
const std::string& name) const override {
std::vector<paddle::framework::InferShapeVarPtr> res;
auto it = tensor_in_->find(name);
PADDLE_ENFORCE_NE(it, tensor_in_->end(),
paddle::platform::errors::NotFound(
"Can not find [%s] in inputs.", name));
for (auto& tensor : it->second) {
res.emplace_back(tensor->MutableVar());
}
return res;
}
std::vector<paddle::framework::InferShapeVarPtr> GetOutputVarPtrs(
const std::string& name) const override {
std::vector<paddle::framework::InferShapeVarPtr> res;
auto it = tensor_out_->find(name);
PADDLE_ENFORCE_NE(it, tensor_out_->end(),
paddle::platform::errors::NotFound(
"Can not find [%s] in outputs.", name));
for (auto& tensor : it->second) {
res.emplace_back(tensor->MutableVar());
}
return res;
}
DDim GetInputDim(const std::string& name) const override {
auto it = tensor_in_->find(name);
PADDLE_ENFORCE_NE(
it, tensor_in_->end(),
paddle::platform::errors::NotFound("can not find [%s] in input", name));
PADDLE_ENFORCE_EQ(
it->second.size(), 1UL,
paddle::platform::errors::PreconditionNotMet(
"Input(%s) should hold one element, but now it holds %d", name,
it->second.size()));
return this->GetDim(it->second[0]->MutableVar());
}
std::vector<DDim> GetInputsDim(const std::string& name) const override {
// const std::vector<Variable*>& vars = InputVars(name);
std::vector<DDim> vec_res;
auto it = tensor_in_->find(name);
PADDLE_ENFORCE_NE(it, tensor_in_->end(),
paddle::platform::errors::NotFound(
"can not find [%s] in output", name));
vec_res.reserve(it->second.size());
for (size_t i = 0; i < it->second.size(); ++i) {
if (it->second[i]) {
vec_res.emplace_back(GetDim(it->second[i]->MutableVar()));
} else {
vec_res.emplace_back();
}
}
return vec_res;
}
std::vector<paddle::framework::proto::VarType::Type> GetInputsVarType(
const std::string& name) const override {
std::vector<paddle::framework::proto::VarType::Type> vec_res;
auto it = tensor_in_->find(name);
PADDLE_ENFORCE_NE(
it, tensor_in_->end(),
paddle::platform::errors::NotFound("can not find [%s] in input", name));
vec_res.reserve(it->second.size());
for (size_t i = 0; i < it->second.size(); ++i) {
if (it->second[i]) {
vec_res.emplace_back(
paddle::framework::ToVarType(it->second[i]->MutableVar()->Type()));
} else {
vec_res.emplace_back();
}
}
return vec_res;
}
std::vector<paddle::framework::proto::VarType::Type> GetOutputsVarType(
const std::string& name) const override {
std::vector<paddle::framework::proto::VarType::Type> vec_res;
auto it = tensor_out_->find(name);
PADDLE_ENFORCE_NE(it, tensor_out_->end(),
paddle::platform::errors::NotFound(
"can not find [%s] in output", name));
vec_res.reserve(it->second.size());
for (size_t i = 0; i < it->second.size(); ++i) {
if (it->second[i]) {
vec_res.emplace_back(
paddle::framework::ToVarType(it->second[i]->MutableVar()->Type()));
} else {
vec_res.emplace_back(
static_cast<paddle::framework::proto::VarType::Type>(-1));
}
}
return vec_res;
}
void SetOutputDim(const std::string& name, const DDim& dim) override {
auto it = tensor_out_->find(name);
PADDLE_ENFORCE_NE(it, tensor_out_->end(),
paddle::platform::errors::NotFound(
"can not find [%s] in output", name));
if (it->second[0]) {
SetDim(it->second[0]->MutableVar(), dim);
}
}
void SetOutputsDim(const std::string& name,
const std::vector<DDim>& dims) override {
auto it = tensor_out_->find(name);
PADDLE_ENFORCE_NE(it, tensor_out_->end(),
paddle::platform::errors::NotFound(
"can not find [%s] in output", name));
PADDLE_ENFORCE_EQ(dims.size(), it->second.size(),
paddle::platform::errors::InvalidArgument(
"The number of dims is expected to be equal to the "
"number of Outputs(%s). But receieved: the number of "
"dims = %d, the number of Outputs(%s) = %d.",
name, dims.size(), name, it->second.size()));
for (size_t i = 0; i < dims.size(); ++i) {
if (it->second[i]) {
SetDim(it->second[i]->MutableVar(), dims[i]);
}
}
}
int32_t GetLoDLevel(const std::string& in, size_t i = 0) const override {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"GetLoDLevel function not support in dygraph mode"));
}
void SetLoDLevel(const std::string& out, int32_t lod_level,
size_t j = 0) const override {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"SetLoDLevel function not support in dygraph mode"));
}
protected:
DDim GetDim(paddle::framework::Variable* var) const {
PADDLE_ENFORCE_NOT_NULL(var, paddle::platform::errors::PreconditionNotMet(
"Input variable should not be null"));
if (var->IsType<paddle::framework::LoDTensor>()) {
return var->Get<paddle::framework::LoDTensor>().dims();
} else if (var->IsType<pten::SelectedRows>()) {
return var->Get<pten::SelectedRows>().GetCompleteDims();
} else {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Only LoDTensor/SelectedRows support 'GetDim', but Variables "
"type_id is xx."));
}
}
std::vector<DDim> GetRepeatedDims(const std::string& name) const override {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"GetRepeatedDims not support in dygraph runtime"));
}
void SetDim(paddle::framework::Variable* var, const DDim& dim) {
if (var->IsType<paddle::framework::LoDTensor>()) {
var->GetMutable<paddle::framework::LoDTensor>()->Resize(dim);
} else if (var->IsType<pten::SelectedRows>()) {
var->GetMutable<pten::SelectedRows>()->set_height(dim[0]);
} else {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Variable type_id %s, expect LoDTensor/SelectedRows."));
}
}
void SetDims(const std::vector<paddle::framework::Variable*>& vars,
const std::vector<DDim>& dims) {
size_t length = vars.size();
PADDLE_ENFORCE_EQ(
length, dims.size(),
paddle::platform::errors::PreconditionNotMet(
"Vars number [%d] should be equal with dims number [%d]", length,
dims.size()));
for (size_t i = 0; i < length; ++i) {
if (vars[i] == nullptr) {
continue;
}
SetDim(vars[i], dims[i]);
}
}
void SetRepeatedDims(const std::string& name,
const std::vector<DDim>& dims) override {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"SetRepeatedDims not support in dygraph runtime"));
}
private:
const NameTensorMap* tensor_in_;
const NameTensorMap* tensor_out_;
const paddle::framework::AttributeMap* attrs_;
const paddle::framework::AttributeMap* default_attrs_;
const std::string op_type_;
const paddle::framework::OpKernelType* op_kernel_type_;
};
} // namespace legacy
} // namespace egr
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/eager/eager_tensor.h"
#include "paddle/fluid/eager/legacy/tensor_helper.h"
#include "paddle/fluid/eager/legacy/type_def.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/pten/api/all.h"
namespace egr {
namespace legacy {
// infer var type context for imperative mode
class TensorRuntimeInferVarTypeContext
: public paddle::framework::InferVarTypeContext {
public:
TensorRuntimeInferVarTypeContext(
const NameTensorMap& inputs, const NameTensorMap& outputs,
const paddle::framework::AttributeMap& attrs_map,
const paddle::framework::AttributeMap& default_attrs_map)
: InferVarTypeContext(nullptr, nullptr),
inputs_(inputs),
outputs_(outputs),
attrs_(attrs_map),
default_attrs_(default_attrs_map) {}
virtual ~TensorRuntimeInferVarTypeContext() {}
paddle::framework::Attribute GetAttr(const std::string& name) const override {
auto it = attrs_.find(name);
if (it == attrs_.end()) {
it = default_attrs_.find(name);
if (it == default_attrs_.end()) {
PADDLE_THROW(paddle::platform::errors::NotFound(
"Can not find [%s] in attributes.", name));
}
}
return it->second;
}
bool HasInput(const std::string& name) const override {
auto it = inputs_.find(name);
return (it != inputs_.end() && it->second.size() > 0);
}
bool HasOutput(const std::string& name) const override {
auto it = outputs_.find(name);
return (it != outputs_.end() && it->second.size() > 0);
}
size_t InputSize(const std::string& name) const {
return inputs_.at(name).size();
}
const std::string& InputVarName(const std::string& name,
const int index = 0) const {
// TODO(jiabin): Support this usage inputs_.at(name)[index]->Name()
auto it = inputs_.find(name);
PADDLE_ENFORCE_NE(it, inputs_.end(),
paddle::platform::errors::PreconditionNotMet(
"Can not find [%s] in Input", name));
return inputs_.at(name)[index]->name();
}
bool InputTypeAnyOf(
const std::string& name,
paddle::framework::proto::VarType::Type type) const override {
auto& inputs = inputs_.at(name);
return std::any_of(
inputs.begin(), inputs.end(),
[&type](const std::shared_ptr<egr::EagerTensor>& var) {
return paddle::framework::ToVarType(var->Var().Type()) == type;
});
}
bool InputTypeAllOf(
const std::string& name,
paddle::framework::proto::VarType::Type type) const override {
auto& inputs = inputs_.at(name);
return std::all_of(
inputs.begin(), inputs.end(),
[&type](const std::shared_ptr<egr::EagerTensor>& var) {
return paddle::framework::ToVarType(var->Var().Type()) == type;
});
}
void SyncTypeAndDataType(const std::string& input_name,
const std::string& output_name,
int index = 0) override {
auto in_tensor = inputs_.at(input_name)[index];
auto out_tensor = outputs_.at(output_name)[index];
if (in_tensor != out_tensor) {
this->SetTensorType(
out_tensor, paddle::framework::ToVarType(in_tensor->Var().Type()));
}
}
void SetOutputType(const std::string& name,
paddle::framework::proto::VarType::Type type,
int index = 0) override {
if (index == paddle::framework::ALL_ELEMENTS) {
for (auto& item : outputs_.at(name)) {
this->SetTensorType(item, type);
}
} else {
auto& var = outputs_.at(name)[index];
this->SetTensorType(var, type);
}
}
void SetTensorType(std::shared_ptr<egr::EagerTensor> out,
paddle::framework::proto::VarType::Type type) {
switch (type) {
case paddle::framework::proto::VarType::LOD_TENSOR: {
out->MutableVar()->GetMutable<paddle::framework::LoDTensor>();
break;
}
case paddle::framework::proto::VarType::SELECTED_ROWS: {
out->MutableVar()->GetMutable<pten::SelectedRows>();
break;
}
default: {
PADDLE_THROW(paddle::platform::errors::NotFound(
"Cannot found var type: %s while running runtime InferVarType",
paddle::framework::ToTypeName(type)));
}
}
}
paddle::framework::proto::VarType::Type GetInputType(
const std::string& name, const int& index = 0) const override {
return paddle::framework::ToVarType(inputs_.at(name)[index]->Var().Type());
}
paddle::framework::proto::VarType::Type GetOutputType(
const std::string& name, const int& index = 0) const override {
// TODO(jiabin): Support SelectedRows when we have it.
return paddle::framework::proto::VarType::LOD_TENSOR;
}
paddle::framework::proto::VarType::Type GetInputDataType(
const std::string& name, const int& index = 0) const override {
return inputs_.at(name)[index]
->Var()
.Get<paddle::framework::LoDTensor>()
.type();
}
void SetOutputDataType(const std::string& name,
paddle::framework::proto::VarType::Type type,
int index = 0) override {
// TODO(jiabin): It seems doesn't make sense to set data_type in EagerMode.
}
bool IsDygraph() const override { return true; }
protected:
bool HasVar(const std::string& name) const override {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"HasVar is not supported in runtime InferVarType"));
}
const std::vector<std::string>& InputVars(
const std::string& name) const override {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"InputVars is not supported in runtime InferVarType"));
}
const std::vector<std::string>& OutputVars(
const std::string& name) const override {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"OutputVars is not supported in runtime InferVarType"));
}
paddle::framework::proto::VarType::Type GetVarType(
const std::string& name) const override {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Do not manipulate var in runtime InferVarType"));
}
void SetVarType(const std::string& name,
paddle::framework::proto::VarType::Type type) override {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Do not manipulate var in runtime InferVarType"));
}
paddle::framework::proto::VarType::Type GetVarDataType(
const std::string& name) const override {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Do not manipulate var in runtime InferVarType"));
}
void SetVarDataType(const std::string& name,
paddle::framework::proto::VarType::Type type) override {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Do not manipulate var in runtime InferVarType"));
}
std::vector<paddle::framework::proto::VarType::Type> GetVarDataTypes(
const std::string& name) const override {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"GetVarDataTypes is not supported in runtime InferVarType"));
}
void SetVarDataTypes(
const std::string& name,
const std::vector<paddle::framework::proto::VarType::Type>&
multiple_data_type) override {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"SetVarDataTypes is not supported in runtime InferVarType"));
}
std::vector<int64_t> GetVarShape(const std::string& name) const override {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Do not handle Shape in runtime InferVarType"));
}
void SetVarShape(const std::string& name,
const std::vector<int64_t>& dims) override {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Do not handle Shape in runtime InferVarType"));
}
int32_t GetVarLoDLevel(const std::string& name) const override {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Do not handle LoDLevel in runtime InferVarType"));
}
void SetVarLoDLevel(const std::string& name, int32_t lod_level) override {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Do not handle LoDLevel in runtime InferVarType"));
}
private:
const NameTensorMap& inputs_;
const NameTensorMap& outputs_;
const paddle::framework::AttributeMap& attrs_;
const paddle::framework::AttributeMap& default_attrs_;
};
} // namespace legacy
} // namespace egr
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/eager/legacy/op_runner.h"
#include <map>
#include <set>
#include <unordered_set>
#include <utility>
#include "paddle/fluid/eager/legacy/amp_auto_cast.h"
#include "paddle/fluid/eager/legacy/infer_var_type_context.h"
#include "paddle/fluid/eager/legacy/prepared_operator.h"
#include "paddle/fluid/eager/legacy/tensor_helper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/denormal.h"
#include "paddle/fluid/string/string_helper.h"
DECLARE_bool(use_mkldnn);
DECLARE_string(tracer_mkldnn_ops_on);
DECLARE_string(tracer_mkldnn_ops_off);
namespace egr {
namespace legacy {
void OpRunImpl(const paddle::framework::OperatorBase& op,
const NameTensorMap& ins, const NameTensorMap& outs,
const paddle::framework::AttributeMap& attrs,
const paddle::framework::AttributeMap& default_attrs,
const paddle::platform::Place& place) {
VLOG(6) << "Get Opertor With Kernel";
auto* op_kernel =
dynamic_cast<const paddle::framework::OperatorWithKernel*>(&op);
PADDLE_ENFORCE_NOT_NULL(
op_kernel, paddle::platform::errors::PermissionDenied(
"Only support operator with kernel in Dygraph mode."));
auto& info = op.Info();
if (info.infer_var_type_) {
VLOG(6) << "Run InferVarType";
egr::legacy::TensorRuntimeInferVarTypeContext infer_var_type_ctx(
ins, outs, attrs, default_attrs);
VLOG(9) << "Actual Run InferVarType";
info.infer_var_type_(&infer_var_type_ctx);
}
VLOG(6) << "Initialize output tensor";
// Initialize output tensor
for (auto& tensor_pair : outs) {
for (auto& tensor : tensor_pair.second) {
if (tensor && tensor.get() && (!tensor->Var().IsInitialized())) {
InitializeVariable(tensor->MutableVar(),
paddle::framework::proto::VarType::LOD_TENSOR);
}
}
}
/**
* [ Why need temporary inputs here? ]
*
* PrepareData should not change original input tensor inplace.
* Suppose the user defines a tensor(int), enters an op to execute,
* and then this op rewrites GetExpectedKernelForVar, and converts
* this tensor to float type during execution. After the dynamic
* graph is executed, the user-defined variable will be lost, and
* the user cannot get the originally defined int tensor, because
* it has been converted to float, this should be regarded as a bug
* in certain usage scenarios
*
* In static graph mode, when op is executed, a temporary scope
* `transfer_scope` is created before PrepareData, the data after
* transform is stored in the temporary scope, and then discarded
* after the execution of op, but the original input is directly
* overwritten in the previous dynamic graph implemention.
*/
VLOG(6) << "Prepare Op";
auto prepared_op = egr::legacy::PreparedOp::Prepare(
ins, outs, *op_kernel, place, attrs, default_attrs);
VLOG(6) << "Prepare Data";
auto tmp_ins_ptr =
egr::legacy::PrepareData(*op_kernel, ins, prepared_op.kernel_type());
VLOG(6) << "Run Prepared Op";
if (tmp_ins_ptr == nullptr) {
prepared_op.Run(ins, outs, attrs, default_attrs);
} else {
prepared_op.Run(*tmp_ins_ptr, outs, attrs, default_attrs);
}
VLOG(6) << "Run Prepared Op end";
// TODO(jiabin): Set the output var's grad Forward DataType
}
void RunOp(const std::string& type, const NameTensorMap& ins,
const NameTensorMap& outs, paddle::framework::AttributeMap attrs,
const paddle::platform::Place& place,
paddle::framework::AttributeMap* default_attrs,
bool override_default_attr_map,
const std::map<std::string, std::string>& inplace_map) {
VLOG(1) << "Run Op: " << type;
if (FLAGS_use_mkldnn) {
// if both lists are empty all ops are enabled (default for
// FLAGS_use_mkldnn=1)
// if ops_on list is not empty only ops from that list are enabled
if (!FLAGS_tracer_mkldnn_ops_on.empty()) {
auto is_on = FLAGS_tracer_mkldnn_ops_on.find(type) != std::string::npos;
attrs["use_mkldnn"] = is_on;
} else {
// if ops_on list is empty all ops are enabled except types from off_list
auto is_off = FLAGS_tracer_mkldnn_ops_off.find(type) != std::string::npos;
attrs["use_mkldnn"] = !is_off;
}
}
auto op = paddle::framework::OpRegistry::CreateOp(type, {}, {}, {}, false);
PADDLE_ENFORCE_NOT_NULL(default_attrs,
paddle::platform::errors::PermissionDenied(
"Detected default_attrs = nullptr."));
if (override_default_attr_map) {
const auto& op_info = op->Info();
auto* attr_checker = op_info.Checker();
if (attr_checker) {
attr_checker->Check(&attrs, true, /*only_check_exist_value=*/true);
}
static paddle::framework::AttributeMap empty_attrs_map = {};
*default_attrs = attr_checker == nullptr
? empty_attrs_map
: attr_checker->GetDefaultAttrMap();
}
auto amp_level = egr::Controller::Instance().GetAMPLevel();
VLOG(6) << "Check AMP status";
NameTensorMap new_ins = ins;
if (amp_level == paddle::imperative::AmpLevel::O1) {
VLOG(5) << "Auto mixed precision run operator: " << type;
new_ins = AutoCastInputs(type, ins);
} else if (amp_level == paddle::imperative::AmpLevel::O2) {
VLOG(5) << "Pure fp16 run operator: " << type;
new_ins = CastPureFp16Inputs(type, ins);
}
try {
VLOG(6) << "Get Device id";
if (paddle::platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
paddle::platform::SetDeviceId(place.device);
#else
PADDLE_THROW(paddle::platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU if use CUDAPlace."));
#endif
} else if (paddle::platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
paddle::platform::SetXPUDeviceId(place.device);
#else
PADDLE_THROW(paddle::platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with XPU if use XPUPlace."));
#endif
} else if (paddle::platform::is_npu_place(place)) {
#ifdef PADDLE_WITH_ASCEND_CL
paddle::platform::SetNPUDeviceId(place.device);
#else
PADDLE_THROW(paddle::platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with NPU if use NPUPlace."));
#endif
}
VLOG(6) << "Step in OpRunImpl";
OpRunImpl(*op, new_ins, outs, attrs, *default_attrs, place);
} catch (paddle::platform::EnforceNotMet& exception) {
paddle::framework::AppendErrorOpHint(type, &exception);
throw std::move(exception);
} catch (std::exception& ex) {
PADDLE_THROW(paddle::platform::errors::Fatal(
"Operator %s raises an %s exception.\n"
"The exception content is\n:%s.",
type, paddle::platform::demangle(typeid(ex).name()), ex.what()));
} catch (...) {
// NOTE: this branch represents a very serious bug with
// low probability of occurrence, and we can't get its
// exception content here.
PADDLE_THROW(paddle::platform::errors::Fatal(
"Operator %s raises an unknown exception.", type));
}
VLOG(6) << "Finish Run Op";
// TODO(jiabin): Support this later
// if (enable_program_desc_tracing_) {
// VLOG(5) << "Trace op " << type << " into ProgramDesc";
// program_desc_tracer_->InsertOp(type, new_ins, outs, attrs);
// }
}
} // namespace legacy
} // namespace egr
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/eager/legacy/type_def.h"
#include "paddle/fluid/imperative/jit/program_desc_tracer.h"
#include "paddle/pten/core/tensor_meta.h"
namespace egr {
namespace legacy {
void RunOp(const std::string& type, const NameTensorMap& ins,
const NameTensorMap& outs, paddle::framework::AttributeMap attrs,
const paddle::platform::Place& place,
paddle::framework::AttributeMap* default_attrs,
bool override_default_attr_map,
const std::map<std::string, std::string>& inplace_map = {});
} // namespace legacy
} // namespace egr
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/eager/legacy/prepared_operator.h"
#include "paddle/fluid/imperative/prepared_operator.h"
#include "paddle/fluid/eager/legacy/infer_shape_context.h"
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/details/nan_inf_utils.h"
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/utils/small_vector.h"
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/platform/device/xpu/xpu_op_list.h"
#endif
DECLARE_bool(check_nan_inf);
namespace egr {
namespace legacy {
const paddle::framework::Tensor* GetTensorFromVar(
const paddle::framework::Variable& var) {
if (var.IsType<paddle::framework::LoDTensor>()) {
return &(var.Get<paddle::framework::LoDTensor>());
} else if (var.IsType<pten::SelectedRows>()) {
return &(var.Get<pten::SelectedRows>().value());
} else {
return nullptr;
}
}
static const paddle::framework::Attribute& GetAttr(
const paddle::framework::AttributeMap& attrs,
const paddle::framework::AttributeMap& default_attrs,
const std::string& name) {
auto it = attrs.find(name);
bool found = it != attrs.end();
if (!found) {
it = default_attrs.find(name);
found = it != default_attrs.end();
}
PADDLE_ENFORCE_EQ(found, true,
paddle::platform::errors::NotFound(
"(%s) is not found in AttributeMap.", name));
return it->second;
}
static void HandleComplexGradToRealGrad(const NameTensorMap& outs) {
// TODO(jiabin): Support complex forward datatype later.
}
PreparedOp::PreparedOp(
const paddle::framework::OperatorBase& op,
const paddle::framework::RuntimeContext& ctx,
const paddle::framework::OpKernelType& kernel_type,
const paddle::framework::OperatorWithKernel::OpKernelFunc& func,
paddle::platform::DeviceContext* dev_ctx)
: op_(op),
ctx_(ctx),
kernel_type_(kernel_type),
func_(func),
dev_ctx_(dev_ctx) {}
PreparedOp::PreparedOp(
const paddle::framework::OperatorBase& op,
const paddle::framework::RuntimeContext& ctx,
const paddle::framework::OpKernelType& kernel_type,
const paddle::framework::KernelSignature& kernel_signature,
const pten::Kernel& pt_kernel, paddle::platform::DeviceContext* dev_ctx)
: op_(op),
ctx_(ctx),
kernel_type_(kernel_type),
func_(nullptr),
dev_ctx_(dev_ctx),
run_pten_kernel_(true),
pt_kernel_signature_(kernel_signature),
pt_kernel_(pt_kernel) {}
PreparedOp PrepareImpl(const NameTensorMap& ins, const NameTensorMap& outs,
const paddle::framework::OperatorWithKernel& op,
const paddle::platform::Place& place,
const paddle::framework::AttributeMap& attrs,
const paddle::framework::AttributeMap& default_attrs) {
VLOG(6) << "Preparing an Op";
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
paddle::framework::RuntimeContext ctx({}, {});
#ifdef PADDLE_WITH_MKLDNN
// MKLDNN variant of code reads attributes in some of GetKernelTypeForVar and
// GetKernelType functions, so we need to copy the attributes there.
// Const qualifier of Attrs had to be discarded to overwrite it.
if (FLAGS_use_mkldnn) {
auto& mutable_op_attrs =
const_cast<paddle::framework::AttributeMap&>(op.Attrs());
mutable_op_attrs = default_attrs;
for (auto& attr : attrs) {
mutable_op_attrs[attr.first] = attr.second;
}
}
#endif
// 1. get expected kernel key
auto dygraph_exe_ctx = egr::legacy::EagerExecutionContext(
op, paddle::framework::Scope(), *dev_ctx, ctx, ins, outs, attrs,
default_attrs);
auto expected_kernel_key = op.GetExpectedKernelType(dygraph_exe_ctx);
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
// fit for pten
pten::KernelSignature pt_kernel_signature;
pten::KernelKey pt_kernel_key;
std::string pt_kernel_name;
if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(op.Type())) {
pt_kernel_signature = op.GetExpectedPtenKernelArgs(dygraph_exe_ctx);
VLOG(6) << pt_kernel_signature;
pt_kernel_name = pt_kernel_signature.name;
pt_kernel_key = TransOpKernelTypeToPtenKernelKey(expected_kernel_key);
auto pt_kernel = pten::KernelFactory::Instance().SelectKernel(
pt_kernel_name, pt_kernel_key);
if (pt_kernel.IsValid()) {
VLOG(6) << "Dynamic mode PrepareImpl - kernel name: " << pt_kernel_name
<< " | kernel key: " << pt_kernel_key
<< " | kernel: " << pt_kernel;
// TODO(chenweihang): using CPUKernel when miss device kernel case
return PreparedOp(op, ctx, expected_kernel_key, pt_kernel_signature,
pt_kernel, dev_ctx);
} else {
VLOG(6) << "Dynamic mode ChoosePtenKernel - kernel `" << pt_kernel_name
<< "` not found.";
}
}
// 2. check if op[type] has kernel registered.
auto& all_op_kernels = op.AllOpKernels();
auto kernels_iter = all_op_kernels.find(op.Type());
if (kernels_iter == all_op_kernels.end() ||
kernels_iter->second.find(expected_kernel_key) ==
kernels_iter->second.end()
#ifdef PADDLE_WITH_XPU
||
paddle::platform::is_xpu_place(expected_kernel_key.place_) &&
!paddle::platform::is_xpu_support_op(op.Type(),
expected_kernel_key) ||
paddle::platform::is_in_xpu_black_list(op.Type())
#endif
) {
if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(op.Type())) {
auto pt_cpu_kernel_key =
FallBackToCpu(expected_kernel_key, pt_kernel_key, op);
auto pt_cpu_kernel = pten::KernelFactory::Instance().SelectKernel(
pt_kernel_name, pt_cpu_kernel_key);
if (pt_cpu_kernel.IsValid()) {
VLOG(6) << "Dynamic mode PrepareImpl - kernel name: " << pt_kernel_name
<< " | kernel key: " << pt_cpu_kernel_key
<< " | kernel: " << pt_cpu_kernel;
return PreparedOp(op, ctx, expected_kernel_key, pt_kernel_signature,
pt_cpu_kernel, dev_ctx);
}
}
}
PADDLE_ENFORCE_NE(
kernels_iter, all_op_kernels.end(),
paddle::platform::errors::NotFound(
"There are no kernels which are registered in the %s operator.",
op.Type()));
auto& kernels = kernels_iter->second;
auto kernel_iter = kernels.find(expected_kernel_key);
#ifdef PADDLE_WITH_XPU
if (paddle::platform::is_xpu_place(expected_kernel_key.place_) &&
(kernel_iter == kernels.end() ||
!paddle::platform::is_xpu_support_op(op.Type(), expected_kernel_key) ||
paddle::platform::is_in_xpu_black_list(op.Type()))) {
VLOG(3) << "missing XPU kernel: " << op.Type()
<< ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!";
expected_kernel_key.place_ = paddle::platform::CPUPlace();
kernel_iter = kernels.find(expected_kernel_key);
}
#endif
#ifdef PADDLE_WITH_ASCEND_CL
if (kernel_iter == kernels.end() &&
paddle::platform::is_npu_place(expected_kernel_key.place_)) {
VLOG(3) << "missing NPU kernel: " << op.Type()
<< ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!";
expected_kernel_key.place_ = paddle::platform::CPUPlace();
kernel_iter = kernels.find(expected_kernel_key);
}
#endif
// TODO(jiabin): Add operator.cc's line 1000 part back when we need that
// case
PADDLE_ENFORCE_NE(kernel_iter, kernels.end(),
paddle::platform::errors::NotFound(
"Operator %s does not have kernel for %s.", op.Type(),
KernelTypeToString(expected_kernel_key)));
if (!(expected_kernel_key.place_ == place)) {
dev_ctx = pool.Get(expected_kernel_key.place_);
}
VLOG(6) << "Construct Prepared Op";
return PreparedOp(op, ctx, expected_kernel_key, kernel_iter->second, dev_ctx);
}
PreparedOp PreparedOp::Prepare(
const NameTensorMap& ins, const NameTensorMap& outs,
const paddle::framework::OperatorWithKernel& op,
const paddle::platform::Place& place,
const paddle::framework::AttributeMap& attrs,
const paddle::framework::AttributeMap& default_attrs) {
return PrepareImpl(ins, outs, op, place, attrs, default_attrs);
}
static void PreparedOpRunImpl(
const paddle::framework::OperatorBase& op,
const paddle::framework::RuntimeContext& ctx,
const paddle::framework::OpKernelType& kernel_type,
const paddle::framework::OperatorWithKernel::OpKernelFunc& func,
paddle::platform::DeviceContext* dev_ctx, const NameTensorMap& ins,
const NameTensorMap& outs, const paddle::framework::AttributeMap& attrs,
const paddle::framework::AttributeMap& default_attrs) {
// TODO(zjl): remove scope in dygraph
VLOG(6) << "Runing Prepared Op";
paddle::framework::Scope scope;
EagerInferShapeContext infer_shape_ctx(&ins, &outs, &attrs, &default_attrs,
op.Type(), &kernel_type);
op.Info().infer_shape_(&infer_shape_ctx);
func(EagerExecutionContext(op, scope, *dev_ctx, ctx, ins, outs, attrs,
default_attrs));
if (FLAGS_check_nan_inf) {
paddle::framework::details::CheckOpHasNanOrInfInEager<EagerTensor>(
op.Type(), outs, dev_ctx->GetPlace());
}
/**
* [ Why need handle complex gradient to real gradient? ]
*
* After the introduction of complex number calculations, Ops that support
* complex number calculations generally support type promotion, such as
* x(float32) + y(complex64) = out(complex64), then the type of the grad
* tensor should be dout(complex64), dx(float32), dy (complex64).
*
* But because the dout is complex64, the dx is also complex64 after
* grad op kernel executed, we need to recognize this situation and
* convert dx to float32 type. HandleComplexGradToRealGrad does this thing.
*/
if (paddle::framework::IsComplexType(kernel_type.data_type_)) {
HandleComplexGradToRealGrad(outs);
}
VLOG(6) << "Finish Runing Prepared Op";
}
static void PreparedOpRunPtImpl(
const paddle::framework::OperatorBase& op,
const paddle::framework::OpKernelType& kernel_type,
const paddle::framework::KernelSignature& pt_kernel_signature,
const pten::Kernel& pt_kernel, paddle::platform::DeviceContext* dev_ctx,
const NameTensorMap& ins, const NameTensorMap& outs,
const paddle::framework::AttributeMap& attrs,
const paddle::framework::AttributeMap& default_attrs) {
EagerInferShapeContext infer_shape_ctx(&ins, &outs, &attrs, &default_attrs,
op.Type());
static_cast<const paddle::framework::OperatorWithKernel&>(op).InferShape(
&infer_shape_ctx);
paddle::imperative::PreparePtenData<EagerTensor>(
pt_kernel, pt_kernel_signature,
static_cast<paddle::imperative::NameTensorMap>(ins));
pten::KernelContext pt_kernel_context;
paddle::imperative::BuildDygraphPtenKernelContext<EagerTensor>(
pt_kernel_signature, pt_kernel,
static_cast<paddle::imperative::NameTensorMap>(ins),
static_cast<paddle::imperative::NameTensorMap>(outs), attrs,
default_attrs, dev_ctx, &pt_kernel_context);
pt_kernel(&pt_kernel_context);
// TODO(chenweihang): add debug flags later
// TODO(chenweihang): deal with complex cases later
}
void PreparedOp::Run(const NameTensorMap& ins, const NameTensorMap& outs,
const paddle::framework::AttributeMap& attrs,
const paddle::framework::AttributeMap& default_attrs) {
if (run_pten_kernel_) {
PreparedOpRunPtImpl(op_, kernel_type_, pt_kernel_signature_, pt_kernel_,
dev_ctx_, ins, outs, attrs, default_attrs);
} else {
PreparedOpRunImpl(op_, ctx_, kernel_type_, func_, dev_ctx_, ins, outs,
attrs, default_attrs);
}
}
std::shared_ptr<NameTensorMap> PrepareData(
const paddle::framework::OperatorWithKernel& op, const NameTensorMap& ins,
const paddle::framework::OpKernelType& expected_kernel_key) {
std::shared_ptr<NameTensorMap> tmp_ins_ptr = nullptr;
for (const auto& name_pair : ins) {
for (size_t i = 0; i < name_pair.second.size(); ++i) {
auto& egr_tensor = name_pair.second[i];
const auto* tensor = GetTensorFromVar(egr_tensor->Var());
if (tensor && tensor->IsInitialized()) {
auto kernel_type_for_var = op.GetKernelTypeForVar(
name_pair.first, *tensor, expected_kernel_key);
if (!NeedTransform(kernel_type_for_var, expected_kernel_key)) {
continue;
} else {
// TODO(jiabin): Support Cache later
VLOG(3) << "Transform Variable " << egr_tensor->name() << " from "
<< kernel_type_for_var << " to " << expected_kernel_key;
paddle::framework::Tensor out;
TransformData(expected_kernel_key, kernel_type_for_var, *tensor,
&out);
if (NeedTransformDataType(kernel_type_for_var, expected_kernel_key)) {
// To avoid NameVarMap copy construction overhead in general
// scenarios, if inplace transformed, return original input
// directly
if (tmp_ins_ptr == nullptr) {
tmp_ins_ptr = std::make_shared<NameTensorMap>(ins);
}
auto tmp_egr_tensor =
std::make_shared<EagerTensor>(egr_tensor->name());
SetTensorToVariable(egr_tensor->Var(), out,
tmp_egr_tensor->MutableVar());
(*tmp_ins_ptr)[name_pair.first][i] = tmp_egr_tensor;
} else {
// if dtype is same, transform inplace will not change the
// original
// value, transform inplace to avoid multiple copy
SetTensorToVariable(egr_tensor->Var(), out,
egr_tensor->MutableVar());
}
}
}
}
}
return tmp_ins_ptr;
}
} // namespace legacy
} // namespace egr
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/eager/legacy/execution_context.h"
#include "paddle/fluid/eager/legacy/type_def.h"
#include "paddle/fluid/framework/data_transform.h"
#include "paddle/fluid/framework/op_kernel_type.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/type_defs.h"
DECLARE_bool(use_mkldnn);
namespace paddle {
namespace framework {
class Variable;
} // namespace framework
} // namespace paddle
namespace pten {
class DenseTensor;
} // namespace pten
namespace egr {
namespace legacy {
const paddle::framework::Tensor* GetTensorFromVar(
const paddle::framework::Variable& var);
std::shared_ptr<NameTensorMap> PrepareData(
const paddle::framework::OperatorWithKernel& op, const NameTensorMap& ins,
const paddle::framework::OpKernelType& expected_kernel_key);
class PreparedOp {
public:
PreparedOp(const paddle::framework::OperatorBase& op,
const paddle::framework::RuntimeContext& ctx,
const paddle::framework::OpKernelType& kernel_type,
const paddle::framework::OperatorWithKernel::OpKernelFunc& func,
paddle::platform::DeviceContext* dev_ctx);
PreparedOp(const paddle::framework::OperatorBase& op,
const paddle::framework::RuntimeContext& ctx,
const paddle::framework::OpKernelType& kernel_type,
const paddle::framework::KernelSignature& kernel_signature,
const pten::Kernel& pt_kernel,
paddle::platform::DeviceContext* dev_ctx);
static PreparedOp Prepare(
const NameTensorMap& ins, const NameTensorMap& outs,
const paddle::framework::OperatorWithKernel& op,
const paddle::platform::Place& place,
const paddle::framework::AttributeMap& attrs,
const paddle::framework::AttributeMap& default_attrs);
void Run(const NameTensorMap& in, const NameTensorMap& out,
const paddle::framework::AttributeMap& attrs,
const paddle::framework::AttributeMap& default_attrs);
const paddle::framework::OpKernelType& kernel_type() const {
return kernel_type_;
}
private:
const paddle::framework::OperatorBase& op_;
const paddle::framework::RuntimeContext& ctx_;
paddle::framework::OpKernelType kernel_type_;
paddle::framework::OperatorWithKernel::OpKernelFunc func_;
paddle::platform::DeviceContext* dev_ctx_;
// NOTE(chenweihang): Similar op members are used to adapt to
// new pten kernel, if there is a better design in the future,
// we may polish the implementation here
bool run_pten_kernel_{false};
paddle::framework::KernelSignature pt_kernel_signature_;
pten::Kernel pt_kernel_;
};
} // namespace legacy
} // namespace egr
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/eager/legacy/tensor_helper.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows_utils.h"
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/platform/place.h"
namespace egr {
namespace legacy {
void InitializeVariable(paddle::framework::Variable *var,
paddle::framework::proto::VarType::Type var_type) {
if (var_type == paddle::framework::proto::VarType::LOD_TENSOR) {
var->GetMutable<paddle::framework::LoDTensor>();
} else if (var_type == paddle::framework::proto::VarType::SELECTED_ROWS) {
var->GetMutable<pten::SelectedRows>();
} else if (var_type == paddle::framework::proto::VarType::FEED_MINIBATCH) {
var->GetMutable<paddle::framework::FeedList>();
} else if (var_type == paddle::framework::proto::VarType::FETCH_LIST) {
var->GetMutable<paddle::framework::FetchList>();
} else if (var_type == paddle::framework::proto::VarType::STEP_SCOPES) {
var->GetMutable<std::vector<paddle::framework::Scope *>>();
} else if (var_type == paddle::framework::proto::VarType::LOD_RANK_TABLE) {
var->GetMutable<paddle::framework::LoDRankTable>();
} else if (var_type == paddle::framework::proto::VarType::LOD_TENSOR_ARRAY) {
var->GetMutable<paddle::framework::LoDTensorArray>();
} else if (var_type == paddle::framework::proto::VarType::STRINGS) {
var->GetMutable<paddle::framework::Strings>();
} else if (var_type == paddle::framework::proto::VarType::VOCAB) {
var->GetMutable<paddle::framework::Vocab>();
} else if (var_type == paddle::framework::proto::VarType::PLACE_LIST) {
var->GetMutable<paddle::platform::PlaceList>();
} else if (var_type == paddle::framework::proto::VarType::READER) {
var->GetMutable<paddle::framework::ReaderHolder>();
} else if (var_type == paddle::framework::proto::VarType::RAW) {
// GetMutable will be called in operator
} else {
PADDLE_THROW(paddle::platform::errors::Unavailable(
"paddle::framework::Variable type %d is not in "
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
"LOD_RANK_TABLE, PLACE_LIST, READER, RAW].",
var_type));
}
}
void CopyVariable(const paddle::framework::Variable &src_var,
paddle::framework::Variable *dst_var) {
// only support cpu now
auto cpu_place = paddle::platform::CPUPlace();
if (src_var.IsType<paddle::framework::LoDTensor>()) {
auto *tmp_grad_tensor = dst_var->GetMutable<paddle::framework::LoDTensor>();
auto &src_tensor = src_var.Get<paddle::framework::LoDTensor>();
tmp_grad_tensor->set_lod(src_tensor.lod());
paddle::framework::TensorCopy(src_tensor, cpu_place, tmp_grad_tensor);
} else if (src_var.IsType<pten::SelectedRows>()) {
auto &src_slr = src_var.Get<pten::SelectedRows>();
auto *tmp_grad_slr = dst_var->GetMutable<pten::SelectedRows>();
tmp_grad_slr->set_rows(src_slr.rows());
tmp_grad_slr->set_height(src_slr.height());
auto &src_t = src_slr.value();
auto *dst_t = tmp_grad_slr->mutable_value();
paddle::framework::TensorCopy(src_t, cpu_place, dst_t);
} else {
PADDLE_THROW(paddle::platform::errors::Unavailable(
"Unknown variable type to copy."));
}
}
paddle::framework::proto::VarType::Type GetDtypeFromVar(
const paddle::framework::Variable &var) {
if (var.IsType<paddle::framework::LoDTensor>()) {
return var.Get<paddle::framework::LoDTensor>().type();
} else if (var.IsType<pten::SelectedRows>()) {
return var.Get<pten::SelectedRows>().value().type();
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Variable type is %s, expect LoDTensor or SelectedRows.",
paddle::framework::ToTypeName(var.Type())));
}
}
const paddle::platform::Place &GetPlaceFromVar(
const paddle::framework::Variable &var) {
if (var.IsType<paddle::framework::LoDTensor>()) {
return var.Get<paddle::framework::LoDTensor>().place();
} else if (var.IsType<pten::SelectedRows>()) {
return var.Get<pten::SelectedRows>().place();
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Variable type is %s, expect LoDTensor or SelectedRows.",
paddle::framework::ToTypeName(var.Type())));
}
}
} // namespace legacy
} // namespace egr
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/platform/macros.h"
namespace egr {
class EagerTensor;
namespace legacy {
namespace details {
template <typename T>
struct NameVarMapTrait {};
template <>
struct NameVarMapTrait<EagerTensor> {
using Type =
std::map<std::string, std::vector<std::shared_ptr<egr::EagerTensor>>>;
};
} // namespace details
template <typename T>
using NameMap = typename details::NameVarMapTrait<T>::Type;
using NameTensorMap = NameMap<EagerTensor>;
} // namespace legacy
} // namespace egr
...@@ -214,7 +214,7 @@ void benchmark_fluid_scale(const std::shared_ptr<imperative::VarBase>& X, ...@@ -214,7 +214,7 @@ void benchmark_fluid_scale(const std::shared_ptr<imperative::VarBase>& X,
{std::shared_ptr<imperative::VarBase>( {std::shared_ptr<imperative::VarBase>(
new imperative::VarBase(true, "Out"))}}}; new imperative::VarBase(true, "Out"))}}};
tracer.TraceOp("scale", ins, outs, attrs, place, true); tracer.TraceOp<VarBase>("scale", ins, outs, attrs, place, true);
tmp_out = outs["Out"][0]; tmp_out = outs["Out"][0];
} }
...@@ -250,7 +250,7 @@ void benchmark_fluid_matmul(const std::shared_ptr<imperative::VarBase>& X, ...@@ -250,7 +250,7 @@ void benchmark_fluid_matmul(const std::shared_ptr<imperative::VarBase>& X,
{std::shared_ptr<imperative::VarBase>( {std::shared_ptr<imperative::VarBase>(
new imperative::VarBase(true, "Out"))}}}; new imperative::VarBase(true, "Out"))}}};
tracer.TraceOp("matmul_v2", ins, outs, attrs, place, true); tracer.TraceOp<VarBase>("matmul_v2", ins, outs, attrs, place, true);
tmp_out = outs["Out"][0]; tmp_out = outs["Out"][0];
} }
...@@ -288,7 +288,7 @@ void benchmark_fluid_mlp( ...@@ -288,7 +288,7 @@ void benchmark_fluid_mlp(
{std::shared_ptr<imperative::VarBase>( {std::shared_ptr<imperative::VarBase>(
new imperative::VarBase(true, "Out"))}}}; new imperative::VarBase(true, "Out"))}}};
tracer.TraceOp("matmul_v2", ins, outs, attrs, place, true); tracer.TraceOp<VarBase>("matmul_v2", ins, outs, attrs, place, true);
// EW-Add0 // EW-Add0
ins = {{"X", outs["Out"]}, {"Y", {Bs[i]}}}; ins = {{"X", outs["Out"]}, {"Y", {Bs[i]}}};
...@@ -296,7 +296,7 @@ void benchmark_fluid_mlp( ...@@ -296,7 +296,7 @@ void benchmark_fluid_mlp(
{std::shared_ptr<imperative::VarBase>( {std::shared_ptr<imperative::VarBase>(
new imperative::VarBase(true, "Out"))}}}; new imperative::VarBase(true, "Out"))}}};
tracer.TraceOp("elementwise_add", ins, outs, attrs, place, true); tracer.TraceOp<VarBase>("elementwise_add", ins, outs, attrs, place, true);
input0 = outs["Out"][0]; input0 = outs["Out"][0];
} }
...@@ -307,7 +307,7 @@ void benchmark_fluid_mlp( ...@@ -307,7 +307,7 @@ void benchmark_fluid_mlp(
new imperative::VarBase(true, "Out"))}}}; new imperative::VarBase(true, "Out"))}}};
attrs = {{"reduce_all", true}}; attrs = {{"reduce_all", true}};
tracer.TraceOp("reduce_sum", ins, outs, attrs, place, true); tracer.TraceOp<VarBase>("reduce_sum", ins, outs, attrs, place, true);
auto* engine = tracer.GetEngine(); auto* engine = tracer.GetEngine();
std::vector<std::shared_ptr<imperative::VarBase>> grad_tensors{nullptr}; std::vector<std::shared_ptr<imperative::VarBase>> grad_tensors{nullptr};
......
...@@ -17,12 +17,11 @@ ...@@ -17,12 +17,11 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/eager/legacy/type_def.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/imperative/var_helper.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
...@@ -49,20 +48,8 @@ void CheckOpHasNanOrInfInDygraph(const std::string& op_type, ...@@ -49,20 +48,8 @@ void CheckOpHasNanOrInfInDygraph(const std::string& op_type,
for (const auto& ivar : pair.second) { for (const auto& ivar : pair.second) {
auto* var = ivar->MutableVar(); auto* var = ivar->MutableVar();
if (var == nullptr) continue; if (var == nullptr) continue;
CheckVarHasNanOrInf(op_type, ivar->Name(), var, place); CheckVarHasNanOrInf(op_type, paddle::imperative::GetNameFromVar(ivar),
} var, place);
}
}
template <typename TensorType>
static void CheckOpHasNanOrInfInEager(
const std::string& op_type, const egr::legacy::NameMap<TensorType>& op_outs,
platform::Place place) {
for (const auto& pair : op_outs) {
for (const auto& tensor : pair.second) {
auto* var = tensor->MutableVar();
if (var == nullptr) continue;
CheckVarHasNanOrInf(op_type, tensor->name(), var, place);
} }
} }
} }
......
cc_library(imperative_flag SRCS flags.cc DEPS gflags flags) cc_library(imperative_flag SRCS flags.cc DEPS gflags flags)
cc_library(var_helper SRCS var_helper.cc DEPS tensor pten_api)
IF(WITH_XPU) IF(WITH_XPU)
cc_library(prepared_operator SRCS prepared_operator.cc DEPS xpu_op_list proto_desc operator device_context lod_tensor selected_rows_utils var_type_traits op_kernel_type data_transform nan_inf_utils pten pten_utils pten_api) cc_library(prepared_operator SRCS prepared_operator.cc DEPS xpu_op_list proto_desc operator device_context lod_tensor selected_rows_utils var_type_traits op_kernel_type data_transform nan_inf_utils pten_api pten pten_utils var_helper)
ELSE() ELSE()
cc_library(prepared_operator SRCS prepared_operator.cc DEPS proto_desc operator device_context lod_tensor selected_rows_utils var_type_traits op_kernel_type data_transform nan_inf_utils pten pten_utils pten_api) cc_library(prepared_operator SRCS prepared_operator.cc DEPS proto_desc operator device_context lod_tensor selected_rows_utils var_type_traits op_kernel_type data_transform nan_inf_utils pten_api pten pten_utils var_helper)
ENDIF() ENDIF()
cc_library(layer SRCS layer.cc DEPS prepared_operator math_function imperative_flag variable_helper op_registry) cc_library(layer SRCS layer.cc DEPS prepared_operator math_function imperative_flag variable_helper op_registry var_helper pten_api)
add_subdirectory(jit) add_subdirectory(jit)
cc_library(amp SRCS amp_auto_cast.cc DEPS layer ) cc_library(amp SRCS amp_auto_cast.cc DEPS layer var_helper)
cc_library(tracer SRCS tracer.cc DEPS layer engine program_desc_tracer amp denormal garbage_collector) cc_library(tracer SRCS tracer.cc DEPS layer engine program_desc_tracer amp denormal garbage_collector var_helper)
cc_library(basic_engine SRCS basic_engine.cc DEPS layer gradient_accumulator) cc_library(basic_engine SRCS basic_engine.cc DEPS layer gradient_accumulator)
cc_library(engine SRCS basic_engine.cc partial_grad_engine.cc DEPS layer gradient_accumulator) cc_library(engine SRCS basic_engine.cc partial_grad_engine.cc DEPS layer gradient_accumulator)
cc_library(imperative_profiler SRCS profiler.cc DEPS flags) cc_library(imperative_profiler SRCS profiler.cc DEPS flags)
......
...@@ -13,11 +13,12 @@ ...@@ -13,11 +13,12 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/imperative/amp_auto_cast.h" #include "paddle/fluid/imperative/amp_auto_cast.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include "paddle/fluid/eager/eager_tensor.h"
#include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/imperative/var_helper.h"
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
...@@ -96,18 +97,20 @@ std::ostream& operator<<(std::ostream& os, AmpOperators& ops) { ...@@ -96,18 +97,20 @@ std::ostream& operator<<(std::ostream& os, AmpOperators& ops) {
return os; return os;
} }
inline std::string GetDtypeStr( template <typename VarType>
const std::shared_ptr<imperative::VarBase>& var) { inline std::string GetDtypeStr(const std::shared_ptr<VarType>& var) {
return framework::DataTypeToString(var->DataType()); return framework::DataTypeToString(GetDataType<VarType>(var));
} }
template <typename VarType>
inline bool NeedCast(const std::shared_ptr<VarBase>& var) { inline bool NeedCast(const std::shared_ptr<VarType>& var) {
if (platform::is_gpu_place(var->Place()) || auto place = GetPlace(var);
platform::is_cuda_pinned_place(var->Place()) || auto data_type = GetDataType<VarType>(var);
platform::is_xpu_place(var->Place())) { if (paddle::platform::is_gpu_place(place) ||
paddle::platform::is_cuda_pinned_place(place) ||
paddle::platform::is_xpu_place(place)) {
// CudaPinndePlace is added for varbase created by dataloader // CudaPinndePlace is added for varbase created by dataloader
if (var->DataType() == framework::proto::VarType::FP32 || if (data_type == paddle::framework::proto::VarType::FP32 ||
var->DataType() == framework::proto::VarType::FP16) { data_type == paddle::framework::proto::VarType::FP16) {
return true; return true;
} }
} }
...@@ -116,16 +119,17 @@ inline bool NeedCast(const std::shared_ptr<VarBase>& var) { ...@@ -116,16 +119,17 @@ inline bool NeedCast(const std::shared_ptr<VarBase>& var) {
// NOTE: Trace a cast op, so if a var is casted from fp32 to fp16, then the grad // NOTE: Trace a cast op, so if a var is casted from fp32 to fp16, then the grad
// var will be cast back from fp16 to fp32 during backward phase. // var will be cast back from fp16 to fp32 during backward phase.
static inline std::shared_ptr<imperative::VarBase> CastToType( template <typename VarType>
const std::shared_ptr<VarBase>& var, static inline std::shared_ptr<VarType> CastToType(
const std::shared_ptr<VarType>& var,
const framework::proto::VarType::Type dst_type) { const framework::proto::VarType::Type dst_type) {
const auto& tracer = imperative::GetCurrentTracer(); const auto& tracer = imperative::GetCurrentTracer();
imperative::NameVarBaseMap ins = {{"X", {var}}}; imperative::NameVarMap<VarType> ins = {{"X", {var}}};
framework::AttributeMap attrs = {{"in_dtype", var->DataType()}, framework::AttributeMap attrs = {{"in_dtype", GetDataType<VarType>(var)},
{"out_dtype", dst_type}}; {"out_dtype", dst_type}};
auto out = std::shared_ptr<imperative::VarBase>( auto out =
new imperative::VarBase(tracer->GenerateUniqueName())); std::shared_ptr<VarType>(new VarType(tracer->GenerateUniqueName()));
imperative::NameVarBaseMap outs = {{"Out", {out}}}; imperative::NameVarMap<VarType> outs = {{"Out", {out}}};
{ {
AutoCastGuard guard(tracer, AmpLevel::O0); AutoCastGuard guard(tracer, AmpLevel::O0);
...@@ -134,32 +138,34 @@ static inline std::shared_ptr<imperative::VarBase> CastToType( ...@@ -134,32 +138,34 @@ static inline std::shared_ptr<imperative::VarBase> CastToType(
return out; return out;
} }
template <typename VarType>
static inline std::shared_ptr<imperative::VarBase> CastToFP16( static inline std::shared_ptr<VarType> CastToFP16(
const std::shared_ptr<VarBase>& var) { const std::shared_ptr<VarType>& var) {
auto dst_type = framework::proto::VarType::FP16; auto dst_type = framework::proto::VarType::FP16;
if (NeedCast(var) && (var->DataType() != dst_type)) { if (NeedCast(var) && (GetDataType<VarType>(var) != dst_type)) {
return CastToType(var, dst_type); return CastToType(var, dst_type);
} }
return var; return var;
} }
static inline std::shared_ptr<imperative::VarBase> CastToFP32( template <typename VarType>
const std::shared_ptr<VarBase>& var) { static inline std::shared_ptr<VarType> CastToFP32(
const std::shared_ptr<VarType>& var) {
auto dst_type = framework::proto::VarType::FP32; auto dst_type = framework::proto::VarType::FP32;
if (NeedCast(var) && (var->DataType() != dst_type)) { if (NeedCast(var) && (GetDataType<VarType>(var) != dst_type)) {
return CastToType(var, dst_type); return CastToType(var, dst_type);
} }
return var; return var;
} }
template <typename VarType>
static inline framework::proto::VarType::Type GetPromoteType( static inline framework::proto::VarType::Type GetPromoteType(
const std::string& op_type, const NameVarBaseMap& ins) { const std::string& op_type, const NameVarMap<VarType>& ins) {
auto dst_type = framework::proto::VarType::FP16; auto dst_type = framework::proto::VarType::FP16;
for (const auto& pair : ins) { for (const auto& pair : ins) {
for (const auto& var : pair.second) { for (const auto& var : pair.second) {
if (var->DataType() == framework::proto::VarType::FP32) { if (GetDataType<VarType>(var) == framework::proto::VarType::FP32) {
dst_type = var->DataType(); dst_type = GetDataType<VarType>(var);
break; break;
} }
} }
...@@ -170,7 +176,8 @@ static inline framework::proto::VarType::Type GetPromoteType( ...@@ -170,7 +176,8 @@ static inline framework::proto::VarType::Type GetPromoteType(
if (op_type == "moving_average_abs_max_scale") { if (op_type == "moving_average_abs_max_scale") {
for (const auto& pair : ins) { for (const auto& pair : ins) {
if (pair.first == "X" && if (pair.first == "X" &&
pair.second.front()->DataType() == framework::proto::VarType::FP16) { GetDataType<VarType>(pair.second.front()) ==
framework::proto::VarType::FP16) {
dst_type = framework::proto::VarType::FP16; dst_type = framework::proto::VarType::FP16;
} }
} }
...@@ -179,9 +186,10 @@ static inline framework::proto::VarType::Type GetPromoteType( ...@@ -179,9 +186,10 @@ static inline framework::proto::VarType::Type GetPromoteType(
return dst_type; return dst_type;
} }
NameVarBaseMap AutoCastInputs(const std::string& op_type, template <typename VarType>
const NameVarBaseMap& ins) { NameVarMap<VarType> AutoCastInputs(const std::string& op_type,
NameVarBaseMap new_ins(ins); const NameVarMap<VarType>& ins) {
NameVarMap<VarType> new_ins(ins);
if (AmpOperators::Instance().GetMutableAllowOps()->count(op_type)) { if (AmpOperators::Instance().GetMutableAllowOps()->count(op_type)) {
for (auto& pair : new_ins) { for (auto& pair : new_ins) {
// NOTE(zhiqiu): batch_norm and layer_norm support only input x is fp16. // NOTE(zhiqiu): batch_norm and layer_norm support only input x is fp16.
...@@ -202,7 +210,7 @@ NameVarBaseMap AutoCastInputs(const std::string& op_type, ...@@ -202,7 +210,7 @@ NameVarBaseMap AutoCastInputs(const std::string& op_type,
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from " VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to float16"; << GetDtypeStr(*pair.second.cbegin()) << " to float16";
for (auto& var : pair.second) { for (auto& var : pair.second) {
var = CastToFP16(var); var = CastToFP16<VarType>(var);
} }
} }
return new_ins; return new_ins;
...@@ -211,12 +219,12 @@ NameVarBaseMap AutoCastInputs(const std::string& op_type, ...@@ -211,12 +219,12 @@ NameVarBaseMap AutoCastInputs(const std::string& op_type,
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from " VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to float"; << GetDtypeStr(*pair.second.cbegin()) << " to float";
for (auto& var : pair.second) { for (auto& var : pair.second) {
var = CastToFP32(var); var = CastToFP32<VarType>(var);
} }
} }
return new_ins; return new_ins;
} else { } else {
auto dst_type = GetPromoteType(op_type, ins); auto dst_type = GetPromoteType<VarType>(op_type, ins);
// NOTE(zhiqiu): if the op has op fp16 kernel, fall back to fp32. // NOTE(zhiqiu): if the op has op fp16 kernel, fall back to fp32.
if (dst_type == framework::proto::VarType::FP16 && if (dst_type == framework::proto::VarType::FP16 &&
...@@ -243,18 +251,23 @@ NameVarBaseMap AutoCastInputs(const std::string& op_type, ...@@ -243,18 +251,23 @@ NameVarBaseMap AutoCastInputs(const std::string& op_type,
<< GetDtypeStr(*pair.second.cbegin()) << " to " << GetDtypeStr(*pair.second.cbegin()) << " to "
<< framework::DataTypeToString(dst_type); << framework::DataTypeToString(dst_type);
for (auto& var : pair.second) { for (auto& var : pair.second) {
var = (dst_type == framework::proto::VarType::FP32 ? CastToFP32(var) var = (dst_type == framework::proto::VarType::FP32
: CastToFP16(var)); ? CastToFP32<VarType>(var)
: CastToFP16<VarType>(var));
} }
} }
return new_ins; return new_ins;
} }
return new_ins; return new_ins;
} }
template NameVarMap<VarBase> AutoCastInputs<VarBase>(
NameVarBaseMap CastPureFp16Inputs(const std::string& op_type, const std::string& op_type, const NameVarMap<VarBase>& ins);
const NameVarBaseMap& ins) { template NameVarMap<egr::EagerTensor> AutoCastInputs<egr::EagerTensor>(
NameVarBaseMap new_ins(ins); const std::string& op_type, const NameVarMap<egr::EagerTensor>& ins);
template <typename VarType>
NameVarMap<VarType> CastPureFp16Inputs(const std::string& op_type,
const NameVarMap<VarType>& ins) {
NameVarMap<VarType> new_ins(ins);
auto dst_type = framework::proto::VarType::FP16; auto dst_type = framework::proto::VarType::FP16;
if (AmpOperators::Instance().GetMutableUnsupportedFp16Ops()->count(op_type) || if (AmpOperators::Instance().GetMutableUnsupportedFp16Ops()->count(op_type) ||
AmpOperators::Instance().GetMutableBlockOps()->count(op_type)) { AmpOperators::Instance().GetMutableBlockOps()->count(op_type)) {
...@@ -284,12 +297,16 @@ NameVarBaseMap CastPureFp16Inputs(const std::string& op_type, ...@@ -284,12 +297,16 @@ NameVarBaseMap CastPureFp16Inputs(const std::string& op_type,
<< GetDtypeStr(*pair.second.cbegin()) << " to " << GetDtypeStr(*pair.second.cbegin()) << " to "
<< framework::DataTypeToString(dst_type); << framework::DataTypeToString(dst_type);
for (auto& var : pair.second) { for (auto& var : pair.second) {
var = (dst_type == framework::proto::VarType::FP32 ? CastToFP32(var) var = (dst_type == framework::proto::VarType::FP32
: CastToFP16(var)); ? CastToFP32<VarType>(var)
: CastToFP16<VarType>(var));
} }
} }
return new_ins; return new_ins;
} }
template NameVarMap<VarBase> CastPureFp16Inputs<VarBase>(
const std::string& op_type, const NameVarMap<VarBase>& ins);
template NameVarMap<egr::EagerTensor> CastPureFp16Inputs<egr::EagerTensor>(
const std::string& op_type, const NameVarMap<egr::EagerTensor>& ins);
} // namespace imperative } // namespace imperative
} // namespace paddle } // namespace paddle
...@@ -83,11 +83,12 @@ class AutoCastGuard { ...@@ -83,11 +83,12 @@ class AutoCastGuard {
AmpLevel pre_amp_level_; AmpLevel pre_amp_level_;
}; };
NameVarBaseMap AutoCastInputs(const std::string& op_type, template <typename VarType>
const NameVarBaseMap& ins); NameVarMap<VarType> AutoCastInputs(const std::string& op_type,
const NameVarMap<VarType>& ins);
NameVarBaseMap CastPureFp16Inputs(const std::string& op_type, template <typename VarType>
const NameVarBaseMap& ins); NameVarMap<VarType> CastPureFp16Inputs(const std::string& op_type,
const NameVarMap<VarType>& ins);
} // namespace imperative } // namespace imperative
} // namespace paddle } // namespace paddle
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "paddle/fluid/framework/type_defs.h" #include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/imperative/var_helper.h"
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
...@@ -33,34 +34,35 @@ class DygraphExecutionContext : public framework::ExecutionContext { ...@@ -33,34 +34,35 @@ class DygraphExecutionContext : public framework::ExecutionContext {
const framework::Scope& scope, const framework::Scope& scope,
const platform::DeviceContext& device_context, const platform::DeviceContext& device_context,
const framework::RuntimeContext& ctx, const framework::RuntimeContext& ctx,
const NameVarMap<VarType>& var_base_map_in, const NameVarMap<VarType>& var_map_in,
const NameVarMap<VarType>& var_base_map_out, const NameVarMap<VarType>& var_map_out,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) const framework::AttributeMap& default_attrs)
: ExecutionContext(op, scope, device_context, ctx), : ExecutionContext(op, scope, device_context, ctx),
var_base_map_in_(var_base_map_in), var_map_in_(var_map_in),
var_base_map_out_(var_base_map_out), var_map_out_(var_map_out),
attrs_(attrs), attrs_(attrs),
default_attrs_(default_attrs) {} default_attrs_(default_attrs) {}
std::string InputName(const std::string& name) const override { std::string InputName(const std::string& name) const override {
auto it = var_base_map_in_.find(name); auto it = var_map_in_.find(name);
PADDLE_ENFORCE_NE(it, var_base_map_in_.end(), PADDLE_ENFORCE_NE(it, var_map_in_.end(),
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Can not find [%s] in Input", name)); "Can not find [%s] in Input", name));
return it->second[0] ? it->second[0]->Name() : framework::kEmptyVarName; return it->second[0] ? GetNameFromVar(it->second[0])
: framework::kEmptyVarName;
} }
std::vector<std::string> InputNames(const std::string& name) const override { std::vector<std::string> InputNames(const std::string& name) const override {
auto it = var_base_map_in_.find(name); auto it = var_map_in_.find(name);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
it, var_base_map_in_.end(), it, var_map_in_.end(),
platform::errors::NotFound("Can not find [%s] in Input", name)); platform::errors::NotFound("Can not find [%s] in Input", name));
std::vector<std::string> vec_res; std::vector<std::string> vec_res;
vec_res.reserve(it->second.size()); vec_res.reserve(it->second.size());
for (size_t i = 0; i < it->second.size(); ++i) { for (size_t i = 0; i < it->second.size(); ++i) {
if (it->second[i]) { if (it->second[i]) {
vec_res.push_back(it->second[i]->Name()); vec_res.push_back(GetNameFromVar(it->second[i]));
} else { } else {
vec_res.push_back(framework::kEmptyVarName); vec_res.push_back(framework::kEmptyVarName);
} }
...@@ -69,23 +71,24 @@ class DygraphExecutionContext : public framework::ExecutionContext { ...@@ -69,23 +71,24 @@ class DygraphExecutionContext : public framework::ExecutionContext {
} }
std::string OutputName(const std::string& name) const override { std::string OutputName(const std::string& name) const override {
auto it = var_base_map_out_.find(name); auto it = var_map_out_.find(name);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
it, var_base_map_out_.end(), it, var_map_out_.end(),
platform::errors::NotFound("Can not find [%s] in Output", name)); platform::errors::NotFound("Can not find [%s] in Output", name));
return it->second[0] ? it->second[0]->Name() : framework::kEmptyVarName; return it->second[0] ? GetNameFromVar(it->second[0])
: framework::kEmptyVarName;
} }
std::vector<std::string> OutputNames(const std::string& name) const override { std::vector<std::string> OutputNames(const std::string& name) const override {
auto it = var_base_map_out_.find(name); auto it = var_map_out_.find(name);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
it, var_base_map_out_.end(), it, var_map_out_.end(),
platform::errors::NotFound("Can not find [%s] in Output", name)); platform::errors::NotFound("Can not find [%s] in Output", name));
std::vector<std::string> vec_res; std::vector<std::string> vec_res;
vec_res.reserve(it->second.size()); vec_res.reserve(it->second.size());
for (size_t i = 0; i < it->second.size(); ++i) { for (size_t i = 0; i < it->second.size(); ++i) {
if (it->second[i]) { if (it->second[i]) {
vec_res.push_back(it->second[i]->Name()); vec_res.push_back(GetNameFromVar(it->second[i]));
} else { } else {
vec_res.push_back(framework::kEmptyVarName); vec_res.push_back(framework::kEmptyVarName);
} }
...@@ -116,9 +119,9 @@ class DygraphExecutionContext : public framework::ExecutionContext { ...@@ -116,9 +119,9 @@ class DygraphExecutionContext : public framework::ExecutionContext {
std::vector<std::string> InNameList() const override { std::vector<std::string> InNameList() const override {
std::vector<std::string> vec_temp; std::vector<std::string> vec_temp;
vec_temp.reserve(var_base_map_in_.size()); vec_temp.reserve(var_map_in_.size());
for (auto& v : var_base_map_in_) { for (auto& v : var_map_in_) {
vec_temp.push_back(v.first); vec_temp.push_back(v.first);
} }
...@@ -126,13 +129,13 @@ class DygraphExecutionContext : public framework::ExecutionContext { ...@@ -126,13 +129,13 @@ class DygraphExecutionContext : public framework::ExecutionContext {
} }
bool HasInput(const std::string& name) const override { bool HasInput(const std::string& name) const override {
auto it = var_base_map_in_.find(name); auto it = var_map_in_.find(name);
return (it != var_base_map_in_.end() && it->second.size() > 0); return (it != var_map_in_.end() && it->second.size() > 0);
} }
bool HasOutput(const std::string& name) const override { bool HasOutput(const std::string& name) const override {
auto it = var_base_map_out_.find(name); auto it = var_map_out_.find(name);
return (it != var_base_map_out_.end() && it->second.size() > 0); return (it != var_map_out_.end() && it->second.size() > 0);
} }
size_t InputSize(const std::string& name) const override { size_t InputSize(const std::string& name) const override {
...@@ -144,8 +147,8 @@ class DygraphExecutionContext : public framework::ExecutionContext { ...@@ -144,8 +147,8 @@ class DygraphExecutionContext : public framework::ExecutionContext {
} }
const Variable* InputVar(const std::string& name) const override { const Variable* InputVar(const std::string& name) const override {
auto it = var_base_map_in_.find(name); auto it = var_map_in_.find(name);
if (it == var_base_map_in_.end()) { if (it == var_map_in_.end()) {
return nullptr; return nullptr;
} }
...@@ -155,8 +158,8 @@ class DygraphExecutionContext : public framework::ExecutionContext { ...@@ -155,8 +158,8 @@ class DygraphExecutionContext : public framework::ExecutionContext {
} }
Variable* OutputVar(const std::string& name) const override { Variable* OutputVar(const std::string& name) const override {
auto it = var_base_map_out_.find(name); auto it = var_map_out_.find(name);
if (it == var_base_map_out_.end()) { if (it == var_map_out_.end()) {
return nullptr; return nullptr;
} }
...@@ -167,8 +170,8 @@ class DygraphExecutionContext : public framework::ExecutionContext { ...@@ -167,8 +170,8 @@ class DygraphExecutionContext : public framework::ExecutionContext {
const std::vector<Variable*> MultiInputVar( const std::vector<Variable*> MultiInputVar(
const std::string& name) const override { const std::string& name) const override {
auto it = var_base_map_in_.find(name); auto it = var_map_in_.find(name);
if (it == var_base_map_in_.end()) { if (it == var_map_in_.end()) {
return {}; return {};
} }
std::vector<Variable*> vec_res; std::vector<Variable*> vec_res;
...@@ -182,8 +185,8 @@ class DygraphExecutionContext : public framework::ExecutionContext { ...@@ -182,8 +185,8 @@ class DygraphExecutionContext : public framework::ExecutionContext {
std::vector<Variable*> MultiOutputVar( std::vector<Variable*> MultiOutputVar(
const std::string& name) const override { const std::string& name) const override {
auto it = var_base_map_out_.find(name); auto it = var_map_out_.find(name);
if (it == var_base_map_out_.end()) { if (it == var_map_out_.end()) {
return {}; return {};
} }
std::vector<Variable*> vec_res; std::vector<Variable*> vec_res;
...@@ -196,8 +199,8 @@ class DygraphExecutionContext : public framework::ExecutionContext { ...@@ -196,8 +199,8 @@ class DygraphExecutionContext : public framework::ExecutionContext {
} }
private: private:
const NameVarMap<VarType>& var_base_map_in_; const NameVarMap<VarType>& var_map_in_;
const NameVarMap<VarType>& var_base_map_out_; const NameVarMap<VarType>& var_map_out_;
const framework::AttributeMap& attrs_; const framework::AttributeMap& attrs_;
const framework::AttributeMap& default_attrs_; const framework::AttributeMap& default_attrs_;
}; };
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "paddle/fluid/framework/shape_inference.h" #include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/type_defs.h" #include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/imperative/var_helper.h"
#include "paddle/fluid/imperative/variable_wrapper.h" #include "paddle/fluid/imperative/variable_wrapper.h"
namespace paddle { namespace paddle {
...@@ -37,8 +38,8 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -37,8 +38,8 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
const framework::AttributeMap* attr, const framework::AttributeMap* attr,
const framework::AttributeMap* default_attr, const std::string op_type, const framework::AttributeMap* default_attr, const std::string op_type,
const framework::OpKernelType* op_kernel_type = nullptr) const framework::OpKernelType* op_kernel_type = nullptr)
: var_base_map_in_(in), : var_map_in_(in),
var_base_map_out_(out), var_map_out_(out),
attrs_(attr), attrs_(attr),
default_attrs_(default_attr), default_attrs_(default_attr),
op_type_(op_type), op_type_(op_type),
...@@ -46,9 +47,9 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -46,9 +47,9 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
bool HasInput(const std::string& name) const override { bool HasInput(const std::string& name) const override {
// has only one input // has only one input
auto it = var_base_map_in_->find(name); auto it = var_map_in_->find(name);
if (it == var_base_map_in_->end()) { if (it == var_map_in_->end()) {
return false; return false;
} }
const auto& in = it->second; const auto& in = it->second;
...@@ -62,8 +63,8 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -62,8 +63,8 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
bool HasOutput(const std::string& name) const override { bool HasOutput(const std::string& name) const override {
// has only one output // has only one output
auto it = var_base_map_out_->find(name); auto it = var_map_out_->find(name);
if (it == var_base_map_out_->end()) { if (it == var_map_out_->end()) {
return false; return false;
} }
const auto& out = it->second; const auto& out = it->second;
...@@ -78,8 +79,8 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -78,8 +79,8 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
} }
bool HasInputs(const std::string& name) const override { bool HasInputs(const std::string& name) const override {
auto it = var_base_map_in_->find(name); auto it = var_map_in_->find(name);
if (it == var_base_map_in_->end() || it->second.empty()) { if (it == var_map_in_->end() || it->second.empty()) {
return false; return false;
} }
for (auto& input : it->second) { for (auto& input : it->second) {
...@@ -91,8 +92,8 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -91,8 +92,8 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
} }
bool HasOutputs(const std::string& name) const override { bool HasOutputs(const std::string& name) const override {
auto it = var_base_map_out_->find(name); auto it = var_map_out_->find(name);
if (it == var_base_map_out_->end() || it->second.empty()) { if (it == var_map_out_->end() || it->second.empty()) {
return false; return false;
} }
for (auto& output : it->second) { for (auto& output : it->second) {
...@@ -109,15 +110,15 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -109,15 +110,15 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
std::vector<std::string> Inputs(const std::string& name) const override { std::vector<std::string> Inputs(const std::string& name) const override {
std::vector<std::string> vec_res; std::vector<std::string> vec_res;
auto it = var_base_map_in_->find(name); auto it = var_map_in_->find(name);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
it, var_base_map_in_->end(), it, var_map_in_->end(),
platform::errors::NotFound("can not find [%s] in input", name)); platform::errors::NotFound("can not find [%s] in input", name));
vec_res.reserve(it->second.size()); vec_res.reserve(it->second.size());
for (auto& var : it->second) { for (auto& var : it->second) {
if (var) { if (var) {
vec_res.push_back(var->Name()); vec_res.push_back(GetNameFromVar(var));
} else { } else {
vec_res.push_back(framework::kEmptyVarName); vec_res.push_back(framework::kEmptyVarName);
} }
...@@ -128,15 +129,15 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -128,15 +129,15 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
std::vector<std::string> Outputs(const std::string& name) const override { std::vector<std::string> Outputs(const std::string& name) const override {
std::vector<std::string> vec_res; std::vector<std::string> vec_res;
auto it = var_base_map_out_->find(name); auto it = var_map_out_->find(name);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
it, var_base_map_out_->end(), it, var_map_out_->end(),
platform::errors::NotFound("can not find [%s] in output", name)); platform::errors::NotFound("can not find [%s] in output", name));
vec_res.reserve(it->second.size()); vec_res.reserve(it->second.size());
for (auto& var : it->second) { for (auto& var : it->second) {
if (var) { if (var) {
vec_res.push_back(var->Name()); vec_res.push_back(GetNameFromVar(var));
} else { } else {
vec_res.push_back(framework::kEmptyVarName); vec_res.push_back(framework::kEmptyVarName);
} }
...@@ -169,16 +170,16 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -169,16 +170,16 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
void ShareDim(const std::string& in, const std::string& out, size_t i = 0, void ShareDim(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) override { size_t j = 0) override {
auto in_it = var_base_map_in_->find(in); auto in_it = var_map_in_->find(in);
auto out_it = var_base_map_out_->find(out); auto out_it = var_map_out_->find(out);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
in_it, var_base_map_in_->end(), in_it, var_map_in_->end(),
platform::errors::NotFound("can not found [%s] in input", in)); platform::errors::NotFound("can not found [%s] in input", in));
PADDLE_ENFORCE_GT(in_it->second.size(), i, PADDLE_ENFORCE_GT(in_it->second.size(), i,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Inputs %s should have %llu argument", in, i)); "Inputs %s should have %llu argument", in, i));
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
out_it, var_base_map_out_->end(), out_it, var_map_out_->end(),
platform::errors::NotFound("can not found [%s] in input", in)); platform::errors::NotFound("can not found [%s] in input", in));
PADDLE_ENFORCE_GT(out_it->second.size(), j, PADDLE_ENFORCE_GT(out_it->second.size(), j,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
...@@ -223,9 +224,9 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -223,9 +224,9 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
std::vector<framework::InferShapeVarPtr> GetInputVarPtrs( std::vector<framework::InferShapeVarPtr> GetInputVarPtrs(
const std::string& name) const override { const std::string& name) const override {
std::vector<framework::InferShapeVarPtr> res; std::vector<framework::InferShapeVarPtr> res;
auto it = var_base_map_in_->find(name); auto it = var_map_in_->find(name);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
it, var_base_map_in_->end(), it, var_map_in_->end(),
platform::errors::NotFound("Can not find [%s] in inputs.", name)); platform::errors::NotFound("Can not find [%s] in inputs.", name));
for (auto& var : it->second) { for (auto& var : it->second) {
res.emplace_back(var->MutableVar()); res.emplace_back(var->MutableVar());
...@@ -236,9 +237,9 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -236,9 +237,9 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
std::vector<framework::InferShapeVarPtr> GetOutputVarPtrs( std::vector<framework::InferShapeVarPtr> GetOutputVarPtrs(
const std::string& name) const override { const std::string& name) const override {
std::vector<framework::InferShapeVarPtr> res; std::vector<framework::InferShapeVarPtr> res;
auto it = var_base_map_out_->find(name); auto it = var_map_out_->find(name);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
it, var_base_map_out_->end(), it, var_map_out_->end(),
platform::errors::NotFound("Can not find [%s] in outputs.", name)); platform::errors::NotFound("Can not find [%s] in outputs.", name));
for (auto& var : it->second) { for (auto& var : it->second) {
res.emplace_back(var->MutableVar()); res.emplace_back(var->MutableVar());
...@@ -247,9 +248,9 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -247,9 +248,9 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
} }
DDim GetInputDim(const std::string& name) const override { DDim GetInputDim(const std::string& name) const override {
auto it = var_base_map_in_->find(name); auto it = var_map_in_->find(name);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
it, var_base_map_in_->end(), it, var_map_in_->end(),
platform::errors::NotFound("can not find [%s] in input", name)); platform::errors::NotFound("can not find [%s] in input", name));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
it->second.size(), 1UL, it->second.size(), 1UL,
...@@ -262,9 +263,9 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -262,9 +263,9 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
std::vector<DDim> GetInputsDim(const std::string& name) const override { std::vector<DDim> GetInputsDim(const std::string& name) const override {
// const std::vector<Variable*>& vars = InputVars(name); // const std::vector<Variable*>& vars = InputVars(name);
std::vector<DDim> vec_res; std::vector<DDim> vec_res;
auto it = var_base_map_in_->find(name); auto it = var_map_in_->find(name);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
it, var_base_map_in_->end(), it, var_map_in_->end(),
platform::errors::NotFound("can not find [%s] in output", name)); platform::errors::NotFound("can not find [%s] in output", name));
vec_res.reserve(it->second.size()); vec_res.reserve(it->second.size());
for (size_t i = 0; i < it->second.size(); ++i) { for (size_t i = 0; i < it->second.size(); ++i) {
...@@ -281,9 +282,9 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -281,9 +282,9 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
std::vector<framework::proto::VarType::Type> GetInputsVarType( std::vector<framework::proto::VarType::Type> GetInputsVarType(
const std::string& name) const override { const std::string& name) const override {
std::vector<framework::proto::VarType::Type> vec_res; std::vector<framework::proto::VarType::Type> vec_res;
auto it = var_base_map_in_->find(name); auto it = var_map_in_->find(name);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
it, var_base_map_in_->end(), it, var_map_in_->end(),
platform::errors::NotFound("can not find [%s] in input", name)); platform::errors::NotFound("can not find [%s] in input", name));
vec_res.reserve(it->second.size()); vec_res.reserve(it->second.size());
for (size_t i = 0; i < it->second.size(); ++i) { for (size_t i = 0; i < it->second.size(); ++i) {
...@@ -300,9 +301,9 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -300,9 +301,9 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
std::vector<framework::proto::VarType::Type> GetOutputsVarType( std::vector<framework::proto::VarType::Type> GetOutputsVarType(
const std::string& name) const override { const std::string& name) const override {
std::vector<framework::proto::VarType::Type> vec_res; std::vector<framework::proto::VarType::Type> vec_res;
auto it = var_base_map_out_->find(name); auto it = var_map_out_->find(name);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
it, var_base_map_out_->end(), it, var_map_out_->end(),
platform::errors::NotFound("can not find [%s] in output", name)); platform::errors::NotFound("can not find [%s] in output", name));
vec_res.reserve(it->second.size()); vec_res.reserve(it->second.size());
for (size_t i = 0; i < it->second.size(); ++i) { for (size_t i = 0; i < it->second.size(); ++i) {
...@@ -317,9 +318,9 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -317,9 +318,9 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
} }
void SetOutputDim(const std::string& name, const DDim& dim) override { void SetOutputDim(const std::string& name, const DDim& dim) override {
auto it = var_base_map_out_->find(name); auto it = var_map_out_->find(name);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
it, var_base_map_out_->end(), it, var_map_out_->end(),
platform::errors::NotFound("can not find [%s] in output", name)); platform::errors::NotFound("can not find [%s] in output", name));
if (it->second[0]) { if (it->second[0]) {
...@@ -329,9 +330,9 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -329,9 +330,9 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
void SetOutputsDim(const std::string& name, void SetOutputsDim(const std::string& name,
const std::vector<DDim>& dims) override { const std::vector<DDim>& dims) override {
auto it = var_base_map_out_->find(name); auto it = var_map_out_->find(name);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
it, var_base_map_out_->end(), it, var_map_out_->end(),
platform::errors::NotFound("can not find [%s] in output", name)); platform::errors::NotFound("can not find [%s] in output", name));
PADDLE_ENFORCE_EQ(dims.size(), it->second.size(), PADDLE_ENFORCE_EQ(dims.size(), it->second.size(),
...@@ -413,8 +414,8 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -413,8 +414,8 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
} }
private: private:
const NameVarMap<VarType>* var_base_map_in_; const NameVarMap<VarType>* var_map_in_;
const NameVarMap<VarType>* var_base_map_out_; const NameVarMap<VarType>* var_map_out_;
const framework::AttributeMap* attrs_; const framework::AttributeMap* attrs_;
const framework::AttributeMap* default_attrs_; const framework::AttributeMap* default_attrs_;
const std::string op_type_; const std::string op_type_;
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "paddle/fluid/framework/type_defs.h" #include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/framework/var_type_inference.h" #include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/imperative/var_helper.h"
#include "paddle/fluid/imperative/variable_wrapper.h" #include "paddle/fluid/imperative/variable_wrapper.h"
namespace paddle { namespace paddle {
...@@ -72,7 +73,7 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext { ...@@ -72,7 +73,7 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
const std::string& InputVarName(const std::string& name, const std::string& InputVarName(const std::string& name,
const int index = 0) const { const int index = 0) const {
return inputs_.at(name)[index]->Name(); return GetNameFromVar(inputs_.at(name)[index]);
} }
bool InputTypeAnyOf(const std::string& name, bool InputTypeAnyOf(const std::string& name,
...@@ -80,7 +81,7 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext { ...@@ -80,7 +81,7 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
auto& inputs = inputs_.at(name); auto& inputs = inputs_.at(name);
return std::any_of(inputs.begin(), inputs.end(), return std::any_of(inputs.begin(), inputs.end(),
[&type](const std::shared_ptr<VarType>& var) { [&type](const std::shared_ptr<VarType>& var) {
return var->Type() == type; return GetType(var) == type;
}); });
} }
...@@ -89,7 +90,7 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext { ...@@ -89,7 +90,7 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
auto& inputs = inputs_.at(name); auto& inputs = inputs_.at(name);
return std::all_of(inputs.begin(), inputs.end(), return std::all_of(inputs.begin(), inputs.end(),
[&type](const std::shared_ptr<VarType>& var) { [&type](const std::shared_ptr<VarType>& var) {
return var->Type() == type; return GetType(var) == type;
}); });
} }
...@@ -99,8 +100,7 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext { ...@@ -99,8 +100,7 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
auto in_var = inputs_.at(input_name)[index]; auto in_var = inputs_.at(input_name)[index];
auto out_var = outputs_.at(output_name)[index]; auto out_var = outputs_.at(output_name)[index];
if (in_var != out_var) { if (in_var != out_var) {
this->SetVarBaseType(out_var, in_var->Type()); this->SetVarType(out_var, GetType(in_var));
this->SetVarBaseDataType(out_var, in_var->DataType());
} }
} }
...@@ -109,54 +109,44 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext { ...@@ -109,54 +109,44 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
int index = 0) override { int index = 0) override {
if (index == framework::ALL_ELEMENTS) { if (index == framework::ALL_ELEMENTS) {
for (auto& item : outputs_.at(name)) { for (auto& item : outputs_.at(name)) {
this->SetVarBaseType(item, type); this->SetVarType(item, type);
} }
} else { } else {
auto& var = outputs_.at(name)[index]; auto& var = outputs_.at(name)[index];
this->SetVarBaseType(var, type); this->SetVarType(var, type);
} }
} }
void SetVarBaseType(std::shared_ptr<VarType> out, void SetVarType(std::shared_ptr<VarType> out,
framework::proto::VarType::Type type) { framework::proto::VarType::Type type) {
out->SetType(type); SetType(out, type);
if ((out->MutableVar()->IsInitialized() == true) && if ((out->MutableVar()->IsInitialized() == true) &&
(out->MutableVar()->Type() != type)) { (out->MutableVar()->Type() != type)) {
out->MutableVar()->Clear(); out->MutableVar()->Clear();
} }
} }
void SetVarBaseDataType(std::shared_ptr<VarType> out,
framework::proto::VarType::Type type) {
out->SetDataType(type);
}
framework::proto::VarType::Type GetInputType( framework::proto::VarType::Type GetInputType(
const std::string& name, const int& index = 0) const override { const std::string& name, const int& index = 0) const override {
return inputs_.at(name)[index]->Type(); return GetType(inputs_.at(name)[index]);
} }
framework::proto::VarType::Type GetOutputType( framework::proto::VarType::Type GetOutputType(
const std::string& name, const int& index = 0) const override { const std::string& name, const int& index = 0) const override {
return outputs_.at(name)[index]->Type(); return GetType(outputs_.at(name)[index]);
} }
framework::proto::VarType::Type GetInputDataType( framework::proto::VarType::Type GetInputDataType(
const std::string& name, const int& index = 0) const override { const std::string& name, const int& index = 0) const override {
return inputs_.at(name)[index]->DataType(); return GetDataType(inputs_.at(name)[index]);
} }
void SetOutputDataType(const std::string& name, void SetOutputDataType(const std::string& name,
framework::proto::VarType::Type type, framework::proto::VarType::Type type,
int index = 0) override { int index = 0) override {
if (framework::ALL_ELEMENTS == index) { VLOG(10) << "Set data type in infer var type of Eager mode is meaning less "
for (auto& item : outputs_.at(name)) { "for var: "
this->SetVarBaseDataType(item, type); << name;
}
} else {
auto& var = outputs_.at(name)[index];
this->SetVarBaseDataType(var, type);
}
} }
bool IsDygraph() const override { return true; } bool IsDygraph() const override { return true; }
......
...@@ -140,6 +140,13 @@ void ProgramDescTracer::InsertOp(const std::string &type, ...@@ -140,6 +140,13 @@ void ProgramDescTracer::InsertOp(const std::string &type,
} }
} }
void ProgramDescTracer::InsertOp(const std::string &type,
const NameTensorMap &inputs,
const NameTensorMap &outputs,
const framework::AttributeMap &attrs) {
// TODO(jiabin): Support this later.
}
TracedProgramTuple ProgramDescTracer::CreateProgramDesc( TracedProgramTuple ProgramDescTracer::CreateProgramDesc(
const std::vector<std::shared_ptr<VarBase>> &feed_vars, const std::vector<std::shared_ptr<VarBase>> &feed_vars,
const std::string &feed_prefix, const std::string &feed_prefix,
......
...@@ -61,6 +61,10 @@ class ProgramDescTracer { ...@@ -61,6 +61,10 @@ class ProgramDescTracer {
const NameVarBaseMap &outputs, const NameVarBaseMap &outputs,
const framework::AttributeMap &attrs); const framework::AttributeMap &attrs);
void InsertOp(const std::string &type, const NameTensorMap &inputs,
const NameTensorMap &outputs,
const framework::AttributeMap &attrs);
TracedProgramTuple CreateProgramDesc( TracedProgramTuple CreateProgramDesc(
const std::vector<std::shared_ptr<VarBase>> &feed_vars, const std::vector<std::shared_ptr<VarBase>> &feed_vars,
const std::string &feed_prefix, const std::string &feed_prefix,
......
...@@ -14,11 +14,12 @@ ...@@ -14,11 +14,12 @@
#include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/eager/eager_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/imperative/infer_var_type_context.h" #include "paddle/fluid/imperative/infer_var_type_context.h"
#include "paddle/fluid/imperative/op_base.h" #include "paddle/fluid/imperative/op_base.h"
#include "paddle/fluid/imperative/prepared_operator.h" #include "paddle/fluid/imperative/prepared_operator.h"
#include "paddle/fluid/imperative/var_helper.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -90,7 +91,7 @@ static std::string DebugString( ...@@ -90,7 +91,7 @@ static std::string DebugString(
ss << "NULL"; ss << "NULL";
continue; continue;
} }
ss << vars[i]->Name() << "["; ss << GetNameFromVar(vars[i]) << "[";
const framework::Variable& var = vars[i]->Var(); const framework::Variable& var = vars[i]->Var();
if (!var.IsInitialized()) { if (!var.IsInitialized()) {
ss << "NOT_INITED_VAR"; ss << "NOT_INITED_VAR";
...@@ -169,6 +170,29 @@ std::string LayerDebugString(const std::string& op_type, ...@@ -169,6 +170,29 @@ std::string LayerDebugString(const std::string& op_type,
return LayerDebugStringImpl<VariableWrapper>(op_type, ins, outs); return LayerDebugStringImpl<VariableWrapper>(op_type, ins, outs);
} }
std::string LayerDebugString(const std::string& op_type,
const NameVarMap<egr::EagerTensor>& ins,
const NameVarMap<egr::EagerTensor>& outs) {
return LayerDebugStringImpl<egr::EagerTensor>(op_type, ins, outs);
}
template <typename VarType>
static void SetForwardDataTypeOfGradVars(const NameVarMap<VarType>& outs) {
for (auto& var_pair : outs) {
for (auto& var : var_pair.second) {
// NOTE(zhiqu): The ouput may be NULL because of pruning.
if (var) {
SetForwardDataTypeOfGradVar(var);
}
}
}
}
template <>
void SetForwardDataTypeOfGradVars<egr::EagerTensor>(
const NameVarMap<egr::EagerTensor>& outs) {
// In eager mode we don't need this.
}
VarBase::VarBase(const std::shared_ptr<VariableWrapper>& var) VarBase::VarBase(const std::shared_ptr<VariableWrapper>& var)
: var_(var), grad_node_(var->GetGradNode()) { : var_(var), grad_node_(var->GetGradNode()) {
if (auto grad_var = var_->GetGradVar()) { if (auto grad_var = var_->GetGradVar()) {
...@@ -407,8 +431,6 @@ void VarBase::_CopyGradientFrom(const VarBase& src) { ...@@ -407,8 +431,6 @@ void VarBase::_CopyGradientFrom(const VarBase& src) {
} }
} }
pten::KernelContext OpBase::pt_kernel_context_;
void OpBase::SetType(const std::string& type) { void OpBase::SetType(const std::string& type) {
op_ = framework::OpRegistry::CreateOp(type, {}, {}, {}, false); op_ = framework::OpRegistry::CreateOp(type, {}, {}, {}, false);
} }
...@@ -440,7 +462,7 @@ static void OpBaseRunImpl(const framework::OperatorBase& op, ...@@ -440,7 +462,7 @@ static void OpBaseRunImpl(const framework::OperatorBase& op,
for (auto& var_pair : outs) { for (auto& var_pair : outs) {
for (auto& var : var_pair.second) { for (auto& var : var_pair.second) {
if (var) { if (var) {
InitializeVariable(var->MutableVar(), var->Type()); InitializeVariable(var->MutableVar(), GetType(var));
} }
} }
} }
...@@ -478,14 +500,7 @@ static void OpBaseRunImpl(const framework::OperatorBase& op, ...@@ -478,14 +500,7 @@ static void OpBaseRunImpl(const framework::OperatorBase& op,
VLOG(4) << LayerDebugString(op.Type(), ins, outs); VLOG(4) << LayerDebugString(op.Type(), ins, outs);
// set the output var // set the output var
for (auto& var_pair : outs) { SetForwardDataTypeOfGradVars<VarType>(outs);
for (auto& var : var_pair.second) {
// NOTE(zhiqu): The ouput may be NULL because of pruning.
if (var) {
SetForwardDataTypeOfGradVar(var);
}
}
}
} }
void OpBase::Run(const framework::OperatorBase& op, void OpBase::Run(const framework::OperatorBase& op,
...@@ -506,6 +521,15 @@ void OpBase::Run(const framework::OperatorBase& op, ...@@ -506,6 +521,15 @@ void OpBase::Run(const framework::OperatorBase& op,
OpBaseRunImpl<VariableWrapper>(op, ins, outs, attrs, default_attrs, place); OpBaseRunImpl<VariableWrapper>(op, ins, outs, attrs, default_attrs, place);
} }
void OpBase::Run(const framework::OperatorBase& op,
const NameVarMap<egr::EagerTensor>& ins,
const NameVarMap<egr::EagerTensor>& outs,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
const platform::Place& place) {
OpBaseRunImpl<egr::EagerTensor>(op, ins, outs, attrs, default_attrs, place);
}
void ClearNoNeedBufferInputs(OpBase* op) { void ClearNoNeedBufferInputs(OpBase* op) {
auto& inferer = op->Info().NoNeedBufferVarsInferer(); auto& inferer = op->Info().NoNeedBufferVarsInferer();
if (!inferer) return; if (!inferer) return;
...@@ -566,5 +590,14 @@ std::shared_ptr<GradOpNode> CreateGradOpNode( ...@@ -566,5 +590,14 @@ std::shared_ptr<GradOpNode> CreateGradOpNode(
} }
} }
std::shared_ptr<GradOpNode> CreateGradOpNode(
const framework::OperatorBase& op, const NameTensorMap& ins,
const NameTensorMap& outs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs, const platform::Place& place,
const std::map<std::string, std::string>& inplace_map) {
// Do Nothing in Eager Mode.
return nullptr;
}
} // namespace imperative } // namespace imperative
} // namespace paddle } // namespace paddle
...@@ -288,6 +288,12 @@ std::shared_ptr<GradOpNode> CreateGradOpNode( ...@@ -288,6 +288,12 @@ std::shared_ptr<GradOpNode> CreateGradOpNode(
const framework::AttributeMap& default_attrs, const platform::Place& place, const framework::AttributeMap& default_attrs, const platform::Place& place,
const std::map<std::string, std::string>& inplace_map); const std::map<std::string, std::string>& inplace_map);
std::shared_ptr<GradOpNode> CreateGradOpNode(
const framework::OperatorBase& op, const NameTensorMap& ins,
const NameTensorMap& outs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs, const platform::Place& place,
const std::map<std::string, std::string>& inplace_map);
void ClearNoNeedBufferInputs(OpBase* op); void ClearNoNeedBufferInputs(OpBase* op);
} // namespace imperative } // namespace imperative
......
...@@ -121,6 +121,8 @@ class OpBase { ...@@ -121,6 +121,8 @@ class OpBase {
const framework::AttributeMap& DefaultAttrsMap() { return *default_attrs_; } const framework::AttributeMap& DefaultAttrsMap() { return *default_attrs_; }
bool HasAttr(const std::string& name) const { bool HasAttr(const std::string& name) const {
VLOG(6) << "Default attrs: " << default_attrs_;
VLOG(6) << "attrs: " << &attrs_;
return attrs_.count(name) > 0 || default_attrs_->count(name) > 0; return attrs_.count(name) > 0 || default_attrs_->count(name) > 0;
} }
...@@ -182,6 +184,12 @@ class OpBase { ...@@ -182,6 +184,12 @@ class OpBase {
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs, const framework::AttributeMap& default_attrs,
const platform::Place& place); const platform::Place& place);
static void Run(const framework::OperatorBase& op,
const NameVarMap<egr::EagerTensor>& ins,
const NameVarMap<egr::EagerTensor>& outs,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
const platform::Place& place);
bool HasVoidFunctionPostHook() const { bool HasVoidFunctionPostHook() const {
return !void_function_post_hooks_.empty(); return !void_function_post_hooks_.empty();
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/imperative/prepared_operator.h" #include "paddle/fluid/imperative/prepared_operator.h"
#include "paddle/fluid/eager/eager_tensor.h"
#include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/details/nan_inf_utils.h" #include "paddle/fluid/framework/details/nan_inf_utils.h"
#include "paddle/fluid/imperative/infer_shape_context.h" #include "paddle/fluid/imperative/infer_shape_context.h"
...@@ -56,7 +57,7 @@ const framework::Tensor* GetTensorFromVar(const framework::Variable& var) { ...@@ -56,7 +57,7 @@ const framework::Tensor* GetTensorFromVar(const framework::Variable& var) {
} }
template <typename VarType> template <typename VarType>
static void HandleComplexGradToRealGrad(const NameVarMap<VarType>& outs) { void HandleComplexGradToRealGrad(const NameVarMap<VarType>& outs) {
for (auto& pair : outs) { for (auto& pair : outs) {
for (auto& var : pair.second) { for (auto& var : pair.second) {
if (var == nullptr) { if (var == nullptr) {
...@@ -87,6 +88,12 @@ static void HandleComplexGradToRealGrad(const NameVarMap<VarType>& outs) { ...@@ -87,6 +88,12 @@ static void HandleComplexGradToRealGrad(const NameVarMap<VarType>& outs) {
} }
} }
template <>
void HandleComplexGradToRealGrad<egr::EagerTensor>(
const NameVarMap<egr::EagerTensor>& outs) {
// TODO(jiabin): Support Complex here.
}
PreparedOp::PreparedOp(const framework::OperatorBase& op, PreparedOp::PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx, const framework::RuntimeContext& ctx,
const framework::OpKernelType& kernel_type, const framework::OpKernelType& kernel_type,
...@@ -305,6 +312,15 @@ PreparedOp PreparedOp::Prepare(const NameVarMap<VariableWrapper>& ins, ...@@ -305,6 +312,15 @@ PreparedOp PreparedOp::Prepare(const NameVarMap<VariableWrapper>& ins,
default_attrs); default_attrs);
} }
PreparedOp PreparedOp::Prepare(const NameVarMap<egr::EagerTensor>& ins,
const NameVarMap<egr::EagerTensor>& outs,
const framework::OperatorWithKernel& op,
const platform::Place& place,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
return PrepareImpl<egr::EagerTensor>(ins, outs, op, place, attrs,
default_attrs);
}
template <typename VarType> template <typename VarType>
static void PreparedOpRunImpl( static void PreparedOpRunImpl(
const framework::OperatorBase& op, const framework::RuntimeContext& ctx, const framework::OperatorBase& op, const framework::RuntimeContext& ctx,
...@@ -435,5 +451,20 @@ void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins, ...@@ -435,5 +451,20 @@ void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins,
} }
} }
void PreparedOp::Run(const NameVarMap<egr::EagerTensor>& ins,
const NameVarMap<egr::EagerTensor>& outs,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
if (run_pten_kernel_) {
PreparedOpRunPtImpl<egr::EagerTensor>(
op_, kernel_type_, pt_kernel_signature_, pt_kernel_, dev_ctx_, ins,
outs, attrs, default_attrs);
} else {
PreparedOpRunImpl<egr::EagerTensor>(op_, ctx_, kernel_type_, func_,
dev_ctx_, ins, outs, attrs,
default_attrs);
}
}
} // namespace imperative } // namespace imperative
} // namespace paddle } // namespace paddle
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/eager/eager_tensor.h"
#include "paddle/fluid/framework/data_transform.h" #include "paddle/fluid/framework/data_transform.h"
#include "paddle/fluid/framework/op_kernel_type.h" #include "paddle/fluid/framework/op_kernel_type.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
...@@ -26,19 +27,10 @@ ...@@ -26,19 +27,10 @@
#include "paddle/fluid/imperative/execution_context.h" #include "paddle/fluid/imperative/execution_context.h"
#include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/imperative/var_helper.h"
DECLARE_bool(use_mkldnn); DECLARE_bool(use_mkldnn);
namespace pten {
class DenseTensor;
} // namespace pten
namespace paddle {
namespace framework {
class Variable;
} // namespace framework
} // namespace paddle
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
...@@ -66,10 +58,14 @@ void SetForwardDataTypeOfGradVar<VarBase>(const std::shared_ptr<VarBase>& var) { ...@@ -66,10 +58,14 @@ void SetForwardDataTypeOfGradVar<VarBase>(const std::shared_ptr<VarBase>& var) {
} }
} }
extern const std::shared_ptr<VariableWrapper>& GetVariableWrapper( template <>
const std::shared_ptr<paddle::imperative::VarBase>& var); void SetForwardDataTypeOfGradVar<egr::EagerTensor>(
extern const std::shared_ptr<VariableWrapper>& GetVariableWrapper( const std::shared_ptr<egr::EagerTensor>& var) {
const std::shared_ptr<VariableWrapper>& var); VLOG(10) << "Var in Eager dose not support SetForwardDataTypeOfGradVar: "
<< var->name();
// TODO(jiabin): SetForwardDataType of Grad var is not supported yet in
// EagerMode.
}
template <typename VarType> template <typename VarType>
std::shared_ptr<NameVarMap<VarType>> PrepareData( std::shared_ptr<NameVarMap<VarType>> PrepareData(
...@@ -78,31 +74,32 @@ std::shared_ptr<NameVarMap<VarType>> PrepareData( ...@@ -78,31 +74,32 @@ std::shared_ptr<NameVarMap<VarType>> PrepareData(
std::shared_ptr<NameVarMap<VarType>> tmp_ins_ptr = nullptr; std::shared_ptr<NameVarMap<VarType>> tmp_ins_ptr = nullptr;
for (const auto& name_pair : ins) { for (const auto& name_pair : ins) {
for (size_t i = 0; i < name_pair.second.size(); ++i) { for (size_t i = 0; i < name_pair.second.size(); ++i) {
auto& var_base = name_pair.second[i]; auto& template_var = name_pair.second[i];
SetForwardDataTypeOfGradVar(var_base); SetForwardDataTypeOfGradVar(template_var);
const auto* tensor = GetTensorFromVar(var_base->Var()); const auto* tensor = GetTensorFromVar(template_var->Var());
if (tensor && tensor->IsInitialized()) { if (tensor && tensor->IsInitialized()) {
auto kernel_type_for_var = op.GetKernelTypeForVar( auto kernel_type_for_var = op.GetKernelTypeForVar(
name_pair.first, *tensor, expected_kernel_key); name_pair.first, *tensor, expected_kernel_key);
if (!NeedTransform(kernel_type_for_var, expected_kernel_key)) { if (!NeedTransform(kernel_type_for_var, expected_kernel_key)) {
continue; continue;
} else { } else {
VLOG(3) << "Transform Variable " << var_base->Name() << " from " VLOG(3) << "Transform Variable " << GetNameFromVar(template_var)
<< kernel_type_for_var << " to " << expected_kernel_key; << " from " << kernel_type_for_var << " to "
<< expected_kernel_key;
if (GetVariableWrapper(var_base)->hasCacheKey(expected_kernel_key)) { if (CheckCachedKey(template_var, expected_kernel_key)) {
VLOG(3) << "Hit variable_wrapper cache: key=" VLOG(3) << "Hit variable_wrapper cache: key="
<< expected_kernel_key; << expected_kernel_key;
std::shared_ptr<VariableWrapper> cache_var = std::shared_ptr<VariableWrapper> cache_var =
GetVariableWrapper(var_base)->getCacheValue( GetCachedValue(template_var, expected_kernel_key);
expected_kernel_key);
if (tmp_ins_ptr == nullptr) { if (tmp_ins_ptr == nullptr) {
tmp_ins_ptr = std::make_shared<NameVarMap<VarType>>(ins); tmp_ins_ptr = std::make_shared<NameVarMap<VarType>>(ins);
} }
const auto* tensor = GetTensorFromVar(cache_var->Var()); const auto* tensor = GetTensorFromVar(cache_var->Var());
auto tmp_var = std::make_shared<VarType>(var_base->Name()); auto tmp_var =
tmp_var->SetType(var_base->Type()); std::make_shared<VarType>(GetNameFromVar(template_var));
SetType(tmp_var, GetType(template_var));
SetTensorToVariable(cache_var->Var(), *tensor, SetTensorToVariable(cache_var->Var(), *tensor,
tmp_var->MutableVar()); tmp_var->MutableVar());
(*tmp_ins_ptr)[name_pair.first][i] = tmp_var; (*tmp_ins_ptr)[name_pair.first][i] = tmp_var;
...@@ -118,20 +115,21 @@ std::shared_ptr<NameVarMap<VarType>> PrepareData( ...@@ -118,20 +115,21 @@ std::shared_ptr<NameVarMap<VarType>> PrepareData(
if (tmp_ins_ptr == nullptr) { if (tmp_ins_ptr == nullptr) {
tmp_ins_ptr = std::make_shared<NameVarMap<VarType>>(ins); tmp_ins_ptr = std::make_shared<NameVarMap<VarType>>(ins);
} }
auto tmp_var = std::make_shared<VarType>(var_base->Name()); auto tmp_var =
tmp_var->SetType(var_base->Type()); std::make_shared<VarType>(GetNameFromVar(template_var));
SetTensorToVariable(var_base->Var(), out, tmp_var->MutableVar()); SetType(tmp_var, GetType(template_var));
SetTensorToVariable(template_var->Var(), out,
tmp_var->MutableVar());
(*tmp_ins_ptr)[name_pair.first][i] = tmp_var; (*tmp_ins_ptr)[name_pair.first][i] = tmp_var;
SetCachedValue(template_var, expected_kernel_key, tmp_var);
GetVariableWrapper(var_base)->setCacheValue(
expected_kernel_key, GetVariableWrapper(tmp_var));
VLOG(3) << "Set cache to variable_wrapper: key=" VLOG(3) << "Set cache to variable_wrapper: key="
<< expected_kernel_key; << expected_kernel_key;
} else { } else {
// if dtype is same, transform inplace will not change the // if dtype is same, transform inplace will not change the
// original // original
// value, transform inplace to avoid multiple copy // value, transform inplace to avoid multiple copy
SetTensorToVariable(var_base->Var(), out, var_base->MutableVar()); SetTensorToVariable(template_var->Var(), out,
template_var->MutableVar());
} }
} }
} }
...@@ -169,6 +167,13 @@ class PreparedOp { ...@@ -169,6 +167,13 @@ class PreparedOp {
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs); const framework::AttributeMap& default_attrs);
static PreparedOp Prepare(const NameVarMap<egr::EagerTensor>& ins,
const NameVarMap<egr::EagerTensor>& outs,
const framework::OperatorWithKernel& op,
const platform::Place& place,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs);
void Run(const NameVarMap<VarBase>& in, const NameVarMap<VarBase>& out, void Run(const NameVarMap<VarBase>& in, const NameVarMap<VarBase>& out,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs); const framework::AttributeMap& default_attrs);
...@@ -178,6 +183,11 @@ class PreparedOp { ...@@ -178,6 +183,11 @@ class PreparedOp {
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs); const framework::AttributeMap& default_attrs);
void Run(const NameVarMap<egr::EagerTensor>& ins,
const NameVarMap<egr::EagerTensor>& outs,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs);
const framework::OpKernelType& kernel_type() const { return kernel_type_; } const framework::OpKernelType& kernel_type() const { return kernel_type_; }
private: private:
...@@ -416,8 +426,8 @@ void PreparePtenData(const pten::Kernel& pt_kernel, ...@@ -416,8 +426,8 @@ void PreparePtenData(const pten::Kernel& pt_kernel,
auto& ins_vector = ins.at(input_names[i]); auto& ins_vector = ins.at(input_names[i]);
for (size_t offset = 0; offset < ins_vector.size(); ++offset) { for (size_t offset = 0; offset < ins_vector.size(); ++offset) {
auto var_base = ins_vector[offset]; auto var = ins_vector[offset];
const auto* tensor_in = GetTensorFromVar(var_base->Var()); const auto* tensor_in = GetTensorFromVar(var->Var());
if (tensor_in && tensor_in->IsInitialized()) { if (tensor_in && tensor_in->IsInitialized()) {
auto expected_place = pten::TransToFluidPlace(in_def.backend); auto expected_place = pten::TransToFluidPlace(in_def.backend);
if (platform::is_same_place(tensor_in->place(), expected_place)) { if (platform::is_same_place(tensor_in->place(), expected_place)) {
...@@ -430,8 +440,7 @@ void PreparePtenData(const pten::Kernel& pt_kernel, ...@@ -430,8 +440,7 @@ void PreparePtenData(const pten::Kernel& pt_kernel,
framework::Tensor tmp_tensor; framework::Tensor tmp_tensor;
framework::TensorCopySync(*tensor_in, expected_place, &tmp_tensor); framework::TensorCopySync(*tensor_in, expected_place, &tmp_tensor);
SetTensorToVariable(var_base->Var(), tmp_tensor, SetTensorToVariable(var->Var(), tmp_tensor, var->MutableVar());
var_base->MutableVar());
} }
} }
} }
......
...@@ -17,7 +17,7 @@ cc_test(test_layer SRCS test_layer.cc DEPS layer proto_desc operator op_registry ...@@ -17,7 +17,7 @@ cc_test(test_layer SRCS test_layer.cc DEPS layer proto_desc operator op_registry
cc_test(test_prepare_op SRCS test_prepare_op.cc DEPS prepared_operator op_info split_op layer concat_and_split activation_op place) cc_test(test_prepare_op SRCS test_prepare_op.cc DEPS prepared_operator op_info split_op layer concat_and_split activation_op place)
cc_test(test_tracer SRCS test_tracer.cc DEPS tracer layer proto_desc operator op_registry variable_helper mul_op reduce_sum_op elementwise_add_op memcpy) cc_test(test_tracer SRCS test_tracer.cc DEPS tracer layer proto_desc operator op_registry variable_helper mul_op reduce_sum_op elementwise_add_op memcpy)
cc_test(test_hooks SRCS test_hooks.cc DEPS tracer basic_engine layer proto_desc operator op_registry variable_helper mul_op elementwise_add_op memcpy) cc_test(test_hooks SRCS test_hooks.cc DEPS tracer basic_engine layer proto_desc operator op_registry variable_helper mul_op elementwise_add_op memcpy)
cc_test(test_eager SRCS test_eager.cc DEPS tracer layer prepared_operator mul_op)
if (WITH_NCCL OR WITH_RCCL OR WITH_XPU_BKCL) if (WITH_NCCL OR WITH_RCCL OR WITH_XPU_BKCL)
cc_test(test_group SRCS test_group.cc DEPS reducer concat_and_split memcpy) cc_test(test_group SRCS test_group.cc DEPS reducer concat_and_split memcpy)
endif() endif()
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/imperative/basic_engine.h"
#include "paddle/fluid/imperative/execution_context.h"
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/imperative/var_helper.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/pten/core/compat/type_defs.h"
namespace paddle {
namespace imperative {
extern std::string LayerDebugString(const std::string& op_type,
const NameVarMap<egr::EagerTensor>& ins,
const NameVarMap<egr::EagerTensor>& outs);
extern std::shared_ptr<GradOpNode> CreateGradOpNode(
const framework::OperatorBase& op, const NameTensorMap& ins,
const NameTensorMap& outs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs, const platform::Place& place,
const std::map<std::string, std::string>& inplace_map);
TEST(test_eager, eager_debug) {
std::shared_ptr<egr::EagerTensor> x_in(new egr::EagerTensor("x_in"));
std::shared_ptr<egr::EagerTensor> y_in(new egr::EagerTensor("y_in"));
std::shared_ptr<egr::EagerTensor> vout(new egr::EagerTensor("vout"));
imperative::NameVarMap<egr::EagerTensor> ins = {{"X", {x_in}}, {"Y", {y_in}}};
imperative::NameVarMap<egr::EagerTensor> outs = {{"Out", {vout}}};
LayerDebugString("mul", ins, outs);
}
TEST(test_create_node, eager_node) {
auto op = framework::OpRegistry::CreateOp("mul", {}, {}, {}, false);
framework::Scope scope;
auto ctx = framework::RuntimeContext({}, {});
imperative::NameVarMap<egr::EagerTensor> ins = {{"X", {nullptr}},
{"Y", {nullptr}}};
imperative::NameVarMap<egr::EagerTensor> outs = {{"Out", {nullptr}}};
CreateGradOpNode((*op.get()), ins, outs, framework::AttributeMap{},
framework::AttributeMap{}, platform::CPUPlace(), {});
}
TEST(test_var_helper, eager_var_helper) {
framework::Variable var0, var1, var2, var3, var4, var5, var6, var7, var8;
InitializeVariable(&var0, paddle::framework::proto::VarType::FEED_MINIBATCH);
InitializeVariable(&var1, paddle::framework::proto::VarType::STEP_SCOPES);
InitializeVariable(&var2, paddle::framework::proto::VarType::LOD_RANK_TABLE);
InitializeVariable(&var3,
paddle::framework::proto::VarType::LOD_TENSOR_ARRAY);
InitializeVariable(&var4, paddle::framework::proto::VarType::STRINGS);
InitializeVariable(&var5, paddle::framework::proto::VarType::VOCAB);
InitializeVariable(&var6, paddle::framework::proto::VarType::READER);
InitializeVariable(&var7, paddle::framework::proto::VarType::RAW);
ASSERT_ANY_THROW(
InitializeVariable(&var8, paddle::framework::proto::VarType::FP64));
auto egr_tensor = std::make_shared<egr::EagerTensor>();
auto egr_tensor2 = std::make_shared<egr::EagerTensor>();
egr_tensor->MutableVar()
->GetMutable<pten::SelectedRows>()
->mutable_value()
->mutable_data<float>(platform::CPUPlace());
egr_tensor2->MutableVar()->GetMutable<framework::LoDRankTable>();
VLOG(6) << "egr_tensor create with ";
ASSERT_TRUE(platform::is_cpu_place(GetPlace<egr::EagerTensor>(egr_tensor)));
ASSERT_TRUE(GetDataType<egr::EagerTensor>(egr_tensor) ==
framework::proto::VarType::FP32);
GetCachedValue<egr::EagerTensor>(
egr_tensor, framework::OpKernelType(framework::proto::VarType::FP32,
platform::CPUPlace()));
SetCachedValue<egr::EagerTensor>(
egr_tensor, framework::OpKernelType(framework::proto::VarType::FP32,
platform::CPUPlace()),
egr_tensor2);
ASSERT_ANY_THROW(GetPlace<egr::EagerTensor>(egr_tensor2));
ASSERT_ANY_THROW(SetType<egr::EagerTensor>(
egr_tensor, paddle::framework::proto::VarType::LOD_TENSOR_ARRAY));
}
} // namespace imperative
} // namespace paddle
USE_OP(mul);
...@@ -107,7 +107,7 @@ TEST(TestHooks, TestGradVarLeafBackwardHook) { ...@@ -107,7 +107,7 @@ TEST(TestHooks, TestGradVarLeafBackwardHook) {
std::make_shared<std::function<void()>>([&]() { hook_value = 10; })); std::make_shared<std::function<void()>>([&]() { hook_value = 10; }));
// 2. forward // 2. forward
tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true); tracer.TraceOp<VarBase>("mul", ins, outs, mul_attr_map, place, true);
ASSERT_EQ(x->GradVarBase()->GradOpNum(), 0UL); ASSERT_EQ(x->GradVarBase()->GradOpNum(), 0UL);
ASSERT_EQ(y->GradVarBase()->GradOpNum(), 0UL); ASSERT_EQ(y->GradVarBase()->GradOpNum(), 0UL);
...@@ -194,13 +194,13 @@ void GradVarLeafBackwardHookWithGradAccmulatedTest() { ...@@ -194,13 +194,13 @@ void GradVarLeafBackwardHookWithGradAccmulatedTest() {
NameVarBaseMap outs = {out_xy_pair}; NameVarBaseMap outs = {out_xy_pair};
framework::AttributeMap mul_attr_map; framework::AttributeMap mul_attr_map;
mul_attr_map["use_mkldnn"] = false; mul_attr_map["use_mkldnn"] = false;
tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true); tracer.TraceOp<VarBase>("mul", ins, outs, mul_attr_map, place, true);
var_pair z_pair = var_pair("Y", vb_vector(1, z)); var_pair z_pair = var_pair("Y", vb_vector(1, z));
var_pair out_xz_pair = var_pair("Out", vb_vector(1, out_xz)); var_pair out_xz_pair = var_pair("Out", vb_vector(1, out_xz));
ins = {x_pair, z_pair}; ins = {x_pair, z_pair};
outs = {out_xz_pair}; outs = {out_xz_pair};
tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true); tracer.TraceOp<VarBase>("mul", ins, outs, mul_attr_map, place, true);
var_pair xy_pair = var_pair("X", vb_vector(1, out_xy)); var_pair xy_pair = var_pair("X", vb_vector(1, out_xy));
var_pair xz_pair = var_pair("Y", vb_vector(1, out_xz)); var_pair xz_pair = var_pair("Y", vb_vector(1, out_xz));
...@@ -208,7 +208,8 @@ void GradVarLeafBackwardHookWithGradAccmulatedTest() { ...@@ -208,7 +208,8 @@ void GradVarLeafBackwardHookWithGradAccmulatedTest() {
ins = {xy_pair, xz_pair}; ins = {xy_pair, xz_pair};
outs = {out_pair}; outs = {out_pair};
framework::AttributeMap add_attr_map; framework::AttributeMap add_attr_map;
tracer.TraceOp("elementwise_add", ins, outs, add_attr_map, place, true); tracer.TraceOp<VarBase>("elementwise_add", ins, outs, add_attr_map, place,
true);
ASSERT_EQ(x->GradVarBase()->GradOpNum(), 0UL); ASSERT_EQ(x->GradVarBase()->GradOpNum(), 0UL);
ASSERT_EQ(y->GradVarBase()->GradOpNum(), 0UL); ASSERT_EQ(y->GradVarBase()->GradOpNum(), 0UL);
......
...@@ -143,7 +143,8 @@ TEST(test_layer, test_runtime_context) { ...@@ -143,7 +143,8 @@ TEST(test_layer, test_runtime_context) {
ctx->SyncTypeAndDataType("X", "Out"); ctx->SyncTypeAndDataType("X", "Out");
ASSERT_EQ(framework::proto::VarType::FP32, vout->DataType()); // Remove DataType check, because it doesn't make sense of set dtype in
// dygraph
ASSERT_EQ(framework::proto::VarType::LOD_TENSOR, ctx->GetOutputType("Out")); ASSERT_EQ(framework::proto::VarType::LOD_TENSOR, ctx->GetOutputType("Out"));
...@@ -157,8 +158,8 @@ TEST(test_layer, test_runtime_context) { ...@@ -157,8 +158,8 @@ TEST(test_layer, test_runtime_context) {
framework::ALL_ELEMENTS); framework::ALL_ELEMENTS);
ctx->SetOutputDataType("Out", framework::proto::VarType::INT8); ctx->SetOutputDataType("Out", framework::proto::VarType::INT8);
ASSERT_EQ(framework::proto::VarType::INT8, vout->DataType()); // Remove DataType check, because it doesn't make sense of set dtype in
ASSERT_EQ(framework::proto::VarType::FP64, vout_b->DataType()); // dygraph
// no throw, but do nothing // no throw, but do nothing
ASSERT_NO_THROW( ASSERT_NO_THROW(
......
...@@ -16,17 +16,18 @@ ...@@ -16,17 +16,18 @@
// Created by Jiabin on 2019-08-16. // Created by Jiabin on 2019-08-16.
// //
#include <paddle/fluid/framework/op_registry.h>
#include <memory> #include <memory>
#include <set> #include <set>
#include <string> #include <string>
#include <vector> #include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/imperative/basic_engine.h" #include "paddle/fluid/imperative/basic_engine.h"
#include "paddle/fluid/imperative/execution_context.h"
#include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/device_context.h"
namespace imperative = paddle::imperative; namespace imperative = paddle::imperative;
namespace platform = paddle::platform; namespace platform = paddle::platform;
...@@ -71,11 +72,11 @@ TEST(test_tracer, test_trace_op) { ...@@ -71,11 +72,11 @@ TEST(test_tracer, test_trace_op) {
imperative::NameVarBaseMap outs = {out_pair}; imperative::NameVarBaseMap outs = {out_pair};
framework::AttributeMap mul_attr_map; framework::AttributeMap mul_attr_map;
mul_attr_map["use_mkldnn"] = false; mul_attr_map["use_mkldnn"] = false;
tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true); tracer.TraceOp<VarBase>("mul", ins, outs, mul_attr_map, place, true);
#ifndef PADDLE_WITH_XPU #ifndef PADDLE_WITH_XPU
ASSERT_THROW(tracer.TraceOp("mul", ins, outs, mul_attr_map, ASSERT_THROW(tracer.TraceOp<VarBase>("mul", ins, outs, mul_attr_map,
platform::XPUPlace(0), true); platform::XPUPlace(0), true);
, platform::EnforceNotMet); , platform::EnforceNotMet);
#endif #endif
...@@ -117,7 +118,7 @@ TEST(test_tracer, test_trace_op_with_backward) { ...@@ -117,7 +118,7 @@ TEST(test_tracer, test_trace_op_with_backward) {
imperative::NameVarBaseMap outs = {out_pair}; imperative::NameVarBaseMap outs = {out_pair};
framework::AttributeMap mul_attr_map; framework::AttributeMap mul_attr_map;
mul_attr_map["use_mkldnn"] = false; mul_attr_map["use_mkldnn"] = false;
tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true); tracer.TraceOp<VarBase>("mul", ins, outs, mul_attr_map, place, true);
const auto& out_tensor = vout->Var().Get<framework::LoDTensor>(); const auto& out_tensor = vout->Var().Get<framework::LoDTensor>();
for (int i = 0; i < vout->Var().Get<framework::LoDTensor>().numel(); i++) { for (int i = 0; i < vout->Var().Get<framework::LoDTensor>().numel(); i++) {
ASSERT_EQ(out_tensor.data<float>()[i], 20.0); ASSERT_EQ(out_tensor.data<float>()[i], 20.0);
...@@ -157,7 +158,7 @@ TEST(test_tracer, test_track_backward_output) { ...@@ -157,7 +158,7 @@ TEST(test_tracer, test_track_backward_output) {
imperative::NameVarBaseMap outs = {out_pair}; imperative::NameVarBaseMap outs = {out_pair};
framework::AttributeMap mul_attr_map; framework::AttributeMap mul_attr_map;
mul_attr_map["use_mkldnn"] = false; mul_attr_map["use_mkldnn"] = false;
tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true); tracer.TraceOp<VarBase>("mul", ins, outs, mul_attr_map, place, true);
ASSERT_EQ(x_in->GradVarBase()->GradOpNum(), 0UL); ASSERT_EQ(x_in->GradVarBase()->GradOpNum(), 0UL);
ASSERT_EQ(y_in->GradVarBase()->GradOpNum(), 0UL); ASSERT_EQ(y_in->GradVarBase()->GradOpNum(), 0UL);
ASSERT_EQ(vout->GradVarBase()->GradOpNum(), 1UL); ASSERT_EQ(vout->GradVarBase()->GradOpNum(), 1UL);
...@@ -196,7 +197,7 @@ TEST(test_tracer, test_track_backward_input) { ...@@ -196,7 +197,7 @@ TEST(test_tracer, test_track_backward_input) {
imperative::NameVarBaseMap outs = {out_pair}; imperative::NameVarBaseMap outs = {out_pair};
framework::AttributeMap mul_attr_map; framework::AttributeMap mul_attr_map;
mul_attr_map["use_mkldnn"] = false; mul_attr_map["use_mkldnn"] = false;
tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true); tracer.TraceOp<VarBase>("mul", ins, outs, mul_attr_map, place, true);
ASSERT_EQ(x_in->GradVarBase()->GradOpNum(), 0UL); ASSERT_EQ(x_in->GradVarBase()->GradOpNum(), 0UL);
ASSERT_EQ(y_in->GradVarBase()->GradOpNum(), 0UL); ASSERT_EQ(y_in->GradVarBase()->GradOpNum(), 0UL);
...@@ -237,7 +238,8 @@ TEST(test_tracer, test_trace_op_with_multi_device_inputs) { ...@@ -237,7 +238,8 @@ TEST(test_tracer, test_trace_op_with_multi_device_inputs) {
imperative::NameVarBaseMap outs = {out_pair}; imperative::NameVarBaseMap outs = {out_pair};
framework::AttributeMap mul_attr_map; framework::AttributeMap mul_attr_map;
mul_attr_map["use_mkldnn"] = false; mul_attr_map["use_mkldnn"] = false;
tracer.TraceOp("elementwise_add", ins, outs, mul_attr_map, gpu_place, true); tracer.TraceOp<VarBase>("elementwise_add", ins, outs, mul_attr_map, gpu_place,
true);
// run reduce sum // run reduce sum
std::shared_ptr<imperative::VarBase> reduce_sum_out( std::shared_ptr<imperative::VarBase> reduce_sum_out(
...@@ -247,8 +249,8 @@ TEST(test_tracer, test_trace_op_with_multi_device_inputs) { ...@@ -247,8 +249,8 @@ TEST(test_tracer, test_trace_op_with_multi_device_inputs) {
imperative::NameVarBaseMap reduce_in = {reduce_sum_in_pair}; imperative::NameVarBaseMap reduce_in = {reduce_sum_in_pair};
imperative::NameVarBaseMap reduce_out = {reduce_sum_out_pair}; imperative::NameVarBaseMap reduce_out = {reduce_sum_out_pair};
framework::AttributeMap reduce_attr_map; framework::AttributeMap reduce_attr_map;
tracer.TraceOp("reduce_sum", reduce_in, reduce_out, reduce_attr_map, tracer.TraceOp<VarBase>("reduce_sum", reduce_in, reduce_out, reduce_attr_map,
gpu_place, true); gpu_place, true);
imperative::BasicEngine engine; imperative::BasicEngine engine;
std::vector<std::shared_ptr<imperative::VarBase>> tensors{reduce_sum_out}; std::vector<std::shared_ptr<imperative::VarBase>> tensors{reduce_sum_out};
...@@ -368,7 +370,7 @@ TEST(test_tracer, test_var_without_grad_var) { ...@@ -368,7 +370,7 @@ TEST(test_tracer, test_var_without_grad_var) {
imperative::NameVarBaseMap outs = {out_pair}; imperative::NameVarBaseMap outs = {out_pair};
framework::AttributeMap mul_attr_map; framework::AttributeMap mul_attr_map;
mul_attr_map["use_mkldnn"] = false; mul_attr_map["use_mkldnn"] = false;
tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true); tracer.TraceOp<VarBase>("mul", ins, outs, mul_attr_map, place, true);
const auto& out_tensor = vout->Var().Get<framework::LoDTensor>(); const auto& out_tensor = vout->Var().Get<framework::LoDTensor>();
for (int i = 0; i < vout->Var().Get<framework::LoDTensor>().numel(); i++) { for (int i = 0; i < vout->Var().Get<framework::LoDTensor>().numel(); i++) {
...@@ -439,9 +441,9 @@ static void TestVarOpDestructionMain(const platform::Place& place, ...@@ -439,9 +441,9 @@ static void TestVarOpDestructionMain(const platform::Place& place,
size_t op_base_num = op_bases.size(); size_t op_base_num = op_bases.size();
auto z = std::make_shared<VarBase>("z_" + std::to_string(i)); auto z = std::make_shared<VarBase>("z_" + std::to_string(i));
tracer.TraceOp("mul", NameVarBaseMap{{"X", {x}}, {"Y", {y}}}, tracer.TraceOp<VarBase>("mul", NameVarBaseMap{{"X", {x}}, {"Y", {y}}},
NameVarBaseMap{{"Out", {z}}}, framework::AttributeMap{}, NameVarBaseMap{{"Out", {z}}},
place, true); framework::AttributeMap{}, place, true);
ASSERT_EQ(z->GradOpNum(), 0UL); ASSERT_EQ(z->GradOpNum(), 0UL);
ASSERT_EQ(z->GradVarBase()->GradOpNum(), 1UL); ASSERT_EQ(z->GradVarBase()->GradOpNum(), 1UL);
...@@ -530,6 +532,20 @@ TEST(test_tracer, test_var_op_destruction) { ...@@ -530,6 +532,20 @@ TEST(test_tracer, test_var_op_destruction) {
#endif #endif
} }
TEST(test_tracer, test_execution_context) {
auto op = framework::OpRegistry::CreateOp("mul", {}, {}, {}, false);
framework::Scope scope;
auto ctx = framework::RuntimeContext({}, {});
NameVarBaseMap ins = {{"X", {nullptr}}, {"Y", {nullptr}}};
NameVarBaseMap outs = {{"Out", {nullptr}}};
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(platform::CPUPlace());
auto dy_ctx = DygraphExecutionContext<VarBase>(
(*op.get()), scope, *dev_ctx, ctx, ins, outs, framework::AttributeMap{},
framework::AttributeMap{});
ASSERT_EQ(dy_ctx.OutputName("Out"), framework::kEmptyVarName);
}
} // namespace imperative } // namespace imperative
} // namespace paddle } // namespace paddle
......
...@@ -149,10 +149,14 @@ paddle::framework::GarbageCollector* Tracer::MutableGarbageCollectorIfNotExists( ...@@ -149,10 +149,14 @@ paddle::framework::GarbageCollector* Tracer::MutableGarbageCollectorIfNotExists(
return gcs_.at(place).get(); return gcs_.at(place).get();
} }
void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, template <typename VarType>
const NameVarBaseMap& outs, framework::AttributeMap attrs, void Tracer::TraceOp(const std::string& type, const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs,
framework::AttributeMap attrs,
const platform::Place& place, bool trace_backward, const platform::Place& place, bool trace_backward,
const std::map<std::string, std::string>& inplace_map) { const std::map<std::string, std::string>& inplace_map,
paddle::framework::AttributeMap* passed_default_attrs_,
bool override_default_attr_map) {
platform::RecordEvent op_type_record_event(type); platform::RecordEvent op_type_record_event(type);
platform::ScopedFlushDenormal flush; platform::ScopedFlushDenormal flush;
VLOG(1) << "Trace Op: " << type; VLOG(1) << "Trace Op: " << type;
...@@ -181,13 +185,13 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, ...@@ -181,13 +185,13 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
attr_checker == nullptr ? empty_attrs_map attr_checker == nullptr ? empty_attrs_map
: attr_checker->GetDefaultAttrMap(); : attr_checker->GetDefaultAttrMap();
NameVarBaseMap new_ins = ins; NameVarMap<VarType> new_ins = ins;
if (amp_level_ == AmpLevel::O1) { if (amp_level_ == AmpLevel::O1) {
VLOG(5) << "Auto mixed precision run operator: " << type; VLOG(5) << "Auto mixed precision run operator: " << type;
new_ins = AutoCastInputs(type, ins); new_ins = AutoCastInputs<VarType>(type, ins);
} else if (amp_level_ == AmpLevel::O2) { } else if (amp_level_ == AmpLevel::O2) {
VLOG(5) << "Pure fp16 run operator: " << type; VLOG(5) << "Pure fp16 run operator: " << type;
new_ins = CastPureFp16Inputs(type, ins); new_ins = CastPureFp16Inputs<VarType>(type, ins);
} }
try { try {
...@@ -220,8 +224,20 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, ...@@ -220,8 +224,20 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
"PaddlePaddle should compile with MLU if use MLUPlace.")); "PaddlePaddle should compile with MLU if use MLUPlace."));
#endif #endif
} }
if (!override_default_attr_map) {
OpBase::Run(*op, new_ins, outs, attrs, default_attrs, place); PADDLE_ENFORCE_NOT_NULL(passed_default_attrs_,
paddle::platform::errors::PermissionDenied(
"Detected default_attrs = nullptr."));
VLOG(6) << "Use passed in default attrs";
OpBase::Run(*op, new_ins, outs, attrs, (*passed_default_attrs_), place);
} else {
VLOG(6) << "Use Checker's default attrs";
if (passed_default_attrs_) {
// TODO(jiabin): Update this without copy
*passed_default_attrs_ = default_attrs;
}
OpBase::Run(*op, new_ins, outs, attrs, default_attrs, place);
}
} catch (platform::EnforceNotMet& exception) { } catch (platform::EnforceNotMet& exception) {
framework::AppendErrorOpHint(type, &exception); framework::AppendErrorOpHint(type, &exception);
throw std::move(exception); throw std::move(exception);
...@@ -249,13 +265,53 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, ...@@ -249,13 +265,53 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
} else { } else {
VLOG(3) << "No Grad to track for Op: " << type; VLOG(3) << "No Grad to track for Op: " << type;
} }
VLOG(6) << "Finish Trace Op: " << type;
} }
template void Tracer::TraceOp<VarBase>(
const std::string& type, const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs, framework::AttributeMap attrs,
const platform::Place& place, bool trace_backward,
const std::map<std::string, std::string>& inplace_map,
paddle::framework::AttributeMap* default_attrs,
bool override_default_attr_map);
template void Tracer::TraceOp<egr::EagerTensor>(
const std::string& type, const NameVarMap<egr::EagerTensor>& ins,
const NameVarMap<egr::EagerTensor>& outs, framework::AttributeMap attrs,
const platform::Place& place, bool trace_backward,
const std::map<std::string, std::string>& inplace_map_,
paddle::framework::AttributeMap* default_attrs,
bool override_default_attr_map);
void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
const NameVarBaseMap& outs, framework::AttributeMap attrs, const NameVarBaseMap& outs, framework::AttributeMap attrs,
const std::map<std::string, std::string>& inplace_map) { const std::map<std::string, std::string>& inplace_map) {
TraceOp(type, ins, outs, std::move(attrs), expected_place_, has_grad_, TraceOp<VarBase>(type, ins, outs, std::move(attrs), expected_place_,
inplace_map); has_grad_, inplace_map);
}
void Tracer::TraceOp(const std::string& type, const NameTensorMap& ins,
const NameTensorMap& outs,
paddle::framework::AttributeMap attrs,
const paddle::platform::Place& place,
paddle::framework::AttributeMap* default_attrs,
bool override_default_attr_map,
const std::map<std::string, std::string>& inplace_map) {
VLOG(6) << "Running On Eager TraceOp with override_default_attr_map: "
<< override_default_attr_map;
TraceOp<egr::EagerTensor>(type, ins, outs, std::move(attrs), place, false,
inplace_map, default_attrs,
override_default_attr_map);
}
void Tracer::TraceOp(const std::string& type, const NameTensorMap& ins,
const NameTensorMap& outs,
paddle::framework::AttributeMap attrs,
const std::map<std::string, std::string>& inplace_map) {
VLOG(6) << "Running On Eager TraceOp(less): ";
TraceOp<egr::EagerTensor>(type, ins, outs, std::move(attrs), expected_place_,
false, inplace_map, nullptr, true);
} }
void Tracer::SetExpectedPlace(platform::Place place) { void Tracer::SetExpectedPlace(platform::Place place) {
...@@ -280,5 +336,11 @@ bool Tracer::ComputeRequiredGrad(const NameVarBaseMap& ins, ...@@ -280,5 +336,11 @@ bool Tracer::ComputeRequiredGrad(const NameVarBaseMap& ins,
return false; return false;
} }
bool Tracer::ComputeRequiredGrad(const NameTensorMap& ins,
const NameTensorMap& outs,
bool trace_backward) {
return false;
}
} // namespace imperative } // namespace imperative
} // namespace paddle } // namespace paddle
...@@ -63,17 +63,33 @@ class Tracer { ...@@ -63,17 +63,33 @@ class Tracer {
~Tracer() = default; ~Tracer() = default;
template <typename VarType>
void TraceOp(const std::string& type, const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs, framework::AttributeMap attrs,
const platform::Place& place, bool trace_backward,
const std::map<std::string, std::string>& inplace_map = {},
paddle::framework::AttributeMap* passed_default_attrs_ = nullptr,
bool override_default_attr_map = true);
void TraceOp(const std::string& type, const NameVarBaseMap& ins, void TraceOp(const std::string& type, const NameVarBaseMap& ins,
const NameVarBaseMap& outs, framework::AttributeMap attrs, const NameVarBaseMap& outs, framework::AttributeMap attrs,
const platform::Place& place, bool trace_bacward,
const std::map<std::string, std::string>& inplace_map = {}); const std::map<std::string, std::string>& inplace_map = {});
void TraceOp(const std::string& type, const NameVarBaseMap& ins, void TraceOp(const std::string& type, const NameTensorMap& ins,
const NameVarBaseMap& outs, framework::AttributeMap attrs, const NameTensorMap& outs, paddle::framework::AttributeMap attrs,
const std::map<std::string, std::string>& inplace_map = {});
void TraceOp(const std::string& type, const NameTensorMap& ins,
const NameTensorMap& outs, paddle::framework::AttributeMap attrs,
const paddle::platform::Place& place,
paddle::framework::AttributeMap* default_attrs,
bool override_default_attr_map,
const std::map<std::string, std::string>& inplace_map = {}); const std::map<std::string, std::string>& inplace_map = {});
bool ComputeRequiredGrad(const NameVarBaseMap& ins, bool ComputeRequiredGrad(const NameVarBaseMap& ins,
const NameVarBaseMap& outs, bool trace_backward); const NameVarBaseMap& outs, bool trace_backward);
bool ComputeRequiredGrad(const NameTensorMap& ins, const NameTensorMap& outs,
bool trace_backward);
void SetEnableProgramDescTracing(bool enabled) { void SetEnableProgramDescTracing(bool enabled) {
enable_program_desc_tracing_ = enabled; enable_program_desc_tracing_ = enabled;
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/imperative/var_helper.h"
#include "paddle/fluid/eager/eager_tensor.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/pten/core/selected_rows.h"
namespace paddle {
namespace imperative {
/* GetVariableWrapper */
template <>
const std::shared_ptr<VariableWrapper> &GetVariableWrapper<VarBase>(
const std::shared_ptr<VarBase> &var) {
return var->SharedVar();
}
template <>
const std::shared_ptr<VariableWrapper> &GetVariableWrapper<VariableWrapper>(
const std::shared_ptr<VariableWrapper> &var) {
return var;
}
void InitializeVariable(paddle::framework::Variable *var,
paddle::framework::proto::VarType::Type var_type) {
if (var_type == paddle::framework::proto::VarType::LOD_TENSOR) {
var->GetMutable<paddle::framework::LoDTensor>();
} else if (var_type == paddle::framework::proto::VarType::SELECTED_ROWS) {
var->GetMutable<pten::SelectedRows>();
} else if (var_type == paddle::framework::proto::VarType::FEED_MINIBATCH) {
var->GetMutable<paddle::framework::FeedList>();
} else if (var_type == paddle::framework::proto::VarType::FETCH_LIST) {
var->GetMutable<paddle::framework::FetchList>();
} else if (var_type == paddle::framework::proto::VarType::STEP_SCOPES) {
var->GetMutable<std::vector<paddle::framework::Scope *>>();
} else if (var_type == paddle::framework::proto::VarType::LOD_RANK_TABLE) {
var->GetMutable<paddle::framework::LoDRankTable>();
} else if (var_type == paddle::framework::proto::VarType::LOD_TENSOR_ARRAY) {
var->GetMutable<paddle::framework::LoDTensorArray>();
} else if (var_type == paddle::framework::proto::VarType::STRINGS) {
var->GetMutable<paddle::framework::Strings>();
} else if (var_type == paddle::framework::proto::VarType::VOCAB) {
var->GetMutable<paddle::framework::Vocab>();
} else if (var_type == paddle::framework::proto::VarType::PLACE_LIST) {
var->GetMutable<paddle::platform::PlaceList>();
} else if (var_type == paddle::framework::proto::VarType::READER) {
var->GetMutable<paddle::framework::ReaderHolder>();
} else if (var_type == paddle::framework::proto::VarType::RAW) {
// GetMutable will be called in operator
} else {
PADDLE_THROW(paddle::platform::errors::Unavailable(
"paddle::framework::Variable type %d is not in "
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
"LOD_RANK_TABLE, PLACE_LIST, READER, RAW].",
var_type));
}
}
/* GetPlace */
template <typename VarType>
const paddle::platform::Place &GetPlace(const std::shared_ptr<VarType> &var) {
paddle::framework::Variable variable = var->Var();
if (variable.IsType<paddle::framework::LoDTensor>()) {
return variable.Get<paddle::framework::LoDTensor>().place();
} else if (variable.IsType<pten::SelectedRows>()) {
return variable.Get<pten::SelectedRows>().place();
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Variable type is %s, expect LoDTensor or SelectedRows.",
paddle::framework::ToTypeName(var->Var().Type())));
}
}
template const paddle::platform::Place &GetPlace<VarBase>(
const std::shared_ptr<VarBase> &var);
template const paddle::platform::Place &GetPlace<VariableWrapper>(
const std::shared_ptr<VariableWrapper> &var);
template const paddle::platform::Place &GetPlace<egr::EagerTensor>(
const std::shared_ptr<egr::EagerTensor> &var);
/* GetNameFromVar */
template <typename VarType>
const std::string &GetNameFromVar(std::shared_ptr<VarType> var) {
return var->Name();
}
template <>
const std::string &GetNameFromVar<egr::EagerTensor>(
std::shared_ptr<egr::EagerTensor> tensor) {
return tensor->name();
}
template const std::string &GetNameFromVar<VariableWrapper>(
std::shared_ptr<VariableWrapper> var);
template const std::string &GetNameFromVar<VarBase>(
std::shared_ptr<VarBase> var);
/* SetType */
template <typename VarType>
void SetType(std::shared_ptr<VarType> var,
framework::proto::VarType::Type type) {
var->SetType(type);
}
template <>
void SetType<egr::EagerTensor>(std::shared_ptr<egr::EagerTensor> var,
framework::proto::VarType::Type type) {
switch (type) {
case paddle::framework::proto::VarType::LOD_TENSOR: {
var->MutableVar()->GetMutable<paddle::framework::LoDTensor>();
break;
}
case paddle::framework::proto::VarType::SELECTED_ROWS: {
var->MutableVar()->GetMutable<pten::SelectedRows>();
break;
}
default: {
PADDLE_THROW(paddle::platform::errors::NotFound(
"Cannot found var type: %s while running runtime InferVarType",
paddle::framework::ToTypeName(type)));
}
}
}
template void SetType<VarBase>(std::shared_ptr<VarBase> var,
framework::proto::VarType::Type type);
template void SetType<VariableWrapper>(std::shared_ptr<VariableWrapper> var,
framework::proto::VarType::Type type);
/* GetType */
template <typename VarType>
framework::proto::VarType::Type GetType(std::shared_ptr<VarType> var) {
return var->Type();
}
template <>
framework::proto::VarType::Type GetType<egr::EagerTensor>(
std::shared_ptr<egr::EagerTensor> var) {
if (var->Var().IsInitialized()) {
return paddle::framework::ToVarType(var->Var().Type());
} else {
return paddle::framework::proto::VarType::LOD_TENSOR;
}
}
template framework::proto::VarType::Type GetType<VarBase>(
std::shared_ptr<VarBase> var);
template framework::proto::VarType::Type GetType<VariableWrapper>(
std::shared_ptr<VariableWrapper> var);
/* GetDataType */
template <typename VarType>
framework::proto::VarType::Type GetDataType(std::shared_ptr<VarType> var) {
return var->DataType();
}
template <>
framework::proto::VarType::Type GetDataType<egr::EagerTensor>(
std::shared_ptr<egr::EagerTensor> var) {
if (var->Var().IsType<pten::SelectedRows>()) {
return var->Var().Get<pten::SelectedRows>().value().type();
} else if (var->Var().IsType<framework::LoDTensor>()) {
return var->Var().Get<framework::LoDTensor>().type();
} else {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"We only support pten::SelectedRows and framework::LoDTensor in "
"eager mode, but we got %s here, please checkout your var type of "
"tensor: %s",
paddle::framework::ToTypeName(framework::ToVarType(var->Var().Type())),
var->name()));
}
}
template framework::proto::VarType::Type GetDataType<VarBase>(
std::shared_ptr<VarBase> var);
template framework::proto::VarType::Type GetDataType<VariableWrapper>(
std::shared_ptr<VariableWrapper> var);
/* CheckCachedKey */
template <typename VarType>
bool CheckCachedKey(std::shared_ptr<VarType> var,
const paddle::framework::OpKernelType &key) {
return GetVariableWrapper(var)->hasCacheKey(key);
}
template <>
bool CheckCachedKey<egr::EagerTensor>(
std::shared_ptr<egr::EagerTensor> tensor,
const paddle::framework::OpKernelType &key) {
// TODO(jiabin): Support this later
// VLOG(10) << "CheckCachedKey with tensor: " << tensor->name() << "and key is
// equal to self: " << key == key.
return false;
}
template bool CheckCachedKey<VarBase>(
std::shared_ptr<VarBase> var, const paddle::framework::OpKernelType &key);
template bool CheckCachedKey<VariableWrapper>(
std::shared_ptr<VariableWrapper> var,
const paddle::framework::OpKernelType &key);
/* GetCachedValue */
template <typename VarType>
std::shared_ptr<VariableWrapper> GetCachedValue(
std::shared_ptr<VarType> var, const paddle::framework::OpKernelType &key) {
return GetVariableWrapper(var)->getCacheValue(key);
}
template <>
std::shared_ptr<VariableWrapper> GetCachedValue(
std::shared_ptr<egr::EagerTensor> var,
const paddle::framework::OpKernelType &key) {
// TODO(jiabin): Support this later
// PADDLE_THROW(platform::errors::Fatal("In eager mode program should not
// reach this, support cache and remove this error check later, or this
// should not be supported."));
// VLOG(10) << "CheckCachedKey with tensor: " << tensor->name() << "and key
// is equal to self: " << key == key.
return std::make_shared<VariableWrapper>("");
}
template std::shared_ptr<VariableWrapper> GetCachedValue<VarBase>(
std::shared_ptr<VarBase> var, const paddle::framework::OpKernelType &key);
template std::shared_ptr<VariableWrapper> GetCachedValue<VariableWrapper>(
std::shared_ptr<VariableWrapper> var,
const paddle::framework::OpKernelType &key);
/* SetCachedValue */
template <typename VarType>
void SetCachedValue(std::shared_ptr<VarType> var,
const paddle::framework::OpKernelType &key,
std::shared_ptr<VarType> res) {
GetVariableWrapper(var)->setCacheValue(key, GetVariableWrapper(res));
}
template <>
void SetCachedValue<egr::EagerTensor>(
std::shared_ptr<egr::EagerTensor> tensor,
const paddle::framework::OpKernelType &key,
std::shared_ptr<egr::EagerTensor> res) {
// PADDLE_THROW(platform::errors::Fatal("In eager mode program should not
// reach this, support cache and remove this error check later, or this
// should not be supported."));
// VLOG(10) << "CheckCachedKey with tensor: " << tensor->name() << "and key
// is equal to self: " << key == key << " and res name is:" << res->Name().
}
template void SetCachedValue<VarBase>(
std::shared_ptr<VarBase> var, const paddle::framework::OpKernelType &key,
std::shared_ptr<VarBase> res);
template void SetCachedValue<VariableWrapper>(
std::shared_ptr<VariableWrapper> var,
const paddle::framework::OpKernelType &key,
std::shared_ptr<VariableWrapper> res);
} // namespace imperative
} // namespace paddle
...@@ -15,19 +15,56 @@ ...@@ -15,19 +15,56 @@
#pragma once #pragma once
#include <vector> #include <vector>
#include "paddle/fluid/eager/eager_tensor.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/pten/api/all.h"
namespace egr { namespace egr {
namespace legacy { class EagerTensor;
} // namespace egr
namespace pten {
class DenseTensor;
}
namespace paddle {
namespace framework {
class Variable;
class OpKernelType;
} // namespace framework
namespace imperative {
class VarBase;
class VariableWrapper;
void InitializeVariable(paddle::framework::Variable* var, void InitializeVariable(paddle::framework::Variable* var,
paddle::framework::proto::VarType::Type var_type); paddle::framework::proto::VarType::Type var_type);
paddle::framework::proto::VarType::Type GetDtypeFromVar( template <typename VarType>
const paddle::framework::Variable& var); const paddle::platform::Place& GetPlace(const std::shared_ptr<VarType>& var);
const paddle::platform::Place& GetPlaceFromVar( template <typename VarType>
const paddle::framework::Variable& var); const std::string& GetNameFromVar(std::shared_ptr<VarType> var);
void CopyVariable(const paddle::framework::Variable& src_var,
paddle::framework::Variable* dst_var); template <typename VarType>
bool CheckCachedKey(std::shared_ptr<VarType> tensor,
} // namespace legacy const paddle::framework::OpKernelType& key);
} // namespace egr template <typename VarType>
void SetCachedValue(std::shared_ptr<VarType> tensor,
const paddle::framework::OpKernelType& key,
std::shared_ptr<VarType> res);
template <typename VarType>
std::shared_ptr<VariableWrapper> GetCachedValue(
std::shared_ptr<VarType> tensor,
const paddle::framework::OpKernelType& key);
template <typename VarType>
void SetType(std::shared_ptr<VarType> var,
framework::proto::VarType::Type type);
template <typename VarType>
framework::proto::VarType::Type GetType(std::shared_ptr<VarType> var);
template <typename VarType>
framework::proto::VarType::Type GetDataType(std::shared_ptr<VarType> var);
template <typename VarType>
const std::shared_ptr<VariableWrapper>& GetVariableWrapper(
const std::shared_ptr<VarType>& var);
} // namespace imperative
} // namespace paddle
...@@ -128,11 +128,19 @@ class ScaleGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -128,11 +128,19 @@ class ScaleGradMaker : public framework::SingleGradOpMaker<T> {
grad_op->SetInput("ScaleTensor", this->Input("ScaleTensor")); grad_op->SetInput("ScaleTensor", this->Input("ScaleTensor"));
} }
grad_op->SetOutput("Out", this->InputGrad("X")); grad_op->SetOutput("Out", this->InputGrad("X"));
VLOG(6) << "Finish SetOutput";
grad_op->SetAttr("scale", this->GetAttr("scale")); grad_op->SetAttr("scale", this->GetAttr("scale"));
VLOG(6) << "Finish Set Attr scale";
grad_op->SetAttr("bias", 0.0f); grad_op->SetAttr("bias", 0.0f);
VLOG(6) << "Finish Set Attr bias";
grad_op->SetAttr("bias_after_scale", true); grad_op->SetAttr("bias_after_scale", true);
if (grad_op->HasAttr("use_mkldnn")) VLOG(6) << "Finish Set Attr bias_after_scale";
if (grad_op->HasAttr("use_mkldnn")) {
VLOG(6) << "Finish Check Attr use_mkldnn";
grad_op->SetAttr("use_mkldnn", this->GetAttr("use_mkldnn")); grad_op->SetAttr("use_mkldnn", this->GetAttr("use_mkldnn"));
VLOG(6) << "Finish Set Attr use_mkldnn";
}
VLOG(6) << "Finish Apply";
} }
}; };
......
...@@ -2305,9 +2305,9 @@ void BindImperative(py::module *m_ptr) { ...@@ -2305,9 +2305,9 @@ void BindImperative(py::module *m_ptr) {
auto outs_map = ConvertToNameVarBaseMap(outs); auto outs_map = ConvertToNameVarBaseMap(outs);
{ {
py::gil_scoped_release release; py::gil_scoped_release release;
self.TraceOp(type, std::move(ins_map), std::move(outs_map), self.TraceOp<imperative::VarBase>(
std::move(attrs), place, trace_backward, type, std::move(ins_map), std::move(outs_map),
inplace_map); std::move(attrs), place, trace_backward, inplace_map);
} }
}) })
.def("trace", .def("trace",
...@@ -2320,9 +2320,9 @@ void BindImperative(py::module *m_ptr) { ...@@ -2320,9 +2320,9 @@ void BindImperative(py::module *m_ptr) {
auto outs_map = ConvertToNameVarBaseMap(outs); auto outs_map = ConvertToNameVarBaseMap(outs);
{ {
py::gil_scoped_release release; py::gil_scoped_release release;
self.TraceOp(type, std::move(ins_map), std::move(outs_map), self.TraceOp<imperative::VarBase>(
std::move(attrs), place, trace_backward, type, std::move(ins_map), std::move(outs_map),
inplace_map); std::move(attrs), place, trace_backward, inplace_map);
} }
}) })
.def("trace", .def("trace",
...@@ -2335,9 +2335,9 @@ void BindImperative(py::module *m_ptr) { ...@@ -2335,9 +2335,9 @@ void BindImperative(py::module *m_ptr) {
auto outs_map = ConvertToNameVarBaseMap(outs); auto outs_map = ConvertToNameVarBaseMap(outs);
{ {
py::gil_scoped_release release; py::gil_scoped_release release;
self.TraceOp(type, std::move(ins_map), std::move(outs_map), self.TraceOp<imperative::VarBase>(
std::move(attrs), place, trace_backward, type, std::move(ins_map), std::move(outs_map),
inplace_map); std::move(attrs), place, trace_backward, inplace_map);
} }
}) })
.def("trace", .def("trace",
...@@ -2350,9 +2350,9 @@ void BindImperative(py::module *m_ptr) { ...@@ -2350,9 +2350,9 @@ void BindImperative(py::module *m_ptr) {
auto outs_map = ConvertToNameVarBaseMap(outs); auto outs_map = ConvertToNameVarBaseMap(outs);
{ {
py::gil_scoped_release release; py::gil_scoped_release release;
self.TraceOp(type, std::move(ins_map), std::move(outs_map), self.TraceOp<imperative::VarBase>(
std::move(attrs), place, trace_backward, type, std::move(ins_map), std::move(outs_map),
inplace_map); std::move(attrs), place, trace_backward, inplace_map);
} }
}) })
.def("trace", .def("trace",
...@@ -2365,9 +2365,9 @@ void BindImperative(py::module *m_ptr) { ...@@ -2365,9 +2365,9 @@ void BindImperative(py::module *m_ptr) {
auto outs_map = ConvertToNameVarBaseMap(outs); auto outs_map = ConvertToNameVarBaseMap(outs);
{ {
py::gil_scoped_release release; py::gil_scoped_release release;
self.TraceOp(type, std::move(ins_map), std::move(outs_map), self.TraceOp<imperative::VarBase>(
std::move(attrs), place, trace_backward, type, std::move(ins_map), std::move(outs_map),
inplace_map); std::move(attrs), place, trace_backward, inplace_map);
} }
}); });
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册