From d48f7c899a0c4527826cc920239fa9cb27ef496a Mon Sep 17 00:00:00 2001 From: Jiabin Yang Date: Fri, 24 Dec 2021 09:49:31 +0800 Subject: [PATCH] Support test imperative basic in eager (#38313) * Rearranged Eager AutoCodeGen directory structure * Removed USE_OP in Eager AutoCodeGen * Enabled generation for Operators without Grad/Inputs/Outputs * Resolved operators without input * Fixed merge conflicts * Enabled Eager AutoCodeGen for 10+ more operators * Refactored Eager AutoCodeGen with more organized helper objects * Enabled Eager AutoCodeGen for operators with multiple OpBases * Adjusted Eager AutoCodeGen to Enable Passing Output Tensor as Input Argument * Handled Dispensable Inputs/Outputs in Eager AutoCodeGen * Adjusted function generation/call between Python-C API & Dygraph API * Synchronized auto-generated Python-C API with Dygraph Forward Functions * support more eager tensor api * fix merge compile error * fix compile error and fit develop code * support pure CPU * fix some logic error in eager_mode * support _varbase_creator in eager mode * Added safe_initialized interface to EagerTensor for use in processing dispensable inputs * for eager mode * refine * support multiple constructor for eager tensor * add place related code * polish code * specific randint with dtype of int64 * Support pure cpu test * eager logic * refine test in pure cpu * eager logic * eager logic * eager logic, test=develop * skip core.eager when in inference, test=develop * refine, test=develop * refine, test=develop * call RetainGrad after run forward kernel, test=develop * refine, test=develop * support dygraph util, meta, guard test * support inference test * refine test and fix initializer failed Co-authored-by: jim19930609 Co-authored-by: Wang Huan --- paddle/fluid/eager/CMakeLists.txt | 2 +- .../eager/accumulation/accumulation_node.h | 2 + .../eager_generated/forwards/scale.cc | 2 +- paddle/fluid/eager/api/utils/global_utils.h | 2 +- .../auto_code_generator/eager_generator.cc | 23 +- paddle/fluid/eager/autograd_meta.h | 6 + paddle/fluid/eager/grad_node_info.cc | 48 +- paddle/fluid/eager/grad_node_info.h | 3 +- .../eager/legacy/infer_var_type_context.h | 3 +- paddle/fluid/eager/legacy/op_runner.cc | 14 +- .../fluid/eager/legacy/prepared_operator.cc | 5 +- .../grad_node_info_test.cc | 2 +- paddle/fluid/eager/utils.cc | 48 +- paddle/fluid/eager/utils.h | 7 + paddle/fluid/pybind/eager.cc | 12 + paddle/fluid/pybind/eager_functions.cc | 107 +--- paddle/fluid/pybind/eager_method.cc | 22 + .../pybind/eager_op_function_generator.cc | 12 +- paddle/fluid/pybind/eager_properties.cc | 20 +- paddle/fluid/pybind/eager_utils.cc | 91 ++- paddle/fluid/pybind/eager_utils.h | 16 +- paddle/fluid/pybind/imperative.cc | 5 +- python/paddle/fluid/data_feeder.py | 12 +- python/paddle/fluid/dygraph/base.py | 22 +- .../fluid/dygraph/layer_object_helper.py | 21 +- python/paddle/fluid/dygraph/math_op_patch.py | 2 + .../fluid/dygraph/varbase_patch_methods.py | 2 +- .../fluid/eager/eager_tensor_patch_methods.py | 110 ++++ python/paddle/fluid/framework.py | 165 +++++- python/paddle/fluid/initializer.py | 524 +++++++++++------- python/paddle/fluid/layer_helper.py | 25 +- python/paddle/fluid/layer_helper_base.py | 22 +- .../tests/unittests/test_egr_python_api.py | 138 ++++- .../tests/unittests/test_imperative_basic.py | 162 ++++-- python/paddle/tensor/creation.py | 23 +- 35 files changed, 1254 insertions(+), 426 deletions(-) diff --git a/paddle/fluid/eager/CMakeLists.txt b/paddle/fluid/eager/CMakeLists.txt index d5abf639c83..df000011e65 100644 --- a/paddle/fluid/eager/CMakeLists.txt +++ b/paddle/fluid/eager/CMakeLists.txt @@ -15,7 +15,7 @@ cc_library(grad_node_info SRCS grad_node_info.cc DEPS pten pten_api) cc_library(grad_tensor_holder SRCS grad_tensor_holder.cc DEPS grad_node_info gradient_accumulation) cc_library(autograd_meta SRCS autograd_meta.cc DEPS pten pten_api) -cc_library(utils SRCS utils.cc DEPS pten pten_api global_utils layer proto_desc operator op_registry variable_helper memcpy scale_op autograd_meta) +cc_library(utils SRCS utils.cc DEPS pten pten_api global_utils layer proto_desc operator op_registry variable_helper memcpy scale_op autograd_meta hook_utils) cc_library(legacy SRCS ${DYGRAPH_LEGACY} DEPS global_utils proto_desc operator pten pten_api op_registry variable_helper memcpy) cc_library(backward SRCS backward.cc DEPS grad_tensor_holder utils autograd_meta grad_node_info) diff --git a/paddle/fluid/eager/accumulation/accumulation_node.h b/paddle/fluid/eager/accumulation/accumulation_node.h index 2582cd3c9df..a2683db75e9 100644 --- a/paddle/fluid/eager/accumulation/accumulation_node.h +++ b/paddle/fluid/eager/accumulation/accumulation_node.h @@ -32,6 +32,8 @@ class GradNodeAccumulation : public GradNodeBase { void RetainGrad( const std::function& hook); + egr::EagerTensor Grad() { return accumulated_grad; } + private: egr::EagerTensor accumulated_grad; diff --git a/paddle/fluid/eager/api/generated/eager_generated/forwards/scale.cc b/paddle/fluid/eager/api/generated/eager_generated/forwards/scale.cc index a8b3421baac..7b20ff144a7 100644 --- a/paddle/fluid/eager/api/generated/eager_generated/forwards/scale.cc +++ b/paddle/fluid/eager/api/generated/eager_generated/forwards/scale.cc @@ -80,7 +80,7 @@ egr::EagerTensor scale(const egr::EagerTensor& x, float scale, float bias, scale_node->SetAttributes_scale(scale); // Set Next Edges - scale_node->AddEdges(*p_autograd_in, /*slot id*/ 0); + scale_node->AddEdges(p_autograd_in, /*slot id*/ 0); // Set TensorWrappers scale_node->SetTensorWrappers_X({x}); diff --git a/paddle/fluid/eager/api/utils/global_utils.h b/paddle/fluid/eager/api/utils/global_utils.h index f58631e26a8..00578d9a359 100644 --- a/paddle/fluid/eager/api/utils/global_utils.h +++ b/paddle/fluid/eager/api/utils/global_utils.h @@ -63,7 +63,7 @@ class Controller { void SetCurrentTracer( const std::shared_ptr& tracer) { tracer_ = tracer; - VLOG(6) << "Set current tracer: " << tracer_; + VLOG(6) << "Set current tracer for Controller: " << tracer_; } bool InEagerMode() const { return in_eager_mode_; } diff --git a/paddle/fluid/eager/auto_code_generator/eager_generator.cc b/paddle/fluid/eager/auto_code_generator/eager_generator.cc index c87cda34cee..07644bfa195 100644 --- a/paddle/fluid/eager/auto_code_generator/eager_generator.cc +++ b/paddle/fluid/eager/auto_code_generator/eager_generator.cc @@ -888,7 +888,7 @@ static std::string GenerateGradNodeCreationContent( if (input.duplicable()) { const char* GET_MULTI_AUTOGRAD_META_TEMPLATE = " std::vector %s = " - "egr::EagerUtils::unsafe_autograd_meta(%s);\n"; + "egr::EagerUtils::nullable_autograd_meta(%s);\n"; get_autograd_meta_str += paddle::string::Sprintf( GET_MULTI_AUTOGRAD_META_TEMPLATE, input_autograd_name, input_name); @@ -902,7 +902,7 @@ static std::string GenerateGradNodeCreationContent( } else { const char* GET_SINGLE_AUTOGRAD_META_TEMPLATE = " egr::AutogradMeta& %s = " - "*egr::EagerUtils::unsafe_autograd_meta(%s);\n"; + "*egr::EagerUtils::nullable_autograd_meta(%s);\n"; get_autograd_meta_str += paddle::string::Sprintf( GET_SINGLE_AUTOGRAD_META_TEMPLATE, input_autograd_name, input_name); } @@ -999,11 +999,17 @@ static std::string GenerateGradNodeCreationContent( input_position); const char* ADD_EDGES_TEMPLATE = - " if(%s) grad_node->AddEdges(*%s, %d);\n"; + " if(%s) grad_node->AddEdges(%s, %d);\n"; grad_node_creation_str += paddle::string::Sprintf(ADD_EDGES_TEMPLATE, input_autograd_name, input_autograd_name, input_position); + VLOG(6) << "Generated Call RetainGradForTensor"; + const char* RETAIN_GRAD_TEMPLATE = + " egr::EagerUtils::CheckAndRetainGrad(%s);\n"; + grad_node_creation_str += + paddle::string::Sprintf(RETAIN_GRAD_TEMPLATE, input_name); + } else { compute_require_grad_args += ", &" + input_autograd_name; size_t input_position = fwd_inputs_name_pos_map.at(input_name); @@ -1013,7 +1019,7 @@ static std::string GenerateGradNodeCreationContent( grad_node_creation_str += paddle::string::Sprintf( SET_GRAD_OUT_META_TEMPLATE, input_autograd_name, input_position); - const char* ADD_EDGES_TEMPLATE = " grad_node->AddEdges(%s, %d);\n"; + const char* ADD_EDGES_TEMPLATE = " grad_node->AddEdges(&%s, %d);\n"; grad_node_creation_str += paddle::string::Sprintf( ADD_EDGES_TEMPLATE, input_autograd_name, input_position); } @@ -1197,23 +1203,23 @@ static std::pair GenerateForwardFunctionContents( if (op_passing_outs_map[op_type].count(output_name)) { const std::string output_var_name = output_name + "Var"; - // Pass Output from function argument, + // Pass Output from function argument(EagerTensor*/vector&), // in form of shared_ptr/vector> if (output.duplicable()) { const char* FWD_NUM_ARG_TEMPLATE = - ", std::vector& %s"; + ", std::vector& %s"; std::string arg_str = paddle::string::Sprintf(FWD_NUM_ARG_TEMPLATE, output_var_name); dygraph_function_args_str += arg_str; } else { - const char* FWD_NUM_ARG_TEMPLATE = ", egr::EagerTensor& %s"; + const char* FWD_NUM_ARG_TEMPLATE = ", egr::EagerTensor* %s"; std::string arg_str = paddle::string::Sprintf(FWD_NUM_ARG_TEMPLATE, output_var_name); dygraph_function_args_str += arg_str; } const char* FWD_OUTS_CONTENT_TEMPLATE = - "{ \"%s\", egr::EagerUtils::TrySyncToVars(&%s) },"; + "{ \"%s\", egr::EagerUtils::TrySyncToVars(%s) },"; outs_contents_str += paddle::string::Sprintf( FWD_OUTS_CONTENT_TEMPLATE, output_name, output_var_name); @@ -1315,6 +1321,7 @@ static std::pair GenerateForwardFunctionContents( GenerateGradNodeCreationContent(fwd_info, bwd_info); generated_function_body += grad_node_creation_body_str; generated_function_body += "\n"; + // [Generation] Call RetainGradForTensor VLOG(6) << "Generated GradNode Creation codes"; } diff --git a/paddle/fluid/eager/autograd_meta.h b/paddle/fluid/eager/autograd_meta.h index 7f461364167..51937fc4815 100644 --- a/paddle/fluid/eager/autograd_meta.h +++ b/paddle/fluid/eager/autograd_meta.h @@ -120,6 +120,10 @@ class AutogradMeta : public AbstractAutogradMeta { void SetPersistable(bool persistable) { persistable_ = persistable; } + bool RetainGrads() { return retain_grads_; } + + void SetRetainGrads(bool value) { retain_grads_ = value; } + private: // TODO(jiabin) :Should we use pointer instead of object? egr::EagerTensor grad_; @@ -149,6 +153,8 @@ class AutogradMeta : public AbstractAutogradMeta { bool persistable_{false}; + bool retain_grads_{false}; + // TODO(jiabin) :Support Quantum here and add cache mechanism as // VarCache defined in VarBase }; diff --git a/paddle/fluid/eager/grad_node_info.cc b/paddle/fluid/eager/grad_node_info.cc index a1c25f6766a..6760499fdc7 100644 --- a/paddle/fluid/eager/grad_node_info.cc +++ b/paddle/fluid/eager/grad_node_info.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/eager/grad_node_info.h" +#include "paddle/fluid/eager/accumulation/accumulation_node.h" #include "paddle/fluid/eager/autograd_meta.h" #include "paddle/pten/common/data_type.h" #include "paddle/pten/core/dense_tensor.h" @@ -35,6 +36,29 @@ GradNodeBase::GradNodeBase(size_t bwd_in_slot_num, size_t bwd_out_slot_num) { adj_edges_.resize(bwd_out_slot_num); } +void GradNodeBase::AddEdges(std::vector* metas, size_t slot_id) { + PADDLE_ENFORCE_LT( + slot_id, adj_edges_.size(), + paddle::platform::errors::InvalidArgument( + "Given slot id is out of range of adj_edges outter size, " + "adj_edges is designed to has the same size of grad " + "inputs's slot num.")); + for (const auto& meta : *metas) { + // adj_edges has as same rank as fwd inputs, and record it's output rank + // from + // its pre-ops + auto node = meta->GetMutableGradNode(); + if (node) { + adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(), + meta->OutRankInfo()); + } else { + meta->SetGradNode(std::make_shared()); + adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(), + meta->OutRankInfo()); + } + } +} + void GradNodeBase::AddEdges(const std::vector& metas, size_t slot_id) { PADDLE_ENFORCE_LT( @@ -47,20 +71,34 @@ void GradNodeBase::AddEdges(const std::vector& metas, // adj_edges has as same rank as fwd inputs, and record it's output rank // from // its pre-ops - adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(), - meta->OutRankInfo()); + auto node = meta->GetMutableGradNode(); + if (node) { + adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(), + meta->OutRankInfo()); + } else { + meta->SetGradNode(std::make_shared()); + adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(), + meta->OutRankInfo()); + } } } -void GradNodeBase::AddEdges(const AutogradMeta& meta, size_t slot_id) { +void GradNodeBase::AddEdges(AutogradMeta* meta, size_t slot_id) { PADDLE_ENFORCE_LT( slot_id, adj_edges_.size(), paddle::platform::errors::InvalidArgument( "Given slot id is out of range of adj_edges outter size, " "adj_edges is designed to has the same size of grad " "inputs's slot num.")); - adj_edges_[slot_id].emplace_back(meta.GetMutableGradNode(), - meta.OutRankInfo()); + auto node = meta->GetMutableGradNode(); + if (node) { + adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(), + meta->OutRankInfo()); + } else { + meta->SetGradNode(std::make_shared()); + adj_edges_[slot_id].emplace_back(meta->GetMutableGradNode(), + meta->OutRankInfo()); + } } const std::vector& GradNodeBase::InputMeta() const { diff --git a/paddle/fluid/eager/grad_node_info.h b/paddle/fluid/eager/grad_node_info.h index 6a4053e8378..545b577f4bd 100644 --- a/paddle/fluid/eager/grad_node_info.h +++ b/paddle/fluid/eager/grad_node_info.h @@ -105,8 +105,9 @@ class GradNodeBase { * * This one is called slot by slot * **/ + void AddEdges(std::vector* metas, size_t slot_id); void AddEdges(const std::vector& metas, size_t slot_id); - void AddEdges(const AutogradMeta& meta, size_t slot_id); + void AddEdges(AutogradMeta* meta, size_t slot_id); /** * GetEdges is designed to get all edges of current node**/ diff --git a/paddle/fluid/eager/legacy/infer_var_type_context.h b/paddle/fluid/eager/legacy/infer_var_type_context.h index 8e7bbef37d8..2d5a8d806fe 100644 --- a/paddle/fluid/eager/legacy/infer_var_type_context.h +++ b/paddle/fluid/eager/legacy/infer_var_type_context.h @@ -153,7 +153,8 @@ class TensorRuntimeInferVarTypeContext paddle::framework::proto::VarType::Type GetOutputType( const std::string& name, const int& index = 0) const override { - return paddle::framework::ToVarType(outputs_.at(name)[index]->Var().Type()); + // TODO(jiabin): Support SelectedRows when we have it. + return paddle::framework::proto::VarType::LOD_TENSOR; } paddle::framework::proto::VarType::Type GetInputDataType( diff --git a/paddle/fluid/eager/legacy/op_runner.cc b/paddle/fluid/eager/legacy/op_runner.cc index 027dc6ee1cb..4dab96c53ec 100644 --- a/paddle/fluid/eager/legacy/op_runner.cc +++ b/paddle/fluid/eager/legacy/op_runner.cc @@ -37,6 +37,7 @@ void OpRunImpl(const paddle::framework::OperatorBase& op, const paddle::framework::AttributeMap& attrs, const paddle::framework::AttributeMap& default_attrs, const paddle::platform::Place& place) { + VLOG(6) << "Get Opertor With Kernel"; auto* op_kernel = dynamic_cast(&op); PADDLE_ENFORCE_NOT_NULL( @@ -44,11 +45,13 @@ void OpRunImpl(const paddle::framework::OperatorBase& op, "Only support operator with kernel in Dygraph mode.")); auto& info = op.Info(); if (info.infer_var_type_) { + VLOG(6) << "Run InferVarType"; egr::legacy::TensorRuntimeInferVarTypeContext infer_var_type_ctx( ins, outs, attrs, default_attrs); + VLOG(9) << "Actual Run InferVarType"; info.infer_var_type_(&infer_var_type_ctx); } - + VLOG(6) << "Initialize output tensor"; // Initialize output tensor for (auto& tensor_pair : outs) { for (auto& tensor : tensor_pair.second) { @@ -77,10 +80,13 @@ void OpRunImpl(const paddle::framework::OperatorBase& op, * after the execution of op, but the original input is directly * overwritten in the previous dynamic graph implemention. */ + VLOG(6) << "Prepare Op"; auto prepared_op = egr::legacy::PreparedOp::Prepare( ins, outs, *op_kernel, place, attrs, default_attrs); + VLOG(6) << "Prepare Data"; auto tmp_ins_ptr = egr::legacy::PrepareData(*op_kernel, ins, prepared_op.kernel_type()); + VLOG(6) << "Run Prepared Op"; if (tmp_ins_ptr == nullptr) { prepared_op.Run(ins, outs, attrs, default_attrs); } else { @@ -130,6 +136,7 @@ void RunOp(const std::string& type, const NameTensorMap& ins, } auto amp_level = egr::Controller::Instance().GetAMPLevel(); + VLOG(6) << "Check AMP status"; NameTensorMap new_ins = ins; if (amp_level == paddle::imperative::AmpLevel::O1) { VLOG(5) << "Auto mixed precision run operator: " << type; @@ -140,6 +147,7 @@ void RunOp(const std::string& type, const NameTensorMap& ins, } try { + VLOG(6) << "Get Device id"; if (paddle::platform::is_gpu_place(place)) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) paddle::platform::SetDeviceId( @@ -165,7 +173,7 @@ void RunOp(const std::string& type, const NameTensorMap& ins, "PaddlePaddle should compile with NPU if use NPUPlace.")); #endif } - + VLOG(6) << "Step in OpRunImpl"; OpRunImpl(*op, new_ins, outs, attrs, *default_attrs, place); } catch (paddle::platform::EnforceNotMet& exception) { paddle::framework::AppendErrorOpHint(type, &exception); @@ -182,7 +190,7 @@ void RunOp(const std::string& type, const NameTensorMap& ins, PADDLE_THROW(paddle::platform::errors::Fatal( "Operator %s raises an unknown exception.", type)); } - + VLOG(6) << "Finish Run Op"; // TODO(jiabin): Support this later // if (enable_program_desc_tracing_) { // VLOG(5) << "Trace op " << type << " into ProgramDesc"; diff --git a/paddle/fluid/eager/legacy/prepared_operator.cc b/paddle/fluid/eager/legacy/prepared_operator.cc index 547ee869674..1c3429207f8 100644 --- a/paddle/fluid/eager/legacy/prepared_operator.cc +++ b/paddle/fluid/eager/legacy/prepared_operator.cc @@ -76,6 +76,7 @@ PreparedOp PrepareImpl(const NameTensorMap& ins, const NameTensorMap& outs, const paddle::platform::Place& place, const paddle::framework::AttributeMap& attrs, const paddle::framework::AttributeMap& default_attrs) { + VLOG(6) << "Preparing an Op"; paddle::platform::DeviceContextPool& pool = paddle::platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(place); @@ -146,7 +147,7 @@ PreparedOp PrepareImpl(const NameTensorMap& ins, const NameTensorMap& outs, if (!(expected_kernel_key.place_ == place)) { dev_ctx = pool.Get(expected_kernel_key.place_); } - + VLOG(6) << "Construct Prepared Op"; return PreparedOp(op, ctx, expected_kernel_key, kernel_iter->second, dev_ctx); } @@ -168,6 +169,7 @@ static void PreparedOpRunImpl( const NameTensorMap& outs, const paddle::framework::AttributeMap& attrs, const paddle::framework::AttributeMap& default_attrs) { // TODO(zjl): remove scope in dygraph + VLOG(6) << "Runing Prepared Op"; paddle::framework::Scope scope; EagerInferShapeContext infer_shape_ctx(&ins, &outs, &attrs, &default_attrs, @@ -198,6 +200,7 @@ static void PreparedOpRunImpl( if (paddle::framework::IsComplexType(kernel_type.data_type_)) { HandleComplexGradToRealGrad(outs); } + VLOG(6) << "Finish Runing Prepared Op"; } void PreparedOp::Run(const NameTensorMap& ins, const NameTensorMap& outs, diff --git a/paddle/fluid/eager/tests/data_structure_tests/grad_node_info_test.cc b/paddle/fluid/eager/tests/data_structure_tests/grad_node_info_test.cc index abc200f7130..aebb0553e28 100644 --- a/paddle/fluid/eager/tests/data_structure_tests/grad_node_info_test.cc +++ b/paddle/fluid/eager/tests/data_structure_tests/grad_node_info_test.cc @@ -58,7 +58,7 @@ TEST(GradNodeInfo, GradNodeBase) { auto auto_grad0 = std::make_shared(edge0); egr::Edge edge1(grad_test_node1, 3, 4); auto auto_grad1 = std::make_shared(edge1); - grad_test_node0->AddEdges((*auto_grad0.get()), 0); + grad_test_node0->AddEdges(auto_grad0.get(), 0); CHECK_EQ(grad_test_node0->GetEdges()[0][0].GetEdgeRankInfo().first, size_t(1)); CHECK_EQ(grad_test_node0->GetEdges()[0][0].GetEdgeRankInfo().second, diff --git a/paddle/fluid/eager/utils.cc b/paddle/fluid/eager/utils.cc index 2e52753bcc2..6459614330a 100644 --- a/paddle/fluid/eager/utils.cc +++ b/paddle/fluid/eager/utils.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/eager/utils.h" #include "paddle/fluid/eager/api/utils/global_utils.h" +#include "paddle/fluid/eager/api/utils/hook_utils.h" #include "paddle/fluid/eager/tensor_wrapper.h" #include "paddle/pten/api/all.h" @@ -24,6 +25,9 @@ #include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/framework/variable.h" +PADDLE_DEFINE_EXPORTED_bool(retain_grad_for_all_tensor, true, + "retain grad for all tensor"); + namespace egr { /** * Implementation of Eager Utils. @@ -50,8 +54,9 @@ AutogradMeta* EagerUtils::unsafe_autograd_meta(const egr::EagerTensor& target) { std::vector EagerUtils::unsafe_autograd_meta( const std::vector& targets) { std::vector metas; + metas.reserve(targets.size()); for (const egr::EagerTensor& t : targets) { - metas.push_back(unsafe_autograd_meta(t)); + metas.emplace_back(unsafe_autograd_meta(t)); } return metas; } @@ -64,6 +69,16 @@ AutogradMeta* EagerUtils::nullable_autograd_meta( return static_cast(p_autograd_meta); } +std::vector EagerUtils::nullable_autograd_meta( + const std::vector& targets) { + std::vector metas; + metas.reserve(targets.size()); + for (const egr::EagerTensor& t : targets) { + metas.emplace_back(nullable_autograd_meta(t)); + } + return metas; +} + std::vector EagerUtils::multi_autograd_meta( std::vector* targets) { std::vector ret; @@ -140,7 +155,8 @@ static std::shared_ptr TrySyncToVar( if (tensor->initialized() || tensor->Var().IsInitialized()) { tensor->SyncToVar(paddle::framework::proto::VarType_Type_LOD_TENSOR); } - return std::make_shared(*tensor); + return std::shared_ptr(tensor, + [&](egr::EagerTensor* ptr) {}); } std::vector> EagerUtils::TrySyncToVars( @@ -159,6 +175,17 @@ std::vector> EagerUtils::TrySyncToVars( return res; } +std::vector> EagerUtils::TrySyncToVars( + const std::vector& tensors) { + std::vector> res; + size_t num = tensors.size(); + res.reserve(num); + for (size_t i = 0; i < num; i++) { + res.emplace_back(TrySyncToVar(tensors[i])); + } + return res; +} + /* ---- VarBase -> Tensor ---- */ std::vector> EagerUtils::SyncToTensors( const egr::EagerTensor& tensor) { @@ -236,4 +263,21 @@ std::vector EagerUtils::RecoverTensorWrapper( return ret; } +void EagerUtils::CheckAndRetainGrad(const egr::EagerTensor& tensor) { + VLOG(6) << "Check RetainGradForTensor: " << tensor.name(); + if (FLAGS_retain_grad_for_all_tensor) { + egr::egr_utils_api::RetainGradForTensor(tensor); + } +} + +void EagerUtils::CheckAndRetainGrad( + const std::vector& tensors) { + if (FLAGS_retain_grad_for_all_tensor) { + for (auto& tensor : tensors) { + VLOG(6) << "Check RetainGradForTensor: " << tensor.name(); + egr::egr_utils_api::RetainGradForTensor(tensor); + } + } +} + } // namespace egr diff --git a/paddle/fluid/eager/utils.h b/paddle/fluid/eager/utils.h index 843b6404af5..bc1acbd69d0 100644 --- a/paddle/fluid/eager/utils.h +++ b/paddle/fluid/eager/utils.h @@ -116,6 +116,8 @@ class EagerUtils { // This method will return an AutogradMeta pointer unsafely. static AutogradMeta* nullable_autograd_meta(const egr::EagerTensor& target); + static std::vector nullable_autograd_meta( + const std::vector& targets); static AutogradMeta* unsafe_autograd_meta(const egr::EagerTensor& target); static std::vector unsafe_autograd_meta( const std::vector& targets); @@ -149,6 +151,8 @@ class EagerUtils { egr::EagerTensor* tensor); static std::vector> TrySyncToVars( std::vector* tensors); + static std::vector> TrySyncToVars( + const std::vector& tensors); static std::vector> SyncToVars( const egr::EagerTensor& tensor); @@ -163,6 +167,9 @@ class EagerUtils { static std::vector GetOutputs( const std::vector>& outs); static egr::EagerTensor GetOutput(const std::shared_ptr& outs); + + static void CheckAndRetainGrad(const egr::EagerTensor& tensor); + static void CheckAndRetainGrad(const std::vector& tensors); }; } // namespace egr diff --git a/paddle/fluid/pybind/eager.cc b/paddle/fluid/pybind/eager.cc index 94ff2eb4c1e..9b69ccca5a2 100644 --- a/paddle/fluid/pybind/eager.cc +++ b/paddle/fluid/pybind/eager.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/eager/accumulation/accumulation_node.h" #include "paddle/fluid/eager/api/all.h" #include "paddle/fluid/eager/autograd_meta.h" #include "paddle/fluid/eager/utils.h" @@ -72,6 +73,17 @@ void EmptyEagerTensorInitializer( pten::DenseTensorMeta(pten::TransToPtenDataType(dtype), paddle::framework::make_ddim(dims))); self->eager_tensor.set_impl(dense_tensor); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "We only support LoDTensor to be constructed by this initializer, " + "please check your var type first and make sure you are going to " + "construct LoDTensor.")); + } + + if (!autograd_meta->GetMutableGradNode()) { + VLOG(3) << "Tensor(" << name + << ") have not GradNode, add GradNodeAccumulation for it."; + autograd_meta->SetGradNode(std::make_shared()); } } diff --git a/paddle/fluid/pybind/eager_functions.cc b/paddle/fluid/pybind/eager_functions.cc index b980692d455..3f8923440be 100644 --- a/paddle/fluid/pybind/eager_functions.cc +++ b/paddle/fluid/pybind/eager_functions.cc @@ -112,84 +112,6 @@ static PyObject* eager_api_scale(PyObject* self, PyObject* args, EAGER_CATCH_AND_THROW_RETURN_NULL } -static PyObject* eager_api_numpy_to_tensor(PyObject* numpy_data, - pten::DataType dtype, - const paddle::platform::Place& place, - bool stop_gradient) { - std::vector vec_dims; - auto numpy_shape = pybind11::detail::array_proxy(numpy_data)->dimensions; - int rank = pybind11::detail::array_proxy(numpy_data)->nd; - for (int i = 0; i < rank; i++) { - vec_dims.push_back(static_cast(numpy_shape[i])); - } - paddle::framework::DDim dims = paddle::framework::make_ddim(vec_dims); - - // TODO(jiabin): Support GPU later - auto meta = pten::DenseTensorMeta(dtype, dims); - auto holder = std::make_shared(numpy_data, dtype); - auto shared_storage = - pten::make_intrusive(holder, 0); - std::shared_ptr densetensor( - new pten::DenseTensor(std::move(shared_storage), std::move(meta))); - - PyObject* obj = p_eager_tensor_type->tp_alloc(p_eager_tensor_type, 0); - if (obj) { - auto v = reinterpret_cast(obj); - new (&(v->eager_tensor)) egr::EagerTensor(); - v->eager_tensor.set_impl(densetensor); - v->eager_tensor.set_name(egr::Controller::Instance().GenerateUniqueName()); - auto meta = egr::EagerUtils::autograd_meta(&(v->eager_tensor)); - meta->SetStopGradient(stop_gradient); - - // Created tensor will be leaf tensor - // So we append AccumulationNode to it. - auto accumulation_node = std::make_shared(); - meta->SetGradNode(accumulation_node); - - // TODO(jiabin): Shall we increase ref cnt here to make python ref cnt num - // correctly? - } else { - PADDLE_THROW(platform::errors::Fatal( - "tp_alloc return null, can not new a PyObject.")); - } - - return obj; -} - -static PyObject* eager_api_to_tensor(PyObject* self, PyObject* args, - PyObject* kwargs) { - EAGER_TRY - // TODO(jiabin): Support Kwargs here - PyObject* data = PyTuple_GET_ITEM(args, 0); - auto str_dtype = CastPyArg2AttrString(PyTuple_GET_ITEM(args, 1), 1); - pten::DataType dtype = pten::String2DataType(str_dtype); - auto place = CastPyArg2Place(PyTuple_GET_ITEM(args, 2), 2); - bool stop_gradient = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 3), 3); - // TODO(jiabin): Support this when python given name - // auto str_name = CastPyArg2AttrString(PyTuple_GET_ITEM(args, 4), 4); - - if (pybind11::detail::npy_api::get().PyArray_Check_(data)) { - return eager_api_numpy_to_tensor(data, dtype, place, stop_gradient); - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "Eater to_tensor only support numpy to tensor.")); - Py_INCREF(Py_None); - return Py_None; - } - EAGER_CATCH_AND_THROW_RETURN_NULL -} - -static PyObject* eager_api_retain_grad_for_tensor(PyObject* self, - PyObject* args, - PyObject* kwargs) { - EAGER_TRY - egr::egr_utils_api::RetainGradForTensor( - CastPyArg2EagerTensor(PyTuple_GET_ITEM(args, 0), 0)); - Py_INCREF(Py_None); - return Py_None; - EAGER_CATCH_AND_THROW_RETURN_NULL -} - static PyObject* eager_api_run_backward(PyObject* self, PyObject* args, PyObject* kwargs) { EAGER_TRY @@ -203,9 +125,29 @@ static PyObject* eager_api_run_backward(PyObject* self, PyObject* args, EAGER_CATCH_AND_THROW_RETURN_NULL } +static PyObject* eager_api_tensor_copy(PyObject* self, PyObject* args, + PyObject* kwargs) { + EAGER_TRY + egr::EagerTensor& src = + reinterpret_cast(PyTuple_GET_ITEM(args, 0)) + ->eager_tensor; + egr::EagerTensor& dst = + reinterpret_cast(PyTuple_GET_ITEM(args, 1)) + ->eager_tensor; + auto place = CastPyArg2Place(PyTuple_GET_ITEM(args, 2), 2); + bool blocking = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 3), 3); + + dst = src.copy_to(pten::TransToPtenBackend(place), blocking); + egr::EagerUtils::autograd_meta(&dst)->SetStopGradient( + egr::EagerUtils::autograd_meta(&(src))->StopGradient()); + egr::EagerUtils::autograd_meta(&dst)->SetPersistable( + egr::EagerUtils::autograd_meta(&(src))->Persistable()); + Py_INCREF(Py_None); + return Py_None; + EAGER_CATCH_AND_THROW_RETURN_NULL +} + PyMethodDef variable_functions[] = { - {"to_tensor", (PyCFunction)(void (*)(void))eager_api_to_tensor, - METH_VARARGS | METH_KEYWORDS, NULL}, {"scale", (PyCFunction)(void (*)(void))eager_api_scale, METH_VARARGS | METH_KEYWORDS, NULL}, {"_set_expected_place", @@ -214,11 +156,10 @@ PyMethodDef variable_functions[] = { {"_get_expected_place", (PyCFunction)(void (*)(void))eager_api_get_expected_place, METH_VARARGS | METH_KEYWORDS, NULL}, - {"retain_grad_for_tensor", - (PyCFunction)(void (*)(void))eager_api_retain_grad_for_tensor, - METH_VARARGS | METH_KEYWORDS, NULL}, {"run_backward", (PyCFunction)(void (*)(void))eager_api_run_backward, METH_VARARGS | METH_KEYWORDS, NULL}, + {"tensor_copy", (PyCFunction)(void (*)(void))eager_api_tensor_copy, + METH_VARARGS | METH_KEYWORDS, NULL}, {NULL, NULL, 0, NULL}}; void BindFunctions(PyObject* module) { diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index 790969a4b60..e0e23b5a49f 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include "pybind11/numpy.h" #include "pybind11/pybind11.h" +#include "paddle/fluid/eager/accumulation/accumulation_node.h" #include "paddle/fluid/eager/api/all.h" #include "paddle/fluid/eager/autograd_meta.h" #include "paddle/fluid/eager/utils.h" @@ -120,6 +121,8 @@ static PyObject* eager_tensor_method_copy_(EagerTensorObject* self, egr::EagerTensor src_tensor = CastPyArg2EagerTensor(PyTuple_GET_ITEM(args, 0), 0); bool blocking = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 1), 1); + VLOG(6) << "Start Copy Tensor " << src_tensor.name() << " to " + << self->eager_tensor.name(); self->eager_tensor.copy_(src_tensor, blocking); egr::EagerUtils::autograd_meta(&(self->eager_tensor)) ->SetStopGradient( @@ -127,6 +130,23 @@ static PyObject* eager_tensor_method_copy_(EagerTensorObject* self, egr::EagerUtils::autograd_meta(&(self->eager_tensor)) ->SetPersistable( egr::EagerUtils::autograd_meta(&(src_tensor))->Persistable()); + VLOG(6) << "Finish Copy Tensor " << src_tensor.name() << " to " + << self->eager_tensor.name(); + Py_INCREF(Py_None); + return Py_None; + EAGER_CATCH_AND_THROW_RETURN_NULL +} + +static PyObject* eager_tensor_retain_grads(EagerTensorObject* self, + PyObject* args, PyObject* kwargs) { + EAGER_TRY + auto meta = egr::EagerUtils::autograd_meta(&(self->eager_tensor)); + if (!meta->GetMutableGradNode()) { + VLOG(6) << "Make grad node of tensor: " << self->eager_tensor.name() + << "become accumulation node"; + meta->SetGradNode(std::make_shared()); + } + egr::egr_utils_api::RetainGradForTensor(self->eager_tensor); Py_INCREF(Py_None); return Py_None; EAGER_CATCH_AND_THROW_RETURN_NULL @@ -142,6 +162,8 @@ PyMethodDef variable_methods[] = { METH_VARARGS | METH_KEYWORDS, NULL}, {"copy_", (PyCFunction)(void (*)(void))eager_tensor_method_copy_, METH_VARARGS | METH_KEYWORDS, NULL}, + {"retain_grads", (PyCFunction)(void (*)(void))eager_tensor_retain_grads, + METH_VARARGS | METH_KEYWORDS, NULL}, {NULL, NULL, 0, NULL}}; } // namespace pybind diff --git a/paddle/fluid/pybind/eager_op_function_generator.cc b/paddle/fluid/pybind/eager_op_function_generator.cc index 4e880f78c6d..3d0a4d0de75 100644 --- a/paddle/fluid/pybind/eager_op_function_generator.cc +++ b/paddle/fluid/pybind/eager_op_function_generator.cc @@ -70,11 +70,17 @@ const char* OUT_VAR_TYPE = R"(std::shared_ptr)"; const char* OUT_VAR_LIST_TYPE = R"(std::vector>)"; const char* CAST_VAR_TEMPLATE = R"( - auto %s = GetEagerTensorFromArgs("%s", "%s", args, %d, %s);)"; + auto& %s = GetEagerTensorFromArgs("%s", "%s", args, %d, %s);)"; const char* CAST_VAR_LIST_TEMPLATE = R"( auto %s = GetEagerTensorListFromArgs("%s", "%s", args, %d, %s);)"; +const char* CAST_VAR_PTR_TEMPLATE = R"( + auto %s = GetEagerTensorPtrFromArgs("%s", "%s", args, %d, %s);)"; + +const char* CAST_VAR_PTR_LIST_TEMPLATE = R"( + auto %s = GetEagerTensorPtrListFromArgs("%s", "%s", args, %d, %s);)"; + const char* CAST_SIZE_T_TEMPLATE = R"( auto %s = GetUnsignedLongFromArgs("%s", "%s", args, %d, %s);)"; @@ -221,8 +227,8 @@ std::string GenerateOpFunctionsBody( outs_initializer += ","; } - const auto in_cast_type = - output.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE; + const auto in_cast_type = output.duplicable() ? CAST_VAR_PTR_LIST_TEMPLATE + : CAST_VAR_PTR_TEMPLATE; auto dispensable = output.dispensable() ? "true" : "false"; ins_cast_str += paddle::string::Sprintf(in_cast_type, out_name, op_type, out_name, arg_idx++, dispensable); diff --git a/paddle/fluid/pybind/eager_properties.cc b/paddle/fluid/pybind/eager_properties.cc index b8b7adea506..4025a33b561 100644 --- a/paddle/fluid/pybind/eager_properties.cc +++ b/paddle/fluid/pybind/eager_properties.cc @@ -14,7 +14,9 @@ limitations under the License. */ #include #include +#include "paddle/fluid/eager/accumulation/accumulation_node.h" #include "paddle/fluid/eager/api/all.h" +#include "paddle/fluid/eager/api/utils/tensor_utils.h" #include "paddle/fluid/eager/autograd_meta.h" #include "paddle/fluid/eager/utils.h" #include "paddle/fluid/memory/allocation/allocator.h" @@ -60,8 +62,22 @@ PyObject* eager_tensor_properties_get_stop_gradient(EagerTensorObject* self, PyObject* eager_tensor_properties_get_grad(EagerTensorObject* self, void* closure) { EAGER_SYNC_TRY - auto meta = egr::EagerUtils::unsafe_autograd_meta(self->eager_tensor); - return ToPyObject(meta->Grad()); + if (egr::egr_utils_api::IsLeafTensor(self->eager_tensor)) { + // Add RetainGrad as PostHook to AccumulationNode + std::shared_ptr grad_node = + egr::EagerUtils::grad_node(self->eager_tensor); + PADDLE_ENFORCE( + grad_node.get() != nullptr, + paddle::platform::errors::Fatal("Detected NULL grad_node" + "Leaf tensor should have had grad_node " + "with type: GradNodeAccumulation")); + auto accumulation_grad_node = + std::dynamic_pointer_cast(grad_node); + return ToPyObject(accumulation_grad_node->Grad()); + } else { + auto meta = egr::EagerUtils::unsafe_autograd_meta(self->eager_tensor); + return ToPyObject(meta->Grad()); + } EAGER_CATCH_AND_THROW_RETURN_NULL } diff --git a/paddle/fluid/pybind/eager_utils.cc b/paddle/fluid/pybind/eager_utils.cc index d9da5102262..ba328692dd2 100644 --- a/paddle/fluid/pybind/eager_utils.cc +++ b/paddle/fluid/pybind/eager_utils.cc @@ -199,7 +199,7 @@ std::vector CastPyArg2VectorOfEagerTensor(PyObject* obj, } else { PADDLE_THROW(platform::errors::InvalidArgument( "argument (position %d) must be " - "list of bool, but got %s at pos %d", + "list of Tensor, but got %s at pos %d", arg_pos + 1, reinterpret_cast(item->ob_type)->tp_name, i)); } @@ -216,7 +216,7 @@ std::vector CastPyArg2VectorOfEagerTensor(PyObject* obj, } else { PADDLE_THROW(platform::errors::InvalidArgument( "argument (position %d) must be " - "list of EagerTensor, but got %s at pos %d", + "list of Tensor, but got %s at pos %d", arg_pos + 1, reinterpret_cast(item->ob_type)->tp_name, i)); } @@ -478,10 +478,10 @@ PyObject* ToPyObject( return dict; } -egr::EagerTensor GetEagerTensorFromArgs(const std::string& op_type, - const std::string& arg_name, - PyObject* args, ssize_t arg_idx, - bool dispensable) { +egr::EagerTensor& GetEagerTensorFromArgs(const std::string& op_type, + const std::string& arg_name, + PyObject* args, ssize_t arg_idx, + bool dispensable) { PyObject* obj = PyTuple_GET_ITEM(args, arg_idx); if (PyTuple_Check(obj)) { @@ -494,7 +494,7 @@ egr::EagerTensor GetEagerTensorFromArgs(const std::string& op_type, "%s(): argument '%s' (position %d) must be Tensor, but got None", op_type, arg_name, arg_idx)); } - egr::EagerTensor emptytensor; + static egr::EagerTensor emptytensor; return emptytensor; } @@ -555,5 +555,82 @@ std::vector GetEagerTensorListFromArgs( return result; } +egr::EagerTensor* GetEagerTensorPtrFromArgs(const std::string& op_type, + const std::string& arg_name, + PyObject* args, ssize_t arg_idx, + bool dispensable) { + PyObject* obj = PyTuple_GET_ITEM(args, arg_idx); + + if (PyTuple_Check(obj)) { + obj = PyTuple_GET_ITEM(obj, 0); + } + + if (obj == nullptr || obj == Py_None) { + if (!dispensable) { + PADDLE_THROW(platform::errors::InvalidArgument( + "%s(): argument '%s' (position %d) must be Tensor, but got None", + op_type, arg_name, arg_idx)); + } + static egr::EagerTensor emptytensor; + return &emptytensor; + } + + return &(reinterpret_cast(obj)->eager_tensor); +} + +std::vector GetEagerTensorPtrListFromArgs( + const std::string& op_type, const std::string& arg_name, PyObject* args, + ssize_t arg_idx, bool dispensable) { + PyObject* list = PyTuple_GET_ITEM(args, arg_idx); + + if (list == nullptr) { + if (!dispensable) { + PADDLE_THROW(platform::errors::InvalidArgument( + "%s(): argument '%s' (position %d) must be list of Tensor, but got " + "None", + op_type, arg_name, arg_idx)); + } + return {}; + } + + std::vector result; + + if (PyList_Check(list)) { + Py_ssize_t len = PyList_Size(list); + if (len == 0) { + PADDLE_THROW(platform::errors::InvalidArgument( + "%s(): argument '%s' (position %d) must be list of Tensors, but got " + "empty list", + op_type, arg_name, arg_idx)); + } + for (Py_ssize_t i = 0; i < len; i++) { + result.emplace_back( + &(reinterpret_cast(PyList_GetItem(list, i)) + ->eager_tensor)); + } + } else if (PyTuple_Check(list)) { + Py_ssize_t len = PyTuple_Size(list); + if (len == 0) { + PADDLE_THROW(platform::errors::InvalidArgument( + "%s(): argument '%s' (position %d) must be list of Tensors, but got " + "empty list", + op_type, arg_name, arg_idx)); + } + for (Py_ssize_t i = 0; i < len; i++) { + result.emplace_back( + &(reinterpret_cast(PyTuple_GetItem(list, i)) + ->eager_tensor)); + } + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "%s(): argument '%s' (position %d) must be list of Tensors, but got " + "%s", + op_type, arg_name, arg_idx, + (reinterpret_cast(list->ob_type))->tp_name)); + } + + return result; +} + } // namespace pybind } // namespace paddle diff --git a/paddle/fluid/pybind/eager_utils.h b/paddle/fluid/pybind/eager_utils.h index e493e06d7d7..7b7a88b5ac4 100644 --- a/paddle/fluid/pybind/eager_utils.h +++ b/paddle/fluid/pybind/eager_utils.h @@ -83,13 +83,21 @@ PyObject* ToPyObject(const std::tuple& out) { return result; } -egr::EagerTensor GetEagerTensorFromArgs(const std::string& op_type, - const std::string& arg_name, - PyObject* args, ssize_t arg_idx, - bool dispensable = false); +egr::EagerTensor& GetEagerTensorFromArgs(const std::string& op_type, + const std::string& arg_name, + PyObject* args, ssize_t arg_idx, + bool dispensable = false); std::vector GetEagerTensorListFromArgs( const std::string& op_type, const std::string& arg_name, PyObject* args, ssize_t arg_idx, bool dispensable = false); +egr::EagerTensor* GetEagerTensorPtrFromArgs(const std::string& op_type, + const std::string& arg_name, + PyObject* args, ssize_t arg_idx, + bool dispensable = false); +std::vector GetEagerTensorPtrListFromArgs( + const std::string& op_type, const std::string& arg_name, PyObject* args, + ssize_t arg_idx, bool dispensable = false); + } // namespace pybind } // namespace paddle diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index e981de44c5a..2190dafbff8 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -876,9 +876,8 @@ void BindImperative(py::module *m_ptr) { [](const std::shared_ptr &tracer) { if (egr::Controller::Instance().InEagerMode()) { egr::Controller::Instance().SetCurrentTracer(tracer); - } else { - imperative::SetCurrentTracer(tracer); } + imperative::SetCurrentTracer(tracer); }); m.def("_enable_eager_mode", []() { egr::Controller::Instance().SetInEagerMode(true); }); @@ -2150,6 +2149,8 @@ void BindImperative(py::module *m_ptr) { if (py::isinstance(obj)) { auto p = obj.cast(); self.SetExpectedPlace(*p); + // TODO(jiabin): Support eager here when we need to make all + // dygraph in eager mode VLOG(4) << "Tracer(" << &self << ")" << " set expected place " << *p; } else if (py::isinstance(obj)) { diff --git a/python/paddle/fluid/data_feeder.py b/python/paddle/fluid/data_feeder.py index 60f844b27be..26371d0d6ee 100644 --- a/python/paddle/fluid/data_feeder.py +++ b/python/paddle/fluid/data_feeder.py @@ -22,7 +22,7 @@ from six.moves import zip, range, xrange import multiprocessing import warnings -from .framework import Variable, default_main_program, _current_expected_place, in_dygraph_mode +from .framework import Variable, default_main_program, _current_expected_place, in_dygraph_mode, _in_eager_mode from .framework import _cpu_num, _cuda_ids __all__ = ['DataFeeder'] @@ -102,12 +102,20 @@ def check_type(input, input_name, expected_type, op_name, extra_message=''): if not isinstance(expected_type, tuple): expected_type = (expected_type, ) expected_type += (core.VarBase, ) + # TODO(jiabin): uncomment it when we support declarative mode in eager + # if _in_eager_mode(): + # expected_type += (core.eager.EagerTensor, ) elif isinstance(input, core.VarBase): raise TypeError( "Please use `with fluid.dygraph.guard()` as context or `fluid.enable_dygraph()` to switch to imperative mode firstly. " "Because received '{}' in {} is a imperative Variable.".format( input_name, op_name)) - + elif hasattr(core, "eager"): + if isinstance(input, core.eager.EagerTensor): + raise TypeError( + "Please use `with fluid.dygraph.guard()` as context or `fluid.enable_dygraph()` to switch to imperative mode firstly. " + "Because received '{}' in {} is a imperative Variable.".format( + input_name, op_name)) if not isinstance(input, expected_type): raise TypeError( "The type of '%s' in %s must be %s, but received %s. %s" % diff --git a/python/paddle/fluid/dygraph/base.py b/python/paddle/fluid/dygraph/base.py index f54a1629196..9234577b8cc 100644 --- a/python/paddle/fluid/dygraph/base.py +++ b/python/paddle/fluid/dygraph/base.py @@ -25,7 +25,7 @@ from .tracer import Tracer import logging from ..data_feeder import convert_dtype import warnings -from ..framework import _get_paddle_place +from ..framework import _get_paddle_place, _in_eager_mode import paddle __all__ = [ @@ -720,10 +720,16 @@ def to_variable(value, name=None, zero_copy=None, dtype=None): if value.dtype != dtype: value = value.astype(dtype) - py_var = core.VarBase( - value=value, - place=framework._current_expected_place(), - persistable=False, - zero_copy=zero_copy, - name=name if name else '') - return py_var + if _in_eager_mode(): + return core.eager.EagerTensor(value, + framework._current_expected_place(), + False, zero_copy, name + if name else None, True) + else: + py_var = core.VarBase( + value=value, + place=framework._current_expected_place(), + persistable=False, + zero_copy=zero_copy, + name=name if name else '') + return py_var diff --git a/python/paddle/fluid/dygraph/layer_object_helper.py b/python/paddle/fluid/dygraph/layer_object_helper.py index 5bf5eda19a5..4ad575d325b 100644 --- a/python/paddle/fluid/dygraph/layer_object_helper.py +++ b/python/paddle/fluid/dygraph/layer_object_helper.py @@ -21,6 +21,7 @@ from ..param_attr import ParamAttr from .. import core from six.moves import zip from ..layer_helper_base import LayerHelperBase +from ..dygraph_utils import _append_activation_in_dygraph class LayerObjectHelper(LayerHelperBase): @@ -162,14 +163,18 @@ class LayerObjectHelper(LayerHelperBase): if (use_mkldnn is not None) and use_mkldnn: act['use_mkldnn'] = use_mkldnn act_type = act.pop('type') - - tmp = self.create_variable_for_type_inference(dtype=input_var.dtype) - self.append_op( - type=act_type, - inputs={"X": [input_var]}, - outputs={"Out": [tmp]}, - attrs=act) - return tmp + if in_dygraph_mode(): + res = _append_activation_in_dygraph(input_var, act_type, use_cudnn, + use_mkldnn) + return res + else: + tmp = self.create_variable_for_type_inference(dtype=input_var.dtype) + self.append_op( + type=act_type, + inputs={"X": [input_var]}, + outputs={"Out": [tmp]}, + attrs=act) + return tmp def is_instance(self, param, cls): """Check if the input parameter is instance of input class diff --git a/python/paddle/fluid/dygraph/math_op_patch.py b/python/paddle/fluid/dygraph/math_op_patch.py index 3731976ad18..92fbc89a46e 100644 --- a/python/paddle/fluid/dygraph/math_op_patch.py +++ b/python/paddle/fluid/dygraph/math_op_patch.py @@ -18,6 +18,7 @@ from .. import core from ..framework import Variable, convert_np_dtype_to_dtype_, _varbase_creator from ..layers.layer_function_generator import OpProtoHolder from . import no_grad +from ..framework import _in_eager_mode import numpy as np import warnings @@ -332,6 +333,7 @@ def monkey_patch_math_varbase(): ] global _already_patch_varbase + if not _already_patch_varbase: for method in varbase_methods: method_name = method[0] diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index f308af04e5e..a2cecb8030d 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -233,7 +233,7 @@ def monkey_patch_varbase(): if grad_tensor is not None: assert isinstance( grad_tensor, paddle. - Tensor), "The type of grad_tensot must be paddle.Tensor" + Tensor), "The type of grad_tensor must be paddle.Tensor" assert grad_tensor.shape == self.shape, \ "Tensor shape not match, Tensor of grad_tensor [ {} ] with shape {} mismatch Tensor [ {} ] with shape {}".format( grad_tensor.name, grad_tensor.shape, self.name, self.shape) diff --git a/python/paddle/fluid/eager/eager_tensor_patch_methods.py b/python/paddle/fluid/eager/eager_tensor_patch_methods.py index 547a948da40..f820d02d1ab 100644 --- a/python/paddle/fluid/eager/eager_tensor_patch_methods.py +++ b/python/paddle/fluid/eager/eager_tensor_patch_methods.py @@ -13,6 +13,9 @@ # limitations under the License. from .. import core as core +from .. import framework as framework +from ..dygraph.parallel import scale_loss +import numpy as np def monkey_patch_eagertensor(): @@ -20,5 +23,112 @@ def monkey_patch_eagertensor(): from paddle.tensor.to_string import eager_tensor_to_string return eager_tensor_to_string(self) + @framework.dygraph_only + def backward(self, grad_tensor=None, retain_graph=False): + """ + Run backward of current Graph which starts from current Tensor. + + The new gradient will accumulat on previous gradient. + + You can clear gradient by ``Tensor.clear_grad()`` . + + Args: + grad_tensor(Tensor, optional): initial gradient values of the current Tensor. If `grad_tensor` is None, + the initial gradient values of the current Tensor would be Tensor filled with 1.0; + if `grad_tensor` is not None, it must have the same length as the current Tensor. + Teh default value is None. + + retain_graph(bool, optional): If False, the graph used to compute grads will be freed. If you would + like to add more ops to the built graph after calling this method( :code:`backward` ), set the parameter + :code:`retain_graph` to True, then the grads will be retained. Thus, seting it to False is much more memory-efficient. + Defaults to False. + Returns: + NoneType: None + + Examples: + .. code-block:: python + + import paddle + x = paddle.to_tensor(5., stop_gradient=False) + for i in range(5): + y = paddle.pow(x, 4.0) + y.backward() + print("{}: {}".format(i, x.grad)) + # 0: [500.] + # 1: [1000.] + # 2: [1500.] + # 3: [2000.] + # 4: [2500.] + + x.clear_grad() + print("{}".format(x.grad)) + # 0. + + grad_tensor=paddle.to_tensor(2.) + for i in range(5): + y = paddle.pow(x, 4.0) + y.backward(grad_tensor) + print("{}: {}".format(i, x.grad)) + # 0: [1000.] + # 1: [2000.] + # 2: [3000.] + # 3: [4000.] + # 4: [5000.] + + """ + if framework.in_dygraph_mode(): + if grad_tensor is not None: + assert isinstance( + grad_tensor, core.eager.EagerTensor + ), "The type of grad_tensor must be paddle.Tensor" + assert grad_tensor.shape == self.shape, \ + "Tensor shape not match, Tensor of grad_tensor [ {} ] with shape {} mismatch Tensor [ {} ] with shape {}".format( + grad_tensor.name, grad_tensor.shape, self.name, self.shape) + grad_tensor = [grad_tensor] + else: + grad_tensor = [] + + if core.is_compiled_with_xpu() or core.is_compiled_with_npu(): + # TODO(liuyuhui): Currently only for xpu. Will be removed in the future. + scaled_loss = scale_loss(self) + core.eager.run_backward([scaled_loss], grad_tensor, + retain_graph) + else: + core.eager.run_backward([self], grad_tensor, retain_graph) + else: + raise ValueError( + "Variable.backward() is only available in DyGraph mode") + + @framework.dygraph_only + def gradient(self): + """ + .. warning:: + This API will be deprecated in the future, it is recommended to use + :code:`x.grad` which returns the tensor value of the gradient. + + Get the Gradient of Current Tensor. + + Returns: + ndarray: Numpy value of the gradient of current Tensor + + Examples: + .. code-block:: python + + import paddle + + x = paddle.to_tensor(5., stop_gradient=False) + y = paddle.pow(x, 4.0) + y.backward() + print("grad of x: {}".format(x.gradient())) + # [500.] + + """ + if self.grad is None: + return None + # TODO(wanghuancoder) support SELECTED_ROWS + return self.grad.numpy() + if hasattr(core, "eager"): setattr(core.eager.EagerTensor, "__str__", __str__) + setattr(core.eager.EagerTensor, "backward", backward) + setattr(core.eager.EagerTensor, "gradient", gradient) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index dd83fc58e00..cf148257c5f 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -84,6 +84,7 @@ core._disable_eager_mode() def _test_eager_guard(): core._enable_eager_mode() _C_ops.switch_to_eager_ops() + core._switch_tracer(_dygraph_tracer_) try: yield finally: @@ -920,6 +921,14 @@ def _varbase_creator(type=core.VarDesc.VarType.LOD_TENSOR, if not isinstance(dtype, core.VarDesc.VarType): dtype = convert_np_dtype_to_dtype_(dtype) + if _in_eager_mode(): + eager_tensor = core.eager.EagerTensor( + dtype if dtype else core.VarDesc.VarType.FP32, + list(shape) if shape else [], name, type + if type else core.VarDesc.VarType.LOD_TENSOR, True + if persistable else False) + eager_tensor.retain_grads() + return eager_tensor return core.VarBase(dtype if dtype else core.VarDesc.VarType.FP32, list(shape) if shape else [], name, type if type else core.VarDesc.VarType.LOD_TENSOR, True @@ -931,6 +940,8 @@ class VariableMetaClass(type): def __instancecheck__(cls, instance): t = type(instance) if in_dygraph_mode(): + if _in_eager_mode(): + return issubclass(t, core.eager.EagerTensor) return issubclass(t, core.VarBase) else: return issubclass(t, Variable) @@ -941,6 +952,8 @@ class ParameterMetaClass(VariableMetaClass): def __instancecheck__(cls, instance): t = type(instance) if in_dygraph_mode(): + if _in_eager_mode(): + return issubclass(t, EagerParamBase) return issubclass(t, ParamBase) else: return issubclass(t, Parameter) @@ -3244,7 +3257,10 @@ class Block(object): global_block = self.program.global_block() param = None if in_dygraph_mode(): - param = ParamBase(*args, **kwargs) + if _in_eager_mode(): + param = EagerParamBase(*args, **kwargs) + else: + param = ParamBase(*args, **kwargs) else: param = Parameter(global_block, *args, **kwargs) @@ -6243,6 +6259,153 @@ class ParamBase(core.VarBase): __repr__ = __str__ +if hasattr(core, "eager"): + _core_eager_eagertensor = core.eager.EagerTensor +else: + _core_eager_eagertensor = object + + +class EagerParamBase(_core_eager_eagertensor): + """ + EagerParamBase is derived from Tensor( Which is the concept in Eager-Dygraph Mode). + A EagerParamBase is a persistable Tensor, and will be updated by optimizers + after each iteration. + The training of a neural network is essentially the updating of + its EagerParamBase. + + Relative to a general Tensor, a EagerParamBase has several its own + member variables: + + Args: + trainable(bool): True if the EagerParamBase need to be updated after + iterations. + optimize_attr(map): EagerParamBase attributes related with optimizing. + Currently, it only contains 'learning_rate'. + Default: {'learning_rate': 1.0} + regularizer(WeightDecayRegularizer): The Regularizer which will + be applied on the EagerParamBase. Default: None + do_model_average(bool): True if the model average strategy will + be applied on this EagerParamBase. + need_clip (bool): Whether the parameter gradient need to be cliped + in optimizer. Default is True. + """ + + @dygraph_only + def __init__(self, shape, dtype, **kwargs): + if shape is None: + raise ValueError("The shape of Parameter should not be None") + if dtype is None: + raise ValueError("The dtype of Parameter should not be None") + + if len(shape) == 0: + raise ValueError( + "The dimensions of shape for Parameter must be greater than 0") + + for each in shape: + if each < 0: + raise ValueError( + "Each dimension of shape for Parameter must be greater than 0, but received %s" + % list(shape)) + + if dtype is not None: + if not isinstance(dtype, core.VarDesc.VarType): + dtype = convert_np_dtype_to_dtype_(dtype) + + name = kwargs.get('name', unique_name.generate('_eager_param_base')) + + super(EagerParamBase, self).__init__( + dtype if dtype else core.VarDesc.VarType.FP32, + list(shape) + if shape else [], name, core.VarDesc.VarType.LOD_TENSOR, True) + self.retain_grads() + + trainable = kwargs.get('trainable', True) + self.stop_gradient = not trainable + + self.optimize_attr = kwargs.get('optimize_attr', {'learning_rate': 1.0}) + + self.regularizer = kwargs.get('regularizer', None) + + self.do_model_average = kwargs.get('do_model_average', None) + + self.need_clip = kwargs.get('need_clip', True) + + self.is_distributed = kwargs.get('is_distributed', False) + # self.block = default_main_program().global_block() + + @property + def trainable(self): + return not self.stop_gradient + + @trainable.setter + def trainable(self, trainable): + if isinstance(trainable, bool): + self.stop_gradient = not trainable + else: + raise ValueError( + "The type of trainable MUST be bool, but the type is ", + type(trainable)) + + def __str__(self): + """ + Convert a EagerParamBase object to a readable string. + + Returns(str): A readable string. + + Examples: + .. code-block:: python + + import paddle + linear = paddle.nn.Linear(3, 3) + print(linear.weight) + # Parameter containing: + # Tensor(shape=[3, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=False, + # [[ 0.48948765, 0.05829060, -0.25524026], + # [-0.70368278, 0.52986908, -0.68742192], + # [-0.54217887, 0.48439729, 0.34082305]]) + """ + return "Parameter containing:\n{tensor}".format( + tensor=super(EagerParamBase, self).__str__()) + + def __deepcopy__(self, memo): + """ + Deep copy parameter, it will always performs Tensor copy. + + Examples: + .. code-block:: python + + import paddle + import copy + linear = paddle.nn.Linear(1, 3) + linear_copy = copy.deepcopy(linear) + + print(linear.weight) + # Parameter containing: + # Tensor(shape=[1, 3], dtype=float32, place=CPUPlace, stop_gradient=False, + # [[-0.30929261, -0.90929240, -1.07851017]]) + + print(linear_copy.weight) + # Parameter containing: + # Tensor(shape=[1, 3], dtype=float32, place=CPUPlace, stop_gradient=False, + # [[-0.30929261, -0.90929240, -1.07851017]]) + + """ + state = copy.deepcopy(self.__dict__, memo) + state["name"] = self.name + unique_name.generate("_deepcopy") + new_param = EagerParamBase(self.shape, self.dtype, **state) + memo[id(self)] = new_param + new_param.copy_(self, True) + return new_param + + def _copy_to(self, device, blocking): + state = copy.deepcopy(self.__dict__) + new_param = EagerParamBase(self.shape, self.dtype, **state) + core.eager.tensor_copy(self, new_param, device, blocking) + return new_param + + __repr__ = __str__ + + # program is a global instance. _main_program_ = Program() _startup_program_ = Program() diff --git a/python/paddle/fluid/initializer.py b/python/paddle/fluid/initializer.py index a7631848cd3..27e4ef6fe28 100644 --- a/python/paddle/fluid/initializer.py +++ b/python/paddle/fluid/initializer.py @@ -22,6 +22,7 @@ import numpy as np from .core import VarDesc from . import unique_name from .data_feeder import check_variable_and_dtype, check_type, check_dtype +from paddle import _C_ops __all__ = [ 'Constant', 'Uniform', 'Normal', 'TruncatedNormal', 'Xavier', 'Bilinear', @@ -132,7 +133,8 @@ class ConstantInitializer(Initializer): """ block = self._check_block(block) - assert isinstance(var, framework.Variable) + assert (isinstance(var, framework.Variable) or + isinstance(var, framework.EagerParamBase)) assert isinstance(block, framework.Block) # to be compatible of fp16 initializers @@ -149,30 +151,42 @@ class ConstantInitializer(Initializer): out_dtype = var.dtype out_var = var - # fill constant should set the "str_value" to preserve precision - op = block.append_op( - type="fill_constant", - outputs={"Out": out_var}, - attrs={ - "shape": var.shape, - "dtype": int(out_dtype), - "value": float(self._value), - 'str_value': str(float(self._value)), - 'force_cpu': self._force_cpu - }, - stop_gradient=True) + if framework.in_dygraph_mode(): + out_var = _C_ops.fill_constant( + out_var, 'value', + float(self._value), 'force_cpu', self._force_cpu, 'dtype', + int(out_dtype), 'str_value', + str(float(self._value)), 'shape', var.shape) + if var.dtype == VarDesc.VarType.FP16: + var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype, + 'out_dtype', var.dtype) + var.copy_(var_tmp, True) + else: + var.copy_(out_var, True) + return None + else: + # fill constant should set the "str_value" to preserve precision + op = block.append_op( + type="fill_constant", + outputs={"Out": out_var}, + attrs={ + "shape": var.shape, + "dtype": int(out_dtype), + "value": float(self._value), + 'str_value': str(float(self._value)), + 'force_cpu': self._force_cpu + }, + stop_gradient=True) - if var.dtype == VarDesc.VarType.FP16: - block.append_op( - type="cast", - inputs={"X": out_var}, - outputs={"Out": var}, - attrs={"in_dtype": out_var.dtype, - "out_dtype": var.dtype}) - - if not framework.in_dygraph_mode(): + if var.dtype == VarDesc.VarType.FP16: + block.append_op( + type="cast", + inputs={"X": out_var}, + outputs={"Out": var}, + attrs={"in_dtype": out_var.dtype, + "out_dtype": var.dtype}) var.op = op - return op + return op class UniformInitializer(Initializer): @@ -257,33 +271,45 @@ class UniformInitializer(Initializer): out_dtype = var.dtype out_var = var - op = block.append_op( - type="uniform_random", - inputs={}, - outputs={"Out": out_var}, - attrs={ - "shape": var.shape, - "dtype": out_dtype, - "min": self._low, - "max": self._high, - "seed": self._seed, - "diag_num": self._diag_num, - "diag_step": self._diag_step, - "diag_val": self._diag_val - }, - stop_gradient=True) + if framework.in_dygraph_mode(): + out_var = _C_ops.uniform_random( + 'shape', var.shape, 'min', self._low, 'max', self._high, 'seed', + self._seed, 'dtype', out_dtype, 'diag_num', self._diag_num, + 'diag_step', self._diag_step, 'diag_val', self._diag_val) + if var.dtype == VarDesc.VarType.FP16: + var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype, + 'out_dtype', var.dtype) + var.copy_(var_tmp, True) + else: + var.copy_(out_var, True) + return None + else: + op = block.append_op( + type="uniform_random", + inputs={}, + outputs={"Out": out_var}, + attrs={ + "shape": var.shape, + "dtype": out_dtype, + "min": self._low, + "max": self._high, + "seed": self._seed, + "diag_num": self._diag_num, + "diag_step": self._diag_step, + "diag_val": self._diag_val + }, + stop_gradient=True) + + if var.dtype == VarDesc.VarType.FP16: + block.append_op( + type="cast", + inputs={"X": out_var}, + outputs={"Out": var}, + attrs={"in_dtype": out_var.dtype, + "out_dtype": var.dtype}) - if var.dtype == VarDesc.VarType.FP16: - block.append_op( - type="cast", - inputs={"X": out_var}, - outputs={"Out": var}, - attrs={"in_dtype": out_var.dtype, - "out_dtype": var.dtype}) - - if not framework.in_dygraph_mode(): var.op = op - return op + return op class NormalInitializer(Initializer): @@ -349,29 +375,40 @@ class NormalInitializer(Initializer): out_dtype = var.dtype out_var = var - op = block.append_op( - type="gaussian_random", - outputs={"Out": out_var}, - attrs={ - "shape": var.shape, - "dtype": out_dtype, - "mean": self._mean, - "std": self._std_dev, - "seed": self._seed, - "use_mkldnn": False - }, - stop_gradient=True) + if framework.in_dygraph_mode(): + out_var = _C_ops.gaussian_random( + 'shape', var.shape, 'dtype', out_dtype, 'mean', self._mean, + 'std', self._std_dev, 'seed', self._seed, 'use_mkldnn', False) + if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]: + var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype, + 'out_dtype', var.dtype) + var.copy_(var_tmp, True) + else: + var.copy_(out_var, True) + return None + else: + op = block.append_op( + type="gaussian_random", + outputs={"Out": out_var}, + attrs={ + "shape": var.shape, + "dtype": out_dtype, + "mean": self._mean, + "std": self._std_dev, + "seed": self._seed, + "use_mkldnn": False + }, + stop_gradient=True) - if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]: - block.append_op( - type="cast", - inputs={"X": out_var}, - outputs={"Out": var}, - attrs={"in_dtype": out_var.dtype, - "out_dtype": var.dtype}) - if not framework.in_dygraph_mode(): + if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]: + block.append_op( + type="cast", + inputs={"X": out_var}, + outputs={"Out": var}, + attrs={"in_dtype": out_var.dtype, + "out_dtype": var.dtype}) var.op = op - return op + return op class TruncatedNormalInitializer(Initializer): @@ -433,28 +470,39 @@ class TruncatedNormalInitializer(Initializer): out_dtype = var.dtype out_var = var - op = block.append_op( - type="truncated_gaussian_random", - outputs={"Out": out_var}, - attrs={ - "shape": var.shape, - "dtype": out_dtype, - "mean": self._mean, - "std": self._std_dev, - "seed": self._seed - }, - stop_gradient=True) + if framework.in_dygraph_mode(): + out_var = _C_ops.truncated_gaussian_random( + 'shape', var.shape, 'dtype', out_dtype, 'mean', self._mean, + 'std', self._std_dev, 'seed', self._seed) + if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]: + var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype, + 'out_dtype', var.dtype) + var.copy_(var_tmp, True) + else: + var.copy_(out_var, True) + return None + else: + op = block.append_op( + type="truncated_gaussian_random", + outputs={"Out": out_var}, + attrs={ + "shape": var.shape, + "dtype": out_dtype, + "mean": self._mean, + "std": self._std_dev, + "seed": self._seed + }, + stop_gradient=True) - if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]: - block.append_op( - type="cast", - inputs={"X": out_var}, - outputs={"Out": var}, - attrs={"in_dtype": out_var.dtype, - "out_dtype": var.dtype}) - if not framework.in_dygraph_mode(): + if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]: + block.append_op( + type="cast", + inputs={"X": out_var}, + outputs={"Out": var}, + attrs={"in_dtype": out_var.dtype, + "out_dtype": var.dtype}) var.op = op - return op + return op class XavierInitializer(Initializer): @@ -553,47 +601,66 @@ class XavierInitializer(Initializer): out_dtype = var.dtype out_var = var - if self._uniform: - limit = np.sqrt(6.0 / float(fan_in + fan_out)) - op = block.append_op( - type="uniform_random", - inputs={}, - outputs={"Out": out_var}, - attrs={ - "shape": out_var.shape, - "dtype": out_dtype, - "min": -limit, - "max": limit, - "seed": self._seed - }, - stop_gradient=True) - + if framework.in_dygraph_mode(): + if self._uniform: + limit = np.sqrt(6.0 / float(fan_in + fan_out)) + out_var = _C_ops.uniform_random('shape', var.shape, 'min', + -limit, 'max', limit, 'seed', + self._seed, 'dtype', out_dtype) + else: + std = np.sqrt(2.0 / float(fan_in + fan_out)) + out_var = _C_ops.gaussian_random( + 'shape', out_var.shape, 'dtype', out_dtype, 'mean', 0.0, + 'std', std, 'seed', self._seed) + + if var.dtype == VarDesc.VarType.FP16 or ( + var.dtype == VarDesc.VarType.BF16 and not self._uniform): + var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype, + 'out_dtype', var.dtype) + var.copy_(var_tmp, True) + else: + var.copy_(out_var, True) + return None else: - std = np.sqrt(2.0 / float(fan_in + fan_out)) - op = block.append_op( - type="gaussian_random", - outputs={"Out": out_var}, - attrs={ - "shape": out_var.shape, - "dtype": out_dtype, - "mean": 0.0, - "std": std, - "seed": self._seed - }, - stop_gradient=True) + if self._uniform: + limit = np.sqrt(6.0 / float(fan_in + fan_out)) + op = block.append_op( + type="uniform_random", + inputs={}, + outputs={"Out": out_var}, + attrs={ + "shape": out_var.shape, + "dtype": out_dtype, + "min": -limit, + "max": limit, + "seed": self._seed + }, + stop_gradient=True) + else: + std = np.sqrt(2.0 / float(fan_in + fan_out)) + op = block.append_op( + type="gaussian_random", + outputs={"Out": out_var}, + attrs={ + "shape": out_var.shape, + "dtype": out_dtype, + "mean": 0.0, + "std": std, + "seed": self._seed + }, + stop_gradient=True) + + if var.dtype == VarDesc.VarType.FP16 or ( + var.dtype == VarDesc.VarType.BF16 and not self._uniform): + block.append_op( + type="cast", + inputs={"X": out_var}, + outputs={"Out": var}, + attrs={"in_dtype": out_var.dtype, + "out_dtype": var.dtype}) - if var.dtype == VarDesc.VarType.FP16 or ( - var.dtype == VarDesc.VarType.BF16 and not self._uniform): - block.append_op( - type="cast", - inputs={"X": out_var}, - outputs={"Out": var}, - attrs={"in_dtype": out_var.dtype, - "out_dtype": var.dtype}) - - if not framework.in_dygraph_mode(): var.op = op - return op + return op class MSRAInitializer(Initializer): @@ -686,47 +753,68 @@ class MSRAInitializer(Initializer): out_dtype = var.dtype out_var = var - if self._uniform: - limit = np.sqrt(6.0 / float(fan_in)) - op = block.append_op( - type="uniform_random", - inputs={}, - outputs={"Out": out_var}, - attrs={ - "shape": out_var.shape, - "dtype": int(out_dtype), - "min": -limit, - "max": limit, - "seed": self._seed - }, - stop_gradient=True) - + if framework.in_dygraph_mode(): + if self._uniform: + limit = np.sqrt(6.0 / float(fan_in)) + out_var = _C_ops.uniform_random('shape', out_var.shape, 'min', + -limit, 'max', limit, 'seed', + self._seed, 'dtype', + int(out_dtype)) + else: + std = np.sqrt(2.0 / float(fan_in)) + out_var = _C_ops.gaussian_random( + 'shape', out_var.shape, 'dtype', + int(out_dtype), 'mean', 0.0, 'std', std, 'seed', self._seed) + + if var.dtype == VarDesc.VarType.FP16 or ( + var.dtype == VarDesc.VarType.BF16 and not self._uniform): + var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype, + 'out_dtype', var.dtype) + var.copy_(var_tmp, True) + else: + var.copy_(out_var, True) + return None else: - std = np.sqrt(2.0 / float(fan_in)) - op = block.append_op( - type="gaussian_random", - outputs={"Out": out_var}, - attrs={ - "shape": out_var.shape, - "dtype": int(out_dtype), - "mean": 0.0, - "std": std, - "seed": self._seed - }, - stop_gradient=True) + if self._uniform: + limit = np.sqrt(6.0 / float(fan_in)) + op = block.append_op( + type="uniform_random", + inputs={}, + outputs={"Out": out_var}, + attrs={ + "shape": out_var.shape, + "dtype": int(out_dtype), + "min": -limit, + "max": limit, + "seed": self._seed + }, + stop_gradient=True) + + else: + std = np.sqrt(2.0 / float(fan_in)) + op = block.append_op( + type="gaussian_random", + outputs={"Out": out_var}, + attrs={ + "shape": out_var.shape, + "dtype": int(out_dtype), + "mean": 0.0, + "std": std, + "seed": self._seed + }, + stop_gradient=True) + + if var.dtype == VarDesc.VarType.FP16 or ( + var.dtype == VarDesc.VarType.BF16 and not self._uniform): + block.append_op( + type="cast", + inputs={"X": out_var}, + outputs={"Out": var}, + attrs={"in_dtype": out_var.dtype, + "out_dtype": var.dtype}) - if var.dtype == VarDesc.VarType.FP16 or ( - var.dtype == VarDesc.VarType.BF16 and not self._uniform): - block.append_op( - type="cast", - inputs={"X": out_var}, - outputs={"Out": var}, - attrs={"in_dtype": out_var.dtype, - "out_dtype": var.dtype}) - - if not framework.in_dygraph_mode(): var.op = op - return op + return op class BilinearInitializer(Initializer): @@ -839,28 +927,44 @@ class BilinearInitializer(Initializer): if np.prod(shape) > 1024 * 1024: raise ValueError("The size of input is too big. ") - op = block.append_op( - type='assign_value', - outputs={'Out': [out_var]}, - attrs={ - 'dtype': out_dtype, - 'shape': list(shape), - value_name: values - }) - if var.dtype in [ - VarDesc.VarType.FP16, VarDesc.VarType.BF16, VarDesc.VarType.FP64 - ]: - block.append_op( - type="cast", - inputs={"X": out_var}, - outputs={"Out": var}, - attrs={"in_dtype": out_var.dtype, - "out_dtype": var.dtype}) - - if not framework.in_dygraph_mode(): + if framework.in_dygraph_mode(): + out_var = _C_ops.assign_value('shape', + list(shape), 'dtype', out_dtype, + value_name, values) + if var.dtype in [ + VarDesc.VarType.FP16, VarDesc.VarType.BF16, + VarDesc.VarType.FP64 + ]: + var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype, + 'out_dtype', var.dtype) + var.copy_(var_tmp, True) + else: + var.copy_(out_var, True) + return None + else: + op = block.append_op( + type='assign_value', + outputs={'Out': [out_var]}, + attrs={ + 'dtype': out_dtype, + 'shape': list(shape), + value_name: values + }) + + if var.dtype in [ + VarDesc.VarType.FP16, VarDesc.VarType.BF16, + VarDesc.VarType.FP64 + ]: + block.append_op( + type="cast", + inputs={"X": out_var}, + outputs={"Out": var}, + attrs={"in_dtype": out_var.dtype, + "out_dtype": var.dtype}) + var.op = op - return op + return op class NumpyArrayInitializer(Initializer): @@ -932,27 +1036,39 @@ class NumpyArrayInitializer(Initializer): if self._value.size > 1024 * 1024 * 1024: raise ValueError("The size of input is too big. Please consider " "saving it to file and 'load_op' to load it") - op = block.append_op( - type='assign_value', - outputs={'Out': out_var}, - attrs={ - 'dtype': out_dtype, - 'shape': list(self._value.shape), - value_name: values - }, - stop_gradient=True) - if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]: - block.append_op( - type="cast", - inputs={"X": out_var}, - outputs={"Out": var}, - attrs={"in_dtype": out_var.dtype, - "out_dtype": var.dtype}) - - if not framework.in_dygraph_mode(): + if framework.in_dygraph_mode(): + out_var = _C_ops.assign_value('shape', + list(self._value.shape), 'dtype', + out_dtype, value_name, values) + if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]: + var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype, + 'out_dtype', var.dtype) + var.copy_(var_tmp, True) + else: + var.copy_(out_var, True) + return None + else: + op = block.append_op( + type='assign_value', + outputs={'Out': out_var}, + attrs={ + 'dtype': out_dtype, + 'shape': list(self._value.shape), + value_name: values + }, + stop_gradient=True) + + if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]: + block.append_op( + type="cast", + inputs={"X": out_var}, + outputs={"Out": var}, + attrs={"in_dtype": out_var.dtype, + "out_dtype": var.dtype}) + var.op = op - return op + return op def set_global_initializer(weight_init, bias_init=None): diff --git a/python/paddle/fluid/layer_helper.py b/python/paddle/fluid/layer_helper.py index 2b677c11e9d..72cdd1f9ad5 100644 --- a/python/paddle/fluid/layer_helper.py +++ b/python/paddle/fluid/layer_helper.py @@ -24,6 +24,7 @@ from .param_attr import ParamAttr from . import core from six.moves import zip from .layer_helper_base import LayerHelperBase +from .dygraph_utils import _append_activation_in_dygraph class LayerHelper(LayerHelperBase): @@ -145,21 +146,27 @@ class LayerHelper(LayerHelperBase): else: raise TypeError(str(act) + " should be unicode or str") + use_cudnn = None if 'use_cudnn' in self.kwargs and self.kwargs.get('use_cudnn'): - act['use_cudnn'] = self.kwargs.get('use_cudnn') + use_cudnn = self.kwargs.get('use_cudnn') + act['use_cudnn'] = use_cudnn use_mkldnn = self.kwargs.get( 'use_mkldnn', _global_flags().get("FLAGS_use_mkldnn", False)) if use_mkldnn: act['use_mkldnn'] = use_mkldnn act_type = act.pop('type') - - tmp = self.create_variable_for_type_inference(dtype=input_var.dtype) - self.append_op( - type=act_type, - inputs={"X": [input_var]}, - outputs={"Out": [tmp]}, - attrs=act) - return tmp + if in_dygraph_mode(): + res = _append_activation_in_dygraph(input_var, act_type, use_cudnn, + use_mkldnn) + return res + else: + tmp = self.create_variable_for_type_inference(dtype=input_var.dtype) + self.append_op( + type=act_type, + inputs={"X": [input_var]}, + outputs={"Out": [tmp]}, + attrs=act) + return tmp #TODO (jiabin): should we remove this since it has never be used def _get_default_initializer(self, dtype): diff --git a/python/paddle/fluid/layer_helper_base.py b/python/paddle/fluid/layer_helper_base.py index c2de5670eb4..67fcd901ded 100644 --- a/python/paddle/fluid/layer_helper_base.py +++ b/python/paddle/fluid/layer_helper_base.py @@ -17,7 +17,7 @@ from __future__ import print_function import copy import numpy as np -from .framework import Variable, default_main_program, default_startup_program, in_dygraph_mode, _current_expected_place +from .framework import Variable, default_main_program, default_startup_program, in_dygraph_mode, _current_expected_place, _in_eager_mode from . import unique_name from .param_attr import ParamAttr, WeightNormParamAttr from . import core @@ -84,13 +84,19 @@ class LayerHelperBase(object): if isinstance(value, np.ndarray): assert in_dygraph_mode( ), "to_variable could only be called in dygraph mode" - py_var = core.VarBase( - value=value, - name=name if name else '', - persistable=False, - place=_current_expected_place(), - zero_copy=False) - return py_var + if _in_eager_mode(): + return core.eager.EagerTensor(value, + _current_expected_place(), False, + False, name + if name else None, True) + else: + py_var = core.VarBase( + value=value, + name=name if name else '', + persistable=False, + place=_current_expected_place(), + zero_copy=False) + return py_var elif isinstance(value, (core.VarBase, Variable)): return value else: diff --git a/python/paddle/fluid/tests/unittests/test_egr_python_api.py b/python/paddle/fluid/tests/unittests/test_egr_python_api.py index 08a68ca246f..803631a4d2c 100644 --- a/python/paddle/fluid/tests/unittests/test_egr_python_api.py +++ b/python/paddle/fluid/tests/unittests/test_egr_python_api.py @@ -16,9 +16,10 @@ import paddle.fluid.core as core import paddle.fluid.eager.eager_tensor_patch_methods as eager_tensor_patch_methods import paddle import numpy as np -from paddle.fluid.framework import _test_eager_guard +from paddle.fluid.framework import _test_eager_guard, EagerParamBase from paddle.fluid.data_feeder import convert_dtype import unittest +import copy class EagerScaleTestCase(unittest.TestCase): @@ -46,14 +47,42 @@ class EagerScaleTestCase(unittest.TestCase): grad_data = np.ones([4, 16, 16, 32]).astype('float32') grad_eager = paddle.to_tensor(grad_data, 'float32', core.CPUPlace()) - core.eager.retain_grad_for_tensor(data_eager) + data_eager.retain_grads() out_eager = core.eager.scale(data_eager, 1.0, 0.9, True, True) self.assertFalse(data_eager.grad._is_initialized()) - core.eager.run_backward([out_eager], [grad_eager], False) + out_eager.backward(grad_eager, False) self.assertTrue(data_eager.grad._is_initialized()) self.assertTrue(np.array_equal(data_eager.grad.numpy(), input_data)) + def test_retain_grad_and_run_backward_raises(self): + with _test_eager_guard(): + paddle.set_device("cpu") + + input_data = np.ones([4, 16, 16, 32]).astype('float32') + data_eager = paddle.to_tensor(input_data, 'float32', + core.CPUPlace(), False) + + grad_data = np.ones([4, 16, 16, 32]).astype('float32') + grad_data2 = np.ones([4, 16]).astype('float32') + grad_eager = paddle.to_tensor(grad_data, 'float32', core.CPUPlace()) + grad_eager2 = paddle.to_tensor(grad_data2, 'float32', + core.CPUPlace()) + + data_eager.retain_grads() + + out_eager = core.eager.scale(data_eager, 1.0, 0.9, True, True) + self.assertFalse(data_eager.grad._is_initialized()) + with self.assertRaisesRegexp( + AssertionError, + "The type of grad_tensor must be paddle.Tensor"): + out_eager.backward(grad_data, False) + + with self.assertRaisesRegexp( + AssertionError, + "Tensor shape not match, Tensor of grad_tensor /*"): + out_eager.backward(grad_eager2, False) + class EagerDtypeTestCase(unittest.TestCase): def check_to_tesnsor_and_numpy(self, dtype, proto_dtype): @@ -192,6 +221,34 @@ class EagerTensorPropertiesTestCase(unittest.TestCase): self.assertTrue(egr_tensor9.place._equals(place)) self.assertTrue(np.array_equal(egr_tensor9.numpy(), arr4)) + with self.assertRaisesRegexp( + ValueError, "The shape of Parameter should not be None"): + eager_param = EagerParamBase(shape=None, dtype="float32") + + with self.assertRaisesRegexp( + ValueError, "The dtype of Parameter should not be None"): + eager_param = EagerParamBase(shape=[1, 1], dtype=None) + + with self.assertRaisesRegexp( + ValueError, + "The dimensions of shape for Parameter must be greater than 0"): + eager_param = EagerParamBase(shape=[], dtype="float32") + + with self.assertRaisesRegexp( + ValueError, + "Each dimension of shape for Parameter must be greater than 0, but received /*" + ): + eager_param = EagerParamBase(shape=[-1], dtype="float32") + + eager_param = EagerParamBase(shape=[1, 1], dtype="float32") + self.assertTrue(eager_param.trainable) + eager_param.trainable = False + self.assertFalse(eager_param.trainable) + with self.assertRaisesRegexp( + ValueError, + "The type of trainable MUST be bool, but the type is /*"): + eager_param.trainable = "False" + def test_constructor(self): print("Test_constructor") paddle.set_device("cpu") @@ -291,5 +348,80 @@ class EagerTensorPropertiesTestCase(unittest.TestCase): core._disable_eager_mode() +class EagerParamBaseUsageTestCase(unittest.TestCase): + def test_print(self): + with _test_eager_guard(): + linear = paddle.nn.Linear(3, 3, bias_attr=False) + print(linear.weight) + + def test_copy(self): + with _test_eager_guard(): + linear = paddle.nn.Linear(1, 3) + linear_copy = copy.deepcopy(linear) + linear_copy2 = linear.weight._copy_to(core.CPUPlace(), True) + self.assertTrue( + np.array_equal(linear.weight.numpy(), + linear_copy.weight.numpy())) + self.assertTrue( + np.array_equal(linear.weight.numpy(), linear_copy2.numpy())) + + def func_fp16_initilaizer(self): + paddle.set_default_dtype("float16") + linear1 = paddle.nn.Linear(1, 3, bias_attr=False) + linear2 = paddle.nn.Linear( + 1, + 3, + bias_attr=False, + weight_attr=paddle.fluid.initializer.Uniform()) + linear3 = paddle.nn.Linear( + 1, + 3, + bias_attr=False, + weight_attr=paddle.fluid.initializer.TruncatedNormalInitializer()) + linear4 = paddle.nn.Linear( + 1, + 3, + bias_attr=False, + weight_attr=paddle.fluid.initializer.MSRAInitializer()) + res = [ + linear1.weight.numpy(), linear2.weight.numpy(), + linear3.weight.numpy(), linear4.weight.numpy() + ] + paddle.set_default_dtype("float32") + return res + + def test_fp16_initializer(self): + res1 = list() + res2 = list() + paddle.seed(102) + paddle.framework.random._manual_program_seed(102) + with _test_eager_guard(): + res1 = self.func_fp16_initilaizer() + res2 = self.func_fp16_initilaizer() + + for i in range(len(res1)): + self.assertTrue(np.array_equal(res1[i], res2[i])) + + def func_layer_helper_base(self, value): + base = paddle.fluid.layer_helper_base.LayerHelperBase("test_layer", + "test_layer") + return base.to_variable(value).numpy() + + def func_base_to_variable(self, value): + paddle.fluid.dygraph.base.to_variable(value) + + def test_to_variable(self): + value = np.random.rand(4, 16, 16, 32).astype('float32') + res1 = None + res3 = None + with _test_eager_guard(): + res1 = self.func_layer_helper_base(value) + res3 = self.func_base_to_variable(value) + res2 = self.func_layer_helper_base(value) + res4 = self.func_base_to_variable(value) + self.assertTrue(np.array_equal(res1, res2)) + self.assertTrue(np.array_equal(res3, res4)) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_basic.py b/python/paddle/fluid/tests/unittests/test_imperative_basic.py index 8f8ecca105e..c8836cd7767 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_basic.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_basic.py @@ -24,6 +24,7 @@ from test_imperative_base import new_program_scope import paddle.fluid.dygraph_utils as dygraph_utils from paddle.fluid.dygraph.layer_object_helper import LayerObjectHelper import paddle +from paddle.fluid.framework import _test_eager_guard class MyLayer(fluid.Layer): @@ -180,12 +181,12 @@ class SimpleRNN(fluid.Layer): class TestImperative(unittest.TestCase): - def test_functional_dygraph_context(self): + def functional_dygraph_context(self): self.assertFalse(fluid.dygraph.enabled()) fluid.enable_dygraph() self.assertTrue(fluid.dygraph.enabled()) np_inp = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) - var_inp = fluid.dygraph.base.to_variable(np_inp) + var_inp = paddle.to_tensor(np_inp) mlp = MLP(input_size=2) out = mlp(var_inp) dy_out1 = out.numpy() @@ -195,7 +196,7 @@ class TestImperative(unittest.TestCase): self.assertFalse(fluid.dygraph.enabled()) with fluid.dygraph.guard(): self.assertTrue(fluid.dygraph.enabled()) - var_inp = fluid.dygraph.base.to_variable(np_inp) + var_inp = paddle.to_tensor(np_inp) mlp = MLP(input_size=2) out = mlp(var_inp) dy_out2 = out.numpy() @@ -205,7 +206,12 @@ class TestImperative(unittest.TestCase): self.assertTrue(np.array_equal(dy_out1, dy_out2)) self.assertTrue(np.array_equal(dy_grad1, dy_grad2)) - def test_functional_paddle_imperative_dygraph_context(self): + def test_functional_dygraph_context(self): + with _test_eager_guard(): + self.functional_dygraph_context() + self.functional_dygraph_context() + + def functional_paddle_imperative_dygraph_context(self): self.assertFalse(paddle.in_dynamic_mode()) paddle.disable_static() self.assertTrue(paddle.in_dynamic_mode()) @@ -231,13 +237,27 @@ class TestImperative(unittest.TestCase): self.assertTrue(np.array_equal(dy_out1, dy_out2)) self.assertTrue(np.array_equal(dy_grad1, dy_grad2)) - def test_isinstance(self): + def test_functional_paddle_imperative_dygraph_context(self): + with _test_eager_guard(): + self.functional_paddle_imperative_dygraph_context() + self.functional_paddle_imperative_dygraph_context() + + def func_isinstance(self): var = fluid.layers.data(shape=[1], name='x', dtype='float32') self.assertTrue(isinstance(var, fluid.Variable)) with fluid.dygraph.guard(): - var_base = fluid.dygraph.base.to_variable(np.array([3, 4, 5])) - self.assertTrue(isinstance(var_base, core.VarBase)) - self.assertTrue(isinstance(var_base, fluid.Variable)) + if fluid.framework._in_eager_mode(): + var_base = paddle.to_tensor(np.array([3, 4, 5])) + self.assertTrue(isinstance(var_base, core.eager.EagerTensor)) + else: + var_base = paddle.to_tensor(np.array([3, 4, 5])) + self.assertTrue(isinstance(var_base, core.VarBase)) + self.assertTrue(isinstance(var_base, fluid.Variable)) + + def test_isinstance(self): + with _test_eager_guard(): + self.func_isinstance() + self.func_isinstance() def test_create_VarBase(self): x = np.ones([2, 2], np.float32) @@ -247,7 +267,7 @@ class TestImperative(unittest.TestCase): with fluid.dygraph.guard(): tmp = fluid.core.VarBase(value=x, place=fluid.core.CPUPlace()) tmp2 = fluid.core.VarBase(y, fluid.core.CPUPlace()) - tmp3 = fluid.dygraph.base.to_variable(x) + tmp3 = paddle.to_tensor(x) tmp4 = fluid.core.VarBase(y) tmp5 = fluid.core.VarBase(value=x) tmp6 = fluid.core.VarBase(t) @@ -269,7 +289,7 @@ class TestImperative(unittest.TestCase): self.assertTrue(l1.weight.stop_gradient is False) tmp = l1.weight * 2 self.assertTrue(tmp.stop_gradient) - x = fluid.dygraph.to_variable(data) + x = paddle.to_tensor(data) y = l0(x) + tmp o = l1(y) o.backward() @@ -287,7 +307,7 @@ class TestImperative(unittest.TestCase): self.assertTrue(l1.weight.stop_gradient is False) tmp = l1.weight * 2 self.assertTrue(tmp.stop_gradient) - x = fluid.dygraph.to_variable(data) + x = paddle.to_tensor(data) y = l0(x) + tmp o = l1(y) o.backward() @@ -308,7 +328,7 @@ class TestImperative(unittest.TestCase): tmp2 = l1.weight * 2 self.assertTrue(tmp.stop_gradient) self.assertTrue(tmp2.stop_gradient is False) - x = fluid.dygraph.to_variable(data) + x = paddle.to_tensor(data) y = l0(x) + tmp2 o = l1(y) o.backward() @@ -329,7 +349,7 @@ class TestImperative(unittest.TestCase): with fluid.dygraph.guard(): inputs = [] for _ in range(10): - tmp = fluid.dygraph.base.to_variable(x) + tmp = paddle.to_tensor(x) tmp.stop_gradient = False inputs.append(tmp) ret = fluid.layers.sums(inputs) @@ -338,7 +358,7 @@ class TestImperative(unittest.TestCase): with fluid.dygraph.guard(): inputs2 = [] for _ in range(10): - tmp = fluid.dygraph.base.to_variable(x) + tmp = paddle.to_tensor(x) tmp.stop_gradient = False inputs2.append(tmp) ret2 = fluid.layers.sums(inputs2) @@ -376,7 +396,7 @@ class TestImperative(unittest.TestCase): def test_empty_grad(self): with fluid.dygraph.guard(): x = np.ones([2, 2], np.float32) - new_var = fluid.dygraph.base.to_variable(x) + new_var = paddle.to_tensor(x) try: new_var.gradient() except Exception as e: @@ -400,7 +420,7 @@ class TestImperative(unittest.TestCase): def test_set_persistable(self): with fluid.dygraph.guard(): x = np.ones([2, 2], np.float32) - new_var = fluid.dygraph.base.to_variable(x) + new_var = paddle.to_tensor(x) self.assertFalse(new_var.persistable) new_var.persistable = True self.assertTrue(new_var.persistable) @@ -413,7 +433,7 @@ class TestImperative(unittest.TestCase): def test_layer_in_out(self): np_inp = np.array([1.0, 2.0, -1.0], dtype=np.float32) with fluid.dygraph.guard(): - var_inp = fluid.dygraph.base.to_variable(np_inp) + var_inp = paddle.to_tensor(np_inp) var_inp.stop_gradient = False l = MyLayer() x = l(var_inp)[0] @@ -423,7 +443,7 @@ class TestImperative(unittest.TestCase): dy_grad = l._x_for_debug.gradient() with fluid.dygraph.guard(): - var_inp2 = fluid.dygraph.base.to_variable(np_inp) + var_inp2 = paddle.to_tensor(np_inp) var_inp2.stop_gradient = False l2 = MyLayer() x2 = l2(var_inp2)[0] @@ -455,7 +475,7 @@ class TestImperative(unittest.TestCase): def test_mlp(self): np_inp = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) with fluid.dygraph.guard(): - var_inp = fluid.dygraph.base.to_variable(np_inp) + var_inp = paddle.to_tensor(np_inp) mlp = MLP(input_size=2) out = mlp(var_inp) dy_out = out.numpy() @@ -463,7 +483,7 @@ class TestImperative(unittest.TestCase): dy_grad = mlp._linear1.weight.gradient() with fluid.dygraph.guard(): - var_inp2 = fluid.dygraph.base.to_variable(np_inp) + var_inp2 = paddle.to_tensor(np_inp) mlp2 = MLP(input_size=2) out2 = mlp2(var_inp2) dy_out2 = out2.numpy() @@ -641,8 +661,8 @@ class TestImperative(unittest.TestCase): # dynamic graph with fluid.dygraph.guard(): - inp1 = fluid.dygraph.to_variable(np_inp1) - inp2 = fluid.dygraph.to_variable(np_inp2) + inp1 = paddle.to_tensor(np_inp1) + inp2 = paddle.to_tensor(np_inp2) if np.sum(np_inp1) < np.sum(np_inp2): x = fluid.layers.elementwise_add(inp1, inp2) else: @@ -692,7 +712,7 @@ class TestImperative(unittest.TestCase): np_inp = np_inp.reshape((1, 4, 3)) np_inp = np_inp.astype(np.float32) with fluid.dygraph.guard(): - var_inp = fluid.dygraph.base.to_variable(np_inp) + var_inp = paddle.to_tensor(np_inp) var_inp = fluid.layers.reshape(var_inp, shape=[1, 4, 3]) simple_rnn = SimpleRNN() outs, pre_hiddens = simple_rnn.forward(var_inp) @@ -703,7 +723,7 @@ class TestImperative(unittest.TestCase): dy_grad_i2h = simple_rnn._cell._i2h_w.gradient() with fluid.dygraph.guard(): - var_inp2 = fluid.dygraph.base.to_variable(np_inp) + var_inp2 = paddle.to_tensor(np_inp) var_inp2 = fluid.layers.reshape(var_inp2, shape=[1, 4, 3]) simple_rnn2 = SimpleRNN() outs2, pre_hiddens2 = simple_rnn2.forward(var_inp2) @@ -760,58 +780,83 @@ class TestImperative(unittest.TestCase): class TestDygraphUtils(unittest.TestCase): - def test_append_activation_in_dygraph_exception(self): + def func_append_activation_in_dygraph_exception(self): with new_program_scope(): np_inp = np.random.random(size=(10, 20, 30)).astype(np.float32) a = fluid.layers.data("a", [10, 20]) func = dygraph_utils._append_activation_in_dygraph self.assertRaises(AssertionError, func, a, act="sigmoid") - def test_append_activation_in_dygraph1(self): + def test_append_activation_in_dygraph_exception(self): + with _test_eager_guard(): + self.func_append_activation_in_dygraph_exception() + self.func_append_activation_in_dygraph_exception() + + def func_append_activation_in_dygraph1(self): a_np = np.random.random(size=(10, 20, 30)).astype(np.float32) func = dygraph_utils._append_activation_in_dygraph with fluid.dygraph.guard(): - a = fluid.dygraph.to_variable(a_np) + a = paddle.to_tensor(a_np) res1 = func(a, act="hard_sigmoid") res2 = fluid.layers.hard_sigmoid(a) self.assertTrue(np.array_equal(res1.numpy(), res2.numpy())) - def test_append_activation_in_dygraph2(self): + def test_append_activation_in_dygraph1(self): + with _test_eager_guard(): + self.func_append_activation_in_dygraph1() + self.func_append_activation_in_dygraph1() + + def func_append_activation_in_dygraph2(self): a_np = np.random.random(size=(10, 20, 30)).astype(np.float32) func = dygraph_utils._append_activation_in_dygraph with fluid.dygraph.guard(): - a = fluid.dygraph.to_variable(a_np) + a = paddle.to_tensor(a_np) res1 = func(a, act="sigmoid", use_mkldnn=True, use_cudnn=True) res2 = fluid.layers.sigmoid(a) self.assertTrue(np.allclose(res1.numpy(), res2.numpy())) - def test_append_activation_in_dygraph3(self): + def test_append_activation_in_dygraph2(self): + with _test_eager_guard(): + self.func_append_activation_in_dygraph2() + self.func_append_activation_in_dygraph2() + + def func_append_activation_in_dygraph3(self): a_np = np.random.random(size=(10, 20, 30)).astype(np.float32) helper = LayerObjectHelper(fluid.unique_name.generate("test")) func = helper.append_activation with fluid.dygraph.guard(): - a = fluid.dygraph.to_variable(a_np) + a = paddle.to_tensor(a_np) res1 = func(a, act="sigmoid", use_cudnn=True) res2 = fluid.layers.sigmoid(a) self.assertTrue(np.array_equal(res1.numpy(), res2.numpy())) - def test_append_activation_in_dygraph_use_mkldnn(self): + def test_append_activation_in_dygraph3(self): + with _test_eager_guard(): + self.func_append_activation_in_dygraph3() + self.func_append_activation_in_dygraph3() + + def func_append_activation_in_dygraph_use_mkldnn(self): a_np = np.random.uniform(-2, 2, (10, 20, 30)).astype(np.float32) helper = LayerHelper( fluid.unique_name.generate("test"), act="relu", use_mkldnn=True) func = helper.append_activation with fluid.dygraph.guard(): - a = fluid.dygraph.to_variable(a_np) + a = paddle.to_tensor(a_np) res1 = func(a) res2 = fluid.layers.relu(a) self.assertTrue(np.array_equal(res1.numpy(), res2.numpy())) - def test_append_activation_in_dygraph_global_use_mkldnn(self): + def test_append_activation_in_dygraph_use_mkldnn(self): + with _test_eager_guard(): + self.func_append_activation_in_dygraph_use_mkldnn() + self.func_append_activation_in_dygraph_use_mkldnn() + + def func_append_activation_in_dygraph_global_use_mkldnn(self): a_np = np.random.uniform(-2, 2, (10, 20, 30)).astype(np.float32) helper = LayerHelper(fluid.unique_name.generate("test"), act="relu") func = helper.append_activation with fluid.dygraph.guard(fluid.core.CPUPlace()): - a = fluid.dygraph.to_variable(a_np) + a = paddle.to_tensor(a_np) fluid.set_flags({'FLAGS_use_mkldnn': True}) try: res1 = func(a) @@ -820,38 +865,67 @@ class TestDygraphUtils(unittest.TestCase): res2 = fluid.layers.relu(a) self.assertTrue(np.array_equal(res1.numpy(), res2.numpy())) - def test_append_bias_in_dygraph_exception(self): + def test_append_activation_in_dygraph_global_use_mkldnn(self): + with _test_eager_guard(): + self.func_append_activation_in_dygraph_global_use_mkldnn() + self.func_append_activation_in_dygraph_global_use_mkldnn() + + def func_append_bias_in_dygraph_exception(self): with new_program_scope(): np_inp = np.random.random(size=(10, 20, 30)).astype(np.float32) a = fluid.layers.data("a", [10, 20]) func = dygraph_utils._append_bias_in_dygraph self.assertRaises(AssertionError, func, a) - def test_append_bias_in_dygraph(self): + def test_append_bias_in_dygraph_exception(self): + with _test_eager_guard(): + self.func_append_bias_in_dygraph_exception() + self.func_append_bias_in_dygraph_exception() + + def func_append_bias_in_dygraph(self): a_np = np.random.random(size=(10, 20, 30)).astype(np.float32) func = dygraph_utils._append_bias_in_dygraph with fluid.dygraph.guard(): - a = fluid.dygraph.to_variable(a_np) + a = paddle.to_tensor(a_np) res1 = func(a, bias=a) - res2 = a + a + res2 = paddle.add(a, a) self.assertTrue(np.array_equal(res1.numpy(), res2.numpy())) + def test_append_bias_in_dygraph(self): + with _test_eager_guard(): + self.func_append_bias_in_dygraph() + self.func_append_bias_in_dygraph() + class TestDygraphGuardWithError(unittest.TestCase): - def test_without_guard(self): + def func_without_guard(self): with fluid.dygraph.guard(): - x = fluid.dygraph.to_variable(np.zeros([10, 10])) + x = paddle.to_tensor(np.zeros([10, 10])) with self.assertRaisesRegexp(TypeError, "Please use `with fluid.dygraph.guard()"): y = fluid.layers.matmul(x, x) + def test_without_guard(self): + with _test_eager_guard(): + self.func_without_guard() + self.func_without_guard() + class TestMetaclass(unittest.TestCase): - def test_metaclass(self): + def func_metaclass(self): self.assertEqual(type(MyLayer).__name__, 'type') self.assertNotEqual(type(MyLayer).__name__, 'pybind11_type') - self.assertEqual( - type(paddle.fluid.core.VarBase).__name__, 'pybind11_type') + if core._in_eager_mode(): + self.assertEqual( + type(paddle.fluid.core.eager.EagerTensor).__name__, 'type') + else: + self.assertEqual( + type(paddle.fluid.core.VarBase).__name__, 'pybind11_type') + + def test_metaclass(self): + with _test_eager_guard(): + self.func_metaclass() + self.func_metaclass() if __name__ == '__main__': diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index facec0975b6..cd1faf64f3e 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -117,12 +117,6 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True): ) != _current_expected_place()._get_device_id(): place = _current_expected_place() - if _in_eager_mode(): - if dtype is None: - dtype = paddle.get_default_dtype() - return core.eager.to_tensor(data, - convert_dtype(dtype), place, stop_gradient) - if not isinstance(data, np.ndarray): def _handle_dtype(data, dtype): @@ -172,12 +166,17 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True): if dtype and convert_dtype(dtype) != data.dtype: data = data.astype(convert_dtype(dtype)) - return paddle.Tensor( - value=data, - place=place, - persistable=False, - zero_copy=False, - stop_gradient=stop_gradient) + # TOOD(jiabin): Support kwargs in eager tensor constructor + if _in_eager_mode() and isinstance(data, np.ndarray): + return core.eager.EagerTensor(data, place, False, False, None, + stop_gradient) + else: + return paddle.Tensor( + value=data, + place=place, + persistable=False, + zero_copy=False, + stop_gradient=stop_gradient) def full_like(x, fill_value, dtype=None, name=None): -- GitLab