From 13d757362c6ba045bb2dace130175e5f9a90870f Mon Sep 17 00:00:00 2001 From: pangyoki Date: Fri, 15 Jan 2021 17:28:19 +0800 Subject: [PATCH] Add Inplace strategy (Output reuse Input Varbase) in dygraph (#30103) * add view strategy on squeeze,unsqueeze,reshape,flatten * add squeeze unittest * add unittests * use View strategy as name rather than Reuse Allacation * fix view api doc * fix format * use core.ops when input of reshape2 is Tensor * fix test_cross_entropy_loss error because of reshape2 * fix test_cross_entropy_loss error because of reshape2 * add inplace strategy * add elementwise_add sub * let backward op not use inplace * grad op do not use inplace * fix memory increase error and add leaf error message * delete selected_rows * change op_function * little change * solve HandleViewBetweenInputAndOutput * add unittest and leaf error message * merge view error * optimize op_function_generator format and support sum inplace op * fix format of basic_engine * fix format for framework * little change of variable wrapper * add reshape, squeeze, unsqueeze, scatter api * add relu elu tanh softmax inplace api * fix test_squeeze_op unittest * fix test_relu_op unittest * fix comment problems * delete sample code of inplace api * add reference of grad_pending_nodes in basic_engine * fix unittest name * add inplace apis into wlist * fix error message * add PADDLE_ENFORCE for set grad op twice * fix head file error --- paddle/fluid/framework/details/op_registry.h | 6 +- paddle/fluid/framework/grad_op_desc_maker.h | 4 + paddle/fluid/framework/type_defs.h | 3 +- paddle/fluid/imperative/basic_engine.cc | 191 ++++++-- paddle/fluid/imperative/basic_engine.h | 20 +- paddle/fluid/imperative/dygraph_grad_maker.h | 34 +- paddle/fluid/imperative/layer.cc | 6 +- paddle/fluid/imperative/layer.h | 3 +- paddle/fluid/imperative/op_base.h | 21 + .../fluid/imperative/partial_grad_engine.cc | 2 +- paddle/fluid/imperative/tracer.cc | 13 +- paddle/fluid/imperative/tracer.h | 7 +- paddle/fluid/imperative/variable_wrapper.h | 10 +- paddle/fluid/pybind/op_function_generator.cc | 413 +++++++++++------- python/paddle/__init__.py | 5 + .../tests/unittests/test_activation_op.py | 64 ++- .../fluid/tests/unittests/test_inplace.py | 201 +++++++++ .../fluid/tests/unittests/test_reshape_op.py | 48 +- .../fluid/tests/unittests/test_scatter_op.py | 13 +- .../fluid/tests/unittests/test_softmax_op.py | 25 +- .../fluid/tests/unittests/test_squeeze_op.py | 34 +- .../tests/unittests/test_unsqueeze2_op.py | 24 +- .../tests/unittests/test_unsqueeze_op.py | 21 +- python/paddle/nn/functional/__init__.py | 4 + python/paddle/nn/functional/activation.py | 50 +++ python/paddle/tensor/__init__.py | 5 + python/paddle/tensor/manipulation.py | 87 ++++ python/paddle/tensor/math.py | 13 + tools/wlist.json | 32 ++ 29 files changed, 1102 insertions(+), 257 deletions(-) diff --git a/paddle/fluid/framework/details/op_registry.h b/paddle/fluid/framework/details/op_registry.h index 453a25166b5..df5370e42ee 100644 --- a/paddle/fluid/framework/details/op_registry.h +++ b/paddle/fluid/framework/details/op_registry.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include #include @@ -247,8 +248,9 @@ struct OpInfoFiller { const std::string& type, const imperative::NameVarBaseMap& var_base_map_in, const imperative::NameVarBaseMap& var_base_map_out, - const framework::AttributeMap& attrs) { - T maker(type, var_base_map_in, var_base_map_out, attrs); + const framework::AttributeMap& attrs, + const std::map& inplace_map) { + T maker(type, var_base_map_in, var_base_map_out, attrs, inplace_map); return maker(); }; } diff --git a/paddle/fluid/framework/grad_op_desc_maker.h b/paddle/fluid/framework/grad_op_desc_maker.h index 27575878f2e..b0247fe795b 100644 --- a/paddle/fluid/framework/grad_op_desc_maker.h +++ b/paddle/fluid/framework/grad_op_desc_maker.h @@ -221,6 +221,10 @@ class SingleGradOpMaker std::shared_ptr operator()() const final { auto node = this->NewGradNode(); + auto& inplace_map = this->GetInplaceMap(); + if (!inplace_map.empty()) { + node->SetInplaceGradNameMap(inplace_map); + } { imperative::TracedGradOp traced_grad_op(node); try { diff --git a/paddle/fluid/framework/type_defs.h b/paddle/fluid/framework/type_defs.h index 4d2f07fa494..a2b5a98401e 100644 --- a/paddle/fluid/framework/type_defs.h +++ b/paddle/fluid/framework/type_defs.h @@ -59,7 +59,8 @@ using DygraphGradOpMakerFN = const std::string& /*op_type*/, const imperative::NameVarBaseMap& /*var_base_map_in*/, const imperative::NameVarBaseMap& /*var_base_map_out*/, - const framework::AttributeMap& /*attributes*/)>; + const framework::AttributeMap& /*attributes*/, + const std::map& /*inplace_map*/)>; using InferVarTypeFN = std::function; diff --git a/paddle/fluid/imperative/basic_engine.cc b/paddle/fluid/imperative/basic_engine.cc index 731cf121534..a34ac72ec16 100644 --- a/paddle/fluid/imperative/basic_engine.cc +++ b/paddle/fluid/imperative/basic_engine.cc @@ -114,7 +114,9 @@ void BasicEngine::CheckBackwardInputs(const OpBase& op) { } } -void BasicEngine::PrepareGradAccumulators(const OpBase& op) { +void BasicEngine::PrepareGradAccumulators( + const OpBase& op, + const std::vector>& grad_pending_nodes) { for (const auto& pair : op.GetOutsMap()) { if (!pair.second.IsGrad()) { continue; @@ -123,29 +125,94 @@ void BasicEngine::PrepareGradAccumulators(const OpBase& op) { for (const auto& var : pair.second) { if (!var) continue; - auto& accumulator = accumulators_[var.get()]; - if (!accumulator) { - if (FLAGS_sort_sum_gradient) { - accumulator.reset(new SortedGradientAccumulator(var.get())); - } else { - accumulator.reset(new EagerGradientAccumulator(var.get())); + if (!var->HasGradNode()) { + auto& accumulator = accumulators_[var.get()]; + if (!accumulator) { + if (FLAGS_sort_sum_gradient) { + accumulator.reset(new SortedGradientAccumulator(var.get())); + } else { + accumulator.reset(new EagerGradientAccumulator(var.get())); + } } - } - accumulator->IncreaseRefCnt(); + accumulator->IncreaseRefCnt(); - VLOG(3) << "Prepare to acccumulate variable grad " << var->Name() << "(" - << var.get() << ") with reference count " - << accumulator->RefCnt(); + VLOG(3) << "Prepare to acccumulate variable grad " << var->Name() << "(" + << var.get() + << ") that don't have grad node with reference count " + << accumulator->RefCnt(); + + if (var->HasLeafHooks()) { + VLOG(3) << "Grad variable wrapper (" << var->Name() + << ") has leaf grad hooks."; + PADDLE_ENFORCE_NE( + var->HasGradNode(), true, + platform::errors::PermissionDenied( + "Only leaf Tensor's gradient can append hook to " + "Gradientaccumulator.")); + accumulator->SetPostHooks(var->GetLeafHooks()); + } + } else { + // Because Inplace op overwrites the grad_node of the input grad_var. So + // only the information of grad_pending_node can be used to find the + // grad_node of grad_var. + bool find_grad_node_of_var = false; + for (auto& grad_pending_node : grad_pending_nodes) { + PADDLE_ENFORCE_NOT_NULL( + grad_pending_node, + platform::errors::NotFound("Grad pending node is nullptr.")); + for (auto& grad_pending_op : *grad_pending_node) { + VLOG(6) << "Determine whether var (" << var->Name() + << ") is the input var of grad_pending_op (" + << grad_pending_op.Type() << ")."; + grad_pending_op.EnforceHasInOut(); + for (const auto& grad_pending_op_ins_pair : + grad_pending_op.GetInsMap()) { + if (!grad_pending_op_ins_pair.second.IsGrad()) { + continue; + } + for (const auto& pending_in_var : + grad_pending_op_ins_pair.second) { + if (var == pending_in_var) { + VLOG(6) << "Var (" << var->Name() + << ") is the input var of grad_pending_op (" + << grad_pending_op.Type() << ")."; + find_grad_node_of_var = true; + break; + } + } + if (find_grad_node_of_var) { + break; + } + } + } - if (var->HasLeafHooks()) { - VLOG(3) << "Grad variable wrapper (" << var->Name() - << ") has leaf grad hooks."; - PADDLE_ENFORCE_NE(var->HasGradNode(), true, - platform::errors::PermissionDenied( - "Only leaf Tensor's gradient can append hook to " - "Gradientaccumulator.")); - accumulator->SetPostHooks(var->GetLeafHooks()); + if (find_grad_node_of_var) { + auto& accumulator = + accumulators_with_grad_node_[grad_pending_node][var.get()]; + + if (!accumulator) { + if (FLAGS_sort_sum_gradient) { + accumulator.reset(new SortedGradientAccumulator(var.get())); + } else { + accumulator.reset(new EagerGradientAccumulator(var.get())); + } + } + + accumulator->IncreaseRefCnt(); + + VLOG(3) << "Prepare to acccumulate variable grad " << var->Name() + << "(" << var.get() + << ") that has grad node with reference count " + << accumulator->RefCnt(); + break; + } + } + PADDLE_ENFORCE_EQ( + find_grad_node_of_var, true, + platform::errors::NotFound( + "No grad node corresponding to grad Tensor (%s) was found.", + var->Name())); } } } @@ -154,10 +221,13 @@ void BasicEngine::PrepareGradAccumulators(const OpBase& op) { void BasicEngine::PrepareDeps() { PADDLE_ENFORCE_EQ( node_deps_.empty(), true, - platform::errors::AlreadyExists("Op deps must be initialized here")); + platform::errors::AlreadyExists("Op deps must be initialized.")); PADDLE_ENFORCE_EQ( accumulators_.empty(), true, - platform::errors::AlreadyExists("Accumulators must be initialized here")); + platform::errors::AlreadyExists("Accumulators must be initialized.")); + PADDLE_ENFORCE_EQ( + accumulators_with_grad_node_.empty(), true, + platform::errors::AlreadyExists("Accumulators must be initialized.")); std::queue q; std::unordered_set visited; @@ -169,16 +239,17 @@ void BasicEngine::PrepareDeps() { auto* cur_node = q.front(); q.pop(); + const auto& grad_pending_nodes = cur_node->GradPendingNodes(); + for (auto& cur_op : *cur_node) { cur_op.EnforceHasInOut(); - PrepareGradAccumulators(cur_op); + PrepareGradAccumulators(cur_op, grad_pending_nodes); } - const auto& grad_pending_nodes = cur_node->GradPendingNodes(); for (auto& grad_pending_node : grad_pending_nodes) { PADDLE_ENFORCE_NOT_NULL( grad_pending_node, - platform::errors::NotFound("Grad pending node should not be null")); + platform::errors::NotFound("Grad pending node is nullptr.")); ++node_deps_[grad_pending_node.get()]; if (visited.count(grad_pending_node.get()) == 0) { visited.insert(grad_pending_node.get()); @@ -204,6 +275,8 @@ void BasicEngine::Execute() { auto shared_cur_node = std::move(q.front()); q.pop(); + auto& inplace_grad_name_map = shared_cur_node->InplaceGradNameMap(); + for (auto& cur_op : *shared_cur_node) { ++op_num; @@ -228,11 +301,38 @@ void BasicEngine::Execute() { continue; } - auto iter = accumulators_.find(var.get()); - PADDLE_ENFORCE_EQ( - iter != accumulators_.end(), true, - platform::errors::NotFound("Cannot find gradient of variable %s", - var->Name())); + std::unordered_map>::iterator + iter; + if (!var->HasGradNode()) { + VLOG(10) << "Find gradient of var (" << var->Name() + << ") with no grad_node."; + iter = accumulators_.find(var.get()); + PADDLE_ENFORCE_EQ( + iter != accumulators_.end(), true, + platform::errors::NotFound( + "Cannot find gradient of variable %s", var->Name())); + } else { + bool flag_find_grad = false; + VLOG(10) << "Find gradient of var (" << var->Name() + << ") with grad_node."; + for (auto& grad_pending_node : + shared_cur_node->GradPendingNodes()) { + const auto& iter_grad_node = + accumulators_with_grad_node_.find(grad_pending_node); + if (iter_grad_node != accumulators_with_grad_node_.end()) { + iter = iter_grad_node->second.find(var.get()); + if (iter != iter_grad_node->second.end()) { + flag_find_grad = true; + break; + } + } + } + PADDLE_ENFORCE_EQ( + flag_find_grad, true, + platform::errors::NotFound( + "Cannot find gradient of variable %s", var->Name())); + } // leaf_accumulators_ : hooks and accumulate-grad for leaf tensor if (var->IsLeafGrad()) { @@ -251,6 +351,25 @@ void BasicEngine::Execute() { need_accu_var_list_.emplace_back(iter->second.get(), var); VLOG(10) << "create temporary var of " << var->Name() << " for sum gradient within this graph!"; + } else if (!inplace_grad_name_map.empty() && + inplace_grad_name_map.count(pair.first)) { + // When calculate Inplace grad op, create a new output var. + // If a tmp var has been created, there is no need to create it + // again. + for (auto& in_var : + bwd_ins.at(inplace_grad_name_map.at(pair.first))) { + if (in_var == var) { + auto tmp_var = std::make_shared(var->Name()); + tmp_var->SetType(var->Type()); + tmp_var->SetForwardDataType(var->ForwardDataType()); + inplace_output_grad_var_list_.emplace_back(var, tmp_var); + var = tmp_var; + VLOG(10) << "Inplace grad op does not use the Inplace " + "strategy, a temporary output var (" + << var->Name() << ") will be created."; + break; + } + } } } } @@ -286,6 +405,10 @@ void BasicEngine::Execute() { cur_op.place()); } + for (auto& pair : inplace_output_grad_var_list_) { + *pair.first = std::move(*pair.second); + } + // Step 2: Sum Gradient of This graph for (auto& pair : need_accu_var_list_) { pair.first->SumGrad(std::move(pair.second), cur_op.id()); @@ -308,6 +431,7 @@ void BasicEngine::Execute() { } need_accu_var_list_.clear(); + inplace_output_grad_var_list_.clear(); leaf_accumulators_.clear(); if (!retain_graph_) { @@ -318,9 +442,9 @@ void BasicEngine::Execute() { // Step 3: Collect ready ops for (auto& grad_pending_node : shared_cur_node->GradPendingNodes()) { - PADDLE_ENFORCE_NOT_NULL(grad_pending_node, - platform::errors::NotFound( - "Grad pending node should not be nullptr")); + PADDLE_ENFORCE_NOT_NULL( + grad_pending_node, + platform::errors::NotFound("Grad pending node is nullptr.")); auto iter = node_deps_.find(grad_pending_node.get()); if (iter == node_deps_.end()) { continue; @@ -340,6 +464,7 @@ void BasicEngine::Clear() { init_node_.reset(); node_deps_.clear(); accumulators_.clear(); + accumulators_with_grad_node_.clear(); need_accu_var_list_.clear(); leaf_accumulators_.clear(); } diff --git a/paddle/fluid/imperative/basic_engine.h b/paddle/fluid/imperative/basic_engine.h index d7ac7594ef0..87c4ea380f3 100644 --- a/paddle/fluid/imperative/basic_engine.h +++ b/paddle/fluid/imperative/basic_engine.h @@ -39,15 +39,33 @@ class BasicEngine : public Engine { void CheckBackwardInputs(const OpBase& op); - void PrepareGradAccumulators(const OpBase& op); + void PrepareGradAccumulators( + const OpBase& op, + const std::vector>& grad_pending_nodes); void Clear(); private: std::shared_ptr init_node_; std::unordered_map node_deps_; + // The input and output of Inplace op are the same. If only `var` is used + // as the key, then the input and output of inplace op must be gradient + // accumulated. Therefore, add the `grad_node` as the key to prevent the + // problem of gradient accumulation in inplace op. + std::unordered_map, + std::unordered_map>> + accumulators_with_grad_node_; + // Leaf var doesn't have grad_node, and leaf var with `stop_gradient=False` + // can't use Inplace strategy. If a var doesn't have grad_node, only use + // `var` as the key. std::unordered_map> accumulators_; + // The output grad var of Inplace grad op. Because Inplace grad op does not + // use the Inplace strategy, a new output grad var needs to be created. + std::vector, + std::shared_ptr>> + inplace_output_grad_var_list_; std::vector>> need_accu_var_list_; // leaf_accumulators_ is only for leaf tensor(hooks/accumulate grad) diff --git a/paddle/fluid/imperative/dygraph_grad_maker.h b/paddle/fluid/imperative/dygraph_grad_maker.h index d650452ad9a..a3678404728 100644 --- a/paddle/fluid/imperative/dygraph_grad_maker.h +++ b/paddle/fluid/imperative/dygraph_grad_maker.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include #include @@ -43,14 +44,16 @@ class TracedVarList : public std::vector> { class GradOpBaseMakerBase { public: - explicit GradOpBaseMakerBase(const std::string& type, - const NameVarBaseMap& var_base_map_in, - const NameVarBaseMap& var_base_map_out, - const framework::AttributeMap& attrs) + explicit GradOpBaseMakerBase( + const std::string& type, const NameVarBaseMap& var_base_map_in, + const NameVarBaseMap& var_base_map_out, + const framework::AttributeMap& attrs, + const std::map& inplace_map) : type_(type), var_base_map_in_(var_base_map_in), var_base_map_out_(var_base_map_out), - attrs_(attrs) {} + attrs_(attrs), + inplace_map_(inplace_map) {} virtual ~GradOpBaseMakerBase() = default; @@ -141,6 +144,10 @@ class GradOpBaseMakerBase { return std::make_shared(); } + const std::map& GetInplaceMap() const { + return inplace_map_; + } + private: template TracedVarList GetVarBaseList(const std::string& name, @@ -192,6 +199,7 @@ class GradOpBaseMakerBase { const NameVarBaseMap& var_base_map_in_; const NameVarBaseMap& var_base_map_out_; const framework::AttributeMap& attrs_; + const std::map& inplace_map_; }; class TracedGradOp { @@ -220,6 +228,10 @@ class TracedGradOp { for (auto& var : vars) { if (var && !var->OverridedStopGradient()) { var->SetGraphIsFreed(false); + auto dirty_grad_node = var->GradNode(); + if (dirty_grad_node) { + map_dirty_grad_node_[var] = dirty_grad_node; + } var->SetGradNode(node_); } } @@ -246,7 +258,11 @@ class TracedGradOp { } else { for (auto& var : vars) { if (var && !var->OverridedStopGradient() && var->GradNode()) { - node_->InsertGradPendingNode(var->GradNode()); + if (map_dirty_grad_node_.find(var) != map_dirty_grad_node_.end()) { + node_->InsertGradPendingNode(map_dirty_grad_node_[var]); + } else { + node_->InsertGradPendingNode(var->GradNode()); + } } } } @@ -329,6 +345,12 @@ class TracedGradOp { private: const std::shared_ptr& node_; OpBase* op_; + // Inplace op has recursion problems when performing grad calculation. + // Because the input and output of inplace op are the same, the grad + // node of inplace var will be overwritten. + // This map is used to store the grad node of inplace var in temporary. + std::unordered_map, std::shared_ptr> + map_dirty_grad_node_; }; } // namespace imperative diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index b43414c5021..3123d4b5077 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -451,13 +451,15 @@ static void ClearNoNeedBufferInputs(OpBase* op) { std::shared_ptr CreateGradOpNode( const framework::OperatorBase& op, const NameVarBaseMap& ins, const NameVarBaseMap& outs, const framework::AttributeMap& attrs, - const platform::Place& place) { + const platform::Place& place, + const std::map& inplace_map) { const auto& info = op.Info(); if (!info.dygraph_grad_op_maker_) { return nullptr; } - auto grad_node = info.dygraph_grad_op_maker_(op.Type(), ins, outs, attrs); + auto grad_node = + info.dygraph_grad_op_maker_(op.Type(), ins, outs, attrs, inplace_map); if (grad_node && !grad_node->empty()) { for (auto& grad_op : *grad_node) { grad_op.SetId(OpBase::GenerateUniqueId()); diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h index adec67c8067..e218033eae0 100644 --- a/paddle/fluid/imperative/layer.h +++ b/paddle/fluid/imperative/layer.h @@ -256,7 +256,8 @@ class Layer { std::shared_ptr CreateGradOpNode( const framework::OperatorBase& op, const NameVarBaseMap& ins, const NameVarBaseMap& outs, const framework::AttributeMap& attrs, - const platform::Place& place); + const platform::Place& place, + const std::map& inplace_map); } // namespace imperative } // namespace paddle diff --git a/paddle/fluid/imperative/op_base.h b/paddle/fluid/imperative/op_base.h index 36185af3a25..2b7642ae7cf 100644 --- a/paddle/fluid/imperative/op_base.h +++ b/paddle/fluid/imperative/op_base.h @@ -15,11 +15,13 @@ #pragma once #include +#include #include #include #include #include #include "paddle/fluid/framework/type_defs.h" +#include "paddle/fluid/imperative/saved_variable_wrapper_list.h" #include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/imperative/variable_wrapper.h" #include "paddle/fluid/platform/place.h" @@ -227,6 +229,22 @@ class GradOpNode { } } + void SetInplaceGradNameMap( + const std::map& inplace_input_map) { + for (auto& pair : inplace_input_map) { + VLOG(10) << "Set mapping relationship (" + << framework::GradVarName(pair.first) << ", " + << framework::GradVarName(pair.second) + << ") for Inplace grad node."; + inplace_grad_name_map_[framework::GradVarName(pair.first)] = + framework::GradVarName(pair.second); + } + } + + const std::map& InplaceGradNameMap() const { + return inplace_grad_name_map_; + } + const std::vector>& GradPendingNodes() const { return grad_pending_nodes_; } @@ -237,6 +255,9 @@ class GradOpNode { private: std::vector ops_; std::vector> grad_pending_nodes_; + // Mapping relationship between grad output and grad input of the grad node of + // Inplace op. + std::map inplace_grad_name_map_; }; } // namespace imperative diff --git a/paddle/fluid/imperative/partial_grad_engine.cc b/paddle/fluid/imperative/partial_grad_engine.cc index 149a38e2586..8dd8cafc835 100644 --- a/paddle/fluid/imperative/partial_grad_engine.cc +++ b/paddle/fluid/imperative/partial_grad_engine.cc @@ -884,7 +884,7 @@ void PartialGradTask::RunEachOp(OpBase *op) { if (create_graph_) { auto double_grad_node = CreateGradOpNode(op->InnerOp(), tmp_ins, tmp_outs, - op->Attrs(), op->place()); + op->Attrs(), op->place(), {}); PADDLE_ENFORCE_NOT_NULL( double_grad_node, platform::errors::NotFound("The Op %s doesn't have any grad op. If you " diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index 68c79f77e56..e5d664070e1 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/imperative/tracer.h" +#include #include #include #include @@ -130,7 +131,8 @@ paddle::framework::GarbageCollector* Tracer::MutableGarbageCollectorIfNotExists( void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, const NameVarBaseMap& outs, framework::AttributeMap attrs, - const platform::Place& place, bool trace_backward) { + const platform::Place& place, bool trace_backward, + const std::map& inplace_map) { VLOG(1) << "Trace Op: " << type; if (FLAGS_use_mkldnn) { // if both lists are empty all ops are enabled (default for @@ -182,16 +184,17 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, } if (ComputeRequiredGrad(new_ins, outs, trace_backward)) { - CreateGradOpNode(*op, new_ins, outs, attrs, place); + CreateGradOpNode(*op, new_ins, outs, attrs, place, inplace_map); } else { VLOG(3) << "No Grad to track for Op: " << type; } } void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, - const NameVarBaseMap& outs, - framework::AttributeMap attrs) { - TraceOp(type, ins, outs, std::move(attrs), expected_place_, has_grad_); + const NameVarBaseMap& outs, framework::AttributeMap attrs, + const std::map& inplace_map) { + TraceOp(type, ins, outs, std::move(attrs), expected_place_, has_grad_, + inplace_map); } bool Tracer::ComputeRequiredGrad(const NameVarBaseMap& ins, diff --git a/paddle/fluid/imperative/tracer.h b/paddle/fluid/imperative/tracer.h index 601645a8445..d8c825666e7 100644 --- a/paddle/fluid/imperative/tracer.h +++ b/paddle/fluid/imperative/tracer.h @@ -21,7 +21,6 @@ #include #include #include - #include "ThreadPool.h" #include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/imperative/basic_engine.h" @@ -63,10 +62,12 @@ class Tracer { void TraceOp(const std::string& type, const NameVarBaseMap& ins, const NameVarBaseMap& outs, framework::AttributeMap attrs, - const platform::Place& place, bool trace_bacward); + const platform::Place& place, bool trace_bacward, + const std::map& inplace_map = {}); void TraceOp(const std::string& type, const NameVarBaseMap& ins, - const NameVarBaseMap& outs, framework::AttributeMap attrs); + const NameVarBaseMap& outs, framework::AttributeMap attrs, + const std::map& inplace_map = {}); bool ComputeRequiredGrad(const NameVarBaseMap& ins, const NameVarBaseMap& outs, bool trace_backward); diff --git a/paddle/fluid/imperative/variable_wrapper.h b/paddle/fluid/imperative/variable_wrapper.h index 6f99b330595..d4192de519a 100644 --- a/paddle/fluid/imperative/variable_wrapper.h +++ b/paddle/fluid/imperative/variable_wrapper.h @@ -20,6 +20,7 @@ #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/imperative/hooks.h" +#include "paddle/fluid/imperative/op_base.h" namespace paddle { namespace imperative { @@ -258,8 +259,13 @@ class VariableWrapper { auto shared_node = grad_node_.lock(); if (shared_node != grad_node) { PADDLE_ENFORCE_EQ( - shared_node, nullptr, - platform::errors::PermissionDenied("Cannot set gradient op twice")); + !shared_node || !grad_node->InplaceGradNameMap().empty(), true, + platform::errors::PermissionDenied( + "Cannot set gradient op twice unless using Inplace Strategy.")); + if (shared_node) { + VLOG(3) << "The gradient op of Var (" << Name() + << ") has been set twice. Because Inplace Strategy is used."; + } grad_node_ = grad_node; } } diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index 349162c2e5a..03f66208ea5 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -142,9 +142,9 @@ std::map> op_passing_outs_map = { // NOTE(pangyoki): Tensor View Strategy. // In this case, a new output varbase will be created, and this varbase will // reuse the input varbase's allocation. -// It's a 2-layer map. The key of outer map is the view op name, the value is -// also a map which implies the mapping relationship between the output and -// input varbase. +// It's a map. The key of outer map is the view op name, the value is +// a pair which implies the mapping relationship between the input and +// output varbase. std::map> view_op_map = { {"squeeze2", {"X", "Out"}}, // "X" -> "Out" {"unsqueeze2", {"X", "Out"}}, @@ -152,6 +152,14 @@ std::map> view_op_map = { {"flatten_contiguous_range", {"X", "Out"}}, }; +// NOTE(pangyoki): Inplace OP with duplicable input. +// The set includes inplace ops that have duplicable input. +// The first Varbase in input needs to be specified for the inplace strategy +// and share Varbase with the output. +std::set inplace_op_duplicable_ins_set = { + "sum", +}; + // clang-format off const char* OUT_INITIALIZER_TEMPLATE = R"({"%s", {std::shared_ptr(new imperative::VarBase(tracer->GenerateUniqueName()))}})"; @@ -207,11 +215,26 @@ const char* RETURN_TEMPLATE = R"(outs["%s"][0])"; const char* FUNCTION_ARGS = R"(%s, const py::args& args)"; const char* FUNCTION_ARGS_NO_INPUT = R"(const py::args& args)"; -const char* HandleViewBetweenInputAndOutput = R"( +const char* HANDLE_VIEW_BETWEEN_INPUT_AND_OUTPUT = R"( if (ins.count("%s") && outs.count("%s")) { HandleViewBetweenInputAndOutput(ins["%s"][0], outs["%s"][0]); })"; +const char* INPLACE_DUPLICABLE_INPUT = R"([0])"; + +const char* INPLACE_LEAF_ERROR_MESSAGE = R"(Leaf Var (%s) that doesn't stop gradient can't use inplace strategy.)"; + +const char* INPLACE_STRATEGY_TEMPLATE = +R"( + PADDLE_ENFORCE_EQ( + %s->IsLeaf() && !%s->OverridedStopGradient(), false, + platform::errors::InvalidArgument("%s", %s->Name())); + %s->BumpInplaceVersion(); + VLOG(3) << "Var(" << %s->Name() << ") uses Inplace Strategy."; +)"; + +const char* INPLACE_MAPPING_TEMPLATE = R"({"%s", "%s"})"; + const char* OP_FUNCTION_TEMPLATE = R"( %s %s(%s) @@ -222,10 +245,11 @@ R"( { py::gil_scoped_release release; auto tracer = imperative::GetCurrentTracer(); + %s imperative::NameVarBaseMap outs = %s; imperative::NameVarBaseMap ins = %s; %s - tracer->TraceOp("%s", ins, outs, attrs); + tracer->TraceOp("%s", ins, outs, attrs, {%s}); return %s; } })"; @@ -248,6 +272,10 @@ static inline bool FindPassingOutsMap(const std::string& op_type, return op_passing_outs_map[op_type].count(out_name); } +static inline bool FindDuplicableInputInplaceOpSet(const std::string& op_type) { + return inplace_op_duplicable_ins_set.count(op_type); +} + static inline bool FindViewOpMap(const std::string& op_type) { return view_op_map.count(op_type); } @@ -256,6 +284,202 @@ static inline std::string TempName(const std::string& name) { return name + '_'; } +std::string GenerateOpFunctionsBody( + const paddle::framework::proto::OpProto* op_proto, std::string func_name, + bool use_inplace_strategy = false, + std::map inplace_map = {}) { + auto& op_type = op_proto->type(); + std::string input_args = ""; + std::string ins_initializer = "{"; + std::string ins_initializer_with_null = ""; + std::string py_arg = ""; + int arg_idx = 0; + int input_args_num = 0; + std::string ins_cast_str = ""; + std::string view_strategy_str = ""; + std::string inplace_strategy_str = ""; + for (auto& input : op_proto->inputs()) { + auto& in_name = input.name(); + // skip those dispensable inputs, like ResidualData in conv2d + if (input.dispensable() && !FindInsMap(op_type, in_name)) { + continue; + } + const auto in_type = input.duplicable() ? IN_VAR_LIST_TYPE : IN_VAR_TYPE; + auto input_arg = + paddle::string::Sprintf(ARG_TEMPLATE, in_type, TempName(in_name)); + input_args += input_arg; + input_args += ","; + input_args_num++; + const auto in_cast_type = + input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE; + auto dispensable = input.dispensable() ? "true" : "false"; + ins_cast_str += + paddle::string::Sprintf(in_cast_type, in_name, op_type, in_name, + arg_idx++, TempName(in_name), dispensable); + + if (input.dispensable()) { + const auto in_template = input.duplicable() + ? INPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST + : INPUT_INITIALIZER_TEMPLATE_WITH_NULL; + ins_initializer_with_null += + paddle::string::Sprintf(in_template, in_name, in_name, in_name); + } else { + const auto in_template = input.duplicable() + ? INPUT_LIST_INITIALIZER_TEMPLATE + : INPUT_INITIALIZER_TEMPLATE; + ins_initializer += paddle::string::Sprintf(in_template, in_name, in_name); + ins_initializer += ","; + } + } + if (ins_initializer.back() == ',') { + ins_initializer.pop_back(); + } + ins_initializer += "}"; + + if (input_args.back() == ',') { + input_args.pop_back(); + } + + // Generate outs initializer + std::string outs_initializer = "{"; + std::string outs_initializer_with_null = ""; + std::string return_type = ""; + std::string inplace_mapping_str = ""; + std::string return_str = ""; + + int outs_num = 0; + for (auto& output : op_proto->outputs()) { + auto& out_name = output.name(); + // skip those dispensable oututs + if (output.dispensable() && !FindOutsMap(op_type, out_name)) { + continue; + } + const auto out_type = + output.duplicable() ? OUT_VAR_LIST_TYPE : OUT_VAR_TYPE; + const auto return_template = + output.duplicable() ? RETURN_LIST_TEMPLATE : RETURN_TEMPLATE; + + if (FindPassingOutsMap(op_type, out_name)) { + if (input_args != "") { + input_args += ","; + } + input_args += out_type; + input_args += out_name; + input_args_num++; + + if (output.dispensable()) { + const auto out_template = + output.duplicable() ? OUTPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST + : OUTPUT_INITIALIZER_TEMPLATE_WITH_NULL; + outs_initializer_with_null += + paddle::string::Sprintf(out_template, out_name, out_name); + } else { + const auto out_template = output.duplicable() + ? INPUT_LIST_INITIALIZER_TEMPLATE + : INPUT_INITIALIZER_TEMPLATE; + outs_initializer += + paddle::string::Sprintf(out_template, out_name, out_name); + outs_initializer += ","; + } + } else if (use_inplace_strategy && inplace_map.count(out_name)) { + PADDLE_ENFORCE_NE( + inplace_map[out_name], "", + paddle::platform::errors::InvalidArgument( + "Inplace op %s has no input corresponding to output %s.", op_type, + out_name)); + + // TODO(pangyoki): Inplace op don't have duplicable output in temporary, + // so don't support duplicable output now. + const auto out_template = INPUT_INITIALIZER_TEMPLATE; + + auto inplace_input_name = inplace_map[out_name]; + inplace_mapping_str += paddle::string::Sprintf( + INPLACE_MAPPING_TEMPLATE, inplace_input_name, out_name); + inplace_mapping_str += ","; + + // If inplace op has duplicable input, the first Varbase in input will + // share Varbase with output. + if (FindDuplicableInputInplaceOpSet(op_type)) { + inplace_input_name += INPLACE_DUPLICABLE_INPUT; + } + + // Leaf Var that doesn't stop gradient can't use inplace strategy. + // Increase inplace_version. + inplace_strategy_str += paddle::string::Sprintf( + INPLACE_STRATEGY_TEMPLATE, inplace_input_name, inplace_input_name, + INPLACE_LEAF_ERROR_MESSAGE, inplace_input_name, inplace_input_name, + inplace_input_name); + outs_initializer += + paddle::string::Sprintf(out_template, out_name, inplace_input_name); + outs_initializer += ","; + } else { + // There are few Operators that have duplicable output, like `Out` in + // split op. We need to specify the number of variables for the + // duplicable output, as the argument OutNum; + if (output.duplicable()) { + if (input_args != "") { + input_args += ","; + } + auto out_num_str = paddle::string::Sprintf(ARG_OUT_NUM, out_name); + input_args += ARG_OUT_NUM_TYPE; + input_args += out_num_str; + input_args_num++; + outs_initializer += paddle::string::Sprintf( + OUT_DUPLICABLE_INITIALIZER_TEMPLATE, out_name, out_num_str); + } else { + outs_initializer += + paddle::string::Sprintf(OUT_INITIALIZER_TEMPLATE, out_name); + } + outs_initializer += ","; + } + + return_type += out_type; + return_type += ","; + return_str += paddle::string::Sprintf(return_template, out_name); + return_str += ","; + outs_num += 1; + } + if (outs_initializer.back() == ',') { + outs_initializer.pop_back(); + return_type.pop_back(); + return_str.pop_back(); + } + outs_initializer += "}"; + if (inplace_mapping_str.back() == ',') { + inplace_mapping_str.pop_back(); + } + if (!use_inplace_strategy && FindViewOpMap(op_type)) { + std::string viwe_input_name = view_op_map[op_type].first; + std::string viwe_output_name = view_op_map[op_type].second; + view_strategy_str += paddle::string::Sprintf( + HANDLE_VIEW_BETWEEN_INPUT_AND_OUTPUT, viwe_input_name, viwe_output_name, + viwe_input_name, viwe_output_name); + } + if (outs_num == 0) { + return_type = "void"; + } + if (outs_num > 1) { + return_str = paddle::string::Sprintf(RETURN_TUPLE_TEMPLATE, return_str); + return_type = paddle::string::Sprintf(RETURN_TUPLE_TYPE, return_type); + } + std::string function_args = ""; + if (input_args == "") { + function_args = FUNCTION_ARGS_NO_INPUT; + } else { + function_args = paddle::string::Sprintf(FUNCTION_ARGS, input_args); + } + + // generate op funtcion body + auto op_function_str = paddle::string::Sprintf( + OP_FUNCTION_TEMPLATE, return_type, func_name, function_args, ins_cast_str, + op_type, input_args_num, inplace_strategy_str, outs_initializer, + ins_initializer, ins_initializer_with_null + outs_initializer_with_null + + view_strategy_str, + op_type, inplace_mapping_str, return_str); + + return op_function_str; +} + static std::tuple, std::vector> GenerateOpFunctions(const std::string& module_name) { auto& op_info_map = paddle::framework::OpInfoMap::Instance().map(); @@ -275,158 +499,26 @@ GenerateOpFunctions(const std::string& module_name) { if (!all_kernels.count(op_type)) { continue; } - std::string input_args = ""; - std::string ins_initializer = "{"; - std::string ins_initializer_with_null = ""; - std::string py_arg = ""; - int arg_idx = 0; - int input_args_num = 0; - std::string ins_cast_str = ""; - std::string view_strategy_str = ""; - for (auto& input : op_proto->inputs()) { - auto& in_name = input.name(); - // skip those dispensable inputs, like ResidualData in conv2d - if (input.dispensable() && !FindInsMap(op_type, in_name)) { - continue; - } - const auto in_type = input.duplicable() ? IN_VAR_LIST_TYPE : IN_VAR_TYPE; - auto input_arg = - paddle::string::Sprintf(ARG_TEMPLATE, in_type, TempName(in_name)); - input_args += input_arg; - input_args += ","; - input_args_num++; - const auto in_cast_type = - input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE; - auto dispensable = input.dispensable() ? "true" : "false"; - ins_cast_str += - paddle::string::Sprintf(in_cast_type, in_name, op_type, in_name, - arg_idx++, TempName(in_name), dispensable); - - if (input.dispensable()) { - const auto in_template = input.duplicable() - ? INPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST - : INPUT_INITIALIZER_TEMPLATE_WITH_NULL; - ins_initializer_with_null += - paddle::string::Sprintf(in_template, in_name, in_name, in_name); - } else { - const auto in_template = input.duplicable() - ? INPUT_LIST_INITIALIZER_TEMPLATE - : INPUT_INITIALIZER_TEMPLATE; - ins_initializer += - paddle::string::Sprintf(in_template, in_name, in_name); - ins_initializer += ","; - } - } - if (ins_initializer.back() == ',') { - ins_initializer.pop_back(); - } - ins_initializer += "}"; - - if (input_args.back() == ',') { - input_args.pop_back(); - } - // Generate outs initializer - std::string outs_initializer = "{"; - std::string outs_initializer_with_null = ""; - std::string return_type = ""; - std::string return_str = ""; - - int outs_num = 0; - for (auto& output : op_proto->outputs()) { - auto& out_name = output.name(); - // skip those dispensable oututs - if (output.dispensable() && !FindOutsMap(op_type, out_name)) { - continue; + // NOTE(pangyoki): Inplace Strategy. + // In this case, output will reuse input varbase. + // Dygraph mode needs to be aligned with the in-place strategy in static + // mode, and the mapping relationships between output and input that have + // been defined in static mode should be used in dygraph mode. + // Find which ops need to use Inplace strategy in static mode, and get the + // mapping relationship between Inplace output and input. + auto& infer_inplace = + paddle::framework::OpInfoMap::Instance().Get(op_type).infer_inplace_; + std::map inplace_map; + if (infer_inplace) { + auto in_to_outs = infer_inplace(true); + for (auto& inplace_pair : in_to_outs) { + inplace_map[inplace_pair.second] = inplace_pair.first; } - const auto out_type = - output.duplicable() ? OUT_VAR_LIST_TYPE : OUT_VAR_TYPE; - const auto return_template = - output.duplicable() ? RETURN_LIST_TEMPLATE : RETURN_TEMPLATE; - if (FindPassingOutsMap(op_type, out_name)) { - if (input_args != "") { - input_args += ","; - } - input_args += out_type; - input_args += out_name; - input_args_num++; - - if (output.dispensable()) { - const auto out_template = - output.duplicable() ? OUTPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST - : OUTPUT_INITIALIZER_TEMPLATE_WITH_NULL; - outs_initializer_with_null += - paddle::string::Sprintf(out_template, out_name, out_name); - } else { - const auto out_template = output.duplicable() - ? INPUT_LIST_INITIALIZER_TEMPLATE - : INPUT_INITIALIZER_TEMPLATE; - outs_initializer += - paddle::string::Sprintf(out_template, out_name, out_name); - outs_initializer += ","; - } - } else { - // There are few Operators that have duplicable output, like `Out` in - // split op. We need to specify the number of variables for the - // duplicable output, as the argument OutNum; - if (output.duplicable()) { - if (input_args != "") { - input_args += ","; - } - auto out_num_str = paddle::string::Sprintf(ARG_OUT_NUM, out_name); - input_args += ARG_OUT_NUM_TYPE; - input_args += out_num_str; - input_args_num++; - outs_initializer += paddle::string::Sprintf( - OUT_DUPLICABLE_INITIALIZER_TEMPLATE, out_name, out_num_str); - } else { - outs_initializer += - paddle::string::Sprintf(OUT_INITIALIZER_TEMPLATE, out_name); - } - outs_initializer += ","; - } - - return_type += out_type; - return_type += ","; - return_str += paddle::string::Sprintf(return_template, out_name); - return_str += ","; - outs_num += 1; - } - if (outs_initializer.back() == ',') { - outs_initializer.pop_back(); - return_type.pop_back(); - return_str.pop_back(); - } - outs_initializer += "}"; - if (FindViewOpMap(op_type)) { - std::string viwe_input_name = view_op_map[op_type].first; - std::string viwe_output_name = view_op_map[op_type].second; - view_strategy_str += paddle::string::Sprintf( - HandleViewBetweenInputAndOutput, viwe_input_name, viwe_output_name, - viwe_input_name, viwe_output_name); - } - if (outs_num == 0) { - return_type = "void"; - } - if (outs_num > 1) { - return_str = paddle::string::Sprintf(RETURN_TUPLE_TEMPLATE, return_str); - return_type = paddle::string::Sprintf(RETURN_TUPLE_TYPE, return_type); - } - std::string function_args = ""; - if (input_args == "") { - function_args = FUNCTION_ARGS_NO_INPUT; - } else { - function_args = paddle::string::Sprintf(FUNCTION_ARGS, input_args); } std::string func_name = "imperative_" + op_type; - // generate op funtcion body - auto op_function_str = paddle::string::Sprintf( - OP_FUNCTION_TEMPLATE, return_type, func_name, function_args, - ins_cast_str, op_type, input_args_num, outs_initializer, - ins_initializer, ins_initializer_with_null + - outs_initializer_with_null + view_strategy_str, - op_type, return_str); + std::string op_function_str = GenerateOpFunctionsBody(op_proto, func_name); // generate pybind item auto bind_function_str = paddle::string::Sprintf( @@ -434,6 +526,23 @@ GenerateOpFunctions(const std::string& module_name) { op_function_list.emplace_back(std::move(op_function_str)); bind_function_list.emplace_back(std::move(bind_function_str)); + + if (infer_inplace) { + // Reuse Varbase Inplace OP: op_type_. + // The inplace OP needs a new implementation method. + std::string inplace_op_type = op_type + "_"; + std::string inplace_func_name = "imperative_" + inplace_op_type; + std::string inplace_op_function_str = GenerateOpFunctionsBody( + op_proto, inplace_func_name, true, inplace_map); + + // generate pybind item + auto inplace_bind_function_str = + paddle::string::Sprintf(PYBIND_ITEM_TEMPLATE, module_name, + inplace_op_type, inplace_func_name); + + op_function_list.emplace_back(std::move(inplace_op_function_str)); + bind_function_list.emplace_back(std::move(inplace_bind_function_str)); + } } return std::make_tuple(op_function_list, bind_function_list); } diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 50043a9b3cf..8dabe19f57c 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -113,19 +113,23 @@ from .tensor.manipulation import flatten #DEFINE_ALIAS from .tensor.manipulation import gather #DEFINE_ALIAS from .tensor.manipulation import gather_nd #DEFINE_ALIAS from .tensor.manipulation import reshape #DEFINE_ALIAS +from .tensor.manipulation import reshape_ #DEFINE_ALIAS from .tensor.manipulation import flip as reverse #DEFINE_ALIAS from .tensor.manipulation import scatter #DEFINE_ALIAS +from .tensor.manipulation import scatter_ #DEFINE_ALIAS from .tensor.manipulation import scatter_nd_add #DEFINE_ALIAS from .tensor.manipulation import scatter_nd #DEFINE_ALIAS from .tensor.manipulation import shard_index #DEFINE_ALIAS from .tensor.manipulation import slice #DEFINE_ALIAS from .tensor.manipulation import split #DEFINE_ALIAS from .tensor.manipulation import squeeze #DEFINE_ALIAS +from .tensor.manipulation import squeeze_ #DEFINE_ALIAS from .tensor.manipulation import stack #DEFINE_ALIAS from .tensor.manipulation import strided_slice #DEFINE_ALIAS from .tensor.manipulation import transpose #DEFINE_ALIAS from .tensor.manipulation import unique #DEFINE_ALIAS from .tensor.manipulation import unsqueeze #DEFINE_ALIAS +from .tensor.manipulation import unsqueeze_ #DEFINE_ALIAS from .tensor.manipulation import unstack #DEFINE_ALIAS from .tensor.manipulation import flip #DEFINE_ALIAS from .tensor.manipulation import unbind #DEFINE_ALIAS @@ -172,6 +176,7 @@ from .tensor.math import square #DEFINE_ALIAS from .tensor.math import stanh #DEFINE_ALIAS from .tensor.math import sum #DEFINE_ALIAS from .tensor.math import tanh #DEFINE_ALIAS +from .tensor.math import tanh_ #DEFINE_ALIAS from .tensor.math import add_n #DEFINE_ALIAS from .tensor.math import max #DEFINE_ALIAS from .tensor.math import maximum #DEFINE_ALIAS diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index a9982dc1329..3042248f69c 100755 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -221,12 +221,16 @@ class TestTanhAPI(unittest.TestCase): self.x_np = np.random.uniform(-1, 1, [10, 12]).astype(self.dtype) self.place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \ else paddle.CPUPlace() + self.executed_api() + + def executed_api(self): + self.tanh = F.tanh def test_static_api(self): paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): x = paddle.fluid.data('X', [10, 12], self.dtype) - out1 = F.tanh(x) + out1 = self.tanh(x) th = paddle.nn.Tanh() out2 = th(x) exe = paddle.static.Executor(self.place) @@ -261,15 +265,21 @@ class TestTanhAPI(unittest.TestCase): paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): # The input type must be Variable. - self.assertRaises(TypeError, F.tanh, 1) + self.assertRaises(TypeError, self.tanh, 1) # The input dtype must be float16, float32. x_int32 = paddle.fluid.data( name='x_int32', shape=[12, 10], dtype='int32') - self.assertRaises(TypeError, F.tanh, x_int32) + self.assertRaises(TypeError, self.tanh, x_int32) # support the input dtype is float16 x_fp16 = paddle.fluid.data( name='x_fp16', shape=[12, 10], dtype='float16') - F.tanh(x_fp16) + self.tanh(x_fp16) + + +class TestTanhInplaceAPI(TestTanhAPI): + # test paddle.tanh_ + def executed_api(self): + self.tanh = paddle.tanh_ class TestAtan(TestActivation, TestParameter): @@ -1044,12 +1054,16 @@ class TestReluAPI(unittest.TestCase): self.x_np = np.random.uniform(-1, 1, [10, 12]).astype('float32') self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \ else paddle.CPUPlace() + self.executed_api() + + def executed_api(self): + self.relu = F.relu def test_static_api(self): paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): x = paddle.fluid.data('X', [10, 12]) - out1 = F.relu(x) + out1 = self.relu(x) m = paddle.nn.ReLU() out2 = m(x) exe = paddle.static.Executor(self.place) @@ -1061,9 +1075,9 @@ class TestReluAPI(unittest.TestCase): def test_dygraph_api(self): paddle.disable_static(self.place) x = paddle.to_tensor(self.x_np) - out1 = F.relu(x) m = paddle.nn.ReLU() - out2 = m(x) + out1 = m(x) + out2 = self.relu(x) out_ref = np.maximum(self.x_np, 0) for r in [out1, out2]: self.assertEqual(np.allclose(out_ref, r.numpy()), True) @@ -1073,15 +1087,21 @@ class TestReluAPI(unittest.TestCase): paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): # The input type must be Variable. - self.assertRaises(TypeError, F.relu, 1) + self.assertRaises(TypeError, self.relu, 1) # The input dtype must be float16, float32, float64. x_int32 = paddle.fluid.data( name='x_int32', shape=[10, 12], dtype='int32') - self.assertRaises(TypeError, F.relu, x_int32) + self.assertRaises(TypeError, self.relu, x_int32) # support the input dtype is float16 x_fp16 = paddle.fluid.data( name='x_fp16', shape=[10, 12], dtype='float16') - F.relu(x_fp16) + self.relu(x_fp16) + + +class TestReluInplaceAPI(TestReluAPI): + # test paddle.nn.functional.relu_ + def executed_api(self): + self.relu = F.relu_ def ref_leaky_relu(x, alpha=0.01): @@ -1609,12 +1629,16 @@ class TestELUAPI(unittest.TestCase): self.x_np = np.random.uniform(-3, 3, [10, 12]).astype('float32') self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \ else paddle.CPUPlace() + self.executed_api() + + def executed_api(self): + self.elu = F.elu def test_static_api(self): paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): x = paddle.fluid.data('X', [10, 12]) - out1 = F.elu(x) + out1 = self.elu(x) m = paddle.nn.ELU() out2 = m(x) exe = paddle.static.Executor(self.place) @@ -1626,14 +1650,16 @@ class TestELUAPI(unittest.TestCase): def test_dygraph_api(self): paddle.disable_static(self.place) x = paddle.to_tensor(self.x_np) - out1 = F.elu(x) + out1 = self.elu(x) + x = paddle.to_tensor(self.x_np) m = paddle.nn.ELU() out2 = m(x) out_ref = elu(self.x_np, 1.0) for r in [out1, out2]: self.assertEqual(np.allclose(out_ref, r.numpy()), True) - out1 = F.elu(x, 0.2) + out1 = self.elu(x, 0.2) + x = paddle.to_tensor(self.x_np) m = paddle.nn.ELU(0.2) out2 = m(x) out_ref = elu(self.x_np, 0.2) @@ -1645,15 +1671,21 @@ class TestELUAPI(unittest.TestCase): paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): # The input type must be Variable. - self.assertRaises(TypeError, F.elu, 1) + self.assertRaises(TypeError, self.elu, 1) # The input dtype must be float16, float32, float64. x_int32 = paddle.fluid.data( name='x_int32', shape=[10, 12], dtype='int32') - self.assertRaises(TypeError, F.elu, x_int32) + self.assertRaises(TypeError, self.elu, x_int32) # support the input dtype is float16 x_fp16 = paddle.fluid.data( name='x_fp16', shape=[10, 12], dtype='float16') - F.elu(x_fp16) + self.elu(x_fp16) + + +class TestELUInplaceAPI(TestELUAPI): + # test paddle.nn.functional.elu_ + def executed_api(self): + self.elu = F.elu_ class TestReciprocal(TestActivation): diff --git a/python/paddle/fluid/tests/unittests/test_inplace.py b/python/paddle/fluid/tests/unittests/test_inplace.py index 08a7fe80ea1..2c6507c486e 100644 --- a/python/paddle/fluid/tests/unittests/test_inplace.py +++ b/python/paddle/fluid/tests/unittests/test_inplace.py @@ -95,5 +95,206 @@ class TestInplace(unittest.TestCase): loss.backward() +class TestDygraphInplace(unittest.TestCase): + def setUp(self): + self.init_data() + + def init_data(self): + self.input_var_numpy = np.random.rand(2, 3, 1) + self.dtype = "float32" + + def non_inplace_api_processing(self, var): + return paddle.squeeze(var) + + def inplace_api_processing(self, var): + return paddle.squeeze_(var) + + def test_inplace_api(self): + var = paddle.to_tensor(self.input_var_numpy).astype(self.dtype) + inplace_var = self.inplace_api_processing(var) + self.assertTrue(id(var) == id(inplace_var)) + + inplace_var[0] = 2. + self.assertTrue(np.array_equal(var.numpy(), inplace_var.numpy())) + + def test_forward_version(self): + with paddle.fluid.dygraph.guard(): + var = paddle.to_tensor(self.input_var_numpy).astype(self.dtype) + self.assertEqual(var.inplace_version, 0) + + inplace_var = self.inplace_api_processing(var) + self.assertEqual(var.inplace_version, 1) + + inplace_var[0] = 2. + self.assertEqual(var.inplace_version, 2) + + inplace_var = self.inplace_api_processing(inplace_var) + self.assertEqual(var.inplace_version, 3) + + def test_leaf_inplace_var_error(self): + with paddle.fluid.dygraph.guard(): + var = paddle.to_tensor(self.input_var_numpy).astype(self.dtype) + var.stop_gradient = False + + def leaf_inplace_error(): + self.inplace_api_processing(var) + + self.assertRaises(ValueError, leaf_inplace_error) + + def test_backward_error(self): + # It raises an error because the inplace operator will result + # in incorrect gradient computation. + with paddle.fluid.dygraph.guard(): + var_a = paddle.to_tensor(self.input_var_numpy).astype(self.dtype) + var_a.stop_gradient = False + + var_b = var_a**2 + + # Here, the gradient computation will use the value of var_b + var_c = var_b**2 + self.inplace_api_processing(var_b) + + loss = paddle.nn.functional.relu(var_c) + with self.assertRaisesRegexp( + RuntimeError, + "received tensor_version:{} != wrapper_version_snapshot:{}". + format(1, 0)): + loss.backward() + + def test_backward_success_1(self): + # var_b is modified inplace before using it, the inplace operator doesn't result + # in incorrect gradient computation. + grad_var_a, grad_var_a_inplace = 0, 1 + with paddle.fluid.dygraph.guard(): + var_a = paddle.to_tensor(self.input_var_numpy).astype(self.dtype) + var_a.stop_gradient = False + + var_b = var_a**2 + var_c = self.inplace_api_processing( + var_b) # var_b is modified inplace before using it + + # Here, the gradient computation will use the value of var_b + var_d = var_c**2 + loss = var_d.sum() + loss.backward() + grad_var_a_inplace = var_a.grad + + with paddle.fluid.dygraph.guard(): + var_a = paddle.to_tensor(self.input_var_numpy).astype(self.dtype) + var_a.stop_gradient = False + + var_b = var_a**2 + var_c = self.non_inplace_api_processing(var_b) + var_d = var_c**2 + loss = var_d.sum() + loss.backward() + grad_var_a = var_a.grad + + self.assertTrue(np.array_equal(grad_var_a_inplace, grad_var_a)) + + def test_backward_success_2(self): + # Although var_b is modified inplace after using it, it does not used in gradient computation. + # The inplace operator doesn't result in incorrect gradient computation. + grad_var_a, grad_var_a_inplace = 0, 1 + with paddle.fluid.dygraph.guard(): + var_a = paddle.to_tensor(self.input_var_numpy).astype(self.dtype) + var_a.stop_gradient = False + + var_b = var_a**2 + + var_c = self.inplace_api_processing( + var_b) # var_b is modified inplace before using it + + var_d = var_c + var_c # Here, the grad op of sum doesn't use the value of var_b + loss = var_d.sum() + + loss.backward() + grad_var_a_inplace = var_a.grad + + with paddle.fluid.dygraph.guard(): + var_a = paddle.to_tensor(self.input_var_numpy).astype(self.dtype) + var_a.stop_gradient = False + + var_b = var_a**2 + + var_c = self.non_inplace_api_processing( + var_b) # var_b is modified inplace before using it + + var_d = var_c + var_c # Here, the grad op of sum doesn't use the value of var_b + loss = var_d.sum() + + loss.backward() + grad_var_a = var_a.grad + self.assertTrue(np.array_equal(grad_var_a_inplace, grad_var_a)) + + +class TestDygraphInplaceUnsqueeze(TestDygraphInplace): + def non_inplace_api_processing(self, var): + return paddle.unsqueeze(var, -1) + + def inplace_api_processing(self, var): + return paddle.unsqueeze_(var, -1) + + +class TestDygraphInplaceReshape(TestDygraphInplace): + def non_inplace_api_processing(self, var): + return paddle.reshape(var, [-1]) + + def inplace_api_processing(self, var): + return paddle.reshape_(var, [-1]) + + +class TestDygraphInplaceScatter(TestDygraphInplace): + def init_data(self): + self.input_var_numpy = np.array([[1, 1], [2, 2], [3, 3]]) + self.dtype = "float32" + + def non_inplace_api_processing(self, var): + index = paddle.to_tensor([2, 1, 0, 1], dtype='int64') + updates = paddle.to_tensor( + [[1, 1], [2, 2], [3, 3], [4, 4]], dtype='float32') + + return paddle.scatter(var, index, updates, overwrite=False) + + def inplace_api_processing(self, var): + index = paddle.to_tensor([2, 1, 0, 1], dtype='int64') + updates = paddle.to_tensor( + [[1, 1], [2, 2], [3, 3], [4, 4]], dtype='float32') + + return paddle.scatter_(var, index, updates, overwrite=False) + + +class TestDygraphInplaceElu(TestDygraphInplace): + def non_inplace_api_processing(self, var): + return paddle.nn.functional.elu(var) + + def inplace_api_processing(self, var): + return paddle.nn.functional.elu_(var) + + +class TestDygraphInplaceRelu(TestDygraphInplace): + def non_inplace_api_processing(self, var): + return paddle.nn.functional.relu(var) + + def inplace_api_processing(self, var): + return paddle.nn.functional.relu_(var) + + +class TestDygraphInplaceSoftmax(TestDygraphInplace): + def non_inplace_api_processing(self, var): + return paddle.nn.functional.softmax(var) + + def inplace_api_processing(self, var): + return paddle.nn.functional.softmax_(var) + + +class TestDygraphInplaceTanh(TestDygraphInplace): + def non_inplace_api_processing(self, var): + return paddle.tanh(var) + + def inplace_api_processing(self, var): + return paddle.tanh_(var) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_reshape_op.py b/python/paddle/fluid/tests/unittests/test_reshape_op.py index d4a6ae4965e..4e296e7a889 100755 --- a/python/paddle/fluid/tests/unittests/test_reshape_op.py +++ b/python/paddle/fluid/tests/unittests/test_reshape_op.py @@ -250,8 +250,11 @@ class TestReshapeAPI(unittest.TestCase): def _set_paddle_api(self): self.fill_constant = paddle.fluid.layers.fill_constant self.data = paddle.static.data - self.reshape = paddle.reshape self.to_tensor = paddle.to_tensor + self._executed_api() + + def _executed_api(self): + self.reshape = paddle.reshape def _set_fluid_api(self): self.fill_constant = fluid.layers.fill_constant @@ -322,6 +325,30 @@ class TestReshapeAPI(unittest.TestCase): assert np.array_equal(out_3.numpy(), input.reshape(shape)) +class TestStaticReshape_(TestReshapeAPI): + def _executed_api(self): + self.reshape = paddle.reshape_ + + def test_imperative(self): + self._set_paddle_api() + input = np.random.random([2, 25]).astype("float32") + shape = [2, 5, 5] + with fluid.dygraph.guard(): + x = self.to_tensor(input) + positive_five = self.fill_constant([1], "int32", 5) + + out_1 = self.reshape(x, shape) + + out_2 = self.reshape(x, shape=[positive_five, 10]) + + shape_tensor = self.to_tensor(np.array([2, 5, 5]).astype("int32")) + out_3 = self.reshape(x, shape=shape_tensor) + + assert np.array_equal(out_1.numpy(), input.reshape(shape)) + assert np.array_equal(out_2.numpy(), input.reshape(shape)) + assert np.array_equal(out_3.numpy(), input.reshape(shape)) + + # Test Input Error class TestReshapeOpError(unittest.TestCase): def _set_paddle_api(self): @@ -397,12 +424,18 @@ class TestReshapeOpError(unittest.TestCase): self._test_errors() -class API_TestDygraphReshape(unittest.TestCase): +class TestDygraphReshapeAPI(unittest.TestCase): + def setUp(self): + self.executed_api() + + def executed_api(self): + self.reshape = paddle.reshape + def test_out(self): paddle.disable_static() input_1 = np.random.random([5, 1, 10]).astype("int32") input = paddle.to_tensor(input_1) - output = paddle.reshape(x=input, shape=[5, 10]) + output = self.reshape(x=input, shape=[5, 10]) out_np = output.numpy() expected_out = np.reshape(input_1, newshape=[5, 10]) self.assertTrue(np.allclose(expected_out, out_np)) @@ -411,7 +444,7 @@ class API_TestDygraphReshape(unittest.TestCase): paddle.disable_static() input_1 = np.random.random([5, 1, 10]).astype("uint8") input = paddle.to_tensor(input_1) - output = paddle.reshape(x=input, shape=[5, 10]) + output = self.reshape(x=input, shape=[5, 10]) out_np = output.numpy() expected_out = np.reshape(input_1, newshape=[5, 10]) self.assertTrue(np.allclose(expected_out, out_np)) @@ -420,11 +453,16 @@ class API_TestDygraphReshape(unittest.TestCase): paddle.disable_static() input_1 = np.random.random([5, 1, 10]).astype("float32") input = paddle.to_tensor(input_1) - output = paddle.reshape(x=input, shape=[5, 10]) + output = self.reshape(x=input, shape=[5, 10]) out_np = output.numpy() expected_out = np.reshape(input_1, newshape=[5, 10]) self.assertTrue(np.allclose(expected_out, out_np)) +class TestDygraphReshapeInplaceAPI(TestDygraphReshapeAPI): + def executed_api(self): + self.reshape = paddle.reshape_ + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_scatter_op.py b/python/paddle/fluid/tests/unittests/test_scatter_op.py index ce3b060828a..e2f012e9a63 100644 --- a/python/paddle/fluid/tests/unittests/test_scatter_op.py +++ b/python/paddle/fluid/tests/unittests/test_scatter_op.py @@ -180,13 +180,17 @@ class TestScatterAPI(unittest.TestCase): self.places = [fluid.CPUPlace()] if core.is_compiled_with_cuda(): self.places.append(fluid.CUDAPlace(0)) + self.executed_api() + + def executed_api(self): + self.scatter = paddle.scatter def check_static_result(self, place): with fluid.program_guard(fluid.Program(), fluid.Program()): input = fluid.data(name="input", shape=[3, 2], dtype="float64") index = fluid.data(name="index", shape=[4], dtype="int64") updates = fluid.data(name="updates", shape=[4, 2], dtype="float64") - result = paddle.scatter(input, index, updates, False) + result = self.scatter(input, index, updates, False) input_data = np.array([[1, 1], [2, 2], [3, 3]]).astype(np.float64) index_data = np.array([2, 1, 0, 1]).astype(np.int64) @@ -220,10 +224,15 @@ class TestScatterAPI(unittest.TestCase): index = fluid.dygraph.to_variable(index_data) updates = fluid.dygraph.to_variable(updates_data) - output1 = paddle.scatter(x, index, updates, overwrite=False) + output1 = self.scatter(x, index, updates, overwrite=False) self.assertEqual((output1.numpy() == \ np.array([[3., 3.],[6., 6.],[1., 1.]])).all(), True) +class TestScatterInplaceAPI(TestScatterAPI): + def executed_api(self): + self.scatter = paddle.scatter_ + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_softmax_op.py b/python/paddle/fluid/tests/unittests/test_softmax_op.py index 71c4e9c495e..9b0de4e59b4 100644 --- a/python/paddle/fluid/tests/unittests/test_softmax_op.py +++ b/python/paddle/fluid/tests/unittests/test_softmax_op.py @@ -301,11 +301,15 @@ class TestSoftmaxAPI(unittest.TestCase): ) else paddle.CPUPlace() self.x_np = np.random.uniform(-1., 1., [2, 3, 4, 5]).astype('float32') self.out_ref = np.apply_along_axis(stable_softmax, -1, self.x_np) + self.executed_api() + + def executed_api(self): + self.softmax = F.softmax def test_static_check(self): with paddle.static.program_guard(paddle.static.Program()): x = paddle.fluid.data('X', self.x_np.shape, 'float32') - out1 = F.softmax(x) + out1 = self.softmax(x) m = paddle.nn.Softmax() out2 = m(x) exe = paddle.static.Executor(self.place) @@ -318,21 +322,23 @@ class TestSoftmaxAPI(unittest.TestCase): paddle.disable_static(self.place) x = paddle.to_tensor(self.x_np) - out1 = F.softmax(x) + out1 = self.softmax(x) + x = paddle.to_tensor(self.x_np) m = paddle.nn.Softmax() out2 = m(x) out_ref = ref_softmax(self.x_np, axis=-1, dtype=None) for r in [out1, out2]: self.assertEqual(np.allclose(out_ref, r.numpy()), True) - out1 = F.softmax(x, axis=0) + out1 = self.softmax(x, axis=0) + x = paddle.to_tensor(self.x_np) m = paddle.nn.Softmax(axis=0) out2 = m(x) out_ref = ref_softmax(self.x_np, axis=0, dtype=None) for r in [out1, out2]: self.assertEqual(np.allclose(out_ref, r.numpy()), True) - out = F.softmax(x, dtype=np.float64) + out = self.softmax(x, dtype=np.float64) out_ref = ref_softmax(self.x_np, axis=-1, dtype=np.float64) self.assertEqual(np.allclose(out_ref, out.numpy()), True) @@ -341,15 +347,20 @@ class TestSoftmaxAPI(unittest.TestCase): def test_error(self): with paddle.static.program_guard(paddle.static.Program()): # The input type must be Variable. - self.assertRaises(TypeError, F.softmax, 1) + self.assertRaises(TypeError, self.softmax, 1) # The input dtype must be float16, float32, float64. x_int32 = paddle.fluid.data( name='x_int32', shape=[2, 3], dtype='int32') - self.assertRaises(TypeError, F.softmax, x_int32) + self.assertRaises(TypeError, self.softmax, x_int32) # support the input dtype is float16 x_fp16 = paddle.fluid.data( name='x_fp16', shape=[2, 3], dtype='float16') - F.softmax(x_fp16) + self.softmax(x_fp16) + + +class TestSoftmaxInplaceAPI(TestSoftmaxAPI): + def executed_api(self): + self.softmax = F.softmax_ if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_squeeze_op.py b/python/paddle/fluid/tests/unittests/test_squeeze_op.py index 3a26f967e9b..a048293c8da 100755 --- a/python/paddle/fluid/tests/unittests/test_squeeze_op.py +++ b/python/paddle/fluid/tests/unittests/test_squeeze_op.py @@ -98,13 +98,19 @@ class TestSqueezeOpError(unittest.TestCase): class API_TestSqueeze(unittest.TestCase): + def setUp(self): + self.executed_api() + + def executed_api(self): + self.squeeze = paddle.squeeze + def test_out(self): paddle.enable_static() with paddle.static.program_guard(paddle.static.Program(), paddle.static.Program()): data1 = paddle.static.data( 'data1', shape=[-1, 1, 10], dtype='float64') - result_squeeze = paddle.squeeze(data1, axis=[1]) + result_squeeze = self.squeeze(data1, axis=[1]) place = paddle.CPUPlace() exe = paddle.static.Executor(place) input1 = np.random.random([5, 1, 10]).astype('float64') @@ -114,12 +120,23 @@ class API_TestSqueeze(unittest.TestCase): self.assertTrue(np.allclose(expected_result, result)) +class API_TestStaticSqueeze_(API_TestSqueeze): + def executed_api(self): + self.squeeze = paddle.squeeze_ + + class API_TestDygraphSqueeze(unittest.TestCase): + def setUp(self): + self.executed_api() + + def executed_api(self): + self.squeeze = paddle.squeeze + def test_out(self): paddle.disable_static() input_1 = np.random.random([5, 1, 10]).astype("int32") input = paddle.to_tensor(input_1) - output = paddle.squeeze(input, axis=[1]) + output = self.squeeze(input, axis=[1]) out_np = output.numpy() expected_out = np.squeeze(input_1, axis=1) self.assertTrue(np.allclose(expected_out, out_np)) @@ -128,7 +145,7 @@ class API_TestDygraphSqueeze(unittest.TestCase): paddle.disable_static() input_1 = np.random.random([5, 1, 10]).astype("int8") input = paddle.to_tensor(input_1) - output = paddle.squeeze(input, axis=[1]) + output = self.squeeze(input, axis=[1]) out_np = output.numpy() expected_out = np.squeeze(input_1, axis=1) self.assertTrue(np.allclose(expected_out, out_np)) @@ -137,7 +154,7 @@ class API_TestDygraphSqueeze(unittest.TestCase): paddle.disable_static() input_1 = np.random.random([5, 1, 10]).astype("uint8") input = paddle.to_tensor(input_1) - output = paddle.squeeze(input, axis=[1]) + output = self.squeeze(input, axis=[1]) out_np = output.numpy() expected_out = np.squeeze(input_1, axis=1) self.assertTrue(np.allclose(expected_out, out_np)) @@ -146,7 +163,7 @@ class API_TestDygraphSqueeze(unittest.TestCase): paddle.disable_static() input_1 = np.random.random([5, 1, 10]).astype("int32") input = paddle.to_tensor(input_1) - output = paddle.squeeze(input, axis=1) + output = self.squeeze(input, axis=1) out_np = output.numpy() expected_out = np.squeeze(input_1, axis=1) self.assertTrue(np.allclose(expected_out, out_np)) @@ -155,11 +172,16 @@ class API_TestDygraphSqueeze(unittest.TestCase): paddle.disable_static() input_1 = np.random.random([5, 1, 10]).astype("int32") input = paddle.to_tensor(input_1) - output = paddle.squeeze(input, axis=(1, 2)) + output = self.squeeze(input, axis=(1, 0)) out_np = output.numpy() expected_out = np.squeeze(input_1, axis=1) self.assertTrue(np.allclose(expected_out, out_np)) +class API_TestDygraphSqueezeInplace(API_TestDygraphSqueeze): + def executed_api(self): + self.squeeze = paddle.squeeze_ + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_unsqueeze2_op.py b/python/paddle/fluid/tests/unittests/test_unsqueeze2_op.py index 7a57f8a3825..b75e32f2bad 100755 --- a/python/paddle/fluid/tests/unittests/test_unsqueeze2_op.py +++ b/python/paddle/fluid/tests/unittests/test_unsqueeze2_op.py @@ -208,6 +208,12 @@ class TestUnsqueezeOp4_AxesTensor(TestUnsqueezeOp_AxesTensor): # test api class TestUnsqueezeAPI(unittest.TestCase): + def setUp(self): + self.executed_api() + + def executed_api(self): + self.unsqueeze = paddle.unsqueeze + def test_api(self): input = np.random.random([3, 2, 5]).astype("float64") x = paddle.static.data(name='x', shape=[3, 2, 5], dtype="float64") @@ -218,12 +224,11 @@ class TestUnsqueezeAPI(unittest.TestCase): axes_tensor_int64 = paddle.static.data( name='axes_tensor_int64', shape=[3], dtype="int64") - out_1 = paddle.unsqueeze(x, axis=[3, 1, 1]) - out_2 = paddle.unsqueeze( - x, axis=[positive_3_int32, positive_1_int64, 1]) - out_3 = paddle.unsqueeze(x, axis=axes_tensor_int32) - out_4 = paddle.unsqueeze(x, axis=3) - out_5 = paddle.unsqueeze(x, axis=axes_tensor_int64) + out_1 = self.unsqueeze(x, axis=[3, 1, 1]) + out_2 = self.unsqueeze(x, axis=[positive_3_int32, positive_1_int64, 1]) + out_3 = self.unsqueeze(x, axis=axes_tensor_int32) + out_4 = self.unsqueeze(x, axis=3) + out_5 = self.unsqueeze(x, axis=axes_tensor_int64) exe = paddle.static.Executor(place=paddle.CPUPlace()) res_1, res_2, res_3, res_4, res_5 = exe.run( @@ -244,10 +249,15 @@ class TestUnsqueezeAPI(unittest.TestCase): def test_error(self): def test_axes_type(): x2 = paddle.static.data(name="x2", shape=[2, 25], dtype="int32") - paddle.unsqueeze(x2, axis=2.1) + self.unsqueeze(x2, axis=2.1) self.assertRaises(TypeError, test_axes_type) +class TestUnsqueezeInplaceAPI(TestUnsqueezeAPI): + def executed_api(self): + self.unsqueeze = paddle.unsqueeze_ + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py index 98cb5cdb550..9c705837334 100755 --- a/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py +++ b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py @@ -203,11 +203,17 @@ class API_TestDyUnsqueezeAxisTensorList(unittest.TestCase): class API_TestDygraphUnSqueeze(unittest.TestCase): + def setUp(self): + self.executed_api() + + def executed_api(self): + self.unsqueeze = paddle.unsqueeze + def test_out(self): paddle.disable_static() input_1 = np.random.random([5, 1, 10]).astype("int32") input = paddle.to_tensor(input_1) - output = paddle.unsqueeze(input, axis=[1]) + output = self.unsqueeze(input, axis=[1]) out_np = output.numpy() expected_out = np.expand_dims(input_1, axis=1) self.assertTrue(np.allclose(expected_out, out_np)) @@ -216,7 +222,7 @@ class API_TestDygraphUnSqueeze(unittest.TestCase): paddle.disable_static() input_1 = np.random.random([5, 1, 10]).astype("int8") input = paddle.to_tensor(input_1) - output = paddle.unsqueeze(input, axis=[1]) + output = self.unsqueeze(input, axis=[1]) out_np = output.numpy() expected_out = np.expand_dims(input_1, axis=1) self.assertTrue(np.allclose(expected_out, out_np)) @@ -225,7 +231,7 @@ class API_TestDygraphUnSqueeze(unittest.TestCase): paddle.disable_static() input_1 = np.random.random([5, 1, 10]).astype("uint8") input = paddle.to_tensor(input_1) - output = paddle.unsqueeze(input, axis=1) + output = self.unsqueeze(input, axis=1) out_np = output.numpy() expected_out = np.expand_dims(input_1, axis=1) self.assertTrue(np.allclose(expected_out, out_np)) @@ -234,7 +240,7 @@ class API_TestDygraphUnSqueeze(unittest.TestCase): paddle.disable_static() input_1 = np.random.random([5, 1, 10]).astype("int32") input = paddle.to_tensor(input_1) - output = paddle.unsqueeze(input, axis=1) + output = self.unsqueeze(input, axis=1) out_np = output.numpy() expected_out = np.expand_dims(input_1, axis=1) self.assertTrue(np.allclose(expected_out, out_np)) @@ -243,11 +249,16 @@ class API_TestDygraphUnSqueeze(unittest.TestCase): paddle.disable_static() input_1 = np.random.random([5, 1, 10]).astype("int32") input = paddle.to_tensor(input_1) - output = paddle.unsqueeze(input, axis=(1, 2)) + output = self.unsqueeze(input, axis=(1, 2)) out_np = output.numpy() expected_out = np.expand_dims(input_1, axis=1) self.assertTrue(np.allclose(expected_out, out_np)) +class API_TestDygraphUnSqueezeInplace(API_TestDygraphUnSqueeze): + def executed_api(self): + self.unsqueeze = paddle.unsqueeze_ + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 501d9fcfd40..36f39a5056e 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -30,6 +30,7 @@ __all__ += pooling.__all__ from . import loss __all__ += loss.__all__ from .activation import elu #DEFINE_ALIAS +from .activation import elu_ #DEFINE_ALIAS # from .activation import erf #DEFINE_ALIAS from .activation import gelu #DEFINE_ALIAS from .activation import hardshrink #DEFINE_ALIAS @@ -41,16 +42,19 @@ from .activation import log_sigmoid #DEFINE_ALIAS from .activation import maxout #DEFINE_ALIAS from .activation import prelu #DEFINE_ALIAS from .activation import relu #DEFINE_ALIAS +from .activation import relu_ #DEFINE_ALIAS from .activation import relu6 #DEFINE_ALIAS from .activation import selu #DEFINE_ALIAS from .activation import sigmoid #DEFINE_ALIAS # from .activation import soft_relu #DEFINE_ALIAS from .activation import softmax #DEFINE_ALIAS +from .activation import softmax_ #DEFINE_ALIAS from .activation import softplus #DEFINE_ALIAS from .activation import softshrink #DEFINE_ALIAS from .activation import softsign #DEFINE_ALIAS from .activation import swish #DEFINE_ALIAS from .activation import tanh #DEFINE_ALIAS +from .activation import tanh_ #DEFINE_ALIAS from .activation import tanhshrink #DEFINE_ALIAS from .activation import thresholded_relu #DEFINE_ALIAS from .activation import log_softmax #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 34f44fb2390..3553a93dfab 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -20,10 +20,14 @@ from ...fluid.layers import maxout #DEFINE_ALIAS from ...fluid.layers import swish #DEFINE_ALIAS from ...fluid.layers import sigmoid #DEFINE_ALIAS from ...tensor.math import tanh #DEFINE_ALIAS +from ...tensor.math import tanh_ #DEFINE_ALIAS + +from ...tensor.manipulation import _print_warning_in_static_mode __all__ = [ 'brelu', 'elu', + 'elu_', 'gelu', 'hardshrink', 'hardtanh', @@ -34,15 +38,18 @@ __all__ = [ 'maxout', 'prelu', 'relu', + 'relu_', 'relu6', 'selu', 'softmax', + 'softmax_', 'softplus', 'softshrink', 'softsign', 'sigmoid', 'swish', 'tanh', + 'tanh_', 'tanhshrink', 'thresholded_relu', 'log_softmax', @@ -99,6 +106,19 @@ def elu(x, alpha=1.0, name=None): return out +def elu_(x, alpha=1.0, name=None): + r""" + Inplace version of ``elu`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_nn_cn_elu`. + """ + + if in_dygraph_mode(): + return core.ops.elu_(x, 'alpha', alpha) + + _print_warning_in_static_mode("elu") + return elu(x, alpha, name) + + def gelu(x, approximate=False, name=None): r""" gelu activation. @@ -514,6 +534,19 @@ def relu(x, name=None): return out +def relu_(x, name=None): + """ + Inplace version of ``relu`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_nn_cn_relu`. + """ + + if in_dygraph_mode(): + return core.ops.relu_(x) + + _print_warning_in_static_mode("relu") + return relu(x, name) + + def log_sigmoid(x, name=None): r""" log_sigmoid activation. @@ -879,6 +912,23 @@ def softmax(x, axis=-1, dtype=None, name=None): return outs_softmax +def softmax_(x, axis=-1, dtype=None, name=None): + r""" + Inplace version of ``softmax`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_nn_cn_softmax`. + """ + + if (dtype is not None) and (not isinstance(dtype, core.VarDesc.VarType)): + dtype = convert_np_dtype_to_dtype_(dtype) + use_cudnn = True + + if in_dygraph_mode(): + return core.ops.softmax_(x, 'axis', axis, 'use_cudnn', use_cudnn) + + _print_warning_in_static_mode("softmax") + return softmax(x, axis, dtype, name) + + def softplus(x, beta=1, threshold=20, name=None): r""" softplus activation diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 957042e263e..0a75f6fd7ba 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -82,19 +82,23 @@ from .manipulation import flatten #DEFINE_ALIAS from .manipulation import gather #DEFINE_ALIAS from .manipulation import gather_nd #DEFINE_ALIAS from .manipulation import reshape #DEFINE_ALIAS +from .manipulation import reshape_ #DEFINE_ALIAS from .manipulation import flip as reverse #DEFINE_ALIAS from .manipulation import scatter #DEFINE_ALIAS +from .manipulation import scatter_ #DEFINE_ALIAS from .manipulation import scatter_nd_add #DEFINE_ALIAS from .manipulation import scatter_nd #DEFINE_ALIAS from .manipulation import shard_index #DEFINE_ALIAS from .manipulation import slice #DEFINE_ALIAS from .manipulation import split #DEFINE_ALIAS from .manipulation import squeeze #DEFINE_ALIAS +from .manipulation import squeeze_ #DEFINE_ALIAS from .manipulation import stack #DEFINE_ALIAS from .manipulation import strided_slice #DEFINE_ALIAS from .manipulation import transpose #DEFINE_ALIAS from .manipulation import unique #DEFINE_ALIAS from .manipulation import unsqueeze #DEFINE_ALIAS +from .manipulation import unsqueeze_ #DEFINE_ALIAS from .manipulation import unstack #DEFINE_ALIAS from .manipulation import flip #DEFINE_ALIAS from .manipulation import unbind #DEFINE_ALIAS @@ -138,6 +142,7 @@ from .math import square #DEFINE_ALIAS from .math import stanh #DEFINE_ALIAS from .math import sum #DEFINE_ALIAS from .math import tanh #DEFINE_ALIAS +from .math import tanh_ #DEFINE_ALIAS from .math import add_n #DEFINE_ALIAS from .math import max #DEFINE_ALIAS from .math import maximum #DEFINE_ALIAS diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index adb3f5a3c5f..2583c4b95d9 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -32,6 +32,7 @@ from ..fluid.layers import scatter_nd #DEFINE_ALIAS from ..fluid.layers import shard_index #DEFINE_ALIAS from ..fluid import layers import paddle +import warnings __all__ = [ 'cast', @@ -43,8 +44,10 @@ __all__ = [ 'gather', 'gather_nd', 'reshape', + 'reshape_', 'reverse', 'scatter', + 'scatter_', 'scatter_nd_add', 'scatter_nd', 'shard_index', @@ -52,11 +55,13 @@ __all__ = [ 'split', 'chunk', 'squeeze', + 'squeeze_', 'stack', 'strided_slice', 'transpose', 'unique', 'unsqueeze', + 'unsqueeze_', 'unstack', 'flip', 'unbind', @@ -65,6 +70,12 @@ __all__ = [ ] +def _print_warning_in_static_mode(api_name): + warnings.warn( + "In static mode, {}_() is the same as {}() and does not perform inplace operation.". + format(api_name, api_name)) + + def concat(x, axis=0, name=None): """ @@ -567,6 +578,26 @@ def squeeze(x, axis=None, name=None): return layers.squeeze(x, axis, name) +def squeeze_(x, axis=None, name=None): + """ + Inplace version of ``squeeze`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_paddle_tensor_squeeze`. + """ + if axis is None: + axis = [] + elif isinstance(axis, int): + axis = [axis] + elif isinstance(axis, tuple): + axis = list(axis) + + if in_dygraph_mode(): + out, _ = core.ops.squeeze2_(x, 'axes', axis) + return out + + _print_warning_in_static_mode("squeeze") + return squeeze(x, axis, name) + + def unique(x, return_index=False, return_inverse=False, @@ -740,6 +771,28 @@ def unsqueeze(x, axis, name=None): return layers.unsqueeze(x, axis, name) +def unsqueeze_(x, axis, name=None): + """ + Inplace version of ``unsqueeze`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_paddle_tensor_unsqueeze`. + """ + if in_dygraph_mode(): + if isinstance(axis, int): + axis = [axis] + elif isinstance(axis, Variable): + axis = axis.numpy().tolist() + elif isinstance(axis, (list, tuple)): + axis = [ + item.numpy().item(0) if isinstance(item, Variable) else item + for item in axis + ] + out, _ = core.ops.unsqueeze2_(x, 'axes', axis) + return out + + _print_warning_in_static_mode("unsqueeze") + return unsqueeze(x, axis, name) + + def gather(x, index, axis=None, name=None): """ Output is obtained by gathering entries of ``axis`` @@ -966,6 +1019,18 @@ def scatter(x, index, updates, overwrite=True, name=None): return out +def scatter_(x, index, updates, overwrite=True, name=None): + """ + Inplace version of ``scatter`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_paddle_tensor_scatter`. + """ + if in_dygraph_mode(): + return core.ops.scatter_(x, index, updates, 'overwrite', overwrite) + + _print_warning_in_static_mode("scatter") + return scatter(x, index, updates, overwrite, name) + + def scatter_nd_add(x, index, updates, name=None): r""" **Scatter_nd_add Layer** @@ -1485,6 +1550,28 @@ def reshape(x, shape, name=None): return paddle.fluid.layers.reshape(x=x, shape=shape, name=name) +def reshape_(x, shape, name=None): + """ + Inplace version of ``reshape`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_paddle_tensor_reshape`. + """ + if in_dygraph_mode(): + if isinstance(shape, (list, tuple)): + shape = [ + item.numpy().item(0) if isinstance(item, Variable) else item + for item in shape + ] + out, _ = core.ops.reshape2_(x, None, 'shape', shape) + return out + elif isinstance(shape, Variable): + shape.stop_gradient = True + out, _ = core.ops.reshape2_(x, shape) + return out + + _print_warning_in_static_mode("reshape") + return reshape(x, shape, name) + + def gather_nd(x, index, name=None): """ diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index fc99eabc7da..87efa9ac442 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -25,6 +25,7 @@ from ..fluid.framework import core, _varbase_creator, in_dygraph_mode, Variable, from ..fluid.layer_helper import LayerHelper from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype, convert_dtype from ..fluid.layers.layer_function_generator import _generate_doc_string_, generate_activation_fn, generate_layer_fn +from .manipulation import _print_warning_in_static_mode # TODO: define math functions # yapf: disable @@ -99,6 +100,7 @@ __all__ = [ 'stanh', 'sum', 'tanh', + 'tanh_', 'add_n', 'max', 'maximum', @@ -1969,6 +1971,17 @@ def tanh(x, name=None): helper.append_op(type='tanh', inputs={'X': x}, outputs={'Out': out}) return out +def tanh_(x, name=None): + r""" + Inplace version of ``tanh`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_tensor_tanh`. + """ + if in_dygraph_mode(): + return core.ops.tanh_(x) + + _print_warning_in_static_mode("tanh") + return tanh(x, name) + def increment(x, value=1.0, name=None): """ The OP is usually used for control flow to increment the data of :attr:`x` by an amount :attr:`value`. diff --git a/tools/wlist.json b/tools/wlist.json index f907d609898..e8ec83b49db 100644 --- a/tools/wlist.json +++ b/tools/wlist.json @@ -21,6 +21,38 @@ { "name":"xxxxx", "annotation":"not a real api, just for example" + }, + { + "name":"squeeze_", + "annotation":"Inplace APIs don't need sample code. There is a special document introducing Inplace strategy" + }, + { + "name":"unsqueeze_", + "annotation":"Inplace APIs don't need sample code. There is a special document introducing Inplace strategy" + }, + { + "name":"reshape_", + "annotation":"Inplace APIs don't need sample code. There is a special document introducing Inplace strategy" + }, + { + "name":"scatter_", + "annotation":"Inplace APIs don't need sample code. There is a special document introducing Inplace strategy" + }, + { + "name":"elu_", + "annotation":"Inplace APIs don't need sample code. There is a special document introducing Inplace strategy" + }, + { + "name":"relu_", + "annotation":"Inplace APIs don't need sample code. There is a special document introducing Inplace strategy" + }, + { + "name":"softmax_", + "annotation":"Inplace APIs don't need sample code. There is a special document introducing Inplace strategy" + }, + { + "name":"tanh_", + "annotation":"Inplace APIs don't need sample code. There is a special document introducing Inplace strategy" } ], "wlist_temp_api":[ -- GitLab